Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support disabled resolved stub server mode #1493

Merged
merged 20 commits into from
Jan 24, 2024
138 changes: 66 additions & 72 deletions client/internal/dns/file_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package dns

import (
"bufio"
"bytes"
"fmt"
"os"
Expand All @@ -24,11 +23,15 @@ const (
)

type fileConfigurator struct {
repair *repair

originalPerms os.FileMode
}

func newFileConfigurator() (hostManager, error) {
return &fileConfigurator{}, nil
fc := &fileConfigurator{}
fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig)
return fc, nil
}

func (f *fileConfigurator) supportCustomPort() bool {
Expand Down Expand Up @@ -59,22 +62,35 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
}
}

searchDomainList := searchDomains(config)
nbSearchDomains := searchDomains(config)
nbNameserverIP := config.ServerIP

originalSearchDomains, nameServers, others, err := originalDNSConfigs(fileDefaultResolvConfBackupLocation)
resolvConf, err := parseBackupResolvConf()
if err != nil {
log.Error(err)
}

searchDomainList = mergeSearchDomains(searchDomainList, originalSearchDomains)
f.repair.stopWatchFileChanges()

err = f.updateConfig(nbSearchDomains, nbNameserverIP, resolvConf)
if err != nil {
return err
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we enrich this with some context?

Suggested change
return err
return fmt.Errorf("parse backup resolv.conf: %w", err)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right but in the updateConfig(...) we return with a long error message. Does it make sense to add more info here?

}
f.repair.watchFileChanges(nbSearchDomains, nbNameserverIP)
return nil
}

func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf) error {
searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)
nameServers := generateNsList(nbNameserverIP, cfg)

buf := prepareResolvConfContent(
searchDomainList,
append([]string{config.ServerIP}, nameServers...),
others)
nameServers,
cfg.others)

