Skip to content

Commit

Permalink
fix(netxlite): allow overriding default cert pool
Browse files Browse the repository at this point in the history
This diff tweaks #1068
to make sure overriding the default cert pool works.

In #1068 we introduced
code to add this functionality but we never tested it was working
as intended. It turns out it was not!

Because this diff amends the previous diff, we'll consider it
part of ooni/probe#2135.
  • Loading branch information
bassosimone committed Feb 1, 2023
1 parent 7485153 commit faa5355
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 22 deletions.
2 changes: 1 addition & 1 deletion internal/netxlite/quic.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func (d *quicDialerQUICGo) dialEarlyContext(ctx context.Context,
func (d *quicDialerQUICGo) maybeApplyTLSDefaults(config *tls.Config, port int) *tls.Config {
config = config.Clone()
if config.RootCAs == nil {
config.RootCAs = defaultCertPool
config.RootCAs = NewDefaultCertPool()
}
if len(config.NextProtos) <= 0 {
switch port {
Expand Down
8 changes: 4 additions & 4 deletions internal/netxlite/quic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ func TestQUICDialerQUICGo(t *testing.T) {
if tlsConfig.RootCAs != nil {
t.Fatal("tlsConfig.RootCAs should not have been changed")
}
if gotTLSConfig.RootCAs != defaultCertPool {
t.Fatal("invalid gotTLSConfig.RootCAs")
if gotTLSConfig.RootCAs == nil {
t.Fatal("gotTLSConfig.RootCAs should have been set")
}
if tlsConfig.NextProtos != nil {
t.Fatal("tlsConfig.NextProtos should not have been changed")
Expand Down Expand Up @@ -289,8 +289,8 @@ func TestQUICDialerQUICGo(t *testing.T) {
if tlsConfig.RootCAs != nil {
t.Fatal("tlsConfig.RootCAs should not have been changed")
}
if gotTLSConfig.RootCAs != defaultCertPool {
t.Fatal("invalid gotTLSConfig.RootCAs")
if gotTLSConfig.RootCAs == nil {
t.Fatal("gotTLSConfig.RootCAs should have been set")
}
if tlsConfig.NextProtos != nil {
t.Fatal("tlsConfig.NextProtos should not have been changed")
Expand Down
6 changes: 1 addition & 5 deletions internal/netxlite/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,6 @@ type tlsHandshakerConfigurable struct {

var _ model.TLSHandshaker = &tlsHandshakerConfigurable{}

// defaultCertPool is the cert pool we use by default. We store this
// value into a private variable to enable for unit testing.
var defaultCertPool = NewDefaultCertPool()

// tlsMaybeConnectionState returns the connection state if error is nil
// and otherwise just returns an empty state to the caller.
func tlsMaybeConnectionState(conn TLSConn, err error) tls.ConnectionState {
Expand All @@ -213,7 +209,7 @@ func (h *tlsHandshakerConfigurable) Handshake(
conn.SetDeadline(time.Now().Add(timeout))
if config.RootCAs == nil {
config = config.Clone()
config.RootCAs = defaultCertPool
config.RootCAs = NewDefaultCertPool()
}
tlsconn, err := h.newConn(conn, config)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/netxlite/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ func TestTLSHandshakerConfigurable(t *testing.T) {
if config.RootCAs != nil {
t.Fatal("config.RootCAs should still be nil")
}
if gotTLSConfig.RootCAs != defaultCertPool {
if gotTLSConfig.RootCAs == nil {
t.Fatal("gotTLSConfig.RootCAs has not been correctly set")
}
})
Expand Down
57 changes: 46 additions & 11 deletions internal/netxlite/tproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ package netxlite
import (
"context"
"crypto/x509"
"net"
"net/http"
"net/http/httptest"
"runtime"
"strings"
"testing"
"time"

"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
"github.com/ooni/probe-cli/v3/internal/runtimex"
)

func TestDefaultTProxy(t *testing.T) {
Expand All @@ -36,16 +39,48 @@ func TestDefaultTProxy(t *testing.T) {
}

func TestWithCustomTProxy(t *testing.T) {
expected := x509.NewCertPool()
tproxy := &mocks.UnderlyingNetwork{
MockMaybeModifyPool: func(pool *x509.CertPool) *x509.CertPool {
runtimex.Assert(expected != pool, "got unexpected pool")
return expected
},
}
WithCustomTProxy(tproxy, func() {
if NewDefaultCertPool() != expected {
t.Fatal("unexpected pool")

t.Run("we can override the default cert pool", func(t *testing.T) {
srvr := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(444)
}))
defer srvr.Close()

// TODO(bassosimone): we need a more compact and ergonomic
// way of overriding the underlying network
tproxy := &mocks.UnderlyingNetwork{
MockDialContext: func(ctx context.Context, timeout time.Duration, network string, address string) (net.Conn, error) {
return (&DefaultTProxy{}).DialContext(ctx, timeout, network, address)
},
MockListenUDP: func(network string, addr *net.UDPAddr) (model.UDPLikeConn, error) {
return (&DefaultTProxy{}).ListenUDP(network, addr)
},
MockGetaddrinfoLookupANY: func(ctx context.Context, domain string) ([]string, string, error) {
return (&DefaultTProxy{}).GetaddrinfoLookupANY(ctx, domain)
},
MockGetaddrinfoResolverNetwork: func() string {
return (&DefaultTProxy{}).GetaddrinfoResolverNetwork()
},
MockMaybeModifyPool: func(*x509.CertPool) *x509.CertPool {
pool := x509.NewCertPool()
pool.AddCert(srvr.Certificate())
return pool
},
}

WithCustomTProxy(tproxy, func() {
clnt := NewHTTPClientStdlib(model.DiscardLogger)
req, err := http.NewRequestWithContext(context.Background(), "GET", srvr.URL, nil)
if err != nil {
t.Fatal(err)
}
resp, err := clnt.Do(req)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != 444 {
t.Fatal("unexpected status code")
}
})
})
}

0 comments on commit faa5355

Please sign in to comment.