From 7dd6eb369daea79b4bdf6b39407c16f1c5e92c45 Mon Sep 17 00:00:00 2001 From: Olivier Poitrey Date: Sat, 28 Dec 2019 17:03:32 -0800 Subject: [PATCH] Use /etc/hosts file to resolve listen address and list on all IPs listed Fixes #22 --- activate.go | 6 +- host/{dns_freebsd.go => dns_bsd.go} | 2 + host/dns_other.go | 2 +- host/dns_resolvconf.go | 2 +- host/{log_freebsd.go => log_bsd.go} | 2 + host/log_other.go | 2 +- host/log_syslog.go | 7 +- host/log_windows.go | 25 ++-- host/service/bsd/service.go | 4 +- host/service/run.go | 4 +- host/service/run_unix.go | 2 +- host/service/run_windows.go | 10 +- host/service/service_windows.go | 4 +- host/service/systemd/service.go | 2 +- host/service/windows/service.go | 71 +++++------ host/service_bsd.go | 2 +- host/service_windows.go | 2 +- hosts/hosts.go | 127 ++++++++++++++++++++ hosts/hosts_test.go | 175 ++++++++++++++++++++++++++++ hosts/parse.go | 153 ++++++++++++++++++++++++ hosts/testdata/case-hosts | 2 + hosts/testdata/hosts | 11 ++ hosts/testdata/ipv4-hosts | 12 ++ hosts/testdata/ipv6-hosts | 11 ++ hosts/testdata/singleline-hosts | 1 + proxy/proxy.go | 101 ++++++++++------ run.go | 16 +-- 27 files changed, 635 insertions(+), 123 deletions(-) rename host/{dns_freebsd.go => dns_bsd.go} (98%) rename host/{log_freebsd.go => log_bsd.go} (84%) create mode 100644 hosts/hosts.go create mode 100644 hosts/hosts_test.go create mode 100644 hosts/parse.go create mode 100644 hosts/testdata/case-hosts create mode 100644 hosts/testdata/hosts create mode 100644 hosts/testdata/ipv4-hosts create mode 100644 hosts/testdata/ipv6-hosts create mode 100644 hosts/testdata/singleline-hosts diff --git a/activate.go b/activate.go index ae1f0367..e9a082b8 100644 --- a/activate.go +++ b/activate.go @@ -6,6 +6,7 @@ import ( "github.com/nextdns/nextdns/config" "github.com/nextdns/nextdns/host" + "github.com/nextdns/nextdns/hosts" ) func activation(args []string) error { @@ -42,10 +43,7 @@ func listenIP(listen string) (string, error) { case "::": return "::1", nil } - addrs, err := net.LookupHost(host) - if err != nil { - return "", fmt.Errorf("activate: %s: %v", listen, err) - } + addrs := hosts.LookupHost(host) if len(addrs) == 0 { return "", fmt.Errorf("activate: %s: no address found", listen) } diff --git a/host/dns_freebsd.go b/host/dns_bsd.go similarity index 98% rename from host/dns_freebsd.go rename to host/dns_bsd.go index 0dd52f9d..a4fe1248 100644 --- a/host/dns_freebsd.go +++ b/host/dns_bsd.go @@ -1,3 +1,5 @@ +// +build freebsd openbsd netbsd dragonfly + package host import ( diff --git a/host/dns_other.go b/host/dns_other.go index 01564222..67504bb5 100644 --- a/host/dns_other.go +++ b/host/dns_other.go @@ -1,4 +1,4 @@ -// +build !darwin,!linux,!freebsd +// +build !darwin,!linux,!freebsd,!openbsd,!netbsd,!dragonfly package host diff --git a/host/dns_resolvconf.go b/host/dns_resolvconf.go index 2ee9eccf..e8c3eb68 100644 --- a/host/dns_resolvconf.go +++ b/host/dns_resolvconf.go @@ -1,4 +1,4 @@ -// +build linux freebsd +// +build linux freebsd openbsd netbsd dragonfly package host diff --git a/host/log_freebsd.go b/host/log_bsd.go similarity index 84% rename from host/log_freebsd.go rename to host/log_bsd.go index 61d9cc5f..ccee4c2b 100644 --- a/host/log_freebsd.go +++ b/host/log_bsd.go @@ -1,3 +1,5 @@ +// +build freebsd openbsd netbsd dragonfly + package host import ( diff --git a/host/log_other.go b/host/log_other.go index da5cd275..03b2b278 100644 --- a/host/log_other.go +++ b/host/log_other.go @@ -1,4 +1,4 @@ -// +build !linux,!freebsd,!darwin +// +build !darwin,!linux,!freebsd,!openbsd,!netbsd,!dragonfly package host diff --git a/host/log_syslog.go b/host/log_syslog.go index 67c80178..247e28df 100644 --- a/host/log_syslog.go +++ b/host/log_syslog.go @@ -1,3 +1,5 @@ +// +build !windows + package host import ( @@ -18,11 +20,12 @@ func newSyslogLogger(name string) (Logger, error) { } func (l syslogLogger) Info(v ...interface{}) { - _ = l.syslog.Info(fmt.Sprint(v...)) + // Use notice instead of info as many systems filter < notice level + _ = l.syslog.Notice(fmt.Sprint(v...)) } func (l syslogLogger) Infof(format string, a ...interface{}) { - _ = l.syslog.Info(fmt.Sprintf(format, a...)) + _ = l.syslog.Notice(fmt.Sprintf(format, a...)) } func (l syslogLogger) Warning(v ...interface{}) { diff --git a/host/log_windows.go b/host/log_windows.go index 5c3434cb..17490256 100644 --- a/host/log_windows.go +++ b/host/log_windows.go @@ -1,6 +1,9 @@ package host import ( + "fmt" + "strings" + "golang.org/x/sys/windows/svc/debug" "golang.org/x/sys/windows/svc/eventlog" ) @@ -13,18 +16,18 @@ func newConsoleLogger(name string) Logger { return windowsLogger{log: debug.New(name)} } -func newServiceLogger(name string) (log.Logger, error) { - err = eventlog.InstallAsEventCreate(name, eventlog.Error|eventlog.Warning|eventlog.Info) +func newServiceLogger(name string) (Logger, error) { + err := eventlog.InstallAsEventCreate(name, eventlog.Error|eventlog.Warning|eventlog.Info) if err != nil { if !strings.Contains(err.Error(), "exists") { - return err + return nil, err } } - el, err := logentlog.Open(name) + el, err := eventlog.Open(name) if err != nil { return nil, err } - return windowsLogger{log: el} + return windowsLogger{log: el}, nil } func (l windowsLogger) Info(v ...interface{}) { @@ -36,21 +39,17 @@ func (l windowsLogger) Infof(format string, a ...interface{}) { } func (l windowsLogger) Warning(v ...interface{}) { - return l.log.Warning(2, fmt.Sprint(v...)) + l.log.Warning(2, fmt.Sprint(v...)) } func (l windowsLogger) Warningf(format string, a ...interface{}) { - return l.log.Warning(2, fmt.Sprintf(format, a...)) + l.log.Warning(2, fmt.Sprintf(format, a...)) } func (l windowsLogger) Error(v ...interface{}) { - return l.log.Error(3, fmt.Sprint(v...)) + l.log.Error(3, fmt.Sprint(v...)) } func (l windowsLogger) Errorf(format string, a ...interface{}) { - return l.log.Error(3, fmt.Sprintf(format, a...)) + l.log.Error(3, fmt.Sprintf(format, a...)) } - - - - diff --git a/host/service/bsd/service.go b/host/service/bsd/service.go index e5909b18..46dc793e 100644 --- a/host/service/bsd/service.go +++ b/host/service/bsd/service.go @@ -94,8 +94,8 @@ name="{{.Name}}" {{.Name}}_env="{{.RunModeEnv}}=1" pidfile="/var/run/${name}.pid" command="/usr/sbin/daemon" -daemon_args="-P ${pidfile} -r -t \"${name}: daemon\"{{if .WorkingDirectory}} -c {{.WorkingDirectory}}{{end}}" -command_args="${daemon_args} {{.Path}}{{range .Arguments}} {{.}}{{end}}" +daemon_args="-P ${pidfile} -r -t \"${name}: daemon\"" +command_args="${daemon_args} {{.Executable}}{{range .Arguments}} {{.}}{{end}}" run_rc_command "$1" ` diff --git a/host/service/run.go b/host/service/run.go index 528e770d..9379b000 100644 --- a/host/service/run.go +++ b/host/service/run.go @@ -11,11 +11,11 @@ type Runner interface { Stop() error } -func Run(r Runner) error { +func Run(name string, r Runner) error { if CurrentRunMode() == RunModeNone { return runForeground(r) } - return runService(r) + return runService(name, r) } func runForeground(r Runner) error { diff --git a/host/service/run_unix.go b/host/service/run_unix.go index d330c541..eae4f813 100644 --- a/host/service/run_unix.go +++ b/host/service/run_unix.go @@ -2,6 +2,6 @@ package service -func runService(r Runner) error { +func runService(name string, r Runner) error { return runForeground(r) } diff --git a/host/service/run_windows.go b/host/service/run_windows.go index 4d675668..f6222a83 100644 --- a/host/service/run_windows.go +++ b/host/service/run_windows.go @@ -13,7 +13,7 @@ type windowService struct { func (s windowService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (bool, uint32) { const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown changes <- svc.Status{State: svc.StartPending} - if err := s.Start(s.log); err != nil { + if err := s.Start(); err != nil { s.lastErr = err return true, 1 } @@ -27,7 +27,7 @@ loop: changes <- c.CurrentStatus case svc.Stop, svc.Shutdown: changes <- svc.Status{State: svc.StopPending} - if err := s.Stop(s.log); err != nil { + if err := s.Stop(); err != nil { s.lastErr = err return true, 2 } @@ -38,12 +38,12 @@ loop: return false, 0 } -func runService(r Runner) error { +func runService(name string, r Runner) error { runner := svc.Run - if isDebug { + if interactive, _ := svc.IsAnInteractiveSession(); interactive { runner = debug.Run } - s := &windowService{r} + s := &windowService{Runner: r} err := runner(name, s) if s.lastErr != nil { return s.lastErr diff --git a/host/service/service_windows.go b/host/service/service_windows.go index b6f50820..a1c0dcfd 100644 --- a/host/service/service_windows.go +++ b/host/service/service_windows.go @@ -5,8 +5,8 @@ import ( ) func CurrentRunMode() RunMode { - if interactive, err = svc.IsAnInteractiveSession(); interactive || err != nil { - return RunModeNode + if interactive, err := svc.IsAnInteractiveSession(); interactive || err != nil { + return RunModeNone } return RunModeService } diff --git a/host/service/systemd/service.go b/host/service/systemd/service.go index 6e9944ba..00df7cd8 100644 --- a/host/service/systemd/service.go +++ b/host/service/systemd/service.go @@ -27,7 +27,7 @@ func New(c service.Config) (Service, error) { return Service{ Config: c, ConfigFileStorer: service.ConfigFileStorer{File: "/etc/" + c.Name + ".conf"}, - Path: "/etc/Service/system/" + c.Name + ".service", + Path: "/etc/systemd/system/" + c.Name + ".service", }, nil } diff --git a/host/service/windows/service.go b/host/service/windows/service.go index 0d0e1050..8e3936f7 100644 --- a/host/service/windows/service.go +++ b/host/service/windows/service.go @@ -27,7 +27,7 @@ func New(c service.Config) (Service, error) { if err != nil { return Service{}, err } - confPath := filepath.Join(filepath.Dir(ep), c.Name+".conf"), nil + confPath := filepath.Join(filepath.Dir(ep), c.Name+".conf") return Service{ Config: c, ConfigFileStorer: service.ConfigFileStorer{File: confPath}, @@ -44,12 +44,12 @@ func (s Service) Install() error { return err } defer m.Disconnect() - s, err := m.OpenService(s.Name) + srv, err := m.OpenService(s.Name) if err == nil { - s.Close() - return ErrAlreadyInstalled + srv.Close() + return service.ErrAlreadyInstalled } - s, err = m.CreateService(s.Name, ep, mgr.Config{ + srv, err = m.CreateService(s.Name, ep, mgr.Config{ DisplayName: s.DisplayName, Description: s.Description, StartType: mgr.StartAutomatic, @@ -57,8 +57,8 @@ func (s Service) Install() error { if err != nil { return err } - defer s.Close() - err = s.SetRecoveryActions([]mgr.RecoveryAction{ + defer srv.Close() + err = srv.SetRecoveryActions([]mgr.RecoveryAction{ mgr.RecoveryAction{ Type: mgr.ServiceRestart, Delay: 5 * time.Second, @@ -77,55 +77,43 @@ func (s Service) Uninstall() error { return err } defer m.Disconnect() - s, err := m.OpenService(s.Name) + srv, err := m.OpenService(s.Name) if err != nil { - return ErrNoInstalled + return service.ErrNoInstalled } - defer s.Close() - err = s.Delete() + defer srv.Close() + err = srv.Delete() if err != nil { return err } return nil } -func (s Service) Status() (Status, error) { +func (s Service) Status() (service.Status, error) { m, err := mgr.Connect() if err != nil { - return StatusUnknown, err + return service.StatusUnknown, err } defer m.Disconnect() - s, err := m.OpenService(s.Name) + srv, err := m.OpenService(s.Name) if err != nil { if err.Error() == "The specified service does not exist as an installed service." { - return StatusNotInstalled, nil + return service.StatusNotInstalled, nil } - return StatusUnknown, err + return service.StatusUnknown, err } - status, err := s.Query() + status, err := srv.Query() if err != nil { - return StatusUnknown, err + return service.StatusUnknown, err } switch status.State { - case svc.StartPending: - fallthrough - case svc.Running: - return StatusRunning, nil - case svc.PausePending: - fallthrough - case svc.Paused: - fallthrough - case svc.ContinuePending: - fallthrough - case svc.StopPending: - fallthrough - case svc.Stopped: - return StatusStopped, nil + case svc.StartPending, svc.Running, svc.PausePending, svc.Paused, svc.ContinuePending, svc.StopPending, svc.Stopped: + return service.StatusStopped, nil default: - return StatusUnknown, fmt.Errorf("unknown status %v", status) + return service.StatusUnknown, fmt.Errorf("unknown status %v", status) } } @@ -136,12 +124,12 @@ func (s Service) Start() error { } defer m.Disconnect() - svc, err := m.OpenService(ws.Name) + srv, err := m.OpenService(s.Name) if err != nil { return err } - defer svc.Close() - return svc.Start() + defer srv.Close() + return srv.Start() } func (s Service) Stop() error { @@ -150,12 +138,12 @@ func (s Service) Stop() error { return err } defer m.Disconnect() - svc, err := m.OpenService(name) + srv, err := m.OpenService(s.Name) if err != nil { return fmt.Errorf("could not access service: %v", err) } - defer svc.Close() - status, err := svc.Control(svc.Stop) + defer srv.Close() + status, err := srv.Control(svc.Stop) if err != nil { return fmt.Errorf("could not send control=%d: %v", svc.Stop, err) } @@ -165,12 +153,13 @@ func (s Service) Stop() error { return fmt.Errorf("timeout waiting for service to go to state=%d", svc.Stopped) } time.Sleep(300 * time.Millisecond) - status, err = svc.Query() + status, err = srv.Query() if err != nil { return fmt.Errorf("could not retrieve service status: %v", err) } } - return nil} + return nil +} func (s Service) Restart() error { if err := s.Stop(); err != nil { diff --git a/host/service_bsd.go b/host/service_bsd.go index 691a49c3..85462177 100644 --- a/host/service_bsd.go +++ b/host/service_bsd.go @@ -1,4 +1,4 @@ -// +build freebsd,openbsd,netbsd,dragonfly +// +build freebsd openbsd netbsd dragonfly package host diff --git a/host/service_windows.go b/host/service_windows.go index b21f6127..6de33654 100644 --- a/host/service_windows.go +++ b/host/service_windows.go @@ -1,4 +1,4 @@ -package service +package host import ( "github.com/nextdns/nextdns/host/service" diff --git a/hosts/hosts.go b/hosts/hosts.go new file mode 100644 index 00000000..60c9a678 --- /dev/null +++ b/hosts/hosts.go @@ -0,0 +1,127 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package hosts + +import ( + "fmt" + "strings" + "sync" + "time" +) + +const cacheMaxAge = 5 * time.Second + +var testHookHostsPath = "/etc/hosts" + +// LookupHost looks up the addresses for the given host from /etc/hosts. +func LookupHost(host string) []string { + hosts.Lock() + defer hosts.Unlock() + readHosts() + if len(hosts.byName) != 0 { + // TODO(jbd,bradfitz): avoid this alloc if host is already all lowercase? + // or linear scan the byName map if it's small enough? + lowerHost := []byte(host) + lowerASCIIBytes(lowerHost) + if ips, ok := hosts.byName[absDomainName(lowerHost)]; ok { + ipsCp := make([]string, len(ips)) + copy(ipsCp, ips) + return ipsCp + } + } + return nil +} + +// LookupAddr looks up the hosts for the given address from /etc/hosts. +func LookupAddr(addr string) []string { + hosts.Lock() + defer hosts.Unlock() + readHosts() + addr = parseLiteralIP(addr) + if addr == "" { + return nil + } + if len(hosts.byAddr) != 0 { + if hosts, ok := hosts.byAddr[addr]; ok { + hostsCp := make([]string, len(hosts)) + copy(hostsCp, hosts) + return hostsCp + } + } + return nil +} + +// hosts contains known host entries. +var hosts struct { + sync.Mutex + + // Key for the list of literal IP addresses must be a host + // name. It would be part of DNS labels, a FQDN or an absolute + // FQDN. + // For now the key is converted to lower case for convenience. + byName map[string][]string + + // Key for the list of host names must be a literal IP address + // including IPv6 address with zone identifier. + // We don't support old-classful IP address notation. + byAddr map[string][]string + + expire time.Time + path string + mtime time.Time + size int64 +} + +func readHosts() { + now := time.Now() + hp := testHookHostsPath + + if now.Before(hosts.expire) && hosts.path == hp && len(hosts.byName) > 0 { + return + } + mtime, size, err := stat(hp) + if err == nil && hosts.path == hp && hosts.mtime.Equal(mtime) && hosts.size == size { + hosts.expire = now.Add(cacheMaxAge) + return + } + + hs := make(map[string][]string) + is := make(map[string][]string) + var file *file + if file, err = open(hp); file == nil { + fmt.Println("return 1", err) + return + } + for line, ok := file.readLine(); ok; line, ok = file.readLine() { + if i := strings.IndexByte(line, '#'); i >= 0 { + // Discard comments. + line = line[0:i] + } + f := strings.Fields(line) + if len(f) < 2 { + continue + } + addr := parseLiteralIP(f[0]) + if addr == "" { + continue + } + for i := 1; i < len(f); i++ { + name := absDomainName([]byte(f[i])) + h := []byte(f[i]) + lowerASCIIBytes(h) + key := absDomainName(h) + hs[key] = append(hs[key], addr) + is[addr] = append(is[addr], name) + } + } + // Update the data cache. + hosts.expire = now.Add(cacheMaxAge) + hosts.path = hp + hosts.byName = hs + hosts.byAddr = is + hosts.mtime = mtime + hosts.size = size + file.close() +} diff --git a/hosts/hosts_test.go b/hosts/hosts_test.go new file mode 100644 index 00000000..94592d85 --- /dev/null +++ b/hosts/hosts_test.go @@ -0,0 +1,175 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package hosts + +import ( + "reflect" + "strings" + "testing" +) + +type staticHostEntry struct { + in string + out []string +} + +var lookupStaticHostTests = []struct { + name string + ents []staticHostEntry +}{ + { + "testdata/hosts", + []staticHostEntry{ + {"odin", []string{"127.0.0.2", "127.0.0.3", "::2"}}, + {"thor", []string{"127.1.1.1"}}, + {"ullr", []string{"127.1.1.2"}}, + {"ullrhost", []string{"127.1.1.2"}}, + {"localhost", []string{"fe80::1%lo0"}}, + }, + }, + { + "testdata/singleline-hosts", // see golang.org/issue/6646 + []staticHostEntry{ + {"odin", []string{"127.0.0.2"}}, + }, + }, + { + "testdata/ipv4-hosts", // see golang.org/issue/8996 + []staticHostEntry{ + {"localhost", []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}}, + {"localhost.localdomain", []string{"127.0.0.3"}}, + }, + }, + { + "testdata/ipv6-hosts", // see golang.org/issue/8996 + []staticHostEntry{ + {"localhost", []string{"::1", "fe80::1", "fe80::2%lo0", "fe80::3%lo0"}}, + {"localhost.localdomain", []string{"fe80::3%lo0"}}, + }, + }, + { + "testdata/case-hosts", // see golang.org/issue/12806 + []staticHostEntry{ + {"PreserveMe", []string{"127.0.0.1", "::1"}}, + {"PreserveMe.local", []string{"127.0.0.1", "::1"}}, + }, + }, +} + +func TestLookupStaticHost(t *testing.T) { + defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath) + + for _, tt := range lookupStaticHostTests { + testHookHostsPath = tt.name + for _, ent := range tt.ents { + testStaticHost(t, tt.name, ent) + } + } +} + +func testStaticHost(t *testing.T, hostsPath string, ent staticHostEntry) { + ins := []string{ent.in, absDomainName([]byte(ent.in)), strings.ToLower(ent.in), strings.ToUpper(ent.in)} + for _, in := range ins { + addrs := LookupHost(in) + if !reflect.DeepEqual(addrs, ent.out) { + t.Errorf("%s, lookupStaticHost(%s) = %v; want %v", hostsPath, in, addrs, ent.out) + } + } +} + +var lookupStaticAddrTests = []struct { + name string + ents []staticHostEntry +}{ + { + "testdata/hosts", + []staticHostEntry{ + {"255.255.255.255", []string{"broadcasthost"}}, + {"127.0.0.2", []string{"odin"}}, + {"127.0.0.3", []string{"odin"}}, + {"::2", []string{"odin"}}, + {"127.1.1.1", []string{"thor"}}, + {"127.1.1.2", []string{"ullr", "ullrhost"}}, + {"fe80::1%lo0", []string{"localhost"}}, + }, + }, + { + "testdata/singleline-hosts", // see golang.org/issue/6646 + []staticHostEntry{ + {"127.0.0.2", []string{"odin"}}, + }, + }, + { + "testdata/ipv4-hosts", // see golang.org/issue/8996 + []staticHostEntry{ + {"127.0.0.1", []string{"localhost"}}, + {"127.0.0.2", []string{"localhost"}}, + {"127.0.0.3", []string{"localhost", "localhost.localdomain"}}, + }, + }, + { + "testdata/ipv6-hosts", // see golang.org/issue/8996 + []staticHostEntry{ + {"::1", []string{"localhost"}}, + {"fe80::1", []string{"localhost"}}, + {"fe80::2%lo0", []string{"localhost"}}, + {"fe80::3%lo0", []string{"localhost", "localhost.localdomain"}}, + }, + }, + { + "testdata/case-hosts", // see golang.org/issue/12806 + []staticHostEntry{ + {"127.0.0.1", []string{"PreserveMe", "PreserveMe.local"}}, + {"::1", []string{"PreserveMe", "PreserveMe.local"}}, + }, + }, +} + +func TestLookupStaticAddr(t *testing.T) { + defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath) + + for _, tt := range lookupStaticAddrTests { + testHookHostsPath = tt.name + for _, ent := range tt.ents { + testStaticAddr(t, tt.name, ent) + } + } +} + +func testStaticAddr(t *testing.T, hostsPath string, ent staticHostEntry) { + hosts := LookupAddr(ent.in) + for i := range ent.out { + ent.out[i] = absDomainName([]byte(ent.out[i])) + } + if !reflect.DeepEqual(hosts, ent.out) { + t.Errorf("%s, lookupStaticAddr(%s) = %v; want %v", hostsPath, ent.in, hosts, ent.out) + } +} + +func TestHostCacheModification(t *testing.T) { + // Ensure that programs can't modify the internals of the host cache. + // See https://golang.org/issues/14212. + defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath) + + testHookHostsPath = "testdata/ipv4-hosts" + ent := staticHostEntry{"localhost", []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}} + testStaticHost(t, testHookHostsPath, ent) + // Modify the addresses return by lookupStaticHost. + addrs := LookupHost(ent.in) + for i := range addrs { + addrs[i] += "junk" + } + testStaticHost(t, testHookHostsPath, ent) + + testHookHostsPath = "testdata/ipv6-hosts" + ent = staticHostEntry{"::1", []string{"localhost"}} + testStaticAddr(t, testHookHostsPath, ent) + // Modify the hosts return by lookupStaticAddr. + hosts := LookupAddr(ent.in) + for i := range hosts { + hosts[i] += "junk" + } + testStaticAddr(t, testHookHostsPath, ent) +} diff --git a/hosts/parse.go b/hosts/parse.go new file mode 100644 index 00000000..4d9c832c --- /dev/null +++ b/hosts/parse.go @@ -0,0 +1,153 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package hosts + +import ( + "io" + "net" + "os" + "strings" + "time" +) + +type file struct { + file *os.File + data []byte + atEOF bool +} + +func (f *file) close() { f.file.Close() } + +func (f *file) getLineFromData() (s string, ok bool) { + data := f.data + i := 0 + for i = 0; i < len(data); i++ { + if data[i] == '\n' { + s = string(data[0:i]) + ok = true + // move data + i++ + n := len(data) - i + copy(data[0:], data[i:]) + f.data = data[0:n] + return + } + } + if f.atEOF && len(f.data) > 0 { + // EOF, return all we have + s = string(data) + f.data = f.data[0:0] + ok = true + } + return +} + +func (f *file) readLine() (s string, ok bool) { + if s, ok = f.getLineFromData(); ok { + return + } + if len(f.data) < cap(f.data) { + ln := len(f.data) + n, err := io.ReadFull(f.file, f.data[ln:cap(f.data)]) + if n >= 0 { + f.data = f.data[0 : ln+n] + } + if err == io.EOF || err == io.ErrUnexpectedEOF { + f.atEOF = true + } + } + s, ok = f.getLineFromData() + return +} + +func open(name string) (*file, error) { + fd, err := os.Open(name) + if err != nil { + return nil, err + } + return &file{fd, make([]byte, 0, 64*1024), false}, nil +} + +func stat(name string) (mtime time.Time, size int64, err error) { + st, err := os.Stat(name) + if err != nil { + return time.Time{}, 0, err + } + return st.ModTime(), st.Size(), nil +} + +// lowerASCIIBytes makes x ASCII lowercase in-place. +func lowerASCIIBytes(x []byte) { + for i, b := range x { + if 'A' <= b && b <= 'Z' { + x[i] += 'a' - 'A' + } + } +} + +// absDomainName returns an absolute domain name which ends with a +// trailing dot to match pure Go reverse resolver and all other lookup +// routines. +// See golang.org/issue/12189. +// But we don't want to add dots for local names from /etc/hosts. +// It's hard to tell so we settle on the heuristic that names without dots +// (like "localhost" or "myhost") do not get trailing dots, but any other +// names do. +func absDomainName(b []byte) string { + hasDots := false + for _, x := range b { + if x == '.' { + hasDots = true + break + } + } + if hasDots && b[len(b)-1] != '.' { + b = append(b, '.') + } + return string(b) +} + +func parseLiteralIP(addr string) string { + ip, zone := parseIPZone(addr) + if ip == nil { + return "" + } + if zone == "" { + return ip.String() + } + return ip.String() + "%" + zone +} + +// parseIPZone parses s as an IP address, return it and its associated zone +// identifier (IPv6 only). +func parseIPZone(s string) (net.IP, string) { + for i := 0; i < len(s); i++ { + switch s[i] { + case '.': + return net.ParseIP(s), "" + case ':': + return parseIPv6Zone(s) + } + } + return nil, "" +} + +// parseIPv6Zone parses s as a literal IPv6 address and its associated zone +// identifier which is described in RFC 4007. +func parseIPv6Zone(s string) (net.IP, string) { + s, zone := splitHostZone(s) + return net.ParseIP(s), zone +} + +func splitHostZone(s string) (host, zone string) { + // The IPv6 scoped addressing zone identifier starts after the + // last percent sign. + if i := strings.LastIndexByte(s, '%'); i > 0 { + host, zone = s[:i], s[i+1:] + } else { + host = s + } + return +} diff --git a/hosts/testdata/case-hosts b/hosts/testdata/case-hosts new file mode 100644 index 00000000..1f30df11 --- /dev/null +++ b/hosts/testdata/case-hosts @@ -0,0 +1,2 @@ +127.0.0.1 PreserveMe PreserveMe.local +::1 PreserveMe PreserveMe.local diff --git a/hosts/testdata/hosts b/hosts/testdata/hosts new file mode 100644 index 00000000..3ed83ff8 --- /dev/null +++ b/hosts/testdata/hosts @@ -0,0 +1,11 @@ +255.255.255.255 broadcasthost +127.0.0.2 odin +127.0.0.3 odin # inline comment +::2 odin +127.1.1.1 thor +# aliases +127.1.1.2 ullr ullrhost +fe80::1%lo0 localhost +# Bogus entries that must be ignored. +123.123.123 loki +321.321.321.321 diff --git a/hosts/testdata/ipv4-hosts b/hosts/testdata/ipv4-hosts new file mode 100644 index 00000000..5208bb44 --- /dev/null +++ b/hosts/testdata/ipv4-hosts @@ -0,0 +1,12 @@ +# See https://tools.ietf.org/html/rfc1123. +# +# The literal IPv4 address parser in the net package is a relaxed +# one. It may accept a literal IPv4 address in dotted-decimal notation +# with leading zeros such as "001.2.003.4". + +# internet address and host name +127.0.0.1 localhost # inline comment separated by tab +127.000.000.002 localhost # inline comment separated by space + +# internet address, host name and aliases +127.000.000.003 localhost localhost.localdomain diff --git a/hosts/testdata/ipv6-hosts b/hosts/testdata/ipv6-hosts new file mode 100644 index 00000000..f78b7fcf --- /dev/null +++ b/hosts/testdata/ipv6-hosts @@ -0,0 +1,11 @@ +# See https://tools.ietf.org/html/rfc5952, https://tools.ietf.org/html/rfc4007. + +# internet address and host name +::1 localhost # inline comment separated by tab +fe80:0000:0000:0000:0000:0000:0000:0001 localhost # inline comment separated by space + +# internet address with zone identifier and host name +fe80:0000:0000:0000:0000:0000:0000:0002%lo0 localhost + +# internet address, host name and aliases +fe80::3%lo0 localhost localhost.localdomain diff --git a/hosts/testdata/singleline-hosts b/hosts/testdata/singleline-hosts new file mode 100644 index 00000000..5f5f74a3 --- /dev/null +++ b/hosts/testdata/singleline-hosts @@ -0,0 +1 @@ +127.0.0.2 odin \ No newline at end of file diff --git a/proxy/proxy.go b/proxy/proxy.go index 4387ac64..74514974 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -7,6 +7,7 @@ import ( "net" "time" + "github.com/nextdns/nextdns/hosts" "github.com/nextdns/nextdns/resolver" ) @@ -42,6 +43,10 @@ type Proxy struct { // QueryLog specifies an optional log function called for each received query. QueryLog func(QueryInfo) + // InfoLog specifies an option log function called when some actions are + // performed. + InfoLog func(string) + // ErrorLog specifies an optional log function for errors. If not set, // errors are not reported. ErrorLog func(error) @@ -56,50 +61,70 @@ func (p Proxy) ListenAndServe(ctx context.Context) error { addr = ":53" } - var udp net.PacketConn - var tcp net.Listener - lc := &net.ListenConfig{} - ctx, cancel := context.WithCancel(ctx) - errs := make(chan error, 3) + var addrs []string - go func() { - var err error - udp, err = lc.ListenPacket(ctx, "udp", addr) - if err == nil { - err = p.serveUDP(udp) - } - cancel() - if err != nil { - err = fmt.Errorf("udp: %w", err) - } - errs <- err - }() - - go func() { - var err error - tcp, err = lc.Listen(ctx, "tcp", addr) - if err == nil { - err = p.serveTCP(tcp) + // Try to lookup the given addr in the /etc/hosts file (for localhost for + // instance). + if host, port, err := net.SplitHostPort(addr); err == nil { + if ips := hosts.LookupHost(host); len(ips) > 0 { + for _, ip := range ips { + addrs = append(addrs, net.JoinHostPort(ip, port)) + } } - cancel() - if err != nil { - err = fmt.Errorf("tcp: %w", err) - } - errs <- err - }() + } + + if len(addrs) == 0 { + addrs = []string{addr} + } + + lc := &net.ListenConfig{} + ctx, cancel := context.WithCancel(ctx) + defer cancel() + expReturns := (len(addrs) * 2) + 1 + errs := make(chan error, expReturns) + var closeAll []func() error + + for _, addr := range addrs { + go func(addr string) { + var err error + p.logInfof("Listening on UDP/%s", addr) + udp, err := lc.ListenPacket(ctx, "udp", addr) + if err == nil { + closeAll = append(closeAll, udp.Close) + err = p.serveUDP(udp) + } + cancel() + if err != nil { + err = fmt.Errorf("udp: %w", err) + } + errs <- err + }(addr) + + go func(addr string) { + var err error + p.logInfof("Listening on TCP/%s", addr) + tcp, err := lc.Listen(ctx, "tcp", addr) + if err == nil { + closeAll = append(closeAll, tcp.Close) + err = p.serveTCP(tcp) + } + cancel() + if err != nil { + err = fmt.Errorf("tcp: %w", err) + } + errs <- err + }(addr) + } <-ctx.Done() errs <- ctx.Err() - if udp != nil { - udp.Close() - } - if tcp != nil { - tcp.Close() + for _, close := range closeAll { + close() } // Wait for the two sockets (+ ctx err) to be terminated and return the // initial error. var err error - for i := 0; i < 3; i++ { + for i := 0; i < expReturns; i++ { if e := <-errs; (err == nil || errors.Is(err, context.Canceled)) && e != nil { err = e } @@ -123,6 +148,12 @@ func (p Proxy) logQuery(q QueryInfo) { } } +func (p Proxy) logInfof(format string, a ...interface{}) { + if p.InfoLog != nil { + p.InfoLog(fmt.Sprintf(format, a...)) + } +} + func (p Proxy) logErr(err error) { if err != nil && p.ErrorLog != nil { p.ErrorLog(err) diff --git a/run.go b/run.go index 36604daf..5bfe4033 100644 --- a/run.go +++ b/run.go @@ -7,7 +7,6 @@ import ( "net" "net/http" "os" - "os/signal" "runtime" "strconv" "strings" @@ -105,12 +104,13 @@ func (p *proxySvc) Restart() error { } func (p *proxySvc) Stop() error { - p.log.Infof("Stopping NextDNS on %s", p.Addr) + p.log.Infof("Stopping NextDNS %s/%s", version, platform) if p.stop() { if p.OnStopped != nil { p.OnStopped() } } + p.log.Infof("NextDNS %s/%s stopped", version, platform) return nil } @@ -206,6 +206,9 @@ func run(args []string) error { errStr) } } + p.InfoLog = func(msg string) { + log.Info(msg) + } p.ErrorLog = func(err error) { log.Error(err) } @@ -233,14 +236,7 @@ func run(args []string) error { }) } - if err := p.Start(); err != nil { - return err - } - - sig := make(chan os.Signal, 1) - signal.Notify(sig, syscall.SIGTERM, os.Interrupt) - <-sig - return p.Stop() + return service.Run("nextdns", p) } // isLocalhostMode returns true if listen is only listening for the local host.