log.Debugf("creating managed file %s", defaultResolvConfPath)
err = os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
err := os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
if err != nil {
restoreErr := f.restore()
if restoreErr != nil {
Expand All @@ -88,6 +104,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
}

func (f *fileConfigurator) restoreHostDNS() error {
f.repair.stopWatchFileChanges()
return f.restore()
}

Expand Down Expand Up @@ -115,6 +132,18 @@ func (f *fileConfigurator) restore() error {
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
}

// generateNsList generates a list of nameservers from the config and adds the primary nameserver to the beginning of the list
func generateNsList(nbNameserverIP string, cfg *resolvConf) []string {
ns := make([]string, 1, len(cfg.nameServers)+1)
ns[0] = nbNameserverIP
for _, cfgNs := range cfg.nameServers {
if nbNameserverIP != cfgNs {
ns = append(ns, cfgNs)
}
}
return ns
}

func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes.Buffer {
var buf bytes.Buffer
buf.WriteString(fileGeneratedResolvConfContentHeaderNextLine)
Expand Down Expand Up @@ -150,70 +179,6 @@ func searchDomains(config HostDNSConfig) []string {
return listOfDomains
}

func originalDNSConfigs(resolvconfFile string) (searchDomains, nameServers, others []string, err error) {
file, err := os.Open(resolvconfFile)
if err != nil {
err = fmt.Errorf(`could not read existing resolv.conf`)
return
}
defer file.Close()

reader := bufio.NewReader(file)

for {
lineBytes, isPrefix, readErr := reader.ReadLine()
if readErr != nil {
break
}

if isPrefix {
err = fmt.Errorf(`resolv.conf line too long`)
return
}

line := strings.TrimSpace(string(lineBytes))

if strings.HasPrefix(line, "#") {
continue
}

if strings.HasPrefix(line, "domain") {
continue
}

if strings.HasPrefix(line, "options") && strings.Contains(line, "rotate") {
line = strings.ReplaceAll(line, "rotate", "")
splitLines := strings.Fields(line)
if len(splitLines) == 1 {
continue
}
line = strings.Join(splitLines, " ")
}

if strings.HasPrefix(line, "search") {
splitLines := strings.Fields(line)
if len(splitLines) < 2 {
continue
}

searchDomains = splitLines[1:]
continue
}

if strings.HasPrefix(line, "nameserver") {
splitLines := strings.Fields(line)
if len(splitLines) != 2 {
continue
}
nameServers = append(nameServers, splitLines[1])
continue
}

others = append(others, line)
}
return
}

// merge search Domains lists and cut off the list if it is too long
func mergeSearchDomains(searchDomains []string, originalSearchDomains []string) []string {
lineSize := len("search")
Expand All @@ -230,6 +195,19 @@ func mergeSearchDomains(searchDomains []string, originalSearchDomains []string)
// return with the number of characters in the searchDomains line
func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string) int {
for _, sd := range vs {
duplicated := false
for _, fs := range *s {
if fs == sd {
duplicated = true
break
}

}

if duplicated {
continue
}

tmpCharsNumber := initialLineChars + 1 + len(sd)
if tmpCharsNumber > fileMaxLineCharsLimit {
// lets log all skipped Domains
Expand All @@ -246,6 +224,7 @@ func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string
}
*s = append(*s, sd)
}

return initialLineChars
}

Expand All @@ -266,3 +245,18 @@ func copyFile(src, dest string) error {
}
return nil
}

func isContains(subList []string, list []string) bool {
for _, sl := range subList {
var found bool
for _, l := range list {
if sl == l {
found = true
}
}
if !found {
return false
}
}
return true
}
51 changes: 50 additions & 1 deletion client/internal/dns/file_linux_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//go:build !android

package dns

import (
Expand All @@ -7,7 +9,7 @@ import (

func Test_mergeSearchDomains(t *testing.T) {
searchDomains := []string{"a", "b"}
originDomains := []string{"a", "b"}
originDomains := []string{"c", "d"}
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
if len(mergedDomains) != 4 {
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 4)
Expand Down Expand Up @@ -49,6 +51,53 @@ func Test_mergeSearchTooLongDomain(t *testing.T) {
}
}

func Test_isContains(t *testing.T) {
type args struct {
subList []string
list []string
}
tests := []struct {
args args
want bool
}{
{
args: args{
subList: []string{"a", "b", "c"},
list: []string{"a", "b", "c"},
},
want: true,
},
{
args: args{
subList: []string{"a"},
list: []string{"a", "b", "c"},
},
want: true,
},
{
args: args{
subList: []string{"d"},
list: []string{"a", "b", "c"},
},
want: false,
},
{
args: args{
subList: []string{"a"},
list: []string{},
},
want: false,
},
}
for _, tt := range tests {
t.Run("list check test", func(t *testing.T) {
if got := isContains(tt.args.subList, tt.args.list); got != tt.want {
t.Errorf("isContains() = %v, want %v", got, tt.want)
}
})
}
}

func getLongLine() string {
x := "search "
for {
Expand Down
98 changes: 98 additions & 0 deletions client/internal/dns/file_parser_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
//go:build !android

package dns

import (
"fmt"
"os"
"strings"
)

const (
defaultResolvConfPath = "/etc/resolv.conf"
)

type resolvConf struct {
nameServers []string
searchDomains []string
others []string
}

func (r *resolvConf) String() string {
return fmt.Sprintf("search domains: %v, name servers: %v, others: %s", r.searchDomains, r.nameServers, r.others)
}

func parseDefaultResolvConf() (*resolvConf, error) {
return parseResolvConfFile(defaultResolvConfPath)
}

func parseBackupResolvConf() (*resolvConf, error) {
return parseResolvConfFile(fileDefaultResolvConfBackupLocation)
}

func parseResolvConfFile(resolvConfFile string) (*resolvConf, error) {
file, err := os.Open(resolvConfFile)
if err != nil {
return nil, err
pappz marked this conversation as resolved.
Show resolved Hide resolved
}
defer file.Close()
pappz marked this conversation as resolved.
Show resolved Hide resolved

cur, err := os.ReadFile(resolvConfFile)
if err != nil {
return nil, err
pappz marked this conversation as resolved.
Show resolved Hide resolved
}

if len(cur) == 0 {
return nil, fmt.Errorf("file is empty")
}

rconf := &resolvConf{
nameServers: make([]string, 0),
pappz marked this conversation as resolved.
Show resolved Hide resolved
others: make([]string, 0),
pappz marked this conversation as resolved.
Show resolved Hide resolved
}

for _, line := range strings.Split(string(cur), "\n") {
line = strings.TrimSpace(line)

if strings.HasPrefix(line, "#") {
continue
}

if strings.HasPrefix(line, "domain") {
continue
}

if strings.HasPrefix(line, "options") && strings.Contains(line, "rotate") {
line = strings.ReplaceAll(line, "rotate", "")
splitLines := strings.Fields(line)
if len(splitLines) == 1 {
continue
}
line = strings.Join(splitLines, " ")
}

if strings.HasPrefix(line, "search") {
splitLines := strings.Fields(line)
if len(splitLines) < 2 {
continue
}

rconf.searchDomains = splitLines[1:]
continue
}

if strings.HasPrefix(line, "nameserver") {
splitLines := strings.Fields(line)
if len(splitLines) != 2 {
continue
}
rconf.nameServers = append(rconf.nameServers, splitLines[1])
continue
}

if line != "" {
rconf.others = append(rconf.others, line)
}
}
return rconf, nil
}