From d5462b10c9d26dc99e287b5cc111e5c4f94cc57b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E7=A5=96=E5=BB=BA?= Date: Fri, 30 Jun 2023 16:00:44 +0800 Subject: [PATCH] libovsdb: various bug fixes (#2998) --- mocks/pkg/ovs/interface.go | 8 +++--- pkg/controller/node.go | 8 +++--- pkg/ovs/interface.go | 2 +- pkg/ovs/ovn-nb-logical_router_policy.go | 23 ++++++---------- pkg/ovs/ovn-nb-logical_router_policy_test.go | 29 +++++++++++--------- pkg/ovs/ovn-nb-logical_router_route.go | 2 +- pkg/ovs/ovn-nb-nat.go | 2 +- 7 files changed, 35 insertions(+), 39 deletions(-) diff --git a/mocks/pkg/ovs/interface.go b/mocks/pkg/ovs/interface.go index c488252cba9..3eb9448326e 100644 --- a/mocks/pkg/ovs/interface.go +++ b/mocks/pkg/ovs/interface.go @@ -1674,10 +1674,10 @@ func (mr *MockLogicalRouterPolicyMockRecorder) DeleteLogicalRouterPolicyByUUID(l } // GetLogicalRouterPolicy mocks base method. -func (m *MockLogicalRouterPolicy) GetLogicalRouterPolicy(lrName string, priority int, match string, ignoreNotFound bool) (*ovnnb.LogicalRouterPolicy, error) { +func (m *MockLogicalRouterPolicy) GetLogicalRouterPolicy(lrName string, priority int, match string, ignoreNotFound bool) ([]*ovnnb.LogicalRouterPolicy, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetLogicalRouterPolicy", lrName, priority, match, ignoreNotFound) - ret0, _ := ret[0].(*ovnnb.LogicalRouterPolicy) + ret0, _ := ret[0].([]*ovnnb.LogicalRouterPolicy) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -2668,10 +2668,10 @@ func (mr *MockOvnClientMockRecorder) GetLogicalRouter(lrName, ignoreNotFound int } // GetLogicalRouterPolicy mocks base method. -func (m *MockOvnClient) GetLogicalRouterPolicy(lrName string, priority int, match string, ignoreNotFound bool) (*ovnnb.LogicalRouterPolicy, error) { +func (m *MockOvnClient) GetLogicalRouterPolicy(lrName string, priority int, match string, ignoreNotFound bool) ([]*ovnnb.LogicalRouterPolicy, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetLogicalRouterPolicy", lrName, priority, match, ignoreNotFound) - ret0, _ := ret[0].(*ovnnb.LogicalRouterPolicy) + ret0, _ := ret[0].([]*ovnnb.LogicalRouterPolicy) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/pkg/controller/node.go b/pkg/controller/node.go index ea2bde95f99..f269a70578d 100644 --- a/pkg/controller/node.go +++ b/pkg/controller/node.go @@ -1061,15 +1061,15 @@ func (c *Controller) getPolicyRouteParas(cidr string, priority int) (*strset.Set ipSuffix = "ip6" } match := fmt.Sprintf("%s.src == %s", ipSuffix, cidr) - policy, err := c.ovnClient.GetLogicalRouterPolicy(c.config.ClusterRouter, priority, match, true) + policyList, err := c.ovnClient.GetLogicalRouterPolicy(c.config.ClusterRouter, priority, match, true) if err != nil { klog.Errorf("failed to get logical router policy: %v", err) return nil, nil, err } - if policy == nil { - return nil, nil, err + if len(policyList) == 0 { + return strset.New(), map[string]string{}, nil } - return strset.New(policy.Nexthops...), policy.ExternalIDs, nil + return strset.New(policyList[0].Nexthops...), policyList[0].ExternalIDs, nil } func (c *Controller) checkPolicyRouteExistForNode(nodeName, cidr, nexthop string, priority int) (bool, error) { diff --git a/pkg/ovs/interface.go b/pkg/ovs/interface.go index b0963bc8f47..b0e66ecd5cc 100644 --- a/pkg/ovs/interface.go +++ b/pkg/ovs/interface.go @@ -138,7 +138,7 @@ type LogicalRouterPolicy interface { DeleteLogicalRouterPolicyByNexthop(lrName string, priority int, nexthop string) error ClearLogicalRouterPolicy(lrName string) error ListLogicalRouterPolicies(lrName string, priority int, externalIDs map[string]string) ([]*ovnnb.LogicalRouterPolicy, error) - GetLogicalRouterPolicy(lrName string, priority int, match string, ignoreNotFound bool) (*ovnnb.LogicalRouterPolicy, error) + GetLogicalRouterPolicy(lrName string, priority int, match string, ignoreNotFound bool) ([]*ovnnb.LogicalRouterPolicy, error) } type NAT interface { diff --git a/pkg/ovs/ovn-nb-logical_router_policy.go b/pkg/ovs/ovn-nb-logical_router_policy.go index 0d342c65146..5c16621b7f7 100644 --- a/pkg/ovs/ovn-nb-logical_router_policy.go +++ b/pkg/ovs/ovn-nb-logical_router_policy.go @@ -87,18 +87,15 @@ func (c *ovnClient) CreateLogicalRouterPolicies(lrName string, policies ...*ovnn // DeleteLogicalRouterPolicy delete policy from logical router func (c *ovnClient) DeleteLogicalRouterPolicy(lrName string, priority int, match string) error { - policy, err := c.GetLogicalRouterPolicy(lrName, priority, match, true) + policyList, err := c.GetLogicalRouterPolicy(lrName, priority, match, true) if err != nil { return err } - // not found, skip - if policy == nil { - return nil - } - - if err := c.DeleteLogicalRouterPolicyByUUID(lrName, policy.UUID); err != nil { - return err + for _, p := range policyList { + if err := c.DeleteLogicalRouterPolicyByUUID(lrName, p.UUID); err != nil { + return err + } } return nil @@ -182,7 +179,7 @@ func (c *ovnClient) ClearLogicalRouterPolicy(lrName string) error { // GetLogicalRouterPolicy get logical router policy by priority and match, // be consistent with ovn-nbctl which priority and match determine one policy in logical router -func (c *ovnClient) GetLogicalRouterPolicy(lrName string, priority int, match string, ignoreNotFound bool) (*ovnnb.LogicalRouterPolicy, error) { +func (c *ovnClient) GetLogicalRouterPolicy(lrName string, priority int, match string, ignoreNotFound bool) ([]*ovnnb.LogicalRouterPolicy, error) { // this is necessary because may exist same priority and match policy in different logical router if len(lrName) == 0 { return nil, fmt.Errorf("the logical router name is required") @@ -204,11 +201,7 @@ func (c *ovnClient) GetLogicalRouterPolicy(lrName string, priority int, match st return nil, fmt.Errorf("not found policy priority %d match %s in logical router %s", priority, match, lrName) } - if len(policyList) > 1 { - return nil, fmt.Errorf("more than one policy with same priority %d match %s in logical router %s", priority, match, lrName) - } - - return policyList[0], nil + return policyList, nil } // GetLogicalRouterPolicyByUUID get logical router policy by UUID @@ -218,7 +211,7 @@ func (c *ovnClient) GetLogicalRouterPolicyByUUID(uuid string) (*ovnnb.LogicalRou policy := &ovnnb.LogicalRouterPolicy{UUID: uuid} if err := c.Get(ctx, policy); err != nil { - return nil, fmt.Errorf("get logical router policy by UUID %s: %v", uuid, err) + return nil, err } return policy, nil diff --git a/pkg/ovs/ovn-nb-logical_router_policy_test.go b/pkg/ovs/ovn-nb-logical_router_policy_test.go index eab2603f1cd..c0c77da23ec 100644 --- a/pkg/ovs/ovn-nb-logical_router_policy_test.go +++ b/pkg/ovs/ovn-nb-logical_router_policy_test.go @@ -42,9 +42,10 @@ func (suite *OvnClientTestSuite) testAddLogicalRouterPolicy() { lr, err := ovnClient.GetLogicalRouter(lrName, false) require.NoError(t, err) - policy, err := ovnClient.GetLogicalRouterPolicy(lrName, priority, match, false) + policyList, err := ovnClient.GetLogicalRouterPolicy(lrName, priority, match, false) require.NoError(t, err) - require.Contains(t, lr.Policies, policy.UUID) + require.Len(t, policyList, 1) + require.Contains(t, lr.Policies, policyList[0].UUID) err = ovnClient.AddLogicalRouterPolicy(lrName, priority, match, action, nextHops, nil) require.NoError(t, err) @@ -80,11 +81,11 @@ func (suite *OvnClientTestSuite) testCreateLogicalRouterPolicies() { for i := 0; i < 3; i++ { match := fmt.Sprintf("%s && tcp.dst == %d", matchPrefix, basePort+i) - policy, err := ovnClient.GetLogicalRouterPolicy(lrName, priority, match, false) + policyList, err := ovnClient.GetLogicalRouterPolicy(lrName, priority, match, false) require.NoError(t, err) - require.Equal(t, match, policy.Match) - - require.Contains(t, lr.Policies, policy.UUID) + require.Len(t, policyList, 1) + require.Equal(t, match, policyList[0].Match) + require.Contains(t, lr.Policies, policyList[0].UUID) } }) } @@ -110,9 +111,10 @@ func (suite *OvnClientTestSuite) testDeleteLogicalRouterPolicy() { lr, err := ovnClient.GetLogicalRouter(lrName, false) require.NoError(t, err) - policy, err := ovnClient.GetLogicalRouterPolicy(lrName, priority, match, false) + policyList, err := ovnClient.GetLogicalRouterPolicy(lrName, priority, match, false) require.NoError(t, err) - require.Contains(t, lr.Policies, policy.UUID) + require.Len(t, policyList, 1) + require.Contains(t, lr.Policies, policyList[0].UUID) err = ovnClient.DeleteLogicalRouterPolicy(lrName, priority, match) require.NoError(t, err) @@ -122,7 +124,7 @@ func (suite *OvnClientTestSuite) testDeleteLogicalRouterPolicy() { lr, err = ovnClient.GetLogicalRouter(lrName, false) require.NoError(t, err) - require.NotContains(t, lr.Policies, policy.UUID) + require.NotContains(t, lr.Policies, policyList[0].UUID) }) } @@ -255,11 +257,12 @@ func (suite *OvnClientTestSuite) testGetLogicalRouterPolicy() { t.Run("priority and match are same", func(t *testing.T) { t.Parallel() - policy, err := ovnClient.GetLogicalRouterPolicy(lrName, priority, match, false) + policyList, err := ovnClient.GetLogicalRouterPolicy(lrName, priority, match, false) require.NoError(t, err) - require.Equal(t, priority, policy.Priority) - require.Equal(t, match, policy.Match) - require.Equal(t, ovnnb.LogicalRouterPolicyActionAllow, policy.Action) + require.Len(t, policyList, 1) + require.Equal(t, priority, policyList[0].Priority) + require.Equal(t, match, policyList[0].Match) + require.Equal(t, ovnnb.LogicalRouterPolicyActionAllow, policyList[0].Action) }) t.Run("priority and match are not all same", func(t *testing.T) { diff --git a/pkg/ovs/ovn-nb-logical_router_route.go b/pkg/ovs/ovn-nb-logical_router_route.go index 7a6e8a4265c..408fcdf9de1 100644 --- a/pkg/ovs/ovn-nb-logical_router_route.go +++ b/pkg/ovs/ovn-nb-logical_router_route.go @@ -183,7 +183,7 @@ func (c *ovnClient) GetLogicalRouterStaticRouteByUUID(uuid string) (*ovnnb.Logic route := &ovnnb.LogicalRouterStaticRoute{UUID: uuid} if err := c.Get(ctx, route); err != nil { - return nil, fmt.Errorf("get logical router static route by UUID %s: %v", uuid, err) + return nil, err } return route, nil diff --git a/pkg/ovs/ovn-nb-nat.go b/pkg/ovs/ovn-nb-nat.go index ed4445bee46..6f5c0840e41 100644 --- a/pkg/ovs/ovn-nb-nat.go +++ b/pkg/ovs/ovn-nb-nat.go @@ -199,7 +199,7 @@ func (c *ovnClient) GetNATByUUID(uuid string) (*ovnnb.NAT, error) { nat := &ovnnb.NAT{UUID: uuid} if err := c.Get(ctx, nat); err != nil { - return nil, fmt.Errorf("get NAT by UUID %s: %v", uuid, err) + return nil, err } return nat, nil