Skip to content

Commit

Permalink
Add backend weight round robin select (#34)
Browse files Browse the repository at this point in the history
* Add upstream selector, there are two selector now:
    - random selector
    - weight random selector

random selector will choose upstream at random; weight random selector will choose upstream at random with weight

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Rewrite config and config file example, prepare for weight round robbin selector

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Replace bad implement of weight random selector with weight round robbin selector, the algorithm is nginx weight round robbin like

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Use new config module

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Disable deprecated DualStack set

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Fix typo

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Optimize upstreamSelector judge

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Fix typo

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Add config timeout unit tips

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Set wrr http client timeout to replace http request timeout

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Add weight value range

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Add a line ending for .gitignore

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Optimize config file style

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Modify Weight type to int32

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Add upstreamError

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Rewrite Selector interface and wrr implement

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Use http module predefined constant to judge req.response.StatusCode

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Use Selector.ReportUpstreamError to report upstream error for evaluation loop in real time

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Make client selector field private

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Replace config file url to URL
Add miss space for 'weight= 50'

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Rewrite Selector.ReportUpstreamError to Selector.ReportUpstreamStatus, report upstream ok in real time

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Fix checkIETFResponse: if upstream OK, won't increase weight

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>

* Fix typo

Signed-off-by: Sherlock Holo <sherlockya@gmail.com>
  • Loading branch information
Sherlock-Holo authored and m13253 committed Mar 9, 2019
1 parent 8f2004d commit fec1e84
Show file tree
Hide file tree
Showing 14 changed files with 552 additions and 122 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@

# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736
.glide/

.idea/
184 changes: 122 additions & 62 deletions doh-client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,22 @@ import (
"net"
"net/http"
"net/http/cookiejar"
"net/url"
"strconv"
"strings"
"sync"
"time"

"github.com/m13253/dns-over-https/doh-client/config"
"github.com/m13253/dns-over-https/doh-client/selector"
"github.com/m13253/dns-over-https/json-dns"
"github.com/miekg/dns"
"golang.org/x/net/http2"
"golang.org/x/net/idna"
)

type Client struct {
conf *config
conf *config.Config
bootstrap []string
passthrough []string
udpClient *dns.Client
Expand All @@ -56,6 +59,7 @@ type Client struct {
httpTransport *http.Transport
httpClient *http.Client
httpClientLastCreate time.Time
selector selector.Selector
}

type DNSRequest struct {
Expand All @@ -68,7 +72,7 @@ type DNSRequest struct {
err error
}

func NewClient(conf *config) (c *Client, err error) {
func NewClient(conf *config.Config) (c *Client, err error) {
c = &Client{
conf: conf,
}
Expand All @@ -78,11 +82,11 @@ func NewClient(conf *config) (c *Client, err error) {
c.udpClient = &dns.Client{
Net: "udp",
UDPSize: dns.DefaultMsgSize,
Timeout: time.Duration(conf.Timeout) * time.Second,
Timeout: time.Duration(conf.Other.Timeout) * time.Second,
}
c.tcpClient = &dns.Client{
Net: "tcp",
Timeout: time.Duration(conf.Timeout) * time.Second,
Timeout: time.Duration(conf.Other.Timeout) * time.Second,
}
for _, addr := range conf.Listen {
c.udpServers = append(c.udpServers, &dns.Server{
Expand All @@ -98,9 +102,9 @@ func NewClient(conf *config) (c *Client, err error) {
})
}
c.bootstrapResolver = net.DefaultResolver
if len(conf.Bootstrap) != 0 {
c.bootstrap = make([]string, len(conf.Bootstrap))
for i, bootstrap := range conf.Bootstrap {
if len(conf.Other.Bootstrap) != 0 {
c.bootstrap = make([]string, len(conf.Other.Bootstrap))
for i, bootstrap := range conf.Other.Bootstrap {
bootstrapAddr, err := net.ResolveUDPAddr("udp", bootstrap)
if err != nil {
bootstrapAddr, err = net.ResolveUDPAddr("udp", "["+bootstrap+"]:53")
Expand All @@ -120,9 +124,9 @@ func NewClient(conf *config) (c *Client, err error) {
return conn, err
},
}
if len(conf.Passthrough) != 0 {
c.passthrough = make([]string, len(conf.Passthrough))
for i, passthrough := range conf.Passthrough {
if len(conf.Other.Passthrough) != 0 {
c.passthrough = make([]string, len(conf.Other.Passthrough))
for i, passthrough := range conf.Other.Passthrough {
if punycode, err := idna.ToASCII(passthrough); err != nil {
passthrough = punycode
}
Expand All @@ -133,7 +137,7 @@ func NewClient(conf *config) (c *Client, err error) {
// Most CDNs require Cookie support to prevent DDoS attack.
// Disabling Cookie does not effectively prevent tracking,
// so I will leave it on to make anti-DDoS services happy.
if !c.conf.NoCookies {
if !c.conf.Other.NoCookies {
c.cookieJar, err = cookiejar.New(nil)
if err != nil {
return nil, err
Expand All @@ -147,23 +151,59 @@ func NewClient(conf *config) (c *Client, err error) {
if err != nil {
return nil, err
}

switch c.conf.Upstream.UpstreamSelector {
default:
// if selector is invalid or random, use random selector, or should we stop program and let user knows he is wrong?
s := selector.NewRandomSelector()
for _, u := range c.conf.Upstream.UpstreamGoogle {
if err := s.Add(u.URL, selector.Google); err != nil {
return nil, err
}
}

for _, u := range c.conf.Upstream.UpstreamIETF {
if err := s.Add(u.URL, selector.IETF); err != nil {
return nil, err
}
}

c.selector = s

case config.WeightedRoundRobin:
s := selector.NewWeightRoundRobinSelector(time.Duration(c.conf.Other.Timeout) * time.Second)
for _, u := range c.conf.Upstream.UpstreamGoogle {
if err := s.Add(u.URL, selector.Google, u.Weight); err != nil {
return nil, err
}
}

for _, u := range c.conf.Upstream.UpstreamIETF {
if err := s.Add(u.URL, selector.IETF, u.Weight); err != nil {
return nil, err
}
}

c.selector = s
}

return c, nil
}

func (c *Client) newHTTPClient() error {
c.httpClientMux.Lock()
defer c.httpClientMux.Unlock()
if !c.httpClientLastCreate.IsZero() && time.Since(c.httpClientLastCreate) < time.Duration(c.conf.Timeout)*time.Second {
if !c.httpClientLastCreate.IsZero() && time.Since(c.httpClientLastCreate) < time.Duration(c.conf.Other.Timeout)*time.Second {
return nil
}
if c.httpTransport != nil {
c.httpTransport.CloseIdleConnections()
}
dialer := &net.Dialer{
Timeout: time.Duration(c.conf.Timeout) * time.Second,
Timeout: time.Duration(c.conf.Other.Timeout) * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
Resolver: c.bootstrapResolver,
// DualStack: true,
Resolver: c.bootstrapResolver,
}
c.httpTransport = &http.Transport{
DialContext: dialer.DialContext,
Expand All @@ -172,9 +212,9 @@ func (c *Client) newHTTPClient() error {
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
Proxy: http.ProxyFromEnvironment,
TLSHandshakeTimeout: time.Duration(c.conf.Timeout) * time.Second,
TLSHandshakeTimeout: time.Duration(c.conf.Other.Timeout) * time.Second,
}
if c.conf.NoIPv6 {
if c.conf.Other.NoIPv6 {
c.httpTransport.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
if strings.HasPrefix(network, "tcp") {
network = "tcp4"
Expand Down Expand Up @@ -213,11 +253,15 @@ func (c *Client) Start() error {
}
}
close(results)

// start evaluation poll
c.selector.StartEvaluate()

return nil
}

func (c *Client) handlerFunc(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.conf.Timeout)*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(c.conf.Other.Timeout)*time.Second)
defer cancel()

if r.Response {
Expand Down Expand Up @@ -246,7 +290,7 @@ func (c *Client) handlerFunc(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
} else {
questionType = strconv.FormatUint(uint64(question.Qtype), 10)
}
if c.conf.Verbose {
if c.conf.Other.Verbose {
fmt.Printf("%s - - [%s] \"%s %s %s\"\n", w.RemoteAddr(), time.Now().Format("02/Jan/2006:15:04:05 -0700"), questionName, questionClass, questionType)
}

Expand Down Expand Up @@ -284,64 +328,80 @@ func (c *Client) handlerFunc(w dns.ResponseWriter, r *dns.Msg, isTCP bool) {
return
}

requestType := ""
if len(c.conf.UpstreamIETF) == 0 {
requestType = "application/dns-json"
} else if len(c.conf.UpstreamGoogle) == 0 {
requestType = "application/dns-message"
} else {
numServers := len(c.conf.UpstreamGoogle) + len(c.conf.UpstreamIETF)
random := rand.Intn(numServers)
if random < len(c.conf.UpstreamGoogle) {
requestType = "application/dns-json"
} else {
requestType = "application/dns-message"
}
upstream := c.selector.Get()
requestType := upstream.RequestType

if c.conf.Other.Verbose {
log.Println("choose upstream:", upstream)
}

var req *DNSRequest
if requestType == "application/dns-json" {
req = c.generateRequestGoogle(ctx, w, r, isTCP)
} else if requestType == "application/dns-message" {
req = c.generateRequestIETF(ctx, w, r, isTCP)
} else {
switch requestType {
case "application/dns-json":
req = c.generateRequestGoogle(ctx, w, r, isTCP, upstream)

case "application/dns-message":
req = c.generateRequestIETF(ctx, w, r, isTCP, upstream)

default:
panic("Unknown request Content-Type")
}

if req.response != nil {
defer req.response.Body.Close()
for _, header := range c.conf.DebugHTTPHeaders {
if value := req.response.Header.Get(header); value != "" {
log.Printf("%s: %s\n", header, value)
if req.err != nil {
if urlErr, ok := req.err.(*url.Error); ok {
// should we only check timeout?
if urlErr.Timeout() {
c.selector.ReportUpstreamStatus(upstream, selector.Timeout)
}
}
}
if req.err != nil {

return
}

contentType := ""
candidateType := strings.SplitN(req.response.Header.Get("Content-Type"), ";", 2)[0]
if candidateType == "application/json" {
contentType = "application/json"
} else if candidateType == "application/dns-message" {
contentType = "application/dns-message"
} else if candidateType == "application/dns-udpwireformat" {
contentType = "application/dns-message"
} else {
if requestType == "application/dns-json" {
contentType = "application/json"
} else if requestType == "application/dns-message" {
contentType = "application/dns-message"
// if req.err == nil, req.response != nil
defer req.response.Body.Close()

for _, header := range c.conf.Other.DebugHTTPHeaders {
if value := req.response.Header.Get(header); value != "" {
log.Printf("%s: %s\n", header, value)
}
}

if contentType == "application/json" {
candidateType := strings.SplitN(req.response.Header.Get("Content-Type"), ";", 2)[0]

switch candidateType {
case "application/json":
c.parseResponseGoogle(ctx, w, r, isTCP, req)
} else if contentType == "application/dns-message" {

case "application/dns-message", "application/dns-udpwireformat":
c.parseResponseIETF(ctx, w, r, isTCP, req)
} else {
panic("Unknown response Content-Type")

default:
switch requestType {
case "application/dns-json":
c.parseResponseGoogle(ctx, w, r, isTCP, req)

case "application/dns-message":
c.parseResponseIETF(ctx, w, r, isTCP, req)

default:
panic("Unknown response Content-Type")
}
}

// https://developers.cloudflare.com/1.1.1.1/dns-over-https/request-structure/ says
// returns code will be 200 / 400 / 413 / 415 / 504, some server will return 503, so
// I think if status code is 5xx, upstream must has some problems
/*if req.response.StatusCode/100 == 5 {
c.selector.ReportUpstreamStatus(upstream, selector.Medium)
}*/

switch req.response.StatusCode / 100 {
case 5:
c.selector.ReportUpstreamStatus(upstream, selector.Error)

case 2:
c.selector.ReportUpstreamStatus(upstream, selector.OK)
}
}

Expand All @@ -360,7 +420,7 @@ var (

func (c *Client) findClientIP(w dns.ResponseWriter, r *dns.Msg) (ednsClientAddress net.IP, ednsClientNetmask uint8) {
ednsClientNetmask = 255
if c.conf.NoECS {
if c.conf.Other.NoECS {
return net.IPv4(0, 0, 0, 0), 0
}
if opt := r.IsEdns0(); opt != nil {
Expand Down
Loading

0 comments on commit fec1e84

Please sign in to comment.