Skip to content

Commit

Permalink
Refactor inetdiag to become more testable
Browse files Browse the repository at this point in the history
Also test it.
  • Loading branch information
pboothe committed Oct 24, 2018
1 parent 5fd103f commit 7e97c19
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 55 deletions.
11 changes: 5 additions & 6 deletions inetdiag/inetdiag.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,10 @@ type ParsedMessage struct {
Attributes [INET_DIAG_MAX]*syscall.NetlinkRouteAttr
}

func isLocal(addr net.IP) bool {
return addr.IsLoopback() || addr.IsLinkLocalUnicast() || addr.IsMulticast() || addr.IsUnspecified()
}

// Parse parses the NetlinkMessage into a ParsedMessage. If skipLocal is true, it will return nil for
// loopback, local unicast, multicast, and unspecified connections.
// Note that Parse does not populate the Timestamp field, so caller should do so.
Expand All @@ -246,12 +250,7 @@ func Parse(msg *syscall.NetlinkMessage, skipLocal bool) (*ParsedMessage, error)
return nil, ErrParseFailed
}
if skipLocal {
srcIP := idm.ID.SrcIP()
if srcIP.IsLoopback() || srcIP.IsLinkLocalUnicast() || srcIP.IsMulticast() || srcIP.IsUnspecified() {
return nil, nil
}
dstIP := idm.ID.DstIP()
if dstIP.IsLoopback() || dstIP.IsLinkLocalUnicast() || dstIP.IsMulticast() || dstIP.IsUnspecified() {
if isLocal(idm.ID.SrcIP()) || isLocal(idm.ID.DstIP()) {
return nil, nil
}
}
Expand Down
47 changes: 46 additions & 1 deletion inetdiag/inetdiag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ func TestParseInetDiagMsg(t *testing.T) {
data[i] = byte(i + 2)
}
hdr, value := inetdiag.ParseInetDiagMsg(data[:])
if hdr.ID.Interface() == 0 || hdr.ID.Cookie() == 0 || hdr.ID.DPort() == 0 || hdr.ID.String() == "" {
t.Errorf("None of the accessed values should be zero")
}
if hdr.IDiagFamily != syscall.AF_INET {
t.Errorf("Failed %+v\n", hdr)
}
Expand All @@ -48,6 +51,11 @@ func TestParseInetDiagMsg(t *testing.T) {
if len(value) != 28 {
t.Error("Len", len(value))
}

hdr, value = inetdiag.ParseInetDiagMsg(data[:1])
if hdr != nil || value != nil {
t.Error("This should fail, the data is too small.")
}
}

func TestSerialize(t *testing.T) {
Expand Down Expand Up @@ -79,7 +87,10 @@ func TestID4(t *testing.T) {

hdr, _ := inetdiag.ParseInetDiagMsg(data[:])
if !hdr.ID.SrcIP().IsLoopback() {
log.Println(hdr.ID.SrcIP().IsLoopback())
t.Errorf("Should be loopback but isn't")
}
if hdr.ID.DstIP().IsLoopback() {
t.Errorf("Shouldn't be loopback but is")
}
if hdr.ID.SPort() != 0x3412 {
t.Errorf("SPort should be 0x3412 %+v\n", hdr.ID)
Expand Down Expand Up @@ -135,6 +146,9 @@ func TestParse(t *testing.T) {
if len(mp.Attributes) != inetdiag.INET_DIAG_MAX {
t.Error("Should be", inetdiag.INET_DIAG_MAX, "attribute entries")
}
if mp.InetDiagMsg.String() == "" {
t.Error("Empty string made from InetDiagMsg")
}

nonNil := 0
for i := range mp.Attributes {
Expand All @@ -149,6 +163,8 @@ func TestParse(t *testing.T) {
if mp.Attributes[inetdiag.INET_DIAG_INFO] == nil {
t.Error("Should not be nil")
}

// TODO: verify that skiplocal actually skips a message when src or dst is 127.0.0.1
}

func TestParseGarbage(t *testing.T) {
Expand All @@ -159,6 +175,15 @@ func TestParseGarbage(t *testing.T) {
if err != nil {
log.Fatal(err)
}

// Truncate the data down to something that makes no sense.
badNm := nm
badNm.Data = badNm.Data[:1]
_, err = inetdiag.Parse(&badNm, true)
if err == nil {
t.Error("The parse should have failed")
}

// Replace the header type with one that we don't support.
nm.Header.Type = 10
_, err = inetdiag.Parse(&nm, false)
Expand Down Expand Up @@ -186,3 +211,23 @@ func TestParseGarbage(t *testing.T) {
t.Error(err)
}
}

func TestOneType(t *testing.T) {
res4, err := inetdiag.OneType(syscall.AF_INET)
if err != nil {
t.Error(err)
}
res6, err := inetdiag.OneType(syscall.AF_INET6)
if err != nil {
t.Error(err)
}
resUnix, err := inetdiag.OneType(syscall.AF_UNIX)
if err != nil {
t.Error(err)
}
if len(res4) == 0 && len(res6) == 0 && len(resUnix) == 0 {
t.Error("There are never no active streams.")
}
}

// TODO: add whitebox testing of socket-monitor to exercise error handling.
107 changes: 61 additions & 46 deletions inetdiag/socket-monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import "C"
*/

import (
"errors"
"log"
"syscall"
"time"
Expand All @@ -23,6 +24,11 @@ import (

const TCPF_ALL = 0xFFF

var (
errBadPid = errors.New("Bad PID. Can't listen to NL socket.")
errBadSequence = errors.New("Bad sequence number. Can't interpret NetLink response.")
)

func makeReq(inetType uint8) *nl.NetlinkRequest {
req := nl.NewNetlinkRequest(SOCK_DIAG_BY_FAMILY, syscall.NLM_F_DUMP|syscall.NLM_F_REQUEST)
msg := NewInetDiagReqV2(inetType, syscall.IPPROTO_TCP,
Expand All @@ -46,88 +52,97 @@ func makeReq(inetType uint8) *nl.NetlinkRequest {
return req
}

func processSingleMessage(m *syscall.NetlinkMessage, seq uint32, pid uint32) (*syscall.NetlinkMessage, bool, error) {
if m.Header.Seq != seq {
log.Printf("Wrong Seq nr %d, expected %d", m.Header.Seq, seq)
metrics.ErrorCount.With(prometheus.Labels{"source": "wrong seq num"}).Inc()
return nil, false, errBadSequence
}
if m.Header.Pid != pid {
log.Printf("Wrong pid %d, expected %d", m.Header.Pid, pid)
metrics.ErrorCount.With(prometheus.Labels{"source": "wrong pid"}).Inc()
return nil, false, errBadPid
}
if m.Header.Type == unix.NLMSG_DONE {
return nil, false, nil
}
if m.Header.Type == unix.NLMSG_ERROR {
native := nl.NativeEndian()
error := int32(native.Uint32(m.Data[0:4]))
if error == 0 {
return nil, false, nil
}
log.Println(syscall.Errno(-error))
metrics.ErrorCount.With(prometheus.Labels{"source": "NLMSG_ERROR"}).Inc()
}
if m.Header.Flags&unix.NLM_F_MULTI == 0 {
return m, false, nil
}
return m, true, nil
}

// OneType handles the request and response for a single type, e.g. INET or INET6
// TODO maybe move this to top level?
func OneType(inetType uint8) []*syscall.NetlinkMessage {
func OneType(inetType uint8) ([]*syscall.NetlinkMessage, error) {
var res []*syscall.NetlinkMessage

start := time.Now()
defer func() {
af := "unknown"
switch inetType {
case syscall.AF_INET:
af = "ipv4"
case syscall.AF_INET6:
af = "ipv6"
}
metrics.FetchTimeMsecSummary.With(prometheus.Labels{"af": af}).Observe(1000 * time.Since(start).Seconds())
metrics.ConnectionCountSummary.With(prometheus.Labels{"af": af}).Observe(float64(len(res)))
}()

req := makeReq(inetType)

// Copied this from req.Execute in nl_linux.go
sockType := syscall.NETLINK_INET_DIAG
s, err := nl.Subscribe(sockType)
if err != nil {
log.Println(err)
return nil
return nil, err
}
defer s.Close()

if err := s.Send(req); err != nil {
log.Println(err)
return nil
return nil, err
}

pid, err := s.GetPid()
if err != nil {
log.Println(err)
return nil
return nil, err
}

var res []*syscall.NetlinkMessage

done:
// Adapted this from req.Execute in nl_linux.go
for {
msgs, err := s.Receive()
if err != nil {
log.Println(err)
return nil
return nil, err
}
// TODO avoid the copy.
for i := range msgs {
m := &msgs[i]
if m.Header.Seq != req.Seq {
log.Printf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq)
metrics.ErrorCount.With(prometheus.Labels{"source": "wrong seq num"}).Inc()
return nil
m, shouldContinue, err := processSingleMessage(&msgs[i], req.Seq, pid)
if m != nil {
res = append(res, m)
}
if m.Header.Pid != pid {
log.Printf("Wrong pid %d, expected %d", m.Header.Pid, pid)
metrics.ErrorCount.With(prometheus.Labels{"source": "wrong pid"}).Inc()
return nil
if err != nil {
return res, err
}
if m.Header.Type == unix.NLMSG_DONE {
break done
}
if m.Header.Type == unix.NLMSG_ERROR {
native := nl.NativeEndian()
error := int32(native.Uint32(m.Data[0:4]))
if error == 0 {
break done
}
log.Println(syscall.Errno(-error))
metrics.ErrorCount.With(prometheus.Labels{"source": "NLMSG_ERROR"}).Inc()
if !shouldContinue {
return res, nil
}
// if resType != 0 && m.Header.Type != resType {
// continue
// }
res = append(res, m)
if m.Header.Flags&unix.NLM_F_MULTI == 0 {
break done
}
}
}

switch inetType {
case syscall.AF_INET:
metrics.FetchTimeMsecSummary.With(prometheus.Labels{"af": "ipv4"}).Observe(1000 * time.Since(start).Seconds())
metrics.ConnectionCountSummary.With(prometheus.Labels{"af": "ipv4"}).Observe(float64(len(res)))
case syscall.AF_INET6:
metrics.FetchTimeMsecSummary.With(prometheus.Labels{"af": "ipv6"}).Observe(1000 * time.Since(start).Seconds())
metrics.ConnectionCountSummary.With(prometheus.Labels{"af": "ipv6"}).Observe(float64(len(res)))
default:
metrics.FetchTimeMsecSummary.With(prometheus.Labels{"af": "unknown"}).Observe(1000 * time.Since(start).Seconds())
metrics.ConnectionCountSummary.With(prometheus.Labels{"af": "unknown"}).Observe(float64(len(res)))
}

return res
}
4 changes: 2 additions & 2 deletions play.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func ParseAndQueue(cache *cache.Cache, msg *syscall.NetlinkMessage, queue bool)
func Demo(cache *cache.Cache, svr chan<- []*inetdiag.ParsedMessage) (int, int) {
all := make([]*inetdiag.ParsedMessage, 0, 500)
remoteCount := 0
res6 := inetdiag.OneType(syscall.AF_INET6)
res6, _ := inetdiag.OneType(syscall.AF_INET6) // Ignoring errors in Demo code
ts := time.Now()
for i := range res6 {
pm := ParseAndQueue(cache, res6[i], false)
Expand All @@ -167,7 +167,7 @@ func Demo(cache *cache.Cache, svr chan<- []*inetdiag.ParsedMessage) (int, int) {
}
}

res4 := inetdiag.OneType(syscall.AF_INET)
res4, _ := inetdiag.OneType(syscall.AF_INET) // Ignoring errors in Demo code
ts = time.Now()
for i := range res4 {
pm := ParseAndQueue(cache, res4[i], false)
Expand Down

0 comments on commit 7e97c19

Please sign in to comment.