Skip to content

Commit

Permalink
Cleaned up in response to reviewer comments
Browse files Browse the repository at this point in the history
  • Loading branch information
pboothe committed Oct 25, 2018
1 parent 7e97c19 commit b11b7fd
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 12 deletions.
46 changes: 41 additions & 5 deletions inetdiag/inetdiag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ package inetdiag_test

import (
"encoding/json"
"io/ioutil"
"log"
"net"
"os"
"sync"
"syscall"
"testing"
"unsafe"
Expand Down Expand Up @@ -213,20 +217,52 @@ func TestParseGarbage(t *testing.T) {
}

func TestOneType(t *testing.T) {
res4, err := inetdiag.OneType(syscall.AF_INET)
// Open an AF_LOCAL socket connection.
// Get a safe name for the AF_LOCAL socket
f, err := ioutil.TempFile("", "TestOneType")
if err != nil {
t.Error(err)
}
res6, err := inetdiag.OneType(syscall.AF_INET6)
name := f.Name()
os.Remove(name)

// Open a listening UNIX socket at that mostly-safe name.
l, err := net.Listen("unix", name)
if err != nil {
t.Error(err)
}
resUnix, err := inetdiag.OneType(syscall.AF_UNIX)
defer l.Close()

// Unblock all goroutines when the function exits.
wg := sync.WaitGroup{}
wg.Add(1)
defer wg.Done()

// Start a client connection in a goroutine.
go func() {
c, err := net.Dial("unix", name)
if err != nil {
t.Error(err)
}
c.Write([]byte("hi"))
wg.Wait()
c.Close()
}()

// Accept the client connection.
fd, err := l.Accept()
if err != nil {
t.Error(err)
}
defer fd.Close()

// Verify that OneType(AF_LOCAL) finds at least one connection.
res, err := inetdiag.OneType(syscall.AF_LOCAL)
if err != nil {
t.Error(err)
}
if len(res4) == 0 && len(res6) == 0 && len(resUnix) == 0 {
t.Error("There are never no active streams.")
if len(res) == 0 {
t.Error("We have at least one active stream open right now.")
}
}

Expand Down
17 changes: 10 additions & 7 deletions inetdiag/socket-monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ import (
const TCPF_ALL = 0xFFF

var (
errBadPid = errors.New("Bad PID. Can't listen to NL socket.")
errBadSequence = errors.New("Bad sequence number. Can't interpret NetLink response.")
// ErrBadPid is used when the PID is mismatched between the netlink socket and the calling process.
ErrBadPid = errors.New("bad PID, can't listen to NL socket")

// ErrBadSequence is used when the Netlink response has a bad sequence number.
ErrBadSequence = errors.New("bad sequence number, can't interpret NetLink response")
)

func makeReq(inetType uint8) *nl.NetlinkRequest {
Expand Down Expand Up @@ -56,12 +59,12 @@ func processSingleMessage(m *syscall.NetlinkMessage, seq uint32, pid uint32) (*s
if m.Header.Seq != seq {
log.Printf("Wrong Seq nr %d, expected %d", m.Header.Seq, seq)
metrics.ErrorCount.With(prometheus.Labels{"source": "wrong seq num"}).Inc()
return nil, false, errBadSequence
return nil, false, ErrBadSequence
}
if m.Header.Pid != pid {
log.Printf("Wrong pid %d, expected %d", m.Header.Pid, pid)
metrics.ErrorCount.With(prometheus.Labels{"source": "wrong pid"}).Inc()
return nil, false, errBadPid
return nil, false, ErrBadPid
}
if m.Header.Type == unix.NLMSG_DONE {
return nil, false, nil
Expand Down Expand Up @@ -131,12 +134,12 @@ func OneType(inetType uint8) ([]*syscall.NetlinkMessage, error) {
// TODO avoid the copy.
for i := range msgs {
m, shouldContinue, err := processSingleMessage(&msgs[i], req.Seq, pid)
if m != nil {
res = append(res, m)
}
if err != nil {
return res, err
}
if m != nil {
res = append(res, m)
}
if !shouldContinue {
return res, nil
}
Expand Down

0 comments on commit b11b7fd

Please sign in to comment.