Skip to content

Commit

Permalink
Merge pull request #87 from matsuyoshi30/feature/add-dns-resolver
Browse files Browse the repository at this point in the history
Add support for resolvers option
  • Loading branch information
nakabonne committed Dec 5, 2020
2 parents 04b67c9 + 8b2ea13 commit 057c16c
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 0 deletions.
4 changes: 4 additions & 0 deletions attacker/attacker.go
Expand Up @@ -44,6 +44,7 @@ type Options struct {
HTTP2 bool
LocalAddr net.IPAddr
Buckets []time.Duration
Resolvers []string

Attacker Attacker
}
Expand Down Expand Up @@ -83,6 +84,9 @@ func Attack(ctx context.Context, target string, resCh chan<- *Result, metricsCh
if opts.LocalAddr.IP == nil {
opts.LocalAddr = DefaultLocalAddr
}
if len(opts.Resolvers) > 0 {
net.DefaultResolver = NewResolver(opts.Resolvers)
}
if opts.Attacker == nil {
opts.Attacker = vegeta.NewAttacker(
vegeta.Timeout(opts.Timeout),
Expand Down
31 changes: 31 additions & 0 deletions attacker/resolver.go
@@ -0,0 +1,31 @@
package attacker

import (
"context"
"net"
"sync/atomic"
"time"
)

type resolver struct {
addrs []string
idx uint64
}

func NewResolver(addrs []string) *net.Resolver {
r := &resolver{addrs: addrs}

return &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{
Timeout: time.Millisecond * time.Duration(2000),
}
return d.DialContext(ctx, network, r.address())
},
}
}

func (r *resolver) address() string {
return r.addrs[atomic.AddUint64(&r.idx, 1)%uint64(len(r.addrs))]
}
92 changes: 92 additions & 0 deletions attacker/resolver_test.go
@@ -0,0 +1,92 @@
package attacker

import (
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

"github.com/miekg/dns"
)

const (
testDomain = "test.notadomain"
DNSServerAddress = "127.0.0.1"
message = "Test Server"
)

func TestNewResolver(t *testing.T) {
done := make(chan struct{}) // for ensuring ds.PacketConn is not nil

// prepare custom DNS server
dns.HandleFunc(".", handleRequest)
ds := dns.Server{
Addr: DNSServerAddress + ":0",
Net: "udp",
ReadTimeout: time.Millisecond * time.Duration(2000),
WriteTimeout: time.Millisecond * time.Duration(2000),
NotifyStartedFunc: func() { close(done) },
}

go func() {
if err := ds.ListenAndServe(); err != nil {
t.Logf("got error during dns ListenAndServe: %s", err)
}
}()
defer func() {
_ = ds.Shutdown()
}()

<-done

net.DefaultResolver = NewResolver([]string{ds.PacketConn.LocalAddr().String()})

// test server for name resolution
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, message)
}))
defer ts.Close()

tsURL, _ := url.Parse(ts.URL)
_, port, _ := net.SplitHostPort(tsURL.Host)
tsURL.Host = net.JoinHostPort(testDomain, port)

resp, err := http.Get(tsURL.String())
if err != nil {
t.Fatalf("failed resolver round trip: %v", err)
}
defer resp.Body.Close()

body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read respose body: %v", err)
}

if strings.TrimSpace(string(body)) != message {
t.Errorf("reponse body mismatch, expected: '%s', but got '%s'", message, body)
}
}

func handleRequest(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)

m.Answer = []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: r.Question[0].Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 1,
},
A: net.ParseIP(DNSServerAddress),
},
}

