Skip to content

Commit

Permalink
stats: handle client ids
Browse files Browse the repository at this point in the history
  • Loading branch information
ainar-g committed Jan 26, 2021
1 parent 4e14ab3 commit 2d68df4
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 44 deletions.
37 changes: 17 additions & 20 deletions internal/dnsforward/stats.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package dnsforward

import (
"net"
"strings"
"time"

Expand All @@ -16,10 +15,10 @@ import (
func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) {
elapsed := time.Since(ctx.startTime)
s := ctx.srv
d := ctx.proxyCtx
pctx := ctx.proxyCtx

shouldLog := true
msg := d.Req
msg := pctx.Req

// don't log ANY request if refuseAny is enabled
if len(msg.Question) >= 1 && msg.Question[0].Qtype == dns.TypeANY && s.conf.RefuseAny {
Expand All @@ -32,15 +31,15 @@ func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) {
if shouldLog && s.queryLog != nil {
p := querylog.AddParams{
Question: msg,
Answer: d.Res,
Answer: pctx.Res,
OrigAnswer: ctx.origResp,
Result: ctx.result,
Elapsed: elapsed,
ClientIP: IPFromAddr(d.Addr),
ClientIP: IPFromAddr(pctx.Addr),
ClientID: ctx.clientID,
}

switch d.Proto {
switch pctx.Proto {
case proxy.ProtoHTTPS:
p.ClientProto = querylog.ClientProtoDOH
case proxy.ProtoQUIC:
Expand All @@ -54,46 +53,44 @@ func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) {
// request.
}

if d.Upstream != nil {
p.Upstream = d.Upstream.Address()
if pctx.Upstream != nil {
p.Upstream = pctx.Upstream.Address()
}
s.queryLog.Add(p)
}

s.updateStats(d, elapsed, *ctx.result)
s.updateStats(ctx, elapsed, *ctx.result)
s.RUnlock()

return resultCodeSuccess
}

func (s *Server) updateStats(d *proxy.DNSContext, elapsed time.Duration, res dnsfilter.Result) {
func (s *Server) updateStats(ctx *dnsContext, elapsed time.Duration, res dnsfilter.Result) {
if s.stats == nil {
return
}

pctx := ctx.proxyCtx
e := stats.Entry{}
e.Domain = strings.ToLower(d.Req.Question[0].Name)
e.Domain = strings.ToLower(pctx.Req.Question[0].Name)
e.Domain = e.Domain[:len(e.Domain)-1] // remove last "."
switch addr := d.Addr.(type) {
case *net.UDPAddr:
e.Client = addr.IP
case *net.TCPAddr:
e.Client = addr.IP

if clientID := ctx.clientID; clientID != "" {
e.Client = clientID
} else if pctx.Addr != nil {
e.Client = pctx.Addr.String()
}

e.Time = uint32(elapsed / 1000)
e.Result = stats.RNotFiltered

switch res.Reason {

case dnsfilter.FilteredSafeBrowsing:
e.Result = stats.RSafeBrowsing

case dnsfilter.FilteredParental:
e.Result = stats.RParental

case dnsfilter.FilteredSafeSearch:
e.Result = stats.RSafeSearch

case dnsfilter.FilteredBlockList:
fallthrough
case dnsfilter.FilteredInvalid:
Expand Down
8 changes: 6 additions & 2 deletions internal/stats/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,14 @@ const (
rLast
)

// Entry - data to add
// Entry is a statistics data entry.
type Entry struct {
// Clients is the client's primary ID.
//
// TODO(a.garipov): Make this a {net.IP, string} enum?
Client string

Domain string
Client net.IP
Result Result
Time uint32 // processing time (msec)
}
11 changes: 6 additions & 5 deletions internal/stats/stats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ func TestStats(t *testing.T) {
e := Entry{}

e.Domain = "domain"
e.Client = net.IP{127, 0, 0, 1}
e.Client = "127.0.0.1"
e.Result = RFiltered
e.Time = 123456
s.Update(e)

e.Domain = "domain"
e.Client = net.IP{127, 0, 0, 1}
e.Client = "127.0.0.1"
e.Result = RNotFiltered
e.Time = 123456
s.Update(e)
Expand Down Expand Up @@ -113,9 +113,10 @@ func TestLargeNumbers(t *testing.T) {
}
for i := 0; i != n; i++ {
e.Domain = fmt.Sprintf("domain%d", i)
e.Client = net.IP{127, 0, 0, 1}
e.Client[2] = byte((i & 0xff00) >> 8)
e.Client[3] = byte(i & 0xff)
ip := net.IP{127, 0, 0, 1}
ip[2] = byte((i & 0xff00) >> 8)
ip[3] = byte(i & 0xff)
e.Client = ip.String()
e.Result = RNotFiltered
e.Time = 123456
s.Update(e)
Expand Down
44 changes: 27 additions & 17 deletions internal/stats/unit.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,15 @@ func (s *statsCtx) periodicFlush() {
s.unitLock.Lock()
ptr := s.unit
s.unitLock.Unlock()

if ptr == nil {
break
}

id := s.conf.UnitID()
if ptr.id == id {
time.Sleep(time.Second)

continue
}

Expand All @@ -243,6 +245,7 @@ func (s *statsCtx) periodicFlush() {
if tx == nil {
continue
}

ok1 := s.flushUnitToDB(tx, u.id, udb)
ok2 := s.deleteUnit(tx, id-s.conf.limit)
if ok1 || ok2 {
Expand All @@ -251,6 +254,7 @@ func (s *statsCtx) periodicFlush() {
_ = tx.Rollback()
}
}

log.Tracef("periodicFlush() exited")
}

Expand All @@ -265,7 +269,7 @@ func (s *statsCtx) deleteUnit(tx *bolt.Tx, id uint32) bool {
return true
}

func convertMapToArray(m map[string]uint64, max int) []countPair {
func convertMapToSlice(m map[string]uint64, max int) []countPair {
a := []countPair{}
for k, v := range m {
pair := countPair{}
Expand All @@ -283,7 +287,7 @@ func convertMapToArray(m map[string]uint64, max int) []countPair {
return a[:max]
}

func convertArrayToMap(a []countPair) map[string]uint64 {
func convertSliceToMap(a []countPair) map[string]uint64 {
m := map[string]uint64{}
for _, it := range a {
m[it.Name] = it.Count
Expand All @@ -301,9 +305,9 @@ func serialize(u *unit) *unitDB {
udb.TimeAvg = uint32(u.timeSum / u.nTotal)
}

udb.Domains = convertMapToArray(u.domains, maxDomains)
udb.BlockedDomains = convertMapToArray(u.blockedDomains, maxDomains)
udb.Clients = convertMapToArray(u.clients, maxClients)
udb.Domains = convertMapToSlice(u.domains, maxDomains)
udb.BlockedDomains = convertMapToSlice(u.blockedDomains, maxDomains)
udb.Clients = convertMapToSlice(u.clients, maxClients)

return &udb
}
Expand All @@ -319,9 +323,9 @@ func deserialize(u *unit, udb *unitDB) {
u.nResult[i] = udb.NResult[i]
}

u.domains = convertArrayToMap(udb.Domains)
u.blockedDomains = convertArrayToMap(udb.BlockedDomains)
u.clients = convertArrayToMap(udb.Clients)
u.domains = convertSliceToMap(udb.Domains)
u.blockedDomains = convertSliceToMap(udb.BlockedDomains)
u.clients = convertSliceToMap(udb.Clients)
u.timeSum = uint64(udb.TimeAvg) * u.nTotal
}

Expand Down Expand Up @@ -372,7 +376,7 @@ func (s *statsCtx) loadUnitFromDB(tx *bolt.Tx, id uint32) *unitDB {
return &udb
}

func convertTopArray(a []countPair) []map[string]uint64 {
func convertTopSlice(a []countPair) []map[string]uint64 {
m := []map[string]uint64{}
for _, it := range a {
ent := map[string]uint64{}
Expand Down Expand Up @@ -461,13 +465,20 @@ func (s *statsCtx) getClientIP(ip net.IP) (clientIP net.IP) {
func (s *statsCtx) Update(e Entry) {
if e.Result == 0 ||
e.Result >= rLast ||
len(e.Domain) == 0 ||
!(len(e.Client) == 4 || len(e.Client) == 16) {
e.Domain == "" ||
e.Client == "" {
return
}
client := s.getClientIP(e.Client)

clientID := e.Client
if ip := net.ParseIP(clientID); ip != nil {
ip = s.getClientIP(ip)
clientID = ip.String()
}

s.unitLock.Lock()
defer s.unitLock.Unlock()

u := s.unit

u.nResult[e.Result]++
Expand All @@ -478,10 +489,9 @@ func (s *statsCtx) Update(e Entry) {
u.blockedDomains[e.Domain]++
}

u.clients[client.String()]++
u.clients[clientID]++
u.timeSum += uint64(e.Time)
u.nTotal++
s.unitLock.Unlock()
}

func (s *statsCtx) loadUnits(limit uint32) ([]*unitDB, uint32) {
Expand Down Expand Up @@ -594,8 +604,8 @@ func (s *statsCtx) getData() (statsResponse, bool) {
m[it.Name] += it.Count
}
}
a2 := convertMapToArray(m, max)
return convertTopArray(a2)
a2 := convertMapToSlice(m, max)
return convertTopSlice(a2)
}

dnsQueries := statsCollector(func(u *unitDB) (num uint64) { return u.NTotal })
Expand Down Expand Up @@ -661,7 +671,7 @@ func (s *statsCtx) GetTopClientsIP(maxCount uint) []net.IP {
m[it.Name] += it.Count
}
}
a := convertMapToArray(m, int(maxCount))
a := convertMapToSlice(m, int(maxCount))
d := []net.IP{}
for _, it := range a {
d = append(d, net.ParseIP(it.Name))
Expand Down

0 comments on commit 2d68df4

Please sign in to comment.