diff --git a/.changeset/cuddly-buttons-guess.md b/.changeset/cuddly-buttons-guess.md new file mode 100644 index 000000000..ecbb7e28d --- /dev/null +++ b/.changeset/cuddly-buttons-guess.md @@ -0,0 +1,5 @@ +--- +"github.com/livekit/protocol": patch +--- + +Match SIP Trunks by source IP or mask. diff --git a/sip/sip.go b/sip/sip.go index 6751f34fb..a711c437e 100644 --- a/sip/sip.go +++ b/sip/sip.go @@ -17,9 +17,11 @@ package sip import ( "fmt" "math" + "net/netip" "regexp" "sort" "strconv" + "strings" "golang.org/x/exp/slices" @@ -222,9 +224,36 @@ func ValidateTrunks(trunks []*livekit.SIPTrunkInfo) error { return nil } +func matchAddrs(addr string, mask string) bool { + if !strings.Contains(mask, "/") { + return addr == mask + } + ip, err := netip.ParseAddr(addr) + if err != nil { + return false + } + pref, err := netip.ParsePrefix(mask) + if err != nil { + return false + } + return pref.Contains(ip) +} + +func matchAddr(addr string, masks []string) bool { + if addr == "" { + return true + } + for _, mask := range masks { + if !matchAddrs(addr, mask) { + return false + } + } + return true +} + // MatchTrunk finds a SIP Trunk definition matching the request. // Returns nil if no rules matched or an error if there are conflicting definitions. -func MatchTrunk(trunks []*livekit.SIPTrunkInfo, calling, called string) (*livekit.SIPTrunkInfo, error) { +func MatchTrunk(trunks []*livekit.SIPTrunkInfo, srcIP, calling, called string) (*livekit.SIPTrunkInfo, error) { var ( selectedTrunk *livekit.SIPTrunkInfo defaultTrunk *livekit.SIPTrunkInfo @@ -235,6 +264,9 @@ func MatchTrunk(trunks []*livekit.SIPTrunkInfo, calling, called string) (*liveki if len(tr.InboundNumbers) != 0 && !slices.Contains(tr.InboundNumbers, calling) { continue } + if !matchAddr(srcIP, tr.InboundAddresses) { + continue + } // Deprecated, but we still check it for backward compatibility. matchesRe := len(tr.InboundNumbersRegex) == 0 for _, reStr := range tr.InboundNumbersRegex { diff --git a/sip/sip_test.go b/sip/sip_test.go index a98bbb2c8..a2fa6fffc 100644 --- a/sip/sip_test.go +++ b/sip/sip_test.go @@ -157,7 +157,7 @@ func TestSIPMatchTrunk(t *testing.T) { for _, c := range trunkCases { c := c t.Run(c.name, func(t *testing.T) { - got, err := MatchTrunk(c.trunks, sipNumber1, sipNumber2) + got, err := MatchTrunk(c.trunks, "", sipNumber1, sipNumber2) if c.expErr { require.Error(t, err) require.Nil(t, got) @@ -515,3 +515,23 @@ func TestSIPValidateDispatchRules(t *testing.T) { }) } } + +func TestMatchIP(t *testing.T) { + cases := []struct { + addr string + mask string + exp bool + }{ + {addr: "192.168.0.10", mask: "192.168.0.10", exp: true}, + {addr: "192.168.0.10", mask: "192.168.0.11", exp: false}, + {addr: "192.168.0.10", mask: "192.168.0.0/24", exp: true}, + {addr: "192.168.0.10", mask: "192.168.0.10/0", exp: true}, + {addr: "192.168.0.10", mask: "192.170.0.0/24", exp: false}, + } + for _, c := range cases { + t.Run(c.mask, func(t *testing.T) { + got := matchAddrs(c.addr, c.mask) + require.Equal(t, c.exp, got) + }) + } +}