diff --git a/mocks/pkg/ovs/interface.go b/mocks/pkg/ovs/interface.go index 0313c639e32..a7cca286441 100644 --- a/mocks/pkg/ovs/interface.go +++ b/mocks/pkg/ovs/interface.go @@ -847,17 +847,22 @@ func (mr *MockLoadBalancerMockRecorder) LoadBalancerAddVips(lbName, vips interfa } // LoadBalancerDeleteVips mocks base method. -func (m *MockLoadBalancer) LoadBalancerDeleteVips(lbName string, vips map[string]struct{}) error { +func (m *MockLoadBalancer) LoadBalancerDeleteVips(lbName string, vips ...string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LoadBalancerDeleteVips", lbName, vips) + varargs := []interface{}{lbName} + for _, a := range vips { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "LoadBalancerDeleteVips", varargs...) ret0, _ := ret[0].(error) return ret0 } // LoadBalancerDeleteVips indicates an expected call of LoadBalancerDeleteVips. -func (mr *MockLoadBalancerMockRecorder) LoadBalancerDeleteVips(lbName, vips interface{}) *gomock.Call { +func (mr *MockLoadBalancerMockRecorder) LoadBalancerDeleteVips(lbName interface{}, vips ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadBalancerDeleteVips", reflect.TypeOf((*MockLoadBalancer)(nil).LoadBalancerDeleteVips), lbName, vips) + varargs := append([]interface{}{lbName}, vips...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadBalancerDeleteVips", reflect.TypeOf((*MockLoadBalancer)(nil).LoadBalancerDeleteVips), varargs...) } // LoadBalancerExists mocks base method. @@ -2718,17 +2723,22 @@ func (mr *MockOvnClientMockRecorder) LoadBalancerAddVips(lbName, vips interface{ } // LoadBalancerDeleteVips mocks base method. -func (m *MockOvnClient) LoadBalancerDeleteVips(lbName string, vips map[string]struct{}) error { +func (m *MockOvnClient) LoadBalancerDeleteVips(lbName string, vips ...string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LoadBalancerDeleteVips", lbName, vips) + varargs := []interface{}{lbName} + for _, a := range vips { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "LoadBalancerDeleteVips", varargs...) ret0, _ := ret[0].(error) return ret0 } // LoadBalancerDeleteVips indicates an expected call of LoadBalancerDeleteVips. -func (mr *MockOvnClientMockRecorder) LoadBalancerDeleteVips(lbName, vips interface{}) *gomock.Call { +func (mr *MockOvnClientMockRecorder) LoadBalancerDeleteVips(lbName interface{}, vips ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadBalancerDeleteVips", reflect.TypeOf((*MockOvnClient)(nil).LoadBalancerDeleteVips), lbName, vips) + varargs := append([]interface{}{lbName}, vips...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadBalancerDeleteVips", reflect.TypeOf((*MockOvnClient)(nil).LoadBalancerDeleteVips), varargs...) } // LoadBalancerExists mocks base method. diff --git a/pkg/ovs/interface.go b/pkg/ovs/interface.go index b79bdcf04fa..115003405d0 100644 --- a/pkg/ovs/interface.go +++ b/pkg/ovs/interface.go @@ -69,7 +69,7 @@ type LogicalSwitchPort interface { type LoadBalancer interface { CreateLoadBalancer(lbName, protocol, selectFields string) error LoadBalancerAddVips(lbName string, vips map[string]string) error - LoadBalancerDeleteVips(lbName string, vips map[string]struct{}) error + LoadBalancerDeleteVips(lbName string, vips ...string) error SetLoadBalancerAffinityTimeout(lbName string, timeout int) error DeleteLoadBalancers(filter func(lb *ovnnb.LoadBalancer) bool) error GetLoadBalancer(lbName string, ignoreNotFound bool) (*ovnnb.LoadBalancer, error) diff --git a/pkg/ovs/ovn-nb-load_balancer.go b/pkg/ovs/ovn-nb-load_balancer.go index 62f2af5569f..6e42b8d3ab3 100644 --- a/pkg/ovs/ovn-nb-load_balancer.go +++ b/pkg/ovs/ovn-nb-load_balancer.go @@ -3,6 +3,7 @@ package ovs import ( "context" "fmt" + "reflect" "strconv" "github.com/ovn-org/libovsdb/model" @@ -60,9 +61,19 @@ func (c *ovnClient) UpdateLoadBalancer(lb *ovnnb.LoadBalancer, fields ...interfa return nil } -// LoadBalancerAddVips add vips -func (c *ovnClient) LoadBalancerAddVips(lbName string, vips map[string]string) error { - if len(vips) == 0 { +func (c *ovnClient) loadBalancerUpdateVips(lbName string, vips interface{}) error { + var toAdd map[string]string + var toDelete []string + value := reflect.ValueOf(vips) + switch value.Type().Kind() { + case reflect.Slice: + toDelete = vips.([]string) + case reflect.Map: + toAdd = vips.(map[string]string) + default: + return fmt.Errorf("program error: invalid data type of vips %+v", vips) + } + if value.Len() == 0 { return nil } @@ -71,14 +82,21 @@ func (c *ovnClient) LoadBalancerAddVips(lbName string, vips map[string]string) e return err } - if lb.Vips == nil { - lb.Vips = make(map[string]string) + m := make(map[string]string, len(lb.Vips)+len(toAdd)) + for k, v := range lb.Vips { + m[k] = v } - - for vip, backends := range vips { - lb.Vips[vip] = backends + for k, v := range toAdd { + m[k] = v + } + for _, k := range toDelete { + delete(m, k) + } + if reflect.DeepEqual(m, lb.Vips) { + return nil } + lb.Vips = m if err := c.UpdateLoadBalancer(lb, &lb.Vips); err != nil { return fmt.Errorf("add vips %v to lb %s: %v", vips, lbName, err) } @@ -86,26 +104,14 @@ func (c *ovnClient) LoadBalancerAddVips(lbName string, vips map[string]string) e return nil } -// LoadBalancerDeleteVips delete load balancer vips -func (c *ovnClient) LoadBalancerDeleteVips(lbName string, vips map[string]struct{}) error { - if len(vips) == 0 { - return nil - } - - lb, err := c.GetLoadBalancer(lbName, false) - if err != nil { - return err - } - - for vip := range vips { - delete(lb.Vips, vip) - } - - if err := c.UpdateLoadBalancer(lb, &lb.Vips); err != nil { - return fmt.Errorf("delete vips %v from lb %s: %v", vips, lbName, err) - } +// LoadBalancerAddVips adds or updates vips +func (c *ovnClient) LoadBalancerAddVips(lbName string, vips map[string]string) error { + return c.loadBalancerUpdateVips(lbName, vips) +} - return nil +// LoadBalancerDeleteVips deletes load balancer vips +func (c *ovnClient) LoadBalancerDeleteVips(lbName string, vips ...string) error { + return c.loadBalancerUpdateVips(lbName, vips) } // SetLoadBalancerAffinityTimeout sets the LB's affinity timeout in seconds @@ -114,15 +120,18 @@ func (c *ovnClient) SetLoadBalancerAffinityTimeout(lbName string, timeout int) e if err != nil { return err } - - if lb.Options == nil { - lb.Options = make(map[string]string) + value := strconv.Itoa(timeout) + if len(lb.Options) != 0 && lb.Options["affinity_timeout"] == value { + return nil } - lb.Options["affinity_timeout"] = strconv.Itoa(timeout) - + options := make(map[string]string, len(lb.Options)+1) + for k, v := range lb.Options { + options[k] = v + } + options["affinity_timeout"] = value if err := c.UpdateLoadBalancer(lb, &lb.Options); err != nil { - return fmt.Errorf("set affinity timeout of lb %s to %d, err: %v", lbName, timeout, err) + return fmt.Errorf("failed to set affinity timeout of lb %s to %d: %v", lbName, timeout, err) } return nil diff --git a/pkg/ovs/ovn-nb-load_balancer_test.go b/pkg/ovs/ovn-nb-load_balancer_test.go index 997840e45f0..90880bb3d9c 100644 --- a/pkg/ovs/ovn-nb-load_balancer_test.go +++ b/pkg/ovs/ovn-nb-load_balancer_test.go @@ -187,11 +187,11 @@ func (suite *OvnClientTestSuite) testLoadBalancerDeleteVips() { err = ovnClient.LoadBalancerAddVips(lbName, vips) require.NoError(t, err) - err = ovnClient.LoadBalancerDeleteVips(lbName, map[string]struct{}{ - "10.96.0.1:443": {}, - "[fd00:10:96::e82f]:8080": {}, - "10.96.0.100:1443": {}, // non-existent vip - }) + err = ovnClient.LoadBalancerDeleteVips(lbName, + "10.96.0.1:443", + "[fd00:10:96::e82f]:8080", + "10.96.0.100:1443", // non-existent vip + ) require.NoError(t, err) lb, err := ovnClient.GetLoadBalancer(lbName, false) diff --git a/pkg/ovs/ovn-nb-logical_router_policy.go b/pkg/ovs/ovn-nb-logical_router_policy.go index df3e8fc4e86..b4c15b4d6f9 100644 --- a/pkg/ovs/ovn-nb-logical_router_policy.go +++ b/pkg/ovs/ovn-nb-logical_router_policy.go @@ -120,6 +120,9 @@ func (c *ovnClient) DeleteLogicalRouterPolicies(lrName string, priority int, ext if err != nil { return err } + if len(policies) == 0 { + return nil + } policiesUUIDs := make([]string, 0, len(policies)) for _, policy := range policies { diff --git a/pkg/ovs/ovn-nb-suite_test.go b/pkg/ovs/ovn-nb-suite_test.go index 7cd905defff..212e4bea64b 100644 --- a/pkg/ovs/ovn-nb-suite_test.go +++ b/pkg/ovs/ovn-nb-suite_test.go @@ -629,9 +629,7 @@ func Test_scratch(t *testing.T) { err = ovnClient.LoadBalancerAddVips(lbName, vips) require.NoError(t, err) - err = ovnClient.LoadBalancerDeleteVips(lbName, map[string]struct{}{ - "10.96.0.1:443": {}, - }) + err = ovnClient.LoadBalancerDeleteVips(lbName, "10.96.0.1:443") require.NoError(t, err) } diff --git a/pkg/ovs/ovn-nb_global.go b/pkg/ovs/ovn-nb_global.go index 7a3a65ac7c5..c54be3ec5d0 100644 --- a/pkg/ovs/ovn-nb_global.go +++ b/pkg/ovs/ovn-nb_global.go @@ -3,7 +3,7 @@ package ovs import ( "context" "fmt" - "strconv" + "reflect" "strings" "github.com/kubeovn/kube-ovn/pkg/ovsdb/ovnnb" @@ -55,13 +55,6 @@ func (c *ovnClient) GetNbGlobal() (*ovnnb.NBGlobal, error) { } func (c *ovnClient) UpdateNbGlobal(nbGlobal *ovnnb.NBGlobal, fields ...interface{}) error { - /* // list nb_global which connections != nil - op, err := c.Where(nbGlobal, model.Condition{ - Field: &nbGlobal.Connections, - Function: ovsdb.ConditionNotEqual, - Value: []string{""}, - }).Update(nbGlobal) */ - op, err := c.Where(nbGlobal).Update(nbGlobal, fields...) if err != nil { return fmt.Errorf("generate operations for updating nb global: %v", err) @@ -79,13 +72,11 @@ func (c *ovnClient) SetAzName(azName string) error { if err != nil { return fmt.Errorf("get nb global: %v", err) } - if azName == nbGlobal.Name { return nil // no need to update } nbGlobal.Name = azName - if err := c.UpdateNbGlobal(nbGlobal, &nbGlobal.Name); err != nil { return fmt.Errorf("set nb_global az name %s: %v", azName, err) } @@ -93,40 +84,58 @@ func (c *ovnClient) SetAzName(azName string) error { return nil } -func (c *ovnClient) SetUseCtInvMatch() error { +func (c *ovnClient) SetNbGlobalOptions(key string, value interface{}) error { nbGlobal, err := c.GetNbGlobal() if err != nil { - return fmt.Errorf("get nb global: %v", err) + return fmt.Errorf("failed to get nb global: %v", err) } - nbGlobal.Options["use_ct_inv_match"] = "false" + v := fmt.Sprintf("%v", value) + if len(nbGlobal.Options) != 0 && nbGlobal.Options[key] == v { + return nil + } + options := make(map[string]string, len(nbGlobal.Options)+1) + for k, v := range nbGlobal.Options { + options[k] = v + } + nbGlobal.Options[key] = v if err := c.UpdateNbGlobal(nbGlobal, &nbGlobal.Options); err != nil { - return fmt.Errorf("set use_ct_inv_match to false, %v", err) + return fmt.Errorf("failed to set nb global option %s to %v: %v", key, value, err) } return nil } +func (c *ovnClient) SetUseCtInvMatch() error { + return c.SetNbGlobalOptions("use_ct_inv_match", false) +} + func (c *ovnClient) SetICAutoRoute(enable bool, blackList []string) error { nbGlobal, err := c.GetNbGlobal() if err != nil { return fmt.Errorf("get nb global: %v", err) } + options := make(map[string]string, len(nbGlobal.Options)+3) + for k, v := range nbGlobal.Options { + options[k] = v + } if enable { - nbGlobal.Options = map[string]string{ - "ic-route-adv": "true", - "ic-route-learn": "true", - "ic-route-blacklist": strings.Join(blackList, ","), - } + options["ic-route-adv"] = "true" + options["ic-route-learn"] = "true" + options["ic-route-blacklist"] = strings.Join(blackList, ",") } else { - nbGlobal.Options = map[string]string{ - "ic-route-adv": "false", - "ic-route-learn": "false", - } + delete(options, "ic-route-adv") + delete(options, "ic-route-learn") + delete(options, "ic-route-blacklist") + } + if reflect.DeepEqual(options, nbGlobal.Options) { + nbGlobal.Options = options + return nil } + nbGlobal.Options = options if err := c.UpdateNbGlobal(nbGlobal, &nbGlobal.Options); err != nil { return fmt.Errorf("enable ovn-ic auto route, %v", err) } @@ -134,31 +143,9 @@ func (c *ovnClient) SetICAutoRoute(enable bool, blackList []string) error { } func (c *ovnClient) SetLBCIDR(serviceCIDR string) error { - nbGlobal, err := c.GetNbGlobal() - if err != nil { - return fmt.Errorf("get nb global: %v", err) - } - - nbGlobal.Options["svc_ipv4_cidr"] = serviceCIDR - - if err := c.UpdateNbGlobal(nbGlobal, &nbGlobal.Options); err != nil { - return fmt.Errorf("set svc cidr %s for lb, %v", serviceCIDR, err) - } - - return nil + return c.SetNbGlobalOptions("svc_ipv4_cidr", serviceCIDR) } func (c *ovnClient) SetLsDnatModDlDst(enabled bool) error { - nbGlobal, err := c.GetNbGlobal() - if err != nil { - return fmt.Errorf("get nb global: %v", err) - } - - nbGlobal.Options["ls_dnat_mod_dl_dst"] = strconv.FormatBool(enabled) - - if err := c.UpdateNbGlobal(nbGlobal, &nbGlobal.Options); err != nil { - return fmt.Errorf("set NB_Global option ls_dnat_mod_dl_dst to %v: %v", enabled, err) - } - - return nil + return c.SetNbGlobalOptions("ls_dnat_mod_dl_dst", enabled) } diff --git a/pkg/ovs/ovn-nb_global_test.go b/pkg/ovs/ovn-nb_global_test.go index 4d2dda86bd9..f57734de360 100644 --- a/pkg/ovs/ovn-nb_global_test.go +++ b/pkg/ovs/ovn-nb_global_test.go @@ -165,9 +165,9 @@ func (suite *OvnClientTestSuite) testSetICAutoRoute() { out, err := ovnClient.GetNbGlobal() require.NoError(t, err) - require.Equal(t, "false", out.Options["ic-route-adv"]) - require.Equal(t, "false", out.Options["ic-route-learn"]) - require.Empty(t, out.Options["ic-route-blacklist"]) + require.NotContains(t, out.Options, "ic-route-adv") + require.NotContains(t, out.Options, "ic-route-learn") + require.NotContains(t, out.Options, "ic-route-blacklist") }) }