diff --git a/src/aiortc/rtcpeerconnection.py b/src/aiortc/rtcpeerconnection.py index 17e06cb0a..22ced5543 100644 --- a/src/aiortc/rtcpeerconnection.py +++ b/src/aiortc/rtcpeerconnection.py @@ -928,6 +928,8 @@ async def setRemoteDescription( iceTransport._role_set = True # set DTLS role + if description.type == "offer" and media.dtls.role == "client": + dtlsTransport._set_role(role="server") if description.type == "answer": dtlsTransport._set_role( role="server" if media.dtls.role == "client" else "client" @@ -1272,9 +1274,6 @@ def __validate_description( if not media.ice.usernameFragment or not media.ice.password: raise ValueError("ICE username fragment or password is missing") - # check DTLS role is allowed - if description.type == "offer" and media.dtls.role != "auto": - raise ValueError("DTLS setup attribute must be 'actpass' for an offer") if description.type in ["answer", "pranswer"] and media.dtls.role not in [ "client", "server", diff --git a/tests/test_rtcpeerconnection.py b/tests/test_rtcpeerconnection.py index ba51b4af1..3f5a2d33d 100644 --- a/tests/test_rtcpeerconnection.py +++ b/tests/test_rtcpeerconnection.py @@ -4685,30 +4685,6 @@ async def test_setRemoteDescription_media_mismatch(self): await pc1.close() await pc2.close() - @asynctest - async def test_setRemoteDescription_with_invalid_dtls_setup_for_offer(self): - pc1 = RTCPeerConnection() - pc2 = RTCPeerConnection() - - # apply offer - pc1.addTrack(AudioStreamTrack()) - offer = await pc1.createOffer() - await pc1.setLocalDescription(offer) - mangled = RTCSessionDescription( - sdp=pc1.localDescription.sdp.replace("a=setup:actpass", "a=setup:active"), - type=pc1.localDescription.type, - ) - with self.assertRaises(ValueError) as cm: - await pc2.setRemoteDescription(mangled) - self.assertEqual( - str(cm.exception), - "DTLS setup attribute must be 'actpass' for an offer", - ) - - # close - await pc1.close() - await pc2.close() - @asynctest async def test_setRemoteDescription_with_invalid_dtls_setup_for_answer(self): pc1 = RTCPeerConnection() @@ -4976,3 +4952,193 @@ async def test_setRemoteDescription_media_datachannel_bundled(self): "closed", ], ) + + @asynctest + async def test_dtls_role_offer_actpass(self): + pc1 = RTCPeerConnection() + pc2 = RTCPeerConnection() + + pc1_states = track_states(pc1) + pc2_states = track_states(pc2) + + self.assertEqual(pc1.iceConnectionState, "new") + self.assertEqual(pc1.iceGatheringState, "new") + self.assertIsNone(pc1.localDescription) + self.assertIsNone(pc1.remoteDescription) + + self.assertEqual(pc2.iceConnectionState, "new") + self.assertEqual(pc2.iceGatheringState, "new") + self.assertIsNone(pc2.localDescription) + self.assertIsNone(pc2.remoteDescription) + + # create offer + pc1.createDataChannel("chat", protocol="") + offer = await pc1.createOffer() + self.assertEqual(offer.type, "offer") + + await pc1.setLocalDescription(offer) + self.assertEqual(pc1.iceConnectionState, "new") + self.assertEqual(pc1.iceGatheringState, "complete") + + # set remote description + await pc2.setRemoteDescription(pc1.localDescription) + + # create answer + answer = await pc2.createAnswer() + self.assertHasDtls(answer, "active") + + await pc2.setLocalDescription(answer) + await self.assertIceChecking(pc2) + + # handle answer + await pc1.setRemoteDescription(pc2.localDescription) + self.assertEqual(pc1.remoteDescription, pc2.localDescription) + + # check outcome + await self.assertIceCompleted(pc1, pc2) + + self.assertEqual(pc1.sctp.transport._role, "server") + self.assertEqual(pc2.sctp.transport._role, "client") + # close + await pc1.close() + await pc2.close() + self.assertClosed(pc1) + self.assertClosed(pc2) + + # check state changes + self.assertEqual( + pc1_states["connectionState"], ["new", "connecting", "connected", "closed"] + ) + self.assertEqual( + pc2_states["connectionState"], ["new", "connecting", "connected", "closed"] + ) + + @asynctest + async def test_dtls_role_offer_passive(self): + pc1 = RTCPeerConnection() + pc2 = RTCPeerConnection() + + pc1_states = track_states(pc1) + pc2_states = track_states(pc2) + + self.assertEqual(pc1.iceConnectionState, "new") + self.assertEqual(pc1.iceGatheringState, "new") + self.assertIsNone(pc1.localDescription) + self.assertIsNone(pc1.remoteDescription) + + self.assertEqual(pc2.iceConnectionState, "new") + self.assertEqual(pc2.iceGatheringState, "new") + self.assertIsNone(pc2.localDescription) + self.assertIsNone(pc2.remoteDescription) + + # create offer + pc1.createDataChannel("chat", protocol="") + offer = await pc1.createOffer() + self.assertEqual(offer.type, "offer") + + await pc1.setLocalDescription(offer) + self.assertEqual(pc1.iceConnectionState, "new") + self.assertEqual(pc1.iceGatheringState, "complete") + + # handle offer with replaced DTLS role + await pc2.setRemoteDescription( + RTCSessionDescription( + type="offer", sdp=pc1.localDescription.sdp.replace("actpass", "passive") + ) + ) + + # create answer + answer = await pc2.createAnswer() + self.assertHasDtls(answer, "active") + + await pc2.setLocalDescription(answer) + await self.assertIceChecking(pc2) + + # handle answer + await pc1.setRemoteDescription(pc2.localDescription) + self.assertEqual(pc1.remoteDescription, pc2.localDescription) + + # check outcome + await self.assertIceCompleted(pc1, pc2) + + # pc1 is explicity passive so server. + self.assertEqual(pc1.sctp.transport._role, "server") + self.assertEqual(pc2.sctp.transport._role, "client") + # close + await pc1.close() + await pc2.close() + self.assertClosed(pc1) + self.assertClosed(pc2) + + # check state changes + self.assertEqual( + pc1_states["connectionState"], ["new", "connecting", "connected", "closed"] + ) + self.assertEqual( + pc2_states["connectionState"], ["new", "connecting", "connected", "closed"] + ) + + @asynctest + async def test_dtls_role_offer_active(self): + pc1 = RTCPeerConnection() + pc2 = RTCPeerConnection() + + pc1_states = track_states(pc1) + pc2_states = track_states(pc2) + + self.assertEqual(pc1.iceConnectionState, "new") + self.assertEqual(pc1.iceGatheringState, "new") + self.assertIsNone(pc1.localDescription) + self.assertIsNone(pc1.remoteDescription) + + self.assertEqual(pc2.iceConnectionState, "new") + self.assertEqual(pc2.iceGatheringState, "new") + self.assertIsNone(pc2.localDescription) + self.assertIsNone(pc2.remoteDescription) + + # create offer + pc1.createDataChannel("chat", protocol="") + offer = await pc1.createOffer() + self.assertEqual(offer.type, "offer") + + await pc1.setLocalDescription(offer) + self.assertEqual(pc1.iceConnectionState, "new") + self.assertEqual(pc1.iceGatheringState, "complete") + + # handle offer with replaced DTLS role + await pc2.setRemoteDescription( + RTCSessionDescription( + type="offer", sdp=pc1.localDescription.sdp.replace("actpass", "active") + ) + ) + + # create answer + answer = await pc2.createAnswer() + self.assertHasDtls(answer, "passive") + + await pc2.setLocalDescription(answer) + await self.assertIceChecking(pc2) + + # handle answer + await pc1.setRemoteDescription(pc2.localDescription) + self.assertEqual(pc1.remoteDescription, pc2.localDescription) + + # check outcome + await self.assertIceCompleted(pc1, pc2) + + # pc1 is explicity active so client. + self.assertEqual(pc1.sctp.transport._role, "client") + self.assertEqual(pc2.sctp.transport._role, "server") + # close + await pc1.close() + await pc2.close() + self.assertClosed(pc1) + self.assertClosed(pc2) + + # check state changes + self.assertEqual( + pc1_states["connectionState"], ["new", "connecting", "connected", "closed"] + ) + self.assertEqual( + pc2_states["connectionState"], ["new", "connecting", "connected", "closed"] + )