Skip to content

Commit

Permalink
feat: Refactor DNS resolvers for concurrent lookups
Browse files Browse the repository at this point in the history
This commit significantly improves the performance of DNS lookups by
implementing concurrent query execution across all resolver types.
It also refactors the common lookup logic into a shared function to
reduce code duplication and improve maintainability.

Performance improvements:
- Reduced lookup time for multiple queries by ~78%
  (from 1.356s to 0.297s for a sample query with 10 record types)
- Improved CPU utilization (from 1% to 4%) indicating better resource use
  • Loading branch information
mr-karan committed Jul 2, 2024
1 parent adfd23a commit a6447cf
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 42 deletions.
44 changes: 25 additions & 19 deletions cmd/doggo/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"log/slog"
"os"
"sync"
"time"

"github.com/knadh/koanf/providers/posflag"
Expand Down Expand Up @@ -85,6 +86,13 @@ func main() {
os.Exit(0)
}

var (
wg sync.WaitGroup
mu sync.Mutex
allResponses []resolvers.Response
allErrors []error
)

queryFlags := resolvers.QueryFlags{
AA: k.Bool("aa"),
AD: k.Bool("ad"),
Expand All @@ -94,9 +102,24 @@ func main() {
DO: k.Bool("do"),
}

responses, responseErrors := resolveQueries(&app, queryFlags)
for _, resolver := range app.Resolvers {
wg.Add(1)
go func(r resolvers.Resolver) {
defer wg.Done()
responses, err := r.Lookup(app.Questions, queryFlags)
mu.Lock()
if err != nil {
allErrors = append(allErrors, err)
} else {
allResponses = append(allResponses, responses...)
}
mu.Unlock()
}(resolver)
}

outputResults(&app, responses, responseErrors)
wg.Wait()

outputResults(&app, allResponses, allErrors)

os.Exit(0)
}
Expand Down Expand Up @@ -166,23 +189,6 @@ func loadNameservers(app *app.App, args []string) {
app.QueryFlags.QNames = append(app.QueryFlags.QNames, qn...)
}

func resolveQueries(app *app.App, flags resolvers.QueryFlags) ([]resolvers.Response, []error) {
var responses []resolvers.Response
var responseErrors []error

for _, q := range app.Questions {
for _, rslv := range app.Resolvers {
resp, err := rslv.Lookup(q, flags)
if err != nil {
responseErrors = append(responseErrors, err)
}
responses = append(responses, resp)
}
}

return responses, responseErrors
}

