From ae69e4958dcabfbd4e87d197e3403a49452fc68a Mon Sep 17 00:00:00 2001 From: James Rouzier Date: Fri, 9 Sep 2022 20:36:24 +0000 Subject: [PATCH] Allow filtering via the network_source --- go/plugin/coredns/forward/forward.go | 8 +++++++- go/plugin/coredns/forward/setup.go | 11 +++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/go/plugin/coredns/forward/forward.go b/go/plugin/coredns/forward/forward.go index 90ae1aef885e..fcbf8e531260 100644 --- a/go/plugin/coredns/forward/forward.go +++ b/go/plugin/coredns/forward/forward.go @@ -8,6 +8,7 @@ import ( "context" "crypto/tls" "errors" + "net" "sync/atomic" "time" @@ -42,6 +43,7 @@ type Forward struct { maxfails uint32 expire time.Duration maxConcurrent int64 + sourceNetwork net.IPNet opts options // also here for testing @@ -190,13 +192,17 @@ func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg } func (f *Forward) match(state request.Request) bool { - if !plugin.Name(f.from).Matches(state.Name()) || !f.isAllowedDomain(state.Name()) { + if !plugin.Name(f.from).Matches(state.Name()) || !f.isAllowedDomain(state.Name()) || !f.isIpAllowed(net.ParseIP(state.IP())) { return false } return true } +func (f *Forward) isIpAllowed(ip net.IP) bool { + return len(f.sourceNetwork.IP) != len(ip) || f.sourceNetwork.Contains(ip) +} + func (f *Forward) isAllowedDomain(name string) bool { if dns.Name(name) == dns.Name(f.from) { return true diff --git a/go/plugin/coredns/forward/setup.go b/go/plugin/coredns/forward/setup.go index dfae70d37806..b06f55c87a9f 100644 --- a/go/plugin/coredns/forward/setup.go +++ b/go/plugin/coredns/forward/setup.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "errors" "fmt" + "net" "strconv" "time" @@ -281,6 +282,16 @@ func parseBlock(c *caddy.Controller, f *Forward) error { } f.ErrLimitExceeded = errors.New("concurrent queries exceeded maximum " + c.Val()) f.maxConcurrent = int64(n) + case "network_source": + if !c.NextArg() { + return c.ArgErr() + } + _, ipNet, err := net.ParseCIDR(c.Val()) + if err != nil { + return c.Err("Unable to parse network_source configuration parameter") + } + + f.sourceNetwork = *ipNet default: return c.Errf("unknown property '%s'", c.Val())