Skip to content

Commit

Permalink
Don't restore eventFactory in case the connection has already been cl…
Browse files Browse the repository at this point in the history
…osed/unregistered

Signed-off-by: Artem Glazychev <artem.glazychev@xored.com>
  • Loading branch information
glazychev-art committed Sep 23, 2022
1 parent 4056e30 commit 439dd9b
Show file tree
Hide file tree
Showing 11 changed files with 542 additions and 16 deletions.
4 changes: 2 additions & 2 deletions pkg/networkservice/common/begin/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (b *beginClient) Request(ctx context.Context, request *networkservice.Netwo
<-eventFactoryClient.executor.AsyncExec(func() {
// If the eventFactory has changed, usually because the connection has been Closed and re-established
// go back to the beginning and try again.
currentEventFactoryClient, _ := b.LoadOrStore(request.GetConnection().GetId(), eventFactoryClient)
currentEventFactoryClient, _ := b.Load(request.GetConnection().GetId())
if currentEventFactoryClient != eventFactoryClient {
log.FromContext(ctx).Debug("recalling begin.Request because currentEventFactoryClient != eventFactoryClient")
conn, err = b.Request(ctx, request, opts...)
Expand Down Expand Up @@ -103,7 +103,7 @@ func (b *beginClient) Close(ctx context.Context, conn *networkservice.Connection
}

// If this isn't the connection we started with, do nothing
currentEventFactoryClient, _ := b.LoadOrStore(conn.GetId(), eventFactoryClient)
currentEventFactoryClient, _ := b.Load(conn.GetId())
if currentEventFactoryClient != eventFactoryClient {
return
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/networkservice/common/begin/event_factory.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021 Cisco and/or its affiliates.
// Copyright (c) 2021-2022 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -60,7 +60,7 @@ func newEventFactoryClient(ctx context.Context, afterClose func(), opts ...grpc.
client: next.Client(ctx),
opts: opts,
}
ctxFunc := postpone.Context(ctx)
ctxFunc := postpone.ContextWithValues(ctx)
f.ctxFunc = func() (context.Context, context.CancelFunc) {
eventCtx, cancel := ctxFunc()
return withEventFactory(eventCtx, f), cancel
Expand Down Expand Up @@ -155,7 +155,7 @@ func newEventFactoryServer(ctx context.Context, afterClose func()) *eventFactory
f := &eventFactoryServer{
server: next.Server(ctx),
}
ctxFunc := postpone.Context(ctx)
ctxFunc := postpone.ContextWithValues(ctx)
f.ctxFunc = func() (context.Context, context.CancelFunc) {
eventCtx, cancel := ctxFunc()
return withEventFactory(eventCtx, f), cancel
Expand Down
132 changes: 132 additions & 0 deletions pkg/networkservice/common/begin/event_factory_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// Copyright (c) 2022 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package begin_test

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/emptypb"

"github.com/networkservicemesh/api/pkg/api/networkservice"

"github.com/networkservicemesh/sdk/pkg/networkservice/common/begin"
"github.com/networkservicemesh/sdk/pkg/networkservice/core/chain"
"github.com/networkservicemesh/sdk/pkg/networkservice/core/next"
)

// This test reproduces the situation when Close and Request were called at the same time
// nolint:dupl
func TestRefreshDuringClose_Client(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

syncChan := make(chan struct{})
checkCtxCl := &checkContextClient{t: t}
eventFactoryCl := &eventFactoryClient{ch: syncChan}
client := chain.NewNetworkServiceClient(
begin.NewClient(),
checkCtxCl,
eventFactoryCl,
)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Set any value to context
ctx = context.WithValue(ctx, contextKey{}, "value_1")
checkCtxCl.setExpectedValue("value_1")

// Do Request with this context
request := testRequest("1")
conn, err := client.Request(ctx, request.Clone())
assert.NotNil(t, t, conn)
assert.NoError(t, err)

// Change context value before refresh Request
ctx = context.WithValue(ctx, contextKey{}, "value_2")
checkCtxCl.setExpectedValue("value_2")
request.Connection = conn.Clone()

// Call Close from eventFactory
eventFactoryCl.callClose()
<-syncChan

// Call refresh (should be called at the same time as Close)
conn, err = client.Request(ctx, request.Clone())
assert.NotNil(t, t, conn)
assert.NoError(t, err)

// Call refresh from eventFactory. We are expecting updated value in the context
eventFactoryCl.callRefresh()
<-syncChan
}

type eventFactoryClient struct {
ctx context.Context
ch chan<- struct{}
}

func (s *eventFactoryClient) Request(ctx context.Context, request *networkservice.NetworkServiceRequest, opts ...grpc.CallOption) (*networkservice.Connection, error) {
s.ctx = ctx
return next.Client(ctx).Request(ctx, request, opts...)
}

func (s *eventFactoryClient) Close(ctx context.Context, conn *networkservice.Connection, opts ...grpc.CallOption) (*emptypb.Empty, error) {
// Wait to be sure that rerequest was called
time.Sleep(time.Millisecond * 100)
return next.Client(ctx).Close(ctx, conn, opts...)
}

func (s *eventFactoryClient) callClose() {
eventFactory := begin.FromContext(s.ctx)
go func() {
s.ch <- struct{}{}
eventFactory.Close()
}()
}

func (s *eventFactoryClient) callRefresh() {
eventFactory := begin.FromContext(s.ctx)
go func() {
s.ch <- struct{}{}
eventFactory.Request()
}()
}

type contextKey struct{}

type checkContextClient struct {
t *testing.T
expectedValue string
}

func (c *checkContextClient) Request(ctx context.Context, request *networkservice.NetworkServiceRequest, opts ...grpc.CallOption) (*networkservice.Connection, error) {
assert.Equal(c.t, c.expectedValue, ctx.Value(contextKey{}))
return next.Client(ctx).Request(ctx, request, opts...)
}

func (c *checkContextClient) Close(ctx context.Context, conn *networkservice.Connection, opts ...grpc.CallOption) (*emptypb.Empty, error) {
return next.Client(ctx).Close(ctx, conn, opts...)
}

func (c *checkContextClient) setExpectedValue(value string) {
c.expectedValue = value
}
129 changes: 129 additions & 0 deletions pkg/networkservice/common/begin/event_factory_server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Copyright (c) 2022 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at:
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package begin_test

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
"google.golang.org/protobuf/types/known/emptypb"

"github.com/networkservicemesh/api/pkg/api/networkservice"

"github.com/networkservicemesh/sdk/pkg/networkservice/common/begin"
"github.com/networkservicemesh/sdk/pkg/networkservice/core/chain"
"github.com/networkservicemesh/sdk/pkg/networkservice/core/next"
)

// This test reproduces the situation when Close and Request were called at the same time
// nolint:dupl
func TestRefreshDuringClose_Server(t *testing.T) {
t.Cleanup(func() { goleak.VerifyNone(t) })

syncChan := make(chan struct{})
checkCtxServ := &checkContextServer{t: t}
eventFactoryServ := &eventFactoryServer{ch: syncChan}
server := chain.NewNetworkServiceServer(
begin.NewServer(),
checkCtxServ,
eventFactoryServ,
)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Set any value to context
ctx = context.WithValue(ctx, contextKey{}, "value_1")
checkCtxServ.setExpectedValue("value_1")

// Do Request with this context
request := testRequest("1")
conn, err := server.Request(ctx, request.Clone())
assert.NotNil(t, t, conn)
assert.NoError(t, err)

// Change context value before refresh Request
ctx = context.WithValue(ctx, contextKey{}, "value_2")
checkCtxServ.setExpectedValue("value_2")
request.Connection = conn.Clone()

// Call Close from eventFactory
eventFactoryServ.callClose()
<-syncChan

// Call refresh (should be called at the same time as Close)
conn, err = server.Request(ctx, request.Clone())
assert.NotNil(t, t, conn)
assert.NoError(t, err)

// Call refresh from eventFactory. We are expecting updated value in the context
eventFactoryServ.callRefresh()
<-syncChan
}

type eventFactoryServer struct {
ctx context.Context
ch chan<- struct{}
}

func (e *eventFactoryServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) {
e.ctx = ctx
return next.Server(ctx).Request(ctx, request)
}

func (e *eventFactoryServer) Close(ctx context.Context, conn *networkservice.Connection) (*emptypb.Empty, error) {
// Wait to be sure that rerequest was called
time.Sleep(time.Millisecond * 100)
return next.Server(ctx).Close(ctx, conn)
}

func (e *eventFactoryServer) callClose() {
eventFactory := begin.FromContext(e.ctx)
go func() {
e.ch <- struct{}{}
eventFactory.Close()
}()
}

func (e *eventFactoryServer) callRefresh() {
eventFactory := begin.FromContext(e.ctx)
go func() {
e.ch <- struct{}{}
eventFactory.Request()
}()
}

type checkContextServer struct {
t *testing.T
expectedValue string
}

func (c *checkContextServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) {
assert.Equal(c.t, c.expectedValue, ctx.Value(contextKey{}))
return next.Server(ctx).Request(ctx, request)
}

func (c *checkContextServer) Close(ctx context.Context, conn *networkservice.Connection) (*emptypb.Empty, error) {
return next.Server(ctx).Close(ctx, conn)
}

func (c *checkContextServer) setExpectedValue(value string) {
c.expectedValue = value
}
6 changes: 3 additions & 3 deletions pkg/networkservice/common/begin/server.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2021 Cisco and/or its affiliates.
// Copyright (c) 2021-2022 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -55,7 +55,7 @@ func (b *beginServer) Request(ctx context.Context, request *networkservice.Netwo
),
)
<-eventFactoryServer.executor.AsyncExec(func() {
currentEventFactoryServer, _ := b.LoadOrStore(request.GetConnection().GetId(), eventFactoryServer)
currentEventFactoryServer, _ := b.Load(request.GetConnection().GetId())
if currentEventFactoryServer != eventFactoryServer {
log.FromContext(ctx).Debug("recalling begin.Request because currentEventFactoryServer != eventFactoryServer")
conn, err = b.Request(ctx, request)
Expand Down Expand Up @@ -93,7 +93,7 @@ func (b *beginServer) Close(ctx context.Context, conn *networkservice.Connection
if eventFactoryServer.state != established || eventFactoryServer.request == nil {
return
}
currentServerClient, _ := b.LoadOrStore(conn.GetId(), eventFactoryServer)
currentServerClient, _ := b.Load(conn.GetId())
if currentServerClient != eventFactoryServer {
return
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/registry/common/begin/ns_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (b *beginNSClient) Register(ctx context.Context, in *registry.NetworkServic
<-eventFactoryClient.executor.AsyncExec(func() {
// If the eventFactory has changed, usually because the connection has been Closed and re-established
// go back to the beginning and try again.
currentEventFactoryClient, _ := b.LoadOrStore(id, eventFactoryClient)
currentEventFactoryClient, _ := b.Load(id)
if currentEventFactoryClient != eventFactoryClient {
log.FromContext(ctx).Debug("recalling begin.Request because currentEventFactoryClient != eventFactoryClient")
resp, err = b.Register(ctx, in, opts...)
Expand Down Expand Up @@ -101,7 +101,7 @@ func (b *beginNSClient) Unregister(ctx context.Context, in *registry.NetworkServ
}

// If this isn't the connection we started with, do nothing
currentEventFactoryClient, _ := b.LoadOrStore(id, eventFactoryClient)
currentEventFactoryClient, _ := b.Load(id)
if currentEventFactoryClient != eventFactoryClient {
return
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/registry/common/begin/ns_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (b *beginNSServer) Register(ctx context.Context, in *registry.NetworkServic
var err error

<-eventFactoryServer.executor.AsyncExec(func() {
currentEventFactoryServer, _ := b.LoadOrStore(id, eventFactoryServer)
currentEventFactoryServer, _ := b.Load(id)
if currentEventFactoryServer != eventFactoryServer {
log.FromContext(ctx).Debug("recalling begin.Request because currentEventFactoryServer != eventFactoryServer")
resp, err = b.Register(ctx, in)
Expand Down Expand Up @@ -96,7 +96,7 @@ func (b *beginNSServer) Unregister(ctx context.Context, in *registry.NetworkServ
if eventFactoryServer.state != established || eventFactoryServer.registration == nil {
return
}
currentServerClient, _ := b.LoadOrStore(id, eventFactoryServer)
currentServerClient, _ := b.Load(id)
if currentServerClient != eventFactoryServer {
return
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/registry/common/begin/nse_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (b *beginNSEClient) Register(ctx context.Context, in *registry.NetworkServi
<-eventFactoryClient.executor.AsyncExec(func() {
// If the eventFactory has changed, usually because the connection has been Closed and re-established
// go back to the beginning and try again.
currentEventFactoryClient, _ := b.LoadOrStore(id, eventFactoryClient)
currentEventFactoryClient, _ := b.Load(id)
if currentEventFactoryClient != eventFactoryClient {
log.FromContext(ctx).Debug("recalling begin.Request because currentEventFactoryClient != eventFactoryClient")
resp, err = b.Register(ctx, in, opts...)
Expand Down Expand Up @@ -101,7 +101,7 @@ func (b *beginNSEClient) Unregister(ctx context.Context, in *registry.NetworkSer
}

// If this isn't the connection we started with, do nothing
currentEventFactoryClient, _ := b.LoadOrStore(id, eventFactoryClient)
currentEventFactoryClient, _ := b.Load(id)
if currentEventFactoryClient != eventFactoryClient {
return
}
Expand Down

0 comments on commit 439dd9b

Please sign in to comment.