Skip to content

Commit

Permalink
refactor: start refactoring session resolver (#807)
Browse files Browse the repository at this point in the history
This diff addresses the following points of ooni/probe#2135:

- [x] the `childResolver` type is useless and we can use `model.Resolver` directly;
- [x] we should use `model/mocks` instead of custom fakes;
- [x] we should not use `log.Log` rather we should use `model.DiscardLogger`;
- [x] make `timeLimitedLookup` easier to test with a `-short` tests;
- [x] ensure `timeLimitedLookup` returns as soon as its context expires regardless of the child resolver;

Subsequent diffs will address more points mentioned in there.
  • Loading branch information
bassosimone committed Jun 8, 2022
1 parent dea23b4 commit fe29b43
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 84 deletions.
51 changes: 36 additions & 15 deletions internal/engine/internal/sessionresolver/childresolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,46 @@ package sessionresolver
import (
"context"
"time"

"github.com/ooni/probe-cli/v3/internal/model"
)

// childResolver is the DNS client that this package uses
// to perform individual domain name resolutions.
type childResolver interface {
// LookupHost performs a DNS lookup.
LookupHost(ctx context.Context, domain string) ([]string, error)
// defaultTimeLimitedLookupTimeout is the default timeout the code should
// pass to the timeLimitedLookup function.
//
// This algorithm is similar to Firefox using TRR2 mode. See:
// https://wiki.mozilla.org/Trusted_Recursive_Resolver#DNS-over-HTTPS_Prefs_in_Firefox
//
// We use a higher timeout than Firefox's timeout (1.5s) to be on the safe side
// and therefore see to use DoH more often.
const defaultTimeLimitedLookupTimeout = 4 * time.Second

// CloseIdleConnections closes idle connections.
CloseIdleConnections()
// timeLimitedLookup performs a time-limited lookup using the given re.
func timeLimitedLookup(ctx context.Context, re model.Resolver, hostname string) ([]string, error) {
return timeLimitedLookupWithTimeout(ctx, re, hostname, defaultTimeLimitedLookupTimeout)
}

// timeLimitedLookup performs a time-limited lookup using the given re.
func (r *Resolver) timeLimitedLookup(ctx context.Context, re childResolver, hostname string) ([]string, error) {
// Algorithm similar to Firefox TRR2 mode. See:
// https://wiki.mozilla.org/Trusted_Recursive_Resolver#DNS-over-HTTPS_Prefs_in_Firefox
// We use a higher timeout than Firefox's timeout (1.5s) to be on the safe side
// and therefore see to use DoH more often.
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
// timeLimitedLookupResult is the result of a timeLimitedLookup
type timeLimitedLookupResult struct {
addrs []string
err error
}

// timeLimitedLookupWithTimeout is like timeLimitedLookup but with explicit timeout.
func timeLimitedLookupWithTimeout(ctx context.Context, re model.Resolver,
hostname string, timeout time.Duration) ([]string, error) {
outch := make(chan *timeLimitedLookupResult, 1) // buffer
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return re.LookupHost(ctx, hostname)
go func() {
out := &timeLimitedLookupResult{}
out.addrs, out.err = re.LookupHost(ctx, hostname)
outch <- out
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case out := <-outch:
return out.addrs, out.err
}
}
63 changes: 25 additions & 38 deletions internal/engine/internal/sessionresolver/childresolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,51 +8,35 @@ import (
"time"

"github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
)

type FakeResolver struct {
Closed bool
Data []string
Err error
Sleep time.Duration
}

func (r *FakeResolver) LookupHost(ctx context.Context, hostname string) ([]string, error) {
select {
case <-time.After(r.Sleep):
return r.Data, r.Err
case <-ctx.Done():
return nil, ctx.Err()
}
}

func (r *FakeResolver) CloseIdleConnections() {
r.Closed = true
}

func TestTimeLimitedLookupSuccess(t *testing.T) {
reso := &Resolver{}
re := &FakeResolver{
Data: []string{"8.8.8.8", "8.8.4.4"},
expected := []string{"8.8.8.8", "8.8.4.4"}
re := &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return expected, nil
},
}
ctx := context.Background()
out, err := reso.timeLimitedLookup(ctx, re, "dns.google")
out, err := timeLimitedLookup(ctx, re, "dns.google")
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(re.Data, out); diff != "" {
if diff := cmp.Diff(expected, out); diff != "" {
t.Fatal(diff)
}
}

func TestTimeLimitedLookupFailure(t *testing.T) {
reso := &Resolver{}
re := &FakeResolver{
Err: io.EOF,
re := &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, io.EOF
},
}
ctx := context.Background()
out, err := reso.timeLimitedLookup(ctx, re, "dns.google")
if !errors.Is(err, re.Err) {
out, err := timeLimitedLookup(ctx, re, "dns.google")
if !errors.Is(err, io.EOF) {
t.Fatal("not the error we expected", err)
}
if out != nil {
Expand All @@ -61,20 +45,23 @@ func TestTimeLimitedLookupFailure(t *testing.T) {
}

func TestTimeLimitedLookupWillTimeout(t *testing.T) {
if testing.Short() {
t.Skip("skip test in short mode")
}
reso := &Resolver{}
re := &FakeResolver{
Err: io.EOF,
Sleep: 20 * time.Second,
done := make(chan bool)
block := make(chan bool)
re := &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
defer close(done)
<-block
return nil, io.EOF
},
}
ctx := context.Background()
out, err := reso.timeLimitedLookup(ctx, re, "dns.google")
out, err := timeLimitedLookupWithTimeout(ctx, re, "dns.google", 10*time.Millisecond)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatal("not the error we expected", err)
}
if out != nil {
t.Fatal("expected nil here")
}
close(block)
<-done
}
9 changes: 6 additions & 3 deletions internal/engine/internal/sessionresolver/clientmaker.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package sessionresolver

import "github.com/ooni/probe-cli/v3/internal/engine/netx"
import (
"github.com/ooni/probe-cli/v3/internal/engine/netx"
"github.com/ooni/probe-cli/v3/internal/model"
)

// dnsclientmaker makes a new resolver.
type dnsclientmaker interface {
// Make makes a new resolver.
Make(config netx.Config, URL string) (childResolver, error)
Make(config netx.Config, URL string) (model.Resolver, error)
}

// clientmaker returns a valid dnsclientmaker
Expand All @@ -20,6 +23,6 @@ func (r *Resolver) clientmaker() dnsclientmaker {
type defaultDNSClientMaker struct{}

// Make implements dnsclientmaker.Make.
func (*defaultDNSClientMaker) Make(config netx.Config, URL string) (childResolver, error) {
func (*defaultDNSClientMaker) Make(config netx.Config, URL string) (model.Resolver, error) {
return netx.NewDNSClient(config, URL)
}
5 changes: 3 additions & 2 deletions internal/engine/internal/sessionresolver/clientmaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@ import (
"testing"

"github.com/ooni/probe-cli/v3/internal/engine/netx"
"github.com/ooni/probe-cli/v3/internal/model"
)

type fakeDNSClientMaker struct {
reso childResolver
reso model.Resolver
err error
savedConfig netx.Config
savedURL string
}

func (c *fakeDNSClientMaker) Make(config netx.Config, URL string) (childResolver, error) {
func (c *fakeDNSClientMaker) Make(config netx.Config, URL string) (model.Resolver, error) {
c.savedConfig = config
c.savedURL = URL
return c.reso, c.err
Expand Down
12 changes: 4 additions & 8 deletions internal/engine/internal/sessionresolver/resolvermaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"strings"
"time"

"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/bytecounter"
"github.com/ooni/probe-cli/v3/internal/engine/netx"
"github.com/ooni/probe-cli/v3/internal/model"
Expand Down Expand Up @@ -71,15 +70,12 @@ func (r *Resolver) byteCounter() *bytecounter.Counter {

// logger returns the configured logger or a default
func (r *Resolver) logger() model.Logger {
if r.Logger != nil {
return r.Logger
}
return log.Log
return model.ValidLoggerOrDefault(r.Logger)
}

// newresolver creates a new resolver with the given config and URL. This is
// where we expand http3 to https and set the h3 options.
func (r *Resolver) newresolver(URL string) (childResolver, error) {
func (r *Resolver) newresolver(URL string) (model.Resolver, error) {
h3 := strings.HasPrefix(URL, "http3://")
if h3 {
URL = strings.Replace(URL, "http3://", "https://", 1)
Expand All @@ -95,7 +91,7 @@ func (r *Resolver) newresolver(URL string) (childResolver, error) {

// getresolver returns a resolver with the given URL. This function caches
// already allocated resolvers so we only allocate them once.
func (r *Resolver) getresolver(URL string) (childResolver, error) {
func (r *Resolver) getresolver(URL string) (model.Resolver, error) {
defer r.mu.Unlock()
r.mu.Lock()
if re, found := r.res[URL]; found {
Expand All @@ -106,7 +102,7 @@ func (r *Resolver) getresolver(URL string) (childResolver, error) {
return nil, err // config err?
}
if r.res == nil {
r.res = make(map[string]childResolver)
r.res = make(map[string]model.Resolver)
}
r.res[URL] = re
return re, nil
Expand Down
47 changes: 34 additions & 13 deletions internal/engine/internal/sessionresolver/resolvermaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ import (
"strings"
"testing"

"github.com/apex/log"
"github.com/ooni/probe-cli/v3/internal/bytecounter"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/model/mocks"
)

func TestDefaultByteCounter(t *testing.T) {
Expand All @@ -18,18 +19,33 @@ func TestDefaultByteCounter(t *testing.T) {
}

func TestDefaultLogger(t *testing.T) {
logger := &log.Logger{}
reso := &Resolver{Logger: logger}
lo := reso.logger()
if lo != logger {
t.Fatal("expected another logger here counter")
}
t.Run("when using a different logger", func(t *testing.T) {
logger := &mocks.Logger{}
reso := &Resolver{Logger: logger}
lo := reso.logger()
if lo != logger {
t.Fatal("expected another logger here")
}
})

t.Run("when no logger is set", func(t *testing.T) {
reso := &Resolver{Logger: nil}
lo := reso.logger()
if lo != model.DiscardLogger {
t.Fatal("expected another logger here")
}
})
}

func TestGetResolverHTTPSStandard(t *testing.T) {
bc := bytecounter.New()
URL := "https://dns.google"
re := &FakeResolver{}
var closed bool
re := &mocks.Resolver{
MockCloseIdleConnections: func() {
closed = true
},
}
cmk := &fakeDNSClientMaker{reso: re}
reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc}
out, err := reso.getresolver(URL)
Expand All @@ -47,7 +63,7 @@ func TestGetResolverHTTPSStandard(t *testing.T) {
t.Fatal("not the result we expected")
}
reso.closeall()
if re.Closed != true {
if closed != true {
t.Fatal("was not closed")
}
if cmk.savedURL != URL {
Expand All @@ -62,15 +78,20 @@ func TestGetResolverHTTPSStandard(t *testing.T) {
if cmk.savedConfig.HTTP3Enabled != false {
t.Fatal("unexpected HTTP3Enabled")
}
if cmk.savedConfig.Logger != log.Log {
if cmk.savedConfig.Logger != model.DiscardLogger {
t.Fatal("unexpected Log")
}
}

func TestGetResolverHTTP3(t *testing.T) {
bc := bytecounter.New()
URL := "http3://dns.google"
re := &FakeResolver{}
var closed bool
re := &mocks.Resolver{
MockCloseIdleConnections: func() {
closed = true
},
}
cmk := &fakeDNSClientMaker{reso: re}
reso := &Resolver{dnsClientMaker: cmk, ByteCounter: bc}
out, err := reso.getresolver(URL)
Expand All @@ -88,7 +109,7 @@ func TestGetResolverHTTP3(t *testing.T) {
t.Fatal("not the result we expected")
}
reso.closeall()
if re.Closed != true {
if closed != true {
t.Fatal("was not closed")
}
if cmk.savedURL != strings.Replace(URL, "http3://", "https://", 1) {
Expand All @@ -103,7 +124,7 @@ func TestGetResolverHTTP3(t *testing.T) {
if cmk.savedConfig.HTTP3Enabled != true {
t.Fatal("unexpected HTTP3Enabled")
}
if cmk.savedConfig.Logger != log.Log {
if cmk.savedConfig.Logger != model.DiscardLogger {
t.Fatal("unexpected Log")
}
}
Expand Down
4 changes: 2 additions & 2 deletions internal/engine/internal/sessionresolver/sessionresolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ type Resolver struct {
// res maps a URL to a child resolver. We will
// construct child resolvers just once and we
// will track them into this field.
res map[string]childResolver
res map[string]model.Resolver
}

// CloseIdleConnections closes the idle connections, if any. This
Expand Down Expand Up @@ -169,7 +169,7 @@ func (r *Resolver) lookupHost(ctx context.Context, ri *resolverinfo, hostname st
ri.Score = 0 // this is a hard error
return nil, err
}
addrs, err := r.timeLimitedLookup(ctx, re, hostname)
addrs, err := timeLimitedLookup(ctx, re, hostname)
if err == nil {
r.logger().Infof("sessionresolver: %s... %v", ri.URL, model.ErrorToStringOrOK(nil))
ri.Score = ewma*1.0 + (1-ewma)*ri.Score // increase score
Expand Down
Loading

0 comments on commit fe29b43

Please sign in to comment.