func outputResults(app *app.App, responses []resolvers.Response, responseErrors []error) {
if app.QueryFlags.ShowJSON {
outputJSON(app.Logger, responses, responseErrors)
Expand Down
9 changes: 7 additions & 2 deletions pkg/resolvers/classic.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func NewClassicResolver(server string, classicOpts ClassicResolverOpts, resolver

// Lookup takes a dns.Question and sends them to DNS Server.
// It parses the Response from the server in a custom output format.
func (r *ClassicResolver) Lookup(question dns.Question, flags QueryFlags) (Response, error) {
func (r *ClassicResolver) query(question dns.Question, flags QueryFlags) (Response, error) {
var (
rsp Response
messages = prepareMessages(question, flags, r.resolverOptions.Ndots, r.resolverOptions.SearchList)
Expand Down Expand Up @@ -93,7 +93,7 @@ func (r *ClassicResolver) Lookup(question dns.Question, flags QueryFlags) (Respo
r.client.Net = "tcp"
}
r.resolverOptions.Logger.Debug("Response truncated; retrying now", "protocol", r.client.Net)
return r.Lookup(question, flags)
return r.query(question, flags)
}

// Pack questions in output.
Expand All @@ -119,3 +119,8 @@ func (r *ClassicResolver) Lookup(question dns.Question, flags QueryFlags) (Respo
}
return rsp, nil
}

// Lookup implements the Resolver interface
func (r *ClassicResolver) Lookup(questions []dns.Question, flags QueryFlags) ([]Response, error) {
return ConcurrentLookup(questions, flags, r.query, r.resolverOptions.Logger)
}
42 changes: 42 additions & 0 deletions pkg/resolvers/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package resolvers

import (
"log/slog"
"sync"

"github.com/miekg/dns"
)

// QueryFunc represents the signature of a query function
type QueryFunc func(question dns.Question, flags QueryFlags) (Response, error)

// ConcurrentLookup performs concurrent DNS lookups
func ConcurrentLookup(questions []dns.Question, flags QueryFlags, queryFunc QueryFunc, logger *slog.Logger) ([]Response, error) {
var wg sync.WaitGroup
responses := make([]Response, len(questions))
errors := make([]error, len(questions))

for i, q := range questions {
wg.Add(1)
go func(i int, q dns.Question) {
defer wg.Done()
resp, err := queryFunc(q, flags)
responses[i] = resp
errors[i] = err
}(i, q)
}

wg.Wait()

// Collect non-nil responses and handle errors
var validResponses []Response
for i, resp := range responses {
if errors[i] != nil {
logger.Error("error in lookup", "error", errors[i])
} else {
validResponses = append(validResponses, resp)
}
}

return validResponses, nil
}
10 changes: 7 additions & 3 deletions pkg/resolvers/dnscrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@ func NewDNSCryptResolver(server string, dnscryptOpts DNSCryptResolverOpts, resol
}, nil
}

// Lookup takes a dns.Question and sends them to DNS Server.
// It parses the Response from the server in a custom output format.
func (r *DNSCryptResolver) Lookup(question dns.Question, flags QueryFlags) (Response, error) {
// Lookup implements the Resolver interface
func (r *DNSCryptResolver) Lookup(questions []dns.Question, flags QueryFlags) ([]Response, error) {
return ConcurrentLookup(questions, flags, r.query, r.resolverOptions.Logger)
}

// query performs a single DNS query
func (r *DNSCryptResolver) query(question dns.Question, flags QueryFlags) (Response, error) {
var (
rsp Response
messages = prepareMessages(question, flags, r.resolverOptions.Ndots, r.resolverOptions.SearchList)
Expand Down
9 changes: 7 additions & 2 deletions pkg/resolvers/doh.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ func NewDOHResolver(server string, resolverOpts Options) (Resolver, error) {
}, nil
}

// Lookup takes a dns.Question and sends them to DNS Server.
// query takes a dns.Question and sends them to DNS Server.
// It parses the Response from the server in a custom output format.
func (r *DOHResolver) Lookup(question dns.Question, flags QueryFlags) (Response, error) {
func (r *DOHResolver) query(question dns.Question, flags QueryFlags) (Response, error) {
var (
rsp Response
messages = prepareMessages(question, flags, r.resolverOptions.Ndots, r.resolverOptions.SearchList)
Expand Down Expand Up @@ -123,3 +123,8 @@ func (r *DOHResolver) Lookup(question dns.Question, flags QueryFlags) (Response,
}
return rsp, nil
}

// Lookup implements the Resolver interface
func (r *DOHResolver) Lookup(questions []dns.Question, flags QueryFlags) ([]Response, error) {
return ConcurrentLookup(questions, flags, r.query, r.resolverOptions.Logger)
}
7 changes: 6 additions & 1 deletion pkg/resolvers/doq.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,14 @@ func NewDOQResolver(server string, resolverOpts Options) (Resolver, error) {
}, nil
}

// Lookup implements the Resolver interface
func (r *DOQResolver) Lookup(questions []dns.Question, flags QueryFlags) ([]Response, error) {
return ConcurrentLookup(questions, flags, r.query, r.resolverOptions.Logger)
}

// Lookup takes a dns.Question and sends them to DNS Server.
// It parses the Response from the server in a custom output format.
func (r *DOQResolver) Lookup(question dns.Question, flags QueryFlags) (Response, error) {
func (r *DOQResolver) query(question dns.Question, flags QueryFlags) (Response, error) {
var (
rsp Response
messages = prepareMessages(question, flags, r.resolverOptions.Ndots, r.resolverOptions.SearchList)
Expand Down
2 changes: 1 addition & 1 deletion pkg/resolvers/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type Options struct {
// Client. Different types of providers can load
// a DNS Resolver satisfying this interface.
type Resolver interface {
Lookup(dns.Question, QueryFlags) (Response, error)
Lookup([]dns.Question, QueryFlags) ([]Response, error)
}

// Response represents a custom output format
Expand Down
55 changes: 41 additions & 14 deletions web/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
"sync"
"time"

"github.com/mr-karan/doggo/internal/app"
Expand Down Expand Up @@ -41,14 +42,14 @@ func handleLookup(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
if err != nil {
app.Logger.Error("error reading request body", "error", err)
sendErrorResponse(w, fmt.Sprintf("Invalid JSON payload"), http.StatusBadRequest, nil)
sendErrorResponse(w, "Invalid JSON payload", http.StatusBadRequest, nil)
return
}
// Prepare query flags.
var qFlags models.QueryFlags
if err := json.Unmarshal(b, &qFlags); err != nil {
app.Logger.Error("error unmarshalling payload", "error", err)
sendErrorResponse(w, fmt.Sprintf("Invalid JSON payload"), http.StatusBadRequest, nil)
sendErrorResponse(w, "Invalid JSON payload", http.StatusBadRequest, nil)
return
}

Expand All @@ -60,14 +61,14 @@ func handleLookup(w http.ResponseWriter, r *http.Request) {
app.PrepareQuestions()

if len(app.Questions) == 0 {
sendErrorResponse(w, fmt.Sprintf("Missing field `query`."), http.StatusBadRequest, nil)
sendErrorResponse(w, "Missing field `query`.", http.StatusBadRequest, nil)
return
}

// Load Nameservers.
if err := app.LoadNameservers(); err != nil {
app.Logger.Error("error loading nameservers", "error", err)
sendErrorResponse(w, fmt.Sprintf("Error looking up for records."), http.StatusInternalServerError, nil)
sendErrorResponse(w, "Error looking up for records.", http.StatusInternalServerError, nil)
return
}

Expand Down Expand Up @@ -96,19 +97,45 @@ func handleLookup(w http.ResponseWriter, r *http.Request) {
RD: true,
}

var responses []resolvers.Response
for _, q := range app.Questions {
for _, rslv := range app.Resolvers {
resp, err := rslv.Lookup(q, queryFlags)
// ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
// defer cancel()

var (
wg sync.WaitGroup
mu sync.Mutex
allResponses []resolvers.Response
allErrors []error
)

for _, resolver := range app.Resolvers {
wg.Add(1)
go func(r resolvers.Resolver) {
defer wg.Done()
responses, err := r.Lookup(app.Questions, queryFlags)
mu.Lock()
if err != nil {
app.Logger.Error("error looking up DNS records", "error", err)
sendErrorResponse(w, "Error looking up for records.", http.StatusInternalServerError, nil)
return
allErrors = append(allErrors, err)
} else {
allResponses = append(allResponses, responses...)
}
responses = append(responses, resp)
}
mu.Unlock()
}(resolver)
}

wg.Wait()

if len(allErrors) > 0 {
app.Logger.Error("errors looking up DNS records", "errors", allErrors)
sendErrorResponse(w, "Error looking up for records.", http.StatusInternalServerError, nil)
return
}
sendResponse(w, http.StatusOK, responses)

if len(allResponses) == 0 {
sendErrorResponse(w, "No records found.", http.StatusNotFound, nil)
return
}

sendResponse(w, http.StatusOK, allResponses)
}

// wrap is a middleware that wraps HTTP handlers and injects the "app" context.
Expand Down

0 comments on commit a6447cf

Please sign in to comment.