diff --git a/proxy/grpc_handler.go b/proxy/grpc_handler.go index 0310564a8..bed811a62 100644 --- a/proxy/grpc_handler.go +++ b/proxy/grpc_handler.go @@ -55,8 +55,7 @@ func GetGRPCDirector(tlscfg *tls.Config) func(ctx context.Context, fullMethodNam return ctx, nil, fmt.Errorf("error extracting metadata from request") } - outCtx, _ := context.WithCancel(ctx) - outCtx = metadata.NewOutgoingContext(outCtx, md.Copy()) + outCtx := metadata.NewOutgoingContext(ctx, md.Copy()) target, _ := ctx.Value(targetKey{}).(*route.Target) @@ -89,6 +88,10 @@ func (p proxyStream) Context() context.Context { return p.ctx } +func makeGRPCTargetKey(t *route.Target) string { + return t.URL.String() +} + func (g GrpcProxyInterceptor) Stream(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { ctx := stream.Context() @@ -203,7 +206,7 @@ func (h *GrpcStatsHandler) HandleConn(ctx context.Context, conn stats.ConnStats) } type grpcConnectionPool struct { - connections map[*route.Target]*grpc.ClientConn + connections map[string]*grpc.ClientConn lock sync.RWMutex cleanupInterval time.Duration tlscfg *tls.Config @@ -211,7 +214,7 @@ type grpcConnectionPool struct { func newGrpcConnectionPool(tlscfg *tls.Config) *grpcConnectionPool { cp := &grpcConnectionPool{ - connections: make(map[*route.Target]*grpc.ClientConn), + connections: make(map[string]*grpc.ClientConn), lock: sync.RWMutex{}, cleanupInterval: time.Second * 5, tlscfg: tlscfg, @@ -224,7 +227,7 @@ func newGrpcConnectionPool(tlscfg *tls.Config) *grpcConnectionPool { func (p *grpcConnectionPool) Get(ctx context.Context, target *route.Target) (*grpc.ClientConn, error) { p.lock.RLock() - conn := p.connections[target] + conn := p.connections[makeGRPCTargetKey(target)] p.lock.RUnlock() if conn != nil && conn.GetState() != connectivity.Shutdown { @@ -265,23 +268,24 @@ func (p *grpcConnectionPool) newConnection(ctx context.Context, target *route.Ta func (p *grpcConnectionPool) Set(target *route.Target, conn *grpc.ClientConn) { p.lock.Lock() defer p.lock.Unlock() - p.connections[target] = conn + + p.connections[makeGRPCTargetKey(target)] = conn } func (p *grpcConnectionPool) cleanup() { for { p.lock.Lock() table := route.GetTable() - for target, cs := range p.connections { + for tKey, cs := range p.connections { if cs.GetState() == connectivity.Shutdown { - delete(p.connections, target) + delete(p.connections, tKey) continue } - if !hasTarget(target, table) { - log.Println("[DEBUG] grpc: cleaning up connection to", target.URL.Host) + if !hasTarget(tKey, table) { + log.Println("[DEBUG] grpc: cleaning up connection to", tKey) cs.Close() - delete(p.connections, target) + delete(p.connections, tKey) } } p.lock.Unlock() @@ -289,11 +293,11 @@ func (p *grpcConnectionPool) cleanup() { } } -func hasTarget(target *route.Target, table route.Table) bool { +func hasTarget(tKey string, table route.Table) bool { for _, routes := range table { for _, r := range routes { for _, t := range r.Targets { - if target == t { + if tKey == makeGRPCTargetKey(t) { return true } }