Skip to content

Commit

Permalink
Ensure the redis TLS connection uses k6's netext.Dialer under the hood
Browse files Browse the repository at this point in the history
  • Loading branch information
oleiade committed Nov 2, 2023
1 parent d542b1f commit 9bb7ed0
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 32 deletions.
99 changes: 67 additions & 32 deletions redis/client.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package redis

import (
"context"
"crypto/tls"
"fmt"
"net"
"time"

"github.com/dop251/goja"
"github.com/redis/go-redis/v9"
"go.k6.io/k6/js/common"
"go.k6.io/k6/js/modules"
"go.k6.io/k6/lib/netext"
"go.k6.io/k6/lib"
)

// Client represents the Client constructor (i.e. `new redis.Client()`) and
Expand Down Expand Up @@ -1080,37 +1082,34 @@ func (c *Client) connect() error {
}

tlsCfg := c.redisOptions.TLSConfig
if tlsCfg != nil {
if vuState.TLSConfig != nil {
// Merge k6 TLS configuration with the one we received from the
// Client constructor. This will need adjusting depending on which
// options we want to expose in the Redis module, and how we want
// the override to work.
tlsCfg.InsecureSkipVerify = vuState.TLSConfig.InsecureSkipVerify
tlsCfg.CipherSuites = vuState.TLSConfig.CipherSuites
tlsCfg.MinVersion = vuState.TLSConfig.MinVersion
tlsCfg.MaxVersion = vuState.TLSConfig.MaxVersion
tlsCfg.Renegotiation = vuState.TLSConfig.Renegotiation
tlsCfg.KeyLogWriter = vuState.TLSConfig.KeyLogWriter

tlsCfg.Certificates = append(tlsCfg.Certificates, vuState.TLSConfig.Certificates...)

// TODO: Merge vuState.TLSConfig.RootCAs with
// c.redisOptions.TLSConfig. k6 currently doesn't allow setting
// this, so it doesn't matter right now, but these should be merged.
// I couldn't find a way to do this with the x509.CertPool API
// though...
}

k6dialer, ok := vuState.Dialer.(*netext.Dialer)
if !ok {
panic(fmt.Sprintf("expected *netext.Dialer, got: %T", vuState.Dialer))
}
tlsDialer := &tls.Dialer{
NetDialer: &k6dialer.Dialer,
Config: tlsCfg,
}
c.redisOptions.Dialer = tlsDialer.DialContext
if tlsCfg != nil && vuState.TLSConfig != nil {
// Merge k6 TLS configuration with the one we received from the
// Client constructor. This will need adjusting depending on which
// options we want to expose in the Redis module, and how we want
// the override to work.
tlsCfg.InsecureSkipVerify = vuState.TLSConfig.InsecureSkipVerify
tlsCfg.CipherSuites = vuState.TLSConfig.CipherSuites
tlsCfg.MinVersion = vuState.TLSConfig.MinVersion
tlsCfg.MaxVersion = vuState.TLSConfig.MaxVersion
tlsCfg.Renegotiation = vuState.TLSConfig.Renegotiation
tlsCfg.KeyLogWriter = vuState.TLSConfig.KeyLogWriter
tlsCfg.Certificates = append(tlsCfg.Certificates, vuState.TLSConfig.Certificates...)

// TODO: Merge vuState.TLSConfig.RootCAs with
// c.redisOptions.TLSConfig. k6 currently doesn't allow setting
// this, so it doesn't matter right now, but these should be merged.
// I couldn't find a way to do this with the x509.CertPool API
// though...

// In order to preserve the underlying effects of the [netext.Dialer], such
// as handling blocked hostnames, or handling hostname resolution, we override
// the redis client's dialer with our own function which uses the VU's [netext.Dialer]
// and manually upgrades the connection to TLS.
//
// See Pull Request's #17 [discussion] for more details.
//
// [discussion]: https://github.com/grafana/xk6-redis/pull/17#discussion_r1369707388
c.redisOptions.Dialer = c.upgradeDialerToTLS(vuState.Dialer, tlsCfg)
} else {
c.redisOptions.Dialer = vuState.Dialer.DialContext
}
Expand Down Expand Up @@ -1154,3 +1153,39 @@ func (c *Client) isSupportedType(offset int, args ...interface{}) error {

return nil
}

// DialContextFunc is a function that can be used to dial a connection to a redis server.
type DialContextFunc func(ctx context.Context, network, addr string) (net.Conn, error)

// upgradeDialerToTLS returns a DialContextFunc that uses the provided dialer to
// establish a connection, and then upgrades it to TLS using the provided config.
//
// We use this function to make sure the k6 [netext.Dialer], our redis module uses to establish
// the connection and handle network-related options such as blocked hostnames,
// or hostname resolution, but we also want to use the TLS configuration provided
// by the user.
func (c *Client) upgradeDialerToTLS(dialer lib.DialContexter, config *tls.Config) DialContextFunc {
return func(ctx context.Context, network string, addr string) (net.Conn, error) {
// Use netext.Dialer to establish the connection
rawConn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}

// Upgrade the connection to TLS if needed
tlsConn := tls.Client(rawConn, config)
err = tlsConn.Handshake()
if err != nil {
if closeErr := rawConn.Close(); closeErr != nil {
return nil, fmt.Errorf("failed to close connection after TLS handshake error: %w", closeErr)
}

return nil, err
}

// Overwrite rawConn with the TLS connection
rawConn = tlsConn

return rawConn, nil
}
}
51 changes: 51 additions & 0 deletions redis/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2560,6 +2560,7 @@ type testSetup struct {
state *lib.State
samples chan metrics.SampleContainer
ev *eventloop.EventLoop
tb *httpmultibin.HTTPMultiBin
}

