From df3d675005fa8bd6e15a6db3c3d64ef0bde489c9 Mon Sep 17 00:00:00 2001 From: Motok1 Date: Fri, 13 Jan 2023 15:12:51 +0000 Subject: [PATCH] Use netip package --- cmd/pola/grpc_client.go | 39 +++++++++++++------- cmd/pola/session.go | 5 +-- cmd/pola/sr_policy_list.go | 11 +++--- internal/pkg/cspf/cspf.go | 8 ++-- internal/pkg/gobgp/interface.go | 15 +++++--- internal/pkg/table/ted.go | 26 ++++++------- pkg/packet/pcep/message.go | 13 ++++--- pkg/packet/pcep/object.go | 65 ++++++++++++++++++--------------- pkg/server/grpc_server.go | 31 +++++++++------- pkg/server/server.go | 57 +++++++++++++++-------------- pkg/server/session.go | 20 ++++------ 11 files changed, 159 insertions(+), 131 deletions(-) diff --git a/cmd/pola/grpc_client.go b/cmd/pola/grpc_client.go index f399680..effc0d4 100644 --- a/cmd/pola/grpc_client.go +++ b/cmd/pola/grpc_client.go @@ -8,7 +8,7 @@ package main import ( "context" "errors" - "net" + "net/netip" "time" "github.com/golang/protobuf/ptypes/empty" @@ -17,16 +17,16 @@ import ( ) type srPolicyInfo struct { - peerAddr net.IP //TODO: Change to ("loopback addr" or "router name") + peerAddr netip.Addr //TODO: Change to ("loopback addr" or "router name") name string path []uint32 - srcAddr net.IP - dstAddr net.IP + srcAddr netip.Addr + dstAddr netip.Addr color uint32 preference uint32 } -func getSessionAddrList(client pb.PceServiceClient) ([]net.IP, error) { +func getSessionAddrList(client pb.PceServiceClient) ([]netip.Addr, error) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() var empty empty.Empty @@ -34,9 +34,10 @@ func getSessionAddrList(client pb.PceServiceClient) ([]net.IP, error) { if err != nil { return nil, err } - var peerAddrList []net.IP + var peerAddrList []netip.Addr for _, peerAddr := range ret.GetPeerAddrs() { - peerAddrList = append(peerAddrList, net.IP(peerAddr)) + peer, _ := netip.AddrFromSlice(peerAddr) + peerAddrList = append(peerAddrList, peer) } return peerAddrList, nil } @@ -51,11 +52,14 @@ func getSrPolicyList(client pb.PceServiceClient) ([]srPolicyInfo, error) { } srPolicyList := []srPolicyInfo{} for _, lsp := range ret.GetSrPolicies() { + peerAddr, _ := netip.AddrFromSlice(lsp.GetPcepSessionAddr()) + srcAddr, _ := netip.AddrFromSlice(lsp.GetSrcAddr()) + dstAddr, _ := netip.AddrFromSlice(lsp.GetDstAddr()) tmp := srPolicyInfo{ name: lsp.PolicyName, - peerAddr: net.IP(lsp.GetPcepSessionAddr()), - srcAddr: net.IP(lsp.GetSrcAddr()), - dstAddr: net.IP(lsp.GetDstAddr()), + peerAddr: peerAddr, + srcAddr: srcAddr, + dstAddr: dstAddr, color: lsp.Color, preference: lsp.Preference, } @@ -123,8 +127,14 @@ func getTed(client pb.PceServiceClient) (*table.LsTed, error) { for _, link := range node.LsLinks { lsLink := table.NewLsLink(ted.Nodes[link.LocalAsn][link.LocalRouterId], ted.Nodes[link.RemoteAsn][link.RemoteRouterId]) lsLink.AdjSid = link.GetAdjSid() - lsLink.LocalIP = net.ParseIP(link.GetLocalIp()) - lsLink.RemoteIP = net.ParseIP(link.GetRemoteIp()) + lsLink.LocalIP, err = netip.ParseAddr(link.GetLocalIp()) + if err != nil { + return nil, err + } + lsLink.RemoteIP, err = netip.ParseAddr(link.GetRemoteIp()) + if err != nil { + return nil, err + } for _, metricInfo := range link.GetMetrics() { var metric *table.Metric switch metricInfo.GetType().String() { @@ -146,7 +156,10 @@ func getTed(client pb.PceServiceClient) (*table.LsTed, error) { for _, prefix := range node.LsPrefixes { lsPrefix := table.NewLsPrefixV4(ted.Nodes[node.GetAsn()][node.GetRouterId()]) - _, lsPrefix.Prefix, _ = net.ParseCIDR(prefix.GetPrefix()) + lsPrefix.Prefix, err = netip.ParsePrefix(prefix.GetPrefix()) + if err != nil { + return nil, err + } lsPrefix.SidIndex = prefix.GetSidIndex() ted.Nodes[node.GetAsn()][node.GetRouterId()].Prefixes = append(ted.Nodes[node.GetAsn()][node.GetRouterId()].Prefixes, lsPrefix) } diff --git a/cmd/pola/session.go b/cmd/pola/session.go index d7f0685..084de40 100644 --- a/cmd/pola/session.go +++ b/cmd/pola/session.go @@ -8,7 +8,6 @@ package main import ( "encoding/json" "fmt" - "net" "github.com/spf13/cobra" ) @@ -37,7 +36,7 @@ func showSession(jsonFlag bool) error { peerAddrs := []map[string]string{} for _, peerAddr := range sessionAddrList { peerAddrInfo := map[string]string{ - "address": net.IP(peerAddr).String(), + "address": peerAddr.String(), "status": "active", } peerAddrs = append(peerAddrs, peerAddrInfo) @@ -53,7 +52,7 @@ func showSession(jsonFlag bool) error { } else { //output user-friendly format for i, peerAddr := range sessionAddrList { - fmt.Printf("sessionAddr(%d): %v\n", i, net.IP(peerAddr)) + fmt.Printf("sessionAddr(%d): %s\n", i, peerAddr.String()) } } return nil diff --git a/cmd/pola/sr_policy_list.go b/cmd/pola/sr_policy_list.go index 6b77abc..5f964d7 100644 --- a/cmd/pola/sr_policy_list.go +++ b/cmd/pola/sr_policy_list.go @@ -8,7 +8,6 @@ package main import ( "encoding/json" "fmt" - "net" "github.com/spf13/cobra" ) @@ -41,8 +40,8 @@ func showSrPolicyList(jsonFlag bool) error { tmp := map[string]interface{}{ // TODO: Fix format according to readme "peerAddr": lsp.peerAddr.String(), "policyName": lsp.name, - "srcAddr": net.IP(lsp.srcAddr).String(), - "dstAddr": net.IP(lsp.dstAddr).String(), + "srcAddr": lsp.srcAddr.String(), + "dstAddr": lsp.dstAddr.String(), "color": lsp.color, "preference": lsp.preference, "segmentList": lsp.path, @@ -67,11 +66,11 @@ func showSrPolicyList(jsonFlag bool) error { fmt.Printf("LSP(%d): \n", i) fmt.Printf(" PcepSessionAddr: %s\n", lsp.peerAddr) fmt.Printf(" PolicyName: %s\n", lsp.name) - fmt.Printf(" SrcAddr: %s\n", net.IP(lsp.srcAddr)) - fmt.Printf(" DstAddr: %s\n", net.IP(lsp.dstAddr)) + fmt.Printf(" SrcAddr: %s\n", lsp.srcAddr.String()) + fmt.Printf(" DstAddr: %s\n", lsp.dstAddr.String()) fmt.Printf(" Color: %d\n", lsp.color) fmt.Printf(" Preference: %d\n", lsp.preference) - fmt.Printf(" DstAddr: %s\n", net.IP(lsp.dstAddr)) + fmt.Printf(" DstAddr: %s\n", lsp.dstAddr.String()) fmt.Printf(" SegmentList: ") if len(lsp.path) == 0 { diff --git a/internal/pkg/cspf/cspf.go b/internal/pkg/cspf/cspf.go index 87bde9c..76c2bf6 100644 --- a/internal/pkg/cspf/cspf.go +++ b/internal/pkg/cspf/cspf.go @@ -7,7 +7,7 @@ package cspf import ( "errors" - "net" + "net/netip" "github.com/nttcom/pola/internal/pkg/table" "github.com/nttcom/pola/pkg/packet/pcep" @@ -19,10 +19,10 @@ type node struct { cost uint32 prevNode string nodeSid uint32 - LoAddr net.IP + LoAddr netip.Addr } -func newNode(id string, cost uint32, nodeSid uint32, loAddr net.IP) *node { +func newNode(id string, cost uint32, nodeSid uint32, loAddr netip.Addr) *node { node := &node{ id: id, cost: cost, @@ -105,7 +105,7 @@ func spf(srcRouterId string, dstRouterId string, metric table.MetricType, networ for pathNode := calculatingNodes[dstRouterId]; pathNode.id != srcRouterId; pathNode = calculatingNodes[pathNode.prevNode] { segment := pcep.Label{ Sid: pathNode.nodeSid, - LoAddr: pathNode.LoAddr.To4(), + LoAddr: pathNode.LoAddr, } if len(segmentList) == 0 { segmentList = append(segmentList, segment) diff --git a/internal/pkg/gobgp/interface.go b/internal/pkg/gobgp/interface.go index 5cd78c3..e8f8799 100644 --- a/internal/pkg/gobgp/interface.go +++ b/internal/pkg/gobgp/interface.go @@ -10,7 +10,7 @@ import ( "encoding/hex" "errors" "io" - "net" + "net/netip" "strings" "github.com/nttcom/pola/internal/pkg/table" @@ -127,9 +127,14 @@ func ConvertToTedElem(dst *api.Destination) ([]table.TedElem, error) { localNodeAsn := typedLinkStateNlri.GetLocalNode().GetAsn() remoteNodeId := typedLinkStateNlri.GetRemoteNode().GetIgpRouterId() remoteNodeAsn := typedLinkStateNlri.GetRemoteNode().GetAsn() - localIP := net.ParseIP(typedLinkStateNlri.GetLinkDescriptor().GetInterfaceAddrIpv4()) - remoteIP := net.ParseIP(typedLinkStateNlri.GetLinkDescriptor().GetNeighborAddrIpv4()) - + localIP, err := netip.ParseAddr(typedLinkStateNlri.GetLinkDescriptor().GetInterfaceAddrIpv4()) + if err != nil { + return nil, err + } + remoteIP, err := netip.ParseAddr(typedLinkStateNlri.GetLinkDescriptor().GetNeighborAddrIpv4()) + if err != nil { + return nil, err + } localNode := table.NewLsNode(localNodeAsn, localNodeId) remoteNode := table.NewLsNode(remoteNodeAsn, remoteNodeId) lsLink := table.NewLsLink(localNode, remoteNode) @@ -191,7 +196,7 @@ func ConvertToTedElem(dst *api.Destination) ([]table.TedElem, error) { if len(prefixV4) != 1 { return nil, errors.New("invalid prefix length") } - _, lsPrefixV4.Prefix, _ = net.ParseCIDR(prefixV4[0]) + lsPrefixV4.Prefix, _ = netip.ParsePrefix(prefixV4[0]) tedElems = append(tedElems, lsPrefixV4) } } diff --git a/internal/pkg/table/ted.go b/internal/pkg/table/ted.go index 779e2bf..fcd4924 100644 --- a/internal/pkg/table/ted.go +++ b/internal/pkg/table/ted.go @@ -8,7 +8,7 @@ package table import ( "errors" "fmt" - "net" + "net/netip" ) type LsTed struct { @@ -110,14 +110,14 @@ func (node LsNode) NodeSid() (uint32, error) { return 0, errors.New("node doesn't have node-sid") } -func (node LsNode) LoopbackAddr() (net.IP, error) { +func (node LsNode) LoopbackAddr() (netip.Addr, error) { for _, prefix := range node.Prefixes { // If it's a loopback prefix, it should be non-zero. if prefix.SidIndex != 0 { - return prefix.Prefix.IP, nil + return prefix.Prefix.Addr(), nil } } - return nil, errors.New("node doesn't have loopback addr") + return netip.Addr{}, errors.New("node doesn't have loopback addr") } func (lsNode *LsNode) UpdateTed(ted *LsTed) { @@ -136,12 +136,12 @@ func (lsNode *LsNode) UpdateTed(ted *LsTed) { } type LsLink struct { - LocalNode *LsNode // primary key, in MP_REACH_NLRI Attr - RemoteNode *LsNode // primary key, in MP_REACH_NLRI Attr - LocalIP net.IP // in MP_REACH_NLRI Attr - RemoteIP net.IP // in MP_REACH_NLRI Attr - Metrics []*Metric // in BGP-LS Attr - AdjSid uint32 // in BGP-LS Attr + LocalNode *LsNode // primary key, in MP_REACH_NLRI Attr + RemoteNode *LsNode // primary key, in MP_REACH_NLRI Attr + LocalIP netip.Addr // in MP_REACH_NLRI Attr + RemoteIP netip.Addr // in MP_REACH_NLRI Attr + Metrics []*Metric // in BGP-LS Attr + AdjSid uint32 // in BGP-LS Attr } func NewLsLink(localNode *LsNode, remoteNode *LsNode) *LsLink { @@ -178,9 +178,9 @@ func (lsLink *LsLink) UpdateTed(ted *LsTed) { } type LsPrefixV4 struct { - LocalNode *LsNode // primary key, in MP_REACH_NLRI Attr - Prefix *net.IPNet // in MP_REACH_NLRI Attr - SidIndex uint32 // in BGP-LS Attr (only for Lo Address Prefix) + LocalNode *LsNode // primary key, in MP_REACH_NLRI Attr + Prefix netip.Prefix // in MP_REACH_NLRI Attr + SidIndex uint32 // in BGP-LS Attr (only for Lo Address Prefix) } func NewLsPrefixV4(localNode *LsNode) *LsPrefixV4 { diff --git a/pkg/packet/pcep/message.go b/pkg/packet/pcep/message.go index 29908e0..71c582f 100644 --- a/pkg/packet/pcep/message.go +++ b/pkg/packet/pcep/message.go @@ -5,7 +5,10 @@ package pcep -import "fmt" +import ( + "fmt" + "net/netip" +) // Open Message type OpenMessage struct { @@ -153,7 +156,7 @@ type PCInitiateMessage struct { VendorInformationObject *VendorInformationObject } -func NewPCInitiateMessage(srpId uint32, lspName string, labels []Label, color uint32, preference uint32, srcIPv4 []uint8, dstIPv4 []uint8, opt ...Opt) (PCInitiateMessage, error) { +func NewPCInitiateMessage(srpId uint32, lspName string, labels []Label, color uint32, preference uint32, srcAddr netip.Addr, dstAddr netip.Addr, opt ...Opt) (PCInitiateMessage, error) { opts := optParams{ pccType: RFC_COMPLIANT, } @@ -165,18 +168,18 @@ func NewPCInitiateMessage(srpId uint32, lspName string, labels []Label, color ui var pcInitiateMessage PCInitiateMessage pcInitiateMessage.SrpObject = NewSrpObject(srpId, false) pcInitiateMessage.LspObject = NewLspObject(lspName, 0) // PLSP-ID = 0 - pcInitiateMessage.EndpointsObject = NewEndpointsObject(1, dstIPv4, srcIPv4) // objectType = 1 (IPv4) + pcInitiateMessage.EndpointsObject = NewEndpointsObject(1, dstAddr, srcAddr) // objectType = 1 (IPv4) var err error pcInitiateMessage.EroObject, err = NewEroObject(labels) if err != nil { return pcInitiateMessage, err } if opts.pccType == JUNIPER_LEGACY { - pcInitiateMessage.AssociationObject = NewAssociationObject(srcIPv4, dstIPv4, color, preference, VendorSpecific(opts.pccType)) + pcInitiateMessage.AssociationObject = NewAssociationObject(srcAddr, dstAddr, color, preference, VendorSpecific(opts.pccType)) } else if opts.pccType == CISCO_LEGACY { pcInitiateMessage.VendorInformationObject = NewVendorInformationObject(CISCO_LEGACY, color, preference) } else if opts.pccType == RFC_COMPLIANT { - pcInitiateMessage.AssociationObject = NewAssociationObject(srcIPv4, dstIPv4, color, preference) + pcInitiateMessage.AssociationObject = NewAssociationObject(srcAddr, dstAddr, color, preference) // FRRouting is treated as an RFC compliant pcInitiateMessage.VendorInformationObject = NewVendorInformationObject(CISCO_LEGACY, color, preference) } diff --git a/pkg/packet/pcep/object.go b/pkg/packet/pcep/object.go index f4a158a..9df4121 100644 --- a/pkg/packet/pcep/object.go +++ b/pkg/packet/pcep/object.go @@ -9,7 +9,7 @@ import ( "encoding/binary" "errors" "math" - "net" + "net/netip" ) type PccType int @@ -303,7 +303,7 @@ func DecodeTLVsFromBytes(data []uint8) ([]Tlv, error) { type Label struct { Sid uint32 - LoAddr []uint8 + LoAddr netip.Addr } type optParams struct { @@ -525,8 +525,8 @@ func NewSrpObject(srpId uint32, isRemove bool) *SrpObject { // LSP Object (RFC8281 5.3.1) type LspObject struct { Name string - SrcAddr net.IP - DstAddr net.IP + SrcAddr netip.Addr + DstAddr netip.Addr PlspId uint32 OFlag uint8 AFlag bool @@ -554,8 +554,13 @@ func (o *LspObject) DecodeFromBytes(objectBody []uint8) error { } if tlv.Type == uint16(TLV_IPV4_LSP_IDENTIFIERS) { // TODO: Obtain true srcAddr - o.SrcAddr = net.IP(tlv.Value[0:4]) - o.DstAddr = net.IP(tlv.Value[12:16]) + var ok bool + if o.SrcAddr, ok = netip.AddrFromSlice(tlv.Value[0:4]); !ok { + return errors.New("lsp tlv decode error") + } + if o.DstAddr, ok = netip.AddrFromSlice(tlv.Value[12:16]); !ok { + return errors.New("lsp tlv decode error") + } } o.Tlvs = append(o.Tlvs, tlv) @@ -726,7 +731,7 @@ type SrEroSubobject struct { CFlag bool MFlag bool Sid uint32 - Nai []uint8 + Nai netip.Addr } func (o *SrEroSubobject) DecodeFromBytes(subObj []uint8) { @@ -739,7 +744,9 @@ func (o *SrEroSubobject) DecodeFromBytes(subObj []uint8) { o.CFlag = (subObj[3] & 0x02) != 0 o.MFlag = (subObj[3] & 0x01) != 0 o.Sid = binary.BigEndian.Uint32(subObj[4:8]) >> 12 - o.Nai = subObj[8:12] + if o.NaiType == 1 { + o.Nai, _ = netip.AddrFromSlice(subObj[8:12]) + } } func (o *SrEroSubobject) Serialize() []uint8 { @@ -765,7 +772,7 @@ func (o *SrEroSubobject) Serialize() []uint8 { byteSid := make([]uint8, 4) binary.BigEndian.PutUint32(byteSid, o.Sid<<12) - byteSrEroSubobject := AppendByteSlices(buf, byteSid, o.Nai) + byteSrEroSubobject := AppendByteSlices(buf, byteSid, o.Nai.AsSlice()) return byteSrEroSubobject } @@ -784,7 +791,7 @@ func (o SrEroSubobject) getByteLength() (uint16, error) { } } -func NewSrEroSubObject(sid uint32, loAddr []uint8) (SrEroSubobject, error) { +func NewSrEroSubObject(sid uint32, loAddr netip.Addr) (SrEroSubobject, error) { srEroSubObject := SrEroSubobject{ LFlag: false, SubobjectType: ERO_SUBOBJECT_SR, @@ -820,14 +827,14 @@ const ( // END-POINTS Object (RFC5440 7.6) type EndpointsObject struct { ObjectType uint8 // IPv4: 1, IPv6: 2 - srcIPv4 []uint8 - dstIPv4 []uint8 + SrcAddr netip.Addr + DstAddr netip.Addr } func (o EndpointsObject) Serialize() []uint8 { EndpointsObjectHeader := NewCommonObjectHeader(OC_END_POINTS, 1, o.getByteLength()) byteEroObjectHeader := EndpointsObjectHeader.Serialize() - byteEndpointsObject := AppendByteSlices(byteEroObjectHeader, o.srcIPv4, o.dstIPv4) + byteEndpointsObject := AppendByteSlices(byteEroObjectHeader, o.SrcAddr.AsSlice(), o.DstAddr.AsSlice()) return byteEndpointsObject } @@ -837,23 +844,23 @@ func (o EndpointsObject) getByteLength() uint16 { return uint16(COMMON_OBJECT_HEADER_LENGTH + 4 + 4) } -func NewEndpointsObject(objType uint8, dstIPv4 []uint8, srcIPv4 []uint8) *EndpointsObject { +func NewEndpointsObject(objType uint8, dstAddr netip.Addr, srcAddr netip.Addr) *EndpointsObject { // TODO: Expantion for IPv6 Endpoint EndpointsObject := &EndpointsObject{ ObjectType: objType, - dstIPv4: dstIPv4, - srcIPv4: srcIPv4, + DstAddr: dstAddr, + SrcAddr: srcAddr, } return EndpointsObject } // ASSOCIATION Object (RFC8697 6.) type AssociationObject struct { - RFlag bool - AssocType uint16 - AssocId uint16 - Ipv4AssocSrc []uint8 - Tlvs []Tlv + RFlag bool + AssocType uint16 + AssocId uint16 + AssocSrc netip.Addr + Tlvs []Tlv } // (I.D. pce-segment-routing-policy-cp-08 5.1) @@ -902,7 +909,7 @@ func (o *AssociationObject) DecodeFromBytes(objectBody []uint8) error { o.RFlag = (objectBody[3] & 0x01) != 0 o.AssocType = uint16(binary.BigEndian.Uint16(objectBody[4:6])) o.AssocId = uint16(binary.BigEndian.Uint16(objectBody[6:8])) - o.Ipv4AssocSrc = objectBody[8:12] + o.AssocSrc, _ = netip.AddrFromSlice(objectBody[8:12]) if len(objectBody) > 12 { byteTlvs := objectBody[12:] for { @@ -942,7 +949,7 @@ func (o AssociationObject) Serialize() []uint8 { } byteAssociationObject := AppendByteSlices( - byteAssociationObjectHeader, buf, assocType, assocId, o.Ipv4AssocSrc, byteTlvs, + byteAssociationObjectHeader, buf, assocType, assocId, o.AssocSrc.AsSlice(), byteTlvs, ) return byteAssociationObject } @@ -957,7 +964,7 @@ func (o AssociationObject) getByteLength() uint16 { return COMMON_OBJECT_HEADER_LENGTH + associationObjectBodyLength } -func NewAssociationObject(srcIPv4 []uint8, dstIPv4 []uint8, color uint32, preference uint32, opt ...Opt) *AssociationObject { +func NewAssociationObject(srcAddr netip.Addr, dstAddr netip.Addr, color uint32, preference uint32, opt ...Opt) *AssociationObject { opts := optParams{ pccType: RFC_COMPLIANT, } @@ -968,9 +975,9 @@ func NewAssociationObject(srcIPv4 []uint8, dstIPv4 []uint8, color uint32, prefer // TODO: Expantion for IPv6 Endpoint associationObject := &AssociationObject{ - RFlag: false, - Tlvs: []Tlv{}, - Ipv4AssocSrc: srcIPv4, + RFlag: false, + Tlvs: []Tlv{}, + AssocSrc: srcAddr, } if opts.pccType == JUNIPER_LEGACY { associationObject.AssocId = 0 @@ -980,7 +987,7 @@ func NewAssociationObject(srcIPv4 []uint8, dstIPv4 []uint8, color uint32, prefer Type: JUNIPER_SPEC_TLV_EXTENDED_ASSOCIATION_ID, Length: TLV_EXTENDED_ASSOCIATION_ID_LENGTH, // TODO: 20 if ipv6 endpoint Value: AppendByteSlices( - uint32ToListUint8(color), dstIPv4, + uint32ToListUint8(color), dstAddr.AsSlice(), ), }, { @@ -1009,7 +1016,7 @@ func NewAssociationObject(srcIPv4 []uint8, dstIPv4 []uint8, color uint32, prefer Type: TLV_EXTENDED_ASSOCIATION_ID, Length: TLV_EXTENDED_ASSOCIATION_ID_LENGTH, // TODO: 20 if ipv6 endpoint Value: AppendByteSlices( - uint32ToListUint8(color), dstIPv4, + uint32ToListUint8(color), dstAddr.AsSlice(), ), }, { diff --git a/pkg/server/grpc_server.go b/pkg/server/grpc_server.go index 8d2acb9..19b8054 100644 --- a/pkg/server/grpc_server.go +++ b/pkg/server/grpc_server.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "net" + "net/netip" "strings" "github.com/golang/protobuf/ptypes/empty" @@ -64,10 +65,11 @@ func (s *APIServer) CreateSrPolicy(ctx context.Context, input *pb.CreateSrPolicy return &pb.SrPolicyStatus{IsSuccess: false}, errors.New("input is invalid") } + pcepSessionAddr, _ := netip.AddrFromSlice(input.GetSrPolicy().GetPcepSessionAddr()) inputJson := map[string]interface{}{ "asn": fmt.Sprint(input.GetAsn()), "srPolicy": map[string]interface{}{ - "pcepSessionAddr": net.IP(input.GetSrPolicy().GetPcepSessionAddr()).String(), + "pcepSessionAddr": pcepSessionAddr.String(), "color": input.GetSrPolicy().GetColor(), "dstRouterId": input.GetSrPolicy().GetDstRouterId(), "srcRouterId": input.GetSrPolicy().GetSrcRouterId(), @@ -78,18 +80,17 @@ func (s *APIServer) CreateSrPolicy(ctx context.Context, input *pb.CreateSrPolicy } s.pce.logger.Info("Receive CreateSrPolicy API request", zap.Any("input", inputJson), zap.String("server", "grpc")) - pcepSessionAddr := net.IP(input.GetSrPolicy().GetPcepSessionAddr()) pcepSession := s.pce.getSession(pcepSessionAddr) if pcepSession == nil { return &pb.SrPolicyStatus{IsSuccess: false}, fmt.Errorf("no session with %s", pcepSessionAddr) } segmentList := []pcep.Label{} - srcIPAddr, err := s.pce.ted.Nodes[input.GetAsn()][input.GetSrPolicy().SrcRouterId].LoopbackAddr() + srcAddr, err := s.pce.ted.Nodes[input.GetAsn()][input.GetSrPolicy().SrcRouterId].LoopbackAddr() if err != nil { return &pb.SrPolicyStatus{IsSuccess: false}, err } - dstIPAddr, err := s.pce.ted.Nodes[input.GetAsn()][input.GetSrPolicy().DstRouterId].LoopbackAddr() + dstAddr, err := s.pce.ted.Nodes[input.GetAsn()][input.GetSrPolicy().DstRouterId].LoopbackAddr() if err != nil { return &pb.SrPolicyStatus{IsSuccess: false}, err } @@ -108,9 +109,10 @@ func (s *APIServer) CreateSrPolicy(ctx context.Context, input *pb.CreateSrPolicy if err != nil { return &pb.SrPolicyStatus{IsSuccess: false}, err } + pcepSegment := pcep.Label{ Sid: segment.GetSid(), - LoAddr: loAddr.To4(), + LoAddr: loAddr, } segmentList = append(segmentList, pcepSegment) } @@ -144,7 +146,7 @@ func (s *APIServer) CreateSrPolicy(ctx context.Context, input *pb.CreateSrPolicy return &pb.SrPolicyStatus{IsSuccess: false}, err } } else { - if err := pcepSession.SendPCInitiate(input.GetSrPolicy().GetPolicyName(), segmentList, input.GetSrPolicy().GetColor(), uint32(100), srcIPAddr.To4(), dstIPAddr.To4()); err != nil { + if err := pcepSession.SendPCInitiate(input.GetSrPolicy().GetPolicyName(), segmentList, input.GetSrPolicy().GetColor(), uint32(100), srcAddr, dstAddr); err != nil { return &pb.SrPolicyStatus{IsSuccess: false}, err } } @@ -159,16 +161,17 @@ func (s *APIServer) CreateSrPolicyWithoutLinkState(ctx context.Context, input *p } s.pce.logger.Info("Receive CreateSrPolicyWithoutLinkState API request", zap.Any("SR Policy", input.GetSrPolicy()), zap.String("server", "grpc")) - pcepSessionAddr := net.IP(input.GetSrPolicy().GetPcepSessionAddr()) + pcepSessionAddr, _ := netip.AddrFromSlice(input.GetSrPolicy().GetPcepSessionAddr()) pcepSession := s.pce.getSession(pcepSessionAddr) if pcepSession == nil { return &pb.SrPolicyStatus{IsSuccess: false}, fmt.Errorf("no session with %s", pcepSessionAddr) } segmentList := []pcep.Label{} for _, receivedLsp := range input.GetSrPolicy().GetSegmentList() { + loAddr, _ := netip.AddrFromSlice(receivedLsp.GetLoAddr()) pcepSegment := pcep.Label{ Sid: receivedLsp.GetSid(), - LoAddr: receivedLsp.GetLoAddr(), + LoAddr: loAddr, } segmentList = append(segmentList, pcepSegment) } @@ -179,7 +182,9 @@ func (s *APIServer) CreateSrPolicyWithoutLinkState(ctx context.Context, input *p return &pb.SrPolicyStatus{IsSuccess: false}, err } } else { - if err := pcepSession.SendPCInitiate(input.GetSrPolicy().GetPolicyName(), segmentList, input.GetSrPolicy().GetColor(), uint32(100), input.GetSrPolicy().GetSrcAddr(), input.GetSrPolicy().GetDstAddr()); err != nil { + srcAddr, _ := netip.AddrFromSlice(input.GetSrPolicy().GetSrcAddr()) + dstAddr, _ := netip.AddrFromSlice(input.GetSrPolicy().GetDstAddr()) + if err := pcepSession.SendPCInitiate(input.GetSrPolicy().GetPolicyName(), segmentList, input.GetSrPolicy().GetColor(), uint32(100), srcAddr, dstAddr); err != nil { return &pb.SrPolicyStatus{IsSuccess: false}, err } } @@ -190,7 +195,7 @@ func (s *APIServer) GetPeerAddrList(context.Context, *empty.Empty) (*pb.PeerAddr s.pce.logger.Info("Receive GetPeerAddrList API request", zap.String("server", "grpc")) var ret pb.PeerAddrList for _, pcepSession := range s.pce.sessionList { - ret.PeerAddrs = append(ret.PeerAddrs, []byte(pcepSession.peerAddr)) + ret.PeerAddrs = append(ret.PeerAddrs, pcepSession.peerAddr.AsSlice()) } s.pce.logger.Info("Send GetPeerAddrList API reply", zap.String("server", "grpc")) return &ret, nil @@ -201,13 +206,13 @@ func (s *APIServer) GetSrPolicyList(context.Context, *empty.Empty) (*pb.SrPolicy var ret pb.SrPolicyList for _, lsp := range s.pce.lspList { srPolicyData := &pb.SrPolicy{ - PcepSessionAddr: []byte(lsp.peerAddr), + PcepSessionAddr: lsp.peerAddr.AsSlice(), SegmentList: []*pb.Segment{}, Color: lsp.color, Preference: lsp.preference, PolicyName: lsp.name, - SrcAddr: []byte(lsp.srcAddr), - DstAddr: []byte(lsp.dstAddr), + SrcAddr: lsp.srcAddr.AsSlice(), + DstAddr: lsp.dstAddr.AsSlice(), } for _, sid := range lsp.path { segment := pb.Segment{ diff --git a/pkg/server/server.go b/pkg/server/server.go index 65c8ce2..693a2f0 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -7,7 +7,7 @@ package server import ( "net" - "strings" + "net/netip" "go.uber.org/zap" grpc "google.golang.org/grpc" @@ -18,12 +18,12 @@ import ( ) type Lsp struct { - peerAddr net.IP //TODO: Change to ("loopback addr" or "router name") + peerAddr netip.Addr //TODO: Change to ("loopback addr" or "router name") plspId uint32 name string path []uint32 - srcAddr net.IP - dstAddr net.IP + srcAddr netip.Addr + dstAddr netip.Addr color uint32 preference uint32 } @@ -79,7 +79,7 @@ func NewPce(o *PceOptions, logger *zap.Logger, tedElemsChan chan []table.TedElem errChan := make(chan ServerError) // Start PCEP listen go func() { - if err := s.Listen(o.PcepAddr, o.PcepPort, lspChan); err != nil { + if err := s.Serve(o.PcepAddr, o.PcepPort, lspChan); err != nil { errChan <- ServerError{ Server: "pcep", Error: err, @@ -110,33 +110,35 @@ func NewPce(o *PceOptions, logger *zap.Logger, tedElemsChan chan []table.TedElem } } -func (s *Server) Listen(address string, port string, lspChan chan Lsp) error { - var listenInfo strings.Builder - listenInfo.WriteString(address) - listenInfo.WriteString(":") - listenInfo.WriteString(port) - s.logger.Info("PCEP listen", zap.String("listenInfo", listenInfo.String())) - listener, err := net.Listen("tcp", listenInfo.String()) +func (s *Server) Serve(address string, port string, lspChan chan Lsp) error { + localAddr, err := netip.ParseAddrPort(address + ":" + port) + if err != nil { + return err + } + s.logger.Info("PCEP listen", zap.String("listenInfo", localAddr.String())) + l, err := net.ListenTCP("tcp", net.TCPAddrFromAddrPort(localAddr)) if err != nil { return err } - defer listener.Close() + defer l.Close() sessionId := uint8(1) for { - session := NewSession(sessionId, lspChan, s.logger) - session.tcpConn, err = listener.Accept() + ss := NewSession(sessionId, lspChan, s.logger) + ss.tcpConn, err = l.AcceptTCP() + if err != nil { + return err + } + peerAddrPort, err := netip.ParseAddrPort(ss.tcpConn.RemoteAddr().String()) if err != nil { return err } - strPeerAddr := session.tcpConn.RemoteAddr().String() - sessionAddr := net.ParseIP(strings.Split(strPeerAddr, ":")[0]) - session.peerAddr = sessionAddr - s.sessionList = append(s.sessionList, session) + ss.peerAddr = peerAddrPort.Addr() + s.sessionList = append(s.sessionList, ss) go func() { - session.Established() - s.closeSession(session) - s.logger.Info("Close PCEP session", zap.String("session", session.peerAddr.String())) + ss.Established() + s.closeSession(ss) + s.logger.Info("Close PCEP session", zap.String("session", ss.peerAddr.String())) }() sessionId += 1 } @@ -156,7 +158,7 @@ func (s *Server) closeSession(session *Session) { // Remove Lsp List newLspList := []Lsp{} for _, v := range s.lspList { - if !v.peerAddr.Equal(session.peerAddr) { + if v.peerAddr != session.peerAddr { newLspList = append(newLspList, v) } } @@ -165,7 +167,8 @@ func (s *Server) closeSession(session *Session) { func (s *Server) getPlspId(lspData *pb.SrPolicy) uint32 { for _, v := range s.lspList { - if v.name == lspData.GetPolicyName() && v.peerAddr.Equal(net.IP(lspData.GetPcepSessionAddr())) { + pcepSessionAddr, _ := netip.AddrFromSlice(lspData.GetPcepSessionAddr()) + if v.name == lspData.GetPolicyName() && v.peerAddr == pcepSessionAddr { return v.plspId } } @@ -176,7 +179,7 @@ func (s *Server) getPlspId(lspData *pb.SrPolicy) uint32 { func (s *Server) removeLsp(e Lsp) { // Deletes a LSP with name, PLSP-ID, and sessionAddr matching from lspList for i, v := range s.lspList { - if v.name == e.name && v.plspId == e.plspId && v.peerAddr.Equal(e.peerAddr) { + if v.name == e.name && v.plspId == e.plspId && v.peerAddr == e.peerAddr { s.lspList[i] = s.lspList[len(s.lspList)-1] s.lspList = s.lspList[:len(s.lspList)-1] break @@ -184,9 +187,9 @@ func (s *Server) removeLsp(e Lsp) { } } -func (s *Server) getSession(peerAddr net.IP) *Session { +func (s *Server) getSession(peerAddr netip.Addr) *Session { for _, pcepSession := range s.sessionList { - if pcepSession.peerAddr.Equal(peerAddr) { + if pcepSession.peerAddr == peerAddr { if !pcepSession.isSynced { break } diff --git a/pkg/server/session.go b/pkg/server/session.go index 7b7673c..970c548 100644 --- a/pkg/server/session.go +++ b/pkg/server/session.go @@ -9,6 +9,7 @@ import ( "encoding/binary" "fmt" "net" + "net/netip" "time" "github.com/nttcom/pola/pkg/packet/pcep" @@ -20,8 +21,8 @@ const KEEPALIVE uint8 = 30 type Session struct { sessionId uint8 - peerAddr net.IP - tcpConn net.Conn + peerAddr netip.Addr + tcpConn *net.TCPConn isSynced bool srpIdHead uint32 // 0x00000000 and 0xFFFFFFFF are reserved. lspChan chan Lsp @@ -214,8 +215,8 @@ func (s *Session) ReceivePcepMessage() error { } } -func (s *Session) SendPCInitiate(policyName string, labels []pcep.Label, color uint32, preference uint32, srcIPv4 []uint8, dstIPv4 []uint8) error { - pcinitiateMessage, err := pcep.NewPCInitiateMessage(s.srpIdHead, policyName, labels, color, preference, srcIPv4, dstIPv4, pcep.VendorSpecific(s.pccType)) +func (s *Session) SendPCInitiate(policyName string, labels []pcep.Label, color uint32, preference uint32, srcAddr netip.Addr, dstAddr netip.Addr) error { + pcinitiateMessage, err := pcep.NewPCInitiateMessage(s.srpIdHead, policyName, labels, color, preference, srcAddr, dstAddr, pcep.VendorSpecific(s.pccType)) if err != nil { return err } @@ -223,15 +224,8 @@ func (s *Session) SendPCInitiate(policyName string, labels []pcep.Label, color u if err != nil { return err } - labelsJson := []map[string]interface{}{} - for _, l := range labels { - labelJson := map[string]interface{}{ - "Sid": l.Sid, - "LoAddr": net.IP(l.LoAddr).String(), - } - labelsJson = append(labelsJson, labelJson) - } - s.logger.Info("Send PCInitiate", zap.String("session", s.peerAddr.String()), zap.Uint32("srpId", s.srpIdHead), zap.String("policyName", policyName), zap.Any("labels", labelsJson), zap.Uint32("color", color), zap.Uint32("preference", preference), zap.String("srcIPv4", net.IP(srcIPv4).String()), zap.Any("dstIPv4", net.IP(dstIPv4).String())) + + s.logger.Info("Send PCInitiate", zap.String("session", s.peerAddr.String()), zap.Uint32("srpId", s.srpIdHead), zap.String("policyName", policyName), zap.Any("labels", labels), zap.Uint32("color", color), zap.Uint32("preference", preference), zap.String("srcIPv4", srcAddr.String()), zap.Any("dstIPv4", dstAddr.String())) if _, err := s.tcpConn.Write(bytePCInitiateMessage); err != nil { return err }