Skip to content

Commit

Permalink
add proper version negotiation (#109)
Browse files Browse the repository at this point in the history
Until now, the version was strictly checked against the software version.

Now that the project gains in popularity and that a spec is coming, other people may try implementing the protocol.

This adds a version negotiation mechanism that is still quite strict.

- Add the concept of protocol experimental spec version, this currently implements the alpha-00 experimental spec version.
- Do not check the software version during the negotiation, except for versions before alpha-00, where the implementation name then MUST be francoismichel/ssh3
- Right now, the server supports clients with version alpha-00 and older, i.e. clients with implementation name francoismichel/ssh3 and software version between 0.1.4 (included) and 0.1.5 (included)
- Future clients implementing future draft versions will have to adapt their version to the server version. If the server does not support the client version, the client has to retry a new connection with a version that matches the server's version. A message should be displayed to the user, ensuring that the user knows that they should update the server version. Such a message will disappear with stable versions of the protocol.


* add version negotiation

* server: log client version when debug enabled

* client:  try request with matching version upon firstversion mismatch

* better logging about version negotiation and conversation establishment

* update go.mod

* add integration tests for backwards compatibility
  • Loading branch information
francoismichel committed Jan 16, 2024
1 parent 7495fc2 commit 53d7264
Show file tree
Hide file tree
Showing 7 changed files with 393 additions and 60 deletions.
5 changes: 2 additions & 3 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ func Dial(ctx context.Context, options *Options, qconn quic.EarlyConnection,
log.Fatal().Msgf("%s", err)
}
req.Proto = "ssh3"
req.Header.Set("User-Agent", ssh3.GetCurrentVersionString())

var identity ssh3.Identity
for _, method := range options.authMethods {
Expand Down Expand Up @@ -363,8 +362,8 @@ func Dial(ctx context.Context, options *Options, qconn quic.EarlyConnection,
return nil, err
}

log.Debug().Msgf("send CONNECT request to the server")
err = conv.EstablishClientConversation(req, roundTripper)
log.Debug().Msgf("establish conversation with the server")
err = conv.EstablishClientConversation(req, roundTripper, ssh3.AVAILABLE_CLIENT_VERSIONS)
if errors.Is(err, util.Unauthorized{}) {
log.Error().Msgf("Access denied from the server: unauthorized")
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion cmd/ssh3/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ func mainWithStatusCode() int {

if qconn == nil {
if status != 0 {
log.Error().Msgf("could not setup transport for proxy client: %s", err)
log.Error().Msgf("could not setup transport for proxy client.")
}
return status
}
Expand Down
79 changes: 67 additions & 12 deletions conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"net/http"

"github.com/francoismichel/ssh3/util"
"golang.org/x/exp/slices"

"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
Expand All @@ -35,6 +36,7 @@ type Conversation struct {
context context.Context
cancelContext context.CancelCauseFunc
conversationID ConversationID // generated using TLS exporters
peerVersion Version

channelsAcceptQueue *util.AcceptQueue[Channel]
}
Expand Down Expand Up @@ -68,11 +70,13 @@ func NewClientConversation(maxPacketsize uint64, defaultDatagramsQueueSize uint6
context: backgroundCtx,
cancelContext: backgroundCancelCauseFunc,
conversationID: convID,

// peerVersion set afterwards
}
return conv, nil
}

func (c *Conversation) EstablishClientConversation(req *http.Request, roundTripper *http3.RoundTripper) error {
func (c *Conversation) EstablishClientConversation(req *http.Request, roundTripper *http3.RoundTripper, supportedVersions []Version) error {

roundTripper.StreamHijacker = func(frameType http3.FrameType, qconn quic.Connection, stream quic.Stream, err error) (bool, error) {
if err != nil {
Expand Down Expand Up @@ -107,24 +111,73 @@ func (c *Conversation) EstablishClientConversation(req *http.Request, roundTripp
c.channelsAcceptQueue.Add(newChannel)
return true, nil
}
rsp, err := roundTripper.RoundTripOpt(req, http3.RoundTripOpt{DontCloseRequestStream: true})

doReq := func(version Version, req *http.Request) (*http.Response, Version, error) {
req.Header.Set("User-Agent", version.GetVersionString())
log.Debug().Msgf("send %s request on URL %s, User-Agent=\"%s\"", req.Method, req.URL, req.Header.Get("User-Agent"))
rsp, err := roundTripper.RoundTripOpt(req, http3.RoundTripOpt{DontCloseRequestStream: true})
if err != nil {
return rsp, Version{}, err
}

log.Debug().Msgf("got response with %s status code", rsp.Status)

serverVersionStr := rsp.Header.Get("Server")
serverVersion, err := ParseVersionString(serverVersionStr)
if err != nil {
log.Error().Msgf("Could not parse server version: \"%s\"", serverVersionStr)
if rsp.StatusCode == 200 {
return rsp, Version{}, InvalidSSHVersion{versionString: serverVersionStr}
}
} else {
log.Debug().Msgf("server has valid version \"%s\" (protocol version = %s, software version = %s)",
serverVersionStr, serverVersion.GetProtocolVersion(), serverVersion.GetSoftwareVersion())
}
return rsp, serverVersion, nil
}

rsp, serverVersion, err := doReq(ThisVersion(), req)
if err != nil {
return err
}

serverVersion := rsp.Header.Get("Server")
major, minor, patch, err := ParseVersionString(serverVersion)
if err != nil {
log.Error().Msgf("Could not parse server version: \"%s\"", serverVersion)
if rsp.StatusCode == 200 {
return InvalidSSHVersion{versionString: serverVersion}
serverProtocolVersion := serverVersion.GetProtocolVersion()
thisProtocolVersion := ThisVersion().GetProtocolVersion()
if rsp.StatusCode == http.StatusForbidden && serverProtocolVersion != thisProtocolVersion {
// This version negotiation code might feel a bit heavy but is only there for a smooth transition
// between early versions and versions coming from an actual IETF specification that include
// proper version negotiation. Older version of this implementation strictly check the exact protocol
// version (i.e. must be 3.0) and then check the software version. In next iterations, everything will be
// based on the protocol version for better interoperability.

// see if there is an exact version match (including software version, which is useful
// for old versions that do not support version negotiation based on the protocol version)
matchingVersionIndex := slices.Index(supportedVersions, serverVersion)

// there is no exact match, the implementation/software version might differ, but the
// protocol version may still match
if matchingVersionIndex == -1 {
matchingVersionIndex = slices.IndexFunc(supportedVersions, func(supportedVersion Version) bool {
return serverProtocolVersion == supportedVersion.GetProtocolVersion()
})
}
if matchingVersionIndex != -1 {
log.Warn().Msgf("The server runs an old version of the protocol (%s). This software is still experimental, "+
"you may want to update the server version before support is removed. Also, note that connecting to old "+
"servers may increase the connection establishment time.", serverVersion.GetVersionString())
// now retry the request with the compatible version
rsp, serverVersion, err = doReq(supportedVersions[matchingVersionIndex], req)
if err != nil {
return err
}
}
} else if major > MAJOR || minor > MINOR {
log.Warn().Msgf("The server runs a higher SSH version (%d.%d.%d), you may want to consider to update the client (currently %d.%d.%d)",
major, minor, patch, MAJOR, MINOR, PATCH)
}

if rsp.StatusCode == 200 {
if !IsVersionSupported(serverVersion) {
log.Warn().Msgf("The server runs an unsupported SSH version (%s), you may want to consider to update the client (currently %s)",
serverVersion.GetProtocolVersion(), ThisVersion().GetProtocolVersion())
}
c.controlStream = rsp.Body.(http3.HTTPStreamer).HTTPStream()
c.streamCreator = rsp.Body.(http3.Hijacker).StreamCreator()
qconn := c.streamCreator.(quic.Connection)
Expand Down Expand Up @@ -159,6 +212,7 @@ func (c *Conversation) EstablishClientConversation(req *http.Request, roundTripp
}
}
}()
c.peerVersion = serverVersion
return nil
} else if rsp.StatusCode == http.StatusUnauthorized {
return util.Unauthorized{}
Expand All @@ -177,7 +231,7 @@ func (c *Conversation) EstablishClientConversation(req *http.Request, roundTripp
}
}

func NewServerConversation(ctx context.Context, controlStream http3.Stream, qconn quic.Connection, messageSender util.DatagramSender, maxPacketsize uint64) (*Conversation, error) {
func NewServerConversation(ctx context.Context, controlStream http3.Stream, qconn quic.Connection, messageSender util.DatagramSender, maxPacketsize uint64, peerVersion Version) (*Conversation, error) {
backgroundContext, backgroundCancelFunc := context.WithCancelCause(ctx)

tls := qconn.ConnectionState().TLS
Expand All @@ -197,6 +251,7 @@ func NewServerConversation(ctx context.Context, controlStream http3.Stream, qcon
context: backgroundContext,
cancelContext: backgroundCancelFunc,
conversationID: convID,
peerVersion: peerVersion,
}
return conv, nil
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ require (
github.com/quic-go/quic-go v0.40.1-0.20240102075208-1083d1fb8f98
github.com/rs/zerolog v1.31.0
golang.org/x/crypto v0.17.0
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db
golang.org/x/oauth2 v0.13.0
golang.org/x/term v0.15.0
)
Expand All @@ -26,7 +27,6 @@ require (
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/quic-go/qpack v0.4.0 // indirect
go.uber.org/mock v0.3.0 // indirect
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
golang.org/x/mod v0.12.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.15.0 // indirect
Expand Down
63 changes: 57 additions & 6 deletions integration_tests/ssh3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net"
"os"
"os/exec"
"path"
"time"

. "github.com/onsi/ginkgo/v2"
Expand All @@ -22,7 +23,7 @@ const DEFAULT_URL_PATH = "/ssh3-tests"
const DEFAULT_PROXY_URL_PATH = "/ssh3-tests-proxy"

var serverCommand *exec.Cmd
var serverSession *Session
var serverSessions map[string]*Session = make(map[string]*Session) // bind address to session
var proxyServerCommand *exec.Cmd
var proxyServerSession *Session
var rsaPrivKeyPath string
Expand All @@ -33,6 +34,11 @@ var username string
const serverBind = "127.0.0.1:4433"
const proxyServerBind = "127.0.0.1:4444"

var oldServerBinds map[string]string = map[string]string{
"v0.1.5-rc1": "127.0.0.1:5000",
"v0.1.5-rc5": "127.0.0.1:5001",
} // tag version to bind string

func IPv6LoopbackAvailable(addrs []net.Addr) bool {
for _, addr := range addrs {
Expect(addr).To(BeAssignableToTypeOf(&net.IPNet{}))
Expand Down Expand Up @@ -69,9 +75,34 @@ var _ = BeforeSuite(func() {
"-cert", os.Getenv("CERT_PEM"),
"-key", os.Getenv("CERT_PRIV_KEY"))
serverCommand.Env = append(serverCommand.Env, "SSH3_LOG_LEVEL=debug")
serverSession, err = Start(serverCommand, GinkgoWriter, GinkgoWriter)
session, err := Start(serverCommand, GinkgoWriter, GinkgoWriter)
Expect(err).ToNot(HaveOccurred())

serverSessions[serverBind] = session

for tag, bind := range oldServerBinds {
gobin, err := os.MkdirTemp("", fmt.Sprintf("ssh3-backwards-compatible-versions-%s", tag))
Expect(err).ToNot(HaveOccurred())
cmd := exec.Command("go", "install", fmt.Sprintf("github.com/francoismichel/ssh3/cmd/ssh3-server@%s", tag))
cmd.Env = os.Environ()
cmd.Env = append(cmd.Env, fmt.Sprintf("GOBIN=%s", gobin))
err = cmd.Run()
Expect(err).ToNot(HaveOccurred())
serverPath := path.Join(gobin, "ssh3-server")
Expect(err).ToNot(HaveOccurred())
backwardsCompatibleServerCommand := exec.Command(serverPath,
"-bind", bind,
"-v",
"-enable-password-login",
"-url-path", DEFAULT_URL_PATH,
"-cert", os.Getenv("CERT_PEM"),
"-key", os.Getenv("CERT_PRIV_KEY"))
serverCommand.Env = append(backwardsCompatibleServerCommand.Env, "SSH3_LOG_LEVEL=debug")
session, err = Start(backwardsCompatibleServerCommand, GinkgoWriter, GinkgoWriter)
Expect(err).ToNot(HaveOccurred())
serverSessions[bind] = session
}

proxyServerCommand = exec.Command(ssh3ServerPath,
"-bind", proxyServerBind,
"-v",
Expand All @@ -83,6 +114,8 @@ var _ = BeforeSuite(func() {
proxyServerSession, err = Start(proxyServerCommand, GinkgoWriter, GinkgoWriter)
Expect(err).ToNot(HaveOccurred())

serverSessions[proxyServerBind] = proxyServerSession

rsaPrivKeyPath = os.Getenv("TESTUSER_PRIVKEY")
ed25519PrivKeyPath = os.Getenv("TESTUSER_ED25519_PRIVKEY")
attackerPrivKeyPath = os.Getenv("ATTACKER_PRIVKEY")
Expand All @@ -96,7 +129,7 @@ var _ = BeforeSuite(func() {

var _ = AfterSuite(func() {
CleanupBuildArtifacts()
if serverSession != nil {
for _, serverSession := range serverSessions {
serverSession.Terminate()
}
})
Expand All @@ -118,21 +151,24 @@ var _ = Describe("Testing the ssh3 cli", func() {
if os.Getenv("SSH3_INTEGRATION_TESTS_WITH_SERVER_ENABLED") != "1" {
Skip("skipping integration tests")
}
Consistently(serverSession, "200ms").ShouldNot(Exit())
Consistently(serverSessions[serverBind], "200ms").ShouldNot(Exit())
})

Context("Insecure", func() {
var clientArgs []string
getClientArgs := func(privKeyPath string, additionalArgs ...string) []string {
getClientArgsWithBind := func(privKeyPath string, bind string, additionalArgs ...string) []string {
args := []string{
"-v",
"-insecure",
"-privkey", privKeyPath,
}
args = append(args, additionalArgs...)
args = append(args, fmt.Sprintf("%s@%s%s", username, serverBind, DEFAULT_URL_PATH))
args = append(args, fmt.Sprintf("%s@%s%s", username, bind, DEFAULT_URL_PATH))
return args
}
getClientArgs := func(privKeyPath string, additionalArgs ...string) []string {
return getClientArgsWithBind(privKeyPath, serverBind, additionalArgs...)
}

Context("Client behaviour", func() {
It("Should connect using an RSA privkey", func() {
Expand All @@ -153,6 +189,21 @@ var _ = Describe("Testing the ssh3 cli", func() {
Eventually(session).Should(Say("Hello, World!\n"))
})

for key, val := range oldServerBinds {
// actually capture the values of key,val, as directly referring them in the code below will only keep the value of the last iteration
tag, bind := key, val
When("server version is"+tag+", bind is"+bind, func() {
It("Should connect using an RSA privkey to old supported server", func() {
clientArgs = append(getClientArgsWithBind(rsaPrivKeyPath, bind), "echo", "Hello, World!")
command := exec.Command(ssh3Path, clientArgs...)
session, err := Start(command, GinkgoWriter, GinkgoWriter)
Expect(err).ToNot(HaveOccurred())
Eventually(session).Should(Exit(0))
Eventually(session).Should(Say("Hello, World!\n"))
})
})
}

It("Should connect using an ed25519 privkey", func() {
clientArgs = append(getClientArgs(ed25519PrivKeyPath), "echo", "Hello, World!")
command := exec.Command(ssh3Path, clientArgs...)
Expand Down
19 changes: 10 additions & 9 deletions unix_server/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ func HandleAuths(ctx context.Context, enablePasswordLogin bool, defaultMaxPacket
}
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Server", ssh3.GetCurrentVersionString())
major, minor, patch, err := ssh3.ParseVersionString(r.UserAgent())
log.Debug().Msgf("received request from User-Agent %s (major %d, minor %d, patch %d)", r.UserAgent(), major, minor, patch)
peerVersion, err := ssh3.ParseVersionString(r.UserAgent())
log.Debug().Msgf("received request from User-Agent %s", r.UserAgent())
log.Debug().Msgf("peer version: protocol version %s, software version %s", peerVersion.GetProtocolVersion(), peerVersion.GetSoftwareVersion())
// currently apply strict version rules
if err != nil || major != ssh3.MAJOR || minor != ssh3.MINOR {
if err == nil {
http.Error(w, fmt.Sprintf("Unsupported version: %d.%d.%d not supported by server with version %s", major, minor, patch, ssh3.GetCurrentVersionString()), http.StatusForbidden)
} else {
http.Error(w, "Unsupported user-agent", http.StatusForbidden)
}
if err != nil {
http.Error(w, fmt.Sprintf("Unsupported user-agent: %s", r.UserAgent()[:100]), http.StatusForbidden)
return
}
if !ssh3.IsVersionSupported(peerVersion) {
http.Error(w, fmt.Sprintf("Unsupported version: %s not supported by server with version %s", peerVersion.GetProtocolVersion(), ssh3.ThisVersion().GetProtocolVersion()), http.StatusForbidden)
return
}
// Only call Flush() here, as calling flush prevents from adding the Content-Length header to the response
Expand All @@ -52,7 +53,7 @@ func HandleAuths(ctx context.Context, enablePasswordLogin bool, defaultMaxPacket
return
}
str := r.Body.(http3.HTTPStreamer).HTTPStream()
conv, err := ssh3.NewServerConversation(ctx, str, qconn, qconn, defaultMaxPacketSize)
conv, err := ssh3.NewServerConversation(ctx, str, qconn, qconn, defaultMaxPacketSize, peerVersion)
if err != nil {
log.Error().Msgf("could not create new server conversation")
w.WriteHeader(http.StatusInternalServerError)
Expand Down
Loading

0 comments on commit 53d7264

Please sign in to comment.