// newTestSetup initializes a new test setup.
Expand Down Expand Up @@ -2618,6 +2619,7 @@ func newTestSetup(t testing.TB) testSetup {
state: state,
samples: samples,
ev: ev,
tb: tb,
}
}

Expand Down Expand Up @@ -2732,3 +2734,52 @@ func TestClientTLSAuth(t *testing.T) {
{"PING"},
}, rs.GotCommands())
}

func TestClientTLSRespectsNetworkOPtions(t *testing.T) {
t.Parallel()

clientCert, clientPKey, err := generateTLSCert()
require.NoError(t, err)

ts := newTestSetup(t)
rs := RunTSecure(t, clientCert)

err = ts.rt.Set("caCert", string(rs.TLSCertificate()))
require.NoError(t, err)
err = ts.rt.Set("clientCert", string(clientCert))
require.NoError(t, err)
err = ts.rt.Set("clientPKey", string(clientPKey))
require.NoError(t, err)

// Set the redis server's IP to be blacklisted.
net, err := lib.ParseCIDR(rs.Addr().IP.String() + "/32")
require.NoError(t, err)
ts.tb.Dialer.Blacklist = []*lib.IPNet{net}

gotScriptErr := ts.ev.Start(func() error {
_, err := ts.rt.RunString(fmt.Sprintf(`
const redis = new Client({
socket: {
host: '%s',
port: %d,
tls: {
ca: [caCert],
cert: clientCert,
key: clientPKey
}
}
});
// This operation triggers a connection to the redis
// server under the hood, and should therefore fail, since
// the server's IP is blacklisted by k6.
redis.sendCommand("PING")
`, rs.Addr().IP.String(), rs.Addr().Port))

return err
})

assert.Error(t, gotScriptErr)
assert.ErrorContains(t, gotScriptErr, "IP ("+rs.Addr().IP.String()+") is in a blacklisted range")
assert.Equal(t, 0, rs.HandledCommandsCount())
}

0 comments on commit 9bb7ed0

Please sign in to comment.