Skip to content

Commit

Permalink
home: introduce marshalable duration
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Jun 12, 2021
1 parent bfb1a57 commit 64b00fd
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 28 deletions.
20 changes: 0 additions & 20 deletions internal/aghtest/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"net"
"strings"
"sync"
"time"

"github.com/miekg/dns"
)
Expand Down Expand Up @@ -180,22 +179,3 @@ func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
func (u *TestErrUpstream) Address() string {
return ""
}

// TestTimeoutUpstream implements upstream.Upstream interface for replacing real
// upstream in tests.
type TestTimeoutUpstream struct {
// waitFor
waitFor time.Duration
}

// Exchange always returns nil Msg and non-nil error.
func (u *TestTimeoutUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
time.Sleep(u.waitFor)

return nil, nil
}

// Address always returns an empty string.
func (u *TestTimeoutUpstream) Address() string {
return ""
}
13 changes: 13 additions & 0 deletions internal/dnsforward/dnsforward_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,19 @@ func TestServer(t *testing.T) {
}

func TestServer_timeout(t *testing.T) {
const timeout time.Duration = time.Second

srvConf := &ServerConfig{
UpstreamTimeout: timeout,
}

s, err := NewServer(DNSCreateParams{})
require.NoError(t, err)

err = s.Prepare(srvConf)
require.NoError(t, err)

assert.Equal(t, timeout, s.conf.UpstreamTimeout)
}

