diff --git a/ssh/handshake.go b/ssh/handshake.go index 49bbba7692..56cdc7c21c 100644 --- a/ssh/handshake.go +++ b/ssh/handshake.go @@ -35,6 +35,16 @@ type keyingTransport interface { // direction will be effected if a msgNewKeys message is sent // or received. prepareKeyChange(*algorithms, *kexResult) error + + // setStrictMode sets the strict KEX mode, notably triggering + // sequence number resets on sending or receiving msgNewKeys. + // If the sequence number is already > 1 when setStrictMode + // is called, an error is returned. + setStrictMode() error + + // setInitialKEXDone indicates to the transport that the initial key exchange + // was completed + setInitialKEXDone() } // handshakeTransport implements rekeying on top of a keyingTransport @@ -100,6 +110,10 @@ type handshakeTransport struct { // The session ID or nil if first kex did not complete yet. sessionID []byte + + // strictMode indicates if the other side of the handshake indicated + // that we should be following the strict KEX protocol restrictions. + strictMode bool } type pendingKex struct { @@ -209,7 +223,10 @@ func (t *handshakeTransport) readLoop() { close(t.incoming) break } - if p[0] == msgIgnore || p[0] == msgDebug { + // If this is the first kex, and strict KEX mode is enabled, + // we don't ignore any messages, as they may be used to manipulate + // the packet sequence numbers. + if !(t.sessionID == nil && t.strictMode) && (p[0] == msgIgnore || p[0] == msgDebug) { continue } t.incoming <- p @@ -441,6 +458,11 @@ func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) { return successPacket, nil } +const ( + kexStrictClient = "kex-strict-c-v00@openssh.com" + kexStrictServer = "kex-strict-s-v00@openssh.com" +) + // sendKexInit sends a key change message. func (t *handshakeTransport) sendKexInit() error { t.mu.Lock() @@ -454,7 +476,6 @@ func (t *handshakeTransport) sendKexInit() error { } msg := &kexInitMsg{ - KexAlgos: t.config.KeyExchanges, CiphersClientServer: t.config.Ciphers, CiphersServerClient: t.config.Ciphers, MACsClientServer: t.config.MACs, @@ -464,6 +485,13 @@ func (t *handshakeTransport) sendKexInit() error { } io.ReadFull(rand.Reader, msg.Cookie[:]) + // We mutate the KexAlgos slice, in order to add the kex-strict extension algorithm, + // and possibly to add the ext-info extension algorithm. Since the slice may be the + // user owned KeyExchanges, we create our own slice in order to avoid using user + // owned memory by mistake. + msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+2) // room for kex-strict and ext-info + msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...) + isServer := len(t.hostKeys) > 0 if isServer { for _, k := range t.hostKeys { @@ -488,17 +516,24 @@ func (t *handshakeTransport) sendKexInit() error { msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, keyFormat) } } + + if t.sessionID == nil { + msg.KexAlgos = append(msg.KexAlgos, kexStrictServer) + } } else { msg.ServerHostKeyAlgos = t.hostKeyAlgorithms // As a client we opt in to receiving SSH_MSG_EXT_INFO so we know what // algorithms the server supports for public key authentication. See RFC // 8308, Section 2.1. + // + // We also send the strict KEX mode extension algorithm, in order to opt + // into the strict KEX mode. if firstKeyExchange := t.sessionID == nil; firstKeyExchange { - msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+1) - msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...) msg.KexAlgos = append(msg.KexAlgos, "ext-info-c") + msg.KexAlgos = append(msg.KexAlgos, kexStrictClient) } + } packet := Marshal(msg) @@ -604,6 +639,13 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { return err } + if t.sessionID == nil && ((isClient && contains(serverInit.KexAlgos, kexStrictServer)) || (!isClient && contains(clientInit.KexAlgos, kexStrictClient))) { + t.strictMode = true + if err := t.conn.setStrictMode(); err != nil { + return err + } + } + // We don't send FirstKexFollows, but we handle receiving it. // // RFC 4253 section 7 defines the kex and the agreement method for @@ -679,6 +721,12 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { return unexpectedMessageError(msgNewKeys, packet[0]) } + if firstKeyExchange { + // Indicates to the transport that the first key exchange is completed + // after receiving SSH_MSG_NEWKEYS. + t.conn.setInitialKEXDone() + } + return nil } diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go index 65afc2059e..2bc607b649 100644 --- a/ssh/handshake_test.go +++ b/ssh/handshake_test.go @@ -395,6 +395,10 @@ func (n *errorKeyingTransport) readPacket() ([]byte, error) { return n.packetConn.readPacket() } +func (n *errorKeyingTransport) setStrictMode() error { return nil } + +func (n *errorKeyingTransport) setInitialKEXDone() {} + func TestHandshakeErrorHandlingRead(t *testing.T) { for i := 0; i < 20; i++ { testHandshakeErrorHandlingN(t, i, -1, false) @@ -710,3 +714,308 @@ func TestPickIncompatibleHostKeyAlgo(t *testing.T) { t.Fatal("incompatible signer returned") } } + +func TestStrictKEXResetSeqFirstKEX(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + + checker := &syncChecker{ + waitCall: make(chan int, 10), + called: make(chan int, 10), + } + + checker.waitCall <- 1 + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + <-checker.called + + t.Cleanup(func() { + trC.Close() + trS.Close() + }) + + // Throw away the msgExtInfo packet sent during the handshake by the server + _, err = trC.readPacket() + if err != nil { + t.Fatalf("readPacket failed: %s", err) + } + + // close the handshake transports before checking the sequence number to + // avoid races. + trC.Close() + trS.Close() + + // check that the sequence number counters. We reset after msgNewKeys, but + // then the server immediately writes msgExtInfo, and we close the + // transports so we expect read 2, write 0 on the client and read 1, write 1 + // on the server. + if trC.conn.(*transport).reader.seqNum != 2 || trC.conn.(*transport).writer.seqNum != 0 || + trS.conn.(*transport).reader.seqNum != 1 || trS.conn.(*transport).writer.seqNum != 1 { + t.Errorf( + "unexpected sequence counters:\nclient: reader %d (expected 2), writer %d (expected 0)\nserver: reader %d (expected 1), writer %d (expected 1)", + trC.conn.(*transport).reader.seqNum, + trC.conn.(*transport).writer.seqNum, + trS.conn.(*transport).reader.seqNum, + trS.conn.(*transport).writer.seqNum, + ) + } +} + +func TestStrictKEXResetSeqSuccessiveKEX(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + + checker := &syncChecker{ + waitCall: make(chan int, 10), + called: make(chan int, 10), + } + + checker.waitCall <- 1 + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + <-checker.called + + t.Cleanup(func() { + trC.Close() + trS.Close() + }) + + // Throw away the msgExtInfo packet sent during the handshake by the server + _, err = trC.readPacket() + if err != nil { + t.Fatalf("readPacket failed: %s", err) + } + + // write and read five packets on either side to bump the sequence numbers + for i := 0; i < 5; i++ { + if err := trC.writePacket([]byte{msgRequestSuccess}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if _, err := trS.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } + if err := trS.writePacket([]byte{msgRequestSuccess}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if _, err := trC.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } + } + + // Request a key exchange, which should cause the sequence numbers to reset + checker.waitCall <- 1 + trC.requestKeyExchange() + <-checker.called + + // write a packet on the client, and then read it, to verify the key change has actually happened, since + // the HostKeyCallback is called _during_ the handshake, so isn't actually indicative of the handshake + // finishing. + dummyPacket := []byte{99} + if err := trS.writePacket(dummyPacket); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if p, err := trC.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } else if !bytes.Equal(p, dummyPacket) { + t.Fatalf("unexpected packet: got %x, want %x", p, dummyPacket) + } + + // close the handshake transports before checking the sequence number to + // avoid races. + trC.Close() + trS.Close() + + if trC.conn.(*transport).reader.seqNum != 2 || trC.conn.(*transport).writer.seqNum != 0 || + trS.conn.(*transport).reader.seqNum != 1 || trS.conn.(*transport).writer.seqNum != 1 { + t.Errorf( + "unexpected sequence counters:\nclient: reader %d (expected 2), writer %d (expected 0)\nserver: reader %d (expected 1), writer %d (expected 1)", + trC.conn.(*transport).reader.seqNum, + trC.conn.(*transport).writer.seqNum, + trS.conn.(*transport).reader.seqNum, + trS.conn.(*transport).writer.seqNum, + ) + } +} + +func TestSeqNumIncrease(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + + checker := &syncChecker{ + waitCall: make(chan int, 10), + called: make(chan int, 10), + } + + checker.waitCall <- 1 + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + <-checker.called + + t.Cleanup(func() { + trC.Close() + trS.Close() + }) + + // Throw away the msgExtInfo packet sent during the handshake by the server + _, err = trC.readPacket() + if err != nil { + t.Fatalf("readPacket failed: %s", err) + } + + // write and read five packets on either side to bump the sequence numbers + for i := 0; i < 5; i++ { + if err := trC.writePacket([]byte{msgRequestSuccess}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if _, err := trS.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } + if err := trS.writePacket([]byte{msgRequestSuccess}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if _, err := trC.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } + } + + // close the handshake transports before checking the sequence number to + // avoid races. + trC.Close() + trS.Close() + + if trC.conn.(*transport).reader.seqNum != 7 || trC.conn.(*transport).writer.seqNum != 5 || + trS.conn.(*transport).reader.seqNum != 6 || trS.conn.(*transport).writer.seqNum != 6 { + t.Errorf( + "unexpected sequence counters:\nclient: reader %d (expected 7), writer %d (expected 5)\nserver: reader %d (expected 6), writer %d (expected 6)", + trC.conn.(*transport).reader.seqNum, + trC.conn.(*transport).writer.seqNum, + trS.conn.(*transport).reader.seqNum, + trS.conn.(*transport).writer.seqNum, + ) + } +} + +func TestStrictKEXUnexpectedMsg(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + + // Check that unexpected messages during the handshake cause failure + _, _, err := handshakePair(&ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}, "addr", true) + if err == nil { + t.Fatal("handshake should fail when there are unexpected messages during the handshake") + } + + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}, "addr", false) + if err != nil { + t.Fatalf("handshake failed: %s", err) + } + + // Check that ignore/debug pacekts are still ignored outside of the handshake + if err := trC.writePacket([]byte{msgIgnore}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if err := trC.writePacket([]byte{msgDebug}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + dummyPacket := []byte{99} + if err := trC.writePacket(dummyPacket); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + + if p, err := trS.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } else if !bytes.Equal(p, dummyPacket) { + t.Fatalf("unexpected packet: got %x, want %x", p, dummyPacket) + } +} + +func TestStrictKEXMixed(t *testing.T) { + // Test that we still support a mixed connection, where one side sends kex-strict but the other + // side doesn't. + + a, b, err := netPipe() + if err != nil { + t.Fatalf("netPipe failed: %s", err) + } + + var trC, trS keyingTransport + + trC = newTransport(a, rand.Reader, true) + trS = newTransport(b, rand.Reader, false) + trS = addNoiseTransport(trS) + + clientConf := &ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }} + clientConf.SetDefaults() + + v := []byte("version") + client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr()) + + serverConf := &ServerConfig{} + serverConf.AddHostKey(testSigners["ecdsa"]) + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.SetDefaults() + + transport := newHandshakeTransport(trS, &serverConf.Config, []byte("version"), []byte("version")) + transport.hostKeys = serverConf.hostKeys + transport.publicKeyAuthAlgorithms = serverConf.PublicKeyAuthAlgorithms + + readOneFailure := make(chan error, 1) + go func() { + if _, err := transport.readOnePacket(true); err != nil { + readOneFailure <- err + } + }() + + // Basically sendKexInit, but without the kex-strict extension algorithm + msg := &kexInitMsg{ + KexAlgos: transport.config.KeyExchanges, + CiphersClientServer: transport.config.Ciphers, + CiphersServerClient: transport.config.Ciphers, + MACsClientServer: transport.config.MACs, + MACsServerClient: transport.config.MACs, + CompressionClientServer: supportedCompressions, + CompressionServerClient: supportedCompressions, + ServerHostKeyAlgos: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA}, + } + packet := Marshal(msg) + // writePacket destroys the contents, so save a copy. + packetCopy := make([]byte, len(packet)) + copy(packetCopy, packet) + if err := transport.pushPacket(packetCopy); err != nil { + t.Fatalf("pushPacket: %s", err) + } + transport.sentInitMsg = msg + transport.sentInitPacket = packet + + if err := transport.getWriteError(); err != nil { + t.Fatalf("getWriteError failed: %s", err) + } + var request *pendingKex + select { + case err = <-readOneFailure: + t.Fatalf("server readOnePacket failed: %s", err) + case request = <-transport.startKex: + break + } + + // We expect the following calls to fail if the side which does not support + // kex-strict sends unexpected/ignored packets during the handshake, even if + // the other side does support kex-strict. + + if err := transport.enterKeyExchange(request.otherInit); err != nil { + t.Fatalf("enterKeyExchange failed: %s", err) + } + if err := client.waitSession(); err != nil { + t.Fatalf("client.waitSession: %v", err) + } +} diff --git a/ssh/transport.go b/ssh/transport.go index da015801ea..0424d2d37c 100644 --- a/ssh/transport.go +++ b/ssh/transport.go @@ -49,6 +49,9 @@ type transport struct { rand io.Reader isClient bool io.Closer + + strictMode bool + initialKEXDone bool } // packetCipher represents a combination of SSH encryption/MAC @@ -74,6 +77,18 @@ type connectionState struct { pendingKeyChange chan packetCipher } +func (t *transport) setStrictMode() error { + if t.reader.seqNum != 1 { + return errors.New("ssh: sequence number != 1 when strict KEX mode requested") + } + t.strictMode = true + return nil +} + +func (t *transport) setInitialKEXDone() { + t.initialKEXDone = true +} + // prepareKeyChange sets up key material for a keychange. The key changes in // both directions are triggered by reading and writing a msgNewKey packet // respectively. @@ -112,11 +127,12 @@ func (t *transport) printPacket(p []byte, write bool) { // Read and decrypt next packet. func (t *transport) readPacket() (p []byte, err error) { for { - p, err = t.reader.readPacket(t.bufReader) + p, err = t.reader.readPacket(t.bufReader, t.strictMode) if err != nil { break } - if len(p) == 0 || (p[0] != msgIgnore && p[0] != msgDebug) { + // in strict mode we pass through DEBUG and IGNORE packets only during the initial KEX + if len(p) == 0 || (t.strictMode && !t.initialKEXDone) || (p[0] != msgIgnore && p[0] != msgDebug) { break } } @@ -127,7 +143,7 @@ func (t *transport) readPacket() (p []byte, err error) { return p, err } -func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) { +func (s *connectionState) readPacket(r *bufio.Reader, strictMode bool) ([]byte, error) { packet, err := s.packetCipher.readCipherPacket(s.seqNum, r) s.seqNum++ if err == nil && len(packet) == 0 { @@ -140,6 +156,9 @@ func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) { select { case cipher := <-s.pendingKeyChange: s.packetCipher = cipher + if strictMode { + s.seqNum = 0 + } default: return nil, errors.New("ssh: got bogus newkeys message") } @@ -170,10 +189,10 @@ func (t *transport) writePacket(packet []byte) error { if debugTransport { t.printPacket(packet, true) } - return t.writer.writePacket(t.bufWriter, t.rand, packet) + return t.writer.writePacket(t.bufWriter, t.rand, packet, t.strictMode) } -func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte) error { +func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte, strictMode bool) error { changeKeys := len(packet) > 0 && packet[0] == msgNewKeys err := s.packetCipher.writeCipherPacket(s.seqNum, w, rand, packet) @@ -188,6 +207,9 @@ func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet [] select { case cipher := <-s.pendingKeyChange: s.packetCipher = cipher + if strictMode { + s.seqNum = 0 + } default: panic("ssh: no key material for msgNewKeys") }