diff --git a/internal/unix/types_linux.go b/internal/unix/types_linux.go index 2ccf386..460b72b 100644 --- a/internal/unix/types_linux.go +++ b/internal/unix/types_linux.go @@ -15,6 +15,7 @@ const ( SizeofIfInfomsg = linux.SizeofIfInfomsg SizeofNdMsg = linux.SizeofNdMsg SizeofRtMsg = linux.SizeofRtMsg + SizeofRtNexthop = linux.SizeofRtNexthop RTM_NEWADDR = linux.RTM_NEWADDR RTM_DELADDR = linux.RTM_DELADDR RTM_GETADDR = linux.RTM_GETADDR @@ -73,6 +74,7 @@ const ( RTA_MARK = linux.RTA_MARK RTA_EXPIRES = linux.RTA_EXPIRES RTA_METRICS = linux.RTA_METRICS + RTA_MULTIPATH = linux.RTA_MULTIPATH RTAX_ADVMSS = linux.RTAX_ADVMSS RTAX_FEATURES = linux.RTAX_FEATURES RTAX_INITCWND = linux.RTAX_INITCWND diff --git a/internal/unix/types_other.go b/internal/unix/types_other.go index bd5b5a7..f86e791 100644 --- a/internal/unix/types_other.go +++ b/internal/unix/types_other.go @@ -11,6 +11,7 @@ const ( SizeofIfInfomsg = 0x10 SizeofNdMsg = 0xc SizeofRtMsg = 0xc + SizeofRtNexthop = 0x8 RTM_NEWADDR = 0x14 RTM_DELADDR = 0x15 RTM_GETADDR = 0x16 @@ -69,6 +70,7 @@ const ( RTA_MARK = 0x10 RTA_EXPIRES = 0x17 RTA_METRICS = 0x8 + RTA_MULTIPATH = 0x9 RTAX_ADVMSS = 0x8 RTAX_FEATURES = 0xc RTAX_INITCWND = 0xb diff --git a/route.go b/route.go index a45dc46..4766443 100644 --- a/route.go +++ b/route.go @@ -3,6 +3,7 @@ package rtnetlink import ( "errors" "net" + "unsafe" "github.com/jsimonetti/rtnetlink/internal/unix" @@ -47,7 +48,6 @@ func (m *RouteMessage) MarshalBinary() ([]byte, error) { nativeEndian.PutUint32(b[8:12], m.Flags) ae := netlink.NewAttributeEncoder() - ae.ByteOrder = nativeEndian err := m.Attributes.encode(ae) if err != nil { return nil, err @@ -80,7 +80,6 @@ func (m *RouteMessage) UnmarshalBinary(b []byte) error { if l > unix.SizeofRtMsg { m.Attributes = RouteAttributes{} ad, err := netlink.NewAttributeDecoder(b[unix.SizeofRtMsg:]) - ad.ByteOrder = nativeEndian if err != nil { return err } @@ -170,15 +169,16 @@ func (r *RouteService) List() ([]RouteMessage, error) { } type RouteAttributes struct { - Dst net.IP - Src net.IP - Gateway net.IP - OutIface uint32 - Priority uint32 - Table uint32 - Mark uint32 - Expires *uint32 - Metrics *RouteMetrics + Dst net.IP + Src net.IP + Gateway net.IP + OutIface uint32 + Priority uint32 + Table uint32 + Mark uint32 + Expires *uint32 + Metrics *RouteMetrics + Multipath []NextHop } func (a *RouteAttributes) decode(ad *netlink.AttributeDecoder) error { @@ -219,6 +219,8 @@ func (a *RouteAttributes) decode(ad *netlink.AttributeDecoder) error { case unix.RTA_METRICS: a.Metrics = &RouteMetrics{} ad.Nested(a.Metrics.decode) + case unix.RTA_MULTIPATH: + ad.Do(a.parseMultipath) } } @@ -226,7 +228,6 @@ func (a *RouteAttributes) decode(ad *netlink.AttributeDecoder) error { } func (a *RouteAttributes) encode(ae *netlink.AttributeEncoder) error { - if a.Dst != nil { if ipv4 := a.Dst.To4(); ipv4 == nil { // Dst Addr is IPv6 @@ -281,6 +282,10 @@ func (a *RouteAttributes) encode(ae *netlink.AttributeEncoder) error { ae.Nested(unix.RTA_METRICS, a.Metrics.encode) } + if len(a.Multipath) > 0 { + ae.Do(unix.RTA_MULTIPATH, a.encodeMultipath) + } + return nil } @@ -329,3 +334,127 @@ func (rm *RouteMetrics) encode(ae *netlink.AttributeEncoder) error { return nil } + +// TODO(mdlayher): probably eliminate Length field from the API to avoid the +// caller possibly tampering with it since we can compute it. + +// RTNextHop represents the netlink rtnexthop struct (not an attribute) +type RTNextHop struct { + Length uint16 // length of this hop including nested values + Flags uint8 // flags defined in rtnetlink.h line 311 + Hops uint8 + IfIndex uint32 // the interface index number +} + +// NextHop wraps struct rtnexthop to provide access to nested attributes +type NextHop struct { + Hop RTNextHop // a rtnexthop struct + Gateway net.IP // that struct's nested Gateway attribute +} + +func (a *RouteAttributes) encodeMultipath() ([]byte, error) { + var b []byte + for _, nh := range a.Multipath { + // Encode the attributes first so their total length can be used to + // compute the length of each (rtnexthop, attributes) pair. + ae := netlink.NewAttributeEncoder() + + if a.Gateway != nil { + // TODO(mdlayher): more validation. + ae.Bytes(unix.RTA_GATEWAY, nh.Gateway) + } + + ab, err := ae.Encode() + if err != nil { + return nil, err + } + + // Assume the caller wants the length updated so they don't have to + // keep track of it themselves when encoding attributes. + nh.Hop.Length = unix.SizeofRtNexthop + uint16(len(ab)) + var nhb [unix.SizeofRtNexthop]byte + + copy( + nhb[:], + (*(*[unix.SizeofRtNexthop]byte)(unsafe.Pointer(&nh.Hop)))[:], + ) + + // rtnexthop first, then attributes. + b = append(b, nhb[:]...) + b = append(b, ab...) + } + + return b, nil +} + +func (a *RouteAttributes) parseMultipath(b []byte) error { + // check for truncated message + if len(b) <= unix.SizeofRtNexthop { + return errInvalidRouteMessageAttr + } + + // Iterate through the nested array of rtnexthop, unpacking each and appending them to mp + for i := 0; i <= len(b); { + // check for end of message + if len(b)-i < unix.SizeofRtNexthop { + return nil + } + + // Copy over the struct portion + var nh NextHop + var nhb [unix.SizeofRtNexthop]byte + copy(nhb[:], b[i:i+unix.SizeofRtNexthop]) + + copy( + (*(*[unix.SizeofRtNexthop]byte)(unsafe.Pointer(&nh.Hop)))[:], + (*(*[unix.SizeofRtNexthop]byte)(unsafe.Pointer(&nhb[0])))[:], + ) + + // check again for a truncated message + if int(nh.Hop.Length) > len(b) { + return errInvalidRouteMessageAttr + } + + // grab a new attributedecoder for the nested attributes + start := (i + unix.SizeofRtNexthop) + end := (i + int(nh.Hop.Length)) + + ad, err := netlink.NewAttributeDecoder(b[start:end]) + if err != nil { + return err + } + + // read in the nested attributes + if err := nh.decode(ad); err != nil { + return err + } + + // append this hop to the parent Multipath struct + a.Multipath = append(a.Multipath, nh) + + // move forward to the next element in multipath.[]nexthop + i += int(nh.Hop.Length) + } + + return nil +} + +// TODO: Implement func (mp *RTMultiPath) encode() + +// rtnexthop payload is at least one nested attribute RTA_GATEWAY +// possibly others? +func (nh *NextHop) decode(ad *netlink.AttributeDecoder) error { + for ad.Next() { + switch ad.Type() { + case unix.RTA_GATEWAY: + l := len(ad.Bytes()) + if l != 4 && l != 16 { + return errInvalidRouteMessageAttr + } + + nh.Gateway = ad.Bytes() + } + } + + return ad.Err() +} diff --git a/route_test.go b/route_test.go index 5c2bca5..7dee8e6 100644 --- a/route_test.go +++ b/route_test.go @@ -63,6 +63,22 @@ func TestRouteMessageMarshalUnmarshalBinary(t *testing.T) { InitCwnd: 2, MTU: 1500, }, + Multipath: []NextHop{ + { + Hop: RTNextHop{ + Length: 16, + IfIndex: 1, + }, + Gateway: net.IPv4(10, 0, 0, 2), + }, + { + Hop: RTNextHop{ + Length: 16, + IfIndex: 2, + }, + Gateway: net.IPv4(10, 0, 0, 3), + }, + }, }, }, b: []byte{ @@ -128,6 +144,25 @@ func TestRouteMessageMarshalUnmarshalBinary(t *testing.T) { // MTU 0x08, 0x00, 0x02, 0x00, 0xdc, 0x05, 0x00, 0x00, + // Multipath + // + // 2 bytes length, 2 bytes type, then repeated 8 byte rtnexthop + // structures followed by their nested netlink attributes. + 0x24, 0x00, 0x09, 0x00, + // rtnexthop + 0x10, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + // rtnexthop attributes + 0x08, 0x00, 0x05, 0x00, + // Gateway + 10, 0, 0, 2, + // rtnexthop + 0x10, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + // rtnexthop attributes + 0x08, 0x00, 0x05, 0x00, + // Gateway + 10, 0, 0, 3, }, }, }