diff --git a/.changeset/chubby-insects-cut.md b/.changeset/chubby-insects-cut.md new file mode 100644 index 000000000..a64d6d5db --- /dev/null +++ b/.changeset/chubby-insects-cut.md @@ -0,0 +1,5 @@ +--- +"github.com/livekit/protocol": patch +--- + +adding detailed responses to the trunk match logic which can help with decisions on blocking diff --git a/sip/sip.go b/sip/sip.go index b7dd2f987..f1f1d9f13 100644 --- a/sip/sip.go +++ b/sip/sip.go @@ -451,6 +451,30 @@ func matchNumbers(num string, allowed []string) bool { return false } +// TrunkMatchType indicates how a trunk was matched +type TrunkMatchType int + +const ( + // TrunkMatchEmpty indicates no trunks were defined + TrunkMatchEmpty TrunkMatchType = iota + // TrunkMatchNone indicates trunks exist but none matched + TrunkMatchNone + // TrunkMatchDefault indicates only a default trunk (with no specific numbers) matched + TrunkMatchDefault + // TrunkMatchSpecific indicates a trunk with specific numbers matched + TrunkMatchSpecific +) + +// TrunkMatchResult provides detailed information about the trunk matching process +type TrunkMatchResult struct { + // The matched trunk, if any + Trunk *livekit.SIPInboundTrunkInfo + // How the trunk was matched + MatchType TrunkMatchType + // Number of default trunks found + DefaultTrunkCount int +} + // MatchTrunk finds a SIP Trunk definition matching the request. // Returns nil if no rules matched or an error if there are conflicting definitions. // @@ -459,6 +483,99 @@ func MatchTrunk(trunks []*livekit.SIPInboundTrunkInfo, call *rpc.SIPCall, opts . return MatchTrunkIter(iters.Slice(trunks), call, opts...) } +// MatchTrunkDetailed is like MatchTrunkIter but returns detailed match information +func MatchTrunkDetailed(it iters.Iter[*livekit.SIPInboundTrunkInfo], call *rpc.SIPCall, opts ...MatchTrunkOpt) (*TrunkMatchResult, error) { + defer it.Close() + var opt matchTrunkOpts + for _, fnc := range opts { + fnc(&opt) + } + opt.defaults() + + result := &TrunkMatchResult{ + MatchType: TrunkMatchEmpty, // Start with assumption it's empty + } + + var ( + selectedTrunk *livekit.SIPInboundTrunkInfo + defaultTrunk *livekit.SIPInboundTrunkInfo + defaultTrunkPrev *livekit.SIPInboundTrunkInfo + sawAnyTrunk bool + ) + calledNorm := NormalizeNumber(call.To.User) + for { + tr, err := it.Next() + if err == io.EOF { + break + } else if err != nil { + return nil, err + } + if !sawAnyTrunk { + sawAnyTrunk = true + result.MatchType = TrunkMatchNone // We have trunks but haven't matched any yet + } + tr = opt.Replace(tr) + // Do not consider it if number doesn't match. + if !matchNumbers(call.From.User, tr.AllowedNumbers) { + if !opt.Filtered(tr, TrunkFilteredCallingNumberDisallowed) { + continue + } + } + if !matchAddrMasks(call.SourceIp, call.From.Host, tr.AllowedAddresses) { + if !opt.Filtered(tr, TrunkFilteredSourceAddressDisallowed) { + continue + } + } + if len(tr.Numbers) == 0 { + // Default/wildcard trunk. + defaultTrunkPrev = defaultTrunk + defaultTrunk = tr + result.DefaultTrunkCount++ + } else { + for _, num := range tr.Numbers { + if num == call.To.User || NormalizeNumber(num) == calledNorm { + // Trunk specific to the number. + if selectedTrunk != nil { + opt.Conflict(selectedTrunk, tr, TrunkConflictCalledNumber) + if opt.AllowConflicts { + // This path is unreachable, since we pick the first trunk. Kept for completeness. + continue + } + return nil, twirp.NewErrorf(twirp.FailedPrecondition, "Multiple SIP Trunks matched for %q", call.To.User) + } + selectedTrunk = tr + if opt.AllowConflicts { + // Pick the first match as soon as it's found. We don't care about conflicts. + result.Trunk = selectedTrunk + result.MatchType = TrunkMatchSpecific + return result, nil + } + // Keep searching! We want to know if there are any conflicting Trunk definitions. + } else { + opt.Filtered(tr, TrunkFilteredCalledNumberDisallowed) + } + } + } + } + + if selectedTrunk != nil { + result.Trunk = selectedTrunk + result.MatchType = TrunkMatchSpecific + return result, nil + } + if result.DefaultTrunkCount > 1 { + opt.Conflict(defaultTrunk, defaultTrunkPrev, TrunkConflictDefault) + if !opt.AllowConflicts { + return nil, twirp.NewErrorf(twirp.FailedPrecondition, "Multiple default SIP Trunks matched for %q", call.To.User) + } + } + if defaultTrunk != nil { + result.Trunk = defaultTrunk + result.MatchType = TrunkMatchDefault + } + return result, nil +} + type matchTrunkOpts struct { AllowConflicts bool Filtered TrunkFilteredFunc @@ -541,78 +658,11 @@ func WithTrunkReplace(fnc TrunkReplaceFunc) MatchTrunkOpt { // MatchTrunkIter finds a SIP Trunk definition matching the request. // Returns nil if no rules matched or an error if there are conflicting definitions. func MatchTrunkIter(it iters.Iter[*livekit.SIPInboundTrunkInfo], call *rpc.SIPCall, opts ...MatchTrunkOpt) (*livekit.SIPInboundTrunkInfo, error) { - defer it.Close() - var opt matchTrunkOpts - for _, fnc := range opts { - fnc(&opt) - } - opt.defaults() - var ( - selectedTrunk *livekit.SIPInboundTrunkInfo - defaultTrunk *livekit.SIPInboundTrunkInfo - defaultTrunkPrev *livekit.SIPInboundTrunkInfo - defaultTrunkCnt int // to error in case there are multiple ones - ) - calledNorm := NormalizeNumber(call.To.User) - for { - tr, err := it.Next() - if err == io.EOF { - break - } else if err != nil { - return nil, err - } - tr = opt.Replace(tr) - // Do not consider it if number doesn't match. - if !matchNumbers(call.From.User, tr.AllowedNumbers) { - if !opt.Filtered(tr, TrunkFilteredCallingNumberDisallowed) { - continue - } - } - if !matchAddrMasks(call.SourceIp, call.From.Host, tr.AllowedAddresses) { - if !opt.Filtered(tr, TrunkFilteredSourceAddressDisallowed) { - continue - } - } - if len(tr.Numbers) == 0 { - // Default/wildcard trunk. - defaultTrunkPrev = defaultTrunk - defaultTrunk = tr - defaultTrunkCnt++ - } else { - for _, num := range tr.Numbers { - if num == call.To.User || NormalizeNumber(num) == calledNorm { - // Trunk specific to the number. - if selectedTrunk != nil { - opt.Conflict(selectedTrunk, tr, TrunkConflictCalledNumber) - if opt.AllowConflicts { - // This path is unreachable, since we pick the first trunk. Kept for completeness. - continue - } - return nil, twirp.NewErrorf(twirp.FailedPrecondition, "Multiple SIP Trunks matched for %q", call.To.User) - } - selectedTrunk = tr - if opt.AllowConflicts { - // Pick the first match as soon as it's found. We don't care about conflicts. - return selectedTrunk, nil - } - // Keep searching! We want to know if there are any conflicting Trunk definitions. - } else { - opt.Filtered(tr, TrunkFilteredCalledNumberDisallowed) - } - } - } - } - if selectedTrunk != nil { - return selectedTrunk, nil - } - if defaultTrunkCnt > 1 { - opt.Conflict(defaultTrunk, defaultTrunkPrev, TrunkConflictDefault) - if !opt.AllowConflicts { - return nil, twirp.NewErrorf(twirp.FailedPrecondition, "Multiple default SIP Trunks matched for %q", call.To.User) - } + result, err := MatchTrunkDetailed(it, call, opts...) + if err != nil { + return nil, err } - // Could still be nil here. - return defaultTrunk, nil + return result.Trunk, nil } // MatchDispatchRule finds the best dispatch rule matching the request parameters. Returns an error if no rule matched. diff --git a/sip/sip_test.go b/sip/sip_test.go index ba5fdc5b9..c8a11b5ee 100644 --- a/sip/sip_test.go +++ b/sip/sip_test.go @@ -824,3 +824,140 @@ func TestMatchMasks(t *testing.T) { }) } } + +func TestMatchTrunkDetailed(t *testing.T) { + for _, c := range []struct { + name string + trunks []*livekit.SIPInboundTrunkInfo + expMatchType TrunkMatchType + expTrunkID string + expDefaultCount int + expErr bool + from string + to string + src string + host string + }{ + { + name: "empty", + trunks: nil, + expMatchType: TrunkMatchEmpty, + expTrunkID: "", + expErr: false, + }, + { + name: "one wildcard", + trunks: []*livekit.SIPInboundTrunkInfo{ + {SipTrunkId: "aaa"}, + }, + expMatchType: TrunkMatchDefault, + expTrunkID: "aaa", + expDefaultCount: 1, + expErr: false, + }, + { + name: "specific match", + trunks: []*livekit.SIPInboundTrunkInfo{ + {SipTrunkId: "aaa", Numbers: []string{sipNumber2}}, + }, + expMatchType: TrunkMatchSpecific, + expTrunkID: "aaa", + expDefaultCount: 0, + expErr: false, + }, + { + name: "no match with trunks", + trunks: []*livekit.SIPInboundTrunkInfo{ + {SipTrunkId: "aaa", Numbers: []string{sipNumber3}}, + }, + expMatchType: TrunkMatchNone, + expTrunkID: "", + expDefaultCount: 0, + expErr: false, + }, + { + name: "multiple defaults", + trunks: []*livekit.SIPInboundTrunkInfo{ + {SipTrunkId: "aaa"}, + {SipTrunkId: "bbb"}, + }, + expMatchType: TrunkMatchDefault, + expTrunkID: "aaa", + expDefaultCount: 2, + expErr: true, + }, + { + name: "specific over default", + trunks: []*livekit.SIPInboundTrunkInfo{ + {SipTrunkId: "aaa"}, + {SipTrunkId: "bbb", Numbers: []string{sipNumber2}}, + }, + expMatchType: TrunkMatchSpecific, + expTrunkID: "bbb", + expDefaultCount: 1, + expErr: false, + }, + { + name: "multiple specific", + trunks: []*livekit.SIPInboundTrunkInfo{ + {SipTrunkId: "aaa", Numbers: []string{sipNumber2}}, + {SipTrunkId: "bbb", Numbers: []string{sipNumber2}}, + }, + expMatchType: TrunkMatchSpecific, + expTrunkID: "aaa", + expDefaultCount: 0, + expErr: true, + }, + } { + c := c + t.Run(c.name, func(t *testing.T) { + from, to, src, host := c.from, c.to, c.src, c.host + if from == "" { + from = sipNumber1 + } + if to == "" { + to = sipNumber2 + } + if src == "" { + src = "1.1.1.1" + } + if host == "" { + host = "sip.example.com" + } + call := &rpc.SIPCall{ + SourceIp: src, + From: &livekit.SIPUri{ + User: from, + Host: host, + }, + To: &livekit.SIPUri{ + User: to, + }, + } + call.Address = call.To + + var conflicts []string + result, err := MatchTrunkDetailed(iters.Slice(c.trunks), call, WithTrunkConflict(func(t1, t2 *livekit.SIPInboundTrunkInfo, reason TrunkConflictReason) { + conflicts = append(conflicts, fmt.Sprintf("%v: %v vs %v", reason, t1.SipTrunkId, t2.SipTrunkId)) + })) + + if c.expErr { + require.Error(t, err) + require.NotEmpty(t, conflicts, "expected conflicts but got none") + } else { + require.NoError(t, err) + require.Empty(t, conflicts, "unexpected conflicts: %v", conflicts) + + if c.expTrunkID == "" { + require.Nil(t, result.Trunk) + } else { + require.NotNil(t, result.Trunk) + require.Equal(t, c.expTrunkID, result.Trunk.SipTrunkId) + } + + require.Equal(t, c.expMatchType, result.MatchType) + require.Equal(t, c.expDefaultCount, result.DefaultTrunkCount) + } + }) + } +}