w.WriteMsg(m)
}
1 change: 1 addition & 0 deletions go.mod
Expand Up @@ -7,6 +7,7 @@ require (
github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88 // indirect
github.com/k0kubun/pp v3.0.1+incompatible
github.com/mattn/go-colorable v0.1.7 // indirect
github.com/miekg/dns v1.1.17
github.com/mum4k/termdash v0.12.2
github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.6.1
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Expand Up @@ -60,6 +60,7 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky
github.com/mattn/go-runewidth v0.0.4/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU=
github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/miekg/dns v1.1.17 h1:BhJxdA7bH51vKFZSY8Sn9pR7++LREvg0eYFzHA452ew=
github.com/miekg/dns v1.1.17/go.mod h1:WgzbA6oji13JREwiNsRDNfl7jYdPnmz+VEuLrA+/48M=
github.com/mum4k/termdash v0.12.2 h1:S2frz71OrXUKIVVZ3snYBEzyYlUNRTu0ElV6d5Pf6gI=
github.com/mum4k/termdash v0.12.2/go.mod h1:haerPCSO0U8pehROAecmuOHDF+2UXw2KaCTxdWooDFE=
Expand All @@ -84,6 +85,7 @@ go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/goleak v1.1.10 h1:z+mqJhf6ss6BSfSM671tgKyZBFPTTJM+HLxnhPC3wu0=
go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472 h1:Gv7RPwsi3eZ2Fgewe3CBsuOebPwO27PoXzRpJPsvSSM=
golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs=
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
Expand Down
47 changes: 47 additions & 0 deletions main.go
Expand Up @@ -11,6 +11,7 @@ import (
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -46,6 +47,7 @@ type cli struct {
localAddress string
noKeepAlive bool
buckets string
resolvers string

debug bool
version bool
Expand Down Expand Up @@ -84,6 +86,7 @@ func parseFlags(stdout, stderr io.Writer) (*cli, error) {
flagSet.StringVar(&c.localAddress, "local-addr", "0.0.0.0", "Local IP address.")
// TODO: Re-enable when making it capable of drawing histogram bar chart.
//flagSet.StringVar(&c.buckets, "buckets", "", "Histogram buckets; comma-separated list.")
flagSet.StringVar(&c.resolvers, "resolvers", "", "Custom DNS resolver addresses; comma-separated list.")
flagSet.Usage = c.usage
if err := flagSet.Parse(os.Args[1:]); err != nil {
if !errors.Is(err, flag.ErrHelp) {
Expand Down Expand Up @@ -186,6 +189,11 @@ func (c *cli) makeOptions() (*attacker.Options, error) {
return nil, fmt.Errorf("wrong buckets format %w", err)
}

parsedResolvers, err := parseResolvers(c.resolvers)
if err != nil {
return nil, err
}

return &attacker.Options{
Rate: c.rate,
Duration: c.duration,
Expand All @@ -201,6 +209,7 @@ func (c *cli) makeOptions() (*attacker.Options, error) {
HTTP2: !c.noHTTP2,
LocalAddr: localAddr,
Buckets: parsedBuckets,
Resolvers: parsedResolvers,
}, nil
}

Expand Down Expand Up @@ -232,6 +241,44 @@ func parseBucketOptions(rawBuckets string) ([]time.Duration, error) {
return result, nil
}

func parseResolvers(addrs string) ([]string, error) {
if addrs == "" {
return nil, nil
}

stringAddrs := strings.Split(addrs, ",")
result := make([]string, 0, len(stringAddrs))

for _, addr := range stringAddrs {
trimmedAddr := strings.TrimSpace(addr)

// if given address has no port, append "53"
if !strings.Contains(trimmedAddr, ":") {
trimmedAddr += ":53"
}

host, port, err := net.SplitHostPort(trimmedAddr)
if err != nil {
return nil, err
}

// validate port
_, err = strconv.ParseUint(port, 10, 16)
if err != nil {
return nil, fmt.Errorf("port of given address %q has a wrong format", addr)
}

// validate IP
if ip := net.ParseIP(host); ip == nil {
return nil, fmt.Errorf("given address %q has a wrong format", addr)
}

result = append(result, trimmedAddr)
}

return result, nil
}

// Makes a new file under the working directory only when debug use.
func setDebug(w io.Writer, debug bool) {
if !debug {
Expand Down
51 changes: 51 additions & 0 deletions main_test.go
Expand Up @@ -60,6 +60,7 @@ func TestParseFlags(t *testing.T) {
stderr: new(bytes.Buffer),
noHTTP2: false,
localAddress: "0.0.0.0",
resolvers: "",
},
wantErr: false,
},
Expand Down Expand Up @@ -301,6 +302,56 @@ func TestMakeOptions(t *testing.T) {
},
wantErr: false,
},
{
name: "use custom DNS resolvers",
cli: &cli{
method: "GET",
resolvers: "1.2.3.4,192.168.11.1:53",
},
want: &attacker.Options{
Rate: 0,
Duration: 0,
Timeout: 0,
Method: "GET",
Body: []byte{},
Header: http.Header{},
Workers: 0,
MaxWorkers: 0,
MaxBody: 0,
HTTP2: true,
KeepAlive: true,
Buckets: []time.Duration{},
Resolvers: []string{"1.2.3.4:53", "192.168.11.1:53"},
},
wantErr: false,
},
{
name: "wrong format",
cli: &cli{
method: "GET",
resolvers: "1.2.3.4:1:1",
},
want: nil,
wantErr: true,
},
{
name: "wrong IP address",
cli: &cli{
method: "GET",
resolvers: "1111.2.3.4",
},
want: nil,
wantErr: true,
},
{
name: "wrong port number",
cli: &cli{
method: "GET",
resolvers: "192.168.11.1:65536",
},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down

0 comments on commit 057c16c

Please sign in to comment.