diff --git a/network/network_integration_test.go b/network/network_integration_test.go index baf194847e..f225767557 100644 --- a/network/network_integration_test.go +++ b/network/network_integration_test.go @@ -953,6 +953,40 @@ func TestNetworkIntegration_TLSOffloading(t *testing.T) { assert.EqualError(t, err, "rpc error: code = Unauthenticated desc = TLS client certificate authentication failed") assert.Nil(t, msg) }) + t.Run("certificate revoked/denied", func(t *testing.T) { + testDirectory := io.TestDirectory(t) + // Start server node (node1) + node1 := startNode(t, "node1", testDirectory, func(serverCfg *core.ServerConfig, cfg *Config) { + serverCfg.TLS.Offload = core.OffloadIncomingTLS + serverCfg.TLS.ClientCertHeaderName = "client-cert" + }) + + // Load client cert and add it to the denylist + clientCertBytes, err := os.ReadFile(testCertAndKeyFile) + require.NoError(t, err) + cert, err := core.ParseCertificates(clientCertBytes) + require.NoError(t, err) + pki.SetNewDenylistWithCert(t, node1.network.pkiValidator, cert[0]) + + // Create client (node2) that connects to server node + grpcConn, err := grpcLib.Dial(nameToAddress(t, "node1"), grpcLib.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer grpcConn.Close() + ctx := context.Background() + outgoingMD := metadata.MD{} + outgoingMD.Set("peerID", "client") + outgoingMD.Set("nodeDID", "did:nuts:node2") + outgoingMD.Set("client-cert", url.QueryEscape(string(clientCertBytes))) + outgoingContext := metadata.NewOutgoingContext(ctx, outgoingMD) + client := v2.NewProtocolClient(grpcConn) + result, err := client.Stream(outgoingContext) + require.NoError(t, err) + + // Assert connection is rejected + msg, err := result.Recv() + assert.EqualError(t, err, "rpc error: code = Unauthenticated desc = TLS client certificate validation failed") + assert.Nil(t, msg) + }) }) } diff --git a/network/network_test.go b/network/network_test.go index c549481ded..f07da5971a 100644 --- a/network/network_test.go +++ b/network/network_test.go @@ -213,6 +213,7 @@ func TestNetwork_Configure(t *testing.T) { ctx.protocol.EXPECT().Configure(gomock.Any()) ctx.pkiValidator.EXPECT().AddTruststore(gomock.Any()) ctx.pkiValidator.EXPECT().SetVerifyPeerCertificateFunc(gomock.Any()).Times(2) // tls.Configs: client, selfTestDialer + ctx.pkiValidator.EXPECT().SubscribeDenied(gomock.Any()) ctx.network.connectionManager = nil cfg := *core.NewServerConfig() diff --git a/network/transport/grpc/authenticator.go b/network/transport/grpc/authenticator.go index 18b9dcdb15..a968c37b19 100644 --- a/network/transport/grpc/authenticator.go +++ b/network/transport/grpc/authenticator.go @@ -19,6 +19,7 @@ package grpc import ( + "crypto/x509" "errors" "fmt" "github.com/nuts-foundation/go-did/did" @@ -83,6 +84,7 @@ func (t tlsAuthenticator) Authenticate(nodeDID did.DID, grpcPeer grpcPeer.Peer, Debug("Connection successfully authenticated") peer.NodeDID = nodeDID peer.Authenticated = true + peer.Certificate = peerCertificate return peer, nil } @@ -96,5 +98,6 @@ type dummyAuthenticator struct{} func (d dummyAuthenticator) Authenticate(nodeDID did.DID, _ grpcPeer.Peer, peer transport.Peer) (transport.Peer, error) { peer.NodeDID = nodeDID peer.Authenticated = true + peer.Certificate = &x509.Certificate{} return peer, nil } diff --git a/network/transport/grpc/authenticator_test.go b/network/transport/grpc/authenticator_test.go index bd6364173e..68c4709a93 100644 --- a/network/transport/grpc/authenticator_test.go +++ b/network/transport/grpc/authenticator_test.go @@ -54,6 +54,7 @@ func Test_tlsAuthenticator_Authenticate(t *testing.T) { expectedPeer := transport.Peer{ NodeDID: nodeDID, Authenticated: true, + Certificate: cert, } t.Run("ok", func(t *testing.T) { @@ -90,6 +91,11 @@ func Test_tlsAuthenticator_Authenticate(t *testing.T) { }, }, } + expectedPeer := transport.Peer{ + NodeDID: nodeDID, + Authenticated: true, + Certificate: wildcardCert, + } authenticatedPeer, err := authenticator.Authenticate(nodeDID, grpcPeer, transport.Peer{}) diff --git a/network/transport/grpc/connection.go b/network/transport/grpc/connection.go index 01d39cd696..8f938f5a5b 100644 --- a/network/transport/grpc/connection.go +++ b/network/transport/grpc/connection.go @@ -81,9 +81,6 @@ type Connection interface { // IsConnected returns whether the connection is active or not. IsConnected() bool - // IsProtocolConnected returns whether the given protocol is active on the connection. - IsProtocolConnected(protocol Protocol) bool - // IsAuthenticated returns whether teh given connection is authenticated. IsAuthenticated() bool @@ -152,9 +149,8 @@ func (mc *conn) waitUntilDisconnected() { mc.mux.RUnlock() return } - done := mc.ctx.Done() mc.mux.RUnlock() - <-done + <-mc.ctx.Done() } func (mc *conn) verifyOrSetPeerID(id transport.PeerID) bool { @@ -210,10 +206,10 @@ func (mc *conn) registerStream(protocol Protocol, stream Stream) bool { mc.startSending(protocol, stream) // A connection can have multiple active streams, but if one of them is closed, all of them should be closed, also closing the underlying connection. - go func(cancel func()) { + go func() { <-stream.Context().Done() - cancel() - }(mc.cancelCtx) + mc.cancelCtx() + }() return true } @@ -221,11 +217,15 @@ func (mc *conn) registerStream(protocol Protocol, stream Stream) bool { func (mc *conn) startReceiving(protocol Protocol, stream Stream) { peer := mc.Peer() // copy Peer, because it will be nil when logging after disconnecting. atomic.AddInt32(&mc.activeGoroutines, 1) - go func(activeGoroutines *int32, cancel func()) { + go func(activeGoroutines *int32) { defer atomic.AddInt32(activeGoroutines, -1) for { message := protocol.CreateEnvelope() - err := stream.RecvMsg(message) + err := stream.RecvMsg(message) // blocking + if mc.ctx.Err() != nil { + // connection has been closed: drop message and stop receiving + return + } if err != nil { errStatus, isStatusError := status.FromError(err) if errors.Is(err, io.EOF) || (isStatusError && errStatus.Code() == codes.Canceled) { @@ -241,7 +241,7 @@ func (mc *conn) startReceiving(protocol Protocol, stream Stream) { Warn("Peer connection error") } mc.status.Store(errStatus) - cancel() + mc.cancelCtx() break } @@ -255,12 +255,11 @@ func (mc *conn) startReceiving(protocol Protocol, stream Stream) { Warn("Error handling message") } } - }(&mc.activeGoroutines, mc.cancelCtx) + }(&mc.activeGoroutines) } func (mc *conn) startSending(protocol Protocol, stream Stream) { outbox := mc.outboxes[protocol.MethodName()] - done := mc.ctx.Done() atomic.AddInt32(&mc.activeGoroutines, 1) go func(activeGoroutines *int32) { @@ -268,7 +267,7 @@ func (mc *conn) startSending(protocol Protocol, stream Stream) { loop: for { select { - case <-done: + case <-mc.ctx.Done(): break loop case envelope := <-outbox: if envelope == nil { @@ -315,14 +314,6 @@ func (mc *conn) IsConnected() bool { return len(mc.streams) > 0 } -func (mc *conn) IsProtocolConnected(protocol Protocol) bool { - mc.mux.RLock() - defer mc.mux.RUnlock() - - _, ok := mc.streams[protocol.MethodName()] - return ok -} - func (mc *conn) IsAuthenticated() bool { return mc.Peer().Authenticated } diff --git a/network/transport/grpc/connection_manager.go b/network/transport/grpc/connection_manager.go index 9c42257135..77664a6ad2 100644 --- a/network/transport/grpc/connection_manager.go +++ b/network/transport/grpc/connection_manager.go @@ -20,10 +20,12 @@ package grpc import ( "context" + "crypto/x509" "errors" "fmt" "net" "sync" + "sync/atomic" "time" "github.com/nuts-foundation/go-did/did" @@ -36,7 +38,6 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" grpcPeer "google.golang.org/grpc/peer" "google.golang.org/grpc/status" @@ -123,7 +124,10 @@ func NewGRPCConnectionManager(config Config, connectionStore stoabs.KVStore, nod cm.addressBook = newAddressBook(connectionStore, config.backoffCreator) cm.registerPrometheusMetrics() cm.ctx, cm.ctxCancel = context.WithCancel(context.Background()) - + cm.lastCertificateValidation.Store(&time.Time{}) + if config.tlsEnabled() { + config.pkiValidator.SubscribeDenied(cm.revalidatePeers) + } return cm, nil } @@ -145,10 +149,11 @@ type grpcConnectionManager struct { addressBook *addressBook dialer - connectLoopWG sync.WaitGroup - dialOptions []grpc.DialOption - connectionTimeout time.Duration - connections *connectionList + connectLoopWG sync.WaitGroup + dialOptions []grpc.DialOption + connectionTimeout time.Duration + connections *connectionList + lastCertificateValidation atomic.Pointer[time.Time] } // newGrpcServer configures a new grpc.Server @@ -183,18 +188,6 @@ func newGrpcServer(config Config) (*grpc.Server, error) { serverInterceptors = append(serverInterceptors, ipInterceptor) serverOpts = append(serverOpts, grpc.ChainStreamInterceptor(serverInterceptors...)) - // Define the keepalive policy for the grpc server in such a way that connections are not long-lived. - // By blocking long-lived connections we ensure that connections are periodically reauthorized, namely - // so that a remote host which was authorized at the time of connection can become unauthorized and - // this is correctly enforced. - // - // Configured per https://github.com/grpc/grpc-go/blob/c9d3ea5673252d212c69f3d3c10ce1d7b287a86b/examples/features/keepalive/server/main.go#L43 - keepaliveParams := keepalive.ServerParameters{ - MaxConnectionAge: 15 * time.Minute, // If any connection is alive for too long, send a GOAWAY - MaxConnectionAgeGrace: 15 * time.Second, // Allow time for pending RPCs to complete before forcibly closing connections - } - serverOpts = append(serverOpts, grpc.KeepaliveParams(keepaliveParams)) - // Create gRPC server for inbound connectionList and associate it with the protocols return grpc.NewServer(serverOpts...), nil } @@ -401,7 +394,12 @@ func (s *grpcConnectionManager) Contacts() []transport.Contact { } func (s *grpcConnectionManager) Diagnostics() []core.DiagnosticResult { - return append(append([]core.DiagnosticResult{ownPeerIDStatistic{s.config.peerID}}, s.connections.Diagnostics()...)) + return append( + []core.DiagnosticResult{ + lastCertificateValidationStatistic{*s.lastCertificateValidation.Load()}, + ownPeerIDStatistic{s.config.peerID}, + }, + s.connections.Diagnostics()...) } // RegisterService implements grpc.ServiceRegistrar to register the gRPC services protocols expose. @@ -544,6 +542,26 @@ func (s *grpcConnectionManager) authenticate(nodeDID did.DID, peer transport.Pee return peer, nil } +// revalidatePeers verifies for all peers the x509.Certificate provided during TLS handshake is still valid. +func (s *grpcConnectionManager) revalidatePeers() { + var err error + now := nowFunc() + s.lastCertificateValidation.Store(&now) + s.connections.forEach(func(conn Connection) { + peerCert := conn.Peer().Certificate + if nowFunc().After(peerCert.NotAfter) { + log.Logger().WithError(errors.New("certificate expired while in use")).WithFields(conn.Peer().ToFields()).Info("Disconnected peer") + conn.disconnect() + return + } + err = s.config.pkiValidator.Validate([]*x509.Certificate{peerCert}) + if err != nil { + log.Logger().WithError(err).WithFields(conn.Peer().ToFields()).Warn("Disconnected peer") + conn.disconnect() + } + }) +} + func (s *grpcConnectionManager) handleInboundStream(protocol Protocol, inboundStream grpc.ServerStream) error { peerFromCtx, _ := grpcPeer.FromContext(inboundStream.Context()) log.Logger(). diff --git a/network/transport/grpc/connection_manager_test.go b/network/transport/grpc/connection_manager_test.go index ddbeba150c..da3d0d6748 100644 --- a/network/transport/grpc/connection_manager_test.go +++ b/network/transport/grpc/connection_manager_test.go @@ -29,6 +29,7 @@ import ( "hash/crc32" "io" "net" + "os" "path/filepath" "sync" "sync/atomic" @@ -142,6 +143,7 @@ func Test_grpcConnectionManager_Connect(t *testing.T) { pkiMock := pki.NewMockValidator(ctrl) pkiMock.EXPECT().AddTruststore(ts.Certificates()) pkiMock.EXPECT().SetVerifyPeerCertificateFunc(gomock.Any()) + pkiMock.EXPECT().SubscribeDenied(gomock.Any()) config, err := NewConfig("", "test", WithTLS(clientCert, ts, pkiMock)) require.NoError(t, err) @@ -569,6 +571,7 @@ func Test_grpcConnectionManager_Start(t *testing.T) { t.Run("ok - gRPC server bound, TLS enabled", func(t *testing.T) { pkiMock.EXPECT().SetVerifyPeerCertificateFunc(gomock.Any()).Times(2) + pkiMock.EXPECT().SubscribeDenied(gomock.Any()) cfg, err := NewConfig( fmt.Sprintf("127.0.0.1:%d", test.FreeTCPPort()), @@ -586,6 +589,7 @@ func Test_grpcConnectionManager_Start(t *testing.T) { t.Run("ok - gRPC server bound, incoming TLS offloaded", func(t *testing.T) { pkiMock.EXPECT().SetVerifyPeerCertificateFunc(gomock.Any()) + pkiMock.EXPECT().SubscribeDenied(gomock.Any()) cfg, err := NewConfig( fmt.Sprintf("127.0.0.1:%d", test.FreeTCPPort()), @@ -632,6 +636,7 @@ func Test_grpcConnectionManager_Start(t *testing.T) { return nil } }).Times(2) // on inbound and outbound TLS config + pkiMock.EXPECT().SubscribeDenied(gomock.Any()) cfg, err := NewConfig(fmt.Sprintf(":%d", test.FreeTCPPort()), "peerID", WithTLS(serverCert, &core.TrustStore{CertPool: x509.NewCertPool()}, pkiMock)) require.NoError(t, err) @@ -644,6 +649,7 @@ func Test_grpcConnectionManager_Start(t *testing.T) { t.Run("error - invalid server TLS config", func(t *testing.T) { cfg, err := NewConfig(fmt.Sprintf(":%d", test.FreeTCPPort()), "peerID", WithTLS(serverCert, &core.TrustStore{CertPool: x509.NewCertPool()}, pkiMock)) pkiMock.EXPECT().SetVerifyPeerCertificateFunc(gomock.Any()) + pkiMock.EXPECT().SubscribeDenied(gomock.Any()) cm, err := NewGRPCConnectionManager(cfg, nil, *nodeDID, nil, &TestProtocol{}) require.NoError(t, err) @@ -699,16 +705,25 @@ func Test_grpcConnectionManager_Stop(t *testing.T) { func Test_grpcConnectionManager_Diagnostics(t *testing.T) { const peerID = "server-peer-id" + testTime := time.Now() t.Run("no peers", func(t *testing.T) { cm, err := NewGRPCConnectionManager(Config{peerID: peerID}, nil, *nodeDID, nil) require.NoError(t, err) defer cm.Stop() - assert.Equal(t, "0", cm.Diagnostics()[1].String()) // assert number_of_peers + cm.lastCertificateValidation.Store(&testTime) + + diag := cm.Diagnostics() + + require.Len(t, diag, 4) + assert.Equal(t, testTime.String(), diag[0].String()) // assert certificates_last_validated + assert.Equal(t, peerID, diag[1].String()) // assert peer_id + assert.Equal(t, "0", diag[2].String()) // assert number_of_peers }) t.Run("with peers", func(t *testing.T) { cm, err := NewGRPCConnectionManager(Config{peerID: peerID}, nil, *nodeDID, nil) require.NoError(t, err) defer cm.Stop() + cm.lastCertificateValidation.Store(&testTime) go cm.handleInboundStream(&TestProtocol{}, newServerStream("peer1", "")) go cm.handleInboundStream(&TestProtocol{}, newServerStream("peer2", "")) @@ -717,8 +732,13 @@ func Test_grpcConnectionManager_Diagnostics(t *testing.T) { return len(cm.Peers()) == 2, nil }, 5*time.Second, "time-out while waiting for peers to connect") - assert.Equal(t, "2", cm.Diagnostics()[1].String()) // assert number_of_peers - assert.Equal(t, "peer2@127.0.0.1:1028 peer1@127.0.0.1:6718", cm.Diagnostics()[2].String()) // assert peers + diag := cm.Diagnostics() + + require.Len(t, diag, 4) + assert.Equal(t, testTime.String(), diag[0].String()) // assert certificates_last_validated + assert.Equal(t, peerID, diag[1].String()) // assert peer_id + assert.Equal(t, "2", diag[2].String()) // assert number_of_peers + assert.Equal(t, "peer2@127.0.0.1:1028 peer1@127.0.0.1:6718", diag[3].String()) // assert peers }) } @@ -1156,6 +1176,63 @@ func Test_grpcConnectionManager_handleInboundStream(t *testing.T) { }) } +func Test_grpcConnectionManager_revalidatePeers(t *testing.T) { + mockValidator := pki.NewMockValidator(gomock.NewController(t)) + clientCertBytes, err := os.ReadFile(testCertAndKeyFile) + require.NoError(t, err) + certs, err := core.ParseCertificates(clientCertBytes) + cert := certs[0] + require.NoError(t, err) + + t.Run("ok", func(t *testing.T) { + mockValidator.EXPECT().Validate([]*x509.Certificate{cert}) + cm, err := NewGRPCConnectionManager(Config{pkiValidator: mockValidator}, nil, *nodeDID, nil) + require.NoError(t, err) + connection := NewStubConnection(transport.Peer{Certificate: cert}) + cm.connections.list = append(cm.connections.list, connection) + + cm.revalidatePeers() + + assert.Equal(t, 0, connection.disconnectCalls) + }) + t.Run("denied", func(t *testing.T) { + mockValidator.EXPECT().Validate([]*x509.Certificate{cert}).Return(pki.ErrCertBanned) + cm, err := NewGRPCConnectionManager(Config{pkiValidator: mockValidator}, nil, *nodeDID, nil) + require.NoError(t, err) + connection := NewStubConnection(transport.Peer{Certificate: cert}) + cm.connections.list = append(cm.connections.list, connection) + + cm.revalidatePeers() + + assert.Equal(t, 1, connection.disconnectCalls) + }) + t.Run("denied multiple", func(t *testing.T) { + mockValidator.EXPECT().Validate([]*x509.Certificate{cert}).Return(pki.ErrCertBanned).Times(3) + cm, err := NewGRPCConnectionManager(Config{pkiValidator: mockValidator}, nil, *nodeDID, nil) + require.NoError(t, err) + connection := NewStubConnection(transport.Peer{Certificate: cert}) + cm.connections.list = append(cm.connections.list, connection, connection, connection) + + cm.revalidatePeers() + + assert.Equal(t, 3, connection.disconnectCalls) + }) + t.Run("expired", func(t *testing.T) { + nowFunc = func() time.Time { + return time.Now().AddDate(50, 0, 0) + } + defer func() { nowFunc = time.Now }() + cm, err := NewGRPCConnectionManager(Config{pkiValidator: mockValidator}, nil, *nodeDID, nil) + require.NoError(t, err) + connection := NewStubConnection(transport.Peer{Certificate: cert}) + cm.connections.list = append(cm.connections.list, connection) + + cm.revalidatePeers() + + assert.Equal(t, 1, connection.disconnectCalls) + }) +} + func newServerStream(clientPeerID transport.PeerID, nodeDID string) *stubServerStream { md := metadata.New(map[string]string{peerIDHeader: clientPeerID.String()}) if nodeDID != "" { diff --git a/network/transport/grpc/connection_mock.go b/network/transport/grpc/connection_mock.go index 37ba52a7e6..3e82610be1 100644 --- a/network/transport/grpc/connection_mock.go +++ b/network/transport/grpc/connection_mock.go @@ -63,20 +63,6 @@ func (mr *MockConnectionMockRecorder) IsConnected() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsConnected", reflect.TypeOf((*MockConnection)(nil).IsConnected)) } -// IsProtocolConnected mocks base method. -func (m *MockConnection) IsProtocolConnected(protocol Protocol) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsProtocolConnected", protocol) - ret0, _ := ret[0].(bool) - return ret0 -} - -// IsProtocolConnected indicates an expected call of IsProtocolConnected. -func (mr *MockConnectionMockRecorder) IsProtocolConnected(protocol interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsProtocolConnected", reflect.TypeOf((*MockConnection)(nil).IsProtocolConnected), protocol) -} - // Peer mocks base method. func (m *MockConnection) Peer() transport.Peer { m.ctrl.T.Helper() diff --git a/network/transport/grpc/connection_test.go b/network/transport/grpc/connection_test.go index 7487f3a084..60ce6248a4 100644 --- a/network/transport/grpc/connection_test.go +++ b/network/transport/grpc/connection_test.go @@ -55,19 +55,6 @@ func Test_conn_disconnect(t *testing.T) { }) } -func Test_conn_IsProtocolConnected(t *testing.T) { - p := &TestProtocol{} - t.Run("not connected", func(t *testing.T) { - conn := createConnection(context.Background(), transport.Peer{}) - assert.False(t, conn.IsProtocolConnected(p)) - }) - t.Run("connected", func(t *testing.T) { - conn := createConnection(context.Background(), transport.Peer{}).(*conn) - conn.streams[p.MethodName()] = &MockStream{} - assert.True(t, conn.IsProtocolConnected(p)) - }) -} - func Test_conn_waitUntilDisconnected(t *testing.T) { t.Run("never open, should return immediately", func(t *testing.T) { conn := createConnection(context.Background(), transport.Peer{}) @@ -118,10 +105,7 @@ func Test_conn_registerStream(t *testing.T) { } func Test_conn_startSending(t *testing.T) { - t.Run("disconnect causes panic in startSending", func(t *testing.T) { - // startSending reads from the outbox channel, which is closed when disconnect() is called. Closing the channel - // causes startSending to read a nil message from the channel, which causes a panic. - // If the message to be sent is nil, it indicates the connection is closing and the loop should exit. + t.Run("disconnect does not panic", func(t *testing.T) { connection := createConnection(context.Background(), transport.Peer{}).(*conn) stream := newServerStream("foo", "") @@ -139,8 +123,8 @@ func Test_conn_startSending(t *testing.T) { return atomic.LoadInt32(&connection.activeGoroutines) == 0, nil }, 5*time.Second, "waiting for all goroutines to exit") - // err status is set on connection. Due to EOF it's an unknown error - assert.Equal(t, codes.Unknown, connection.status.Load().Code()) + // Last received message is dropped and no status is set. Default value is OK. + assert.Equal(t, codes.OK, connection.status.Load().Code()) }) } diff --git a/network/transport/grpc/stats.go b/network/transport/grpc/stats.go index b0602ad44c..771be6e35a 100644 --- a/network/transport/grpc/stats.go +++ b/network/transport/grpc/stats.go @@ -25,6 +25,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "sort" "strings" + "time" ) // numberOfPeersStatistic contains node's number of peers it's connected to. @@ -92,6 +93,25 @@ func (o ownPeerIDStatistic) String() string { return o.peerID.String() } +// lastCertificateValidationStatistic contains the timestamp of the most recent certificate validation of all peers. +type lastCertificateValidationStatistic struct { + lastCheck time.Time +} + +func (o lastCertificateValidationStatistic) Result() interface{} { + return o.lastCheck +} + +// Name returns the name of the statistic. +func (o lastCertificateValidationStatistic) Name() string { + return "certificates_last_validated" +} + +// String returns the statistic as string. +func (o lastCertificateValidationStatistic) String() string { + return o.lastCheck.String() +} + type prometheusStreamWrapper struct { stream Stream protocol Protocol diff --git a/network/transport/grpc/test.go b/network/transport/grpc/test.go index bc77fd5a8f..8fb624747c 100644 --- a/network/transport/grpc/test.go +++ b/network/transport/grpc/test.go @@ -19,6 +19,7 @@ package grpc import ( + "crypto/x509" "github.com/nuts-foundation/go-did/did" "github.com/nuts-foundation/nuts-node/network/transport" "google.golang.org/grpc/status" @@ -69,12 +70,14 @@ var _ Connection = (*StubConnection)(nil) // StubConnection is a stub implementation of the Connection interface type StubConnection struct { - Open bool - NodeDID did.DID - SentMsgs []interface{} - PeerID transport.PeerID - Authenticated bool - Address string + Open bool + NodeDID did.DID + SentMsgs []interface{} + PeerID transport.PeerID + Authenticated bool + Address string + Certificate *x509.Certificate + disconnectCalls int } func NewStubConnection(peer transport.Peer) *StubConnection { @@ -83,7 +86,8 @@ func NewStubConnection(peer transport.Peer) *StubConnection { PeerID: peer.ID, NodeDID: peer.NodeDID, Authenticated: peer.Authenticated, - Address: peer.Address} + Address: peer.Address, + Certificate: peer.Certificate} } func (s *StubConnection) ID() transport.PeerID { @@ -104,6 +108,7 @@ func (s *StubConnection) Peer() transport.Peer { NodeDID: s.NodeDID, Authenticated: s.Authenticated, Address: s.Address, + Certificate: s.Certificate, } } @@ -131,7 +136,7 @@ func (s *StubConnection) SetErrorStatus(_ *status.Status) { } func (s *StubConnection) disconnect() { - panic("implement me") + s.disconnectCalls++ } func (s *StubConnection) waitUntilDisconnected() { diff --git a/network/transport/grpc/tls_offloading.go b/network/transport/grpc/tls_offloading.go index 7ced803609..89223b87c1 100644 --- a/network/transport/grpc/tls_offloading.go +++ b/network/transport/grpc/tls_offloading.go @@ -43,16 +43,15 @@ func newAuthenticationInterceptor(clientCertHeaderName string, pkiValidator pki. return (&tlsOffloadingAuthenticator{clientCertHeaderName: clientCertHeaderName, pkiValidator: pkiValidator}).intercept } +// tlsOffloadingAuthenticator get the TLS certificate from the 'clientCertHeaderName' header and set it on the grpc.peer. type tlsOffloadingAuthenticator struct { clientCertHeaderName string pkiValidator pki.Validator } func (t *tlsOffloadingAuthenticator) intercept(srv interface{}, serverStream grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + // Get certificate from header certificates, err := t.authenticate(serverStream) - if err == nil { - err = t.pkiValidator.Validate(certificates) - } if err != nil { log.Logger(). WithError(err). @@ -60,6 +59,14 @@ func (t *tlsOffloadingAuthenticator) intercept(srv interface{}, serverStream grp return status.Error(codes.Unauthenticated, "TLS client certificate authentication failed") } + // Validate revocation/deny list status + if err = t.pkiValidator.Validate(certificates); err != nil { + log.Logger(). + WithError(err). + Warnf("Validation of offloaded TLS certificate failed") + return status.Error(codes.Unauthenticated, "TLS client certificate validation failed") + } + // Build TLS info and override in Peer info, which is set on the incoming context peerInfo, _ := peer.FromContext(serverStream.Context()) if peerInfo == nil { diff --git a/network/transport/types.go b/network/transport/types.go index 2f1876f4b1..266df3443f 100644 --- a/network/transport/types.go +++ b/network/transport/types.go @@ -19,6 +19,7 @@ package transport import ( + "crypto/x509" "encoding/json" "errors" "fmt" @@ -51,6 +52,8 @@ type Peer struct { NodeDID did.DID `json:"nodedid"` // Authenticated is true when NodeDID is set and authentication is successful. Authenticated bool `json:"authenticated"` + // Certificate presented by peer during TLS handshake. + Certificate *x509.Certificate `json:"-" yaml:"-"` } // ToFields returns the peer as a map of fields, to be used when logging the peer details. diff --git a/pki/denylist.go b/pki/denylist.go index e645b5e09f..62a96a9104 100644 --- a/pki/denylist.go +++ b/pki/denylist.go @@ -22,7 +22,6 @@ package pki import ( "crypto/x509" "encoding/json" - "errors" "fmt" "io" "net/http" @@ -35,29 +34,6 @@ import ( "github.com/lestrrat-go/jwx/jws" ) -var ( - // ErrDenylistMissing occurs when the denylist cannot be downloaded - ErrDenylistMissing = errors.New("denylist cannot be retrieved") - - // ErrCertBanned means the certificate was banned by a denylist rather than revoked by a CRL - ErrCertBanned = errors.New("certificate is banned") -) - -// Denylist implements a global certificate rejection -type Denylist interface { - // LastUpdated provides the time at which the denylist was last retrieved - LastUpdated() time.Time - - // Update fetches a new copy of the denylist - Update() error - - // URL returns the URL of the denylist - URL() string - - // ValidateCert returns an error if a certificate should not be used - ValidateCert(cert *x509.Certificate) error -} - // denylistImpl implements arbitrary certificate rejection using issuer and serial number tuples type denylistImpl struct { // url specifies the URL where the denylist is downloaded @@ -71,6 +47,9 @@ type denylistImpl struct { // lastUpdated contains the time the certificate was last updated lastUpdated time.Time + + // subscribers for denylist updates + subscribers []func() } // denylistEntry contains parameters for an X.509 certificate that must not be accepted for TLS connections @@ -224,10 +203,19 @@ func (b *denylistImpl) Update() error { // Log when the denylist is updated logger().Debug("Denylist updated successfully") + // Notify all subscribers synchronously + for _, sub := range b.subscribers { + sub() + } + // Return a nil error as the denylist was successfully updated return nil } +func (b *denylistImpl) Subscribe(f func()) { + b.subscribers = append(b.subscribers, f) +} + // download retrieves and parses the denylist func (b *denylistImpl) download() ([]byte, error) { // Make an HTTP GET request for the denylist URL diff --git a/pki/denylist_test.go b/pki/denylist_test.go index 525f7fb567..1ec55f2425 100644 --- a/pki/denylist_test.go +++ b/pki/denylist_test.go @@ -288,9 +288,14 @@ func TestUpdateValidDenylist(t *testing.T) { // Ensure the new denylist update time is zero assert.True(t, denylist.LastUpdated().IsZero()) + // Ensure that subscribers are notified + var iscalled bool + denylist.Subscribe(func() { iscalled = true }) + // Update the denylist data and ensure there are no errors err = denylist.Update() require.NoError(t, err) + assert.True(t, iscalled) // Ensure the entries are present as expected in the denylist structure entriesPtr := denylist.(*denylistImpl).entries.Load() @@ -465,4 +470,3 @@ func TestRSACertificateJWKThumbprint(t *testing.T) { keyID := certKeyJWKThumbprint(cert) assert.Equal(t, "PVOjk-5d4Lb-FGxurW-fNMUv3rYZZBWF3gGaP5s1UVQ", keyID) } - diff --git a/pki/interface.go b/pki/interface.go index 894200df0b..95fd2c026b 100644 --- a/pki/interface.go +++ b/pki/interface.go @@ -23,6 +23,7 @@ import ( "crypto/x509" "errors" "github.com/nuts-foundation/nuts-node/core" + "time" ) // errors @@ -31,8 +32,31 @@ var ( ErrCRLExpired = errors.New("crl has expired") ErrCertRevoked = errors.New("certificate is revoked") ErrCertUntrusted = errors.New("certificate's issuer is not trusted") + // ErrDenylistMissing occurs when the denylist cannot be downloaded + ErrDenylistMissing = errors.New("denylist cannot be retrieved") + + // ErrCertBanned means the certificate was banned by a denylist rather than revoked by a CRL + ErrCertBanned = errors.New("certificate is banned") ) +// Denylist implements a global certificate rejection +type Denylist interface { + // LastUpdated provides the time at which the denylist was last retrieved + LastUpdated() time.Time + + // Update fetches a new copy of the denylist + Update() error + + // URL returns the URL of the denylist + URL() string + + // ValidateCert returns an error if a certificate should not be used + ValidateCert(cert *x509.Certificate) error + + // Subscribe registers a callback that is triggered everytime the denylist is updated + Subscribe(f func()) +} + type Validator interface { // Validate returns an error if any of the certificates in the chain has been revoked, or if the request cannot be processed. // ErrCertRevoked and ErrCertUntrusted indicate that at least one of the certificates is revoked, or signed by a CA that is not in the truststore. @@ -48,6 +72,10 @@ type Validator interface { // AddTruststore adds all CAs to the truststore for validation of CRL signatures. It also adds all CRL Distribution Endpoints found in the chain. // CRL Distribution Points encountered during operation, such as on end user certificates, are only added to the monitored CRLs if their issuer is in the truststore. AddTruststore(chain []*x509.Certificate) error + + // SubscribeDenied registers a callback that is triggered everytime the denylist is updated. + // This can be used to revalidate all certificates on long-lasting connections by calling Validate on them again. + SubscribeDenied(f func()) } // Provider is an interface for providing PKI services (e.g. TLS configuration, certificate validation). diff --git a/pki/mock.go b/pki/mock.go index 6b6b243f4a..cfd13b1414 100644 --- a/pki/mock.go +++ b/pki/mock.go @@ -8,11 +8,103 @@ import ( tls "crypto/tls" x509 "crypto/x509" reflect "reflect" + time "time" gomock "github.com/golang/mock/gomock" core "github.com/nuts-foundation/nuts-node/core" ) +// MockDenylist is a mock of Denylist interface. +type MockDenylist struct { + ctrl *gomock.Controller + recorder *MockDenylistMockRecorder +} + +// MockDenylistMockRecorder is the mock recorder for MockDenylist. +type MockDenylistMockRecorder struct { + mock *MockDenylist +} + +// NewMockDenylist creates a new mock instance. +func NewMockDenylist(ctrl *gomock.Controller) *MockDenylist { + mock := &MockDenylist{ctrl: ctrl} + mock.recorder = &MockDenylistMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDenylist) EXPECT() *MockDenylistMockRecorder { + return m.recorder +} + +// LastUpdated mocks base method. +func (m *MockDenylist) LastUpdated() time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LastUpdated") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// LastUpdated indicates an expected call of LastUpdated. +func (mr *MockDenylistMockRecorder) LastUpdated() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LastUpdated", reflect.TypeOf((*MockDenylist)(nil).LastUpdated)) +} + +// Subscribe mocks base method. +func (m *MockDenylist) Subscribe(f func()) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Subscribe", f) +} + +// Subscribe indicates an expected call of Subscribe. +func (mr *MockDenylistMockRecorder) Subscribe(f interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Subscribe", reflect.TypeOf((*MockDenylist)(nil).Subscribe), f) +} + +// URL mocks base method. +func (m *MockDenylist) URL() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "URL") + ret0, _ := ret[0].(string) + return ret0 +} + +// URL indicates an expected call of URL. +func (mr *MockDenylistMockRecorder) URL() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "URL", reflect.TypeOf((*MockDenylist)(nil).URL)) +} + +// Update mocks base method. +func (m *MockDenylist) Update() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Update") + ret0, _ := ret[0].(error) + return ret0 +} + +// Update indicates an expected call of Update. +func (mr *MockDenylistMockRecorder) Update() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockDenylist)(nil).Update)) +} + +// ValidateCert mocks base method. +func (m *MockDenylist) ValidateCert(cert *x509.Certificate) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateCert", cert) + ret0, _ := ret[0].(error) + return ret0 +} + +// ValidateCert indicates an expected call of ValidateCert. +func (mr *MockDenylistMockRecorder) ValidateCert(cert interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateCert", reflect.TypeOf((*MockDenylist)(nil).ValidateCert), cert) +} + // MockValidator is a mock of Validator interface. type MockValidator struct { ctrl *gomock.Controller @@ -64,6 +156,18 @@ func (mr *MockValidatorMockRecorder) SetVerifyPeerCertificateFunc(config interfa return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetVerifyPeerCertificateFunc", reflect.TypeOf((*MockValidator)(nil).SetVerifyPeerCertificateFunc), config) } +// SubscribeDenied mocks base method. +func (m *MockValidator) SubscribeDenied(f func()) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SubscribeDenied", f) +} + +// SubscribeDenied indicates an expected call of SubscribeDenied. +func (mr *MockValidatorMockRecorder) SubscribeDenied(f interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscribeDenied", reflect.TypeOf((*MockValidator)(nil).SubscribeDenied), f) +} + // Validate mocks base method. func (m *MockValidator) Validate(chain []*x509.Certificate) error { m.ctrl.T.Helper() @@ -144,6 +248,18 @@ func (mr *MockProviderMockRecorder) SetVerifyPeerCertificateFunc(config interfac return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetVerifyPeerCertificateFunc", reflect.TypeOf((*MockProvider)(nil).SetVerifyPeerCertificateFunc), config) } +// SubscribeDenied mocks base method. +func (m *MockProvider) SubscribeDenied(f func()) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SubscribeDenied", f) +} + +// SubscribeDenied indicates an expected call of SubscribeDenied. +func (mr *MockProviderMockRecorder) SubscribeDenied(f interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscribeDenied", reflect.TypeOf((*MockProvider)(nil).SubscribeDenied), f) +} + // Validate mocks base method. func (m *MockProvider) Validate(chain []*x509.Certificate) error { m.ctrl.T.Helper() diff --git a/pki/test.go b/pki/test.go new file mode 100644 index 0000000000..9459c0a71b --- /dev/null +++ b/pki/test.go @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2023 Nuts community + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +package pki + +import ( + "crypto/x509" + "testing" + "time" +) + +// SetNewDenylistWithCert sets a new Denylist on the Validator and adds the certificate. +// This is useful in integrations tests etc. +func SetNewDenylistWithCert(t *testing.T, val Validator, cert *x509.Certificate) { + dl := &denylistImpl{ + url: "some-url", + lastUpdated: time.Now(), + } + dl.entries.Store(&[]denylistEntry{ + { + Issuer: cert.Issuer.String(), + SerialNumber: cert.SerialNumber.String(), + JWKThumbprint: certKeyJWKThumbprint(cert), + Reason: `testing purposes`, + }, + }) + switch v := val.(type) { + case *PKI: + v.denylist = dl + case *validator: + v.denylist = dl + default: + t.Fatal("cannot set Denylist on val") + } +} diff --git a/pki/validator.go b/pki/validator.go index 5e2cc900fe..cf0936c86b 100644 --- a/pki/validator.go +++ b/pki/validator.go @@ -258,6 +258,10 @@ func (v *validator) AddTruststore(chain []*x509.Certificate) error { return nil } +func (v *validator) SubscribeDenied(f func()) { + v.denylist.Subscribe(f) +} + func (v *validator) getCert(subject string) (*x509.Certificate, bool) { issuer, ok := v.truststore.Load(subject) if !ok { diff --git a/pki/validator_test.go b/pki/validator_test.go index f8d61fc57e..528b1f0a31 100644 --- a/pki/validator_test.go +++ b/pki/validator_test.go @@ -27,6 +27,7 @@ import ( "crypto/x509" "encoding/pem" "errors" + "github.com/golang/mock/gomock" "github.com/nuts-foundation/nuts-node/core" "go.uber.org/goleak" "math/big" @@ -223,6 +224,17 @@ func TestValidator_AddTruststore(t *testing.T) { }) } +func TestValidator_SubscribeDenied(t *testing.T) { + mockDenylist := NewMockDenylist(gomock.NewController(t)) + mockDenylist.EXPECT().Subscribe(gomock.Any()) + + val, err := newValidator(DefaultConfig()) + require.NoError(t, err) + val.denylist = mockDenylist + + val.SubscribeDenied(func() { _ = "functions handles cannot be tested for equality" }) +} + func Test_NewValidator(t *testing.T) { cfg := DefaultConfig()