diff --git a/.travis.yml b/.travis.yml index 0d8f3d7..295b2db 100644 --- a/.travis.yml +++ b/.travis.yml @@ -56,7 +56,7 @@ script: - cd $TRAVIS_BUILD_DIR # To start, run all the non-integration tests. -- MODULES="inetdiag zstd nl-proto/tools" +- MODULES="inetdiag zstd nl-proto/tools cache" - for module in $MODULES; do COVER_PKGS=${COVER_PKGS}./$module/..., ; done diff --git a/README.md b/README.md index b0b16ee..fa5ffbf 100644 --- a/README.md +++ b/README.md @@ -11,13 +11,13 @@ This repository uses protobuffers and zstd. To build it locally you will need t compiler ```bash -`wget https://github.com/google/protobuf/releases/download/v3.5.1/protoc-3.5.1-linux-x86_64.zip` -`unzip protoc-3.5.1-linux-x86_64.zip` -`cd nl-proto && ../bin/protoc --go_out=. *.proto` +wget https://github.com/google/protobuf/releases/download/v3.5.1/protoc-3.5.1-linux-x86_64.zip +unzip protoc-3.5.1-linux-x86_64.zip +cd nl-proto && ../bin/protoc --go_out=. *.proto ``` To run the tools, you will also require zstd, which can be installed with: ```bash -`bash <(curl -fsSL https://raw.githubusercontent.com/horta/zstd.install/master/install)` +bash <(curl -fsSL https://raw.githubusercontent.com/horta/zstd.install/master/install) ``` \ No newline at end of file diff --git a/cache/cache.go b/cache/cache.go new file mode 100644 index 0000000..8b936f4 --- /dev/null +++ b/cache/cache.go @@ -0,0 +1,50 @@ +// Package cache keeps a cache of connection info records. +package cache + +import ( + "errors" + + "github.com/m-lab/tcp-info/inetdiag" +) + +// Package error messages +var ( + ErrInetDiagParseFailed = errors.New("Error parsing inetdiag message") + ErrLocal = errors.New("Connection is loopback") + ErrUnknownMessageType = errors.New("Unknown netlink message type") +) + +// Cache is a cache of all connection status. +type Cache struct { + // Map from inode to ParsedMessage + current map[uint32]*inetdiag.ParsedMessage // Cache of most recent messages. + previous map[uint32]*inetdiag.ParsedMessage // Cache of previous round of messages. +} + +// NewCache creates a cache object with capacity of 1000. +func NewCache() *Cache { + return &Cache{current: make(map[uint32]*inetdiag.ParsedMessage, 1000), + previous: make(map[uint32]*inetdiag.ParsedMessage, 0)} +} + +// Update swaps msg with the cache contents, and returns the evicted value. +func (c *Cache) Update(msg *inetdiag.ParsedMessage) *inetdiag.ParsedMessage { + inode := msg.InetDiagMsg.IDiagInode + c.current[inode] = msg + evicted, ok := c.previous[inode] + if ok { + delete(c.previous, inode) + } + return evicted +} + +// EndCycle marks the completion of updates from one set of netlink messages. +// It returns all messages that did not have corresponding inodes in the most recent +// batch of messages. +func (c *Cache) EndCycle() map[uint32]*inetdiag.ParsedMessage { + tmp := c.previous + c.previous = c.current + // Allocate a bit more than last time, to accommodate new connections. + c.current = make(map[uint32]*inetdiag.ParsedMessage, len(c.previous)+len(c.previous)/10+10) + return tmp +} diff --git a/cache/cache_test.go b/cache/cache_test.go new file mode 100644 index 0000000..9199567 --- /dev/null +++ b/cache/cache_test.go @@ -0,0 +1,43 @@ +package cache_test + +import ( + "testing" + + "github.com/m-lab/tcp-info/cache" + "github.com/m-lab/tcp-info/inetdiag" +) + +func TestUpdate(t *testing.T) { + c := cache.NewCache() + pm1 := inetdiag.ParsedMessage{InetDiagMsg: &inetdiag.InetDiagMsg{IDiagInode: 1234}} + old := c.Update(&pm1) + if old != nil { + t.Error("old should be nil") + } + pm2 := inetdiag.ParsedMessage{InetDiagMsg: &inetdiag.InetDiagMsg{IDiagInode: 4321}} + old = c.Update(&pm2) + if old != nil { + t.Error("old should be nil") + } + + leftover := c.EndCycle() + if len(leftover) > 0 { + t.Error("Should be empty") + } + + pm3 := inetdiag.ParsedMessage{InetDiagMsg: &inetdiag.InetDiagMsg{IDiagInode: 4321}} + old = c.Update(&pm3) + if old == nil { + t.Error("old should NOT be nil") + } + + leftover = c.EndCycle() + if len(leftover) != 1 { + t.Error("Should not be empty") + } + for k := range leftover { + if *leftover[k] != pm1 { + t.Error("Should have found pm1") + } + } +} diff --git a/inetdiag/inetdiag.go b/inetdiag/inetdiag.go index 3f8e838..f464547 100644 --- a/inetdiag/inetdiag.go +++ b/inetdiag/inetdiag.go @@ -29,15 +29,24 @@ expressed in host-byte order" */ import ( + "errors" "fmt" "log" "net" "syscall" "unsafe" + "golang.org/x/sys/unix" + tcpinfo "github.com/m-lab/tcp-info/nl-proto" ) +// Error types. +var ( + ErrParseFailed = errors.New("Unable to parse InetDiagMsg") + ErrNotType20 = errors.New("NetlinkMessage wrong type") +) + // Constants from linux. const ( TCPDIAG_GETSOCK = 18 // uapi/linux/inet_diag.h @@ -178,12 +187,6 @@ func (msg *InetDiagMsg) String() string { return fmt.Sprintf("%s, %s, %s", diagFamilyMap[msg.IDiagFamily], tcpinfo.TCPState(msg.IDiagState), msg.ID.String()) } -// rtaAlignOf round the length of a netlink route attribute up to align it -// properly. -func rtaAlignOf(attrlen int) int { - return (attrlen + syscall.RTA_ALIGNTO - 1) & ^(syscall.RTA_ALIGNTO - 1) -} - // ParseInetDiagMsg returns the InetDiagMsg itself, and the aligned byte array containing the message content. // Modified from original to also return attribute data array. func ParseInetDiagMsg(data []byte) (*InetDiagMsg, []byte) { @@ -195,3 +198,73 @@ func ParseInetDiagMsg(data []byte) (*InetDiagMsg, []byte) { } return (*InetDiagMsg)(unsafe.Pointer(&data[0])), data[rtaAlignOf(int(unsafe.Sizeof(InetDiagMsg{}))):] } + +// ParsedMessage is a container for parsed InetDiag messages and attributes. +type ParsedMessage struct { + Header syscall.NlMsghdr + InetDiagMsg *InetDiagMsg + Attributes [INET_DIAG_MAX]*syscall.NetlinkRouteAttr +} + +// Parse parsed the NetlinkMessage into a ParsedMessage. If skipLocal is true, it will return nil for +// loopback, local unicast, multicast, and unspecified connections. +func Parse(msg *syscall.NetlinkMessage, skipLocal bool) (*ParsedMessage, error) { + if msg.Header.Type != 20 { + return nil, ErrNotType20 + } + idm, attrBytes := ParseInetDiagMsg(msg.Data) + if idm == nil { + return nil, ErrParseFailed + } + if skipLocal { + srcIP := idm.ID.SrcIP() + if srcIP.IsLoopback() || srcIP.IsLinkLocalUnicast() || srcIP.IsMulticast() || srcIP.IsUnspecified() { + return nil, nil + } + dstIP := idm.ID.DstIP() + if dstIP.IsLoopback() || dstIP.IsLinkLocalUnicast() || dstIP.IsMulticast() || dstIP.IsUnspecified() { + return nil, nil + } + } + parsedMsg := ParsedMessage{Header: msg.Header, InetDiagMsg: idm} + attrs, err := ParseRouteAttr(attrBytes) + if err != nil { + return nil, err + } + for i := range attrs { + parsedMsg.Attributes[attrs[i].Attr.Type] = &attrs[i] + } + return &parsedMsg, nil +} + +/*********************************************************************************************/ +/* Copied from "github.com/vishvananda/netlink/nl/nl_linux.go" */ +/*********************************************************************************************/ + +// ParseRouteAttr parses a byte array into a NetlinkRouteAttr struct. +func ParseRouteAttr(b []byte) ([]syscall.NetlinkRouteAttr, error) { + var attrs []syscall.NetlinkRouteAttr + for len(b) >= unix.SizeofRtAttr { + a, vbuf, alen, err := netlinkRouteAttrAndValue(b) + if err != nil { + return nil, err + } + ra := syscall.NetlinkRouteAttr{Attr: syscall.RtAttr(*a), Value: vbuf[:int(a.Len)-unix.SizeofRtAttr]} + attrs = append(attrs, ra) + b = b[alen:] + } + return attrs, nil +} + +// rtaAlignOf rounds the length of a netlink route attribute up to align it properly. +func rtaAlignOf(attrlen int) int { + return (attrlen + unix.RTA_ALIGNTO - 1) & ^(unix.RTA_ALIGNTO - 1) +} + +func netlinkRouteAttrAndValue(b []byte) (*unix.RtAttr, []byte, int, error) { + a := (*unix.RtAttr)(unsafe.Pointer(&b[0])) + if int(a.Len) < unix.SizeofRtAttr || int(a.Len) > len(b) { + return nil, nil, 0, unix.EINVAL + } + return a, b[unix.SizeofRtAttr:], rtaAlignOf(int(a.Len)), nil +} diff --git a/inetdiag/inetdiag_test.go b/inetdiag/inetdiag_test.go index 50ce730..9ce0b5a 100644 --- a/inetdiag/inetdiag_test.go +++ b/inetdiag/inetdiag_test.go @@ -1,12 +1,14 @@ package inetdiag_test import ( + "encoding/json" "log" "syscall" "testing" "unsafe" "github.com/m-lab/tcp-info/inetdiag" + "golang.org/x/sys/unix" tcpinfo "github.com/m-lab/tcp-info/nl-proto" ) @@ -14,6 +16,11 @@ import ( // This is not exhaustive, but covers the basics. Integration tests will expose any more subtle // problems. +func init() { + // Always prepend the filename and line number. + log.SetFlags(log.LstdFlags | log.Lshortfile) +} + func TestSizes(t *testing.T) { if unsafe.Sizeof(inetdiag.InetDiagSockID{}) != 48 { t.Error("SockID wrong size", unsafe.Sizeof(inetdiag.InetDiagSockID{})) @@ -107,3 +114,75 @@ func TestID6(t *testing.T) { t.Errorf("Should not be identified as loopback") } } + +func TestParse(t *testing.T) { + var json1 = `{"Header":{"Len":356,"Type":20,"Flags":2,"Seq":1,"Pid":148940},"Data":"CgEAAOpWE6cmIAAAEAMEFbM+nWqBv4ehJgf4sEANDAoAAAAAAAAAgQAAAAAdWwAAAAAAAAAAAAAAAAAAAAAAAAAAAAC13zIBBQAIAAAAAAAFAAUAIAAAAAUABgAgAAAAFAABAAAAAAAAAAAAAAAAAAAAAAAoAAcAAAAAAICiBQAAAAAAALQAAAAAAAAAAAAAAAAAAAAAAAAAAAAArAACAAEAAAAAB3gBQIoDAECcAABEBQAAuAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAUCEAAAAAAAAgIQAAQCEAANwFAACsywIAJW8AAIRKAAD///9/CgAAAJQFAAADAAAALMkAAIBwAAAAAAAALnUOAAAAAAD///////////ayBAAAAAAASfQPAAAAAADMEQAANRMAAAAAAABiNQAAxAsAAGMIAABX5AUAAAAAAAoABABjdWJpYwAAAA=="}` + nm := syscall.NetlinkMessage{} + err := json.Unmarshal([]byte(json1), &nm) + if err != nil { + log.Fatal(err) + } + mp, err := inetdiag.Parse(&nm, true) + if err != nil { + log.Fatal(err) + } + if mp.Header.Len != 356 { + t.Error("wrong length") + } + if mp.InetDiagMsg.IDiagFamily != unix.AF_INET6 { + t.Error("Should not be IPv6") + } + if len(mp.Attributes) != inetdiag.INET_DIAG_MAX { + t.Error("Should be", inetdiag.INET_DIAG_MAX, "attribute entries") + } + + nonNil := 0 + for i := range mp.Attributes { + if mp.Attributes[i] != nil { + nonNil++ + } + } + if nonNil != 7 { + t.Error("Incorrect number of attribs") + } + + if mp.Attributes[inetdiag.INET_DIAG_INFO] == nil { + t.Error("Should not be nil") + } +} + +func TestParseGarbage(t *testing.T) { + // Json encoding of a good netlink message containing inet diag info. + var good = `{"Header":{"Len":356,"Type":20,"Flags":2,"Seq":1,"Pid":148940},"Data":"CgEAAOpWE6cmIAAAEAMEFbM+nWqBv4ehJgf4sEANDAoAAAAAAAAAgQAAAAAdWwAAAAAAAAAAAAAAAAAAAAAAAAAAAAC13zIBBQAIAAAAAAAFAAUAIAAAAAUABgAgAAAAFAABAAAAAAAAAAAAAAAAAAAAAAAoAAcAAAAAAICiBQAAAAAAALQAAAAAAAAAAAAAAAAAAAAAAAAAAAAArAACAAEAAAAAB3gBQIoDAECcAABEBQAAuAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAUCEAAAAAAAAgIQAAQCEAANwFAACsywIAJW8AAIRKAAD///9/CgAAAJQFAAADAAAALMkAAIBwAAAAAAAALnUOAAAAAAD///////////ayBAAAAAAASfQPAAAAAADMEQAANRMAAAAAAABiNQAAxAsAAGMIAABX5AUAAAAAAAoABABjdWJpYwAAAA=="}` + nm := syscall.NetlinkMessage{} + err := json.Unmarshal([]byte(good), &nm) + if err != nil { + log.Fatal(err) + } + // Replace the header type with one that we don't support. + nm.Header.Type = 10 + _, err = inetdiag.Parse(&nm, false) + if err == nil { + t.Error("Should detect wrong type") + } + + // Restore the header type. + nm.Header.Type = 20 + // Replace the payload with garbage. + for i := range nm.Data { + // Replace the attribute records with garbage + nm.Data[i] = byte(i) + } + + _, err = inetdiag.Parse(&nm, false) + if err == nil || err.Error() != "invalid argument" { + t.Error(err) + } + + // Replace length with garbage so that data is incomplete. + nm.Header.Len = 400 + _, err = inetdiag.Parse(&nm, false) + if err == nil || err.Error() != "invalid argument" { + t.Error(err) + } +} diff --git a/nl-proto/tools/convert_test.go b/nl-proto/tools/convert_test.go index 811f03b..8ba36ba 100644 --- a/nl-proto/tools/convert_test.go +++ b/nl-proto/tools/convert_test.go @@ -12,7 +12,6 @@ import ( tcpinfo "github.com/m-lab/tcp-info/nl-proto" "github.com/m-lab/tcp-info/nl-proto/tools" "github.com/m-lab/tcp-info/zstd" - "github.com/vishvananda/netlink/nl" ) func init() { @@ -44,34 +43,9 @@ var ( ) func convertToProto(msg *syscall.NetlinkMessage, t *testing.T) *tcpinfo.TCPDiagnosticsProto { - if msg.Header.Type != 20 { - t.Error("Skipping unknown message type:", msg.Header) - } - idm, attrBytes := inetdiag.ParseInetDiagMsg(msg.Data) - if idm == nil { - t.Error("Couldn't parse InetDiagMsg") - } - srcIP := idm.ID.SrcIP() - if srcIP.IsLoopback() || srcIP.IsLinkLocalUnicast() || srcIP.IsMulticast() || srcIP.IsUnspecified() { - return nil - } - dstIP := idm.ID.DstIP() - if dstIP.IsLoopback() || dstIP.IsLinkLocalUnicast() || dstIP.IsMulticast() || dstIP.IsUnspecified() { - return nil - } - type ParsedMessage struct { - Header syscall.NlMsghdr - InetDiagMsg *inetdiag.InetDiagMsg - Attributes [inetdiag.INET_DIAG_MAX]*syscall.NetlinkRouteAttr - } - - parsedMsg := ParsedMessage{Header: msg.Header, InetDiagMsg: idm} - attrs, err := nl.ParseRouteAttr(attrBytes) + parsedMsg, err := inetdiag.Parse(msg, true) if err != nil { - t.Error(err) - } - for i := range attrs { - parsedMsg.Attributes[attrs[i].Attr.Type] = &attrs[i] + t.Fatal(err) } return tools.CreateProto(msg.Header, parsedMsg.InetDiagMsg, parsedMsg.Attributes[:]) }