diff --git a/cmd/options.go b/cmd/options.go index 78480102526..17bf1b477ca 100644 --- a/cmd/options.go +++ b/cmd/options.go @@ -72,6 +72,8 @@ func optionFlagSet() *pflag.FlagSet { flags.Duration("min-iteration-duration", 0, "minimum amount of time k6 will take executing a single iteration") flags.BoolP("throw", "w", false, "throw warnings (like failed http requests) as errors") flags.StringSlice("blacklist-ip", nil, "blacklist an `ip range` from being called") + flags.StringSlice("block-hostname", nil, "block a case-insensitive hostname `pattern`,"+ + " with optional leading wildcard, from being called") // The comment about system-tags also applies for summary-trend-stats. The default values // are set in applyDefault(). @@ -187,6 +189,17 @@ func getOptions(flags *pflag.FlagSet) (lib.Options, error) { opts.BlacklistIPs = append(opts.BlacklistIPs, net) } + blockedHostnameStrings, err := flags.GetStringSlice("block-hostname") + if err != nil { + return opts, err + } + if flags.Changed("block-hostname") { + opts.BlockedHostnames, err = types.NewNullHostnameTrie(blockedHostnameStrings) + if err != nil { + return opts, err + } + } + if flags.Changed("summary-trend-stats") { trendStats, errSts := flags.GetStringSlice("summary-trend-stats") if errSts != nil { diff --git a/js/runner.go b/js/runner.go index 629c77d3726..eb584e64bed 100644 --- a/js/runner.go +++ b/js/runner.go @@ -159,10 +159,11 @@ func (r *Runner) newVU(id int64, samplesOut chan<- stats.SampleContainer) (*VU, } dialer := &netext.Dialer{ - Dialer: r.BaseDialer, - Resolver: r.Resolver, - Blacklist: r.Bundle.Options.BlacklistIPs, - Hosts: r.Bundle.Options.Hosts, + Dialer: r.BaseDialer, + Resolver: r.Resolver, + Blacklist: r.Bundle.Options.BlacklistIPs, + BlockedHostnames: r.Bundle.Options.BlockedHostnames.Trie, + Hosts: r.Bundle.Options.Hosts, } tlsConfig := &tls.Config{ InsecureSkipVerify: r.Bundle.Options.InsecureSkipTLSVerify.Bool, diff --git a/js/runner_test.go b/js/runner_test.go index 1937e3989fe..00a46631fc4 100644 --- a/js/runner_test.go +++ b/js/runner_test.go @@ -885,6 +885,77 @@ func TestVUIntegrationBlacklistScript(t *testing.T) { } } +func TestVUIntegrationBlockHostnamesOption(t *testing.T) { + r1, err := getSimpleRunner(t, "/script.js", ` + var http = require("k6/http"); + exports.default = function() { http.get("https://k6.io/"); } + `) + require.NoError(t, err) + + hostnames, err := types.NewNullHostnameTrie([]string{"*.io"}) + require.NoError(t, err) + require.NoError(t, r1.SetOptions(lib.Options{ + Throw: null.BoolFrom(true), + BlockedHostnames: hostnames, + })) + + r2, err := NewFromArchive(testutils.NewLogger(t), r1.MakeArchive(), lib.RuntimeOptions{}) + if !assert.NoError(t, err) { + return + } + + runners := map[string]*Runner{"Source": r1, "Archive": r2} + + for name, r := range runners { + r := r + t.Run(name, func(t *testing.T) { + initVu, err := r.NewVU(1, make(chan stats.SampleContainer, 100)) + require.NoError(t, err) + vu := initVu.Activate(&lib.VUActivationParams{RunContext: context.Background()}) + err = vu.RunOnce() + require.Error(t, err) + assert.Contains(t, err.Error(), "hostname (k6.io) is in a blocked pattern (*.io)") + }) + } +} + +func TestVUIntegrationBlockHostnamesScript(t *testing.T) { + r1, err := getSimpleRunner(t, "/script.js", ` + var http = require("k6/http"); + + exports.options = { + throw: true, + blockHostnames: ["*.io"], + }; + + exports.default = function() { http.get("https://k6.io/"); } + `) + if !assert.NoError(t, err) { + return + } + + r2, err := NewFromArchive(testutils.NewLogger(t), r1.MakeArchive(), lib.RuntimeOptions{}) + if !assert.NoError(t, err) { + return + } + + runners := map[string]*Runner{"Source": r1, "Archive": r2} + + for name, r := range runners { + r := r + t.Run(name, func(t *testing.T) { + initVu, err := r.NewVU(0, make(chan stats.SampleContainer, 100)) + if !assert.NoError(t, err) { + return + } + vu := initVu.Activate(&lib.VUActivationParams{RunContext: context.Background()}) + err = vu.RunOnce() + require.Error(t, err) + assert.Contains(t, err.Error(), "hostname (k6.io) is in a blocked pattern (*.io)") + }) + } +} + func TestVUIntegrationHosts(t *testing.T) { tb := httpmultibin.NewHTTPMultiBin(t) defer tb.Cleanup() diff --git a/lib/netext/dialer.go b/lib/netext/dialer.go index bec251fa7f7..050c4660067 100644 --- a/lib/netext/dialer.go +++ b/lib/netext/dialer.go @@ -33,6 +33,7 @@ import ( "github.com/loadimpact/k6/lib" "github.com/loadimpact/k6/lib/metrics" + "github.com/loadimpact/k6/lib/types" "github.com/loadimpact/k6/stats" ) @@ -47,9 +48,10 @@ type dnsResolver interface { type Dialer struct { net.Dialer - Resolver dnsResolver - Blacklist []*lib.IPNet - Hosts map[string]*lib.HostAddress + Resolver dnsResolver + Blacklist []*lib.IPNet + BlockedHostnames *types.HostnameTrie + Hosts map[string]*lib.HostAddress BytesRead int64 BytesWritten int64 @@ -77,6 +79,16 @@ func (b BlackListedIPError) Error() string { return fmt.Sprintf("IP (%s) is in a blacklisted range (%s)", b.ip, b.net) } +// BlockedHostError is returned when a given hostname is blocked +type BlockedHostError struct { + hostname string + match string +} + +func (b BlockedHostError) Error() string { + return fmt.Sprintf("hostname (%s) is in a blocked pattern (%s)", b.hostname, b.match) +} + // DialContext wraps the net.Dialer.DialContext and handles the k6 specifics func (d *Dialer) DialContext(ctx context.Context, proto, addr string) (net.Conn, error) { dialAddr, err := d.getDialAddr(addr) @@ -163,12 +175,18 @@ func (d *Dialer) findRemote(addr string) (*lib.HostAddress, error) { return nil, err } + ip := net.ParseIP(host) + if d.BlockedHostnames != nil && ip == nil { + if match, blocked := d.BlockedHostnames.Contains(host); blocked { + return nil, BlockedHostError{hostname: host, match: match} + } + } + remote, err := d.getConfiguredHost(addr, host, port) if err != nil || remote != nil { return remote, err } - ip := net.ParseIP(host) if ip != nil { return lib.NewHostAddress(ip, port) } diff --git a/lib/netext/dialer_test.go b/lib/netext/dialer_test.go index 15a1d67c9c6..71db79cc07e 100644 --- a/lib/netext/dialer_test.go +++ b/lib/netext/dialer_test.go @@ -25,6 +25,7 @@ import ( "testing" "github.com/loadimpact/k6/lib" + "github.com/loadimpact/k6/lib/types" "github.com/stretchr/testify/require" ) @@ -74,6 +75,42 @@ func TestDialerAddr(t *testing.T) { {"[::1.2.3.4]", "", "address [::1.2.3.4]: missing port in address"}, {"example-ipv6-deny-resolver.com:80", "", "IP (::1) is in a blacklisted range (::/24)"}, {"example-ipv6-deny-host.com:80", "", "IP (::1) is in a blacklisted range (::/24)"}, + {"example-ipv6-deny-host.com:80", "", "IP (::1) is in a blacklisted range (::/24)"}, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.address, func(t *testing.T) { + addr, err := dialer.getDialAddr(tc.address) + + if tc.expErr != "" { + require.EqualError(t, err, tc.expErr) + } else { + require.NoError(t, err) + require.Equal(t, tc.expAddress, addr) + } + }) + } +} + +func TestDialerAddrBlockHostnamesStar(t *testing.T) { + dialer := newDialerWithResolver(net.Dialer{}, newResolver()) + dialer.Hosts = map[string]*lib.HostAddress{ + "example.com": {IP: net.ParseIP("3.4.5.6")}, + } + + blocked, err := types.NewHostnameTrie([]string{"*"}) + require.NoError(t, err) + dialer.BlockedHostnames = blocked + testCases := []struct { + address, expAddress, expErr string + }{ + // IPv4 + {"example.com:80", "", "hostname (example.com) is in a blocked pattern (*)"}, + {"example.com:443", "", "hostname (example.com) is in a blocked pattern (*)"}, + {"not.com:30", "", "hostname (not.com) is in a blocked pattern (*)"}, + {"1.2.3.4:80", "1.2.3.4:80", ""}, } for _, tc := range testCases { diff --git a/lib/netext/httpext/error_codes.go b/lib/netext/httpext/error_codes.go index 1d771629d34..6e7f5de53f9 100644 --- a/lib/netext/httpext/error_codes.go +++ b/lib/netext/httpext/error_codes.go @@ -47,9 +47,10 @@ const ( defaultErrorCode errCode = 1000 defaultNetNonTCPErrorCode errCode = 1010 // DNS errors - defaultDNSErrorCode errCode = 1100 - dnsNoSuchHostErrorCode errCode = 1101 - blackListedIPErrorCode errCode = 1110 + defaultDNSErrorCode errCode = 1100 + dnsNoSuchHostErrorCode errCode = 1101 + blackListedIPErrorCode errCode = 1110 + blockedHostnameErrorCode errCode = 1111 // tcp errors defaultTCPErrorCode errCode = 1200 tcpBrokenPipeErrorCode errCode = 1201 @@ -91,6 +92,7 @@ const ( netUnknownErrnoErrorCodeMsg = "%s: unknown errno `%d` on %s with message `%s`" dnsNoSuchHostErrorCodeMsg = "lookup: no such host" blackListedIPErrorCodeMsg = "ip is blacklisted" + blockedHostnameErrorMsg = "hostname is blocked" http2GoAwayErrorCodeMsg = "http2: received GoAway with http2 ErrCode %s" http2StreamErrorCodeMsg = "http2: stream error with http2 ErrCode %s" http2ConnectionErrorCodeMsg = "http2: connection error with http2 ErrCode %s" @@ -119,6 +121,8 @@ func errorCodeForError(err error) (errCode, string) { } case netext.BlackListedIPError: return blackListedIPErrorCode, blackListedIPErrorCodeMsg + case netext.BlockedHostError: + return blockedHostnameErrorCode, blockedHostnameErrorMsg case *http2.GoAwayError: return unknownHTTP2GoAwayErrorCode + http2ErrCodeOffset(e.ErrCode), fmt.Sprintf(http2GoAwayErrorCodeMsg, e.ErrCode) diff --git a/lib/options.go b/lib/options.go index cde380d656e..20555d4fd41 100644 --- a/lib/options.go +++ b/lib/options.go @@ -339,6 +339,9 @@ type Options struct { // Blacklist IP ranges that tests may not contact. Mainly useful in hosted setups. BlacklistIPs []*IPNet `json:"blacklistIPs" envconfig:"K6_BLACKLIST_IPS"` + // Block hostname patterns that tests may not contact. + BlockedHostnames types.NullHostnameTrie `json:"blockHostnames" envconfig:"K6_BLOCK_HOSTNAMES"` + // Hosts overrides dns entries for given hosts Hosts map[string]*HostAddress `json:"hosts" envconfig:"K6_HOSTS"` @@ -494,6 +497,9 @@ func (o Options) Apply(opts Options) Options { if opts.BlacklistIPs != nil { o.BlacklistIPs = opts.BlacklistIPs } + if opts.BlockedHostnames.Valid { + o.BlockedHostnames = opts.BlockedHostnames + } if opts.Hosts != nil { o.Hosts = opts.Hosts } diff --git a/lib/options_test.go b/lib/options_test.go index 0125e4837ce..a7eb55b4147 100644 --- a/lib/options_test.go +++ b/lib/options_test.go @@ -314,6 +314,13 @@ func TestOptions(t *testing.T) { assert.Equal(t, net.IPv4zero, opts.BlacklistIPs[0].IP) assert.Equal(t, net.CIDRMask(1, 1), opts.BlacklistIPs[0].Mask) }) + t.Run("BlockedHostnames", func(t *testing.T) { + blockedHostnames, err := types.NewNullHostnameTrie([]string{"test.k6.io", "*valid.pattern"}) + require.NoError(t, err) + opts := Options{}.Apply(Options{BlockedHostnames: blockedHostnames}) + assert.NotNil(t, opts.BlockedHostnames) + assert.Equal(t, blockedHostnames, opts.BlockedHostnames) + }) t.Run("Hosts", func(t *testing.T) { host, err := NewHostAddress(net.ParseIP("192.0.2.1"), "80") diff --git a/lib/types/hostnametrie.go b/lib/types/hostnametrie.go new file mode 100644 index 00000000000..eb3ab7d500d --- /dev/null +++ b/lib/types/hostnametrie.go @@ -0,0 +1,182 @@ +/* + * + * k6 - a next-generation load testing tool + * Copyright (C) 2020 Load Impact + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + * + */ + +package types + +import ( + "bytes" + "encoding/json" + "regexp" + "strings" + + "github.com/pkg/errors" +) + +// NullHostnameTrie is a nullable HostnameTrie, in the same vein as the nullable types provided by +// package gopkg.in/guregu/null.v3 +type NullHostnameTrie struct { + Trie *HostnameTrie + Valid bool +} + +// UnmarshalText converts text data to a valid NullHostnameTrie +func (d *NullHostnameTrie) UnmarshalText(data []byte) error { + if len(data) == 0 { + *d = NullHostnameTrie{} + return nil + } + var err error + d.Trie, err = NewHostnameTrie(strings.Split(string(data), ",")) + if err != nil { + return err + } + d.Valid = true + return nil +} + +// UnmarshalJSON converts JSON data to a valid NullHostnameTrie +func (d *NullHostnameTrie) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, []byte(`null`)) { + d.Valid = false + return nil + } + + var m []string + var err error + if err = json.Unmarshal(data, &m); err != nil { + return err + } + d.Trie, err = NewHostnameTrie(m) + if err != nil { + return err + } + d.Valid = true + return nil +} + +// MarshalJSON implements json.Marshaler interface +func (d NullHostnameTrie) MarshalJSON() ([]byte, error) { + if !d.Valid { + return []byte(`null`), nil + } + return json.Marshal(d.Trie.source) +} + +// HostnameTrie is a tree-structured list of hostname matches with support +// for wildcards exclusively at the start of the pattern. Items may only +// be inserted and searched. Internationalized hostnames are valid. +type HostnameTrie struct { + source []string + + children map[rune]*HostnameTrie +} + +// NewNullHostnameTrie returns a NullHostnameTrie encapsulating HostnameTrie or an error if the +// input is incorrect +func NewNullHostnameTrie(source []string) (NullHostnameTrie, error) { + h, err := NewHostnameTrie(source) + if err != nil { + return NullHostnameTrie{}, err + } + return NullHostnameTrie{ + Valid: true, + Trie: h, + }, nil +} + +// NewHostnameTrie returns a pointer to a new HostnameTrie or an error if the input is incorrect +func NewHostnameTrie(source []string) (*HostnameTrie, error) { + h := &HostnameTrie{ + source: source, + } + for _, s := range h.source { + if err := h.insert(s); err != nil { + return nil, err + } + } + return h, nil +} + +// Regex description of hostname pattern to enforce blocks by. Global var +// to avoid compilation penalty at runtime. +// based on regex from https://stackoverflow.com/a/106223/5427244 +//nolint:gochecknoglobals,lll +var validHostnamePattern *regexp.Regexp = regexp.MustCompile(`^(\*\.?)?((([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9]))?$`) + +func isValidHostnamePattern(s string) error { + if len(validHostnamePattern.FindString(s)) != len(s) { + return errors.Errorf("invalid hostname pattern %s", s) + } + return nil +} + +// insert a hostname pattern into the given HostnameTrie. Returns an error +// if hostname pattern is invalid. +func (t *HostnameTrie) insert(s string) error { + s = strings.ToLower(s) + if err := isValidHostnamePattern(s); err != nil { + return err + } + + return t.childInsert(s) +} + +func (t *HostnameTrie) childInsert(s string) error { + if len(s) == 0 { + return nil + } + + // mask creation of the trie by initializing the root here + if t.children == nil { + t.children = make(map[rune]*HostnameTrie) + } + + rStr := []rune(s) // need to iterate by runes for intl' names + last := len(rStr) - 1 + if c, ok := t.children[rStr[last]]; ok { + return c.childInsert(string(rStr[:last])) + } + + t.children[rStr[last]] = &HostnameTrie{children: make(map[rune]*HostnameTrie)} + return t.children[rStr[last]].childInsert(string(rStr[:last])) +} + +// Contains returns whether s matches a pattern in the HostnameTrie +// along with the matching pattern, if one was found. +func (t *HostnameTrie) Contains(s string) (matchedPattern string, matchFound bool) { + s = strings.ToLower(s) + if len(s) == 0 { + return s, len(t.children) == 0 + } + + rStr := []rune(s) + last := len(rStr) - 1 + if c, ok := t.children[rStr[last]]; ok { + if match, matched := c.Contains(string(rStr[:last])); matched { + return match + string(rStr[last]), true + } + } + + if _, wild := t.children['*']; wild { + return "*", true + } + + return "", false +} diff --git a/lib/types/hostnametrie_test.go b/lib/types/hostnametrie_test.go new file mode 100644 index 00000000000..50fda4d48fd --- /dev/null +++ b/lib/types/hostnametrie_test.go @@ -0,0 +1,60 @@ +/* + * + * k6 - a next-generation load testing tool + * Copyright (C) 2020 Load Impact + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + * + */ + +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHostnameTrieInsert(t *testing.T) { + hostnames := HostnameTrie{} + assert.NoError(t, hostnames.insert("test.k6.io")) + assert.Error(t, hostnames.insert("inval*d.pattern")) + assert.NoError(t, hostnames.insert("*valid.pattern")) +} + +func TestHostnameTrieContains(t *testing.T) { + trie, err := NewHostnameTrie([]string{"test.k6.io", "*valid.pattern"}) + require.NoError(t, err) + cases := map[string]string{ + "K6.Io": "", + "tEsT.k6.Io": "test.k6.io", + "TESt.K6.IO": "test.k6.io", + "blocked.valId.paTtern": "*valid.pattern", + "example.test.k6.io": "", + } + for key, value := range cases { + host, pattern := key, value + t.Run(host, func(t *testing.T) { + match, matches := trie.Contains(host) + if pattern == "" { + assert.False(t, matches) + assert.Empty(t, match) + } else { + assert.True(t, matches) + assert.Equal(t, pattern, match) + } + }) + } +}