func TestServerWithProtectionDisabled(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion internal/home/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ func (clients *clientsContainer) findUpstreams(
upstreams,
upstream.Options{
Bootstrap: config.DNS.BootstrapDNS,
Timeout: time.Duration(config.DNS.UpstreamTimeout),
Timeout: config.DNS.UpstreamTimeout.Duration,
},
)
if err != nil {
Expand Down
10 changes: 5 additions & 5 deletions internal/home/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ type dnsConfig struct {

// UpstreamTimeout is the timeout for querying upstream servers in
// seconds.
UpstreamTimeout uint32 `yaml:"upstream_timeout"`
UpstreamTimeout Duration `yaml:"upstream_timeout"`

// LocalDomainName is the domain name used for known internal hosts.
// For example, a machine called "myhost" can be addressed as
Expand Down Expand Up @@ -186,7 +186,7 @@ var config = configuration{
},
FilteringEnabled: true, // whether or not use filter lists
FiltersUpdateIntervalHours: 24,
UpstreamTimeout: uint32(dnsforward.DefaultTimeout.Seconds()),
UpstreamTimeout: Duration{dnsforward.DefaultTimeout},
LocalDomainName: "lan",
ResolveClients: true,
UsePrivateRDNS: true,
Expand Down Expand Up @@ -281,8 +281,8 @@ func parseConfig() error {
config.DNS.FiltersUpdateIntervalHours = 24
}

if config.DNS.UpstreamTimeout == 0 {
config.DNS.UpstreamTimeout = uint32(dnsforward.DefaultTimeout.Seconds())
if config.DNS.UpstreamTimeout.Duration == 0 {
config.DNS.UpstreamTimeout = Duration{dnsforward.DefaultTimeout}
}

return nil
Expand Down Expand Up @@ -348,7 +348,7 @@ func (c *configuration) write() error {
dns.LocalPTRResolvers,
dns.ResolveClients,
dns.UsePrivateRDNS = s.RDNSSettings()
dns.UpstreamTimeout = uint32(s.UpstreamTimeout().Seconds())
dns.UpstreamTimeout = Duration{s.UpstreamTimeout()}
}

if Context.dhcpServer != nil {
Expand Down
3 changes: 1 addition & 2 deletions internal/home/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"net/url"
"os"
"path/filepath"
"time"

"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
Expand Down Expand Up @@ -203,7 +202,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
newConf.ResolveClients = dnsConf.ResolveClients
newConf.UsePrivateRDNS = dnsConf.UsePrivateRDNS
newConf.LocalPTRResolvers = dnsConf.LocalPTRResolvers
newConf.UpstreamTimeout = time.Duration(dnsConf.UpstreamTimeout) * time.Second
newConf.UpstreamTimeout = dnsConf.UpstreamTimeout.Duration

return newConf, nil
}
Expand Down
29 changes: 29 additions & 0 deletions internal/home/duration.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package home

import (
"time"

"github.com/AdguardTeam/golibs/errors"
)

// Duration is a wrapper of time.Duration extending it's functionality.
type Duration struct {
// time.Duration is embedded here to avoid implementing all the methods.
time.Duration
}

// MarshalText implements the encoding.TextMarshaler interface for Duration.
func (d Duration) MarshalText() (text []byte, err error) {
return []byte(d.String()), nil
}

const unmarshalAnnotation = "unmarshalling duration:"

// UnmarshalText implements the encoding.TextUnmarshaler interface for Duration.
func (d *Duration) UnmarshalText(b []byte) (err error) {
defer func() { err = errors.Annotate(err, unmarshalAnnotation+" %w") }()

d.Duration, err = time.ParseDuration(string(b))

return err
}
164 changes: 164 additions & 0 deletions internal/home/duration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package home

import (
"encoding"
"encoding/json"
"encoding/xml"
"fmt"
"io"
"strings"
"testing"
"time"

"github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v2"
)

const (
NotTextMarshalerErr errors.Error = "not a text marshaler"
NotTextUnmarshalerErr errors.Error = "not a text unmarshaler"
)

type directTextEncoder struct {
w io.Writer
r io.Reader
}

func (e *directTextEncoder) Encode(v interface{}) (err error) {
val, ok := v.(encoding.TextMarshaler)
if !ok {
return NotTextMarshalerErr
}

var data []byte
data, err = val.MarshalText()
if err != nil {
return err
}

_, err = e.w.Write(data)
if err != nil {
return err
}

return nil
}

func (e *directTextEncoder) Decode(v interface{}) (err error) {
val, ok := v.(encoding.TextUnmarshaler)
if !ok {
return NotTextUnmarshalerErr
}

var data []byte
data, err = io.ReadAll(e.r)
if err != nil {
return err
}

err = val.UnmarshalText(data)
if err != nil {
return err
}

return nil
}

const (
val = 1 * time.Millisecond
valStr = "1ms"
)

func TestDuration_MarshalText(t *testing.T) {
d := Duration{val}
b := &strings.Builder{}

testCases := []struct {
enc interface {
Encode(v interface{}) (err error)
}
name string
fmtStr string
}{{
enc: yaml.NewEncoder(b),
name: "yaml",
fmtStr: "%s\n",
}, {
enc: json.NewEncoder(b),
name: "json",
fmtStr: "%q\n",
}, {
enc: xml.NewEncoder(b),
name: "xml",
fmtStr: "<Duration>%s</Duration>",
}, {
enc: &directTextEncoder{
w: b,
},
name: "direct",
fmtStr: "%s",
}}

for _, tc := range testCases {
b.Reset()
t.Run(tc.name, func(t *testing.T) {
err := tc.enc.Encode(d)
require.NoError(t, err)

assert.Equal(t, fmt.Sprintf(tc.fmtStr, val), b.String())
})
}
}

func TestDuration_UnmarshalText(t *testing.T) {
d := Duration{}

testCases := []struct {
dec interface {
Decode(v interface{}) (err error)
}
name string
}{{
dec: yaml.NewDecoder(
strings.NewReader(valStr),
),
name: "yaml",
}, {
dec: json.NewDecoder(
strings.NewReader(`"` + valStr + `"`),
),
name: "json",
}, {
dec: xml.NewDecoder(
strings.NewReader("<Duration>" + valStr + "</Duration>"),
),
name: "xml",
}, {
dec: &directTextEncoder{
r: strings.NewReader(valStr),
},
name: "direct",
}}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.dec.Decode(&d)
require.NoError(t, err)

assert.Equal(t, val, d.Duration)
})
}

t.Run("bad_data", func(t *testing.T) {
dec := &directTextEncoder{
r: strings.NewReader("abc"),
}

err := dec.Decode(&d)
require.Error(t, err)

assert.Contains(t, err.Error(), unmarshalAnnotation)
})
}

0 comments on commit 64b00fd

Please sign in to comment.