diff --git a/go.mod b/go.mod index d97dbbf91..d3c914d35 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/dennwc/iters v1.1.0 github.com/frostbyte73/core v0.1.1 github.com/fsnotify/fsnotify v1.9.0 - github.com/gammazero/deque v1.0.0 + github.com/gammazero/deque v1.1.0 github.com/go-jose/go-jose/v3 v3.0.4 github.com/go-logr/logr v1.4.3 github.com/hashicorp/go-retryablehttp v0.7.7 diff --git a/go.sum b/go.sum index cae99b906..f2bdf84d7 100644 --- a/go.sum +++ b/go.sum @@ -51,8 +51,8 @@ github.com/frostbyte73/core v0.1.1 h1:ChhJOR7bAKOCPbA+lqDLE2cGKlCG5JXsDvvQr4YaJI github.com/frostbyte73/core v0.1.1/go.mod h1:mhfOtR+xWAvwXiwor7jnqPMnu4fxbv1F2MwZ0BEpzZo= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= -github.com/gammazero/deque v1.0.0 h1:LTmimT8H7bXkkCy6gZX7zNLtkbz4NdS2z8LZuor3j34= -github.com/gammazero/deque v1.0.0/go.mod h1:iflpYvtGfM3U8S8j+sZEKIak3SAKYpA5/SQewgfXDKo= +github.com/gammazero/deque v1.1.0 h1:OyiyReBbnEG2PP0Bnv1AASLIYvyKqIFN5xfl1t8oGLo= +github.com/gammazero/deque v1.1.0/go.mod h1:JVrR+Bj1NMQbPnYclvDlvSX0nVGReLrQZ0aUMuWLctg= github.com/go-jose/go-jose/v3 v3.0.4 h1:Wp5HA7bLQcKnf6YYao/4kpRpVMp/yf6+pJKV8WFSaNY= github.com/go-jose/go-jose/v3 v3.0.4/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= diff --git a/signalling/signalfragment_test.go b/signalling/signalfragment_test.go new file mode 100644 index 000000000..5237a32df --- /dev/null +++ b/signalling/signalfragment_test.go @@ -0,0 +1,142 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package signalling + +import ( + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + "github.com/livekit/protocol/livekit" +) + +func TestSignalFragment(t *testing.T) { + inputMessage := &livekit.Envelope{ + ServerMessages: []*livekit.Signalv2ServerMessage{ + { + Message: &livekit.Signalv2ServerMessage_ConnectResponse{ + ConnectResponse: &livekit.ConnectResponse{ + SifTrailer: []byte("abcdefghijklmnopqrstuvwxyz0123456789"), + }, + }, + }, + { + Message: &livekit.Signalv2ServerMessage_ConnectResponse{ + ConnectResponse: &livekit.ConnectResponse{ + SifTrailer: []byte("0123456789abcdefghijklmnopqrstuvwxyz0123456789"), + }, + }, + }, + { + Message: &livekit.Signalv2ServerMessage_ConnectResponse{ + ConnectResponse: &livekit.ConnectResponse{ + SifTrailer: []byte("ABCDEFGHIJKLMNOPQRSTabcdefghijklmnopqrstuvwxyz0123456789"), + }, + }, + }, + }, + } + + t.Run("no segmentation needed", func(t *testing.T) { + sr := NewSignalSegmenter(SignalSegmenterParams{ + MaxFragmentSize: 5_000_000, + }) + + marshalled, err := proto.Marshal(inputMessage) + require.NoError(t, err) + require.Nil(t, sr.Segment(marshalled)) + }) + + t.Run("segmentation + reassembly", func(t *testing.T) { + maxFragmentSize := 5 + sr := NewSignalSegmenter(SignalSegmenterParams{ + MaxFragmentSize: maxFragmentSize, + }) + + marshalled, err := proto.Marshal(inputMessage) + require.NoError(t, err) + + expectedNumFragments := (len(marshalled) + maxFragmentSize - 1) / maxFragmentSize + + fragments := sr.Segment(marshalled) + require.NotZero(t, len(fragments)) + require.Equal(t, uint32(len(marshalled)), fragments[0].TotalSize) + + rr := NewSignalReassembler(SignalReassemblerParams{}) + var reassembled []byte + for idx, fragment := range fragments { + require.Equal(t, uint32(idx+1), fragment.FragmentNumber) + require.NotZero(t, fragment.FragmentSize) + require.Equal(t, uint32(expectedNumFragments), fragment.NumFragments) + require.Equal(t, fragment.FragmentSize, uint32(len(fragment.Data))) + + reassembled = rr.Reassemble(fragment) + } + require.Equal(t, marshalled, reassembled) + }) + + t.Run("runt", func(t *testing.T) { + maxFragmentSize := 5 + sr := NewSignalSegmenter(SignalSegmenterParams{ + MaxFragmentSize: maxFragmentSize, + }) + + marshalled, err := proto.Marshal(inputMessage) + require.NoError(t, err) + + fragments := sr.Segment(marshalled) + + rr := NewSignalReassembler(SignalReassemblerParams{}) + var reassembled []byte + for idx, fragment := range fragments { + // do not send one packet into re-assembly initially, re-assembly should not succeed + if idx == 0 { + continue + } + + reassembled = rr.Reassemble(fragment) + } + require.Zero(t, len(reassembled)) + + // submit 1st fragment and ensure reassembly completes + reassembled = rr.Reassemble(fragments[0]) + require.Equal(t, marshalled, reassembled) + }) + + t.Run("corrupted", func(t *testing.T) { + maxFragmentSize := 5 + sr := NewSignalSegmenter(SignalSegmenterParams{ + MaxFragmentSize: maxFragmentSize, + }) + + marshalled, err := proto.Marshal(inputMessage) + require.NoError(t, err) + + fragments := sr.Segment(marshalled) + + rr := NewSignalReassembler(SignalReassemblerParams{}) + var reassembled []byte + for idx, fragment := range fragments { + // corrupt a fragment, re-assembly should fail + if idx == 0 { + fragment.FragmentSize += 1 + } + + reassembled = rr.Reassemble(fragment) + } + require.Zero(t, len(reassembled)) + }) +} diff --git a/signalling/signalreassembler.go b/signalling/signalreassembler.go new file mode 100644 index 000000000..233e5a24b --- /dev/null +++ b/signalling/signalreassembler.go @@ -0,0 +1,154 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package signalling + +import ( + "sync" + "time" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" + "go.uber.org/zap/zapcore" +) + +const ( + reassemblerTimeout = time.Minute +) + +type reassembly struct { + packetId uint32 + startedAt time.Time + fragments []*livekit.Fragment + isCorrupted bool + tqi *utils.TimeoutQueueItem[*reassembly] +} + +func (r *reassembly) MarshalLogObject(e zapcore.ObjectEncoder) error { + if r == nil { + return nil + } + + e.AddUint32("packetId", r.packetId) + e.AddTime("startAt", r.startedAt) + e.AddDuration("age", time.Since(r.startedAt)) + + expectedNumberOfFragments := len(r.fragments) + expectedTotalSize := uint32(0) + availableSize := uint32(0) + var availableFragments []uint32 + for _, fragment := range r.fragments { + if fragment == nil { + continue + } + + expectedTotalSize = fragment.TotalSize + availableSize += fragment.FragmentSize + availableFragments = append(availableFragments, fragment.FragmentNumber) + } + e.AddInt("expectedNumberOfFragments", expectedNumberOfFragments) + e.AddUint32("expectedTotalSize", expectedTotalSize) + e.AddUint32("availableSize", availableSize) + e.AddArray("availableFragments", logger.Uint32Slice(availableFragments)) + + e.AddBool("isCorrupted", r.isCorrupted) + return nil +} + +// ------------------------------------------------ + +type SignalReassemblerParams struct { + Logger logger.Logger +} + +type SignalReassembler struct { + params SignalReassemblerParams + + lock sync.Mutex + reassemblies map[uint32]*reassembly + + timeoutQueue utils.TimeoutQueue[*reassembly] +} + +func NewSignalReassembler(params SignalReassemblerParams) *SignalReassembler { + return &SignalReassembler{ + params: params, + reassemblies: make(map[uint32]*reassembly), + } +} + +func (s *SignalReassembler) Reassemble(fragment *livekit.Fragment) []byte { + s.lock.Lock() + defer s.lock.Unlock() + + re, ok := s.reassemblies[fragment.PacketId] + if !ok { + re = &reassembly{ + packetId: fragment.PacketId, + startedAt: time.Now(), + fragments: make([]*livekit.Fragment, fragment.NumFragments), + } + re.tqi = &utils.TimeoutQueueItem[*reassembly]{Value: re} + + s.reassemblies[fragment.PacketId] = re + } + if int(fragment.FragmentNumber) <= len(re.fragments) { + if int(fragment.FragmentSize) != len(fragment.Data) { + re.isCorrupted = true // runt packet, data size of blob does not match fragment size + } else { + re.fragments[fragment.FragmentNumber-1] = fragment + } + } else { + re.isCorrupted = true + } + + if re.isCorrupted { + return nil + } + + // try to reassemble + expectedTotalSize := uint32(0) + totalSize := 0 + for _, fr := range re.fragments { + if fr == nil { + return nil // not received all fragments of packet yet + } + + expectedTotalSize = fr.TotalSize // can read this from any fragment of packet + totalSize += len(fr.Data) + } + if expectedTotalSize != 0 && uint32(totalSize) != expectedTotalSize { + re.isCorrupted = true + return nil + } + + data := make([]byte, 0, expectedTotalSize) + for _, fr := range re.fragments { + data = append(data, fr.Data...) + } + delete(s.reassemblies, re.packetId) // fully re-assembled, can be deleted from cache + return data +} + +func (s *SignalReassembler) Prune() { + for it := s.timeoutQueue.IterateRemoveAfter(reassemblerTimeout); it.Next(); { + re := it.Item().Value + s.params.Logger.Infow("pruning stale reassembly packet", "reassembly", re) + + s.lock.Lock() + delete(s.reassemblies, re.packetId) + s.lock.Unlock() + } +} diff --git a/signalling/signalsegmenter.go b/signalling/signalsegmenter.go new file mode 100644 index 000000000..d187c6acf --- /dev/null +++ b/signalling/signalsegmenter.go @@ -0,0 +1,81 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package signalling + +import ( + "math/rand" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "go.uber.org/atomic" +) + +const ( + defaultMaxFragmentSize = 8192 +) + +type SignalSegmenterParams struct { + Logger logger.Logger + MaxFragmentSize int + FirstPacketId uint32 // should be used for testing only +} + +type SignalSegmenter struct { + params SignalSegmenterParams + + packetId atomic.Uint32 +} + +func NewSignalSegmenter(params SignalSegmenterParams) *SignalSegmenter { + s := &SignalSegmenter{ + params: params, + } + if s.params.MaxFragmentSize == 0 { + s.params.MaxFragmentSize = defaultMaxFragmentSize + } + s.packetId.Store(params.FirstPacketId) + if s.packetId.Load() == 0 { + s.packetId.Store(uint32(rand.Intn(1<<8) + 1)) + } + return s +} + +func (s *SignalSegmenter) Segment(data []byte) []*livekit.Fragment { + if len(data) <= s.params.MaxFragmentSize { + return nil + } + + var fragments []*livekit.Fragment + numFragments := uint32((len(data) + s.params.MaxFragmentSize - 1) / s.params.MaxFragmentSize) + fragmentNumber := uint32(1) + consumed := 0 + packetId := s.packetId.Inc() + for len(data[consumed:]) != 0 { + fragmentSize := min(len(data[consumed:]), s.params.MaxFragmentSize) + fragment := &livekit.Fragment{ + PacketId: packetId, + FragmentNumber: fragmentNumber, + NumFragments: numFragments, + FragmentSize: uint32(fragmentSize), + TotalSize: uint32(len(data)), + Data: data[consumed : consumed+fragmentSize], + } + fragments = append(fragments, fragment) + fragmentNumber++ + consumed += fragmentSize + } + + return fragments +} diff --git a/signalling/signalv2cache.go b/signalling/signalv2cache.go new file mode 100644 index 000000000..62938129f --- /dev/null +++ b/signalling/signalv2cache.go @@ -0,0 +1,125 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package signalling + +import ( + "math/rand" + "sync" + + "github.com/gammazero/deque" + + "github.com/livekit/protocol/livekit" + "github.com/livekit/protocol/logger" + "github.com/livekit/protocol/utils" +) + +type Signalv2CacheParams struct { + Logger logger.Logger + FirstMessageId uint32 // should be used for testing only +} + +type Signalv2Cache struct { + params Signalv2CacheParams + + lock sync.Mutex + messageId uint32 + lastProcessedRemoteMessageId uint32 + messages deque.Deque[*livekit.Signalv2ServerMessage] +} + +func NewSignalv2Cache(params Signalv2CacheParams) *Signalv2Cache { + s := &Signalv2Cache{ + params: params, + messageId: params.FirstMessageId, + } + if s.messageId == 0 { + s.messageId = uint32(rand.Intn(1<<8) + 1) + } + s.messages.SetBaseCap(16) + return s +} + +func (s *Signalv2Cache) SetLastProcessedRemoteMessageId(lastProcessedRemoteMessageId uint32) { + s.lock.Lock() + defer s.lock.Unlock() + + s.lastProcessedRemoteMessageId = lastProcessedRemoteMessageId +} + +func (s *Signalv2Cache) Add(msg *livekit.Signalv2ServerMessage) *livekit.Signalv2ServerMessage { + if msg != nil { + s.AddBatch([]*livekit.Signalv2ServerMessage{msg}) + } + + return msg +} + +// SIGNALLING-V2-TODO: may not need this API +func (s *Signalv2Cache) AddBatch(msgs []*livekit.Signalv2ServerMessage) { + s.lock.Lock() + defer s.lock.Unlock() + + for _, msg := range msgs { + msg.Sequencer = &livekit.Sequencer{ + MessageId: s.messageId, + } + s.messageId++ + + s.messages.PushBack(msg) + } +} + +func (s *Signalv2Cache) Clear(till uint32) { + s.lock.Lock() + defer s.lock.Unlock() + + s.clearLocked(till) +} + +func (s *Signalv2Cache) clearLocked(till uint32) { + for s.messages.Len() != 0 { + front := s.messages.Front() + if front.Sequencer.GetMessageId() > till { + break + } + s.messages.PopFront() + } +} + +func (s *Signalv2Cache) GetFromFront() []*livekit.Signalv2ServerMessage { + s.lock.Lock() + defer s.lock.Unlock() + + return s.getFromFrontLocked() +} + +func (s *Signalv2Cache) getFromFrontLocked() []*livekit.Signalv2ServerMessage { + var msgs []*livekit.Signalv2ServerMessage + for msg := range s.messages.Iter() { + clone := utils.CloneProto(msg) + clone.Sequencer.LastProcessedRemoteMessageId = s.lastProcessedRemoteMessageId + msgs = append(msgs, clone) + } + + return msgs +} + +func (s *Signalv2Cache) ClearAndGetFrom(from uint32) []*livekit.Signalv2ServerMessage { + s.lock.Lock() + defer s.lock.Unlock() + + s.clearLocked(from - 1) + return s.getFromFrontLocked() +} diff --git a/signalling/signalv2cache_test.go b/signalling/signalv2cache_test.go new file mode 100644 index 000000000..9355dd243 --- /dev/null +++ b/signalling/signalv2cache_test.go @@ -0,0 +1,171 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package signalling + +import ( + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + "github.com/livekit/protocol/livekit" +) + +func TestSignalv2Cache(t *testing.T) { + firstMessageId := uint32(10) + lastProcessedRemoteMessageId := uint32(2345) + cache := NewSignalv2Cache(Signalv2CacheParams{ + FirstMessageId: firstMessageId, + }) + + inputMessages := []*livekit.Signalv2ServerMessage{ + &livekit.Signalv2ServerMessage{ + Message: &livekit.Signalv2ServerMessage_ConnectResponse{}, + }, + &livekit.Signalv2ServerMessage{ + Message: &livekit.Signalv2ServerMessage_PublisherSdp{}, + }, + &livekit.Signalv2ServerMessage{ + Message: &livekit.Signalv2ServerMessage_SubscriberSdp{}, + }, + &livekit.Signalv2ServerMessage{ + Message: &livekit.Signalv2ServerMessage_RoomUpdate{}, + }, + } + + expectedOutputMessages := []*livekit.Signalv2ServerMessage{ + &livekit.Signalv2ServerMessage{ + Sequencer: &livekit.Sequencer{ + MessageId: firstMessageId, + LastProcessedRemoteMessageId: lastProcessedRemoteMessageId, + }, + Message: &livekit.Signalv2ServerMessage_ConnectResponse{}, + }, + &livekit.Signalv2ServerMessage{ + Sequencer: &livekit.Sequencer{ + MessageId: firstMessageId + 1, + LastProcessedRemoteMessageId: lastProcessedRemoteMessageId, + }, + Message: &livekit.Signalv2ServerMessage_PublisherSdp{}, + }, + &livekit.Signalv2ServerMessage{ + Sequencer: &livekit.Sequencer{ + MessageId: firstMessageId + 2, + LastProcessedRemoteMessageId: lastProcessedRemoteMessageId, + }, + Message: &livekit.Signalv2ServerMessage_SubscriberSdp{}, + }, + &livekit.Signalv2ServerMessage{ + Sequencer: &livekit.Sequencer{ + MessageId: firstMessageId + 3, + LastProcessedRemoteMessageId: lastProcessedRemoteMessageId, + }, + Message: &livekit.Signalv2ServerMessage_RoomUpdate{}, + }, + } + + cache.SetLastProcessedRemoteMessageId(lastProcessedRemoteMessageId) + + // Add() - add one message at a time + for _, inputMessage := range inputMessages { + cache.Add(inputMessage) + } + + // get all messages in cache + outputMessages := cache.GetFromFront() + require.True(t, compareProtoSlices(expectedOutputMessages, outputMessages)) + + // clear one and get again + cache.Clear(firstMessageId) + + outputMessages = cache.GetFromFront() + require.True(t, compareProtoSlices(expectedOutputMessages[1:], outputMessages)) + + // clearing some evicted messages should not clear anything + cache.Clear(firstMessageId) // firstMessageId has been cleared already at this point + + outputMessages = cache.GetFromFront() + require.True(t, compareProtoSlices(expectedOutputMessages[1:], outputMessages)) + + // clear some and get rest in one go + outputMessages = cache.ClearAndGetFrom(firstMessageId + 3) + require.Equal(t, 1, len(outputMessages)) + require.True(t, compareProtoSlices(expectedOutputMessages[3:], outputMessages)) + + // getting again should get the same messages again as they sill should in cache + outputMessages = cache.GetFromFront() + require.True(t, compareProtoSlices(expectedOutputMessages[3:], outputMessages)) + + // clearing all and getting should return nil + require.Nil(t, cache.ClearAndGetFrom(firstMessageId+uint32(len(inputMessages)))) + + // getting again should return nil as the cache is fully cleared above + require.Nil(t, cache.GetFromFront()) + + lastProcessedRemoteMessageId = 4567 + cache.SetLastProcessedRemoteMessageId(lastProcessedRemoteMessageId) + + expectedOutputMessages = []*livekit.Signalv2ServerMessage{ + &livekit.Signalv2ServerMessage{ + Sequencer: &livekit.Sequencer{ + MessageId: firstMessageId + 4, + LastProcessedRemoteMessageId: lastProcessedRemoteMessageId, + }, + Message: &livekit.Signalv2ServerMessage_ConnectResponse{}, + }, + &livekit.Signalv2ServerMessage{ + Sequencer: &livekit.Sequencer{ + MessageId: firstMessageId + 1 + 4, + LastProcessedRemoteMessageId: lastProcessedRemoteMessageId, + }, + Message: &livekit.Signalv2ServerMessage_PublisherSdp{}, + }, + &livekit.Signalv2ServerMessage{ + Sequencer: &livekit.Sequencer{ + MessageId: firstMessageId + 2 + 4, + LastProcessedRemoteMessageId: lastProcessedRemoteMessageId, + }, + Message: &livekit.Signalv2ServerMessage_SubscriberSdp{}, + }, + &livekit.Signalv2ServerMessage{ + Sequencer: &livekit.Sequencer{ + MessageId: firstMessageId + 3 + 4, + LastProcessedRemoteMessageId: lastProcessedRemoteMessageId, + }, + Message: &livekit.Signalv2ServerMessage_RoomUpdate{}, + }, + } + + // AddBatch() - add all messages at once + cache.AddBatch(inputMessages) + + // get all messages in cache + outputMessages = cache.GetFromFront() + require.True(t, compareProtoSlices(expectedOutputMessages, outputMessages)) +} + +func compareProtoSlices(a []*livekit.Signalv2ServerMessage, b []*livekit.Signalv2ServerMessage) bool { + if len(a) != len(b) { + return false + } + + for i := 0; i < len(a); i++ { + if !proto.Equal(a[i], b[i]) { + return false + } + } + + return true +} diff --git a/signalling/utils.go b/signalling/utils.go new file mode 100644 index 000000000..e86980b68 --- /dev/null +++ b/signalling/utils.go @@ -0,0 +1,81 @@ +// Copyright 2023 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package signalling + +import ( + "encoding/json" + + "github.com/livekit/protocol/livekit" + "github.com/pion/webrtc/v4" +) + +func ToProtoSessionDescription(sd webrtc.SessionDescription, id uint32) *livekit.SessionDescription { + return &livekit.SessionDescription{ + Type: sd.Type.String(), + Sdp: sd.SDP, + Id: id, + } +} + +func FromProtoSessionDescription(sd *livekit.SessionDescription) (webrtc.SessionDescription, uint32) { + var sdType webrtc.SDPType + switch sd.Type { + case webrtc.SDPTypeOffer.String(): + sdType = webrtc.SDPTypeOffer + case webrtc.SDPTypeAnswer.String(): + sdType = webrtc.SDPTypeAnswer + case webrtc.SDPTypePranswer.String(): + sdType = webrtc.SDPTypePranswer + case webrtc.SDPTypeRollback.String(): + sdType = webrtc.SDPTypeRollback + } + return webrtc.SessionDescription{ + Type: sdType, + SDP: sd.Sdp, + }, sd.Id +} + +func ToProtoTrickle(candidateInit webrtc.ICECandidateInit, target livekit.SignalTarget, final bool) *livekit.TrickleRequest { + data, _ := json.Marshal(candidateInit) + return &livekit.TrickleRequest{ + CandidateInit: string(data), + Target: target, + Final: final, + } +} + +func FromProtoTrickle(trickle *livekit.TrickleRequest) (webrtc.ICECandidateInit, error) { + ci := webrtc.ICECandidateInit{} + err := json.Unmarshal([]byte(trickle.CandidateInit), &ci) + if err != nil { + return webrtc.ICECandidateInit{}, err + } + return ci, nil +} + +func FromProtoIceServers(iceservers []*livekit.ICEServer) []webrtc.ICEServer { + if iceservers == nil { + return nil + } + servers := make([]webrtc.ICEServer, 0, len(iceservers)) + for _, server := range iceservers { + servers = append(servers, webrtc.ICEServer{ + URLs: server.Urls, + Username: server.Username, + Credential: server.Credential, + }) + } + return servers +}