diff --git a/api/ca.pb.go b/api/ca.pb.go index 619421b3b0..343a182e7c 100644 --- a/api/ca.pb.go +++ b/api/ca.pb.go @@ -836,12 +836,12 @@ func encodeVarintCa(data []byte, offset int, v uint64) int { } type raftProxyCAServer struct { - local CAServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local CAServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyCAServer(local CAServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) CAServer { +func NewRaftProxyCAServer(local CAServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) CAServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -858,18 +858,24 @@ func NewRaftProxyCAServer(local CAServer, connSelector raftselector.ConnProvider md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyCAServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyCAServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyCAServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -906,11 +912,15 @@ func (p *raftProxyCAServer) GetRootCACertificate(ctx context.Context, r *GetRoot conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.GetRootCACertificate(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -937,11 +947,15 @@ func (p *raftProxyCAServer) GetUnlockKey(ctx context.Context, r *GetUnlockKeyReq conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.GetUnlockKey(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -964,12 +978,12 @@ func (p *raftProxyCAServer) GetUnlockKey(ctx context.Context, r *GetUnlockKeyReq } type raftProxyNodeCAServer struct { - local NodeCAServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local NodeCAServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyNodeCAServer(local NodeCAServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) NodeCAServer { +func NewRaftProxyNodeCAServer(local NodeCAServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) NodeCAServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -986,18 +1000,24 @@ func NewRaftProxyNodeCAServer(local NodeCAServer, connSelector raftselector.Conn md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyNodeCAServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyNodeCAServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyNodeCAServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -1034,11 +1054,15 @@ func (p *raftProxyNodeCAServer) IssueNodeCertificate(ctx context.Context, r *Iss conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.IssueNodeCertificate(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -1065,11 +1089,15 @@ func (p *raftProxyNodeCAServer) NodeCertificateStatus(ctx context.Context, r *No conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.NodeCertificateStatus(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } diff --git a/api/control.pb.go b/api/control.pb.go index 6f36208c2c..17b4f4113d 100644 --- a/api/control.pb.go +++ b/api/control.pb.go @@ -5256,12 +5256,12 @@ func encodeVarintControl(data []byte, offset int, v uint64) int { } type raftProxyControlServer struct { - local ControlServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local ControlServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyControlServer(local ControlServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) ControlServer { +func NewRaftProxyControlServer(local ControlServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) ControlServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -5278,18 +5278,24 @@ func NewRaftProxyControlServer(local ControlServer, connSelector raftselector.Co md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyControlServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyControlServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyControlServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -5326,11 +5332,15 @@ func (p *raftProxyControlServer) GetNode(ctx context.Context, r *GetNodeRequest) conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.GetNode(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5357,11 +5367,15 @@ func (p *raftProxyControlServer) ListNodes(ctx context.Context, r *ListNodesRequ conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.ListNodes(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5388,11 +5402,15 @@ func (p *raftProxyControlServer) UpdateNode(ctx context.Context, r *UpdateNodeRe conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.UpdateNode(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5419,11 +5437,15 @@ func (p *raftProxyControlServer) RemoveNode(ctx context.Context, r *RemoveNodeRe conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.RemoveNode(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5450,11 +5472,15 @@ func (p *raftProxyControlServer) GetTask(ctx context.Context, r *GetTaskRequest) conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.GetTask(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5481,11 +5507,15 @@ func (p *raftProxyControlServer) ListTasks(ctx context.Context, r *ListTasksRequ conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.ListTasks(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5512,11 +5542,15 @@ func (p *raftProxyControlServer) RemoveTask(ctx context.Context, r *RemoveTaskRe conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.RemoveTask(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5543,11 +5577,15 @@ func (p *raftProxyControlServer) GetService(ctx context.Context, r *GetServiceRe conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.GetService(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5574,11 +5612,15 @@ func (p *raftProxyControlServer) ListServices(ctx context.Context, r *ListServic conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.ListServices(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5605,11 +5647,15 @@ func (p *raftProxyControlServer) CreateService(ctx context.Context, r *CreateSer conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.CreateService(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5636,11 +5682,15 @@ func (p *raftProxyControlServer) UpdateService(ctx context.Context, r *UpdateSer conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.UpdateService(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5667,11 +5717,15 @@ func (p *raftProxyControlServer) RemoveService(ctx context.Context, r *RemoveSer conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.RemoveService(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5698,11 +5752,15 @@ func (p *raftProxyControlServer) GetNetwork(ctx context.Context, r *GetNetworkRe conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.GetNetwork(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5729,11 +5787,15 @@ func (p *raftProxyControlServer) ListNetworks(ctx context.Context, r *ListNetwor conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.ListNetworks(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5760,11 +5822,15 @@ func (p *raftProxyControlServer) CreateNetwork(ctx context.Context, r *CreateNet conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.CreateNetwork(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5791,11 +5857,15 @@ func (p *raftProxyControlServer) RemoveNetwork(ctx context.Context, r *RemoveNet conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.RemoveNetwork(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5822,11 +5892,15 @@ func (p *raftProxyControlServer) GetCluster(ctx context.Context, r *GetClusterRe conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.GetCluster(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5853,11 +5927,15 @@ func (p *raftProxyControlServer) ListClusters(ctx context.Context, r *ListCluste conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.ListClusters(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5884,11 +5962,15 @@ func (p *raftProxyControlServer) UpdateCluster(ctx context.Context, r *UpdateClu conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.UpdateCluster(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5915,11 +5997,15 @@ func (p *raftProxyControlServer) GetSecret(ctx context.Context, r *GetSecretRequ conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.GetSecret(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5946,11 +6032,15 @@ func (p *raftProxyControlServer) UpdateSecret(ctx context.Context, r *UpdateSecr conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.UpdateSecret(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5977,11 +6067,15 @@ func (p *raftProxyControlServer) ListSecrets(ctx context.Context, r *ListSecrets conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.ListSecrets(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -6008,11 +6102,15 @@ func (p *raftProxyControlServer) CreateSecret(ctx context.Context, r *CreateSecr conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.CreateSecret(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -6039,11 +6137,15 @@ func (p *raftProxyControlServer) RemoveSecret(ctx context.Context, r *RemoveSecr conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.RemoveSecret(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } diff --git a/api/dispatcher.pb.go b/api/dispatcher.pb.go index 751c48d37c..7c5c70b5e7 100644 --- a/api/dispatcher.pb.go +++ b/api/dispatcher.pb.go @@ -1670,12 +1670,12 @@ func encodeVarintDispatcher(data []byte, offset int, v uint64) int { } type raftProxyDispatcherServer struct { - local DispatcherServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local DispatcherServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyDispatcherServer(local DispatcherServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) DispatcherServer { +func NewRaftProxyDispatcherServer(local DispatcherServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) DispatcherServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -1692,18 +1692,24 @@ func NewRaftProxyDispatcherServer(local DispatcherServer, connSelector raftselec md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyDispatcherServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyDispatcherServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyDispatcherServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -1735,17 +1741,33 @@ func (p *raftProxyDispatcherServer) pollNewLeaderConn(ctx context.Context) (*grp } } -func (p *raftProxyDispatcherServer) Session(r *SessionRequest, stream Dispatcher_SessionServer) error { +type Dispatcher_SessionServerWrapper struct { + Dispatcher_SessionServer + ctx context.Context +} +func (s Dispatcher_SessionServerWrapper) Context() context.Context { + return s.ctx +} + +func (p *raftProxyDispatcherServer) Session(r *SessionRequest, stream Dispatcher_SessionServer) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.Session(r, stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := Dispatcher_SessionServerWrapper{ + Dispatcher_SessionServer: stream, + ctx: ctx, + } + return p.local.Session(r, streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err } @@ -1775,11 +1797,15 @@ func (p *raftProxyDispatcherServer) Heartbeat(ctx context.Context, r *HeartbeatR conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.Heartbeat(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -1806,11 +1832,15 @@ func (p *raftProxyDispatcherServer) UpdateTaskStatus(ctx context.Context, r *Upd conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.UpdateTaskStatus(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -1832,17 +1862,33 @@ func (p *raftProxyDispatcherServer) UpdateTaskStatus(ctx context.Context, r *Upd return resp, err } -func (p *raftProxyDispatcherServer) Tasks(r *TasksRequest, stream Dispatcher_TasksServer) error { +type Dispatcher_TasksServerWrapper struct { + Dispatcher_TasksServer + ctx context.Context +} + +func (s Dispatcher_TasksServerWrapper) Context() context.Context { + return s.ctx +} +func (p *raftProxyDispatcherServer) Tasks(r *TasksRequest, stream Dispatcher_TasksServer) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.Tasks(r, stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := Dispatcher_TasksServerWrapper{ + Dispatcher_TasksServer: stream, + ctx: ctx, + } + return p.local.Tasks(r, streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err } @@ -1867,17 +1913,33 @@ func (p *raftProxyDispatcherServer) Tasks(r *TasksRequest, stream Dispatcher_Tas return nil } -func (p *raftProxyDispatcherServer) Assignments(r *AssignmentsRequest, stream Dispatcher_AssignmentsServer) error { +type Dispatcher_AssignmentsServerWrapper struct { + Dispatcher_AssignmentsServer + ctx context.Context +} + +func (s Dispatcher_AssignmentsServerWrapper) Context() context.Context { + return s.ctx +} +func (p *raftProxyDispatcherServer) Assignments(r *AssignmentsRequest, stream Dispatcher_AssignmentsServer) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.Assignments(r, stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := Dispatcher_AssignmentsServerWrapper{ + Dispatcher_AssignmentsServer: stream, + ctx: ctx, + } + return p.local.Assignments(r, streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err } diff --git a/api/health.pb.go b/api/health.pb.go index 13c40143df..5e53c97bd0 100644 --- a/api/health.pb.go +++ b/api/health.pb.go @@ -321,12 +321,12 @@ func encodeVarintHealth(data []byte, offset int, v uint64) int { } type raftProxyHealthServer struct { - local HealthServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local HealthServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyHealthServer(local HealthServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) HealthServer { +func NewRaftProxyHealthServer(local HealthServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) HealthServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -343,18 +343,24 @@ func NewRaftProxyHealthServer(local HealthServer, connSelector raftselector.Conn md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyHealthServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyHealthServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyHealthServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -391,11 +397,15 @@ func (p *raftProxyHealthServer) Check(ctx context.Context, r *HealthCheckRequest conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.Check(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } diff --git a/api/logbroker.pb.go b/api/logbroker.pb.go index a066add9eb..b3dc176c96 100644 --- a/api/logbroker.pb.go +++ b/api/logbroker.pb.go @@ -1279,12 +1279,12 @@ func encodeVarintLogbroker(data []byte, offset int, v uint64) int { } type raftProxyLogsServer struct { - local LogsServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local LogsServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyLogsServer(local LogsServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) LogsServer { +func NewRaftProxyLogsServer(local LogsServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) LogsServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -1301,18 +1301,24 @@ func NewRaftProxyLogsServer(local LogsServer, connSelector raftselector.ConnProv md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyLogsServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyLogsServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyLogsServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -1344,17 +1350,33 @@ func (p *raftProxyLogsServer) pollNewLeaderConn(ctx context.Context) (*grpc.Clie } } -func (p *raftProxyLogsServer) SubscribeLogs(r *SubscribeLogsRequest, stream Logs_SubscribeLogsServer) error { +type Logs_SubscribeLogsServerWrapper struct { + Logs_SubscribeLogsServer + ctx context.Context +} +func (s Logs_SubscribeLogsServerWrapper) Context() context.Context { + return s.ctx +} + +func (p *raftProxyLogsServer) SubscribeLogs(r *SubscribeLogsRequest, stream Logs_SubscribeLogsServer) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.SubscribeLogs(r, stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := Logs_SubscribeLogsServerWrapper{ + Logs_SubscribeLogsServer: stream, + ctx: ctx, + } + return p.local.SubscribeLogs(r, streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err } @@ -1380,12 +1402,12 @@ func (p *raftProxyLogsServer) SubscribeLogs(r *SubscribeLogsRequest, stream Logs } type raftProxyLogBrokerServer struct { - local LogBrokerServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local LogBrokerServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyLogBrokerServer(local LogBrokerServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) LogBrokerServer { +func NewRaftProxyLogBrokerServer(local LogBrokerServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) LogBrokerServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -1402,18 +1424,24 @@ func NewRaftProxyLogBrokerServer(local LogBrokerServer, connSelector raftselecto md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyLogBrokerServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyLogBrokerServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyLogBrokerServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -1445,17 +1473,33 @@ func (p *raftProxyLogBrokerServer) pollNewLeaderConn(ctx context.Context) (*grpc } } -func (p *raftProxyLogBrokerServer) ListenSubscriptions(r *ListenSubscriptionsRequest, stream LogBroker_ListenSubscriptionsServer) error { +type LogBroker_ListenSubscriptionsServerWrapper struct { + LogBroker_ListenSubscriptionsServer + ctx context.Context +} +func (s LogBroker_ListenSubscriptionsServerWrapper) Context() context.Context { + return s.ctx +} + +func (p *raftProxyLogBrokerServer) ListenSubscriptions(r *ListenSubscriptionsRequest, stream LogBroker_ListenSubscriptionsServer) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.ListenSubscriptions(r, stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := LogBroker_ListenSubscriptionsServerWrapper{ + LogBroker_ListenSubscriptionsServer: stream, + ctx: ctx, + } + return p.local.ListenSubscriptions(r, streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err } @@ -1480,17 +1524,33 @@ func (p *raftProxyLogBrokerServer) ListenSubscriptions(r *ListenSubscriptionsReq return nil } -func (p *raftProxyLogBrokerServer) PublishLogs(stream LogBroker_PublishLogsServer) error { +type LogBroker_PublishLogsServerWrapper struct { + LogBroker_PublishLogsServer + ctx context.Context +} +func (s LogBroker_PublishLogsServerWrapper) Context() context.Context { + return s.ctx +} + +func (p *raftProxyLogBrokerServer) PublishLogs(stream LogBroker_PublishLogsServer) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.PublishLogs(stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := LogBroker_PublishLogsServerWrapper{ + LogBroker_PublishLogsServer: stream, + ctx: ctx, + } + return p.local.PublishLogs(streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err } diff --git a/api/raft.pb.go b/api/raft.pb.go index e824d66a9c..8a96952f76 100644 --- a/api/raft.pb.go +++ b/api/raft.pb.go @@ -1498,12 +1498,12 @@ func encodeVarintRaft(data []byte, offset int, v uint64) int { } type raftProxyRaftServer struct { - local RaftServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local RaftServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyRaftServer(local RaftServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) RaftServer { +func NewRaftProxyRaftServer(local RaftServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) RaftServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -1520,18 +1520,24 @@ func NewRaftProxyRaftServer(local RaftServer, connSelector raftselector.ConnProv md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyRaftServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyRaftServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyRaftServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -1568,11 +1574,15 @@ func (p *raftProxyRaftServer) ProcessRaftMessage(ctx context.Context, r *Process conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.ProcessRaftMessage(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -1599,11 +1609,15 @@ func (p *raftProxyRaftServer) ResolveAddress(ctx context.Context, r *ResolveAddr conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.ResolveAddress(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -1626,12 +1640,12 @@ func (p *raftProxyRaftServer) ResolveAddress(ctx context.Context, r *ResolveAddr } type raftProxyRaftMembershipServer struct { - local RaftMembershipServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local RaftMembershipServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyRaftMembershipServer(local RaftMembershipServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) RaftMembershipServer { +func NewRaftProxyRaftMembershipServer(local RaftMembershipServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) RaftMembershipServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -1648,18 +1662,24 @@ func NewRaftProxyRaftMembershipServer(local RaftMembershipServer, connSelector r md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyRaftMembershipServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyRaftMembershipServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyRaftMembershipServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -1696,11 +1716,15 @@ func (p *raftProxyRaftMembershipServer) Join(ctx context.Context, r *JoinRequest conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.Join(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -1727,11 +1751,15 @@ func (p *raftProxyRaftMembershipServer) Leave(ctx context.Context, r *LeaveReque conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.Leave(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } diff --git a/api/resource.pb.go b/api/resource.pb.go index 52d1e4e4ab..a764a6ccee 100644 --- a/api/resource.pb.go +++ b/api/resource.pb.go @@ -451,12 +451,12 @@ func encodeVarintResource(data []byte, offset int, v uint64) int { } type raftProxyResourceAllocatorServer struct { - local ResourceAllocatorServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local ResourceAllocatorServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyResourceAllocatorServer(local ResourceAllocatorServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) ResourceAllocatorServer { +func NewRaftProxyResourceAllocatorServer(local ResourceAllocatorServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) ResourceAllocatorServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -473,18 +473,24 @@ func NewRaftProxyResourceAllocatorServer(local ResourceAllocatorServer, connSele md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyResourceAllocatorServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyResourceAllocatorServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyResourceAllocatorServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -521,11 +527,15 @@ func (p *raftProxyResourceAllocatorServer) AttachNetwork(ctx context.Context, r conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.AttachNetwork(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -552,11 +562,15 @@ func (p *raftProxyResourceAllocatorServer) DetachNetwork(ctx context.Context, r conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.DetachNetwork(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } diff --git a/ca/auth.go b/ca/auth.go index d81b543da0..bc7c629a54 100644 --- a/ca/auth.go +++ b/ca/auth.go @@ -16,6 +16,13 @@ import ( "google.golang.org/grpc/peer" ) +type localRequestKeyType struct{} + +// LocalRequestKey is a context key to mark a request that originating on the +// local node. The assocated value is a RemoteNodeInfo structure describing the +// local node. +var LocalRequestKey = localRequestKeyType{} + // LogTLSState logs information about the TLS connection and remote peers func LogTLSState(ctx context.Context, tlsState *tls.ConnectionState) { if tlsState == nil { @@ -189,6 +196,17 @@ type RemoteNodeInfo struct { // well as the forwarder's ID. This function does not do authorization checks - // it only looks up the node ID. func RemoteNode(ctx context.Context) (RemoteNodeInfo, error) { + // If we have a value on the context that marks this as a local + // request, we return the node info from the context. + localNodeInfo := ctx.Value(LocalRequestKey) + + if localNodeInfo != nil { + nodeInfo, ok := localNodeInfo.(RemoteNodeInfo) + if ok { + return nodeInfo, nil + } + } + certSubj, err := certSubjectFromContext(ctx) if err != nil { return RemoteNodeInfo{}, err diff --git a/ca/server.go b/ca/server.go index fa55d38534..29fc5a217f 100644 --- a/ca/server.go +++ b/ca/server.go @@ -211,6 +211,15 @@ func (s *Server) IssueNodeCertificate(ctx context.Context, request *api.IssueNod blacklistedCerts = clusters[0].BlacklistedCertificates } + // Renewing the cert with a local (unix socket) is always valid. + localNodeInfo := ctx.Value(LocalRequestKey) + if localNodeInfo != nil { + nodeInfo, ok := localNodeInfo.(RemoteNodeInfo) + if ok && nodeInfo.NodeID != "" { + return s.issueRenewCertificate(ctx, nodeInfo.NodeID, request.CSR) + } + } + // If the remote node is a worker (either forwarded by a manager, or calling directly), // issue a renew worker certificate entry with the correct ID nodeID, err := AuthorizeForwardedRoleAndOrg(ctx, []string{WorkerRole}, []string{ManagerRole}, s.securityConfig.ClientTLSCreds.Organization(), blacklistedCerts) diff --git a/manager/manager.go b/manager/manager.go index 649f229532..39e17295b2 100644 --- a/manager/manager.go +++ b/manager/manager.go @@ -328,23 +328,46 @@ func (m *Manager) Run(parent context.Context) error { authenticatedHealthAPI := api.NewAuthenticatedWrapperHealthServer(healthServer, authorize) authenticatedRaftMembershipAPI := api.NewAuthenticatedWrapperRaftMembershipServer(m.raftNode, authorize) - proxyDispatcherAPI := api.NewRaftProxyDispatcherServer(authenticatedDispatcherAPI, m.raftNode, ca.WithMetadataForwardTLSInfo) - proxyCAAPI := api.NewRaftProxyCAServer(authenticatedCAAPI, m.raftNode, ca.WithMetadataForwardTLSInfo) - proxyNodeCAAPI := api.NewRaftProxyNodeCAServer(authenticatedNodeCAAPI, m.raftNode, ca.WithMetadataForwardTLSInfo) - proxyRaftMembershipAPI := api.NewRaftProxyRaftMembershipServer(authenticatedRaftMembershipAPI, m.raftNode, ca.WithMetadataForwardTLSInfo) - proxyResourceAPI := api.NewRaftProxyResourceAllocatorServer(authenticatedResourceAPI, m.raftNode, ca.WithMetadataForwardTLSInfo) - proxyLogBrokerAPI := api.NewRaftProxyLogBrokerServer(authenticatedLogBrokerAPI, m.raftNode, ca.WithMetadataForwardTLSInfo) - - // localProxyControlAPI is a special kind of proxy. It is only wired up - // to receive requests from a trusted local socket, and these requests - // don't use TLS, therefore the requests it handles locally should - // bypass authorization. When it proxies, it sends them as requests from - // this manager rather than forwarded requests (it has no TLS - // information to put in the metadata map). + proxyDispatcherAPI := api.NewRaftProxyDispatcherServer(authenticatedDispatcherAPI, m.raftNode, nil, ca.WithMetadataForwardTLSInfo) + proxyCAAPI := api.NewRaftProxyCAServer(authenticatedCAAPI, m.raftNode, nil, ca.WithMetadataForwardTLSInfo) + proxyNodeCAAPI := api.NewRaftProxyNodeCAServer(authenticatedNodeCAAPI, m.raftNode, nil, ca.WithMetadataForwardTLSInfo) + proxyRaftMembershipAPI := api.NewRaftProxyRaftMembershipServer(authenticatedRaftMembershipAPI, m.raftNode, nil, ca.WithMetadataForwardTLSInfo) + proxyResourceAPI := api.NewRaftProxyResourceAllocatorServer(authenticatedResourceAPI, m.raftNode, nil, ca.WithMetadataForwardTLSInfo) + proxyLogBrokerAPI := api.NewRaftProxyLogBrokerServer(authenticatedLogBrokerAPI, m.raftNode, nil, ca.WithMetadataForwardTLSInfo) + + // The following local proxies are only wired up to receive requests + // from a trusted local socket, and these requests don't use TLS, + // therefore the requests they handle locally should bypass + // authorization. When requests are proxied from these servers, they + // are sent as requests from this manager rather than forwarded + // requests (it has no TLS information to put in the metadata map). forwardAsOwnRequest := func(ctx context.Context) (context.Context, error) { return ctx, nil } - localProxyControlAPI := api.NewRaftProxyControlServer(baseControlAPI, m.raftNode, forwardAsOwnRequest) - localProxyLogsAPI := api.NewRaftProxyLogsServer(m.logbroker, m.raftNode, forwardAsOwnRequest) - localCAAPI := api.NewRaftProxyCAServer(m.caserver, m.raftNode, forwardAsOwnRequest) + handleRequestLocally := func(ctx context.Context) (context.Context, error) { + var remoteAddr string + if m.config.RemoteAPI.AdvertiseAddr != "" { + remoteAddr = m.config.RemoteAPI.AdvertiseAddr + } else { + remoteAddr = m.config.RemoteAPI.ListenAddr + } + + creds := m.config.SecurityConfig.ClientTLSCreds + + nodeInfo := ca.RemoteNodeInfo{ + Roles: []string{creds.Role()}, + Organization: creds.Organization(), + NodeID: creds.NodeID(), + RemoteAddr: remoteAddr, + } + + return context.WithValue(ctx, ca.LocalRequestKey, nodeInfo), nil + } + localProxyControlAPI := api.NewRaftProxyControlServer(baseControlAPI, m.raftNode, handleRequestLocally, forwardAsOwnRequest) + localProxyLogsAPI := api.NewRaftProxyLogsServer(m.logbroker, m.raftNode, handleRequestLocally, forwardAsOwnRequest) + localProxyDispatcherAPI := api.NewRaftProxyDispatcherServer(m.dispatcher, m.raftNode, handleRequestLocally, forwardAsOwnRequest) + localProxyCAAPI := api.NewRaftProxyCAServer(m.caserver, m.raftNode, handleRequestLocally, forwardAsOwnRequest) + localProxyNodeCAAPI := api.NewRaftProxyNodeCAServer(m.caserver, m.raftNode, handleRequestLocally, forwardAsOwnRequest) + localProxyResourceAPI := api.NewRaftProxyResourceAllocatorServer(baseResourceAPI, m.raftNode, handleRequestLocally, forwardAsOwnRequest) + localProxyLogBrokerAPI := api.NewRaftProxyLogBrokerServer(m.logbroker, m.raftNode, handleRequestLocally, forwardAsOwnRequest) // Everything registered on m.server should be an authenticated // wrapper, or a proxy wrapping an authenticated wrapper! @@ -362,7 +385,11 @@ func (m *Manager) Run(parent context.Context) error { api.RegisterControlServer(m.localserver, localProxyControlAPI) api.RegisterLogsServer(m.localserver, localProxyLogsAPI) api.RegisterHealthServer(m.localserver, localHealthServer) - api.RegisterCAServer(m.localserver, localCAAPI) + api.RegisterDispatcherServer(m.localserver, localProxyDispatcherAPI) + api.RegisterCAServer(m.localserver, localProxyCAAPI) + api.RegisterNodeCAServer(m.localserver, localProxyNodeCAAPI) + api.RegisterResourceAllocatorServer(m.localserver, localProxyResourceAPI) + api.RegisterLogBrokerServer(m.localserver, localProxyLogBrokerAPI) healthServer.SetServingStatus("Raft", api.HealthCheckResponse_NOT_SERVING) localHealthServer.SetServingStatus("ControlAPI", api.HealthCheckResponse_NOT_SERVING) diff --git a/protobuf/plugin/raftproxy/raftproxy.go b/protobuf/plugin/raftproxy/raftproxy.go index 931dfdf23b..bb8582113b 100644 --- a/protobuf/plugin/raftproxy/raftproxy.go +++ b/protobuf/plugin/raftproxy/raftproxy.go @@ -27,12 +27,12 @@ func (g *raftProxyGen) genProxyStruct(s *descriptor.ServiceDescriptorProto) { g.gen.P("type " + serviceTypeName(s) + " struct {") g.gen.P("\tlocal " + s.GetName() + "Server") g.gen.P("\tconnSelector raftselector.ConnProvider") - g.gen.P("\tctxMods []func(context.Context)(context.Context, error)") + g.gen.P("\tlocalCtxMods, remoteCtxMods []func(context.Context)(context.Context, error)") g.gen.P("}") } func (g *raftProxyGen) genProxyConstructor(s *descriptor.ServiceDescriptorProto) { - g.gen.P("func NewRaftProxy" + s.GetName() + "Server(local " + s.GetName() + "Server, connSelector raftselector.ConnProvider, ctxMod func(context.Context)(context.Context, error)) " + s.GetName() + "Server {") + g.gen.P("func NewRaftProxy" + s.GetName() + "Server(local " + s.GetName() + "Server, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context)(context.Context, error)) " + s.GetName() + "Server {") g.gen.P(`redirectChecker := func(ctx context.Context)(context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -49,21 +49,27 @@ func (g *raftProxyGen) genProxyConstructor(s *descriptor.ServiceDescriptorProto) md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context)(context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context)(context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context)(context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context)(context.Context, error){localCtxMod} + } `) g.gen.P("return &" + serviceTypeName(s) + `{ local: local, connSelector: connSelector, - ctxMods: mods, + localCtxMods: localMods, + remoteCtxMods: remoteMods, }`) g.gen.P("}") } func (g *raftProxyGen) genRunCtxMods(s *descriptor.ServiceDescriptorProto) { - g.gen.P("func (p *" + serviceTypeName(s) + `) runCtxMods(ctx context.Context) (context.Context, error) { + g.gen.P("func (p *" + serviceTypeName(s) + `) runCtxMods(ctx context.Context, ctxMods []func(context.Context)(context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -91,18 +97,43 @@ func sigPrefix(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescrip return "func (p *" + serviceTypeName(s) + ") " + m.GetName() + "(" } +func (g *raftProxyGen) genStreamWrapper(streamType string) { + // Generate stream wrapper that returns a modified context + g.gen.P(`type ` + streamType + `Wrapper struct { + ` + streamType + ` + ctx context.Context +} +`) + g.gen.P(`func (s ` + streamType + `Wrapper) Context() context.Context { + return s.ctx +} +`) +} + func (g *raftProxyGen) genClientStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) { - g.gen.P(sigPrefix(s, m) + "stream " + s.GetName() + "_" + m.GetName() + "Server) error {") - g.gen.P(` + streamType := s.GetName() + "_" + m.GetName() + "Server" + + // Generate stream wrapper that returns a modified context + g.genStreamWrapper(streamType) + + g.gen.P(sigPrefix(s, m) + "stream " + streamType + `) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.` + m.GetName() + `(stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := ` + streamType + `Wrapper{ + ` + streamType + `: stream, + ctx: ctx, + } + return p.local.` + m.GetName() + `(streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err }`) @@ -135,17 +166,28 @@ func (g *raftProxyGen) genClientStreamingMethod(s *descriptor.ServiceDescriptorP } func (g *raftProxyGen) genServerStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) { - g.gen.P(sigPrefix(s, m) + "r *" + getInputTypeName(m) + ", stream " + s.GetName() + "_" + m.GetName() + "Server) error {") - g.gen.P(` + streamType := s.GetName() + "_" + m.GetName() + "Server" + + g.genStreamWrapper(streamType) + + g.gen.P(sigPrefix(s, m) + "r *" + getInputTypeName(m) + ", stream " + streamType + `) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.` + m.GetName() + `(r, stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := ` + streamType + `Wrapper{ + ` + streamType + `: stream, + ctx: ctx, + } + return p.local.` + m.GetName() + `(r, streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err }`) @@ -172,17 +214,28 @@ func (g *raftProxyGen) genServerStreamingMethod(s *descriptor.ServiceDescriptorP } func (g *raftProxyGen) genClientServerStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) { - g.gen.P(sigPrefix(s, m) + "stream " + s.GetName() + "_" + m.GetName() + "Server) error {") - g.gen.P(` + streamType := s.GetName() + "_" + m.GetName() + "Server" + + g.genStreamWrapper(streamType) + + g.gen.P(sigPrefix(s, m) + "stream " + streamType + `) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.` + m.GetName() + `(stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := ` + streamType + `Wrapper{ + ` + streamType + `: stream, + ctx: ctx, + } + return p.local.` + m.GetName() + `(streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err }`) @@ -231,11 +284,15 @@ func (g *raftProxyGen) genSimpleMethod(s *descriptor.ServiceDescriptorProto, m * conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.` + m.GetName() + `(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err }`) diff --git a/protobuf/plugin/raftproxy/test/raftproxy_test.go b/protobuf/plugin/raftproxy/test/raftproxy_test.go index 2f4e6fd364..3dd8990661 100644 --- a/protobuf/plugin/raftproxy/test/raftproxy_test.go +++ b/protobuf/plugin/raftproxy/test/raftproxy_test.go @@ -51,7 +51,7 @@ func TestSimpleRedirect(t *testing.T) { cluster := &mockCluster{conn: conn} forwardAsOwnRequest := func(ctx context.Context) (context.Context, error) { return ctx, nil } - api := NewRaftProxyRouteGuideServer(testRouteGuide{}, cluster, forwardAsOwnRequest) + api := NewRaftProxyRouteGuideServer(testRouteGuide{}, cluster, nil, forwardAsOwnRequest) srv := grpc.NewServer() RegisterRouteGuideServer(srv, api) go srv.Serve(l) diff --git a/protobuf/plugin/raftproxy/test/service.pb.go b/protobuf/plugin/raftproxy/test/service.pb.go index 1ebe84bcfb..9285ec8964 100644 --- a/protobuf/plugin/raftproxy/test/service.pb.go +++ b/protobuf/plugin/raftproxy/test/service.pb.go @@ -906,12 +906,12 @@ func encodeVarintService(data []byte, offset int, v uint64) int { } type raftProxyRouteGuideServer struct { - local RouteGuideServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local RouteGuideServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyRouteGuideServer(local RouteGuideServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) RouteGuideServer { +func NewRaftProxyRouteGuideServer(local RouteGuideServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) RouteGuideServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -928,18 +928,24 @@ func NewRaftProxyRouteGuideServer(local RouteGuideServer, connSelector raftselec md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyRouteGuideServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyRouteGuideServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyRouteGuideServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -976,11 +982,15 @@ func (p *raftProxyRouteGuideServer) GetFeature(ctx context.Context, r *Point) (* conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.GetFeature(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -1002,17 +1012,33 @@ func (p *raftProxyRouteGuideServer) GetFeature(ctx context.Context, r *Point) (* return resp, err } -func (p *raftProxyRouteGuideServer) ListFeatures(r *Rectangle, stream RouteGuide_ListFeaturesServer) error { +type RouteGuide_ListFeaturesServerWrapper struct { + RouteGuide_ListFeaturesServer + ctx context.Context +} +func (s RouteGuide_ListFeaturesServerWrapper) Context() context.Context { + return s.ctx +} + +func (p *raftProxyRouteGuideServer) ListFeatures(r *Rectangle, stream RouteGuide_ListFeaturesServer) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.ListFeatures(r, stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := RouteGuide_ListFeaturesServerWrapper{ + RouteGuide_ListFeaturesServer: stream, + ctx: ctx, + } + return p.local.ListFeatures(r, streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err } @@ -1037,17 +1063,33 @@ func (p *raftProxyRouteGuideServer) ListFeatures(r *Rectangle, stream RouteGuide return nil } -func (p *raftProxyRouteGuideServer) RecordRoute(stream RouteGuide_RecordRouteServer) error { +type RouteGuide_RecordRouteServerWrapper struct { + RouteGuide_RecordRouteServer + ctx context.Context +} + +func (s RouteGuide_RecordRouteServerWrapper) Context() context.Context { + return s.ctx +} +func (p *raftProxyRouteGuideServer) RecordRoute(stream RouteGuide_RecordRouteServer) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.RecordRoute(stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := RouteGuide_RecordRouteServerWrapper{ + RouteGuide_RecordRouteServer: stream, + ctx: ctx, + } + return p.local.RecordRoute(streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err } @@ -1078,17 +1120,33 @@ func (p *raftProxyRouteGuideServer) RecordRoute(stream RouteGuide_RecordRouteSer return stream.SendAndClose(reply) } -func (p *raftProxyRouteGuideServer) RouteChat(stream RouteGuide_RouteChatServer) error { +type RouteGuide_RouteChatServerWrapper struct { + RouteGuide_RouteChatServer + ctx context.Context +} + +func (s RouteGuide_RouteChatServerWrapper) Context() context.Context { + return s.ctx +} +func (p *raftProxyRouteGuideServer) RouteChat(stream RouteGuide_RouteChatServer) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.RouteChat(stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := RouteGuide_RouteChatServerWrapper{ + RouteGuide_RouteChatServer: stream, + ctx: ctx, + } + return p.local.RouteChat(streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err } @@ -1131,12 +1189,12 @@ func (p *raftProxyRouteGuideServer) RouteChat(stream RouteGuide_RouteChatServer) } type raftProxyHealthServer struct { - local HealthServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local HealthServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyHealthServer(local HealthServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) HealthServer { +func NewRaftProxyHealthServer(local HealthServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) HealthServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -1153,18 +1211,24 @@ func NewRaftProxyHealthServer(local HealthServer, connSelector raftselector.Conn md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyHealthServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyHealthServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyHealthServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -1201,11 +1265,15 @@ func (p *raftProxyHealthServer) Check(ctx context.Context, r *HealthCheckRequest conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.Check(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err }