diff --git a/gossip/comm/comm_impl.go b/gossip/comm/comm_impl.go index 572eafce8f8..89ebeab3306 100644 --- a/gossip/comm/comm_impl.go +++ b/gossip/comm/comm_impl.go @@ -72,17 +72,19 @@ func NewCommInstanceWithServer(port int, idMapper identity.Mapper, peerIdentity var ll net.Listener var s *grpc.Server var secOpt grpc.DialOption + var certHash []byte if len(dialOpts) == 0 { dialOpts = []grpc.DialOption{grpc.WithTimeout(dialTimeout)} } if port > 0 { - s, ll, secOpt = createGRPCLayer(port) + s, ll, secOpt, certHash = createGRPCLayer(port) dialOpts = append(dialOpts, secOpt) } commInst := &commImpl{ + selfCertHash: certHash, PKIID: idMapper.GetPKIidOfCert(peerIdentity), idMapper: idMapper, logger: util.GetLogger(util.LOGGING_COMM_MODULE, fmt.Sprintf("%d", port)), @@ -117,16 +119,28 @@ func NewCommInstanceWithServer(port int, idMapper identity.Mapper, peerIdentity } // NewCommInstance creates a new comm instance that binds itself to the given gRPC server -func NewCommInstance(s *grpc.Server, idStore identity.Mapper, peerIdentity api.PeerIdentityType, dialOpts ...grpc.DialOption) (Comm, error) { +func NewCommInstance(s *grpc.Server, cert *tls.Certificate, idStore identity.Mapper, peerIdentity api.PeerIdentityType, dialOpts ...grpc.DialOption) (Comm, error) { commInst, err := NewCommInstanceWithServer(-1, idStore, peerIdentity, dialOpts...) if err != nil { return nil, err } + + if cert != nil { + inst := commInst.(*commImpl) + if len(cert.Certificate) == 0 { + inst.logger.Panic("Certificate supplied but certificate chain is empty") + } else { + inst.selfCertHash = certHashFromRawCert(cert.Certificate[0]) + } + } + proto.RegisterGossipServer(s, commInst.(*commImpl)) + return commInst, nil } type commImpl struct { + selfCertHash []byte peerIdentity api.PeerIdentityType idMapper identity.Mapper logger *util.Logger @@ -373,13 +387,16 @@ func extractRemoteAddress(stream stream) string { func (c *commImpl) authenticateRemotePeer(stream stream) (common.PKIidType, error) { ctx := stream.Context() remoteAddress := extractRemoteAddress(stream) - tlsUnique := ExtractTLSUnique(ctx) + remoteCertHash := extractCertificateHashFromContext(ctx) var sig []byte var err error - if tlsUnique != nil { - sig, err = c.idMapper.Sign(tlsUnique) + + // If TLS is detected, sign the hash of our cert to bind our TLS cert + // to the gRPC session + if remoteCertHash != nil && c.selfCertHash != nil { + sig, err = c.idMapper.Sign(c.selfCertHash) if err != nil { - c.logger.Error("Failed signing TLS-Unique:", err) + c.logger.Error("Failed signing self certificate hash:", err) return nil, err } } @@ -414,8 +431,9 @@ func (c *commImpl) authenticateRemotePeer(stream stream) (common.PKIidType, erro return nil, err } - if tlsUnique != nil { - err = c.idMapper.Verify(receivedMsg.PkiID, receivedMsg.Sig, tlsUnique) + // if TLS is detected, verify remote peer + if remoteCertHash != nil && c.selfCertHash != nil { + err = c.idMapper.Verify(receivedMsg.PkiID, receivedMsg.Sig, remoteCertHash) if err != nil { c.logger.Error("Failed verifying signature from", remoteAddress, ":", err) return nil, err @@ -424,7 +442,6 @@ func (c *commImpl) authenticateRemotePeer(stream stream) (common.PKIidType, erro c.logger.Debug("Authenticated", remoteAddress) return receivedMsg.PkiID, nil - } func (c *commImpl) GossipStream(stream proto.Gossip_GossipStreamServer) error { @@ -518,7 +535,8 @@ type stream interface { grpc.Stream } -func createGRPCLayer(port int) (*grpc.Server, net.Listener, grpc.DialOption) { +func createGRPCLayer(port int) (*grpc.Server, net.Listener, grpc.DialOption, []byte) { + var returnedCertHash []byte var s *grpc.Server var ll net.Listener var err error @@ -533,10 +551,25 @@ func createGRPCLayer(port int) (*grpc.Server, net.Listener, grpc.DialOption) { err = generateCertificates(keyFileName, certFileName) if err == nil { - var creds credentials.TransportCredentials - creds, err = credentials.NewServerTLSFromFile(certFileName, keyFileName) - serverOpts = append(serverOpts, grpc.Creds(creds)) + cert, err := tls.LoadX509KeyPair(certFileName, keyFileName) + if err != nil { + panic(err) + } + + if len(cert.Certificate) == 0 { + panic(fmt.Errorf("Certificate chain is nil")) + } + + returnedCertHash = certHashFromRawCert(cert.Certificate[0]) + + tlsConf := &tls.Config{ + Certificates: []tls.Certificate{cert}, + ClientAuth: tls.RequestClientCert, + InsecureSkipVerify: true, + } + serverOpts = append(serverOpts, grpc.Creds(credentials.NewTLS(tlsConf))) ta := credentials.NewTLS(&tls.Config{ + Certificates: []tls.Certificate{cert}, InsecureSkipVerify: true, }) dialOpts = grpc.WithTransportCredentials(&authCreds{tlsCreds: ta}) @@ -551,5 +584,5 @@ func createGRPCLayer(port int) (*grpc.Server, net.Listener, grpc.DialOption) { } s = grpc.NewServer(serverOpts...) - return s, ll, dialOpts + return s, ll, dialOpts, returnedCertHash } diff --git a/gossip/comm/comm_test.go b/gossip/comm/comm_test.go index 43704dfc04a..ab0c3ebb9e1 100644 --- a/gossip/comm/comm_test.go +++ b/gossip/comm/comm_test.go @@ -21,6 +21,7 @@ import ( "crypto/tls" "fmt" "math/rand" + "os" "sync" "testing" "time" @@ -88,8 +89,13 @@ func newCommInstance(port int, sec api.MessageCryptoService) (Comm, error) { } func handshaker(endpoint string, comm Comm, t *testing.T, sigMutator func([]byte) []byte, pkiIDmutator func([]byte) []byte) <-chan ReceivedMessage { + err := generateCertificates("key.pem", "cert.pem") + defer os.Remove("cert.pem") + defer os.Remove("key.pem") + cert, err := tls.LoadX509KeyPair("cert.pem", "key.pem") ta := credentials.NewTLS(&tls.Config{ InsecureSkipVerify: true, + Certificates: []tls.Certificate{cert}, }) acceptChan := comm.Accept(acceptAll) conn, err := grpc.Dial("localhost:9611", grpc.WithTransportCredentials(&authCreds{tlsCreds: ta}), grpc.WithBlock(), grpc.WithTimeout(time.Second)) @@ -103,8 +109,8 @@ func handshaker(endpoint string, comm Comm, t *testing.T, sigMutator func([]byte if err != nil { return nil } - clientTLSUnique := ExtractTLSUnique(stream.Context()) - sig, err := naiveSec.Sign(clientTLSUnique) + clientCertHash := certHashFromRawCert(cert.Certificate[0]) + sig, err := naiveSec.Sign(clientCertHash) if sigMutator != nil { sig = sigMutator(sig) } @@ -119,7 +125,7 @@ func handshaker(endpoint string, comm Comm, t *testing.T, sigMutator func([]byte msg, err = stream.Recv() assert.NoError(t, err, "%v", err) if sigMutator == nil { - assert.Equal(t, clientTLSUnique, msg.GetConn().Sig) + assert.Equal(t, extractCertificateHashFromContext(stream.Context()), msg.GetConn().Sig) } assert.Equal(t, []byte("localhost:9611"), msg.GetConn().PkiID) msg2Send := createGossipMsg() @@ -152,10 +158,10 @@ func TestHandshake(t *testing.T) { assert.Equal(t, 0, len(acceptChan)) // negative path, nothing should be read from the channel because the PKIid doesn't match the identity - mutateEndpoint := func(b []byte) []byte { + mutatePKIID := func(b []byte) []byte { return []byte("localhost:9650") } - acceptChan = handshaker("localhost:9613", comm, t, nil, mutateEndpoint) + acceptChan = handshaker("localhost:9613", comm, t, nil, mutatePKIID) time.Sleep(time.Second) assert.Equal(t, 0, len(acceptChan)) } diff --git a/gossip/comm/crypto.go b/gossip/comm/crypto.go index 3b9a99bff3c..c63b90e986d 100644 --- a/gossip/comm/crypto.go +++ b/gossip/comm/crypto.go @@ -20,13 +20,13 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" + "crypto/sha256" + "crypto/tls" "crypto/x509" "encoding/pem" "math/big" - "os" - - "crypto/tls" "net" + "os" "time" "golang.org/x/net/context" @@ -71,8 +71,17 @@ func generateCertificates(privKeyFile string, certKeyFile string) error { return err } -// ExtractTLSUnique extracts the TLS-Unique from the stream -func ExtractTLSUnique(ctx context.Context) []byte { +func certHashFromRawCert(rawCert []byte) []byte { + if len(rawCert) == 0 { + return nil + } + hash := sha256.New() + hash.Write(rawCert) + return hash.Sum(nil) +} + +// ExtractCertificateHash extracts the hash of the certificate from the stream +func extractCertificateHashFromContext(ctx context.Context) []byte { pr, extracted := peer.FromContext(ctx) if !extracted { return nil @@ -87,7 +96,12 @@ func ExtractTLSUnique(ctx context.Context) []byte { if !isTLSConn { return nil } - return tlsInfo.State.TLSUnique + certs := tlsInfo.State.PeerCertificates + if len(certs) == 0 { + return nil + } + raw := certs[0].Raw + return certHashFromRawCert(raw) } type authCreds struct { diff --git a/gossip/comm/crypto_test.go b/gossip/comm/crypto_test.go index ecf31ad9f97..88b9bcaaaeb 100644 --- a/gossip/comm/crypto_test.go +++ b/gossip/comm/crypto_test.go @@ -17,7 +17,6 @@ limitations under the License. package comm import ( - "bytes" "crypto/tls" "fmt" "net" @@ -34,75 +33,76 @@ import ( ) type gossipTestServer struct { - lock sync.Mutex - msgChan chan uint64 - tlsUnique []byte + lock sync.Mutex + remoteCertHash []byte + selfCertHash []byte + ll net.Listener + s *grpc.Server } -func (s *gossipTestServer) GossipStream(stream proto.Gossip_GossipStreamServer) error { - s.lock.Lock() - s.tlsUnique = ExtractTLSUnique(stream.Context()) - s.lock.Unlock() - m, err := stream.Recv() - if err != nil { - fmt.Println(err) - } else { - s.msgChan <- m.Nonce +func createTestServer(t *testing.T, cert *tls.Certificate) *gossipTestServer { + tlsConf := &tls.Config{ + Certificates: []tls.Certificate{*cert}, + ClientAuth: tls.RequestClientCert, + InsecureSkipVerify: true, } + s := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConf))) + ll, err := net.Listen("tcp", fmt.Sprintf("%s:%d", "", 5611)) + assert.NoError(t, err, "%v", err) + srv := &gossipTestServer{s: s, ll: ll, selfCertHash: certHashFromRawCert(cert.Certificate[0])} + proto.RegisterGossipServer(s, srv) + go s.Serve(ll) + return srv +} + +func (s *gossipTestServer) stop() { + s.s.Stop() + s.ll.Close() +} + +func (s *gossipTestServer) GossipStream(stream proto.Gossip_GossipStreamServer) error { + s.lock.Lock() + defer s.lock.Unlock() + s.remoteCertHash = extractCertificateHashFromContext(stream.Context()) return nil } -func (s *gossipTestServer) getTLSUnique() []byte { +func (s *gossipTestServer) getClientCertHash() []byte { s.lock.Lock() defer s.lock.Unlock() - return s.tlsUnique + return s.remoteCertHash } func (s *gossipTestServer) Ping(context.Context, *proto.Empty) (*proto.Empty, error) { return &proto.Empty{}, nil } -func TestCertificateGeneration(t *testing.T) { +func TestCertificateExtraction(t *testing.T) { err := generateCertificates("key.pem", "cert.pem") - assert.NoError(t, err, "%v", err) - if err != nil { - return - } defer os.Remove("cert.pem") defer os.Remove("key.pem") - var ll net.Listener - creds, err := credentials.NewServerTLSFromFile("cert.pem", "key.pem") assert.NoError(t, err, "%v", err) - if err != nil { - return - } - s := grpc.NewServer(grpc.Creds(creds)) - ll, err = net.Listen("tcp", fmt.Sprintf("%s:%d", "", 5511)) + serverCert, err := tls.LoadX509KeyPair("cert.pem", "key.pem") assert.NoError(t, err, "%v", err) - if err != nil { - return - } - srv := &gossipTestServer{msgChan: make(chan uint64)} - proto.RegisterGossipServer(s, srv) - go s.Serve(ll) - defer func() { - s.Stop() - ll.Close() - }() - time.Sleep(time.Second * time.Duration(2)) + + srv := createTestServer(t, &serverCert) + defer srv.stop() + + generateCertificates("key2.pem", "cert2.pem") + defer os.Remove("cert2.pem") + defer os.Remove("key2.pem") + clientCert, err := tls.LoadX509KeyPair("cert2.pem", "key2.pem") + clientCertHash := certHashFromRawCert(clientCert.Certificate[0]) + assert.NoError(t, err) ta := credentials.NewTLS(&tls.Config{ + Certificates: []tls.Certificate{clientCert}, InsecureSkipVerify: true, }) assert.NoError(t, err, "%v", err) - if err != nil { - return - } - conn, err := grpc.Dial("localhost:5511", grpc.WithTransportCredentials(&authCreds{tlsCreds: ta}), grpc.WithBlock(), grpc.WithTimeout(time.Second)) + conn, err := grpc.Dial("localhost:5611", grpc.WithTransportCredentials(&authCreds{tlsCreds: ta}), grpc.WithBlock(), grpc.WithTimeout(time.Second)) assert.NoError(t, err, "%v", err) - if err != nil { - return - } + cl := proto.NewGossipClient(conn) stream, err := cl.GossipStream(context.Background()) assert.NoError(t, err, "%v", err) @@ -110,23 +110,16 @@ func TestCertificateGeneration(t *testing.T) { return } - time.Sleep(time.Duration(1) * time.Second) + time.Sleep(time.Second) + clientSideCertHash := extractCertificateHashFromContext(stream.Context()) + serverSideCertHash := srv.getClientCertHash() - clientTLSUnique := ExtractTLSUnique(stream.Context()) - serverTLSUnique := srv.getTLSUnique() + assert.NotNil(t, clientSideCertHash) + assert.NotNil(t, serverSideCertHash) - assert.NotNil(t, clientTLSUnique) - assert.NotNil(t, serverTLSUnique) + assert.Equal(t, 32, len(clientSideCertHash), "client side cert hash is %v", clientSideCertHash) + assert.Equal(t, 32, len(serverSideCertHash), "server side cert hash is %v", serverSideCertHash) - assert.True(t, bytes.Equal(clientTLSUnique, serverTLSUnique), "Client and server TLSUnique are not equal") - - msg := createGossipMsg() - stream.Send(msg) - select { - case nonce := <-srv.msgChan: - assert.Equal(t, msg.Nonce, nonce) - break - case <-time.NewTicker(time.Second).C: - assert.Fail(t, "Timed out reading from stream") - } + assert.Equal(t, clientSideCertHash, srv.selfCertHash, "Server self hash isn't equal to client side hash") + assert.Equal(t, clientCertHash, srv.remoteCertHash, "Server side and client hash aren't equal") } diff --git a/gossip/gossip/gossip.go b/gossip/gossip/gossip.go index 3606a64a672..57aad6aa1f6 100644 --- a/gossip/gossip/gossip.go +++ b/gossip/gossip/gossip.go @@ -17,6 +17,7 @@ limitations under the License. package gossip import ( + "crypto/tls" "time" "github.com/hyperledger/fabric/gossip/comm" @@ -68,4 +69,6 @@ type Config struct { PullPeerNum int PublishCertPeriod time.Duration + + TLSServerCert *tls.Certificate } diff --git a/gossip/gossip/gossip_impl.go b/gossip/gossip/gossip_impl.go index bf6648771c1..3651f38f72f 100644 --- a/gossip/gossip/gossip_impl.go +++ b/gossip/gossip/gossip_impl.go @@ -18,6 +18,7 @@ package gossip import ( "bytes" + "crypto/tls" "fmt" "sync" "sync/atomic" @@ -73,7 +74,7 @@ func NewGossipService(conf *Config, s *grpc.Server, mcs api.MessageCryptoService if s == nil { c, err = createCommWithServer(conf.BindPort, idMapper, selfIdentity) } else { - c, err = createCommWithoutServer(s, idMapper, selfIdentity, dialOpts...) + c, err = createCommWithoutServer(s, conf.TLSServerCert, idMapper, selfIdentity, dialOpts...) } if err != nil { @@ -165,8 +166,8 @@ func createCommWithServer(port int, idStore identity.Mapper, identity api.PeerId return comm.NewCommInstanceWithServer(port, idStore, identity) } -func createCommWithoutServer(s *grpc.Server, idStore identity.Mapper, identity api.PeerIdentityType, dialOpts ...grpc.DialOption) (comm.Comm, error) { - return comm.NewCommInstance(s, idStore, identity, dialOpts...) +func createCommWithoutServer(s *grpc.Server, cert *tls.Certificate, idStore identity.Mapper, identity api.PeerIdentityType, dialOpts ...grpc.DialOption) (comm.Comm, error) { + return comm.NewCommInstance(s, cert, idStore, identity, dialOpts...) } func (g *gossipServiceImpl) toDie() bool {