From d68ab0567bda3c6f4531949bf9089005bd433c47 Mon Sep 17 00:00:00 2001 From: xicilion Date: Tue, 30 Apr 2024 20:40:33 +0800 Subject: [PATCH] datachannel, feat: add datachannel library. --- datachannel/CMakeLists.txt | 18 + datachannel/LICENSE | 373 ++++ datachannel/README.md | 182 ++ datachannel/include/plog/.DS_Store | Bin 0 -> 6148 bytes .../include/plog/Appenders/AndroidAppender.h | 47 + .../include/plog/Appenders/ArduinoAppender.h | 23 + .../plog/Appenders/ColorConsoleAppender.h | 108 ++ .../include/plog/Appenders/ConsoleAppender.h | 83 + .../plog/Appenders/DebugOutputAppender.h | 16 + .../include/plog/Appenders/DynamicAppender.h | 42 + .../include/plog/Appenders/EventLogAppender.h | 117 ++ .../include/plog/Appenders/IAppender.h | 16 + .../plog/Appenders/RollingFileAppender.h | 148 ++ .../plog/Converters/NativeEOLConverter.h | 44 + .../include/plog/Converters/UTF8Converter.h | 28 + .../include/plog/Formatters/CsvFormatter.h | 57 + .../plog/Formatters/FuncMessageFormatter.h | 23 + .../plog/Formatters/MessageOnlyFormatter.h | 23 + .../include/plog/Formatters/TxtFormatter.h | 36 + datachannel/include/plog/Helpers/AscDump.h | 40 + datachannel/include/plog/Helpers/HexDump.h | 79 + datachannel/include/plog/Helpers/PrintVar.h | 24 + datachannel/include/plog/Init.h | 17 + .../plog/Initializers/ConsoleInitializer.h | 22 + .../Initializers/RollingFileInitializer.h | 80 + datachannel/include/plog/Log.h | 202 ++ datachannel/include/plog/Logger.h | 84 + datachannel/include/plog/Record.h | 435 +++++ datachannel/include/plog/Severity.h | 61 + datachannel/include/plog/Util.h | 616 ++++++ datachannel/include/plog/WinApi.h | 175 ++ datachannel/include/rtc/av1rtppacketizer.hpp | 56 + datachannel/include/rtc/candidate.hpp | 77 + datachannel/include/rtc/channel.hpp | 61 + datachannel/include/rtc/common.hpp | 86 + datachannel/include/rtc/configuration.hpp | 122 ++ datachannel/include/rtc/datachannel.hpp | 80 + datachannel/include/rtc/description.hpp | 324 ++++ datachannel/include/rtc/global.hpp | 59 + datachannel/include/rtc/h264rtppacketizer.hpp | 58 + datachannel/include/rtc/h265nalunit.hpp | 186 ++ datachannel/include/rtc/h265rtppacketizer.hpp | 56 + datachannel/include/rtc/mediahandler.hpp | 58 + datachannel/include/rtc/message.hpp | 79 + datachannel/include/rtc/nalunit.hpp | 226 +++ datachannel/include/rtc/peerconnection.hpp | 130 ++ datachannel/include/rtc/plihandler.hpp | 36 + datachannel/include/rtc/reliability.hpp | 43 + datachannel/include/rtc/rtc.h | 518 +++++ datachannel/include/rtc/rtc.hpp | 41 + datachannel/include/rtc/rtcpnackresponder.hpp | 76 + .../include/rtc/rtcpreceivingsession.hpp | 54 + datachannel/include/rtc/rtcpsrreporter.hpp | 46 + datachannel/include/rtc/rtp.hpp | 380 ++++ .../include/rtc/rtppacketizationconfig.hpp | 99 + datachannel/include/rtc/rtppacketizer.hpp | 88 + datachannel/include/rtc/track.hpp | 61 + datachannel/include/rtc/utils.hpp | 159 ++ datachannel/include/rtc/websocket.hpp | 67 + datachannel/include/rtc/websocketserver.hpp | 48 + datachannel/src/av1rtppacketizer.cpp | 225 +++ datachannel/src/candidate.cpp | 287 +++ datachannel/src/capi.cpp | 1673 +++++++++++++++++ datachannel/src/channel.cpp | 62 + datachannel/src/configuration.cpp | 155 ++ datachannel/src/datachannel.cpp | 57 + datachannel/src/description.cpp | 1398 ++++++++++++++ datachannel/src/global.cpp | 118 ++ datachannel/src/h264rtppacketizer.cpp | 112 ++ datachannel/src/h265nalunit.cpp | 100 + datachannel/src/h265rtppacketizer.cpp | 113 ++ datachannel/src/impl/certificate.cpp | 578 ++++++ datachannel/src/impl/certificate.hpp | 76 + datachannel/src/impl/channel.cpp | 96 + datachannel/src/impl/channel.hpp | 52 + datachannel/src/impl/datachannel.cpp | 393 ++++ datachannel/src/impl/datachannel.hpp | 93 + datachannel/src/impl/dtlssrtptransport.cpp | 393 ++++ datachannel/src/impl/dtlssrtptransport.hpp | 68 + datachannel/src/impl/dtlstransport.cpp | 1095 +++++++++++ datachannel/src/impl/dtlstransport.hpp | 125 ++ datachannel/src/impl/http.cpp | 66 + datachannel/src/impl/http.hpp | 30 + datachannel/src/impl/httpproxytransport.cpp | 129 ++ datachannel/src/impl/httpproxytransport.hpp | 50 + datachannel/src/impl/icetransport.cpp | 893 +++++++++ datachannel/src/impl/icetransport.hpp | 114 ++ datachannel/src/impl/init.cpp | 181 ++ datachannel/src/impl/init.hpp | 56 + datachannel/src/impl/internals.hpp | 54 + datachannel/src/impl/logcounter.cpp | 40 + datachannel/src/impl/logcounter.hpp | 41 + datachannel/src/impl/peerconnection.cpp | 1323 +++++++++++++ datachannel/src/impl/peerconnection.hpp | 164 ++ datachannel/src/impl/pollinterrupter.cpp | 125 ++ datachannel/src/impl/pollinterrupter.hpp | 44 + datachannel/src/impl/pollservice.cpp | 229 +++ datachannel/src/impl/pollservice.hpp | 82 + datachannel/src/impl/processor.cpp | 42 + datachannel/src/impl/processor.hpp | 76 + datachannel/src/impl/queue.hpp | 129 ++ datachannel/src/impl/sctptransport.cpp | 1005 ++++++++++ datachannel/src/impl/sctptransport.hpp | 135 ++ datachannel/src/impl/sha.cpp | 74 + datachannel/src/impl/sha.hpp | 25 + datachannel/src/impl/socket.hpp | 132 ++ datachannel/src/impl/tcpserver.cpp | 190 ++ datachannel/src/impl/tcpserver.hpp | 48 + datachannel/src/impl/tcptransport.cpp | 473 +++++ datachannel/src/impl/tcptransport.hpp | 80 + datachannel/src/impl/threadpool.cpp | 97 + datachannel/src/impl/threadpool.hpp | 118 ++ datachannel/src/impl/tls.cpp | 231 +++ datachannel/src/impl/tls.hpp | 96 + datachannel/src/impl/tlstransport.cpp | 834 ++++++++ datachannel/src/impl/tlstransport.hpp | 102 + datachannel/src/impl/track.cpp | 229 +++ datachannel/src/impl/track.hpp | 78 + datachannel/src/impl/transport.cpp | 79 + datachannel/src/impl/transport.hpp | 60 + datachannel/src/impl/utils.cpp | 183 ++ datachannel/src/impl/utils.hpp | 88 + datachannel/src/impl/verifiedtlstransport.cpp | 71 + datachannel/src/impl/verifiedtlstransport.hpp | 35 + datachannel/src/impl/websocket.cpp | 533 ++++++ datachannel/src/impl/websocket.hpp | 95 + datachannel/src/impl/websocketserver.cpp | 102 + datachannel/src/impl/websocketserver.hpp | 55 + datachannel/src/impl/wshandshake.cpp | 254 +++ datachannel/src/impl/wshandshake.hpp | 68 + datachannel/src/impl/wstransport.cpp | 424 +++++ datachannel/src/impl/wstransport.hpp | 91 + datachannel/src/mediahandler.cpp | 80 + datachannel/src/message.cpp | 79 + datachannel/src/nalunit.cpp | 99 + datachannel/src/peerconnection.cpp | 464 +++++ datachannel/src/plihandler.cpp | 45 + datachannel/src/rtcpnackresponder.cpp | 114 ++ datachannel/src/rtcpreceivingsession.cpp | 133 ++ datachannel/src/rtcpsrreporter.cpp | 90 + datachannel/src/rtp.cpp | 663 +++++++ datachannel/src/rtppacketizationconfig.cpp | 56 + datachannel/src/rtppacketizer.cpp | 109 ++ datachannel/src/track.cpp | 73 + datachannel/src/websocket.cpp | 96 + datachannel/src/websocketserver.cpp | 36 + libs.cmake | 1 + 147 files changed, 25953 insertions(+) create mode 100644 datachannel/CMakeLists.txt create mode 100644 datachannel/LICENSE create mode 100644 datachannel/README.md create mode 100644 datachannel/include/plog/.DS_Store create mode 100644 datachannel/include/plog/Appenders/AndroidAppender.h create mode 100644 datachannel/include/plog/Appenders/ArduinoAppender.h create mode 100644 datachannel/include/plog/Appenders/ColorConsoleAppender.h create mode 100644 datachannel/include/plog/Appenders/ConsoleAppender.h create mode 100644 datachannel/include/plog/Appenders/DebugOutputAppender.h create mode 100644 datachannel/include/plog/Appenders/DynamicAppender.h create mode 100644 datachannel/include/plog/Appenders/EventLogAppender.h create mode 100644 datachannel/include/plog/Appenders/IAppender.h create mode 100644 datachannel/include/plog/Appenders/RollingFileAppender.h create mode 100644 datachannel/include/plog/Converters/NativeEOLConverter.h create mode 100644 datachannel/include/plog/Converters/UTF8Converter.h create mode 100644 datachannel/include/plog/Formatters/CsvFormatter.h create mode 100644 datachannel/include/plog/Formatters/FuncMessageFormatter.h create mode 100644 datachannel/include/plog/Formatters/MessageOnlyFormatter.h create mode 100644 datachannel/include/plog/Formatters/TxtFormatter.h create mode 100644 datachannel/include/plog/Helpers/AscDump.h create mode 100644 datachannel/include/plog/Helpers/HexDump.h create mode 100644 datachannel/include/plog/Helpers/PrintVar.h create mode 100644 datachannel/include/plog/Init.h create mode 100644 datachannel/include/plog/Initializers/ConsoleInitializer.h create mode 100644 datachannel/include/plog/Initializers/RollingFileInitializer.h create mode 100644 datachannel/include/plog/Log.h create mode 100644 datachannel/include/plog/Logger.h create mode 100644 datachannel/include/plog/Record.h create mode 100644 datachannel/include/plog/Severity.h create mode 100644 datachannel/include/plog/Util.h create mode 100644 datachannel/include/plog/WinApi.h create mode 100644 datachannel/include/rtc/av1rtppacketizer.hpp create mode 100644 datachannel/include/rtc/candidate.hpp create mode 100644 datachannel/include/rtc/channel.hpp create mode 100644 datachannel/include/rtc/common.hpp create mode 100644 datachannel/include/rtc/configuration.hpp create mode 100644 datachannel/include/rtc/datachannel.hpp create mode 100644 datachannel/include/rtc/description.hpp create mode 100644 datachannel/include/rtc/global.hpp create mode 100644 datachannel/include/rtc/h264rtppacketizer.hpp create mode 100644 datachannel/include/rtc/h265nalunit.hpp create mode 100644 datachannel/include/rtc/h265rtppacketizer.hpp create mode 100644 datachannel/include/rtc/mediahandler.hpp create mode 100644 datachannel/include/rtc/message.hpp create mode 100644 datachannel/include/rtc/nalunit.hpp create mode 100644 datachannel/include/rtc/peerconnection.hpp create mode 100644 datachannel/include/rtc/plihandler.hpp create mode 100644 datachannel/include/rtc/reliability.hpp create mode 100644 datachannel/include/rtc/rtc.h create mode 100644 datachannel/include/rtc/rtc.hpp create mode 100644 datachannel/include/rtc/rtcpnackresponder.hpp create mode 100644 datachannel/include/rtc/rtcpreceivingsession.hpp create mode 100644 datachannel/include/rtc/rtcpsrreporter.hpp create mode 100644 datachannel/include/rtc/rtp.hpp create mode 100644 datachannel/include/rtc/rtppacketizationconfig.hpp create mode 100644 datachannel/include/rtc/rtppacketizer.hpp create mode 100644 datachannel/include/rtc/track.hpp create mode 100644 datachannel/include/rtc/utils.hpp create mode 100644 datachannel/include/rtc/websocket.hpp create mode 100644 datachannel/include/rtc/websocketserver.hpp create mode 100644 datachannel/src/av1rtppacketizer.cpp create mode 100644 datachannel/src/candidate.cpp create mode 100644 datachannel/src/capi.cpp create mode 100644 datachannel/src/channel.cpp create mode 100644 datachannel/src/configuration.cpp create mode 100644 datachannel/src/datachannel.cpp create mode 100644 datachannel/src/description.cpp create mode 100644 datachannel/src/global.cpp create mode 100644 datachannel/src/h264rtppacketizer.cpp create mode 100644 datachannel/src/h265nalunit.cpp create mode 100644 datachannel/src/h265rtppacketizer.cpp create mode 100644 datachannel/src/impl/certificate.cpp create mode 100644 datachannel/src/impl/certificate.hpp create mode 100644 datachannel/src/impl/channel.cpp create mode 100644 datachannel/src/impl/channel.hpp create mode 100644 datachannel/src/impl/datachannel.cpp create mode 100644 datachannel/src/impl/datachannel.hpp create mode 100644 datachannel/src/impl/dtlssrtptransport.cpp create mode 100644 datachannel/src/impl/dtlssrtptransport.hpp create mode 100644 datachannel/src/impl/dtlstransport.cpp create mode 100644 datachannel/src/impl/dtlstransport.hpp create mode 100644 datachannel/src/impl/http.cpp create mode 100644 datachannel/src/impl/http.hpp create mode 100644 datachannel/src/impl/httpproxytransport.cpp create mode 100644 datachannel/src/impl/httpproxytransport.hpp create mode 100644 datachannel/src/impl/icetransport.cpp create mode 100644 datachannel/src/impl/icetransport.hpp create mode 100644 datachannel/src/impl/init.cpp create mode 100644 datachannel/src/impl/init.hpp create mode 100644 datachannel/src/impl/internals.hpp create mode 100644 datachannel/src/impl/logcounter.cpp create mode 100644 datachannel/src/impl/logcounter.hpp create mode 100644 datachannel/src/impl/peerconnection.cpp create mode 100644 datachannel/src/impl/peerconnection.hpp create mode 100644 datachannel/src/impl/pollinterrupter.cpp create mode 100644 datachannel/src/impl/pollinterrupter.hpp create mode 100644 datachannel/src/impl/pollservice.cpp create mode 100644 datachannel/src/impl/pollservice.hpp create mode 100644 datachannel/src/impl/processor.cpp create mode 100644 datachannel/src/impl/processor.hpp create mode 100644 datachannel/src/impl/queue.hpp create mode 100644 datachannel/src/impl/sctptransport.cpp create mode 100644 datachannel/src/impl/sctptransport.hpp create mode 100644 datachannel/src/impl/sha.cpp create mode 100644 datachannel/src/impl/sha.hpp create mode 100644 datachannel/src/impl/socket.hpp create mode 100644 datachannel/src/impl/tcpserver.cpp create mode 100644 datachannel/src/impl/tcpserver.hpp create mode 100644 datachannel/src/impl/tcptransport.cpp create mode 100644 datachannel/src/impl/tcptransport.hpp create mode 100644 datachannel/src/impl/threadpool.cpp create mode 100644 datachannel/src/impl/threadpool.hpp create mode 100644 datachannel/src/impl/tls.cpp create mode 100644 datachannel/src/impl/tls.hpp create mode 100644 datachannel/src/impl/tlstransport.cpp create mode 100644 datachannel/src/impl/tlstransport.hpp create mode 100644 datachannel/src/impl/track.cpp create mode 100644 datachannel/src/impl/track.hpp create mode 100644 datachannel/src/impl/transport.cpp create mode 100644 datachannel/src/impl/transport.hpp create mode 100644 datachannel/src/impl/utils.cpp create mode 100644 datachannel/src/impl/utils.hpp create mode 100644 datachannel/src/impl/verifiedtlstransport.cpp create mode 100644 datachannel/src/impl/verifiedtlstransport.hpp create mode 100644 datachannel/src/impl/websocket.cpp create mode 100644 datachannel/src/impl/websocket.hpp create mode 100644 datachannel/src/impl/websocketserver.cpp create mode 100644 datachannel/src/impl/websocketserver.hpp create mode 100644 datachannel/src/impl/wshandshake.cpp create mode 100644 datachannel/src/impl/wshandshake.hpp create mode 100644 datachannel/src/impl/wstransport.cpp create mode 100644 datachannel/src/impl/wstransport.hpp create mode 100644 datachannel/src/mediahandler.cpp create mode 100644 datachannel/src/message.cpp create mode 100644 datachannel/src/nalunit.cpp create mode 100644 datachannel/src/peerconnection.cpp create mode 100644 datachannel/src/plihandler.cpp create mode 100644 datachannel/src/rtcpnackresponder.cpp create mode 100644 datachannel/src/rtcpreceivingsession.cpp create mode 100644 datachannel/src/rtcpsrreporter.cpp create mode 100644 datachannel/src/rtp.cpp create mode 100644 datachannel/src/rtppacketizationconfig.cpp create mode 100644 datachannel/src/rtppacketizer.cpp create mode 100644 datachannel/src/track.cpp create mode 100644 datachannel/src/websocket.cpp create mode 100644 datachannel/src/websocketserver.cpp diff --git a/datachannel/CMakeLists.txt b/datachannel/CMakeLists.txt new file mode 100644 index 000000000..ad17ffdcf --- /dev/null +++ b/datachannel/CMakeLists.txt @@ -0,0 +1,18 @@ +cmake_minimum_required(VERSION 3.5) + +include_directories( + ${PROJECT_SOURCE_DIR}/include/rtc + ${PROJECT_SOURCE_DIR}/src + ${PROJECT_SOURCE_DIR}/../openssl/include + ${PROJECT_SOURCE_DIR}/../juice/include + ${PROJECT_SOURCE_DIR}/../usrsctp/include +) + +add_definitions( + -DRTC_STATIC + -DJUICE_STATIC + -DRTC_ENABLE_MEDIA=0 + -DRTC_ENABLE_WEBSOCKET=0 +) + +include(../build_tools/cmake/Library.cmake) diff --git a/datachannel/LICENSE b/datachannel/LICENSE new file mode 100644 index 000000000..14e2f777f --- /dev/null +++ b/datachannel/LICENSE @@ -0,0 +1,373 @@ +Mozilla Public License Version 2.0 +================================== + +1. Definitions +-------------- + +1.1. "Contributor" + means each individual or legal entity that creates, contributes to + the creation of, or owns Covered Software. + +1.2. "Contributor Version" + means the combination of the Contributions of others (if any) used + by a Contributor and that particular Contributor's Contribution. + +1.3. "Contribution" + means Covered Software of a particular Contributor. + +1.4. "Covered Software" + means Source Code Form to which the initial Contributor has attached + the notice in Exhibit A, the Executable Form of such Source Code + Form, and Modifications of such Source Code Form, in each case + including portions thereof. + +1.5. "Incompatible With Secondary Licenses" + means + + (a) that the initial Contributor has attached the notice described + in Exhibit B to the Covered Software; or + + (b) that the Covered Software was made available under the terms of + version 1.1 or earlier of the License, but not also under the + terms of a Secondary License. + +1.6. "Executable Form" + means any form of the work other than Source Code Form. + +1.7. "Larger Work" + means a work that combines Covered Software with other material, in + a separate file or files, that is not Covered Software. + +1.8. "License" + means this document. + +1.9. "Licensable" + means having the right to grant, to the maximum extent possible, + whether at the time of the initial grant or subsequently, any and + all of the rights conveyed by this License. + +1.10. "Modifications" + means any of the following: + + (a) any file in Source Code Form that results from an addition to, + deletion from, or modification of the contents of Covered + Software; or + + (b) any new file in Source Code Form that contains any Covered + Software. + +1.11. "Patent Claims" of a Contributor + means any patent claim(s), including without limitation, method, + process, and apparatus claims, in any patent Licensable by such + Contributor that would be infringed, but for the grant of the + License, by the making, using, selling, offering for sale, having + made, import, or transfer of either its Contributions or its + Contributor Version. + +1.12. "Secondary License" + means either the GNU General Public License, Version 2.0, the GNU + Lesser General Public License, Version 2.1, the GNU Affero General + Public License, Version 3.0, or any later versions of those + licenses. + +1.13. "Source Code Form" + means the form of the work preferred for making modifications. + +1.14. "You" (or "Your") + means an individual or a legal entity exercising rights under this + License. For legal entities, "You" includes any entity that + controls, is controlled by, or is under common control with You. For + purposes of this definition, "control" means (a) the power, direct + or indirect, to cause the direction or management of such entity, + whether by contract or otherwise, or (b) ownership of more than + fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. + +2. License Grants and Conditions +-------------------------------- + +2.1. Grants + +Each Contributor hereby grants You a world-wide, royalty-free, +non-exclusive license: + +(a) under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or + as part of a Larger Work; and + +(b) under Patent Claims of such Contributor to make, use, sell, offer + for sale, have made, import, and otherwise transfer either its + Contributions or its Contributor Version. + +2.2. Effective Date + +The licenses granted in Section 2.1 with respect to any Contribution +become effective for each Contribution on the date the Contributor first +distributes such Contribution. + +2.3. Limitations on Grant Scope + +The licenses granted in this Section 2 are the only rights granted under +this License. No additional rights or licenses will be implied from the +distribution or licensing of Covered Software under this License. +Notwithstanding Section 2.1(b) above, no patent license is granted by a +Contributor: + +(a) for any code that a Contributor has removed from Covered Software; + or + +(b) for infringements caused by: (i) Your and any other third party's + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + +(c) under Patent Claims infringed by Covered Software in the absence of + its Contributions. + +This License does not grant any rights in the trademarks, service marks, +or logos of any Contributor (except as may be necessary to comply with +the notice requirements in Section 3.4). + +2.4. Subsequent Licenses + +No Contributor makes additional grants as a result of Your choice to +distribute the Covered Software under a subsequent version of this +License (see Section 10.2) or under the terms of a Secondary License (if +permitted under the terms of Section 3.3). + +2.5. Representation + +Each Contributor represents that the Contributor believes its +Contributions are its original creation(s) or it has sufficient rights +to grant the rights to its Contributions conveyed by this License. + +2.6. Fair Use + +This License is not intended to limit any rights You have under +applicable copyright doctrines of fair use, fair dealing, or other +equivalents. + +2.7. Conditions + +Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted +in Section 2.1. + +3. Responsibilities +------------------- + +3.1. Distribution of Source Form + +All distribution of Covered Software in Source Code Form, including any +Modifications that You create or to which You contribute, must be under +the terms of this License. You must inform recipients that the Source +Code Form of the Covered Software is governed by the terms of this +License, and how they can obtain a copy of this License. You may not +attempt to alter or restrict the recipients' rights in the Source Code +Form. + +3.2. Distribution of Executable Form + +If You distribute Covered Software in Executable Form then: + +(a) such Covered Software must also be made available in Source Code + Form, as described in Section 3.1, and You must inform recipients of + the Executable Form how they can obtain a copy of such Source Code + Form by reasonable means in a timely manner, at a charge no more + than the cost of distribution to the recipient; and + +(b) You may distribute such Executable Form under the terms of this + License, or sublicense it under different terms, provided that the + license for the Executable Form does not attempt to limit or alter + the recipients' rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + +You may create and distribute a Larger Work under terms of Your choice, +provided that You also comply with the requirements of this License for +the Covered Software. If the Larger Work is a combination of Covered +Software with a work governed by one or more Secondary Licenses, and the +Covered Software is not Incompatible With Secondary Licenses, this +License permits You to additionally distribute such Covered Software +under the terms of such Secondary License(s), so that the recipient of +the Larger Work may, at their option, further distribute the Covered +Software under the terms of either this License or such Secondary +License(s). + +3.4. Notices + +You may not remove or alter the substance of any license notices +(including copyright notices, patent notices, disclaimers of warranty, +or limitations of liability) contained within the Source Code Form of +the Covered Software, except that You may alter any license notices to +the extent required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + +You may choose to offer, and to charge a fee for, warranty, support, +indemnity or liability obligations to one or more recipients of Covered +Software. However, You may do so only on Your own behalf, and not on +behalf of any Contributor. You must make it absolutely clear that any +such warranty, support, indemnity, or liability obligation is offered by +You alone, and You hereby agree to indemnify every Contributor for any +liability incurred by such Contributor as a result of warranty, support, +indemnity or liability terms You offer. You may include additional +disclaimers of warranty and limitations of liability specific to any +jurisdiction. + +4. Inability to Comply Due to Statute or Regulation +--------------------------------------------------- + +If it is impossible for You to comply with any of the terms of this +License with respect to some or all of the Covered Software due to +statute, judicial order, or regulation then You must: (a) comply with +the terms of this License to the maximum extent possible; and (b) +describe the limitations and the code they affect. Such description must +be placed in a text file included with all distributions of the Covered +Software under this License. Except to the extent prohibited by statute +or regulation, such description must be sufficiently detailed for a +recipient of ordinary skill to be able to understand it. + +5. Termination +-------------- + +5.1. The rights granted under this License will terminate automatically +if You fail to comply with any of its terms. However, if You become +compliant, then the rights granted under this License from a particular +Contributor are reinstated (a) provisionally, unless and until such +Contributor explicitly and finally terminates Your grants, and (b) on an +ongoing basis, if such Contributor fails to notify You of the +non-compliance by some reasonable means prior to 60 days after You have +come back into compliance. Moreover, Your grants from a particular +Contributor are reinstated on an ongoing basis if such Contributor +notifies You of the non-compliance by some reasonable means, this is the +first time You have received notice of non-compliance with this License +from such Contributor, and You become compliant prior to 30 days after +Your receipt of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent +infringement claim (excluding declaratory judgment actions, +counter-claims, and cross-claims) alleging that a Contributor Version +directly or indirectly infringes any patent, then the rights granted to +You by any and all Contributors for the Covered Software under Section +2.1 of this License shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all +end user license agreements (excluding distributors and resellers) which +have been validly granted by You or Your distributors under this License +prior to termination shall survive termination. + +************************************************************************ +* * +* 6. Disclaimer of Warranty * +* ------------------------- * +* * +* Covered Software is provided under this License on an "as is" * +* basis, without warranty of any kind, either expressed, implied, or * +* statutory, including, without limitation, warranties that the * +* Covered Software is free of defects, merchantable, fit for a * +* particular purpose or non-infringing. The entire risk as to the * +* quality and performance of the Covered Software is with You. * +* Should any Covered Software prove defective in any respect, You * +* (not any Contributor) assume the cost of any necessary servicing, * +* repair, or correction. This disclaimer of warranty constitutes an * +* essential part of this License. No use of any Covered Software is * +* authorized under this License except under this disclaimer. * +* * +************************************************************************ + +************************************************************************ +* * +* 7. Limitation of Liability * +* -------------------------- * +* * +* Under no circumstances and under no legal theory, whether tort * +* (including negligence), contract, or otherwise, shall any * +* Contributor, or anyone who distributes Covered Software as * +* permitted above, be liable to You for any direct, indirect, * +* special, incidental, or consequential damages of any character * +* including, without limitation, damages for lost profits, loss of * +* goodwill, work stoppage, computer failure or malfunction, or any * +* and all other commercial damages or losses, even if such party * +* shall have been informed of the possibility of such damages. This * +* limitation of liability shall not apply to liability for death or * +* personal injury resulting from such party's negligence to the * +* extent applicable law prohibits such limitation. Some * +* jurisdictions do not allow the exclusion or limitation of * +* incidental or consequential damages, so this exclusion and * +* limitation may not apply to You. * +* * +************************************************************************ + +8. Litigation +------------- + +Any litigation relating to this License may be brought only in the +courts of a jurisdiction where the defendant maintains its principal +place of business and such litigation shall be governed by laws of that +jurisdiction, without reference to its conflict-of-law provisions. +Nothing in this Section shall prevent a party's ability to bring +cross-claims or counter-claims. + +9. Miscellaneous +---------------- + +This License represents the complete agreement concerning the subject +matter hereof. If any provision of this License is held to be +unenforceable, such provision shall be reformed only to the extent +necessary to make it enforceable. Any law or regulation which provides +that the language of a contract shall be construed against the drafter +shall not be used to construe this License against a Contributor. + +10. Versions of the License +--------------------------- + +10.1. New Versions + +Mozilla Foundation is the license steward. Except as provided in Section +10.3, no one other than the license steward has the right to modify or +publish new versions of this License. Each version will be given a +distinguishing version number. + +10.2. Effect of New Versions + +You may distribute the Covered Software under the terms of the version +of the License under which You originally received the Covered Software, +or under the terms of any subsequent version published by the license +steward. + +10.3. Modified Versions + +If you create software not governed by this License, and you want to +create a new license for such software, you may create and use a +modified version of this License if you rename the license and remove +any references to the name of the license steward (except to note that +such modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary +Licenses + +If You choose to distribute Source Code Form that is Incompatible With +Secondary Licenses under the terms of this version of the License, the +notice described in Exhibit B of this License must be attached. + +Exhibit A - Source Code Form License Notice +------------------------------------------- + + This Source Code Form is subject to the terms of the Mozilla Public + License, v. 2.0. If a copy of the MPL was not distributed with this + file, You can obtain one at http://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular +file, then You may include the notice in a location (such as a LICENSE +file in a relevant directory) where a recipient would be likely to look +for such a notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - "Incompatible With Secondary Licenses" Notice +--------------------------------------------------------- + + This Source Code Form is "Incompatible With Secondary Licenses", as + defined by the Mozilla Public License, v. 2.0. diff --git a/datachannel/README.md b/datachannel/README.md new file mode 100644 index 000000000..793b33c85 --- /dev/null +++ b/datachannel/README.md @@ -0,0 +1,182 @@ +# libdatachannel - C/C++ WebRTC network library + +[![License: MPL 2.0](https://img.shields.io/badge/License-MPL_2.0-blue.svg)](https://www.mozilla.org/en-US/MPL/2.0/) +[![Build with GnuTLS](https://github.com/paullouisageneau/libdatachannel/actions/workflows/build-gnutls.yml/badge.svg)](https://github.com/paullouisageneau/libdatachannel/actions/workflows/build-gnutls.yml) +[![Build with Mbed TLS](https://github.com/paullouisageneau/libdatachannel/actions/workflows/build-mbedtls.yml/badge.svg)](https://github.com/paullouisageneau/libdatachannel/actions/workflows/build-mbedtls.yml) +[![Build with OpenSSL](https://github.com/paullouisageneau/libdatachannel/actions/workflows/build-openssl.yml/badge.svg)](https://github.com/paullouisageneau/libdatachannel/actions/workflows/build-openssl.yml) + +[![AUR package](https://repology.org/badge/version-for-repo/aur/libdatachannel.svg)](https://repology.org/project/libdatachannel/versions) [![FreeBSD port](https://repology.org/badge/version-for-repo/freebsd/libdatachannel.svg)](https://repology.org/project/libdatachannel/versions) [![Vcpkg package](https://repology.org/badge/version-for-repo/vcpkg/libdatachannel.svg)](https://repology.org/project/libdatachannel/versions) +[![Gitter](https://badges.gitter.im/libdatachannel/community.svg)](https://gitter.im/libdatachannel/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![Discord](https://img.shields.io/discord/903257095539925006?logo=discord)](https://discord.gg/jXAP8jp3Nn) + +libdatachannel is a standalone implementation of WebRTC Data Channels, WebRTC Media Transport, and WebSockets in C++17 with C bindings for POSIX platforms (including GNU/Linux, Android, FreeBSD, Apple macOS and iOS) and Microsoft Windows. WebRTC is a W3C and IETF standard enabling real-time peer-to-peer data and media exchange between two devices. + +The library aims at being both straightforward and lightweight with minimal external dependencies, to enable direct connectivity between native applications and web browsers without the pain of importing Google's bloated [reference library](https://webrtc.googlesource.com/src/). The interface consists of somewhat simplified versions of the JavaScript WebRTC and WebSocket APIs present in browsers, in order to ease the design of cross-environment applications. + +It can be compiled with multiple backends: +- The security layer can be provided through [GnuTLS](https://www.gnutls.org/), [Mbed TLS](https://www.trustedfirmware.org/projects/mbed-tls/), or [OpenSSL](https://www.openssl.org/). +- The connectivity for WebRTC can be provided through my ad-hoc ICE library [libjuice](https://github.com/paullouisageneau/libjuice) as submodule or through [libnice](https://github.com/libnice/libnice). + +The WebRTC stack is fully compatible with browsers like Firefox and Chromium, see [Compatibility](#Compatibility) below. Additionally, code using Data Channels and WebSockets from the library may be compiled as is to WebAssembly for browsers with [datachannel-wasm](https://github.com/paullouisageneau/datachannel-wasm). + +libdatachannel is licensed under MPL 2.0 since version 0.18, see [LICENSE](https://github.com/paullouisageneau/libdatachannel/blob/master/LICENSE) (previous versions were licensed under LGPLv2.1 or later). + +libdatachannel is available on [AUR](https://aur.archlinux.org/packages/libdatachannel/), [vcpkg](https://vcpkg.io/en/getting-started), and [FreeBSD ports](https://www.freshports.org/www/libdatachannel). Bindings are available for [Rust](https://crates.io/crates/datachannel) and [Node.js](https://www.npmjs.com/package/node-datachannel). + +## Dependencies + +- [GnuTLS](https://www.gnutls.org/), [Mbed TLS](https://www.trustedfirmware.org/projects/mbed-tls/), or [OpenSSL](https://www.openssl.org/) +- [usrsctp](https://github.com/sctplab/usrsctp) (as submodule by default) +- [plog](https://github.com/SergiusTheBest/plog) (as submodule by default) +- [libjuice](https://github.com/paullouisageneau/libjuice) (as submodule by default) or [libnice](https://nice.freedesktop.org/) as an ICE backend. +- [libsrtp](https://github.com/cisco/libsrtp) (as submodule by default) required if compiled with media support. +- [nlohmann JSON](https://github.com/nlohmann/json) (as submodule by default) required to build examples. + +## Building + +See [BUILDING.md](https://github.com/paullouisageneau/libdatachannel/blob/master/BUILDING.md) for building instructions. + +## Examples + +See [examples](https://github.com/paullouisageneau/libdatachannel/blob/master/examples/) for complete usage examples with signaling server (under MPL 2.0). + +Additionally, you might want to have a look at the [C API documentation](https://github.com/paullouisageneau/libdatachannel/blob/master/DOC.md). + +### Signal a PeerConnection + +```cpp +#include "rtc/rtc.hpp" +``` + +```cpp +rtc::Configuration config; +config.iceServers.emplace_back("mystunserver.org:3478"); + +rtc::PeerConnection pc(config); + +pc.onLocalDescription([](rtc::Description sdp) { + // Send the SDP to the remote peer + MY_SEND_DESCRIPTION_TO_REMOTE(std::string(sdp)); +}); + +pc.onLocalCandidate([](rtc::Candidate candidate) { + // Send the candidate to the remote peer + MY_SEND_CANDIDATE_TO_REMOTE(candidate.candidate(), candidate.mid()); +}); + +MY_ON_RECV_DESCRIPTION_FROM_REMOTE([&pc](std::string sdp) { + pc.setRemoteDescription(rtc::Description(sdp)); +}); + +MY_ON_RECV_CANDIDATE_FROM_REMOTE([&pc](std::string candidate, std::string mid) { + pc.addRemoteCandidate(rtc::Candidate(candidate, mid)); +}); +``` + +### Observe the PeerConnection state + +```cpp +pc.onStateChange([](rtc::PeerConnection::State state) { + std::cout << "State: " << state << std::endl; +}); + +pc.onGatheringStateChange([](rtc::PeerConnection::GatheringState state) { + std::cout << "Gathering state: " << state << std::endl; +}); +``` + +### Create a DataChannel + +```cpp +auto dc = pc.createDataChannel("test"); + +dc->onOpen([]() { + std::cout << "Open" << std::endl; +}); + +dc->onMessage([](std::variant message) { + if (std::holds_alternative(message)) { + std::cout << "Received: " << get(message) << std::endl; + } +}); +``` + +### Receive a DataChannel + +```cpp +std::shared_ptr dc; +pc.onDataChannel([&dc](std::shared_ptr incoming) { + dc = incoming; + dc->send("Hello world!"); +}); +``` + +### Open a WebSocket + +```cpp +rtc::WebSocket ws; + +ws.onOpen([]() { + std::cout << "WebSocket open" << std::endl; +}); + +ws.onMessage([](std::variant message) { + if (std::holds_alternative(message)) { + std::cout << "WebSocket received: " << std::get(message) << endl; + } +}); + +ws.open("wss://my.websocket/service"); +``` + +## Compatibility + +The library implements the following communication protocols: + +### WebRTC Data Channels and Media Transport + +WebRTC allows real-time data and media exchange between two devices through a Peer Connection (or RTCPeerConnection), a signaled peer-to-peer connection which can carry both Data Channels and media tracks. It is compatible with browsers Firefox, Chromium, and Safari, and other WebRTC libraries (see [webrtc-echoes](https://github.com/sipsorcery/webrtc-echoes)). Media transport is optional and can be disabled at compile time. + +Protocol stack: +- SCTP-based Data Channels ([RFC8831](https://www.rfc-editor.org/rfc/rfc8831.html)) +- SRTP-based Media Transport ([RFC8834](https://www.rfc-editor.org/rfc/rfc8834.html)) +- DTLS/UDP ([RFC7350](https://www.rfc-editor.org/rfc/rfc7350.html) and [RFC8261](https://www.rfc-editor.org/rfc/rfc8261.html)) +- ICE ([RFC8445](https://www.rfc-editor.org/rfc/rfc8445.html)) with STUN ([RFC8489](https://www.rfc-editor.org/rfc/rfc8489.html)) and its extension TURN ([RFC8656](https://www.rfc-editor.org/rfc/rfc8656.html)) + +Features: +- Full IPv6 support (as mandated by [RFC8835](https://www.rfc-editor.org/rfc/rfc8835.html)) +- Trickle ICE ([RFC8838](https://www.rfc-editor.org/rfc/rfc8838.html)) +- JSEP-compatible session establishment with SDP ([RFC8829](https://www.rfc-editor.org/rfc/rfc8829.html)) +- SCTP over DTLS with SDP offer/answer ([RFC8841](https://www.rfc-editor.org/rfc/rfc8841.html)) +- DTLS with ECDSA or RSA keys ([RFC8827](https://www.rfc-editor.org/rfc/rfc8827.html)) +- SRTP and SRTCP key derivation from DTLS ([RFC5764](https://www.rfc-editor.org/rfc/rfc5764.html)) +- Differentiated Services QoS ([RFC8837](https://www.rfc-editor.org/rfc/rfc8837.html)) where possible +- Multicast DNS candidates ([draft-ietf-rtcweb-mdns-ice-candidates-04](https://datatracker.ietf.org/doc/html/draft-ietf-rtcweb-mdns-ice-candidates-04)) +- Multiplexing connections on a single UDP port with libjuice as ICE backend + +Note only SDP BUNDLE mode is supported for media multiplexing ([RFC8843](https://www.rfc-editor.org/rfc/rfc8843.html)). The behavior is equivalent to the JSEP bundle-only policy: the library always negotiates one unique network component, where SRTP media streams are multiplexed with SRTCP control packets ([RFC5761](https://www.rfc-editor.org/rfc/rfc5761.html)) and SCTP/DTLS data traffic ([RFC8261](https://www.rfc-editor.org/rfc/rfc8261.html)). + +### WebSocket + +WebSocket is the protocol of choice for WebRTC signaling. The support is optional and can be disabled at compile time. + +Protocol stack: +- WebSocket protocol ([RFC6455](https://www.rfc-editor.org/rfc/rfc6455.html)), client and server side +- HTTP over TLS ([RFC2818](https://www.rfc-editor.org/rfc/rfc2818.html)) + +Features: +- IPv6 and IPv4/IPv6 dual-stack support +- Keepalive with ping/pong + +## External resources +- Rust bindings for libdatachannel: [datachannel-rs](https://github.com/lerouxrgd/datachannel-rs) +- Node.js bindings for libdatachannel: [node-datachannel](https://github.com/murat-dogan/node-datachannel) +- Unity bindings for Windows 10 and Hololens: [datachannel-unity](https://github.com/hanseuljun/datachannel-unity) +- WebAssembly wrapper compatible with libdatachannel: [datachannel-wasm](https://github.com/paullouisageneau/datachannel-wasm) +- Lightweight STUN/TURN server: [Violet](https://github.com/paullouisageneau/violet) +- Native platform (Android/iOS/macOS) wrapper for libdatachannel: [datachannel-native](https://github.com/swarm-cloud/datachannel-native) + +## Thanks + +Thanks to [Streamr](https://streamr.network/), [Vagon](https://vagon.io/), [Shiguredo](https://github.com/shiguredo), [Deon Botha](https://github.com/dbotha), and [Michael Cho](https://github.com/micoolcho) for sponsoring this work! + diff --git a/datachannel/include/plog/.DS_Store b/datachannel/include/plog/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..f907fb4f2908f6869c4847cc37fa529e412d2d9b GIT binary patch literal 6148 zcmeHKK}+jE5T4a)Q$^@OK|C#ZE!e1t;w4sncs8O3m70+7UNB}$lD3CZ$Pdszi{5yX%~ z5e{cg+YQTrW#G3lKzFwRXVAkZhEMmm=!K~Y(fS_t(xfQc?GLfETw7VKJ9TH>dG!CO zp!?EY7t<62*Ar;V#J{HSDN<78TmpHwLHS zG`h{5(dfiIYDstBJs7v-$Za1Sw&d~A@p#;Dws!YUue;CbOQv28nF4c+i82Ez>Au877}aEl>~JMK$9=P_0WjXMZ#J_LVRaEBuF*YW*Q9|z$Xw4-IfGVqmw zWwUHj{lEOa|NrV_SC#?Gz(O$~Y8U=R2Sei8Iujh#T8a9EN +#include + +namespace plog +{ + template + class PLOG_LINKAGE_HIDDEN AndroidAppender : public IAppender + { + public: + AndroidAppender(const char* tag) : m_tag(tag) + { + } + + virtual void write(const Record& record) PLOG_OVERRIDE + { + std::string str = Formatter::format(record); + + __android_log_print(toPriority(record.getSeverity()), m_tag, "%s", str.c_str()); + } + + private: + static android_LogPriority toPriority(Severity severity) + { + switch (severity) + { + case fatal: + return ANDROID_LOG_FATAL; + case error: + return ANDROID_LOG_ERROR; + case warning: + return ANDROID_LOG_WARN; + case info: + return ANDROID_LOG_INFO; + case debug: + return ANDROID_LOG_DEBUG; + case verbose: + return ANDROID_LOG_VERBOSE; + default: + return ANDROID_LOG_UNKNOWN; + } + } + + private: + const char* const m_tag; + }; +} diff --git a/datachannel/include/plog/Appenders/ArduinoAppender.h b/datachannel/include/plog/Appenders/ArduinoAppender.h new file mode 100644 index 000000000..276af323f --- /dev/null +++ b/datachannel/include/plog/Appenders/ArduinoAppender.h @@ -0,0 +1,23 @@ +#pragma once +#include +#include + +namespace plog +{ + template + class PLOG_LINKAGE_HIDDEN ArduinoAppender : public IAppender + { + public: + ArduinoAppender(Stream &stream) : m_stream(stream) + { + } + + virtual void write(const Record &record) PLOG_OVERRIDE + { + m_stream.print(Formatter::format(record).c_str()); + } + + private: + Stream &m_stream; + }; +} diff --git a/datachannel/include/plog/Appenders/ColorConsoleAppender.h b/datachannel/include/plog/Appenders/ColorConsoleAppender.h new file mode 100644 index 000000000..24bbc7d90 --- /dev/null +++ b/datachannel/include/plog/Appenders/ColorConsoleAppender.h @@ -0,0 +1,108 @@ +#pragma once +#include +#include + +namespace plog +{ + template + class PLOG_LINKAGE_HIDDEN ColorConsoleAppender : public ConsoleAppender + { + public: +#ifdef _WIN32 +# ifdef _MSC_VER +# pragma warning(suppress: 26812) // Prefer 'enum class' over 'enum' +# endif + ColorConsoleAppender(OutputStream outStream = streamStdOut) + : ConsoleAppender(outStream) + , m_originalAttr() + { + if (this->m_isatty) + { + CONSOLE_SCREEN_BUFFER_INFO csbiInfo; + GetConsoleScreenBufferInfo(this->m_outputHandle, &csbiInfo); + + m_originalAttr = csbiInfo.wAttributes; + } + } +#else + ColorConsoleAppender(OutputStream outStream = streamStdOut) + : ConsoleAppender(outStream) + {} +#endif + + virtual void write(const Record& record) PLOG_OVERRIDE + { + util::nstring str = Formatter::format(record); + util::MutexLock lock(this->m_mutex); + + setColor(record.getSeverity()); + this->writestr(str); + resetColor(); + } + + protected: + void setColor(Severity severity) + { + if (this->m_isatty) + { + switch (severity) + { +#ifdef _WIN32 + case fatal: + SetConsoleTextAttribute(this->m_outputHandle, foreground::kRed | foreground::kGreen | foreground::kBlue | foreground::kIntensity | background::kRed); // white on red background + break; + + case error: + SetConsoleTextAttribute(this->m_outputHandle, static_cast(foreground::kRed | foreground::kIntensity | (m_originalAttr & 0xf0))); // red + break; + + case warning: + SetConsoleTextAttribute(this->m_outputHandle, static_cast(foreground::kRed | foreground::kGreen | foreground::kIntensity | (m_originalAttr & 0xf0))); // yellow + break; + + case debug: + case verbose: + SetConsoleTextAttribute(this->m_outputHandle, static_cast(foreground::kGreen | foreground::kBlue | foreground::kIntensity | (m_originalAttr & 0xf0))); // cyan + break; +#else + case fatal: + this->m_outputStream << "\x1B[97m\x1B[41m"; // white on red background + break; + + case error: + this->m_outputStream << "\x1B[91m"; // red + break; + + case warning: + this->m_outputStream << "\x1B[93m"; // yellow + break; + + case debug: + case verbose: + this->m_outputStream << "\x1B[96m"; // cyan + break; +#endif + default: + break; + } + } + } + + void resetColor() + { + if (this->m_isatty) + { +#ifdef _WIN32 + SetConsoleTextAttribute(this->m_outputHandle, m_originalAttr); +#else + this->m_outputStream << "\x1B[0m\x1B[0K"; +#endif + } + } + + private: +#ifdef _WIN32 + WORD m_originalAttr; +#endif + }; +} diff --git a/datachannel/include/plog/Appenders/ConsoleAppender.h b/datachannel/include/plog/Appenders/ConsoleAppender.h new file mode 100644 index 000000000..a8925a049 --- /dev/null +++ b/datachannel/include/plog/Appenders/ConsoleAppender.h @@ -0,0 +1,83 @@ +#pragma once +#include +#include +#include +#include + +namespace plog +{ + enum OutputStream + { + streamStdOut, + streamStdErr + }; + + template + class PLOG_LINKAGE_HIDDEN ConsoleAppender : public IAppender + { + public: +#ifdef _WIN32 +# ifdef _MSC_VER +# pragma warning(suppress: 26812) // Prefer 'enum class' over 'enum' +# endif + ConsoleAppender(OutputStream outStream = streamStdOut) + : m_isatty(!!_isatty(_fileno(outStream == streamStdOut ? stdout : stderr))) + , m_outputStream(outStream == streamStdOut ? std::cout : std::cerr) + , m_outputHandle() + { + if (m_isatty) + { + m_outputHandle = GetStdHandle(outStream == streamStdOut ? stdHandle::kOutput : stdHandle::kErrorOutput); + } + } +#else + ConsoleAppender(OutputStream outStream = streamStdOut) + : m_isatty(!!isatty(fileno(outStream == streamStdOut ? stdout : stderr))) + , m_outputStream(outStream == streamStdOut ? std::cout : std::cerr) + {} +#endif + + virtual void write(const Record& record) PLOG_OVERRIDE + { + util::nstring str = Formatter::format(record); + util::MutexLock lock(m_mutex); + + writestr(str); + } + + protected: + void writestr(const util::nstring& str) + { +#ifdef _WIN32 + if (m_isatty) + { + const std::wstring& wstr = util::toWide(str); + WriteConsoleW(m_outputHandle, wstr.c_str(), static_cast(wstr.size()), NULL, NULL); + } + else + { +# if PLOG_CHAR_IS_UTF8 + m_outputStream << str << std::flush; +# else + m_outputStream << util::toNarrow(str, codePage::kActive) << std::flush; +# endif + } +#else + m_outputStream << str << std::flush; +#endif + } + + private: +#ifdef __BORLANDC__ + static int _isatty(int fd) { return ::isatty(fd); } +#endif + + protected: + util::Mutex m_mutex; + const bool m_isatty; + std::ostream& m_outputStream; +#ifdef _WIN32 + HANDLE m_outputHandle; +#endif + }; +} diff --git a/datachannel/include/plog/Appenders/DebugOutputAppender.h b/datachannel/include/plog/Appenders/DebugOutputAppender.h new file mode 100644 index 000000000..5d7c95ef2 --- /dev/null +++ b/datachannel/include/plog/Appenders/DebugOutputAppender.h @@ -0,0 +1,16 @@ +#pragma once +#include +#include + +namespace plog +{ + template + class PLOG_LINKAGE_HIDDEN DebugOutputAppender : public IAppender + { + public: + virtual void write(const Record& record) PLOG_OVERRIDE + { + OutputDebugStringW(util::toWide(Formatter::format(record)).c_str()); + } + }; +} diff --git a/datachannel/include/plog/Appenders/DynamicAppender.h b/datachannel/include/plog/Appenders/DynamicAppender.h new file mode 100644 index 000000000..f078c08cc --- /dev/null +++ b/datachannel/include/plog/Appenders/DynamicAppender.h @@ -0,0 +1,42 @@ +#pragma once +#include +#include + +namespace plog +{ + class PLOG_LINKAGE_HIDDEN DynamicAppender : public IAppender + { + public: + DynamicAppender& addAppender(IAppender* appender) + { + assert(appender != this); + + util::MutexLock lock(m_mutex); + m_appenders.insert(appender); + + return *this; + } + + DynamicAppender& removeAppender(IAppender* appender) + { + util::MutexLock lock(m_mutex); + m_appenders.erase(appender); + + return *this; + } + + virtual void write(const Record& record) PLOG_OVERRIDE + { + util::MutexLock lock(m_mutex); + + for (std::set::iterator it = m_appenders.begin(); it != m_appenders.end(); ++it) + { + (*it)->write(record); + } + } + + private: + mutable util::Mutex m_mutex; + std::set m_appenders; + }; +} diff --git a/datachannel/include/plog/Appenders/EventLogAppender.h b/datachannel/include/plog/Appenders/EventLogAppender.h new file mode 100644 index 000000000..3b7065be1 --- /dev/null +++ b/datachannel/include/plog/Appenders/EventLogAppender.h @@ -0,0 +1,117 @@ +#pragma once +#include +#include + +namespace plog +{ + template + class PLOG_LINKAGE_HIDDEN EventLogAppender : public IAppender + { + public: + EventLogAppender(const util::nchar* sourceName) : m_eventSource(RegisterEventSourceW(NULL, util::toWide(sourceName).c_str())) + { + } + + ~EventLogAppender() + { + DeregisterEventSource(m_eventSource); + } + + virtual void write(const Record& record) PLOG_OVERRIDE + { + util::nstring str = Formatter::format(record); + + write(record.getSeverity(), util::toWide(str).c_str()); + } + + private: + void write(Severity severity, const wchar_t* str) + { + const wchar_t* logMessagePtr[] = { str }; + + ReportEventW(m_eventSource, logSeverityToType(severity), static_cast(severity), 0, NULL, 1, 0, logMessagePtr, NULL); + } + + static WORD logSeverityToType(plog::Severity severity) + { + switch (severity) + { + case plog::fatal: + case plog::error: + return eventLog::kErrorType; + + case plog::warning: + return eventLog::kWarningType; + + case plog::info: + case plog::debug: + case plog::verbose: + default: + return eventLog::kInformationType; + } + } + + private: + HANDLE m_eventSource; + }; + + class EventLogAppenderRegistry + { + public: + static bool add(const util::nchar* sourceName, const util::nchar* logName = PLOG_NSTR("Application")) + { + std::wstring logKeyName; + std::wstring sourceKeyName; + getKeyNames(sourceName, logName, sourceKeyName, logKeyName); + + HKEY sourceKey; + if (0 != RegCreateKeyExW(hkey::kLocalMachine, sourceKeyName.c_str(), 0, NULL, 0, regSam::kSetValue, NULL, &sourceKey, NULL)) + { + return false; + } + + const DWORD kTypesSupported = eventLog::kErrorType | eventLog::kWarningType | eventLog::kInformationType; + RegSetValueExW(sourceKey, L"TypesSupported", 0, regType::kDword, reinterpret_cast(&kTypesSupported), sizeof(kTypesSupported)); + + const wchar_t kEventMessageFile[] = L"%windir%\\Microsoft.NET\\Framework\\v4.0.30319\\EventLogMessages.dll;%windir%\\Microsoft.NET\\Framework\\v2.0.50727\\EventLogMessages.dll"; + RegSetValueExW(sourceKey, L"EventMessageFile", 0, regType::kExpandSz, reinterpret_cast(kEventMessageFile), sizeof(kEventMessageFile) - sizeof(*kEventMessageFile)); + + RegCloseKey(sourceKey); + return true; + } + + static bool exists(const util::nchar* sourceName, const util::nchar* logName = PLOG_NSTR("Application")) + { + std::wstring logKeyName; + std::wstring sourceKeyName; + getKeyNames(sourceName, logName, sourceKeyName, logKeyName); + + HKEY sourceKey; + if (0 != RegOpenKeyExW(hkey::kLocalMachine, sourceKeyName.c_str(), 0, regSam::kQueryValue, &sourceKey)) + { + return false; + } + + RegCloseKey(sourceKey); + return true; + } + + static void remove(const util::nchar* sourceName, const util::nchar* logName = PLOG_NSTR("Application")) + { + std::wstring logKeyName; + std::wstring sourceKeyName; + getKeyNames(sourceName, logName, sourceKeyName, logKeyName); + + RegDeleteKeyW(hkey::kLocalMachine, sourceKeyName.c_str()); + RegDeleteKeyW(hkey::kLocalMachine, logKeyName.c_str()); + } + + private: + static void getKeyNames(const util::nchar* sourceName, const util::nchar* logName, std::wstring& sourceKeyName, std::wstring& logKeyName) + { + const std::wstring kPrefix = L"SYSTEM\\CurrentControlSet\\Services\\EventLog\\"; + logKeyName = kPrefix + util::toWide(logName); + sourceKeyName = logKeyName + L"\\" + util::toWide(sourceName); + } + }; +} diff --git a/datachannel/include/plog/Appenders/IAppender.h b/datachannel/include/plog/Appenders/IAppender.h new file mode 100644 index 000000000..56b7827ae --- /dev/null +++ b/datachannel/include/plog/Appenders/IAppender.h @@ -0,0 +1,16 @@ +#pragma once +#include +#include + +namespace plog +{ + class PLOG_LINKAGE IAppender + { + public: + virtual ~IAppender() + { + } + + virtual void write(const Record& record) = 0; + }; +} diff --git a/datachannel/include/plog/Appenders/RollingFileAppender.h b/datachannel/include/plog/Appenders/RollingFileAppender.h new file mode 100644 index 000000000..3b667287a --- /dev/null +++ b/datachannel/include/plog/Appenders/RollingFileAppender.h @@ -0,0 +1,148 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace plog +{ + template > + class PLOG_LINKAGE_HIDDEN RollingFileAppender : public IAppender + { + public: + RollingFileAppender(const util::nchar* fileName, size_t maxFileSize = 0, int maxFiles = 0) + : m_fileSize() + , m_maxFileSize() + , m_maxFiles(maxFiles) + , m_firstWrite(true) + { + setFileName(fileName); + setMaxFileSize(maxFileSize); + } + +#if defined(_WIN32) && !PLOG_CHAR_IS_UTF8 + RollingFileAppender(const char* fileName, size_t maxFileSize = 0, int maxFiles = 0) + : m_fileSize() + , m_maxFileSize() + , m_maxFiles(maxFiles) + , m_firstWrite(true) + { + setFileName(fileName); + setMaxFileSize(maxFileSize); + } +#endif + + virtual void write(const Record& record) PLOG_OVERRIDE + { + util::MutexLock lock(m_mutex); + + if (m_firstWrite) + { + openLogFile(); + m_firstWrite = false; + } + else if (m_maxFiles > 0 && m_fileSize > m_maxFileSize && static_cast(-1) != m_fileSize) + { + rollLogFiles(); + } + + size_t bytesWritten = m_file.write(Converter::convert(Formatter::format(record))); + + if (static_cast(-1) != bytesWritten) + { + m_fileSize += bytesWritten; + } + } + + void setFileName(const util::nchar* fileName) + { + util::MutexLock lock(m_mutex); + + util::splitFileName(fileName, m_fileNameNoExt, m_fileExt); + + m_file.close(); + m_firstWrite = true; + } + +#if defined(_WIN32) && !PLOG_CHAR_IS_UTF8 + void setFileName(const char* fileName) + { + setFileName(util::toWide(fileName).c_str()); + } +#endif + + void setMaxFiles(int maxFiles) + { + m_maxFiles = maxFiles; + } + + void setMaxFileSize(size_t maxFileSize) + { + m_maxFileSize = (std::max)(maxFileSize, static_cast(1000)); // set a lower limit for the maxFileSize + } + + void rollLogFiles() + { + m_file.close(); + + util::nstring lastFileName = buildFileName(m_maxFiles - 1); + util::File::unlink(lastFileName); + + for (int fileNumber = m_maxFiles - 2; fileNumber >= 0; --fileNumber) + { + util::nstring currentFileName = buildFileName(fileNumber); + util::nstring nextFileName = buildFileName(fileNumber + 1); + + util::File::rename(currentFileName, nextFileName); + } + + openLogFile(); + m_firstWrite = false; + } + + private: + void openLogFile() + { + m_fileSize = m_file.open(buildFileName()); + + if (0 == m_fileSize) + { + size_t bytesWritten = m_file.write(Converter::header(Formatter::header())); + + if (static_cast(-1) != bytesWritten) + { + m_fileSize += bytesWritten; + } + } + } + + util::nstring buildFileName(int fileNumber = 0) + { + util::nostringstream ss; + ss << m_fileNameNoExt; + + if (fileNumber > 0) + { + ss << '.' << fileNumber; + } + + if (!m_fileExt.empty()) + { + ss << '.' << m_fileExt; + } + + return ss.str(); + } + + private: + util::Mutex m_mutex; + util::File m_file; + size_t m_fileSize; + size_t m_maxFileSize; + int m_maxFiles; + util::nstring m_fileExt; + util::nstring m_fileNameNoExt; + bool m_firstWrite; + }; +} diff --git a/datachannel/include/plog/Converters/NativeEOLConverter.h b/datachannel/include/plog/Converters/NativeEOLConverter.h new file mode 100644 index 000000000..a395e4e89 --- /dev/null +++ b/datachannel/include/plog/Converters/NativeEOLConverter.h @@ -0,0 +1,44 @@ +#pragma once +#include +#include + +namespace plog +{ + template + class NativeEOLConverter : public NextConverter + { +#ifdef _WIN32 + public: + static std::string header(const util::nstring& str) + { + return NextConverter::header(fixLineEndings(str)); + } + + static std::string convert(const util::nstring& str) + { + return NextConverter::convert(fixLineEndings(str)); + } + + private: + static util::nstring fixLineEndings(const util::nstring& str) + { + util::nstring output; + output.reserve(str.length() * 2); // the worst case requires 2x chars + + for (size_t i = 0; i < str.size(); ++i) + { + util::nchar ch = str[i]; + + if (ch == PLOG_NSTR('\n')) + { + output.push_back(PLOG_NSTR('\r')); + } + + output.push_back(ch); + } + + return output; + } +#endif + }; +} diff --git a/datachannel/include/plog/Converters/UTF8Converter.h b/datachannel/include/plog/Converters/UTF8Converter.h new file mode 100644 index 000000000..9de5a6303 --- /dev/null +++ b/datachannel/include/plog/Converters/UTF8Converter.h @@ -0,0 +1,28 @@ +#pragma once +#include + +namespace plog +{ + class UTF8Converter + { + public: + static std::string header(const util::nstring& str) + { + const char kBOM[] = "\xEF\xBB\xBF"; + + return std::string(kBOM) + convert(str); + } + +#if PLOG_CHAR_IS_UTF8 + static const std::string& convert(const util::nstring& str) + { + return str; + } +#else + static std::string convert(const util::nstring& str) + { + return util::toNarrow(str, codePage::kUTF8); + } +#endif + }; +} diff --git a/datachannel/include/plog/Formatters/CsvFormatter.h b/datachannel/include/plog/Formatters/CsvFormatter.h new file mode 100644 index 000000000..282c57f19 --- /dev/null +++ b/datachannel/include/plog/Formatters/CsvFormatter.h @@ -0,0 +1,57 @@ +#pragma once +#include +#include +#include + +namespace plog +{ + template + class CsvFormatterImpl + { + public: + static util::nstring header() + { + return PLOG_NSTR("Date;Time;Severity;TID;This;Function;Message\n"); + } + + static util::nstring format(const Record& record) + { + tm t; + useUtcTime ? util::gmtime_s(&t, &record.getTime().time) : util::localtime_s(&t, &record.getTime().time); + + util::nostringstream ss; + ss << t.tm_year + 1900 << PLOG_NSTR("/") << std::setfill(PLOG_NSTR('0')) << std::setw(2) << t.tm_mon + 1 << PLOG_NSTR("/") << std::setfill(PLOG_NSTR('0')) << std::setw(2) << t.tm_mday << PLOG_NSTR(";"); + ss << std::setfill(PLOG_NSTR('0')) << std::setw(2) << t.tm_hour << PLOG_NSTR(":") << std::setfill(PLOG_NSTR('0')) << std::setw(2) << t.tm_min << PLOG_NSTR(":") << std::setfill(PLOG_NSTR('0')) << std::setw(2) << t.tm_sec << PLOG_NSTR(".") << std::setfill(PLOG_NSTR('0')) << std::setw(3) << static_cast (record.getTime().millitm) << PLOG_NSTR(";"); + ss << severityToString(record.getSeverity()) << PLOG_NSTR(";"); + ss << record.getTid() << PLOG_NSTR(";"); + ss << record.getObject() << PLOG_NSTR(";"); + ss << record.getFunc() << PLOG_NSTR("@") << record.getLine() << PLOG_NSTR(";"); + + util::nstring message = record.getMessage(); + + if (message.size() > kMaxMessageSize) + { + message.resize(kMaxMessageSize); + message.append(PLOG_NSTR("...")); + } + + util::nistringstream split(message); + util::nstring token; + + while (!split.eof()) + { + std::getline(split, token, PLOG_NSTR('"')); + ss << PLOG_NSTR("\"") << token << PLOG_NSTR("\""); + } + + ss << PLOG_NSTR("\n"); + + return ss.str(); + } + + static const size_t kMaxMessageSize = 32000; + }; + + class CsvFormatter : public CsvFormatterImpl {}; + class CsvFormatterUtcTime : public CsvFormatterImpl {}; +} diff --git a/datachannel/include/plog/Formatters/FuncMessageFormatter.h b/datachannel/include/plog/Formatters/FuncMessageFormatter.h new file mode 100644 index 000000000..aa57806b8 --- /dev/null +++ b/datachannel/include/plog/Formatters/FuncMessageFormatter.h @@ -0,0 +1,23 @@ +#pragma once +#include +#include + +namespace plog +{ + class FuncMessageFormatter + { + public: + static util::nstring header() + { + return util::nstring(); + } + + static util::nstring format(const Record& record) + { + util::nostringstream ss; + ss << record.getFunc() << PLOG_NSTR("@") << record.getLine() << PLOG_NSTR(": ") << record.getMessage() << PLOG_NSTR("\n"); + + return ss.str(); + } + }; +} diff --git a/datachannel/include/plog/Formatters/MessageOnlyFormatter.h b/datachannel/include/plog/Formatters/MessageOnlyFormatter.h new file mode 100644 index 000000000..b2db5a5a0 --- /dev/null +++ b/datachannel/include/plog/Formatters/MessageOnlyFormatter.h @@ -0,0 +1,23 @@ +#pragma once +#include +#include + +namespace plog +{ + class MessageOnlyFormatter + { + public: + static util::nstring header() + { + return util::nstring(); + } + + static util::nstring format(const Record& record) + { + util::nostringstream ss; + ss << record.getMessage() << PLOG_NSTR("\n"); + + return ss.str(); + } + }; +} diff --git a/datachannel/include/plog/Formatters/TxtFormatter.h b/datachannel/include/plog/Formatters/TxtFormatter.h new file mode 100644 index 000000000..48e2d50b8 --- /dev/null +++ b/datachannel/include/plog/Formatters/TxtFormatter.h @@ -0,0 +1,36 @@ +#pragma once +#include +#include +#include + +namespace plog +{ + template + class TxtFormatterImpl + { + public: + static util::nstring header() + { + return util::nstring(); + } + + static util::nstring format(const Record& record) + { + tm t; + useUtcTime ? util::gmtime_s(&t, &record.getTime().time) : util::localtime_s(&t, &record.getTime().time); + + util::nostringstream ss; + ss << t.tm_year + 1900 << "-" << std::setfill(PLOG_NSTR('0')) << std::setw(2) << t.tm_mon + 1 << PLOG_NSTR("-") << std::setfill(PLOG_NSTR('0')) << std::setw(2) << t.tm_mday << PLOG_NSTR(" "); + ss << std::setfill(PLOG_NSTR('0')) << std::setw(2) << t.tm_hour << PLOG_NSTR(":") << std::setfill(PLOG_NSTR('0')) << std::setw(2) << t.tm_min << PLOG_NSTR(":") << std::setfill(PLOG_NSTR('0')) << std::setw(2) << t.tm_sec << PLOG_NSTR(".") << std::setfill(PLOG_NSTR('0')) << std::setw(3) << static_cast (record.getTime().millitm) << PLOG_NSTR(" "); + ss << std::setfill(PLOG_NSTR(' ')) << std::setw(5) << std::left << severityToString(record.getSeverity()) << PLOG_NSTR(" "); + ss << PLOG_NSTR("[") << record.getTid() << PLOG_NSTR("] "); + ss << PLOG_NSTR("[") << record.getFunc() << PLOG_NSTR("@") << record.getLine() << PLOG_NSTR("] "); + ss << record.getMessage() << PLOG_NSTR("\n"); + + return ss.str(); + } + }; + + class TxtFormatter : public TxtFormatterImpl {}; + class TxtFormatterUtcTime : public TxtFormatterImpl {}; +} diff --git a/datachannel/include/plog/Helpers/AscDump.h b/datachannel/include/plog/Helpers/AscDump.h new file mode 100644 index 000000000..18b9a6f4f --- /dev/null +++ b/datachannel/include/plog/Helpers/AscDump.h @@ -0,0 +1,40 @@ +#pragma once +#include +#include + +namespace plog +{ + class AscDump + { + public: + AscDump(const void* ptr, size_t size) + : m_ptr(static_cast(ptr)) + , m_size(size) + { + } + + friend util::nostringstream& operator<<(util::nostringstream& stream, const AscDump& ascDump); + + private: + const char* m_ptr; + size_t m_size; + }; + + inline util::nostringstream& operator<<(util::nostringstream& stream, const AscDump& ascDump) + { + for (size_t i = 0; i < ascDump.m_size; ++i) + { + stream << (std::isprint(ascDump.m_ptr[i]) ? ascDump.m_ptr[i] : '.'); + } + + return stream; + } + + inline AscDump ascdump(const void* ptr, size_t size) { return AscDump(ptr, size); } + + template + inline AscDump ascdump(const Container& container) { return AscDump(container.data(), container.size() * sizeof(*container.data())); } + + template + inline AscDump ascdump(const T (&arr)[N]) { return AscDump(arr, N * sizeof(*arr)); } +} diff --git a/datachannel/include/plog/Helpers/HexDump.h b/datachannel/include/plog/Helpers/HexDump.h new file mode 100644 index 000000000..b0493d707 --- /dev/null +++ b/datachannel/include/plog/Helpers/HexDump.h @@ -0,0 +1,79 @@ +#pragma once +#include +#include + +namespace plog +{ + class HexDump + { + public: + HexDump(const void* ptr, size_t size) + : m_ptr(static_cast(ptr)) + , m_size(size) + , m_group(8) + , m_digitSeparator(" ") + , m_groupSeparator(" ") + { + } + + HexDump& group(size_t group) + { + m_group = group; + return *this; + } + + HexDump& separator(const char* digitSeparator) + { + m_digitSeparator = digitSeparator; + return *this; + } + + HexDump& separator(const char* digitSeparator, const char* groupSeparator) + { + m_digitSeparator = digitSeparator; + m_groupSeparator = groupSeparator; + return *this; + } + + friend util::nostringstream& operator<<(util::nostringstream& stream, const HexDump&); + + private: + const unsigned char* m_ptr; + size_t m_size; + size_t m_group; + const char* m_digitSeparator; + const char* m_groupSeparator; + }; + + inline util::nostringstream& operator<<(util::nostringstream& stream, const HexDump& hexDump) + { + stream << std::hex << std::setfill(PLOG_NSTR('0')); + + for (size_t i = 0; i < hexDump.m_size;) + { + stream << std::setw(2) << static_cast(hexDump.m_ptr[i]); + + if (++i < hexDump.m_size) + { + if (hexDump.m_group > 0 && i % hexDump.m_group == 0) + { + stream << hexDump.m_groupSeparator; + } + else + { + stream << hexDump.m_digitSeparator; + } + } + } + + return stream; + } + + inline HexDump hexdump(const void* ptr, size_t size) { return HexDump(ptr, size); } + + template + inline HexDump hexdump(const Container& container) { return HexDump(container.data(), container.size() * sizeof(*container.data())); } + + template + inline HexDump hexdump(const T (&arr)[N]) { return HexDump(arr, N * sizeof(*arr)); } +} diff --git a/datachannel/include/plog/Helpers/PrintVar.h b/datachannel/include/plog/Helpers/PrintVar.h new file mode 100644 index 000000000..465e1938f --- /dev/null +++ b/datachannel/include/plog/Helpers/PrintVar.h @@ -0,0 +1,24 @@ +#pragma once + +#define PLOG_IMPL_PRINT_VAR_1(a1) #a1 ": " << a1 +#define PLOG_IMPL_PRINT_VAR_2(a1, a2) PLOG_IMPL_PRINT_VAR_1(a1) PLOG_IMPL_PRINT_VAR_TAIL(a2) +#define PLOG_IMPL_PRINT_VAR_3(a1, a2, a3) PLOG_IMPL_PRINT_VAR_2(a1, a2) PLOG_IMPL_PRINT_VAR_TAIL(a3) +#define PLOG_IMPL_PRINT_VAR_4(a1, a2, a3, a4) PLOG_IMPL_PRINT_VAR_3(a1, a2, a3) PLOG_IMPL_PRINT_VAR_TAIL(a4) +#define PLOG_IMPL_PRINT_VAR_5(a1, a2, a3, a4, a5) PLOG_IMPL_PRINT_VAR_4(a1, a2, a3, a4) PLOG_IMPL_PRINT_VAR_TAIL(a5) +#define PLOG_IMPL_PRINT_VAR_6(a1, a2, a3, a4, a5, a6) PLOG_IMPL_PRINT_VAR_5(a1, a2, a3, a4, a5) PLOG_IMPL_PRINT_VAR_TAIL(a6) +#define PLOG_IMPL_PRINT_VAR_7(a1, a2, a3, a4, a5, a6, a7) PLOG_IMPL_PRINT_VAR_6(a1, a2, a3, a4, a5, a6) PLOG_IMPL_PRINT_VAR_TAIL(a7) +#define PLOG_IMPL_PRINT_VAR_8(a1, a2, a3, a4, a5, a6, a7, a8) PLOG_IMPL_PRINT_VAR_7(a1, a2, a3, a4, a5, a6, a7) PLOG_IMPL_PRINT_VAR_TAIL(a8) +#define PLOG_IMPL_PRINT_VAR_9(a1, a2, a3, a4, a5, a6, a7, a8, a9) PLOG_IMPL_PRINT_VAR_8(a1, a2, a3, a4, a5, a6, a7, a8) PLOG_IMPL_PRINT_VAR_TAIL(a9) +#define PLOG_IMPL_PRINT_VAR_TAIL(a) << ", " PLOG_IMPL_PRINT_VAR_1(a) + +#define PLOG_IMPL_PRINT_VAR_EXPAND(x) x + +#ifdef __GNUC__ +#pragma GCC system_header // disable warning: variadic macros are a C99 feature +#endif + +#define PLOG_IMPL_PRINT_VAR_GET_MACRO(x1, x2, x3, x4, x5, x6, x7, x8, x9, NAME, ...) NAME + +#define PLOG_PRINT_VAR(...) PLOG_IMPL_PRINT_VAR_EXPAND(PLOG_IMPL_PRINT_VAR_GET_MACRO(__VA_ARGS__,\ + PLOG_IMPL_PRINT_VAR_9, PLOG_IMPL_PRINT_VAR_8, PLOG_IMPL_PRINT_VAR_7, PLOG_IMPL_PRINT_VAR_6, PLOG_IMPL_PRINT_VAR_5,\ + PLOG_IMPL_PRINT_VAR_4, PLOG_IMPL_PRINT_VAR_3, PLOG_IMPL_PRINT_VAR_2, PLOG_IMPL_PRINT_VAR_1, UNUSED)(__VA_ARGS__)) diff --git a/datachannel/include/plog/Init.h b/datachannel/include/plog/Init.h new file mode 100644 index 000000000..615c41d6c --- /dev/null +++ b/datachannel/include/plog/Init.h @@ -0,0 +1,17 @@ +#pragma once +#include + +namespace plog +{ + template + PLOG_LINKAGE_HIDDEN inline Logger& init(Severity maxSeverity = none, IAppender* appender = NULL) + { + static Logger logger(maxSeverity); + return appender ? logger.addAppender(appender) : logger; + } + + inline Logger& init(Severity maxSeverity = none, IAppender* appender = NULL) + { + return init(maxSeverity, appender); + } +} diff --git a/datachannel/include/plog/Initializers/ConsoleInitializer.h b/datachannel/include/plog/Initializers/ConsoleInitializer.h new file mode 100644 index 000000000..5b504abd5 --- /dev/null +++ b/datachannel/include/plog/Initializers/ConsoleInitializer.h @@ -0,0 +1,22 @@ +#pragma once +#include +#include + +namespace plog +{ + ////////////////////////////////////////////////////////////////////////// + // ColorConsoleAppender with any Formatter + + template + PLOG_LINKAGE_HIDDEN inline Logger& init(Severity maxSeverity, OutputStream outputStream) + { + static ColorConsoleAppender appender(outputStream); + return init(maxSeverity, &appender); + } + + template + inline Logger& init(Severity maxSeverity, OutputStream outputStream) + { + return init(maxSeverity, outputStream); + } +} diff --git a/datachannel/include/plog/Initializers/RollingFileInitializer.h b/datachannel/include/plog/Initializers/RollingFileInitializer.h new file mode 100644 index 000000000..dc0adda14 --- /dev/null +++ b/datachannel/include/plog/Initializers/RollingFileInitializer.h @@ -0,0 +1,80 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace plog +{ + ////////////////////////////////////////////////////////////////////////// + // RollingFileAppender with any Formatter + + template + PLOG_LINKAGE_HIDDEN inline Logger& init(Severity maxSeverity, const util::nchar* fileName, size_t maxFileSize = 0, int maxFiles = 0) + { + static RollingFileAppender rollingFileAppender(fileName, maxFileSize, maxFiles); + return init(maxSeverity, &rollingFileAppender); + } + + template + inline Logger& init(Severity maxSeverity, const util::nchar* fileName, size_t maxFileSize = 0, int maxFiles = 0) + { + return init(maxSeverity, fileName, maxFileSize, maxFiles); + } + + ////////////////////////////////////////////////////////////////////////// + // RollingFileAppender with TXT/CSV chosen by file extension + + namespace + { + inline bool isCsv(const util::nchar* fileName) + { + const util::nchar* dot = util::findExtensionDot(fileName); +#if PLOG_CHAR_IS_UTF8 + return dot && 0 == std::strcmp(dot, ".csv"); +#else + return dot && 0 == std::wcscmp(dot, L".csv"); +#endif + } + } + + template + inline Logger& init(Severity maxSeverity, const util::nchar* fileName, size_t maxFileSize = 0, int maxFiles = 0) + { + return isCsv(fileName) ? init(maxSeverity, fileName, maxFileSize, maxFiles) : init(maxSeverity, fileName, maxFileSize, maxFiles); + } + + inline Logger& init(Severity maxSeverity, const util::nchar* fileName, size_t maxFileSize = 0, int maxFiles = 0) + { + return init(maxSeverity, fileName, maxFileSize, maxFiles); + } + + ////////////////////////////////////////////////////////////////////////// + // CHAR variants for Windows + +#if defined(_WIN32) && !PLOG_CHAR_IS_UTF8 + template + inline Logger& init(Severity maxSeverity, const char* fileName, size_t maxFileSize = 0, int maxFiles = 0) + { + return init(maxSeverity, util::toWide(fileName).c_str(), maxFileSize, maxFiles); + } + + template + inline Logger& init(Severity maxSeverity, const char* fileName, size_t maxFileSize = 0, int maxFiles = 0) + { + return init(maxSeverity, fileName, maxFileSize, maxFiles); + } + + template + inline Logger& init(Severity maxSeverity, const char* fileName, size_t maxFileSize = 0, int maxFiles = 0) + { + return init(maxSeverity, util::toWide(fileName).c_str(), maxFileSize, maxFiles); + } + + inline Logger& init(Severity maxSeverity, const char* fileName, size_t maxFileSize = 0, int maxFiles = 0) + { + return init(maxSeverity, fileName, maxFileSize, maxFiles); + } +#endif +} diff --git a/datachannel/include/plog/Log.h b/datachannel/include/plog/Log.h new file mode 100644 index 000000000..138d92f5f --- /dev/null +++ b/datachannel/include/plog/Log.h @@ -0,0 +1,202 @@ +////////////////////////////////////////////////////////////////////////// +// Plog - portable and simple log for C++ +// Documentation and sources: https://github.com/SergiusTheBest/plog +// License: MIT, https://choosealicense.com/licenses/mit + +#pragma once +#include + +////////////////////////////////////////////////////////////////////////// +// Helper macros that get context info + +#if defined(PLOG_ENABLE_GET_THIS) && defined(_MSC_VER) && _MSC_VER >= 1600 && !defined(__INTELLISENSE__) && !defined(__INTEL_COMPILER) && !defined(__llvm__) && !defined(__RESHARPER__) // >= Visual Studio 2010, skip IntelliSense, Intel Compiler, Clang Code Model and ReSharper +# define PLOG_GET_THIS() __if_exists(this) { this } __if_not_exists(this) { 0 } +#else +# define PLOG_GET_THIS() reinterpret_cast(0) +#endif + +#ifdef _MSC_VER +# define PLOG_GET_FUNC() __FUNCTION__ +#elif defined(__BORLANDC__) +# define PLOG_GET_FUNC() __FUNC__ +#else +# define PLOG_GET_FUNC() __PRETTY_FUNCTION__ +#endif + +#ifdef PLOG_CAPTURE_FILE +# define PLOG_GET_FILE() __FILE__ +#else +# define PLOG_GET_FILE() "" +#endif + +////////////////////////////////////////////////////////////////////////// +// Log severity level checker + +#ifdef PLOG_DISABLE_LOGGING +# ifdef _MSC_VER +# define IF_PLOG_(instanceId, severity) __pragma(warning(push)) __pragma(warning(disable:4127)) if (true) {;} else __pragma(warning(pop)) // conditional expression is constant +# else +# define IF_PLOG_(instanceId, severity) if (true) {;} else +# endif +#else +# define IF_PLOG_(instanceId, severity) if (!plog::get() || !plog::get()->checkSeverity(severity)) {;} else +#endif + +#define IF_PLOG(severity) IF_PLOG_(PLOG_DEFAULT_INSTANCE_ID, severity) + +////////////////////////////////////////////////////////////////////////// +// Main logging macros + +#define PLOG_(instanceId, severity) IF_PLOG_(instanceId, severity) (*plog::get()) += plog::Record(severity, PLOG_GET_FUNC(), __LINE__, PLOG_GET_FILE(), PLOG_GET_THIS(), instanceId).ref() +#define PLOG(severity) PLOG_(PLOG_DEFAULT_INSTANCE_ID, severity) + +#define PLOG_VERBOSE PLOG(plog::verbose) +#define PLOG_DEBUG PLOG(plog::debug) +#define PLOG_INFO PLOG(plog::info) +#define PLOG_WARNING PLOG(plog::warning) +#define PLOG_ERROR PLOG(plog::error) +#define PLOG_FATAL PLOG(plog::fatal) +#define PLOG_NONE PLOG(plog::none) + +#define PLOG_VERBOSE_(instanceId) PLOG_(instanceId, plog::verbose) +#define PLOG_DEBUG_(instanceId) PLOG_(instanceId, plog::debug) +#define PLOG_INFO_(instanceId) PLOG_(instanceId, plog::info) +#define PLOG_WARNING_(instanceId) PLOG_(instanceId, plog::warning) +#define PLOG_ERROR_(instanceId) PLOG_(instanceId, plog::error) +#define PLOG_FATAL_(instanceId) PLOG_(instanceId, plog::fatal) +#define PLOG_NONE_(instanceId) PLOG_(instanceId, plog::none) + +#define PLOGV PLOG_VERBOSE +#define PLOGD PLOG_DEBUG +#define PLOGI PLOG_INFO +#define PLOGW PLOG_WARNING +#define PLOGE PLOG_ERROR +#define PLOGF PLOG_FATAL +#define PLOGN PLOG_NONE + +#define PLOGV_(instanceId) PLOG_VERBOSE_(instanceId) +#define PLOGD_(instanceId) PLOG_DEBUG_(instanceId) +#define PLOGI_(instanceId) PLOG_INFO_(instanceId) +#define PLOGW_(instanceId) PLOG_WARNING_(instanceId) +#define PLOGE_(instanceId) PLOG_ERROR_(instanceId) +#define PLOGF_(instanceId) PLOG_FATAL_(instanceId) +#define PLOGN_(instanceId) PLOG_NONE_(instanceId) + +////////////////////////////////////////////////////////////////////////// +// Conditional logging macros + +#define PLOG_IF_(instanceId, severity, condition) if (!(condition)) {;} else PLOG_(instanceId, severity) +#define PLOG_IF(severity, condition) PLOG_IF_(PLOG_DEFAULT_INSTANCE_ID, severity, condition) + +#define PLOG_VERBOSE_IF(condition) PLOG_IF(plog::verbose, condition) +#define PLOG_DEBUG_IF(condition) PLOG_IF(plog::debug, condition) +#define PLOG_INFO_IF(condition) PLOG_IF(plog::info, condition) +#define PLOG_WARNING_IF(condition) PLOG_IF(plog::warning, condition) +#define PLOG_ERROR_IF(condition) PLOG_IF(plog::error, condition) +#define PLOG_FATAL_IF(condition) PLOG_IF(plog::fatal, condition) +#define PLOG_NONE_IF(condition) PLOG_IF(plog::none, condition) + +#define PLOG_VERBOSE_IF_(instanceId, condition) PLOG_IF_(instanceId, plog::verbose, condition) +#define PLOG_DEBUG_IF_(instanceId, condition) PLOG_IF_(instanceId, plog::debug, condition) +#define PLOG_INFO_IF_(instanceId, condition) PLOG_IF_(instanceId, plog::info, condition) +#define PLOG_WARNING_IF_(instanceId, condition) PLOG_IF_(instanceId, plog::warning, condition) +#define PLOG_ERROR_IF_(instanceId, condition) PLOG_IF_(instanceId, plog::error, condition) +#define PLOG_FATAL_IF_(instanceId, condition) PLOG_IF_(instanceId, plog::fatal, condition) +#define PLOG_NONE_IF_(instanceId, condition) PLOG_IF_(instanceId, plog::none, condition) + +#define PLOGV_IF(condition) PLOG_VERBOSE_IF(condition) +#define PLOGD_IF(condition) PLOG_DEBUG_IF(condition) +#define PLOGI_IF(condition) PLOG_INFO_IF(condition) +#define PLOGW_IF(condition) PLOG_WARNING_IF(condition) +#define PLOGE_IF(condition) PLOG_ERROR_IF(condition) +#define PLOGF_IF(condition) PLOG_FATAL_IF(condition) +#define PLOGN_IF(condition) PLOG_NONE_IF(condition) + +#define PLOGV_IF_(instanceId, condition) PLOG_VERBOSE_IF_(instanceId, condition) +#define PLOGD_IF_(instanceId, condition) PLOG_DEBUG_IF_(instanceId, condition) +#define PLOGI_IF_(instanceId, condition) PLOG_INFO_IF_(instanceId, condition) +#define PLOGW_IF_(instanceId, condition) PLOG_WARNING_IF_(instanceId, condition) +#define PLOGE_IF_(instanceId, condition) PLOG_ERROR_IF_(instanceId, condition) +#define PLOGF_IF_(instanceId, condition) PLOG_FATAL_IF_(instanceId, condition) +#define PLOGN_IF_(instanceId, condition) PLOG_NONE_IF_(instanceId, condition) + +// Old macro names for downward compatibility. To bypass including these macro names, add +// #define PLOG_OMIT_LOG_DEFINES before #include +#ifndef PLOG_OMIT_LOG_DEFINES + +////////////////////////////////////////////////////////////////////////// +// Main logging macros - can be changed later to point at macros for a different logging package + +#define LOG_(instanceId, severity) IF_PLOG_(instanceId, severity) (*plog::get()) += plog::Record(severity, PLOG_GET_FUNC(), __LINE__, PLOG_GET_FILE(), PLOG_GET_THIS(), instanceId).ref() +#define LOG(severity) PLOG_(PLOG_DEFAULT_INSTANCE_ID, severity) + +#define LOG_VERBOSE PLOG(plog::verbose) +#define LOG_DEBUG PLOG(plog::debug) +#define LOG_INFO PLOG(plog::info) +#define LOG_WARNING PLOG(plog::warning) +#define LOG_ERROR PLOG(plog::error) +#define LOG_FATAL PLOG(plog::fatal) +#define LOG_NONE PLOG(plog::none) + +#define LOG_VERBOSE_(instanceId) PLOG_(instanceId, plog::verbose) +#define LOG_DEBUG_(instanceId) PLOG_(instanceId, plog::debug) +#define LOG_INFO_(instanceId) PLOG_(instanceId, plog::info) +#define LOG_WARNING_(instanceId) PLOG_(instanceId, plog::warning) +#define LOG_ERROR_(instanceId) PLOG_(instanceId, plog::error) +#define LOG_FATAL_(instanceId) PLOG_(instanceId, plog::fatal) +#define LOG_NONE_(instanceId) PLOG_(instanceId, plog::none) + +#define LOGV PLOG_VERBOSE +#define LOGD PLOG_DEBUG +#define LOGI PLOG_INFO +#define LOGW PLOG_WARNING +#define LOGE PLOG_ERROR +#define LOGF PLOG_FATAL +#define LOGN PLOG_NONE + +#define LOGV_(instanceId) PLOG_VERBOSE_(instanceId) +#define LOGD_(instanceId) PLOG_DEBUG_(instanceId) +#define LOGI_(instanceId) PLOG_INFO_(instanceId) +#define LOGW_(instanceId) PLOG_WARNING_(instanceId) +#define LOGE_(instanceId) PLOG_ERROR_(instanceId) +#define LOGF_(instanceId) PLOG_FATAL_(instanceId) +#define LOGN_(instanceId) PLOG_NONE_(instanceId) + +////////////////////////////////////////////////////////////////////////// +// Conditional logging macros + +#define LOG_IF_(instanceId, severity, condition) if (!(condition)) {;} else PLOG_(instanceId, severity) +#define LOG_IF(severity, condition) PLOG_IF_(PLOG_DEFAULT_INSTANCE_ID, severity, condition) + +#define LOG_VERBOSE_IF(condition) PLOG_IF(plog::verbose, condition) +#define LOG_DEBUG_IF(condition) PLOG_IF(plog::debug, condition) +#define LOG_INFO_IF(condition) PLOG_IF(plog::info, condition) +#define LOG_WARNING_IF(condition) PLOG_IF(plog::warning, condition) +#define LOG_ERROR_IF(condition) PLOG_IF(plog::error, condition) +#define LOG_FATAL_IF(condition) PLOG_IF(plog::fatal, condition) +#define LOG_NONE_IF(condition) PLOG_IF(plog::none, condition) + +#define LOG_VERBOSE_IF_(instanceId, condition) PLOG_IF_(instanceId, plog::verbose, condition) +#define LOG_DEBUG_IF_(instanceId, condition) PLOG_IF_(instanceId, plog::debug, condition) +#define LOG_INFO_IF_(instanceId, condition) PLOG_IF_(instanceId, plog::info, condition) +#define LOG_WARNING_IF_(instanceId, condition) PLOG_IF_(instanceId, plog::warning, condition) +#define LOG_ERROR_IF_(instanceId, condition) PLOG_IF_(instanceId, plog::error, condition) +#define LOG_FATAL_IF_(instanceId, condition) PLOG_IF_(instanceId, plog::fatal, condition) +#define LOG_NONE_IF_(instanceId, condition) PLOG_IF_(instanceId, plog::none, condition) + +#define LOGV_IF(condition) PLOG_VERBOSE_IF(condition) +#define LOGD_IF(condition) PLOG_DEBUG_IF(condition) +#define LOGI_IF(condition) PLOG_INFO_IF(condition) +#define LOGW_IF(condition) PLOG_WARNING_IF(condition) +#define LOGE_IF(condition) PLOG_ERROR_IF(condition) +#define LOGF_IF(condition) PLOG_FATAL_IF(condition) +#define LOGN_IF(condition) PLOG_NONE_IF(condition) + +#define LOGV_IF_(instanceId, condition) PLOG_VERBOSE_IF_(instanceId, condition) +#define LOGD_IF_(instanceId, condition) PLOG_DEBUG_IF_(instanceId, condition) +#define LOGI_IF_(instanceId, condition) PLOG_INFO_IF_(instanceId, condition) +#define LOGW_IF_(instanceId, condition) PLOG_WARNING_IF_(instanceId, condition) +#define LOGE_IF_(instanceId, condition) PLOG_ERROR_IF_(instanceId, condition) +#define LOGF_IF_(instanceId, condition) PLOG_FATAL_IF_(instanceId, condition) +#define LOGN_IF_(instanceId, condition) PLOG_NONE_IF_(instanceId, condition) +#endif diff --git a/datachannel/include/plog/Logger.h b/datachannel/include/plog/Logger.h new file mode 100644 index 000000000..0e116e4c6 --- /dev/null +++ b/datachannel/include/plog/Logger.h @@ -0,0 +1,84 @@ +#pragma once +#include +#include +#include + +#ifdef PLOG_DEFAULT_INSTANCE // for backward compatibility +# define PLOG_DEFAULT_INSTANCE_ID PLOG_DEFAULT_INSTANCE +#endif + +#ifndef PLOG_DEFAULT_INSTANCE_ID +# define PLOG_DEFAULT_INSTANCE_ID 0 +#endif + +namespace plog +{ + template + class PLOG_LINKAGE Logger : public util::Singleton >, public IAppender + { + public: + Logger(Severity maxSeverity = none) : m_maxSeverity(maxSeverity) + { + } + + Logger& addAppender(IAppender* appender) + { + assert(appender != this); + m_appenders.push_back(appender); + return *this; + } + + Severity getMaxSeverity() const + { + return m_maxSeverity; + } + + void setMaxSeverity(Severity severity) + { + m_maxSeverity = severity; + } + + bool checkSeverity(Severity severity) const + { + return severity <= m_maxSeverity; + } + + virtual void write(const Record& record) PLOG_OVERRIDE + { + if (checkSeverity(record.getSeverity())) + { + *this += record; + } + } + + void operator+=(const Record& record) + { + for (std::vector::iterator it = m_appenders.begin(); it != m_appenders.end(); ++it) + { + (*it)->write(record); + } + } + + private: + Severity m_maxSeverity; +#ifdef _MSC_VER +# pragma warning(push) +# pragma warning(disable:4251) // needs to have dll-interface to be used by clients of class +#endif + std::vector m_appenders; +#ifdef _MSC_VER +# pragma warning(pop) +#endif + }; + + template + inline Logger* get() + { + return Logger::getInstance(); + } + + inline Logger* get() + { + return Logger::getInstance(); + } +} diff --git a/datachannel/include/plog/Record.h b/datachannel/include/plog/Record.h new file mode 100644 index 000000000..c0c4e8720 --- /dev/null +++ b/datachannel/include/plog/Record.h @@ -0,0 +1,435 @@ +#pragma once +#include +#include +#include + +#ifdef __cplusplus_cli +#include // For PtrToStringChars +#endif + +namespace plog +{ + namespace detail + { +#if !defined(_MSC_VER) || _MSC_VER > 1400 // MSVC 2005 doesn't understand `enableIf`, so drop all `meta` + namespace meta + { + template + inline T& declval() + { +#ifdef __INTEL_COMPILER +# pragma warning(suppress: 327) // NULL reference is not allowed +#endif + return *reinterpret_cast(0); + } + + template + struct enableIf {}; + + template + struct enableIf { typedef T type; }; + + struct No { char a[1]; }; + struct Yes { char a[2]; }; + + template + struct isConvertible + { + // `+ sizeof(U*)` is required for GCC 4.5-4.7 + template + static typename enableIf(meta::declval())) + sizeof(U*)), Yes>::type test(int); + + template + static No test(...); + + enum { value = sizeof(test(0)) == sizeof(Yes) }; + }; + + template + struct isConvertibleToNString : isConvertible {}; + + template + struct isConvertibleToString : isConvertible {}; + + template + struct isContainer + { + template + static typename meta::enableIf().begin()) + sizeof(meta::declval().end() +#else + typename U::const_iterator +#endif + )), Yes>::type test(int); + + template + static No test(...); + + enum { value = sizeof(test(0)) == sizeof(Yes) }; + }; + + // Detects `std::filesystem::path` and `boost::filesystem::path`. They look like containers + // but we don't want to treat them as containers, so we use this detector to filter them out. + template + struct isFilesystemPath + { + template + static typename meta::enableIf().preferred_separator)), Yes>::type test(int); + + template + static No test(...); + + enum { value = sizeof(test(0)) == sizeof(Yes) }; + }; + } +#endif + + ////////////////////////////////////////////////////////////////////////// + // Stream output operators as free functions + +#if PLOG_ENABLE_WCHAR_INPUT + inline void operator<<(util::nostringstream& stream, const wchar_t* data) + { + data = data ? data : L"(null)"; + +# ifdef _WIN32 +# if PLOG_CHAR_IS_UTF8 + std::operator<<(stream, util::toNarrow(data, codePage::kUTF8)); +# else + std::operator<<(stream, data); +# endif +# else + std::operator<<(stream, util::toNarrow(data)); +# endif + } + + inline void operator<<(util::nostringstream& stream, wchar_t* data) + { + plog::detail::operator<<(stream, const_cast(data)); + } + + inline void operator<<(util::nostringstream& stream, const std::wstring& data) + { + plog::detail::operator<<(stream, data.c_str()); + } +#endif + + inline void operator<<(util::nostringstream& stream, const char* data) + { + data = data ? data : "(null)"; + +#if defined(_WIN32) && defined(__BORLANDC__) +# if PLOG_CHAR_IS_UTF8 + stream << data; +# else + stream << util::toWide(data); +# endif +#elif defined(_WIN32) +# if PLOG_CHAR_IS_UTF8 + std::operator<<(stream, data); +# else + std::operator<<(stream, util::toWide(data)); +# endif +#else + std::operator<<(stream, data); +#endif + } + + inline void operator<<(util::nostringstream& stream, char* data) + { + plog::detail::operator<<(stream, const_cast(data)); + } + + inline void operator<<(util::nostringstream& stream, const std::string& data) + { + plog::detail::operator<<(stream, data.c_str()); + } + +#ifdef __cpp_char8_t + inline void operator<<(util::nostringstream& stream, const char8_t* data) + { +# if PLOG_CHAR_IS_UTF8 + plog::detail::operator<<(stream, reinterpret_cast(data)); +# else + plog::detail::operator<<(stream, util::toWide(reinterpret_cast(data), codePage::kUTF8)); +# endif + } +#endif //__cpp_char8_t + + // Print `std::pair` + template + inline void operator<<(util::nostringstream& stream, const std::pair& data) + { + stream << data.first; + stream << ":"; + stream << data.second; + } + +#if defined(__clang__) || !defined(__GNUC__) || (__GNUC__ * 100 + __GNUC_MINOR__) >= 405 // skip for GCC < 4.5 due to https://gcc.gnu.org/bugzilla/show_bug.cgi?id=38600 +#if !defined(_MSC_VER) || _MSC_VER > 1400 // MSVC 2005 doesn't understand `enableIf`, so drop all `meta` + // Print data that can be casted to `std::basic_string`. + template + inline typename meta::enableIf::value, void>::type operator<<(util::nostringstream& stream, const T& data) + { + plog::detail::operator<<(stream, static_cast(data)); + } + + // Print std containers + template + inline typename meta::enableIf::value && + !meta::isConvertibleToNString::value && + !meta::isConvertibleToString::value && + !meta::isFilesystemPath::value, void>::type operator<<(util::nostringstream& stream, const T& data) + { + stream << "["; + + for (typename T::const_iterator it = data.begin(); it != data.end();) + { + stream << *it; + + if (++it == data.end()) + { + break; + } + + stream << ", "; + } + + stream << "]"; + } +#endif +#endif + +#ifdef __cplusplus_cli + inline void operator<<(util::nostringstream& stream, System::String^ data) + { + cli::pin_ptr ptr = PtrToStringChars(data); + plog::detail::operator<<(stream, static_cast(ptr)); + } +#endif + +#if defined(_WIN32) && (!defined(_MSC_VER) || _MSC_VER > 1400) // MSVC 2005 doesn't understand `enableIf`, so drop all `meta` + namespace meta + { + template + struct valueType { enum { value = Value }; }; + + template + inline No operator<<(Stream&, const T&); + + template + struct isStreamable : valueType(), meta::declval())) != sizeof(No)> {}; + + template + struct isStreamable : valueType {}; + + template + struct isStreamable : valueType {}; + + template + struct isStreamable : valueType {}; + + // meta doesn't work well for deleted functions and C++20 has `operator<<(std::ostream&, const wchar_t*) = delete` so exlicitly define it + template<> + struct isStreamable : valueType {}; + +# ifdef __cpp_char8_t + // meta doesn't work well for deleted functions and C++20 has `operator<<(std::ostream&, const char8_t*) = delete` so exlicitly define it + template + struct isStreamable : valueType {}; + + template + struct isStreamable : valueType {}; +# endif //__cpp_char8_t + } + + template + inline typename meta::enableIf::value && !meta::isStreamable::value, void>::type operator<<(std::wostringstream& stream, const T& data) + { + std::ostringstream ss; + ss << data; + stream << ss.str(); + } +#endif + } + + class Record + { + public: + Record(Severity severity, const char* func, size_t line, const char* file, const void* object, int instanceId) + : m_severity(severity), m_tid(util::gettid()), m_object(object), m_line(line), m_func(func), m_file(file), m_instanceId(instanceId) + { + util::ftime(&m_time); + } + + Record& ref() + { + return *this; + } + + ////////////////////////////////////////////////////////////////////////// + // Stream output operators + + Record& operator<<(char data) + { + char str[] = { data, 0 }; + return *this << str; + } + +#if PLOG_ENABLE_WCHAR_INPUT + Record& operator<<(wchar_t data) + { + wchar_t str[] = { data, 0 }; + return *this << str; + } +#endif + + Record& operator<<(util::nostream& (PLOG_CDECL *data)(util::nostream&)) + { + m_message << data; + return *this; + } + +#ifdef QT_VERSION + Record& operator<<(const QString& data) + { +# if PLOG_CHAR_IS_UTF8 + return *this << data.toStdString(); +# else + return *this << data.toStdWString(); +# endif + } + +# if QT_VERSION < 0x060000 + Record& operator<<(const QStringRef& data) + { + return *this << data.toString(); + } +# endif + +# ifdef QSTRINGVIEW_H + Record& operator<<(QStringView data) + { + return *this << data.toString(); + } +# endif +#endif + + template + Record& operator<<(const T& data) + { + using namespace plog::detail; + + m_message << data; + return *this; + } + +#ifndef __cplusplus_cli + Record& printf(const char* format, ...) + { + using namespace util; + + char* str = NULL; + va_list ap; + + va_start(ap, format); + int len = vasprintf(&str, format, ap); + static_cast(len); + va_end(ap); + + *this << str; + free(str); + + return *this; + } + +#ifdef _WIN32 + Record& printf(const wchar_t* format, ...) + { + using namespace util; + + wchar_t* str = NULL; + va_list ap; + + va_start(ap, format); + int len = vaswprintf(&str, format, ap); + static_cast(len); + va_end(ap); + + *this << str; + free(str); + + return *this; + } +#endif +#endif //__cplusplus_cli + + ////////////////////////////////////////////////////////////////////////// + // Getters + + virtual const util::Time& getTime() const + { + return m_time; + } + + virtual Severity getSeverity() const + { + return m_severity; + } + + virtual unsigned int getTid() const + { + return m_tid; + } + + virtual const void* getObject() const + { + return m_object; + } + + virtual size_t getLine() const + { + return m_line; + } + + virtual const util::nchar* getMessage() const + { + m_messageStr = m_message.str(); + return m_messageStr.c_str(); + } + + virtual const char* getFunc() const + { + m_funcStr = util::processFuncName(m_func); + return m_funcStr.c_str(); + } + + virtual const char* getFile() const + { + return m_file; + } + + virtual ~Record() // virtual destructor to satisfy -Wnon-virtual-dtor warning + { + } + + virtual int getInstanceId() const + { + return m_instanceId; + } + + private: + util::Time m_time; + const Severity m_severity; + const unsigned int m_tid; + const void* const m_object; + const size_t m_line; + util::nostringstream m_message; + const char* const m_func; + const char* const m_file; + const int m_instanceId; + mutable std::string m_funcStr; + mutable util::nstring m_messageStr; + }; +} diff --git a/datachannel/include/plog/Severity.h b/datachannel/include/plog/Severity.h new file mode 100644 index 000000000..446768e8f --- /dev/null +++ b/datachannel/include/plog/Severity.h @@ -0,0 +1,61 @@ +#pragma once +#include + +namespace plog +{ + enum Severity + { + none = 0, + fatal = 1, + error = 2, + warning = 3, + info = 4, + debug = 5, + verbose = 6 + }; + +#ifdef _MSC_VER +# pragma warning(suppress: 26812) // Prefer 'enum class' over 'enum' +#endif + inline const char* severityToString(Severity severity) + { + switch (severity) + { + case fatal: + return "FATAL"; + case error: + return "ERROR"; + case warning: + return "WARN"; + case info: + return "INFO"; + case debug: + return "DEBUG"; + case verbose: + return "VERB"; + default: + return "NONE"; + } + } + + inline Severity severityFromString(const char* str) + { + switch (std::toupper(str[0])) + { + case 'F': + return fatal; + case 'E': + return error; + case 'W': + return warning; + case 'I': + return info; + case 'D': + return debug; + case 'V': + return verbose; + default: + return none; + } + } +} diff --git a/datachannel/include/plog/Util.h b/datachannel/include/plog/Util.h new file mode 100644 index 000000000..ac01a526e --- /dev/null +++ b/datachannel/include/plog/Util.h @@ -0,0 +1,616 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include + +#ifndef PLOG_ENABLE_WCHAR_INPUT +# ifdef _WIN32 +# define PLOG_ENABLE_WCHAR_INPUT 1 +# else +# define PLOG_ENABLE_WCHAR_INPUT 0 +# endif +#endif + +////////////////////////////////////////////////////////////////////////// +// PLOG_CHAR_IS_UTF8 specifies character encoding of `char` type. On *nix +// systems it's set to UTF-8 while on Windows in can be ANSI or UTF-8. It +// automatically detects `/utf-8` command line option in MSVC. Also it can +// be set manually if required. +// This option allows to support http://utf8everywhere.org approach. + +#ifndef PLOG_CHAR_IS_UTF8 +# if defined(_WIN32) && !defined(_UTF8) +# define PLOG_CHAR_IS_UTF8 0 +# else +# define PLOG_CHAR_IS_UTF8 1 +# endif +#endif + +#ifdef _WIN32 +# if defined(PLOG_EXPORT) +# define PLOG_LINKAGE __declspec(dllexport) +# elif defined(PLOG_IMPORT) +# define PLOG_LINKAGE __declspec(dllimport) +# endif +# if defined(PLOG_GLOBAL) +# error "PLOG_GLOBAL isn't supported on Windows" +# endif +#else +# if defined(PLOG_GLOBAL) +# define PLOG_LINKAGE __attribute__ ((visibility ("default"))) +# elif defined(PLOG_LOCAL) +# define PLOG_LINKAGE __attribute__ ((visibility ("hidden"))) +# define PLOG_LINKAGE_HIDDEN PLOG_LINKAGE +# endif +# if defined(PLOG_EXPORT) || defined(PLOG_IMPORT) +# error "PLOG_EXPORT/PLOG_IMPORT is supported only on Windows" +# endif +#endif + +#ifndef PLOG_LINKAGE +# define PLOG_LINKAGE +#endif + +#ifndef PLOG_LINKAGE_HIDDEN +# define PLOG_LINKAGE_HIDDEN +#endif + +#ifdef _WIN32 +# include +# include +# include +# include +# include +#else +# include +# include +# if defined(__linux__) || defined(__FreeBSD__) +# include +# elif defined(__rtems__) +# include +# endif +# if defined(_POSIX_THREADS) +# include +# endif +# if PLOG_ENABLE_WCHAR_INPUT +# include +# endif +#endif + +#if PLOG_CHAR_IS_UTF8 +# define PLOG_NSTR(x) x +#else +# define _PLOG_NSTR(x) L##x +# define PLOG_NSTR(x) _PLOG_NSTR(x) +#endif + +#ifdef _WIN32 +# define PLOG_CDECL __cdecl +#else +# define PLOG_CDECL +#endif + +#if __cplusplus >= 201103L || defined(_MSC_VER) && _MSC_VER >= 1700 +# define PLOG_OVERRIDE override +#else +# define PLOG_OVERRIDE +#endif + +namespace plog +{ + namespace util + { +#if PLOG_CHAR_IS_UTF8 + typedef std::string nstring; + typedef std::ostringstream nostringstream; + typedef std::istringstream nistringstream; + typedef std::ostream nostream; + typedef char nchar; +#else + typedef std::wstring nstring; + typedef std::wostringstream nostringstream; + typedef std::wistringstream nistringstream; + typedef std::wostream nostream; + typedef wchar_t nchar; +#endif + + inline void localtime_s(struct tm* t, const time_t* time) + { +#if defined(_WIN32) && defined(__BORLANDC__) + ::localtime_s(time, t); +#elif defined(_WIN32) && defined(__MINGW32__) && !defined(__MINGW64_VERSION_MAJOR) + *t = *::localtime(time); +#elif defined(_WIN32) + ::localtime_s(t, time); +#else + ::localtime_r(time, t); +#endif + } + + inline void gmtime_s(struct tm* t, const time_t* time) + { +#if defined(_WIN32) && defined(__BORLANDC__) + ::gmtime_s(time, t); +#elif defined(_WIN32) && defined(__MINGW32__) && !defined(__MINGW64_VERSION_MAJOR) + *t = *::gmtime(time); +#elif defined(_WIN32) + ::gmtime_s(t, time); +#else + ::gmtime_r(time, t); +#endif + } + +#ifdef _WIN32 + typedef timeb Time; + + inline void ftime(Time* t) + { + ::ftime(t); + } +#else + struct Time + { + time_t time; + unsigned short millitm; + }; + + inline void ftime(Time* t) + { + timeval tv; + ::gettimeofday(&tv, NULL); + + t->time = tv.tv_sec; + t->millitm = static_cast(tv.tv_usec / 1000); + } +#endif + + inline unsigned int gettid() + { +#ifdef _WIN32 + return GetCurrentThreadId(); +#elif defined(__linux__) + return static_cast(::syscall(__NR_gettid)); +#elif defined(__FreeBSD__) + long tid; + syscall(SYS_thr_self, &tid); + return static_cast(tid); +#elif defined(__rtems__) + return rtems_task_self(); +#elif defined(__APPLE__) + uint64_t tid64; + pthread_threadid_np(NULL, &tid64); + return static_cast(tid64); +#else + return 0; +#endif + } + +#ifndef _GNU_SOURCE + inline int vasprintf(char** strp, const char* format, va_list ap) + { + va_list ap_copy; +#if defined(_MSC_VER) && _MSC_VER <= 1600 + ap_copy = ap; // there is no va_copy on Visual Studio 2010 +#else + va_copy(ap_copy, ap); +#endif +#ifndef __STDC_SECURE_LIB__ + int charCount = vsnprintf(NULL, 0, format, ap_copy); +#else + int charCount = _vscprintf(format, ap_copy); +#endif + va_end(ap_copy); + if (charCount < 0) + { + return -1; + } + + size_t bufferCharCount = static_cast(charCount) + 1; + + char* str = static_cast(malloc(bufferCharCount)); + if (!str) + { + return -1; + } + +#ifndef __STDC_SECURE_LIB__ + int retval = vsnprintf(str, bufferCharCount, format, ap); +#else + int retval = vsnprintf_s(str, bufferCharCount, static_cast(-1), format, ap); +#endif + if (retval < 0) + { + free(str); + return -1; + } + + *strp = str; + return retval; + } +#endif + +#ifdef _WIN32 + inline int vaswprintf(wchar_t** strp, const wchar_t* format, va_list ap) + { +#if defined(__BORLANDC__) + int charCount = 0x1000; // there is no _vscwprintf on Borland/Embarcadero +#else + int charCount = _vscwprintf(format, ap); + if (charCount < 0) + { + return -1; + } +#endif + + size_t bufferCharCount = static_cast(charCount) + 1; + + wchar_t* str = static_cast(malloc(bufferCharCount * sizeof(wchar_t))); + if (!str) + { + return -1; + } + +#if defined(__BORLANDC__) + int retval = vsnwprintf_s(str, bufferCharCount, format, ap); +#elif defined(__MINGW32__) && !defined(__MINGW64_VERSION_MAJOR) + int retval = _vsnwprintf(str, bufferCharCount, format, ap); +#else + int retval = _vsnwprintf_s(str, bufferCharCount, charCount, format, ap); +#endif + if (retval < 0) + { + free(str); + return -1; + } + + *strp = str; + return retval; + } +#endif + +#ifdef _WIN32 + inline std::wstring toWide(const char* str, UINT cp = codePage::kChar) + { + size_t len = ::strlen(str); + std::wstring wstr(len, 0); + + if (!wstr.empty()) + { + int wlen = MultiByteToWideChar(cp, 0, str, static_cast(len), &wstr[0], static_cast(wstr.size())); + wstr.resize(wlen); + } + + return wstr; + } + + inline std::wstring toWide(const std::string& str, UINT cp = codePage::kChar) + { + return toWide(str.c_str(), cp); + } + + inline const std::wstring& toWide(const std::wstring& str) // do nothing for already wide string + { + return str; + } + + inline std::string toNarrow(const std::wstring& wstr, long page) + { + int len = WideCharToMultiByte(page, 0, wstr.c_str(), static_cast(wstr.size()), 0, 0, 0, 0); + std::string str(len, 0); + + if (!str.empty()) + { + WideCharToMultiByte(page, 0, wstr.c_str(), static_cast(wstr.size()), &str[0], len, 0, 0); + } + + return str; + } +#elif PLOG_ENABLE_WCHAR_INPUT + inline std::string toNarrow(const wchar_t* wstr) + { + size_t wlen = ::wcslen(wstr); + std::string str(wlen * sizeof(wchar_t), 0); + + if (!str.empty()) + { + const char* in = reinterpret_cast(&wstr[0]); + char* out = &str[0]; + size_t inBytes = wlen * sizeof(wchar_t); + size_t outBytes = str.size(); + + iconv_t cd = ::iconv_open("UTF-8", "WCHAR_T"); + ::iconv(cd, const_cast(&in), &inBytes, &out, &outBytes); + ::iconv_close(cd); + + str.resize(str.size() - outBytes); + } + + return str; + } +#endif + + inline std::string processFuncName(const char* func) + { +#if (defined(_WIN32) && !defined(__MINGW32__)) || defined(__OBJC__) + return std::string(func); +#else + const char* funcBegin = func; + const char* funcEnd = ::strchr(funcBegin, '('); + int foundTemplate = 0; + + if (!funcEnd) + { + return std::string(func); + } + + for (const char* i = funcEnd - 1; i >= funcBegin; --i) // search backwards for the first space char + { + if (*i == '>') + { + foundTemplate++; + } + else if (*i == '<') + { + foundTemplate--; + } + else if (*i == ' ' && foundTemplate == 0) + { + funcBegin = i + 1; + break; + } + } + + return std::string(funcBegin, funcEnd); +#endif + } + + inline const nchar* findExtensionDot(const nchar* fileName) + { +#if PLOG_CHAR_IS_UTF8 + return std::strrchr(fileName, '.'); +#else + return std::wcsrchr(fileName, L'.'); +#endif + } + + inline void splitFileName(const nchar* fileName, nstring& fileNameNoExt, nstring& fileExt) + { + const nchar* dot = findExtensionDot(fileName); + + if (dot) + { + fileNameNoExt.assign(fileName, dot); + fileExt.assign(dot + 1); + } + else + { + fileNameNoExt.assign(fileName); + fileExt.clear(); + } + } + + class PLOG_LINKAGE NonCopyable + { + protected: + NonCopyable() + { + } + + private: + NonCopyable(const NonCopyable&); + NonCopyable& operator=(const NonCopyable&); + }; + + class PLOG_LINKAGE_HIDDEN File : NonCopyable + { + public: + File() : m_file(-1) + { + } + + ~File() + { + close(); + } + + size_t open(const nstring& fileName) + { +#if defined(_WIN32) && (defined(__BORLANDC__) || defined(__MINGW32__)) + m_file = ::_wsopen(toWide(fileName).c_str(), _O_CREAT | _O_WRONLY | _O_BINARY | _O_NOINHERIT, SH_DENYWR, _S_IREAD | _S_IWRITE); +#elif defined(_WIN32) + ::_wsopen_s(&m_file, toWide(fileName).c_str(), _O_CREAT | _O_WRONLY | _O_BINARY | _O_NOINHERIT, _SH_DENYWR, _S_IREAD | _S_IWRITE); +#elif defined(O_CLOEXEC) + m_file = ::open(fileName.c_str(), O_CREAT | O_APPEND | O_WRONLY | O_CLOEXEC, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH); +#else + m_file = ::open(fileName.c_str(), O_CREAT | O_APPEND | O_WRONLY, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH); +#endif + return seek(0, SEEK_END); + } + + size_t write(const void* buf, size_t count) + { + return m_file != -1 ? static_cast( +#ifdef _WIN32 + ::_write(m_file, buf, static_cast(count)) +#else + ::write(m_file, buf, count) +#endif + ) : static_cast(-1); + } + + template + size_t write(const std::basic_string& str) + { + return write(str.data(), str.size() * sizeof(CharType)); + } + + size_t seek(size_t offset, int whence) + { + return m_file != -1 ? static_cast( +#if defined(_WIN32) && (defined(__BORLANDC__) || defined(__MINGW32__)) + ::_lseek(m_file, static_cast(offset), whence) +#elif defined(_WIN32) + ::_lseeki64(m_file, static_cast(offset), whence) +#else + ::lseek(m_file, static_cast(offset), whence) +#endif + ) : static_cast(-1); + } + + void close() + { + if (m_file != -1) + { +#ifdef _WIN32 + ::_close(m_file); +#else + ::close(m_file); +#endif + m_file = -1; + } + } + + static int unlink(const nstring& fileName) + { +#ifdef _WIN32 + return ::_wunlink(toWide(fileName).c_str()); +#else + return ::unlink(fileName.c_str()); +#endif + } + + static int rename(const nstring& oldFilename, const nstring& newFilename) + { +#ifdef _WIN32 + return MoveFileW(toWide(oldFilename).c_str(), toWide(newFilename).c_str()); +#else + return ::rename(oldFilename.c_str(), newFilename.c_str()); +#endif + } + + private: + int m_file; + }; + + class PLOG_LINKAGE_HIDDEN Mutex : NonCopyable + { + public: + Mutex() + { +#ifdef _WIN32 + InitializeCriticalSection(&m_sync); +#elif defined(__rtems__) + rtems_semaphore_create(0, 1, + RTEMS_PRIORITY | + RTEMS_BINARY_SEMAPHORE | + RTEMS_INHERIT_PRIORITY, 1, &m_sync); +#elif defined(_POSIX_THREADS) + ::pthread_mutex_init(&m_sync, 0); +#endif + } + + ~Mutex() + { +#ifdef _WIN32 + DeleteCriticalSection(&m_sync); +#elif defined(__rtems__) + rtems_semaphore_delete(m_sync); +#elif defined(_POSIX_THREADS) + ::pthread_mutex_destroy(&m_sync); +#endif + } + + friend class MutexLock; + + private: + void lock() + { +#ifdef _WIN32 + EnterCriticalSection(&m_sync); +#elif defined(__rtems__) + rtems_semaphore_obtain(m_sync, RTEMS_WAIT, RTEMS_NO_TIMEOUT); +#elif defined(_POSIX_THREADS) + ::pthread_mutex_lock(&m_sync); +#endif + } + + void unlock() + { +#ifdef _WIN32 + LeaveCriticalSection(&m_sync); +#elif defined(__rtems__) + rtems_semaphore_release(m_sync); +#elif defined(_POSIX_THREADS) + ::pthread_mutex_unlock(&m_sync); +#endif + } + + private: +#ifdef _WIN32 + CRITICAL_SECTION m_sync; +#elif defined(__rtems__) + rtems_id m_sync; +#elif defined(_POSIX_THREADS) + pthread_mutex_t m_sync; +#endif + }; + + class PLOG_LINKAGE_HIDDEN MutexLock : NonCopyable + { + public: + MutexLock(Mutex& mutex) : m_mutex(mutex) + { + m_mutex.lock(); + } + + ~MutexLock() + { + m_mutex.unlock(); + } + + private: + Mutex& m_mutex; + }; + + template +#ifdef _WIN32 + class Singleton : NonCopyable +#else + class PLOG_LINKAGE Singleton : NonCopyable +#endif + { + public: +#if (defined(__clang__) || defined(__GNUC__) && __GNUC__ >= 8) && !defined(__BORLANDC__) + // This constructor is called before the `T` object is fully constructed, and + // pointers are not dereferenced anyway, so UBSan shouldn't check vptrs. + __attribute__((no_sanitize("vptr"))) +#endif + Singleton() + { + assert(!m_instance); + m_instance = static_cast(this); + } + + ~Singleton() + { + assert(m_instance); + m_instance = 0; + } + + static T* getInstance() + { + return m_instance; + } + + private: + static T* m_instance; + }; + + template + T* Singleton::m_instance = NULL; + } +} diff --git a/datachannel/include/plog/WinApi.h b/datachannel/include/plog/WinApi.h new file mode 100644 index 000000000..ccf44af0a --- /dev/null +++ b/datachannel/include/plog/WinApi.h @@ -0,0 +1,175 @@ +#pragma once + +#ifdef _WIN32 + +// These windows structs must be in a global namespace +struct HKEY__; +struct _SECURITY_ATTRIBUTES; +struct _CONSOLE_SCREEN_BUFFER_INFO; +struct _RTL_CRITICAL_SECTION; + +namespace plog +{ + typedef unsigned long DWORD; + typedef unsigned short WORD; + typedef unsigned char BYTE; + typedef unsigned int UINT; + typedef int BOOL; + typedef long LSTATUS; + typedef char* LPSTR; + typedef wchar_t* LPWSTR; + typedef const char* LPCSTR; + typedef const wchar_t* LPCWSTR; + typedef void* HANDLE; + typedef HKEY__* HKEY; + typedef size_t ULONG_PTR; + + struct CRITICAL_SECTION + { + void* DebugInfo; + long LockCount; + long RecursionCount; + HANDLE OwningThread; + HANDLE LockSemaphore; + ULONG_PTR SpinCount; + }; + + struct COORD + { + short X; + short Y; + }; + + struct SMALL_RECT + { + short Left; + short Top; + short Right; + short Bottom; + }; + + struct CONSOLE_SCREEN_BUFFER_INFO + { + COORD dwSize; + COORD dwCursorPosition; + WORD wAttributes; + SMALL_RECT srWindow; + COORD dwMaximumWindowSize; + }; + + namespace codePage + { + const UINT kActive = 0; + const UINT kUTF8 = 65001; +#if PLOG_CHAR_IS_UTF8 + const UINT kChar = kUTF8; +#else + const UINT kChar = kActive; +#endif + } + + namespace eventLog + { + const WORD kErrorType = 0x0001; + const WORD kWarningType = 0x0002; + const WORD kInformationType = 0x0004; + } + + namespace hkey + { + const HKEY kLocalMachine = reinterpret_cast(static_cast(0x80000002)); + } + + namespace regSam + { + const DWORD kQueryValue = 0x0001; + const DWORD kSetValue = 0x0002; + } + + namespace regType + { + const DWORD kExpandSz = 2; + const DWORD kDword = 4; + } + + namespace stdHandle + { + const DWORD kOutput = static_cast(-11); + const DWORD kErrorOutput = static_cast(-12); + } + + namespace foreground + { + const WORD kBlue = 0x0001; + const WORD kGreen = 0x0002; + const WORD kRed = 0x0004; + const WORD kIntensity = 0x0008; + } + + namespace background + { + const WORD kBlue = 0x0010; + const WORD kGreen = 0x0020; + const WORD kRed = 0x0040; + const WORD kIntensity = 0x0080; + } + + extern "C" + { + __declspec(dllimport) int __stdcall MultiByteToWideChar(UINT CodePage, DWORD dwFlags, LPCSTR lpMultiByteStr, int cbMultiByte, LPWSTR lpWideCharStr, int cchWideChar); + __declspec(dllimport) int __stdcall WideCharToMultiByte(UINT CodePage, DWORD dwFlags, LPCWSTR lpWideCharStr, int cchWideChar, LPSTR lpMultiByteStr, int cbMultiByte, const char* lpDefaultChar, BOOL* lpUsedDefaultChar); + + __declspec(dllimport) DWORD __stdcall GetCurrentThreadId(); + + __declspec(dllimport) BOOL __stdcall MoveFileW(LPCWSTR lpExistingFileName, LPCWSTR lpNewFileName); + + __declspec(dllimport) void __stdcall InitializeCriticalSection(_RTL_CRITICAL_SECTION* lpCriticalSection); + __declspec(dllimport) void __stdcall EnterCriticalSection(_RTL_CRITICAL_SECTION* lpCriticalSection); + __declspec(dllimport) void __stdcall LeaveCriticalSection(_RTL_CRITICAL_SECTION* lpCriticalSection); + __declspec(dllimport) void __stdcall DeleteCriticalSection(_RTL_CRITICAL_SECTION* lpCriticalSection); + + __declspec(dllimport) HANDLE __stdcall RegisterEventSourceW(LPCWSTR lpUNCServerName, LPCWSTR lpSourceName); + __declspec(dllimport) BOOL __stdcall DeregisterEventSource(HANDLE hEventLog); + __declspec(dllimport) BOOL __stdcall ReportEventW(HANDLE hEventLog, WORD wType, WORD wCategory, DWORD dwEventID, void* lpUserSid, WORD wNumStrings, DWORD dwDataSize, LPCWSTR* lpStrings, void* lpRawData); + + __declspec(dllimport) LSTATUS __stdcall RegCreateKeyExW(HKEY hKey, LPCWSTR lpSubKey, DWORD Reserved, LPWSTR lpClass, DWORD dwOptions, DWORD samDesired, _SECURITY_ATTRIBUTES* lpSecurityAttributes, HKEY* phkResult, DWORD* lpdwDisposition); + __declspec(dllimport) LSTATUS __stdcall RegSetValueExW(HKEY hKey, LPCWSTR lpValueName, DWORD Reserved, DWORD dwType, const BYTE* lpData, DWORD cbData); + __declspec(dllimport) LSTATUS __stdcall RegCloseKey(HKEY hKey); + __declspec(dllimport) LSTATUS __stdcall RegOpenKeyExW(HKEY hKey, LPCWSTR lpSubKey, DWORD ulOptions, DWORD samDesired, HKEY* phkResult); + __declspec(dllimport) LSTATUS __stdcall RegDeleteKeyW(HKEY hKey, LPCWSTR lpSubKey); + + __declspec(dllimport) HANDLE __stdcall GetStdHandle(DWORD nStdHandle); + + __declspec(dllimport) BOOL __stdcall WriteConsoleW(HANDLE hConsoleOutput, const void* lpBuffer, DWORD nNumberOfCharsToWrite, DWORD* lpNumberOfCharsWritten, void* lpReserved); + __declspec(dllimport) BOOL __stdcall GetConsoleScreenBufferInfo(HANDLE hConsoleOutput, _CONSOLE_SCREEN_BUFFER_INFO* lpConsoleScreenBufferInfo); + __declspec(dllimport) BOOL __stdcall SetConsoleTextAttribute(HANDLE hConsoleOutput, WORD wAttributes); + + __declspec(dllimport) void __stdcall OutputDebugStringW(LPCWSTR lpOutputString); + } + + inline void InitializeCriticalSection(CRITICAL_SECTION* criticalSection) + { + plog::InitializeCriticalSection(reinterpret_cast<_RTL_CRITICAL_SECTION*>(criticalSection)); + } + + inline void EnterCriticalSection(CRITICAL_SECTION* criticalSection) + { + plog::EnterCriticalSection(reinterpret_cast<_RTL_CRITICAL_SECTION*>(criticalSection)); + } + + inline void LeaveCriticalSection(CRITICAL_SECTION* criticalSection) + { + plog::LeaveCriticalSection(reinterpret_cast<_RTL_CRITICAL_SECTION*>(criticalSection)); + } + + inline void DeleteCriticalSection(CRITICAL_SECTION* criticalSection) + { + plog::DeleteCriticalSection(reinterpret_cast<_RTL_CRITICAL_SECTION*>(criticalSection)); + } + + inline BOOL GetConsoleScreenBufferInfo(HANDLE consoleOutput, CONSOLE_SCREEN_BUFFER_INFO* consoleScreenBufferInfo) + { + return plog::GetConsoleScreenBufferInfo(consoleOutput, reinterpret_cast<_CONSOLE_SCREEN_BUFFER_INFO*>(consoleScreenBufferInfo)); + } +} +#endif // _WIN32 diff --git a/datachannel/include/rtc/av1rtppacketizer.hpp b/datachannel/include/rtc/av1rtppacketizer.hpp new file mode 100644 index 000000000..b56a875f9 --- /dev/null +++ b/datachannel/include/rtc/av1rtppacketizer.hpp @@ -0,0 +1,56 @@ +/** + * Copyright (c) 2023 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_AV1_RTP_PACKETIZER_H +#define RTC_AV1_RTP_PACKETIZER_H + +#if RTC_ENABLE_MEDIA + +#include "mediahandler.hpp" +#include "nalunit.hpp" +#include "rtppacketizer.hpp" + +namespace rtc { + +// RTP packetization of AV1 payload +class RTC_CPP_EXPORT AV1RtpPacketizer final : public RtpPacketizer { +public: + // Default clock rate for AV1 in RTP + inline static const uint32_t defaultClockRate = 90 * 1000; + + // Define how OBUs are seperated in a AV1 Sample + enum class Packetization { + Obu = RTC_OBU_PACKETIZED_OBU, + TemporalUnit = RTC_OBU_PACKETIZED_TEMPORAL_UNIT, + }; + + // Constructs AV1 payload packetizer with given RTP configuration. + // @note RTP configuration is used in packetization process which may change some configuration + // properties such as sequence number. + AV1RtpPacketizer(Packetization packetization, shared_ptr rtpConfig, + uint16_t maxFragmentSize = NalUnits::defaultMaximumFragmentSize); + + void outgoing(message_vector &messages, const message_callback &send) override; + +private: + shared_ptr splitMessage(binary_ptr message); + std::vector> packetizeObu(binary_ptr message, uint16_t maxFragmentSize); + + const uint16_t maxFragmentSize; + const Packetization packetization; + std::shared_ptr sequenceHeader; +}; + +// For backward compatibility, do not use +using AV1PacketizationHandler [[deprecated("Add AV1RtpPacketizer directly")]] = PacketizationHandler; + +} // namespace rtc + +#endif /* RTC_ENABLE_MEDIA */ + +#endif /* RTC_AV1_RTP_PACKETIZER_H */ diff --git a/datachannel/include/rtc/candidate.hpp b/datachannel/include/rtc/candidate.hpp new file mode 100644 index 000000000..00ca20d84 --- /dev/null +++ b/datachannel/include/rtc/candidate.hpp @@ -0,0 +1,77 @@ +/** + * Copyright (c) 2019 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_CANDIDATE_H +#define RTC_CANDIDATE_H + +#include "common.hpp" + +#include + +namespace rtc { + +class RTC_CPP_EXPORT Candidate { +public: + enum class Family { Unresolved, Ipv4, Ipv6 }; + enum class Type { Unknown, Host, ServerReflexive, PeerReflexive, Relayed }; + enum class TransportType { Unknown, Udp, TcpActive, TcpPassive, TcpSo, TcpUnknown }; + + Candidate(); + Candidate(string candidate); + Candidate(string candidate, string mid); + + void hintMid(string mid); + void changeAddress(string addr); + void changeAddress(string addr, uint16_t port); + void changeAddress(string addr, string service); + + enum class ResolveMode { Simple, Lookup }; + bool resolve(ResolveMode mode = ResolveMode::Simple); + + Type type() const; + TransportType transportType() const; + uint32_t priority() const; + string candidate() const; + string mid() const; + operator string() const; + + bool operator==(const Candidate &other) const; + bool operator!=(const Candidate &other) const; + + bool isResolved() const; + Family family() const; + optional address() const; + optional port() const; + +private: + void parse(string candidate); + + string mFoundation; + uint32_t mComponent, mPriority; + string mTypeString, mTransportString; + Type mType; + TransportType mTransportType; + string mNode, mService; + string mTail; + + optional mMid; + + // Extracted on resolution + Family mFamily; + string mAddress; + uint16_t mPort; +}; + +RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, const Candidate &candidate); +RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, const Candidate::Type &type); +RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, + const Candidate::TransportType &transportType); + +} // namespace rtc + +#endif diff --git a/datachannel/include/rtc/channel.hpp b/datachannel/include/rtc/channel.hpp new file mode 100644 index 000000000..384279d76 --- /dev/null +++ b/datachannel/include/rtc/channel.hpp @@ -0,0 +1,61 @@ +/** + * Copyright (c) 2019-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_CHANNEL_H +#define RTC_CHANNEL_H + +#include "common.hpp" + +#include +#include + +namespace rtc { + +namespace impl { +struct Channel; +} + +class RTC_CPP_EXPORT Channel : private CheshireCat { +public: + virtual ~Channel(); + + virtual void close() = 0; + virtual bool send(message_variant data) = 0; // returns false if buffered + virtual bool send(const byte *data, size_t size) = 0; + + virtual bool isOpen() const = 0; + virtual bool isClosed() const = 0; + virtual size_t maxMessageSize() const; // max message size in a call to send + virtual size_t bufferedAmount() const; // total size buffered to send + + void onOpen(std::function callback); + void onClosed(std::function callback); + void onError(std::function callback); + + void onMessage(std::function callback); + void onMessage(std::function binaryCallback, + std::function stringCallback); + + void onBufferedAmountLow(std::function callback); + void setBufferedAmountLowThreshold(size_t amount); + + void resetCallbacks(); + + // Extended API + optional receive(); // only if onMessage unset + optional peek(); // only if onMessage unset + size_t availableAmount() const; // total size available to receive + void onAvailable(std::function callback); + +protected: + Channel(impl_ptr impl); +}; + +} // namespace rtc + +#endif // RTC_CHANNEL_H diff --git a/datachannel/include/rtc/common.hpp b/datachannel/include/rtc/common.hpp new file mode 100644 index 000000000..08f981cb8 --- /dev/null +++ b/datachannel/include/rtc/common.hpp @@ -0,0 +1,86 @@ +/** + * Copyright (c) 2019 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_COMMON_H +#define RTC_COMMON_H + +#ifdef RTC_STATIC +#define RTC_CPP_EXPORT +#else // dynamic library +#ifdef _WIN32 +#ifdef RTC_EXPORTS +#define RTC_CPP_EXPORT __declspec(dllexport) // building the library +#else +#define RTC_CPP_EXPORT __declspec(dllimport) // using the library +#endif +#else // not WIN32 +#define RTC_CPP_EXPORT +#endif +#endif + +#ifdef _WIN32 +#ifndef _WIN32_WINNT +#define _WIN32_WINNT 0x0602 // Windows 8 +#endif +#ifdef _MSC_VER +#pragma warning(disable : 4251) // disable "X needs to have dll-interface..." +#endif +#endif + +#ifndef RTC_ENABLE_WEBSOCKET +#define RTC_ENABLE_WEBSOCKET 1 +#endif + +#ifndef RTC_ENABLE_MEDIA +#define RTC_ENABLE_MEDIA 1 +#endif + +#include "rtc.h" // for C API defines + +#include "utils.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace rtc { + +using std::byte; +using std::nullopt; +using std::optional; +using std::shared_ptr; +using std::string; +using std::string_view; +using std::unique_ptr; +using std::variant; +using std::weak_ptr; + +using binary = std::vector; +using binary_ptr = shared_ptr; +using message_variant = variant; + +using std::int16_t; +using std::int32_t; +using std::int64_t; +using std::int8_t; +using std::ptrdiff_t; +using std::size_t; +using std::uint16_t; +using std::uint32_t; +using std::uint64_t; +using std::uint8_t; + +} // namespace rtc + +#endif diff --git a/datachannel/include/rtc/configuration.hpp b/datachannel/include/rtc/configuration.hpp new file mode 100644 index 000000000..41bea91d7 --- /dev/null +++ b/datachannel/include/rtc/configuration.hpp @@ -0,0 +1,122 @@ +/** + * Copyright (c) 2019 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_ICE_CONFIGURATION_H +#define RTC_ICE_CONFIGURATION_H + +#include "common.hpp" + +#include + +namespace rtc { + +struct RTC_CPP_EXPORT IceServer { + enum class Type { Stun, Turn }; + enum class RelayType { TurnUdp, TurnTcp, TurnTls }; + + // Any type + IceServer(const string &url); + + // STUN + IceServer(string hostname_, uint16_t port_); + IceServer(string hostname_, string service_); + + // TURN + IceServer(string hostname_, uint16_t port, string username_, string password_, + RelayType relayType_ = RelayType::TurnUdp); + IceServer(string hostname_, string service_, string username_, string password_, + RelayType relayType_ = RelayType::TurnUdp); + + string hostname; + uint16_t port; + Type type; + string username; + string password; + RelayType relayType; +}; + +struct RTC_CPP_EXPORT ProxyServer { + enum class Type { Http, Socks5 }; + + ProxyServer(const string &url); + + ProxyServer(Type type_, string hostname_, uint16_t port_); + ProxyServer(Type type_, string hostname_, uint16_t port_, string username_, string password_); + + Type type; + string hostname; + uint16_t port; + optional username; + optional password; +}; + +enum class CertificateType { + Default = RTC_CERTIFICATE_DEFAULT, // ECDSA + Ecdsa = RTC_CERTIFICATE_ECDSA, + Rsa = RTC_CERTIFICATE_RSA +}; + +enum class TransportPolicy { All = RTC_TRANSPORT_POLICY_ALL, Relay = RTC_TRANSPORT_POLICY_RELAY }; + +struct RTC_CPP_EXPORT Configuration { + // ICE settings + std::vector iceServers; + optional proxyServer; // libnice only + optional bindAddress; // libjuice only, default any + + // Options + CertificateType certificateType = CertificateType::Default; + TransportPolicy iceTransportPolicy = TransportPolicy::All; + bool enableIceTcp = false; // libnice only + bool enableIceUdpMux = false; // libjuice only + bool disableAutoNegotiation = false; + bool forceMediaTransport = false; + + // Port range + uint16_t portRangeBegin = 1024; + uint16_t portRangeEnd = 65535; + + // Network MTU + optional mtu; + + // Local maximum message size for Data Channels + optional maxMessageSize; +}; + +#ifdef RTC_ENABLE_WEBSOCKET + +struct WebSocketConfiguration { + bool disableTlsVerification = false; // if true, don't verify the TLS certificate + optional proxyServer; // only non-authenticated http supported for now + std::vector protocols; + optional connectionTimeout; // zero to disable + optional pingInterval; // zero to disable + optional maxOutstandingPings; + optional caCertificatePemFile; + optional certificatePemFile; + optional keyPemFile; + optional keyPemPass; + optional maxMessageSize; +}; + +struct WebSocketServerConfiguration { + uint16_t port = 8080; + bool enableTls = false; + optional certificatePemFile; + optional keyPemFile; + optional keyPemPass; + optional bindAddress; + optional connectionTimeout; + optional maxMessageSize; +}; + +#endif + +} // namespace rtc + +#endif diff --git a/datachannel/include/rtc/datachannel.hpp b/datachannel/include/rtc/datachannel.hpp new file mode 100644 index 000000000..0e83a9790 --- /dev/null +++ b/datachannel/include/rtc/datachannel.hpp @@ -0,0 +1,80 @@ +/** + * Copyright (c) 2019 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_DATA_CHANNEL_H +#define RTC_DATA_CHANNEL_H + +#include "channel.hpp" +#include "common.hpp" +#include "reliability.hpp" + +#include + +namespace rtc { + +namespace impl { + +struct DataChannel; +struct PeerConnection; + +} // namespace impl + +class RTC_CPP_EXPORT DataChannel final : private CheshireCat, public Channel { +public: + DataChannel(impl_ptr impl); + ~DataChannel() override; + + optional stream() const; + optional id() const; + string label() const; + string protocol() const; + Reliability reliability() const; + + bool isOpen(void) const override; + bool isClosed(void) const override; + size_t maxMessageSize() const override; + + void close(void) override; + bool send(message_variant data) override; + bool send(const byte *data, size_t size) override; + template bool sendBuffer(const Buffer &buf); + template bool sendBuffer(Iterator first, Iterator last); + +private: + using CheshireCat::impl; +}; + +template std::pair to_bytes(const Buffer &buf) { + using T = typename std::remove_pointer::type; + using E = typename std::conditional::value, byte, T>::type; + return std::make_pair(static_cast(static_cast(buf.data())), + buf.size() * sizeof(E)); +} + +template bool DataChannel::sendBuffer(const Buffer &buf) { + auto [bytes, size] = to_bytes(buf); + return send(bytes, size); +} + +template bool DataChannel::sendBuffer(Iterator first, Iterator last) { + size_t size = 0; + for (Iterator it = first; it != last; ++it) + size += it->size(); + + binary buffer(size); + byte *pos = buffer.data(); + for (Iterator it = first; it != last; ++it) { + auto [bytes, len] = to_bytes(*it); + pos = std::copy(bytes, bytes + len, pos); + } + return send(std::move(buffer)); +} + +} // namespace rtc + +#endif diff --git a/datachannel/include/rtc/description.hpp b/datachannel/include/rtc/description.hpp new file mode 100644 index 000000000..0d0c58b5f --- /dev/null +++ b/datachannel/include/rtc/description.hpp @@ -0,0 +1,324 @@ +/** + * Copyright (c) 2019-2020 Paul-Louis Ageneau + * Copyright (c) 2020 Staz Modrzynski + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_DESCRIPTION_H +#define RTC_DESCRIPTION_H + +#include "candidate.hpp" +#include "common.hpp" + +#include +#include +#include + +namespace rtc { + +const string DEFAULT_OPUS_AUDIO_PROFILE = + "minptime=10;maxaveragebitrate=96000;stereo=1;sprop-stereo=1;useinbandfec=1"; + +// Use Constrained Baseline profile Level 3.1 (necessary for Firefox) +// https://developer.mozilla.org/en-US/docs/Web/Media/Formats/WebRTC_codecs#Supported_video_codecs +// TODO: Should be 42E0 but 42C0 appears to be more compatible. Investigate this. +const string DEFAULT_H264_VIDEO_PROFILE = + "profile-level-id=42e01f;packetization-mode=1;level-asymmetry-allowed=1"; + +struct CertificateFingerprint { + enum class Algorithm { Sha1, Sha224, Sha256, Sha384, Sha512 }; + static string AlgorithmIdentifier(Algorithm algorithm); + static size_t AlgorithmSize(Algorithm algorithm); + + bool isValid() const; + + Algorithm algorithm; + string value; +}; + +class RTC_CPP_EXPORT Description { +public: + enum class Type { Unspec, Offer, Answer, Pranswer, Rollback }; + enum class Role { ActPass, Passive, Active }; + + enum class Direction { + SendOnly = RTC_DIRECTION_SENDONLY, + RecvOnly = RTC_DIRECTION_RECVONLY, + SendRecv = RTC_DIRECTION_SENDRECV, + Inactive = RTC_DIRECTION_INACTIVE, + Unknown = RTC_DIRECTION_UNKNOWN + }; + + Description(const string &sdp, Type type = Type::Unspec, Role role = Role::ActPass); + Description(const string &sdp, string typeString); + + Type type() const; + string typeString() const; + Role role() const; + string bundleMid() const; + std::vector iceOptions() const; + optional iceUfrag() const; + optional icePwd() const; + optional fingerprint() const; + bool ended() const; + + void hintType(Type type); + void setFingerprint(CertificateFingerprint f); + void addIceOption(string option); + void removeIceOption(const string &option); + + std::vector attributes() const; + void addAttribute(string attr); + void removeAttribute(const string &attr); + + std::vector candidates() const; + std::vector extractCandidates(); + bool hasCandidate(const Candidate &candidate) const; + void addCandidate(Candidate candidate); + void addCandidates(std::vector candidates); + void endCandidates(); + + operator string() const; + string generateSdp(string_view eol = "\r\n") const; + string generateApplicationSdp(string_view eol = "\r\n") const; + + class RTC_CPP_EXPORT Entry { + public: + virtual ~Entry() = default; + + virtual string type() const; + virtual string description() const; + virtual string mid() const; + + Direction direction() const; + void setDirection(Direction dir); + + bool isRemoved() const; + void markRemoved(); + + std::vector attributes() const; + void addAttribute(string attr); + void removeAttribute(const string &attr); + void addRid(string rid); + + struct RTC_CPP_EXPORT ExtMap { + static int parseId(string_view description); + + ExtMap(int id, string uri, Direction direction = Direction::Unknown); + ExtMap(string_view description); + + void setDescription(string_view description); + + int id; + string uri; + string attributes; + Direction direction = Direction::Unknown; + }; + + std::vector extIds(); + ExtMap *extMap(int id); + const ExtMap *extMap(int id) const; + void addExtMap(ExtMap map); + void removeExtMap(int id); + + operator string() const; + string generateSdp(string_view eol = "\r\n", string_view addr = "0.0.0.0", + uint16_t port = 9) const; + + virtual void parseSdpLine(string_view line); + + protected: + Entry(const string &mline, string mid, Direction dir = Direction::Unknown); + + virtual string generateSdpLines(string_view eol) const; + + std::vector mAttributes; + std::map mExtMaps; + + private: + string mType; + string mDescription; + string mMid; + std::vector mRids; + Direction mDirection; + bool mIsRemoved; + }; + + struct RTC_CPP_EXPORT Application : public Entry { + public: + Application(string mid = "data"); + Application(const string &mline, string mid); + virtual ~Application() = default; + + string description() const override; + Application reciprocate() const; + + void setSctpPort(uint16_t port); + void hintSctpPort(uint16_t port); + void setMaxMessageSize(size_t size); + + optional sctpPort() const; + optional maxMessageSize() const; + + virtual void parseSdpLine(string_view line) override; + + private: + virtual string generateSdpLines(string_view eol) const override; + + optional mSctpPort; + optional mMaxMessageSize; + }; + + // Media (non-data) + class RTC_CPP_EXPORT Media : public Entry { + public: + Media(const string &sdp); + Media(const string &mline, string mid, Direction dir = Direction::SendOnly); + virtual ~Media() = default; + + string description() const override; + Media reciprocate() const; + + void addSSRC(uint32_t ssrc, optional name, optional msid = nullopt, + optional trackId = nullopt); + void removeSSRC(uint32_t ssrc); + void replaceSSRC(uint32_t old, uint32_t ssrc, optional name, + optional msid = nullopt, optional trackID = nullopt); + bool hasSSRC(uint32_t ssrc) const; + void clearSSRCs(); + std::vector getSSRCs() const; + optional getCNameForSsrc(uint32_t ssrc) const; + + int bitrate() const; + void setBitrate(int bitrate); + + struct RTC_CPP_EXPORT RtpMap { + static int parsePayloadType(string_view description); + + explicit RtpMap(int payloadType); + RtpMap(string_view description); + + void setDescription(string_view description); + + void addFeedback(string fb); + void removeFeedback(const string &str); + void addParameter(string p); + void removeParameter(const string &str); + + int payloadType; + string format; + int clockRate; + string encParams; + + std::vector rtcpFbs; + std::vector fmtps; + }; + + bool hasPayloadType(int payloadType) const; + std::vector payloadTypes() const; + RtpMap *rtpMap(int payloadType); + const RtpMap *rtpMap(int payloadType) const; + void addRtpMap(RtpMap map); + void removeRtpMap(int payloadType); + void removeFormat(const string &format); + + void addRtxCodec(int payloadType, int origPayloadType, unsigned int clockRate); + + virtual void parseSdpLine(string_view line) override; + + private: + virtual string generateSdpLines(string_view eol) const override; + + int mBas = -1; + + std::map mRtpMaps; + std::vector mSsrcs; + std::map mCNameMap; + }; + + class RTC_CPP_EXPORT Audio : public Media { + public: + Audio(string mid = "audio", Direction dir = Direction::SendOnly); + + void addAudioCodec(int payloadType, string codec, optional profile = std::nullopt); + void addOpusCodec(int payloadType, optional profile = DEFAULT_OPUS_AUDIO_PROFILE); + void addPCMACodec(int payloadType, optional profile = std::nullopt); + void addPCMUCodec(int payloadType, optional profile = std::nullopt); + void addAACCodec(int payloadType, optional profile = std::nullopt); + + [[deprecated("Use addAACCodec")]] inline void + addAacCodec(int payloadType, optional profile = std::nullopt) { + addAACCodec(payloadType, std::move(profile)); + }; + }; + + class RTC_CPP_EXPORT Video : public Media { + public: + Video(string mid = "video", Direction dir = Direction::SendOnly); + + void addVideoCodec(int payloadType, string codec, optional profile = std::nullopt); + + void addH264Codec(int payloadType, optional profile = DEFAULT_H264_VIDEO_PROFILE); + void addH265Codec(int payloadType, optional profile = std::nullopt); + void addVP8Codec(int payloadType, optional profile = std::nullopt); + void addVP9Codec(int payloadType, optional profile = std::nullopt); + void addAV1Codec(int payloadType, optional profile = std::nullopt); + }; + + bool hasApplication() const; + bool hasAudioOrVideo() const; + bool hasMid(string_view mid) const; + + int addMedia(Media media); + int addMedia(Application application); + int addApplication(string mid = "data"); + int addVideo(string mid = "video", Direction dir = Direction::SendOnly); + int addAudio(string mid = "audio", Direction dir = Direction::SendOnly); + void clearMedia(); + + variant media(unsigned int index); + variant media(unsigned int index) const; + unsigned int mediaCount() const; + + const Application *application() const; + Application *application(); + + static Type stringToType(const string &typeString); + static string typeToString(Type type); + +private: + optional defaultCandidate() const; + shared_ptr createEntry(string mline, string mid, Direction dir); + void removeApplication(); + + Type mType; + + // Session-level attributes + Role mRole; + string mUsername; + string mSessionId; + std::vector mIceOptions; + optional mIceUfrag, mIcePwd; + optional mFingerprint; + std::vector mAttributes; // other attributes + + // Entries + std::vector> mEntries; + shared_ptr mApplication; + + // Candidates + std::vector mCandidates; + bool mEnded = false; +}; + +RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, const Description &description); +RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, Description::Type type); +RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, Description::Role role); +RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, const Description::Direction &direction); + +} // namespace rtc + +#endif diff --git a/datachannel/include/rtc/global.hpp b/datachannel/include/rtc/global.hpp new file mode 100644 index 000000000..84317f4dd --- /dev/null +++ b/datachannel/include/rtc/global.hpp @@ -0,0 +1,59 @@ +/** + * Copyright (c) 2020-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_GLOBAL_H +#define RTC_GLOBAL_H + +#include "common.hpp" + +#include +#include +#include + +namespace rtc { + +enum class LogLevel { // Don't change, it must match plog severity + None = 0, + Fatal = 1, + Error = 2, + Warning = 3, + Info = 4, + Debug = 5, + Verbose = 6 +}; + +typedef std::function LogCallback; + +RTC_CPP_EXPORT void InitLogger(LogLevel level, LogCallback callback = nullptr); + +RTC_CPP_EXPORT void Preload(); +RTC_CPP_EXPORT std::shared_future Cleanup(); + +struct SctpSettings { + // For the following settings, not set means optimized default + optional recvBufferSize; // in bytes + optional sendBufferSize; // in bytes + optional maxChunksOnQueue; // in chunks + optional initialCongestionWindow; // in MTUs + optional maxBurst; // in MTUs + optional congestionControlModule; // 0: RFC2581, 1: HSTCP, 2: H-TCP, 3: RTCC + optional delayedSackTime; + optional minRetransmitTimeout; + optional maxRetransmitTimeout; + optional initialRetransmitTimeout; + optional maxRetransmitAttempts; + optional heartbeatInterval; +}; + +RTC_CPP_EXPORT void SetSctpSettings(SctpSettings s); + +RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, LogLevel level); + +} // namespace rtc + +#endif diff --git a/datachannel/include/rtc/h264rtppacketizer.hpp b/datachannel/include/rtc/h264rtppacketizer.hpp new file mode 100644 index 000000000..9aeb1147c --- /dev/null +++ b/datachannel/include/rtc/h264rtppacketizer.hpp @@ -0,0 +1,58 @@ +/** + * Copyright (c) 2020 Filip Klembara (in2core) + * Copyright (c) 2023 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_H264_RTP_PACKETIZER_H +#define RTC_H264_RTP_PACKETIZER_H + +#if RTC_ENABLE_MEDIA + +#include "nalunit.hpp" +#include "rtppacketizer.hpp" + +namespace rtc { + +/// RTP packetization for H264 +class RTC_CPP_EXPORT H264RtpPacketizer final : public RtpPacketizer { +public: + using Separator = NalUnit::Separator; + + /// Default clock rate for H264 in RTP + inline static const uint32_t defaultClockRate = 90 * 1000; + + /// Constructs h264 payload packetizer with given RTP configuration. + /// @note RTP configuration is used in packetization process which may change some configuration + /// properties such as sequence number. + /// @param separator NAL unit separator + /// @param rtpConfig RTP configuration + /// @param maxFragmentSize maximum size of one NALU fragment + H264RtpPacketizer(Separator separator, shared_ptr rtpConfig, + uint16_t maxFragmentSize = NalUnits::defaultMaximumFragmentSize); + + // For backward compatibility, do not use + [[deprecated]] H264RtpPacketizer( + shared_ptr rtpConfig, + uint16_t maxFragmentSize = NalUnits::defaultMaximumFragmentSize); + + void outgoing(message_vector &messages, const message_callback &send) override; + +private: + shared_ptr splitMessage(binary_ptr message); + + const uint16_t maxFragmentSize; + const Separator separator; +}; + +// For backward compatibility, do not use +using H264PacketizationHandler [[deprecated("Add H264RtpPacketizer directly")]] = PacketizationHandler; + +} // namespace rtc + +#endif /* RTC_ENABLE_MEDIA */ + +#endif /* RTC_H264_RTP_PACKETIZER_H */ diff --git a/datachannel/include/rtc/h265nalunit.hpp b/datachannel/include/rtc/h265nalunit.hpp new file mode 100644 index 000000000..b322fc6bd --- /dev/null +++ b/datachannel/include/rtc/h265nalunit.hpp @@ -0,0 +1,186 @@ +/** + * Copyright (c) 2023 Zita Liao (Dolby) + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_H265_NAL_UNIT_H +#define RTC_H265_NAL_UNIT_H + +#if RTC_ENABLE_MEDIA + +#include "common.hpp" +#include "nalunit.hpp" + +#include + +namespace rtc { + +#pragma pack(push, 1) + +#define H265_FU_HEADER_SIZE 1 +/// Nalu header +struct RTC_CPP_EXPORT H265NalUnitHeader { + /* + * nal_unit_header( ) { + * forbidden_zero_bit f(1) + * nal_unit_type u(6) + * nuh_layer_id u(6) + * nuh_temporal_id_plus1 u(3) + } + */ + uint8_t _first = 0; // high byte of header + uint8_t _second = 0; // low byte of header + + bool forbiddenBit() const { return _first >> 7; } + uint8_t unitType() const { return (_first & 0b0111'1110) >> 1; } + uint8_t nuhLayerId() const { return ((_first & 0x1) << 5) | ((_second & 0b1111'1000) >> 3); } + uint8_t nuhTempIdPlus1() const { return _second & 0b111; } + + void setForbiddenBit(bool isSet) { _first = (_first & 0x7F) | (isSet << 7); } + void setUnitType(uint8_t type) { _first = (_first & 0b1000'0001) | ((type & 0b11'1111) << 1); } + void setNuhLayerId(uint8_t nuhLayerId) { + _first = (_first & 0b1111'1110) | ((nuhLayerId & 0b10'0000) >> 5); + _second = (_second & 0b0000'0111) | ((nuhLayerId & 0b01'1111) << 3); + } + void setNuhTempIdPlus1(uint8_t nuhTempIdPlus1) { + _second = (_second & 0b1111'1000) | (nuhTempIdPlus1 & 0b111); + } +}; + +/// Nalu fragment header +struct RTC_CPP_EXPORT H265NalUnitFragmentHeader { + /* + * +---------------+ + * |0|1|2|3|4|5|6|7| + * +-+-+-+-+-+-+-+-+ + * |S|E| FuType | + * +---------------+ + */ + uint8_t _first = 0; + + bool isStart() const { return _first >> 7; } + bool isEnd() const { return (_first >> 6) & 0x01; } + uint8_t unitType() const { return _first & 0b11'1111; } + + void setStart(bool isSet) { _first = (_first & 0x7F) | (isSet << 7); } + void setEnd(bool isSet) { _first = (_first & 0b1011'1111) | (isSet << 6); } + void setUnitType(uint8_t type) { _first = (_first & 0b1100'0000) | (type & 0b11'1111); } +}; + +#pragma pack(pop) + +/// Nal unit +struct RTC_CPP_EXPORT H265NalUnit : NalUnit { + H265NalUnit(const H265NalUnit &unit) = default; + H265NalUnit(size_t size, bool includingHeader = true) + : NalUnit(size, includingHeader, NalUnit::Type::H265) {} + H265NalUnit(binary &&data) : NalUnit(std::move(data)) {} + H265NalUnit() : NalUnit(NalUnit::Type::H265) {} + + template + H265NalUnit(Iterator begin_, Iterator end_) : NalUnit(begin_, end_) {} + + bool forbiddenBit() const { return header()->forbiddenBit(); } + uint8_t unitType() const { return header()->unitType(); } + uint8_t nuhLayerId() const { return header()->nuhLayerId(); } + uint8_t nuhTempIdPlus1() const { return header()->nuhTempIdPlus1(); } + + binary payload() const { + assert(size() >= H265_NAL_HEADER_SIZE); + return {begin() + H265_NAL_HEADER_SIZE, end()}; + } + + void setForbiddenBit(bool isSet) { header()->setForbiddenBit(isSet); } + void setUnitType(uint8_t type) { header()->setUnitType(type); } + void setNuhLayerId(uint8_t nuhLayerId) { header()->setNuhLayerId(nuhLayerId); } + void setNuhTempIdPlus1(uint8_t nuhTempIdPlus1) { header()->setNuhTempIdPlus1(nuhTempIdPlus1); } + + void setPayload(binary payload) { + assert(size() >= H265_NAL_HEADER_SIZE); + erase(begin() + H265_NAL_HEADER_SIZE, end()); + insert(end(), payload.begin(), payload.end()); + } + +protected: + const H265NalUnitHeader *header() const { + assert(size() >= H265_NAL_HEADER_SIZE); + return reinterpret_cast(data()); + } + + H265NalUnitHeader *header() { + assert(size() >= H265_NAL_HEADER_SIZE); + return reinterpret_cast(data()); + } +}; + +/// Nal unit fragment A +struct RTC_CPP_EXPORT H265NalUnitFragment : H265NalUnit { + static std::vector> fragmentsFrom(shared_ptr nalu, + uint16_t maxFragmentSize); + + enum class FragmentType { Start, Middle, End }; + + H265NalUnitFragment(FragmentType type, bool forbiddenBit, uint8_t nuhLayerId, + uint8_t nuhTempIdPlus1, uint8_t unitType, binary data); + + uint8_t unitType() const { return fragmentHeader()->unitType(); } + + binary payload() const { + assert(size() >= H265_NAL_HEADER_SIZE + H265_FU_HEADER_SIZE); + return {begin() + H265_NAL_HEADER_SIZE + H265_FU_HEADER_SIZE, end()}; + } + + FragmentType type() const { + if (fragmentHeader()->isStart()) { + return FragmentType::Start; + } else if (fragmentHeader()->isEnd()) { + return FragmentType::End; + } else { + return FragmentType::Middle; + } + } + + void setUnitType(uint8_t type) { fragmentHeader()->setUnitType(type); } + + void setPayload(binary payload) { + assert(size() >= H265_NAL_HEADER_SIZE + H265_FU_HEADER_SIZE); + erase(begin() + H265_NAL_HEADER_SIZE + H265_FU_HEADER_SIZE, end()); + insert(end(), payload.begin(), payload.end()); + } + + void setFragmentType(FragmentType type); + +protected: + const uint8_t nal_type_fu = 49; + + H265NalUnitHeader *fragmentIndicator() { return reinterpret_cast(data()); } + + const H265NalUnitHeader *fragmentIndicator() const { + return reinterpret_cast(data()); + } + + H265NalUnitFragmentHeader *fragmentHeader() { + return reinterpret_cast(data() + H265_NAL_HEADER_SIZE); + } + + const H265NalUnitFragmentHeader *fragmentHeader() const { + return reinterpret_cast(data() + H265_NAL_HEADER_SIZE); + } +}; + +class RTC_CPP_EXPORT H265NalUnits : public std::vector> { +public: + static const uint16_t defaultMaximumFragmentSize = + uint16_t(RTC_DEFAULT_MTU - 12 - 8 - 40); // SRTP/UDP/IPv6 + + std::vector> generateFragments(uint16_t maxFragmentSize); +}; + +} // namespace rtc + +#endif /* RTC_ENABLE_MEDIA */ + +#endif /* RTC_NAL_UNIT_H */ diff --git a/datachannel/include/rtc/h265rtppacketizer.hpp b/datachannel/include/rtc/h265rtppacketizer.hpp new file mode 100644 index 000000000..b629c6aa7 --- /dev/null +++ b/datachannel/include/rtc/h265rtppacketizer.hpp @@ -0,0 +1,56 @@ +/** + * Copyright (c) 2023 Zita Liao (Dolby) + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_H265_RTP_PACKETIZER_H +#define RTC_H265_RTP_PACKETIZER_H + +#if RTC_ENABLE_MEDIA + +#include "h265nalunit.hpp" +#include "rtppacketizer.hpp" + +namespace rtc { + +// RTP packetization for H265 +class RTC_CPP_EXPORT H265RtpPacketizer final : public RtpPacketizer { +public: + using Separator = NalUnit::Separator; + + // Default clock rate for H265 in RTP + inline static const uint32_t defaultClockRate = 90 * 1000; + + // Constructs h265 payload packetizer with given RTP configuration. + // @note RTP configuration is used in packetization process which may change some configuration + // properties such as sequence number. + // @param separator NAL unit separator + // @param rtpConfig RTP configuration + // @param maxFragmentSize maximum size of one NALU fragment + H265RtpPacketizer(Separator separator, shared_ptr rtpConfig, + uint16_t maxFragmentSize = H265NalUnits::defaultMaximumFragmentSize); + + // for backward compatibility + [[deprecated]] H265RtpPacketizer(shared_ptr rtpConfig, + uint16_t maxFragmentSize = H265NalUnits::defaultMaximumFragmentSize); + + void outgoing(message_vector &messages, const message_callback &send) override; + +private: + shared_ptr splitMessage(binary_ptr message); + + const uint16_t maxFragmentSize; + const NalUnit::Separator separator; +}; + +// For backward compatibility, do not use +using H265PacketizationHandler [[deprecated("Add H265RtpPacketizer directly")]] = PacketizationHandler; + +} // namespace rtc + +#endif /* RTC_ENABLE_MEDIA */ + +#endif /* RTC_H265_RTP_PACKETIZER_H */ diff --git a/datachannel/include/rtc/mediahandler.hpp b/datachannel/include/rtc/mediahandler.hpp new file mode 100644 index 000000000..04a676062 --- /dev/null +++ b/datachannel/include/rtc/mediahandler.hpp @@ -0,0 +1,58 @@ +/** + * Copyright (c) 2020 Staz Modrzynski + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_MEDIA_HANDLER_H +#define RTC_MEDIA_HANDLER_H + +#include "common.hpp" +#include "description.hpp" +#include "message.hpp" + +namespace rtc { + +class RTC_CPP_EXPORT MediaHandler : public std::enable_shared_from_this { +public: + MediaHandler(); + virtual ~MediaHandler(); + + /// Called when a media is added or updated + /// @param desc Description of the media + virtual void media([[maybe_unused]] const Description::Media &desc) {} + + /// Called when there is traffic coming from the peer + /// @param messages Incoming messages from the peer, can be modified by the handler + /// @param send Send callback to send messages back to the peer + virtual void incoming([[maybe_unused]] message_vector &messages, [[maybe_unused]] const message_callback &send) {} + + /// Called when there is traffic that needs to be sent to the peer + /// @param messages Outgoing messages to the peer, can be modified by the handler + /// @param send Send callback to send messages back to the peer + virtual void outgoing([[maybe_unused]] message_vector &messages, [[maybe_unused]] const message_callback &send) {} + + virtual bool requestKeyframe(const message_callback &send); + virtual bool requestBitrate(unsigned int bitrate, const message_callback &send); + + void addToChain(shared_ptr handler); + void setNext(shared_ptr handler); + shared_ptr next(); + shared_ptr next() const; + shared_ptr last(); // never null + shared_ptr last() const; // never null + + void mediaChain(const Description::Media &desc); + void incomingChain(message_vector &messages, const message_callback &send); + void outgoingChain(message_vector &messages, const message_callback &send); + +private: + shared_ptr mNext; +}; + +} // namespace rtc + +#endif // RTC_MEDIA_HANDLER_H diff --git a/datachannel/include/rtc/message.hpp b/datachannel/include/rtc/message.hpp new file mode 100644 index 000000000..486210723 --- /dev/null +++ b/datachannel/include/rtc/message.hpp @@ -0,0 +1,79 @@ +/** + * Copyright (c) 2019-2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_MESSAGE_H +#define RTC_MESSAGE_H + +#include "common.hpp" +#include "reliability.hpp" + +#include + +namespace rtc { + +struct RTC_CPP_EXPORT Message : binary { + enum Type { Binary, String, Control, Reset }; + + Message(const Message &message) = default; + Message(size_t size, Type type_ = Binary) : binary(size), type(type_) {} + + template + Message(Iterator begin_, Iterator end_, Type type_ = Binary) + : binary(begin_, end_), type(type_) {} + + Message(binary &&data, Type type_ = Binary) : binary(std::move(data)), type(type_) {} + + Type type; + unsigned int stream = 0; // Stream id (SCTP stream or SSRC) + unsigned int dscp = 0; // Differentiated Services Code Point + shared_ptr reliability; +}; + +using message_ptr = shared_ptr; +using message_callback = std::function; +using message_vector = std::vector; + +inline size_t message_size_func(const message_ptr &m) { + return m->type == Message::Binary || m->type == Message::String ? m->size() : 0; +} + +template +message_ptr make_message(Iterator begin, Iterator end, Message::Type type = Message::Binary, + unsigned int stream = 0, shared_ptr reliability = nullptr) { + auto message = std::make_shared(begin, end, type); + message->stream = stream; + message->reliability = reliability; + return message; +} + +RTC_CPP_EXPORT message_ptr make_message(size_t size, Message::Type type = Message::Binary, + unsigned int stream = 0, + shared_ptr reliability = nullptr); + +RTC_CPP_EXPORT message_ptr make_message(binary &&data, Message::Type type = Message::Binary, + unsigned int stream = 0, + shared_ptr reliability = nullptr); + +RTC_CPP_EXPORT message_ptr make_message(size_t size, message_ptr orig); + +RTC_CPP_EXPORT message_ptr make_message(message_variant data); + +#if RTC_ENABLE_MEDIA + +// Reconstructs a message_ptr from an opaque rtcMessage pointer that +// was allocated by rtcCreateOpaqueMessage(). +message_ptr make_message_from_opaque_ptr(rtcMessage *&&message); + +#endif + +RTC_CPP_EXPORT message_variant to_variant(Message &&message); +RTC_CPP_EXPORT message_variant to_variant(const Message &message); + +} // namespace rtc + +#endif diff --git a/datachannel/include/rtc/nalunit.hpp b/datachannel/include/rtc/nalunit.hpp new file mode 100644 index 000000000..030d8ea1e --- /dev/null +++ b/datachannel/include/rtc/nalunit.hpp @@ -0,0 +1,226 @@ +/** + * Copyright (c) 2020 Filip Klembara (in2core) + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_NAL_UNIT_H +#define RTC_NAL_UNIT_H + +#if RTC_ENABLE_MEDIA + +#include "common.hpp" + +#include + +namespace rtc { + +#pragma pack(push, 1) + +/// Nalu header +struct RTC_CPP_EXPORT NalUnitHeader { + uint8_t _first = 0; + + bool forbiddenBit() const { return _first >> 7; } + uint8_t nri() const { return _first >> 5 & 0x03; } + uint8_t unitType() const { return _first & 0x1F; } + + void setForbiddenBit(bool isSet) { _first = (_first & 0x7F) | (isSet << 7); } + void setNRI(uint8_t nri) { _first = (_first & 0x9F) | ((nri & 0x03) << 5); } + void setUnitType(uint8_t type) { _first = (_first & 0xE0) | (type & 0x1F); } +}; + +/// Nalu fragment header +struct RTC_CPP_EXPORT NalUnitFragmentHeader { + uint8_t _first = 0; + + bool isStart() const { return _first >> 7; } + bool reservedBit6() const { return (_first >> 5) & 0x01; } + bool isEnd() const { return (_first >> 6) & 0x01; } + uint8_t unitType() const { return _first & 0x1F; } + + void setStart(bool isSet) { _first = (_first & 0x7F) | (isSet << 7); } + void setEnd(bool isSet) { _first = (_first & 0xBF) | (isSet << 6); } + void setReservedBit6(bool isSet) { _first = (_first & 0xDF) | (isSet << 5); } + void setUnitType(uint8_t type) { _first = (_first & 0xE0) | (type & 0x1F); } +}; + +#pragma pack(pop) + +enum NalUnitStartSequenceMatch { + NUSM_noMatch, + NUSM_firstZero, + NUSM_secondZero, + NUSM_thirdZero, + NUSM_shortMatch, + NUSM_longMatch +}; + +static const size_t H264_NAL_HEADER_SIZE = 1; +static const size_t H265_NAL_HEADER_SIZE = 2; +/// Nal unit +struct RTC_CPP_EXPORT NalUnit : binary { + enum class Type { H264, H265 }; + + NalUnit(const NalUnit &unit) = default; + NalUnit(size_t size, bool includingHeader = true, Type type = Type::H264) + : binary(size + (includingHeader + ? 0 + : (type == Type::H264 ? H264_NAL_HEADER_SIZE : H265_NAL_HEADER_SIZE))) {} + NalUnit(binary &&data) : binary(std::move(data)) {} + NalUnit(Type type = Type::H264) + : binary(type == Type::H264 ? H264_NAL_HEADER_SIZE : H265_NAL_HEADER_SIZE) {} + template NalUnit(Iterator begin_, Iterator end_) : binary(begin_, end_) {} + + bool forbiddenBit() const { return header()->forbiddenBit(); } + uint8_t nri() const { return header()->nri(); } + uint8_t unitType() const { return header()->unitType(); } + + binary payload() const { + assert(size() >= 1); + return {begin() + 1, end()}; + } + + void setForbiddenBit(bool isSet) { header()->setForbiddenBit(isSet); } + void setNRI(uint8_t nri) { header()->setNRI(nri); } + void setUnitType(uint8_t type) { header()->setUnitType(type); } + + void setPayload(binary payload) { + assert(size() >= 1); + erase(begin() + 1, end()); + insert(end(), payload.begin(), payload.end()); + } + + /// NAL unit separator + enum class Separator { + Length = RTC_NAL_SEPARATOR_LENGTH, // first 4 bytes are NAL unit length + LongStartSequence = RTC_NAL_SEPARATOR_LONG_START_SEQUENCE, // 0x00, 0x00, 0x00, 0x01 + ShortStartSequence = RTC_NAL_SEPARATOR_SHORT_START_SEQUENCE, // 0x00, 0x00, 0x01 + StartSequence = RTC_NAL_SEPARATOR_START_SEQUENCE, // LongStartSequence or ShortStartSequence + }; + + static NalUnitStartSequenceMatch StartSequenceMatchSucc(NalUnitStartSequenceMatch match, + std::byte _byte, Separator separator) { + assert(separator != Separator::Length); + auto byte = (uint8_t)_byte; + auto detectShort = + separator == Separator::ShortStartSequence || separator == Separator::StartSequence; + auto detectLong = + separator == Separator::LongStartSequence || separator == Separator::StartSequence; + switch (match) { + case NUSM_noMatch: + if (byte == 0x00) { + return NUSM_firstZero; + } + break; + case NUSM_firstZero: + if (byte == 0x00) { + return NUSM_secondZero; + } + break; + case NUSM_secondZero: + if (byte == 0x00 && detectLong) { + return NUSM_thirdZero; + } else if (byte == 0x00 && detectShort) { + return NUSM_secondZero; + } else if (byte == 0x01 && detectShort) { + return NUSM_shortMatch; + } + break; + case NUSM_thirdZero: + if (byte == 0x00 && detectLong) { + return NUSM_thirdZero; + } else if (byte == 0x01 && detectLong) { + return NUSM_longMatch; + } + break; + case NUSM_shortMatch: + return NUSM_shortMatch; + case NUSM_longMatch: + return NUSM_longMatch; + } + return NUSM_noMatch; + } + +protected: + const NalUnitHeader *header() const { + assert(size() >= 1); + return reinterpret_cast(data()); + } + + NalUnitHeader *header() { + assert(size() >= 1); + return reinterpret_cast(data()); + } +}; + +/// Nal unit fragment A +struct RTC_CPP_EXPORT NalUnitFragmentA : NalUnit { + static std::vector> fragmentsFrom(shared_ptr nalu, + uint16_t maxFragmentSize); + + enum class FragmentType { Start, Middle, End }; + + NalUnitFragmentA(FragmentType type, bool forbiddenBit, uint8_t nri, uint8_t unitType, + binary data); + + uint8_t unitType() const { return fragmentHeader()->unitType(); } + + binary payload() const { + assert(size() >= 2); + return {begin() + 2, end()}; + } + + FragmentType type() const { + if (fragmentHeader()->isStart()) { + return FragmentType::Start; + } else if (fragmentHeader()->isEnd()) { + return FragmentType::End; + } else { + return FragmentType::Middle; + } + } + + void setUnitType(uint8_t type) { fragmentHeader()->setUnitType(type); } + + void setPayload(binary payload) { + assert(size() >= 2); + erase(begin() + 2, end()); + insert(end(), payload.begin(), payload.end()); + } + + void setFragmentType(FragmentType type); + +protected: + const uint8_t nal_type_fu_A = 28; + + NalUnitHeader *fragmentIndicator() { return reinterpret_cast(data()); } + + const NalUnitHeader *fragmentIndicator() const { + return reinterpret_cast(data()); + } + + NalUnitFragmentHeader *fragmentHeader() { + return reinterpret_cast(fragmentIndicator() + 1); + } + + const NalUnitFragmentHeader *fragmentHeader() const { + return reinterpret_cast(fragmentIndicator() + 1); + } +}; + +class RTC_CPP_EXPORT NalUnits : public std::vector> { +public: + static const uint16_t defaultMaximumFragmentSize = + uint16_t(RTC_DEFAULT_MTU - 12 - 8 - 40); // SRTP/UDP/IPv6 + + std::vector> generateFragments(uint16_t maxFragmentSize); +}; + +} // namespace rtc + +#endif /* RTC_ENABLE_MEDIA */ + +#endif /* RTC_NAL_UNIT_H */ diff --git a/datachannel/include/rtc/peerconnection.hpp b/datachannel/include/rtc/peerconnection.hpp new file mode 100644 index 000000000..86ea410cd --- /dev/null +++ b/datachannel/include/rtc/peerconnection.hpp @@ -0,0 +1,130 @@ +/** + * Copyright (c) 2019 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_PEER_CONNECTION_H +#define RTC_PEER_CONNECTION_H + +#include "candidate.hpp" +#include "common.hpp" +#include "configuration.hpp" +#include "datachannel.hpp" +#include "description.hpp" +#include "reliability.hpp" +#include "track.hpp" + +#include +#include + +namespace rtc { + +namespace impl { + +struct PeerConnection; + +} + +struct RTC_CPP_EXPORT DataChannelInit { + Reliability reliability = {}; + bool negotiated = false; + optional id = nullopt; + string protocol = ""; +}; + +class RTC_CPP_EXPORT PeerConnection final : CheshireCat { +public: + enum class State : int { + New = RTC_NEW, + Connecting = RTC_CONNECTING, + Connected = RTC_CONNECTED, + Disconnected = RTC_DISCONNECTED, + Failed = RTC_FAILED, + Closed = RTC_CLOSED + }; + + enum class IceState : int { + New = RTC_ICE_NEW, + Checking = RTC_ICE_CHECKING, + Connected = RTC_ICE_CONNECTED, + Completed = RTC_ICE_COMPLETED, + Failed = RTC_ICE_FAILED, + Disconnected = RTC_ICE_DISCONNECTED, + Closed = RTC_ICE_CLOSED + }; + + enum class GatheringState : int { + New = RTC_GATHERING_NEW, + InProgress = RTC_GATHERING_INPROGRESS, + Complete = RTC_GATHERING_COMPLETE + }; + + enum class SignalingState : int { + Stable = RTC_SIGNALING_STABLE, + HaveLocalOffer = RTC_SIGNALING_HAVE_LOCAL_OFFER, + HaveRemoteOffer = RTC_SIGNALING_HAVE_REMOTE_OFFER, + HaveLocalPranswer = RTC_SIGNALING_HAVE_LOCAL_PRANSWER, + HaveRemotePranswer = RTC_SIGNALING_HAVE_REMOTE_PRANSWER, + }; + + PeerConnection(); + PeerConnection(Configuration config); + ~PeerConnection(); + + void close(); + + const Configuration *config() const; + State state() const; + IceState iceState() const; + GatheringState gatheringState() const; + SignalingState signalingState() const; + bool hasMedia() const; + optional localDescription() const; + optional remoteDescription() const; + size_t remoteMaxMessageSize() const; + optional localAddress() const; + optional remoteAddress() const; + uint16_t maxDataChannelId() const; + bool getSelectedCandidatePair(Candidate *local, Candidate *remote); + + void setLocalDescription(Description::Type type = Description::Type::Unspec); + void setRemoteDescription(Description description); + void addRemoteCandidate(Candidate candidate); + + void setMediaHandler(shared_ptr handler); + shared_ptr getMediaHandler(); + + [[nodiscard]] shared_ptr createDataChannel(string label, + DataChannelInit init = {}); + void onDataChannel(std::function dataChannel)> callback); + + [[nodiscard]] shared_ptr addTrack(Description::Media description); + void onTrack(std::function track)> callback); + + void onLocalDescription(std::function callback); + void onLocalCandidate(std::function callback); + void onStateChange(std::function callback); + void onIceStateChange(std::function callback); + void onGatheringStateChange(std::function callback); + void onSignalingStateChange(std::function callback); + + void resetCallbacks(); + + // Stats + void clearStats(); + size_t bytesSent(); + size_t bytesReceived(); + optional rtt(); +}; + +RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, PeerConnection::State state); +RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, PeerConnection::IceState state); +RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, PeerConnection::GatheringState state); +RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, PeerConnection::SignalingState state); + +} // namespace rtc + +#endif diff --git a/datachannel/include/rtc/plihandler.hpp b/datachannel/include/rtc/plihandler.hpp new file mode 100644 index 000000000..ac149edc8 --- /dev/null +++ b/datachannel/include/rtc/plihandler.hpp @@ -0,0 +1,36 @@ +/** + * Copyright (c) 2023 Arda Cinar + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_PLI_RESPONDER_H +#define RTC_PLI_RESPONDER_H + +#if RTC_ENABLE_MEDIA + +#include "mediahandler.hpp" +#include "utils.hpp" + +namespace rtc { + +/// Responds to PLI and FIR messages sent by the receiver. The sender should respond to these +/// messages by sending an intra. +class RTC_CPP_EXPORT PliHandler final : public MediaHandler { + rtc::synchronized_callback<> mOnPli; + +public: + /// Constructs the PLIResponder object to notify whenever a new intra frame is requested + /// @param onPli The callback that gets called whenever an intra frame is requested by the receiver + PliHandler(std::function onPli); + + void incoming(message_vector &messages, const message_callback &send) override; +}; + +} + +#endif // RTC_ENABLE_MEDIA + +#endif // RTC_PLI_RESPONDER_H diff --git a/datachannel/include/rtc/reliability.hpp b/datachannel/include/rtc/reliability.hpp new file mode 100644 index 000000000..df63b93b7 --- /dev/null +++ b/datachannel/include/rtc/reliability.hpp @@ -0,0 +1,43 @@ +/** + * Copyright (c) 2019 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_RELIABILITY_H +#define RTC_RELIABILITY_H + +#include "common.hpp" + +#include + +namespace rtc { + +struct Reliability { + // It true, the channel does not enforce message ordering and out-of-order delivery is allowed + bool unordered = false; + + // If both maxPacketLifeTime or maxRetransmits are unset, the channel is reliable. + // If either maxPacketLifeTime or maxRetransmits is set, the channel is unreliable. + // (The settings are exclusive so both maxPacketLifetime and maxRetransmits must not be set.) + + // Time window during which transmissions and retransmissions may occur + optional maxPacketLifeTime; + + // Maximum number of retransmissions that are attempted + optional maxRetransmits; + + // For backward compatibility, do not use + enum class Type { Reliable = 0, Rexmit, Timed }; + union { + Type typeDeprecated = Type::Reliable; + [[deprecated("Use maxPacketLifeTime or maxRetransmits")]] Type type; + }; + variant rexmit = 0; +}; + +} // namespace rtc + +#endif diff --git a/datachannel/include/rtc/rtc.h b/datachannel/include/rtc/rtc.h new file mode 100644 index 000000000..5a7c214e6 --- /dev/null +++ b/datachannel/include/rtc/rtc.h @@ -0,0 +1,518 @@ +/** + * Copyright (c) 2019-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_C_API +#define RTC_C_API + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include + +#ifdef RTC_STATIC +#define RTC_C_EXPORT +#else // dynamic library +#ifdef _WIN32 +#ifdef RTC_EXPORTS +#define RTC_C_EXPORT __declspec(dllexport) // building the library +#else +#define RTC_C_EXPORT __declspec(dllimport) // using the library +#endif +#else // not WIN32 +#define RTC_C_EXPORT +#endif +#endif + +#ifndef RTC_ENABLE_WEBSOCKET +#define RTC_ENABLE_WEBSOCKET 1 +#endif + +#ifndef RTC_ENABLE_MEDIA +#define RTC_ENABLE_MEDIA 1 +#endif + +#define RTC_DEFAULT_MTU 1280 // IPv6 minimum guaranteed MTU + +#if RTC_ENABLE_MEDIA +#define RTC_DEFAULT_MAX_FRAGMENT_SIZE ((uint16_t)(RTC_DEFAULT_MTU - 12 - 8 - 40)) // SRTP/UDP/IPv6 +#define RTC_DEFAULT_MAX_STORED_PACKET_COUNT 512 +// Deprecated, do not use +#define RTC_DEFAULT_MAXIMUM_FRAGMENT_SIZE RTC_DEFAULT_MAX_FRAGMENT_SIZE +#define RTC_DEFAULT_MAXIMUM_PACKET_COUNT_FOR_NACK_CACHE RTC_DEFAULT_MAX_STORED_PACKET_COUNT +#endif + +#ifdef _WIN32 +#ifdef CAPI_STDCALL +#define RTC_API __stdcall +#else +#define RTC_API +#endif +#else // not WIN32 +#define RTC_API +#endif + +#if defined(__GNUC__) || defined(__clang__) +#define RTC_DEPRECATED __attribute__((deprecated)) +#elif defined(_MSC_VER) +#define RTC_DEPRECATED __declspec(deprecated) +#else +#define DEPRECATED +#endif + +// libdatachannel C API + +typedef enum { + RTC_NEW = 0, + RTC_CONNECTING = 1, + RTC_CONNECTED = 2, + RTC_DISCONNECTED = 3, + RTC_FAILED = 4, + RTC_CLOSED = 5 +} rtcState; + +typedef enum { + RTC_ICE_NEW = 0, + RTC_ICE_CHECKING = 1, + RTC_ICE_CONNECTED = 2, + RTC_ICE_COMPLETED = 3, + RTC_ICE_FAILED = 4, + RTC_ICE_DISCONNECTED = 5, + RTC_ICE_CLOSED = 6 +} rtcIceState; + +typedef enum { + RTC_GATHERING_NEW = 0, + RTC_GATHERING_INPROGRESS = 1, + RTC_GATHERING_COMPLETE = 2 +} rtcGatheringState; + +typedef enum { + RTC_SIGNALING_STABLE = 0, + RTC_SIGNALING_HAVE_LOCAL_OFFER = 1, + RTC_SIGNALING_HAVE_REMOTE_OFFER = 2, + RTC_SIGNALING_HAVE_LOCAL_PRANSWER = 3, + RTC_SIGNALING_HAVE_REMOTE_PRANSWER = 4, +} rtcSignalingState; + +typedef enum { // Don't change, it must match plog severity + RTC_LOG_NONE = 0, + RTC_LOG_FATAL = 1, + RTC_LOG_ERROR = 2, + RTC_LOG_WARNING = 3, + RTC_LOG_INFO = 4, + RTC_LOG_DEBUG = 5, + RTC_LOG_VERBOSE = 6 +} rtcLogLevel; + +typedef enum { + RTC_CERTIFICATE_DEFAULT = 0, // ECDSA + RTC_CERTIFICATE_ECDSA = 1, + RTC_CERTIFICATE_RSA = 2, +} rtcCertificateType; + +typedef enum { + // video + RTC_CODEC_H264 = 0, + RTC_CODEC_VP8 = 1, + RTC_CODEC_VP9 = 2, + RTC_CODEC_H265 = 3, + RTC_CODEC_AV1 = 4, + + // audio + RTC_CODEC_OPUS = 128, + RTC_CODEC_PCMU = 129, + RTC_CODEC_PCMA = 130, + RTC_CODEC_AAC = 131, +} rtcCodec; + +typedef enum { + RTC_DIRECTION_UNKNOWN = 0, + RTC_DIRECTION_SENDONLY = 1, + RTC_DIRECTION_RECVONLY = 2, + RTC_DIRECTION_SENDRECV = 3, + RTC_DIRECTION_INACTIVE = 4 +} rtcDirection; + +typedef enum { RTC_TRANSPORT_POLICY_ALL = 0, RTC_TRANSPORT_POLICY_RELAY = 1 } rtcTransportPolicy; + +#define RTC_ERR_SUCCESS 0 +#define RTC_ERR_INVALID -1 // invalid argument +#define RTC_ERR_FAILURE -2 // runtime error +#define RTC_ERR_NOT_AVAIL -3 // element not available +#define RTC_ERR_TOO_SMALL -4 // buffer too small + +typedef void(RTC_API *rtcLogCallbackFunc)(rtcLogLevel level, const char *message); +typedef void(RTC_API *rtcDescriptionCallbackFunc)(int pc, const char *sdp, const char *type, + void *ptr); +typedef void(RTC_API *rtcCandidateCallbackFunc)(int pc, const char *cand, const char *mid, + void *ptr); +typedef void(RTC_API *rtcStateChangeCallbackFunc)(int pc, rtcState state, void *ptr); +typedef void(RTC_API *rtcIceStateChangeCallbackFunc)(int pc, rtcIceState state, void *ptr); +typedef void(RTC_API *rtcGatheringStateCallbackFunc)(int pc, rtcGatheringState state, void *ptr); +typedef void(RTC_API *rtcSignalingStateCallbackFunc)(int pc, rtcSignalingState state, void *ptr); +typedef void(RTC_API *rtcDataChannelCallbackFunc)(int pc, int dc, void *ptr); +typedef void(RTC_API *rtcTrackCallbackFunc)(int pc, int tr, void *ptr); +typedef void(RTC_API *rtcOpenCallbackFunc)(int id, void *ptr); +typedef void(RTC_API *rtcClosedCallbackFunc)(int id, void *ptr); +typedef void(RTC_API *rtcErrorCallbackFunc)(int id, const char *error, void *ptr); +typedef void(RTC_API *rtcMessageCallbackFunc)(int id, const char *message, int size, void *ptr); +typedef void *(RTC_API *rtcInterceptorCallbackFunc)(int pc, const char *message, int size, + void *ptr); +typedef void(RTC_API *rtcBufferedAmountLowCallbackFunc)(int id, void *ptr); +typedef void(RTC_API *rtcAvailableCallbackFunc)(int id, void *ptr); +typedef void(RTC_API *rtcPliHandlerCallbackFunc)(int tr, void *ptr); + +// Log + +// NULL cb on the first call will log to stdout +RTC_C_EXPORT void rtcInitLogger(rtcLogLevel level, rtcLogCallbackFunc cb); + +// User pointer +RTC_C_EXPORT void rtcSetUserPointer(int id, void *ptr); +RTC_C_EXPORT void *rtcGetUserPointer(int i); + +// PeerConnection + +typedef struct { + const char **iceServers; + int iceServersCount; + const char *proxyServer; // libnice only + const char *bindAddress; // libjuice only, NULL means any + rtcCertificateType certificateType; + rtcTransportPolicy iceTransportPolicy; + bool enableIceTcp; // libnice only + bool enableIceUdpMux; // libjuice only + bool disableAutoNegotiation; + bool forceMediaTransport; + uint16_t portRangeBegin; // 0 means automatic + uint16_t portRangeEnd; // 0 means automatic + int mtu; // <= 0 means automatic + int maxMessageSize; // <= 0 means default +} rtcConfiguration; + +RTC_C_EXPORT int rtcCreatePeerConnection(const rtcConfiguration *config); // returns pc id +RTC_C_EXPORT int rtcClosePeerConnection(int pc); +RTC_C_EXPORT int rtcDeletePeerConnection(int pc); + +RTC_C_EXPORT int rtcSetLocalDescriptionCallback(int pc, rtcDescriptionCallbackFunc cb); +RTC_C_EXPORT int rtcSetLocalCandidateCallback(int pc, rtcCandidateCallbackFunc cb); +RTC_C_EXPORT int rtcSetStateChangeCallback(int pc, rtcStateChangeCallbackFunc cb); +RTC_C_EXPORT int rtcSetIceStateChangeCallback(int pc, rtcIceStateChangeCallbackFunc cb); +RTC_C_EXPORT int rtcSetGatheringStateChangeCallback(int pc, rtcGatheringStateCallbackFunc cb); +RTC_C_EXPORT int rtcSetSignalingStateChangeCallback(int pc, rtcSignalingStateCallbackFunc cb); + +RTC_C_EXPORT int rtcSetLocalDescription(int pc, const char *type); +RTC_C_EXPORT int rtcSetRemoteDescription(int pc, const char *sdp, const char *type); +RTC_C_EXPORT int rtcAddRemoteCandidate(int pc, const char *cand, const char *mid); + +RTC_C_EXPORT int rtcGetLocalDescription(int pc, char *buffer, int size); +RTC_C_EXPORT int rtcGetRemoteDescription(int pc, char *buffer, int size); + +RTC_C_EXPORT int rtcGetLocalDescriptionType(int pc, char *buffer, int size); +RTC_C_EXPORT int rtcGetRemoteDescriptionType(int pc, char *buffer, int size); + +RTC_C_EXPORT int rtcGetLocalAddress(int pc, char *buffer, int size); +RTC_C_EXPORT int rtcGetRemoteAddress(int pc, char *buffer, int size); + +RTC_C_EXPORT int rtcGetSelectedCandidatePair(int pc, char *local, int localSize, char *remote, + int remoteSize); + +RTC_C_EXPORT int rtcGetMaxDataChannelStream(int pc); +RTC_C_EXPORT int rtcGetRemoteMaxMessageSize(int pc); + +// DataChannel, Track, and WebSocket common API + +RTC_C_EXPORT int rtcSetOpenCallback(int id, rtcOpenCallbackFunc cb); +RTC_C_EXPORT int rtcSetClosedCallback(int id, rtcClosedCallbackFunc cb); +RTC_C_EXPORT int rtcSetErrorCallback(int id, rtcErrorCallbackFunc cb); +RTC_C_EXPORT int rtcSetMessageCallback(int id, rtcMessageCallbackFunc cb); +RTC_C_EXPORT int rtcSendMessage(int id, const char *data, int size); +RTC_C_EXPORT int rtcClose(int id); +RTC_C_EXPORT int rtcDelete(int id); +RTC_C_EXPORT bool rtcIsOpen(int id); +RTC_C_EXPORT bool rtcIsClosed(int id); + +RTC_C_EXPORT int rtcMaxMessageSize(int id); +RTC_C_EXPORT int rtcGetBufferedAmount(int id); // total size buffered to send +RTC_C_EXPORT int rtcSetBufferedAmountLowThreshold(int id, int amount); +RTC_C_EXPORT int rtcSetBufferedAmountLowCallback(int id, rtcBufferedAmountLowCallbackFunc cb); + +// DataChannel, Track, and WebSocket common extended API + +RTC_C_EXPORT int rtcGetAvailableAmount(int id); // total size available to receive +RTC_C_EXPORT int rtcSetAvailableCallback(int id, rtcAvailableCallbackFunc cb); +RTC_C_EXPORT int rtcReceiveMessage(int id, char *buffer, int *size); + +// DataChannel + +typedef struct { + bool unordered; + bool unreliable; + unsigned int maxPacketLifeTime; // ignored if reliable + unsigned int maxRetransmits; // ignored if reliable +} rtcReliability; + +typedef struct { + rtcReliability reliability; + const char *protocol; // empty string if NULL + bool negotiated; + bool manualStream; + uint16_t stream; // numeric ID 0-65534, ignored if manualStream is false +} rtcDataChannelInit; + +RTC_C_EXPORT int rtcSetDataChannelCallback(int pc, rtcDataChannelCallbackFunc cb); +RTC_C_EXPORT int rtcCreateDataChannel(int pc, const char *label); // returns dc id +RTC_C_EXPORT int rtcCreateDataChannelEx(int pc, const char *label, + const rtcDataChannelInit *init); // returns dc id +RTC_C_EXPORT int rtcDeleteDataChannel(int dc); + +RTC_C_EXPORT int rtcGetDataChannelStream(int dc); +RTC_C_EXPORT int rtcGetDataChannelLabel(int dc, char *buffer, int size); +RTC_C_EXPORT int rtcGetDataChannelProtocol(int dc, char *buffer, int size); +RTC_C_EXPORT int rtcGetDataChannelReliability(int dc, rtcReliability *reliability); + +// Track + +typedef struct { + rtcDirection direction; + rtcCodec codec; + int payloadType; + uint32_t ssrc; + const char *mid; + const char *name; // optional + const char *msid; // optional + const char *trackId; // optional, track ID used in MSID + const char *profile; // optional, codec profile +} rtcTrackInit; + +RTC_C_EXPORT int rtcSetTrackCallback(int pc, rtcTrackCallbackFunc cb); +RTC_C_EXPORT int rtcAddTrack(int pc, const char *mediaDescriptionSdp); // returns tr id +RTC_C_EXPORT int rtcAddTrackEx(int pc, const rtcTrackInit *init); // returns tr id +RTC_C_EXPORT int rtcDeleteTrack(int tr); + +RTC_C_EXPORT int rtcGetTrackDescription(int tr, char *buffer, int size); +RTC_C_EXPORT int rtcGetTrackMid(int tr, char *buffer, int size); +RTC_C_EXPORT int rtcGetTrackDirection(int tr, rtcDirection *direction); + +RTC_C_EXPORT int rtcRequestKeyframe(int tr); +RTC_C_EXPORT int rtcRequestBitrate(int tr, unsigned int bitrate); + +#if RTC_ENABLE_MEDIA + +// Media + +// Define how OBUs are packetizied in a AV1 Sample +typedef enum { + RTC_OBU_PACKETIZED_OBU = 0, + RTC_OBU_PACKETIZED_TEMPORAL_UNIT = 1, +} rtcObuPacketization; + +// Define how NAL units are separated in a H264/H265 sample +typedef enum { + RTC_NAL_SEPARATOR_LENGTH = 0, // first 4 bytes are NAL unit length + RTC_NAL_SEPARATOR_LONG_START_SEQUENCE = 1, // 0x00, 0x00, 0x00, 0x01 + RTC_NAL_SEPARATOR_SHORT_START_SEQUENCE = 2, // 0x00, 0x00, 0x01 + RTC_NAL_SEPARATOR_START_SEQUENCE = 3, // long or short start sequence +} rtcNalUnitSeparator; + +typedef struct { + uint32_t ssrc; + const char *cname; + uint8_t payloadType; + uint32_t clockRate; + uint16_t sequenceNumber; + uint32_t timestamp; + + // H264, H265, AV1 + uint16_t maxFragmentSize; // Maximum fragment size, 0 means default + + // H264/H265 only + rtcNalUnitSeparator nalSeparator; // NAL unit separator + + // AV1 only + rtcObuPacketization obuPacketization; // OBU paketization for AV1 samples + +} rtcPacketizerInit; + +// Deprecated, do not use +typedef rtcPacketizerInit rtcPacketizationHandlerInit; + +typedef struct { + uint32_t ssrc; + const char *name; // optional + const char *msid; // optional + const char *trackId; // optional, track ID used in MSID +} rtcSsrcForTypeInit; + +// Opaque type used (via rtcMessage*) to reference an rtc::Message +typedef void *rtcMessage; + +// Allocate a new opaque message. +// Must be explicitly freed by rtcDeleteOpaqueMessage() unless +// explicitly returned by a media interceptor callback; +RTC_C_EXPORT rtcMessage *rtcCreateOpaqueMessage(void *data, int size); +RTC_C_EXPORT void rtcDeleteOpaqueMessage(rtcMessage *msg); + +// Set MediaInterceptor on peer connection +RTC_C_EXPORT int rtcSetMediaInterceptorCallback(int id, rtcInterceptorCallbackFunc cb); + +// Set a packetizer on track +RTC_C_EXPORT int rtcSetH264Packetizer(int tr, const rtcPacketizerInit *init); +RTC_C_EXPORT int rtcSetH265Packetizer(int tr, const rtcPacketizerInit *init); +RTC_C_EXPORT int rtcSetAV1Packetizer(int tr, const rtcPacketizerInit *init); +RTC_C_EXPORT int rtcSetOpusPacketizer(int tr, const rtcPacketizerInit *init); +RTC_C_EXPORT int rtcSetAACPacketizer(int tr, const rtcPacketizerInit *init); + +// Deprecated, do not use +RTC_DEPRECATED static inline int +rtcSetH264PacketizationHandler(int tr, const rtcPacketizationHandlerInit *init) { + return rtcSetH264Packetizer(tr, init); +} +RTC_DEPRECATED static inline int +rtcSetH265PacketizationHandler(int tr, const rtcPacketizationHandlerInit *init) { + return rtcSetH265Packetizer(tr, init); +} +RTC_DEPRECATED static inline int +rtcSetAV1PacketizationHandler(int tr, const rtcPacketizationHandlerInit *init) { + return rtcSetAV1Packetizer(tr, init); +} +RTC_DEPRECATED static inline int +rtcSetOpusPacketizationHandler(int tr, const rtcPacketizationHandlerInit *init) { + return rtcSetOpusPacketizer(tr, init); +} +RTC_DEPRECATED static inline int +rtcSetAACPacketizationHandler(int tr, const rtcPacketizationHandlerInit *init) { + return rtcSetAACPacketizer(tr, init); +} + +// Chain RtcpReceivingSession on track +RTC_C_EXPORT int rtcChainRtcpReceivingSession(int tr); + +// Chain RtcpSrReporter on track +RTC_C_EXPORT int rtcChainRtcpSrReporter(int tr); + +// Chain RtcpNackResponder on track +RTC_C_EXPORT int rtcChainRtcpNackResponder(int tr, unsigned int maxStoredPacketsCount); + +// Chain PliHandler on track +RTC_C_EXPORT int rtcChainPliHandler(int tr, rtcPliHandlerCallbackFunc cb); + +// Transform seconds to timestamp using track's clock rate, result is written to timestamp +RTC_C_EXPORT int rtcTransformSecondsToTimestamp(int id, double seconds, uint32_t *timestamp); + +// Transform timestamp to seconds using track's clock rate, result is written to seconds +RTC_C_EXPORT int rtcTransformTimestampToSeconds(int id, uint32_t timestamp, double *seconds); + +// Get current timestamp, result is written to timestamp +RTC_C_EXPORT int rtcGetCurrentTrackTimestamp(int id, uint32_t *timestamp); + +// Set RTP timestamp for track identified by given id +RTC_C_EXPORT int rtcSetTrackRtpTimestamp(int id, uint32_t timestamp); + +// Get timestamp of last RTCP SR, result is written to timestamp +RTC_C_EXPORT int rtcGetLastTrackSenderReportTimestamp(int id, uint32_t *timestamp); + +// Set NeedsToReport flag in RtcpSrReporter handler identified by given track id +RTC_C_EXPORT int rtcSetNeedsToSendRtcpSr(int id); + +// Get all available payload types for given codec and stores them in buffer, does nothing if +// buffer is NULL +int rtcGetTrackPayloadTypesForCodec(int tr, const char *ccodec, int *buffer, int size); + +// Get all SSRCs for given track +int rtcGetSsrcsForTrack(int tr, uint32_t *buffer, int count); + +// Get CName for SSRC +int rtcGetCNameForSsrc(int tr, uint32_t ssrc, char *cname, int cnameSize); + +// Get all SSRCs for given media type in given SDP +int rtcGetSsrcsForType(const char *mediaType, const char *sdp, uint32_t *buffer, int bufferSize); + +// Set SSRC for given media type in given SDP +int rtcSetSsrcForType(const char *mediaType, const char *sdp, char *buffer, const int bufferSize, + rtcSsrcForTypeInit *init); + +#endif // RTC_ENABLE_MEDIA + +#if RTC_ENABLE_WEBSOCKET + +// WebSocket + +typedef struct { + bool disableTlsVerification; // if true, don't verify the TLS certificate + const char *proxyServer; // only non-authenticated http supported for now + const char **protocols; + int protocolsCount; + int connectionTimeoutMs; // in milliseconds, 0 means default, < 0 means disabled + int pingIntervalMs; // in milliseconds, 0 means default, < 0 means disabled + int maxOutstandingPings; // 0 means default, < 0 means disabled + int maxMessageSize; // <= 0 means default +} rtcWsConfiguration; + +RTC_C_EXPORT int rtcCreateWebSocket(const char *url); // returns ws id +RTC_C_EXPORT int rtcCreateWebSocketEx(const char *url, const rtcWsConfiguration *config); +RTC_C_EXPORT int rtcDeleteWebSocket(int ws); + +RTC_C_EXPORT int rtcGetWebSocketRemoteAddress(int ws, char *buffer, int size); +RTC_C_EXPORT int rtcGetWebSocketPath(int ws, char *buffer, int size); + +// WebSocketServer + +typedef void(RTC_API *rtcWebSocketClientCallbackFunc)(int wsserver, int ws, void *ptr); + +typedef struct { + uint16_t port; // 0 means automatic selection + bool enableTls; // if true, enable TLS (WSS) + const char *certificatePemFile; // NULL for autogenerated certificate + const char *keyPemFile; // NULL for autogenerated certificate + const char *keyPemPass; // NULL if no pass + const char *bindAddress; // NULL for any + int connectionTimeoutMs; // in milliseconds, 0 means default, < 0 means disabled + int maxMessageSize; // <= 0 means default +} rtcWsServerConfiguration; + +RTC_C_EXPORT int rtcCreateWebSocketServer(const rtcWsServerConfiguration *config, + rtcWebSocketClientCallbackFunc cb); // returns wsserver id +RTC_C_EXPORT int rtcDeleteWebSocketServer(int wsserver); + +RTC_C_EXPORT int rtcGetWebSocketServerPort(int wsserver); + +#endif + +// Optional global preload and cleanup + +RTC_C_EXPORT void rtcPreload(void); +RTC_C_EXPORT void rtcCleanup(void); + +// SCTP global settings + +typedef struct { + int recvBufferSize; // in bytes, <= 0 means optimized default + int sendBufferSize; // in bytes, <= 0 means optimized default + int maxChunksOnQueue; // in chunks, <= 0 means optimized default + int initialCongestionWindow; // in MTUs, <= 0 means optimized default + int maxBurst; // in MTUs, 0 means optimized default, < 0 means disabled + int congestionControlModule; // 0: RFC2581 (default), 1: HSTCP, 2: H-TCP, 3: RTCC + int delayedSackTimeMs; // in milliseconds, 0 means optimized default, < 0 means disabled + int minRetransmitTimeoutMs; // in milliseconds, <= 0 means optimized default + int maxRetransmitTimeoutMs; // in milliseconds, <= 0 means optimized default + int initialRetransmitTimeoutMs; // in milliseconds, <= 0 means optimized default + int maxRetransmitAttempts; // number of retransmissions, <= 0 means optimized default + int heartbeatIntervalMs; // in milliseconds, <= 0 means optimized default +} rtcSctpSettings; + +// Note: SCTP settings apply to newly-created PeerConnections only +RTC_C_EXPORT int rtcSetSctpSettings(const rtcSctpSettings *settings); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif diff --git a/datachannel/include/rtc/rtc.hpp b/datachannel/include/rtc/rtc.hpp new file mode 100644 index 000000000..ed10fa31f --- /dev/null +++ b/datachannel/include/rtc/rtc.hpp @@ -0,0 +1,41 @@ +/** + * Copyright (c) 2019 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +// C API +#include "rtc.h" + +// C++ API +#include "common.hpp" +#include "global.hpp" +// +#include "datachannel.hpp" +#include "peerconnection.hpp" +#include "track.hpp" + +#if RTC_ENABLE_WEBSOCKET + +// WebSocket +#include "websocket.hpp" +#include "websocketserver.hpp" + +#endif // RTC_ENABLE_WEBSOCKET + +#if RTC_ENABLE_MEDIA + +// Media +#include "av1rtppacketizer.hpp" +#include "h264rtppacketizer.hpp" +#include "h265rtppacketizer.hpp" +#include "mediahandler.hpp" +#include "plihandler.hpp" +#include "rtcpnackresponder.hpp" +#include "rtcpreceivingsession.hpp" +#include "rtcpsrreporter.hpp" +#include "rtppacketizer.hpp" + +#endif // RTC_ENABLE_MEDIA diff --git a/datachannel/include/rtc/rtcpnackresponder.hpp b/datachannel/include/rtc/rtcpnackresponder.hpp new file mode 100644 index 000000000..71ced1320 --- /dev/null +++ b/datachannel/include/rtc/rtcpnackresponder.hpp @@ -0,0 +1,76 @@ +/** + * Copyright (c) 2020 Filip Klembara (in2core) + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_RTCP_NACK_RESPONDER_H +#define RTC_RTCP_NACK_RESPONDER_H + +#if RTC_ENABLE_MEDIA + +#include "mediahandler.hpp" + +#include +#include + +namespace rtc { + +class RTC_CPP_EXPORT RtcpNackResponder final : public MediaHandler { +public: + static const size_t DefaultMaxSize = 512; + + RtcpNackResponder(size_t maxSize = DefaultMaxSize); + + void incoming(message_vector &messages, const message_callback &send) override; + void outgoing(message_vector &messages, const message_callback &send) override; + +private: + // Packet storage + class RTC_CPP_EXPORT Storage { + + /// Packet storage element + struct RTC_CPP_EXPORT Element { + Element(binary_ptr packet, uint16_t sequenceNumber, shared_ptr next = nullptr); + const binary_ptr packet; + const uint16_t sequenceNumber; + /// Pointer to newer element + shared_ptr next = nullptr; + }; + + private: + /// Oldest packet in storage + shared_ptr oldest = nullptr; + /// Newest packet in storage + shared_ptr newest = nullptr; + /// Inner storage + std::unordered_map> storage{}; + std::mutex mutex; + + /// Maximum storage size + const size_t maxSize; + + /// Returns current size + size_t size(); + + public: + Storage(size_t _maxSize); + + /// Returns packet with given sequence number + optional get(uint16_t sequenceNumber); + + /// Stores packet + /// @param packet Packet + void store(binary_ptr packet); + }; + + const shared_ptr mStorage; +}; + +} // namespace rtc + +#endif /* RTC_ENABLE_MEDIA */ + +#endif /* RTC_RTCP_NACK_RESPONDER_H */ diff --git a/datachannel/include/rtc/rtcpreceivingsession.hpp b/datachannel/include/rtc/rtcpreceivingsession.hpp new file mode 100644 index 000000000..00d47d51b --- /dev/null +++ b/datachannel/include/rtc/rtcpreceivingsession.hpp @@ -0,0 +1,54 @@ +/** + * Copyright (c) 2020 Staz Modrzynski + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_RTCP_RECEIVING_SESSION_H +#define RTC_RTCP_RECEIVING_SESSION_H + +#if RTC_ENABLE_MEDIA + +#include "common.hpp" +#include "mediahandler.hpp" +#include "message.hpp" +#include "rtp.hpp" + +#include + +namespace rtc { + +// An RtcpSession can be plugged into a Track to handle the whole RTCP session +class RTC_CPP_EXPORT RtcpReceivingSession : public MediaHandler { +public: + RtcpReceivingSession() = default; + virtual ~RtcpReceivingSession() = default; + + void incoming(message_vector &messages, const message_callback &send) override; + bool requestKeyframe(const message_callback &send) override; + bool requestBitrate(unsigned int bitrate, const message_callback &send) override; + + // For backward compatibility + [[deprecated("Use Track::requestKeyframe()")]] inline bool requestKeyframe() { return false; }; + [[deprecated("Use Track::requestBitrate()")]] inline void requestBitrate(unsigned int) {}; + +protected: + void pushREMB(const message_callback &send, unsigned int bitrate); + void pushRR(const message_callback &send,unsigned int lastSrDelay); + void pushPLI(const message_callback &send); + + SSRC mSsrc = 0; + uint32_t mGreatestSeqNo = 0; + uint64_t mSyncRTPTS, mSyncNTPTS; + + std::atomic mRequestedBitrate = 0; +}; + +} // namespace rtc + +#endif // RTC_ENABLE_MEDIA + +#endif // RTC_RTCP_RECEIVING_SESSION_H diff --git a/datachannel/include/rtc/rtcpsrreporter.hpp b/datachannel/include/rtc/rtcpsrreporter.hpp new file mode 100644 index 000000000..1fb10a467 --- /dev/null +++ b/datachannel/include/rtc/rtcpsrreporter.hpp @@ -0,0 +1,46 @@ +/** + * Copyright (c) 2020 Filip Klembara (in2core) + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_RTCP_SR_REPORTER_H +#define RTC_RTCP_SR_REPORTER_H + +#if RTC_ENABLE_MEDIA + +#include "mediahandler.hpp" +#include "rtppacketizationconfig.hpp" +#include "rtp.hpp" + +namespace rtc { + +class RTC_CPP_EXPORT RtcpSrReporter final : public MediaHandler { +public: + RtcpSrReporter(shared_ptr rtpConfig); + + uint32_t lastReportedTimestamp() const; + void setNeedsToReport(); + + void outgoing(message_vector &messages, const message_callback &send) override; + + // TODO: remove this + const shared_ptr rtpConfig; + +private: + void addToReport(RtpHeader *rtp, uint32_t rtpSize); + message_ptr getSenderReport(uint32_t timestamp); + + uint32_t mPacketCount = 0; + uint32_t mPayloadOctets = 0; + uint32_t mLastReportedTimestamp = 0; + bool mNeedsToReport = false; +}; + +} // namespace rtc + +#endif /* RTC_ENABLE_MEDIA */ + +#endif /* RTC_RTCP_SR_REPORTER_H */ diff --git a/datachannel/include/rtc/rtp.hpp b/datachannel/include/rtc/rtp.hpp new file mode 100644 index 000000000..503a61ff5 --- /dev/null +++ b/datachannel/include/rtc/rtp.hpp @@ -0,0 +1,380 @@ +/** + * Copyright (c) 2020 Staz Modrzynski + * Copyright (c) 2020 Paul-Louis Ageneau + * Copyright (c) 2020 Filip Klembara (in2core) + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_RTP_HPP +#define RTC_RTP_HPP + +#include "common.hpp" + +#include + +namespace rtc { + +typedef uint32_t SSRC; + +RTC_CPP_EXPORT bool IsRtcp(const binary &data); + +#pragma pack(push, 1) + +struct RTC_CPP_EXPORT RtpExtensionHeader { + uint16_t _profileSpecificId; + uint16_t _headerLength; + + [[nodiscard]] uint16_t profileSpecificId() const; + [[nodiscard]] uint16_t headerLength() const; + + [[nodiscard]] size_t getSize() const; + [[nodiscard]] const char *getBody() const; + [[nodiscard]] char *getBody(); + + void setProfileSpecificId(uint16_t profileSpecificId); + void setHeaderLength(uint16_t headerLength); + + void clearBody(); + void writeCurrentVideoOrientation(size_t offset, uint8_t id, uint8_t value); + void writeOneByteHeader(size_t offset, uint8_t id, const byte *value, size_t size); +}; + +struct RTC_CPP_EXPORT RtpHeader { + uint8_t _first; + uint8_t _payloadType; + uint16_t _seqNumber; + uint32_t _timestamp; + SSRC _ssrc; + // The following field is SSRC _csrc[] + + [[nodiscard]] uint8_t version() const; + [[nodiscard]] bool padding() const; + [[nodiscard]] bool extension() const; + [[nodiscard]] uint8_t csrcCount() const; + [[nodiscard]] uint8_t marker() const; + [[nodiscard]] uint8_t payloadType() const; + [[nodiscard]] uint16_t seqNumber() const; + [[nodiscard]] uint32_t timestamp() const; + [[nodiscard]] uint32_t ssrc() const; + + [[nodiscard]] size_t getSize() const; + [[nodiscard]] size_t getExtensionHeaderSize() const; + [[nodiscard]] const RtpExtensionHeader *getExtensionHeader() const; + [[nodiscard]] RtpExtensionHeader *getExtensionHeader(); + [[nodiscard]] const char *getBody() const; + [[nodiscard]] char *getBody(); + + void log() const; + + void preparePacket(); + void setSeqNumber(uint16_t newSeqNo); + void setPayloadType(uint8_t newPayloadType); + void setSsrc(uint32_t in_ssrc); + void setMarker(bool marker); + void setTimestamp(uint32_t i); + void setExtension(bool extension); +}; + +struct RTC_CPP_EXPORT RtcpReportBlock { + SSRC _ssrc; + uint32_t _fractionLostAndPacketsLost; // fraction lost is 8-bit, packets lost is 24-bit + uint16_t _seqNoCycles; + uint16_t _highestSeqNo; + uint32_t _jitter; + uint32_t _lastReport; + uint32_t _delaySinceLastReport; + + [[nodiscard]] uint16_t seqNoCycles() const; + [[nodiscard]] uint16_t highestSeqNo() const; + [[nodiscard]] uint32_t extendedHighestSeqNo() const; + [[nodiscard]] uint32_t jitter() const; + [[nodiscard]] uint32_t delaySinceSR() const; + + [[nodiscard]] SSRC getSSRC() const; + [[nodiscard]] uint32_t getNTPOfSR() const; + [[nodiscard]] uint8_t getFractionLost() const; + [[nodiscard]] unsigned int getPacketsLostCount() const; + + void preparePacket(SSRC in_ssrc, unsigned int packetsLost, unsigned int totalPackets, + uint16_t highestSeqNo, uint16_t seqNoCycles, uint32_t jitter, + uint64_t lastSR_NTP, uint64_t lastSR_DELAY); + void setSSRC(SSRC in_ssrc); + void setPacketsLost(uint8_t fractionLost, unsigned int packetsLostCount); + void setSeqNo(uint16_t highestSeqNo, uint16_t seqNoCycles); + void setJitter(uint32_t jitter); + void setNTPOfSR(uint64_t ntp); + void setDelaySinceSR(uint32_t sr); + + void log() const; +}; + +struct RTC_CPP_EXPORT RtcpHeader { + uint8_t _first; + uint8_t _payloadType; + uint16_t _length; + + [[nodiscard]] uint8_t version() const; + [[nodiscard]] bool padding() const; + [[nodiscard]] uint8_t reportCount() const; + [[nodiscard]] uint8_t payloadType() const; + [[nodiscard]] uint16_t length() const; + [[nodiscard]] size_t lengthInBytes() const; + + void prepareHeader(uint8_t payloadType, uint8_t reportCount, uint16_t length); + void setPayloadType(uint8_t type); + void setReportCount(uint8_t count); + void setLength(uint16_t length); + + void log() const; +}; + +struct RTC_CPP_EXPORT RtcpFbHeader { + RtcpHeader header; + + SSRC _packetSender; + SSRC _mediaSource; + + [[nodiscard]] SSRC packetSenderSSRC() const; + [[nodiscard]] SSRC mediaSourceSSRC() const; + + void setPacketSenderSSRC(SSRC ssrc); + void setMediaSourceSSRC(SSRC ssrc); + + void log() const; +}; + +struct RTC_CPP_EXPORT RtcpSr { + RtcpHeader header; + + SSRC _senderSSRC; + uint64_t _ntpTimestamp; + uint32_t _rtpTimestamp; + uint32_t _packetCount; + uint32_t _octetCount; + + RtcpReportBlock _reportBlocks; + + [[nodiscard]] static unsigned int Size(unsigned int reportCount); + + [[nodiscard]] uint64_t ntpTimestamp() const; + [[nodiscard]] uint32_t rtpTimestamp() const; + [[nodiscard]] uint32_t packetCount() const; + [[nodiscard]] uint32_t octetCount() const; + [[nodiscard]] uint32_t senderSSRC() const; + + [[nodiscard]] const RtcpReportBlock *getReportBlock(int num) const; + [[nodiscard]] RtcpReportBlock *getReportBlock(int num); + [[nodiscard]] unsigned int size(unsigned int reportCount); + [[nodiscard]] size_t getSize() const; + + void preparePacket(SSRC senderSSRC, uint8_t reportCount); + void setNtpTimestamp(uint64_t ts); + void setRtpTimestamp(uint32_t ts); + void setOctetCount(uint32_t ts); + void setPacketCount(uint32_t ts); + + void log() const; +}; + +struct RTC_CPP_EXPORT RtcpSdesItem { + uint8_t type; + + uint8_t _length; + char _text[1]; + + [[nodiscard]] static unsigned int Size(uint8_t textLength); + + [[nodiscard]] string text() const; + [[nodiscard]] uint8_t length() const; + + void setText(string text); +}; + +struct RTC_CPP_EXPORT RtcpSdesChunk { + SSRC _ssrc; + RtcpSdesItem _items; + + [[nodiscard]] static unsigned int Size(const std::vector textLengths); + + [[nodiscard]] SSRC ssrc() const; + + void setSSRC(SSRC ssrc); + + // Get item at given index + // All items with index < num must be valid, otherwise this function has undefined behaviour + // (use safelyCountChunkSize() to check if chunk is valid). + [[nodiscard]] const RtcpSdesItem *getItem(int num) const; + [[nodiscard]] RtcpSdesItem *getItem(int num); + + // Get size of chunk + // All items must be valid, otherwise this function has undefined behaviour (use + // safelyCountChunkSize() to check if chunk is valid) + [[nodiscard]] unsigned int getSize() const; + + long safelyCountChunkSize(size_t maxChunkSize) const; +}; + +struct RTC_CPP_EXPORT RtcpSdes { + RtcpHeader header; + RtcpSdesChunk _chunks; + + [[nodiscard]] static unsigned int Size(const std::vector> lengths); + + bool isValid() const; + + // Returns number of chunks in this packet + // Returns 0 if packet is invalid + unsigned int chunksCount() const; + + // Get chunk at given index + // All chunks (and their items) with index < `num` must be valid, otherwise this function has + // undefined behaviour (use `isValid` to check if chunk is valid). + const RtcpSdesChunk *getChunk(int num) const; + RtcpSdesChunk *getChunk(int num); + + void preparePacket(uint8_t chunkCount); +}; + +struct RTC_CPP_EXPORT RtcpRr { + RtcpHeader header; + + SSRC _senderSSRC; + RtcpReportBlock _reportBlocks; + + [[nodiscard]] static size_t SizeWithReportBlocks(uint8_t reportCount); + + SSRC senderSSRC() const; + bool isSenderReport(); + bool isReceiverReport(); + + [[nodiscard]] RtcpReportBlock *getReportBlock(int num); + [[nodiscard]] const RtcpReportBlock *getReportBlock(int num) const; + [[nodiscard]] size_t getSize() const; + + void preparePacket(SSRC senderSSRC, uint8_t reportCount); + void setSenderSSRC(SSRC ssrc); + + void log() const; +}; + +struct RTC_CPP_EXPORT RtcpRemb { + RtcpFbHeader header; + + char _id[4]; // Unique identifier ('R' 'E' 'M' 'B') + uint32_t _bitrate; // Num SSRC, Br Exp, Br Mantissa (bit mask) + SSRC _ssrc[1]; + + [[nodiscard]] static size_t SizeWithSSRCs(int count); + + [[nodiscard]] unsigned int getSize() const; + + void preparePacket(SSRC senderSSRC, unsigned int numSSRC, unsigned int in_bitrate); + void setBitrate(unsigned int numSSRC, unsigned int in_bitrate); + void setSsrc(int iterator, SSRC newSssrc); +}; + +struct RTC_CPP_EXPORT RtcpPli { + RtcpFbHeader header; + + [[nodiscard]] static unsigned int Size(); + + void preparePacket(SSRC messageSSRC); + + void log() const; +}; + +struct RTC_CPP_EXPORT RtcpFirPart { + uint32_t ssrc; + uint8_t seqNo; + uint8_t dummy1; + uint16_t dummy2; +}; + +struct RTC_CPP_EXPORT RtcpFir { + RtcpFbHeader header; + RtcpFirPart parts[1]; + + static unsigned int Size(); + + void preparePacket(SSRC messageSSRC, uint8_t seqNo); + + void log() const; +}; + +struct RTC_CPP_EXPORT RtcpNackPart { + uint16_t _pid; + uint16_t _blp; + + uint16_t pid(); + uint16_t blp(); + + void setPid(uint16_t pid); + void setBlp(uint16_t blp); + + std::vector getSequenceNumbers(); +}; + +struct RTC_CPP_EXPORT RtcpNack { + RtcpFbHeader header; + RtcpNackPart parts[1]; + + [[nodiscard]] static unsigned int Size(unsigned int discreteSeqNoCount); + + [[nodiscard]] unsigned int getSeqNoCount(); + + void preparePacket(SSRC ssrc, unsigned int discreteSeqNoCount); + + /** + * Add a packet to the list of missing packets. + * @param fciCount The number of FCI fields that are present in this packet. + * Let the number start at zero and let this function grow the number. + * @param fciPID The seq no of the active FCI. It will be initialized automatically, and will + * change automatically. + * @param missingPacket The seq no of the missing packet. This will be added to the queue. + * @return true if the packet has grown, false otherwise. + */ + bool addMissingPacket(unsigned int *fciCount, uint16_t *fciPID, uint16_t missingPacket); +}; + +struct RTC_CPP_EXPORT RtpRtx { + RtpHeader header; + + [[nodiscard]] const char *getBody() const; + [[nodiscard]] char *getBody(); + [[nodiscard]] size_t getBodySize(size_t totalSize) const; + [[nodiscard]] size_t getSize() const; + [[nodiscard]] uint16_t getOriginalSeqNo() const; + + // Returns the new size of the packet + size_t normalizePacket(size_t totalSize, SSRC originalSSRC, uint8_t originalPayloadType); + + size_t copyTo(RtpHeader *dest, size_t totalSize, uint8_t originalPayloadType); +}; + +// For backward compatibility, do not use +using RTP_ExtensionHeader [[deprecated]] = RtpExtensionHeader; +using RTP [[deprecated]] = RtpHeader; +using RTCP_ReportBlock [[deprecated]] = RtcpReportBlock; +using RTCP_HEADER [[deprecated]] = RtcpHeader; +using RTCP_FB_HEADER [[deprecated]] = RtcpFbHeader; +using RTCP_SR [[deprecated]] = RtcpSr; +using RTCP_SDES_ITEM [[deprecated]] = RtcpSdesItem; +using RTCP_SDES_CHUNK [[deprecated]] = RtcpSdesChunk; +using RTCP_SDES [[deprecated]] = RtcpSdes; +using RTCP_RR [[deprecated]] = RtcpRr; +using RTCP_REMB [[deprecated]] = RtcpRemb; +using RTCP_PLI [[deprecated]] = RtcpPli; +using RTCP_FIR_PART [[deprecated]] = RtcpFirPart; +using RTCP_FIR [[deprecated]] = RtcpFir; +using RTCP_NACK_PART [[deprecated]] = RtcpNackPart; +using RTCP_NACK [[deprecated]] = RtcpNack; +using RTP_RTX [[deprecated]] = RtpRtx; + +#pragma pack(pop) + +} // namespace rtc + +#endif diff --git a/datachannel/include/rtc/rtppacketizationconfig.hpp b/datachannel/include/rtc/rtppacketizationconfig.hpp new file mode 100644 index 000000000..0e6dcada2 --- /dev/null +++ b/datachannel/include/rtc/rtppacketizationconfig.hpp @@ -0,0 +1,99 @@ +/** + * Copyright (c) 2020 Filip Klembara (in2core) + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_RTP_PACKETIZATION_CONFIG_H +#define RTC_RTP_PACKETIZATION_CONFIG_H + +#if RTC_ENABLE_MEDIA + +#include "rtp.hpp" + +namespace rtc { + +// RTP configuration used in packetization process +class RTC_CPP_EXPORT RtpPacketizationConfig { +public: + SSRC ssrc; + std::string cname; + uint8_t payloadType; + uint32_t clockRate; + uint8_t videoOrientationId; + + // current sequence number + uint16_t sequenceNumber; + + // current timestamp + uint32_t timestamp; + + // start timestamp + uint32_t startTimestamp; + + /// Current video orientation + /// + /// Bit# 7 6 5 4 3 2 1 0 + /// Definition 0 0 0 0 C F R1 R0 + /// + /// C + /// 0 - Front-facing camera (use this if unsure) + /// 1 - Back-facing camera + /// + /// F + /// 0 - No Flip + /// 1 - Horizontal flip + /// + /// R1 R0 - CW rotation that receiver must apply + /// 0 - 0 degrees + /// 1 - 90 degrees + /// 2 - 180 degrees + /// 3 - 270 degrees + uint8_t videoOrientation = 0; + + // MID Extension Header + uint8_t midId = 0; + optional mid; + + // RID Extension Header + uint8_t ridId = 0; + optional rid; + + /// Construct RTP configuration used in packetization process + /// @param ssrc SSRC of source + /// @param cname CNAME of source + /// @param payloadType Payload type of source + /// @param clockRate Clock rate of source used in timestamps + /// nullopt) + /// @param videoOrientationId Video orientation (see above) + RtpPacketizationConfig(SSRC ssrc, std::string cname, uint8_t payloadType, uint32_t clockRate, + uint8_t videoOrientationId = 0); + + RtpPacketizationConfig(const RtpPacketizationConfig &) = delete; + + /// Convert timestamp to seconds + /// @param timestamp Timestamp + /// @param clockRate Clock rate for timestamp calculation + static double getSecondsFromTimestamp(uint32_t timestamp, uint32_t clockRate); + + /// Convert timestamp to seconds + /// @param timestamp Timestamp + double timestampToSeconds(uint32_t timestamp); + + /// Convert seconds to timestamp + /// @param seconds Number of seconds + /// @param clockRate Clock rate for timestamp calculation + static uint32_t getTimestampFromSeconds(double seconds, uint32_t clockRate); + + /// Convert seconds to timestamp + /// @param seconds Number of seconds + uint32_t secondsToTimestamp(double seconds); +}; + +} // namespace rtc + +#endif /* RTC_ENABLE_MEDIA */ + +#endif /* RTC_RTP_PACKETIZATION_CONFIG_H */ diff --git a/datachannel/include/rtc/rtppacketizer.hpp b/datachannel/include/rtc/rtppacketizer.hpp new file mode 100644 index 000000000..99839b9d4 --- /dev/null +++ b/datachannel/include/rtc/rtppacketizer.hpp @@ -0,0 +1,88 @@ +/** + * Copyright (c) 2020 Filip Klembara (in2core) + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_RTP_PACKETIZER_H +#define RTC_RTP_PACKETIZER_H + +#if RTC_ENABLE_MEDIA + +#include "mediahandler.hpp" +#include "message.hpp" +#include "rtppacketizationconfig.hpp" + +namespace rtc { + +/// RTP packetizer +class RTC_CPP_EXPORT RtpPacketizer : public MediaHandler { +public: + /// Constructs packetizer with given RTP configuration + /// @note RTP configuration is used in packetization process which may change some configuration + /// properties such as sequence number. + /// @param rtpConfig RTP configuration + RtpPacketizer(shared_ptr rtpConfig); + virtual ~RtpPacketizer(); + + virtual void media(const Description::Media &desc) override; + virtual void outgoing(message_vector &messages, const message_callback &send) override; + + /// RTP packetization config + const shared_ptr rtpConfig; + +protected: + /// Creates RTP packet for given payload + /// @note This function increase sequence number after packetization. + /// @param payload RTP payload + /// @param setMark Set marker flag in RTP packet if true + virtual message_ptr packetize(shared_ptr payload, bool mark); + +private: + static const auto RtpHeaderSize = 12; + static const auto RtpExtHeaderCvoSize = 8; +}; + +// Generic audio RTP packetizer +template +class RTC_CPP_EXPORT AudioRtpPacketizer final : public RtpPacketizer { +public: + inline static const uint32_t DefaultClockRate = DEFAULT_CLOCK_RATE; + inline static const uint32_t defaultClockRate [[deprecated("Use DefaultClockRate")]] = + DEFAULT_CLOCK_RATE; // for backward compatibility + + AudioRtpPacketizer(shared_ptr rtpConfig) + : RtpPacketizer(std::move(rtpConfig)) {} +}; + +// Audio RTP packetizers +using OpusRtpPacketizer = AudioRtpPacketizer<48000>; +using AACRtpPacketizer = AudioRtpPacketizer<48000>; + +// Dummy wrapper for backward compatibility, do not use +class RTC_CPP_EXPORT PacketizationHandler final : public MediaHandler { +public: + PacketizationHandler(shared_ptr packetizer) + : mPacketizer(std::move(packetizer)) {} + + inline void outgoing(message_vector &messages, const message_callback &send) { + return mPacketizer->outgoing(messages, send); + } + +private: + shared_ptr mPacketizer; +}; + +// Audio packetization handlers for backward compatibility, do not use +using OpusPacketizationHandler [[deprecated("Add OpusRtpPacketizer directly")]] = + PacketizationHandler; +using AACPacketizationHandler [[deprecated("Add AACRtpPacketizer directly")]] = + PacketizationHandler; + +} // namespace rtc + +#endif /* RTC_ENABLE_MEDIA */ + +#endif /* RTC_RTP_PACKETIZER_H */ diff --git a/datachannel/include/rtc/track.hpp b/datachannel/include/rtc/track.hpp new file mode 100644 index 000000000..4bd9fde3e --- /dev/null +++ b/datachannel/include/rtc/track.hpp @@ -0,0 +1,61 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_TRACK_H +#define RTC_TRACK_H + +#include "channel.hpp" +#include "common.hpp" +#include "description.hpp" +#include "mediahandler.hpp" + +namespace rtc { + +namespace impl { + +class Track; + +} // namespace impl + +class RTC_CPP_EXPORT Track final : private CheshireCat, public Channel { +public: + Track(impl_ptr impl); + ~Track() override; + + string mid() const; + Description::Direction direction() const; + Description::Media description() const; + + void setDescription(Description::Media description); + + void close(void) override; + bool send(message_variant data) override; + bool send(const byte *data, size_t size) override; + + bool isOpen(void) const override; + bool isClosed(void) const override; + size_t maxMessageSize() const override; + + bool requestKeyframe(); + bool requestBitrate(unsigned int bitrate); + + void setMediaHandler(shared_ptr handler); + void chainMediaHandler(shared_ptr handler); + shared_ptr getMediaHandler(); + + // Deprecated, use setMediaHandler() and getMediaHandler() + inline void setRtcpHandler(shared_ptr handler) { setMediaHandler(handler); } + inline shared_ptr getRtcpHandler() { return getMediaHandler(); } + +private: + using CheshireCat::impl; +}; + +} // namespace rtc + +#endif diff --git a/datachannel/include/rtc/utils.hpp b/datachannel/include/rtc/utils.hpp new file mode 100644 index 000000000..be2773369 --- /dev/null +++ b/datachannel/include/rtc/utils.hpp @@ -0,0 +1,159 @@ +/** + * Copyright (c) 2019-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_UTILS_H +#define RTC_UTILS_H + +#include +#include +#include +#include +#include +#include + +namespace rtc { + +// overloaded helper +template struct overloaded : Ts... { using Ts::operator()...; }; +template overloaded(Ts...) -> overloaded; + +// weak_ptr bind helper +template auto weak_bind(F &&f, T *t, Args &&..._args) { + return [bound = std::bind(f, t, _args...), weak_this = t->weak_from_this()](auto &&...args) { + if (auto shared_this = weak_this.lock()) + return bound(args...); + else + return static_cast(false); + }; +} + +// scope_guard helper +class scope_guard final { +public: + scope_guard(std::function func) : function(std::move(func)) {} + scope_guard(scope_guard &&other) = delete; + scope_guard(const scope_guard &) = delete; + void operator=(const scope_guard &) = delete; + + ~scope_guard() { + if (function) + function(); + } + +private: + std::function function; +}; + +// callback with built-in synchronization +template class synchronized_callback { +public: + synchronized_callback() = default; + synchronized_callback(synchronized_callback &&cb) { *this = std::move(cb); } + synchronized_callback(const synchronized_callback &cb) { *this = cb; } + synchronized_callback(std::function func) { *this = std::move(func); } + virtual ~synchronized_callback() { *this = nullptr; } + + synchronized_callback &operator=(synchronized_callback &&cb) { + std::scoped_lock lock(mutex, cb.mutex); + set(std::exchange(cb.callback, nullptr)); + return *this; + } + + synchronized_callback &operator=(const synchronized_callback &cb) { + std::scoped_lock lock(mutex, cb.mutex); + set(cb.callback); + return *this; + } + + synchronized_callback &operator=(std::function func) { + std::lock_guard lock(mutex); + set(std::move(func)); + return *this; + } + + bool operator()(Args... args) const { + std::lock_guard lock(mutex); + return call(std::move(args)...); + } + + operator bool() const { + std::lock_guard lock(mutex); + return callback ? true : false; + } + +protected: + virtual void set(std::function func) { callback = std::move(func); } + virtual bool call(Args... args) const { + if (!callback) + return false; + + callback(std::move(args)...); + return true; + } + + std::function callback; + mutable std::recursive_mutex mutex; +}; + +// callback with built-in synchronization and replay of the last missed call +template +class synchronized_stored_callback final : public synchronized_callback { +public: + template + synchronized_stored_callback(CArgs &&...cargs) + : synchronized_callback(std::forward(cargs)...) {} + ~synchronized_stored_callback() {} + +private: + void set(std::function func) { + synchronized_callback::set(func); + if (func && stored) { + std::apply(func, std::move(*stored)); + stored.reset(); + } + } + + bool call(Args... args) const { + if (!synchronized_callback::call(args...)) + stored.emplace(std::move(args)...); + + return true; + } + + mutable std::optional> stored; +}; + +// pimpl base class +template using impl_ptr = std::shared_ptr; +template class CheshireCat { +public: + CheshireCat(impl_ptr impl) : mImpl(std::move(impl)) {} + template + CheshireCat(Args... args) : mImpl(std::make_shared(std::forward(args)...)) {} + CheshireCat(CheshireCat &&cc) { *this = std::move(cc); } + CheshireCat(const CheshireCat &) = delete; + + virtual ~CheshireCat() = default; + + CheshireCat &operator=(CheshireCat &&cc) { + mImpl = std::move(cc.mImpl); + return *this; + }; + CheshireCat &operator=(const CheshireCat &) = delete; + +protected: + impl_ptr impl() { return mImpl; } + impl_ptr impl() const { return mImpl; } + +private: + impl_ptr mImpl; +}; + +} // namespace rtc + +#endif diff --git a/datachannel/include/rtc/websocket.hpp b/datachannel/include/rtc/websocket.hpp new file mode 100644 index 000000000..0b20faef2 --- /dev/null +++ b/datachannel/include/rtc/websocket.hpp @@ -0,0 +1,67 @@ +/** + * Copyright (c) 2020-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_WEBSOCKET_H +#define RTC_WEBSOCKET_H + +#if RTC_ENABLE_WEBSOCKET + +#include "channel.hpp" +#include "common.hpp" +#include "configuration.hpp" + +namespace rtc { + +namespace impl { + +struct WebSocket; + +} + +class RTC_CPP_EXPORT WebSocket final : private CheshireCat, public Channel { +public: + enum class State : int { + Connecting = 0, + Open = 1, + Closing = 2, + Closed = 3, + }; + + using Configuration = WebSocketConfiguration; + + WebSocket(); + WebSocket(Configuration config); + WebSocket(impl_ptr impl); + ~WebSocket() override; + + State readyState() const; + + bool isOpen() const override; + bool isClosed() const override; + size_t maxMessageSize() const override; + + void open(const string &url); + void close() override; + void forceClose(); + bool send(const message_variant data) override; + bool send(const byte *data, size_t size) override; + + optional remoteAddress() const; + optional path() const; + +private: + using CheshireCat::impl; +}; + +std::ostream &operator<<(std::ostream &out, WebSocket::State state); + +} // namespace rtc + +#endif + +#endif // RTC_WEBSOCKET_H diff --git a/datachannel/include/rtc/websocketserver.hpp b/datachannel/include/rtc/websocketserver.hpp new file mode 100644 index 000000000..9bb1c0d25 --- /dev/null +++ b/datachannel/include/rtc/websocketserver.hpp @@ -0,0 +1,48 @@ +/** + * Copyright (c) 2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_WEBSOCKETSERVER_H +#define RTC_WEBSOCKETSERVER_H + +#if RTC_ENABLE_WEBSOCKET + +#include "common.hpp" +#include "configuration.hpp" +#include "websocket.hpp" + +namespace rtc { + +namespace impl { + +struct WebSocketServer; + +} + +class RTC_CPP_EXPORT WebSocketServer final : private CheshireCat { +public: + using Configuration = WebSocketServerConfiguration; + + WebSocketServer(); + WebSocketServer(Configuration config); + ~WebSocketServer(); + + void stop(); + + uint16_t port() const; + + void onClient(std::function)> callback); + +private: + using CheshireCat::impl; +}; + +} // namespace rtc + +#endif + +#endif // RTC_WEBSOCKET_H diff --git a/datachannel/src/av1rtppacketizer.cpp b/datachannel/src/av1rtppacketizer.cpp new file mode 100644 index 000000000..b7cfc55e8 --- /dev/null +++ b/datachannel/src/av1rtppacketizer.cpp @@ -0,0 +1,225 @@ +/** + * Copyright (c) 2023 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#if RTC_ENABLE_MEDIA + +#include "av1rtppacketizer.hpp" + +#include "impl/internals.hpp" + +namespace rtc { + +const auto payloadHeaderSize = 1; + +const auto zMask = byte(0b10000000); +const auto yMask = byte(0b01000000); +const auto nMask = byte(0b00001000); + +const auto wBitshift = 4; + +const auto obuFrameTypeMask = byte(0b01111000); +const auto obuFrameTypeBitshift = 3; + +const auto obuHeaderSize = 1; +const auto obuHasExtensionMask = byte(0b00000100); +const auto obuHasSizeMask = byte(0b00000010); + +const auto obuFrameTypeSequenceHeader = byte(1); + +const auto obuTemporalUnitDelimiter = std::vector{byte(0x12), byte(0x00)}; + +const auto oneByteLeb128Size = 1; + +const uint8_t sevenLsbBitmask = 0b01111111; +const uint8_t msbBitmask = 0b10000000; + +std::vector extractTemporalUnitObus(binary_ptr message) { + std::vector> obus{}; + + if (message->size() <= 2 || (message->at(0) != obuTemporalUnitDelimiter.at(0)) || + (message->at(1) != obuTemporalUnitDelimiter.at(1))) { + return obus; + } + + size_t messageIndex = 2; + while (messageIndex < message->size()) { + if ((message->at(messageIndex) & obuHasSizeMask) == byte(0)) { + return obus; + } + + if ((message->at(messageIndex) & obuHasExtensionMask) != byte(0)) { + messageIndex++; + } + + // https://aomediacodec.github.io/av1-spec/#leb128 + uint32_t obuLength = 0; + uint8_t leb128Size = 0; + while (leb128Size < 8) { + auto leb128Index = messageIndex + leb128Size + obuHeaderSize; + if (message->size() < leb128Index) { + break; + } + + auto leb128_byte = uint8_t(message->at(leb128Index)); + + obuLength |= ((leb128_byte & sevenLsbBitmask) << (leb128Size * 7)); + leb128Size++; + + if (!(leb128_byte & msbBitmask)) { + break; + } + } + + obus.push_back(std::make_shared(message->begin() + messageIndex, + message->begin() + messageIndex + obuHeaderSize + + leb128Size + obuLength)); + + messageIndex += obuHeaderSize + leb128Size + obuLength; + } + + return obus; +} + +/* + * 0 1 2 3 4 5 6 7 + * +-+-+-+-+-+-+-+-+ + * |Z|Y| W |N|-|-|-| + * +-+-+-+-+-+-+-+-+ + * + * Z: MUST be set to 1 if the first OBU element is an + * OBU fragment that is a continuation of an OBU fragment + * from the previous packet, and MUST be set to 0 otherwise. + * + * Y: MUST be set to 1 if the last OBU element is an OBU fragment + * that will continue in the next packet, and MUST be set to 0 otherwise. + * + * W: two bit field that describes the number of OBU elements in the packet. + * This field MUST be set equal to 0 or equal to the number of OBU elements + * contained in the packet. If set to 0, each OBU element MUST be preceded by + * a length field. If not set to 0 (i.e., W = 1, 2 or 3) the last OBU element + * MUST NOT be preceded by a length field. Instead, the length of the last OBU + * element contained in the packet can be calculated as follows: + * Length of the last OBU element = + * length of the RTP payload + * - length of aggregation header + * - length of previous OBU elements including length fields + * + * N: MUST be set to 1 if the packet is the first packet of a coded video sequence, and MUST be set + * to 0 otherwise. + * + * https://aomediacodec.github.io/av1-rtp-spec/#44-av1-aggregation-header + * + **/ + +std::vector AV1RtpPacketizer::packetizeObu(binary_ptr message, + uint16_t maxFragmentSize) { + + std::vector> payloads{}; + size_t messageIndex = 0; + + if (message->size() < 1) { + return payloads; + } + + // Cache sequence header and packetize with next OBU + auto frameType = (message->at(0) & obuFrameTypeMask) >> obuFrameTypeBitshift; + if (frameType == obuFrameTypeSequenceHeader) { + sequenceHeader = std::make_shared(message->begin(), message->end()); + return payloads; + } + + size_t messageRemaining = message->size(); + while (messageRemaining > 0) { + auto obuCount = 1; + auto metadataSize = payloadHeaderSize; + + if (sequenceHeader != nullptr) { + obuCount++; + metadataSize += /* 1 byte leb128 */ 1 + int(sequenceHeader->size()); + } + + auto payload = std::make_shared( + std::min(size_t(maxFragmentSize), messageRemaining + metadataSize)); + auto payloadOffset = payloadHeaderSize; + + payload->at(0) = byte(obuCount) << wBitshift; + + // Packetize cached SequenceHeader + if (obuCount == 2) { + payload->at(0) ^= nMask; + payload->at(1) = byte(sequenceHeader->size() & sevenLsbBitmask); + payloadOffset += oneByteLeb128Size; + + std::memcpy(payload->data() + payloadOffset, sequenceHeader->data(), + sequenceHeader->size()); + payloadOffset += int(sequenceHeader->size()); + + sequenceHeader = nullptr; + } + + // Copy as much of OBU as possible into Payload + auto payloadRemaining = payload->size() - payloadOffset; + std::memcpy(payload->data() + payloadOffset, message->data() + messageIndex, + payloadRemaining); + messageRemaining -= payloadRemaining; + messageIndex += payloadRemaining; + + // Does this Fragment contain an OBU that started in a previous payload + if (payloads.size() > 0) { + payload->at(0) ^= zMask; + } + + // This OBU will be continued in next Payload + if (messageIndex < message->size()) { + payload->at(0) ^= yMask; + } + + payloads.push_back(payload); + } + + return payloads; +} + +AV1RtpPacketizer::AV1RtpPacketizer(AV1RtpPacketizer::Packetization packetization, + shared_ptr rtpConfig, + uint16_t maxFragmentSize) + : RtpPacketizer(rtpConfig), maxFragmentSize(maxFragmentSize), + packetization(packetization) {} + +void AV1RtpPacketizer::outgoing(message_vector &messages, + [[maybe_unused]] const message_callback &send) { + message_vector result; + for (const auto &message : messages) { + std::vector obus; + if (packetization == AV1RtpPacketizer::Packetization::TemporalUnit) { + obus = extractTemporalUnitObus(message); + } else { + obus.push_back(message); + } + + std::vector fragments; + for (auto obu : obus) { + auto p = packetizeObu(obu, maxFragmentSize); + fragments.insert(fragments.end(), p.begin(), p.end()); + } + + if (fragments.size() == 0) + continue; + + for (size_t i = 0; i < fragments.size() - 1; i++) + result.push_back(packetize(fragments[i], false)); + + result.push_back(packetize(fragments[fragments.size() - 1], true)); + } + + messages.swap(result); +} + +} // namespace rtc + +#endif /* RTC_ENABLE_MEDIA */ diff --git a/datachannel/src/candidate.cpp b/datachannel/src/candidate.cpp new file mode 100644 index 000000000..13ee90029 --- /dev/null +++ b/datachannel/src/candidate.cpp @@ -0,0 +1,287 @@ +/** + * Copyright (c) 2019 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "candidate.hpp" + +#include "impl/internals.hpp" + +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#include +#else +#include +#include +#include +#endif + +#include + +using std::array; +using std::string; + +namespace { + +inline bool match_prefix(const string &str, const string &prefix) { + return str.size() >= prefix.size() && + std::mismatch(prefix.begin(), prefix.end(), str.begin()).first == prefix.end(); +} + +inline void trim_begin(string &str) { + str.erase(str.begin(), + std::find_if(str.begin(), str.end(), [](char c) { return !std::isspace(c); })); +} + +inline void trim_end(string &str) { + str.erase( + std::find_if(str.rbegin(), str.rend(), [](char c) { return !std::isspace(c); }).base(), + str.end()); +} + +} // namespace + +namespace rtc { + +Candidate::Candidate() + : mFoundation("none"), mComponent(0), mPriority(0), mTypeString("unknown"), + mTransportString("unknown"), mType(Type::Unknown), mTransportType(TransportType::Unknown), + mNode("0.0.0.0"), mService("9"), mFamily(Family::Unresolved), mPort(0) {} + +Candidate::Candidate(string candidate) : Candidate() { + if (!candidate.empty()) + parse(std::move(candidate)); +} + +Candidate::Candidate(string candidate, string mid) : Candidate() { + if (!candidate.empty()) + parse(std::move(candidate)); + if (!mid.empty()) + mMid.emplace(std::move(mid)); +} + +void Candidate::parse(string candidate) { + using TypeMap_t = std::unordered_map; + using TcpTypeMap_t = std::unordered_map; + + static const TypeMap_t TypeMap = {{"host", Type::Host}, + {"srflx", Type::ServerReflexive}, + {"prflx", Type::PeerReflexive}, + {"relay", Type::Relayed}}; + + static const TcpTypeMap_t TcpTypeMap = {{"active", TransportType::TcpActive}, + {"passive", TransportType::TcpPassive}, + {"so", TransportType::TcpSo}}; + + const std::array prefixes{"a=", "candidate:"}; + for (string prefix : prefixes) + if (match_prefix(candidate, prefix)) + candidate.erase(0, prefix.size()); + + PLOG_VERBOSE << "Parsing candidate: " << candidate; + + // See RFC 8445 for format + std::istringstream iss(candidate); + string typ_; + if (!(iss >> mFoundation >> mComponent >> mTransportString >> mPriority && + iss >> mNode >> mService >> typ_ >> mTypeString && typ_ == "typ")) + throw std::invalid_argument("Invalid candidate format"); + + std::getline(iss, mTail); + trim_begin(mTail); + trim_end(mTail); + + if (auto it = TypeMap.find(mTypeString); it != TypeMap.end()) + mType = it->second; + else + mType = Type::Unknown; + + if (mTransportString == "UDP" || mTransportString == "udp") { + mTransportType = TransportType::Udp; + } else if (mTransportString == "TCP" || mTransportString == "tcp") { + // Peek tail to find TCP type + std::istringstream tiss(mTail); + string tcptype_, tcptype; + if (tiss >> tcptype_ >> tcptype && tcptype_ == "tcptype") { + if (auto it = TcpTypeMap.find(tcptype); it != TcpTypeMap.end()) + mTransportType = it->second; + else + mTransportType = TransportType::TcpUnknown; + + } else { + mTransportType = TransportType::TcpUnknown; + } + } else { + mTransportType = TransportType::Unknown; + } +} + +void Candidate::hintMid(string mid) { + if (!mMid) + mMid.emplace(std::move(mid)); +} + +void Candidate::changeAddress(string addr) { changeAddress(std::move(addr), mService); } + +void Candidate::changeAddress(string addr, uint16_t port) { + changeAddress(std::move(addr), std::to_string(port)); +} + +void Candidate::changeAddress(string addr, string service) { + mNode = std::move(addr); + mService = std::move(service); + + mFamily = Family::Unresolved; + mAddress.clear(); + mPort = 0; + + if (!resolve(ResolveMode::Simple)) + throw std::invalid_argument("Invalid candidate address \"" + addr + ":" + service + "\""); +} + +bool Candidate::resolve(ResolveMode mode) { + PLOG_VERBOSE << "Resolving candidate (mode=" + << (mode == ResolveMode::Simple ? "simple" : "lookup") << "): " << mNode << ' ' + << mService; + + // Try to resolve the node and service + struct addrinfo hints = {}; + hints.ai_family = AF_UNSPEC; + hints.ai_flags = AI_ADDRCONFIG; + if (mTransportType == TransportType::Udp) { + hints.ai_socktype = SOCK_DGRAM; + hints.ai_protocol = IPPROTO_UDP; + } else if (mTransportType != TransportType::Unknown) { + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + } + + if (mode == ResolveMode::Simple) + hints.ai_flags |= AI_NUMERICHOST; + + struct addrinfo *result = nullptr; + if (getaddrinfo(mNode.c_str(), mService.c_str(), &hints, &result) == 0) { + for (auto p = result; p; p = p->ai_next) { + if (p->ai_family == AF_INET || p->ai_family == AF_INET6) { + char nodebuffer[MAX_NUMERICNODE_LEN]; + char servbuffer[MAX_NUMERICSERV_LEN]; + if (getnameinfo(p->ai_addr, socklen_t(p->ai_addrlen), nodebuffer, + MAX_NUMERICNODE_LEN, servbuffer, MAX_NUMERICSERV_LEN, + NI_NUMERICHOST | NI_NUMERICSERV) == 0) { + try { + mPort = uint16_t(std::stoul(servbuffer)); + } catch (...) { + return false; + } + mAddress = nodebuffer; + mFamily = p->ai_family == AF_INET6 ? Family::Ipv6 : Family::Ipv4; + PLOG_VERBOSE << "Resolved candidate: " << mAddress << ' ' << mPort; + break; + } + } + } + + freeaddrinfo(result); + } + + return mFamily != Family::Unresolved; +} + +Candidate::Type Candidate::type() const { return mType; } + +Candidate::TransportType Candidate::transportType() const { return mTransportType; } + +uint32_t Candidate::priority() const { return mPriority; } + +string Candidate::candidate() const { + const char sp{' '}; + std::ostringstream oss; + oss << "candidate:"; + oss << mFoundation << sp << mComponent << sp << mTransportString << sp << mPriority << sp; + if (isResolved()) + oss << mAddress << sp << mPort; + else + oss << mNode << sp << mService; + + oss << sp << "typ" << sp << mTypeString; + + if (!mTail.empty()) + oss << sp << mTail; + + return oss.str(); +} + +string Candidate::mid() const { return mMid.value_or("0"); } + +Candidate::operator string() const { + std::ostringstream line; + line << "a=" << candidate(); + return line.str(); +} + +bool Candidate::operator==(const Candidate &other) const { + return (mFoundation == other.mFoundation && mService == other.mService && mNode == other.mNode); +} + +bool Candidate::operator!=(const Candidate &other) const { + return mFoundation != other.mFoundation; +} + +bool Candidate::isResolved() const { return mFamily != Family::Unresolved; } + +Candidate::Family Candidate::family() const { return mFamily; } + +optional Candidate::address() const { + return isResolved() ? std::make_optional(mAddress) : nullopt; +} + +optional Candidate::port() const { + return isResolved() ? std::make_optional(mPort) : nullopt; +} + +std::ostream &operator<<(std::ostream &out, const Candidate &candidate) { + return out << string(candidate); +} + +std::ostream &operator<<(std::ostream &out, const Candidate::Type &type) { + switch (type) { + case Candidate::Type::Host: + return out << "host"; + case Candidate::Type::PeerReflexive: + return out << "prflx"; + case Candidate::Type::ServerReflexive: + return out << "srflx"; + case Candidate::Type::Relayed: + return out << "relay"; + default: + return out << "unknown"; + } +} + +std::ostream &operator<<(std::ostream &out, const Candidate::TransportType &transportType) { + switch (transportType) { + case Candidate::TransportType::Udp: + return out << "UDP"; + case Candidate::TransportType::TcpActive: + return out << "TCP_active"; + case Candidate::TransportType::TcpPassive: + return out << "TCP_passive"; + case Candidate::TransportType::TcpSo: + return out << "TCP_so"; + case Candidate::TransportType::TcpUnknown: + return out << "TCP_unknown"; + default: + return out << "unknown"; + } +} + +} // namespace rtc diff --git a/datachannel/src/capi.cpp b/datachannel/src/capi.cpp new file mode 100644 index 000000000..b6942214a --- /dev/null +++ b/datachannel/src/capi.cpp @@ -0,0 +1,1673 @@ +/** + * Copyright (c) 2019-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "rtc.h" +#include "rtc.hpp" + +#include "impl/internals.hpp" + +#include +#include +#include +#include +#include +#include +#include + +using namespace rtc; +using namespace std::chrono_literals; +using std::chrono::milliseconds; + +namespace { + +std::unordered_map> peerConnectionMap; +std::unordered_map> dataChannelMap; +std::unordered_map> trackMap; +#if RTC_ENABLE_MEDIA +std::unordered_map> rtcpSrReporterMap; +std::unordered_map> rtpConfigMap; +#endif +#if RTC_ENABLE_WEBSOCKET +std::unordered_map> webSocketMap; +std::unordered_map> webSocketServerMap; +#endif +std::unordered_map userPointerMap; +std::mutex mutex; +int lastId = 0; + +optional getUserPointer(int id) { + std::lock_guard lock(mutex); + auto it = userPointerMap.find(id); + return it != userPointerMap.end() ? std::make_optional(it->second) : nullopt; +} + +void setUserPointer(int i, void *ptr) { + std::lock_guard lock(mutex); + userPointerMap[i] = ptr; +} + +shared_ptr getPeerConnection(int id) { + std::lock_guard lock(mutex); + if (auto it = peerConnectionMap.find(id); it != peerConnectionMap.end()) + return it->second; + else + throw std::invalid_argument("PeerConnection ID does not exist"); +} + +shared_ptr getDataChannel(int id) { + std::lock_guard lock(mutex); + if (auto it = dataChannelMap.find(id); it != dataChannelMap.end()) + return it->second; + else + throw std::invalid_argument("DataChannel ID does not exist"); +} + +shared_ptr getTrack(int id) { + std::lock_guard lock(mutex); + if (auto it = trackMap.find(id); it != trackMap.end()) + return it->second; + else + throw std::invalid_argument("Track ID does not exist"); +} + +int emplacePeerConnection(shared_ptr ptr) { + std::lock_guard lock(mutex); + int pc = ++lastId; + peerConnectionMap.emplace(std::make_pair(pc, ptr)); + userPointerMap.emplace(std::make_pair(pc, nullptr)); + return pc; +} + +int emplaceDataChannel(shared_ptr ptr) { + std::lock_guard lock(mutex); + int dc = ++lastId; + dataChannelMap.emplace(std::make_pair(dc, ptr)); + userPointerMap.emplace(std::make_pair(dc, nullptr)); + return dc; +} + +int emplaceTrack(shared_ptr ptr) { + std::lock_guard lock(mutex); + int tr = ++lastId; + trackMap.emplace(std::make_pair(tr, ptr)); + userPointerMap.emplace(std::make_pair(tr, nullptr)); + return tr; +} + +void erasePeerConnection(int pc) { + std::lock_guard lock(mutex); + if (peerConnectionMap.erase(pc) == 0) + throw std::invalid_argument("Peer Connection ID does not exist"); + userPointerMap.erase(pc); +} + +void eraseDataChannel(int dc) { + std::lock_guard lock(mutex); + if (dataChannelMap.erase(dc) == 0) + throw std::invalid_argument("Data Channel ID does not exist"); + userPointerMap.erase(dc); +} + +void eraseTrack(int tr) { + std::lock_guard lock(mutex); + if (trackMap.erase(tr) == 0) + throw std::invalid_argument("Track ID does not exist"); +#if RTC_ENABLE_MEDIA + rtcpSrReporterMap.erase(tr); + rtpConfigMap.erase(tr); +#endif + userPointerMap.erase(tr); +} + +size_t eraseAll() { + std::lock_guard lock(mutex); + size_t count = dataChannelMap.size() + trackMap.size() + peerConnectionMap.size(); + dataChannelMap.clear(); + trackMap.clear(); + peerConnectionMap.clear(); +#if RTC_ENABLE_MEDIA + count += rtcpSrReporterMap.size() + rtpConfigMap.size(); + rtcpSrReporterMap.clear(); + rtpConfigMap.clear(); +#endif +#if RTC_ENABLE_WEBSOCKET + count += webSocketMap.size() + webSocketServerMap.size(); + webSocketMap.clear(); + webSocketServerMap.clear(); +#endif + userPointerMap.clear(); + return count; +} + +shared_ptr getChannel(int id) { + std::lock_guard lock(mutex); + if (auto it = dataChannelMap.find(id); it != dataChannelMap.end()) + return it->second; + if (auto it = trackMap.find(id); it != trackMap.end()) + return it->second; +#if RTC_ENABLE_WEBSOCKET + if (auto it = webSocketMap.find(id); it != webSocketMap.end()) + return it->second; +#endif + throw std::invalid_argument("DataChannel, Track, or WebSocket ID does not exist"); +} + +void eraseChannel(int id) { + std::lock_guard lock(mutex); + if (dataChannelMap.erase(id) != 0) { + userPointerMap.erase(id); + return; + } + if (trackMap.erase(id) != 0) { + userPointerMap.erase(id); +#if RTC_ENABLE_MEDIA + rtcpSrReporterMap.erase(id); + rtpConfigMap.erase(id); +#endif + return; + } +#if RTC_ENABLE_WEBSOCKET + if (webSocketMap.erase(id) != 0) { + userPointerMap.erase(id); + return; + } +#endif + throw std::invalid_argument("DataChannel, Track, or WebSocket ID does not exist"); +} + +int copyAndReturn(string s, char *buffer, int size) { + if (!buffer) + return int(s.size() + 1); + + if (size < int(s.size() + 1)) + return RTC_ERR_TOO_SMALL; + + std::copy(s.begin(), s.end(), buffer); + buffer[s.size()] = '\0'; + return int(s.size() + 1); +} + +int copyAndReturn(binary b, char *buffer, int size) { + if (!buffer) + return int(b.size()); + + if (size < int(b.size())) + return RTC_ERR_TOO_SMALL; + + auto data = reinterpret_cast(b.data()); + std::copy(data, data + b.size(), buffer); + return int(b.size()); +} + +template int copyAndReturn(std::vector b, T *buffer, int size) { + if (!buffer) + return int(b.size()); + + if (size < int(b.size())) + return RTC_ERR_TOO_SMALL; + std::copy(b.begin(), b.end(), buffer); + return int(b.size()); +} + +template int wrap(F func) { + try { + return int(func()); + + } catch (const std::invalid_argument &e) { + PLOG_ERROR << e.what(); + return RTC_ERR_INVALID; + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + return RTC_ERR_FAILURE; + } +} + +#if RTC_ENABLE_MEDIA + +string lowercased(string str) { + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + return str; +} + +shared_ptr getRtcpSrReporter(int id) { + std::lock_guard lock(mutex); + if (auto it = rtcpSrReporterMap.find(id); it != rtcpSrReporterMap.end()) { + return it->second; + } else { + throw std::invalid_argument("RTCP SR reporter ID does not exist"); + } +} + +void emplaceRtcpSrReporter(shared_ptr ptr, int tr) { + std::lock_guard lock(mutex); + rtcpSrReporterMap.emplace(std::make_pair(tr, ptr)); +} + +shared_ptr getRtpConfig(int id) { + std::lock_guard lock(mutex); + if (auto it = rtpConfigMap.find(id); it != rtpConfigMap.end()) { + return it->second; + } else { + throw std::invalid_argument("RTP configuration ID does not exist"); + } +} + +void emplaceRtpConfig(shared_ptr ptr, int tr) { + std::lock_guard lock(mutex); + rtpConfigMap.emplace(std::make_pair(tr, ptr)); +} + +shared_ptr +createRtpPacketizationConfig(const rtcPacketizationHandlerInit *init) { + if (!init) + throw std::invalid_argument("Unexpected null pointer for packetization handler init"); + + if (!init->cname) + throw std::invalid_argument("Unexpected null pointer for cname"); + + auto config = std::make_shared(init->ssrc, init->cname, + init->payloadType, init->clockRate); + config->sequenceNumber = init->sequenceNumber; + config->timestamp = init->timestamp; + return config; +} + +class MediaInterceptor final : public MediaHandler { +public: + using MessageCallback = std::function; + + MediaInterceptor(MessageCallback cb) : incomingCallback(cb) {} + + // Called when there is traffic coming from the peer + void incoming(message_vector &messages, + [[maybe_unused]] const message_callback &send) override { + // If no callback is provided, just forward the message on + if (!incomingCallback) + return; + + message_vector result; + for (auto &msg : messages) { + auto res = incomingCallback(reinterpret_cast(msg->data()), int(msg->size())); + + // If a null pointer was returned, drop the incoming message + if (!res) + continue; + + if (res == msg->data()) { + // If the original data pointer was returned, forward the incoming message + result.push_back(std::move(msg)); + } else { + // else construct a true message_ptr from the returned opaque pointer + result.push_back( + make_message_from_opaque_ptr(std::move(reinterpret_cast(res)))); + } + } + } + +private: + MessageCallback incomingCallback; +}; + +#endif // RTC_ENABLE_MEDIA + +#if RTC_ENABLE_WEBSOCKET + +shared_ptr getWebSocket(int id) { + std::lock_guard lock(mutex); + if (auto it = webSocketMap.find(id); it != webSocketMap.end()) + return it->second; + else + throw std::invalid_argument("WebSocket ID does not exist"); +} + +int emplaceWebSocket(shared_ptr ptr) { + std::lock_guard lock(mutex); + int ws = ++lastId; + webSocketMap.emplace(std::make_pair(ws, ptr)); + userPointerMap.emplace(std::make_pair(ws, nullptr)); + return ws; +} + +void eraseWebSocket(int ws) { + std::lock_guard lock(mutex); + if (webSocketMap.erase(ws) == 0) + throw std::invalid_argument("WebSocket ID does not exist"); + userPointerMap.erase(ws); +} + +shared_ptr getWebSocketServer(int id) { + std::lock_guard lock(mutex); + if (auto it = webSocketServerMap.find(id); it != webSocketServerMap.end()) + return it->second; + else + throw std::invalid_argument("WebSocketServer ID does not exist"); +} + +int emplaceWebSocketServer(shared_ptr ptr) { + std::lock_guard lock(mutex); + int wsserver = ++lastId; + webSocketServerMap.emplace(std::make_pair(wsserver, ptr)); + userPointerMap.emplace(std::make_pair(wsserver, nullptr)); + return wsserver; +} + +void eraseWebSocketServer(int wsserver) { + std::lock_guard lock(mutex); + if (webSocketServerMap.erase(wsserver) == 0) + throw std::invalid_argument("WebSocketServer ID does not exist"); + userPointerMap.erase(wsserver); +} + +#endif + +} // namespace + +void rtcInitLogger(rtcLogLevel level, rtcLogCallbackFunc cb) { + LogCallback callback = nullptr; + if (cb) + callback = [cb](LogLevel level, string message) { + cb(static_cast(level), message.c_str()); + }; + + InitLogger(static_cast(level), callback); +} + +void rtcSetUserPointer(int i, void *ptr) { setUserPointer(i, ptr); } + +void *rtcGetUserPointer(int i) { return getUserPointer(i).value_or(nullptr); } + +int rtcCreatePeerConnection(const rtcConfiguration *config) { + return wrap([config] { + Configuration c; + for (int i = 0; i < config->iceServersCount; ++i) + c.iceServers.emplace_back(string(config->iceServers[i])); + + if (config->proxyServer) + c.proxyServer.emplace(config->proxyServer); + + if (config->bindAddress) + c.bindAddress = string(config->bindAddress); + + if (config->portRangeBegin > 0 || config->portRangeEnd > 0) { + c.portRangeBegin = config->portRangeBegin; + c.portRangeEnd = config->portRangeEnd; + } + + c.certificateType = static_cast(config->certificateType); + c.iceTransportPolicy = static_cast(config->iceTransportPolicy); + c.enableIceTcp = config->enableIceTcp; + c.enableIceUdpMux = config->enableIceUdpMux; + c.disableAutoNegotiation = config->disableAutoNegotiation; + c.forceMediaTransport = config->forceMediaTransport; + + if (config->mtu > 0) + c.mtu = size_t(config->mtu); + + if (config->maxMessageSize) + c.maxMessageSize = size_t(config->maxMessageSize); + + return emplacePeerConnection(std::make_shared(std::move(c))); + }); +} + +int rtcClosePeerConnection(int pc) { + return wrap([pc] { + auto peerConnection = getPeerConnection(pc); + peerConnection->close(); + return RTC_ERR_SUCCESS; + }); +} + +int rtcDeletePeerConnection(int pc) { + return wrap([pc] { + auto peerConnection = getPeerConnection(pc); + peerConnection->close(); + erasePeerConnection(pc); + return RTC_ERR_SUCCESS; + }); +} + +int rtcSetLocalDescriptionCallback(int pc, rtcDescriptionCallbackFunc cb) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + if (cb) + peerConnection->onLocalDescription([pc, cb](Description desc) { + if (auto ptr = getUserPointer(pc)) + cb(pc, string(desc).c_str(), desc.typeString().c_str(), *ptr); + }); + else + peerConnection->onLocalDescription(nullptr); + return RTC_ERR_SUCCESS; + }); +} + +int rtcSetLocalCandidateCallback(int pc, rtcCandidateCallbackFunc cb) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + if (cb) + peerConnection->onLocalCandidate([pc, cb](Candidate cand) { + if (auto ptr = getUserPointer(pc)) + cb(pc, cand.candidate().c_str(), cand.mid().c_str(), *ptr); + }); + else + peerConnection->onLocalCandidate(nullptr); + return RTC_ERR_SUCCESS; + }); +} + +int rtcSetStateChangeCallback(int pc, rtcStateChangeCallbackFunc cb) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + if (cb) + peerConnection->onStateChange([pc, cb](PeerConnection::State state) { + if (auto ptr = getUserPointer(pc)) + cb(pc, static_cast(state), *ptr); + }); + else + peerConnection->onStateChange(nullptr); + return RTC_ERR_SUCCESS; + }); +} + +int rtcSetIceStateChangeCallback(int pc, rtcIceStateChangeCallbackFunc cb) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + if (cb) + peerConnection->onIceStateChange([pc, cb](PeerConnection::IceState state) { + if (auto ptr = getUserPointer(pc)) + cb(pc, static_cast(state), *ptr); + }); + else + peerConnection->onIceStateChange(nullptr); + return RTC_ERR_SUCCESS; + }); +} + +int rtcSetGatheringStateChangeCallback(int pc, rtcGatheringStateCallbackFunc cb) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + if (cb) + peerConnection->onGatheringStateChange([pc, cb](PeerConnection::GatheringState state) { + if (auto ptr = getUserPointer(pc)) + cb(pc, static_cast(state), *ptr); + }); + else + peerConnection->onGatheringStateChange(nullptr); + return RTC_ERR_SUCCESS; + }); +} + +int rtcSetSignalingStateChangeCallback(int pc, rtcSignalingStateCallbackFunc cb) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + if (cb) + peerConnection->onSignalingStateChange([pc, cb](PeerConnection::SignalingState state) { + if (auto ptr = getUserPointer(pc)) + cb(pc, static_cast(state), *ptr); + }); + else + peerConnection->onSignalingStateChange(nullptr); + return RTC_ERR_SUCCESS; + }); +} + +int rtcSetDataChannelCallback(int pc, rtcDataChannelCallbackFunc cb) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + if (cb) + peerConnection->onDataChannel([pc, cb](shared_ptr dataChannel) { + int dc = emplaceDataChannel(dataChannel); + if (auto ptr = getUserPointer(pc)) { + rtcSetUserPointer(dc, *ptr); + cb(pc, dc, *ptr); + } + }); + else + peerConnection->onDataChannel(nullptr); + return RTC_ERR_SUCCESS; + }); +} + +int rtcSetTrackCallback(int pc, rtcTrackCallbackFunc cb) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + if (cb) + peerConnection->onTrack([pc, cb](shared_ptr track) { + int tr = emplaceTrack(track); + if (auto ptr = getUserPointer(pc)) { + rtcSetUserPointer(tr, *ptr); + cb(pc, tr, *ptr); + } + }); + else + peerConnection->onTrack(nullptr); + return RTC_ERR_SUCCESS; + }); +} + +int rtcSetLocalDescription(int pc, const char *type) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + peerConnection->setLocalDescription(type ? Description::stringToType(type) + : Description::Type::Unspec); + return RTC_ERR_SUCCESS; + }); +} + +int rtcSetRemoteDescription(int pc, const char *sdp, const char *type) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + + if (!sdp) + throw std::invalid_argument("Unexpected null pointer for remote description"); + + peerConnection->setRemoteDescription({string(sdp), type ? string(type) : ""}); + return RTC_ERR_SUCCESS; + }); +} + +int rtcAddRemoteCandidate(int pc, const char *cand, const char *mid) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + + if (!cand) + throw std::invalid_argument("Unexpected null pointer for remote candidate"); + + peerConnection->addRemoteCandidate({string(cand), mid ? string(mid) : ""}); + return RTC_ERR_SUCCESS; + }); +} + +int rtcGetLocalDescription(int pc, char *buffer, int size) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + + if (auto desc = peerConnection->localDescription()) + return copyAndReturn(string(*desc), buffer, size); + else + return RTC_ERR_NOT_AVAIL; + }); +} + +int rtcGetRemoteDescription(int pc, char *buffer, int size) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + + if (auto desc = peerConnection->remoteDescription()) + return copyAndReturn(string(*desc), buffer, size); + else + return RTC_ERR_NOT_AVAIL; + }); +} + +int rtcGetLocalDescriptionType(int pc, char *buffer, int size) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + + if (auto desc = peerConnection->localDescription()) + return copyAndReturn(desc->typeString(), buffer, size); + else + return RTC_ERR_NOT_AVAIL; + }); +} + +int rtcGetRemoteDescriptionType(int pc, char *buffer, int size) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + + if (auto desc = peerConnection->remoteDescription()) + return copyAndReturn(desc->typeString(), buffer, size); + else + return RTC_ERR_NOT_AVAIL; + }); +} + +int rtcGetLocalAddress(int pc, char *buffer, int size) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + + if (auto addr = peerConnection->localAddress()) + return copyAndReturn(std::move(*addr), buffer, size); + else + return RTC_ERR_NOT_AVAIL; + }); +} + +int rtcGetRemoteAddress(int pc, char *buffer, int size) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + + if (auto addr = peerConnection->remoteAddress()) + return copyAndReturn(std::move(*addr), buffer, size); + else + return RTC_ERR_NOT_AVAIL; + }); +} + +int rtcGetSelectedCandidatePair(int pc, char *local, int localSize, char *remote, int remoteSize) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + + Candidate localCand; + Candidate remoteCand; + if (!peerConnection->getSelectedCandidatePair(&localCand, &remoteCand)) + return RTC_ERR_NOT_AVAIL; + + int localRet = copyAndReturn(string(localCand), local, localSize); + if (localRet < 0) + return localRet; + + int remoteRet = copyAndReturn(string(remoteCand), remote, remoteSize); + if (remoteRet < 0) + return remoteRet; + + return std::max(localRet, remoteRet); + }); +} + +int rtcGetMaxDataChannelStream(int pc) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + return int(peerConnection->maxDataChannelId()); + }); +} + +int rtcGetRemoteMaxMessageSize(int pc) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + return int(peerConnection->remoteMaxMessageSize()); + }); +} + +int rtcSetOpenCallback(int id, rtcOpenCallbackFunc cb) { + return wrap([&] { + auto channel = getChannel(id); + if (cb) + channel->onOpen([id, cb]() { + if (auto ptr = getUserPointer(id)) + cb(id, *ptr); + }); + else + channel->onOpen(nullptr); + return RTC_ERR_SUCCESS; + }); +} + +int rtcSetClosedCallback(int id, rtcClosedCallbackFunc cb) { + return wrap([&] { + auto channel = getChannel(id); + if (cb) + channel->onClosed([id, cb]() { + if (auto ptr = getUserPointer(id)) + cb(id, *ptr); + }); + else + channel->onClosed(nullptr); + return RTC_ERR_SUCCESS; + }); +} + +int rtcSetErrorCallback(int id, rtcErrorCallbackFunc cb) { + return wrap([&] { + auto channel = getChannel(id); + if (cb) + channel->onError([id, cb](string error) { + if (auto ptr = getUserPointer(id)) + cb(id, error.c_str(), *ptr); + }); + else + channel->onError(nullptr); + return RTC_ERR_SUCCESS; + }); +} + +int rtcSetMessageCallback(int id, rtcMessageCallbackFunc cb) { + return wrap([&] { + auto channel = getChannel(id); + if (cb) + channel->onMessage( + [id, cb](binary b) { + if (auto ptr = getUserPointer(id)) + cb(id, reinterpret_cast(b.data()), int(b.size()), *ptr); + }, + [id, cb](string s) { + if (auto ptr = getUserPointer(id)) + cb(id, s.c_str(), -int(s.size() + 1), *ptr); + }); + else + channel->onMessage(nullptr); + return RTC_ERR_SUCCESS; + }); +} + +int rtcSendMessage(int id, const char *data, int size) { + return wrap([&] { + auto channel = getChannel(id); + + if (!data && size != 0) + throw std::invalid_argument("Unexpected null pointer for data"); + + if (size >= 0) { + auto b = reinterpret_cast(data); + channel->send(binary(b, b + size)); + } else { + channel->send(string(data)); + } + return RTC_ERR_SUCCESS; + }); +} + +int rtcClose(int id) { + return wrap([&] { + auto channel = getChannel(id); + channel->close(); + return RTC_ERR_SUCCESS; + }); +} + +int rtcDelete(int id) { + return wrap([id] { + auto channel = getChannel(id); + channel->close(); + eraseChannel(id); + return RTC_ERR_SUCCESS; + }); +} + +bool rtcIsOpen(int id) { + return wrap([id] { return getChannel(id)->isOpen() ? 0 : 1; }) == 0 ? true : false; +} + +bool rtcIsClosed(int id) { + return wrap([id] { return getChannel(id)->isClosed() ? 0 : 1; }) == 0 ? true : false; +} + +int rtcMaxMessageSize(int id) { + return wrap([id] { + auto channel = getChannel(id); + return int(channel->maxMessageSize()); + }); +} + +int rtcGetBufferedAmount(int id) { + return wrap([id] { + auto channel = getChannel(id); + return int(channel->bufferedAmount()); + }); +} + +int rtcSetBufferedAmountLowThreshold(int id, int amount) { + return wrap([&] { + auto channel = getChannel(id); + channel->setBufferedAmountLowThreshold(size_t(amount)); + return RTC_ERR_SUCCESS; + }); +} + +int rtcSetBufferedAmountLowCallback(int id, rtcBufferedAmountLowCallbackFunc cb) { + return wrap([&] { + auto channel = getChannel(id); + if (cb) + channel->onBufferedAmountLow([id, cb]() { + if (auto ptr = getUserPointer(id)) + cb(id, *ptr); + }); + else + channel->onBufferedAmountLow(nullptr); + return RTC_ERR_SUCCESS; + }); +} + +int rtcGetAvailableAmount(int id) { + return wrap([id] { return int(getChannel(id)->availableAmount()); }); +} + +int rtcSetAvailableCallback(int id, rtcAvailableCallbackFunc cb) { + return wrap([&] { + auto channel = getChannel(id); + if (cb) + channel->onAvailable([id, cb]() { + if (auto ptr = getUserPointer(id)) + cb(id, *ptr); + }); + else + channel->onAvailable(nullptr); + return RTC_ERR_SUCCESS; + }); +} + +int rtcReceiveMessage(int id, char *buffer, int *size) { + return wrap([&] { + auto channel = getChannel(id); + + if (!size) + throw std::invalid_argument("Unexpected null pointer for size"); + + *size = std::abs(*size); + + auto message = channel->peek(); + if (!message) + return RTC_ERR_NOT_AVAIL; + + return std::visit( // + overloaded{ + [&](binary b) { + int ret = copyAndReturn(std::move(b), buffer, *size); + if (ret >= 0) { + *size = ret; + if (buffer) { + channel->receive(); // discard + } + + return RTC_ERR_SUCCESS; + } else { + *size = int(b.size()); + return ret; + } + }, + [&](string s) { + int ret = copyAndReturn(std::move(s), buffer, *size); + if (ret >= 0) { + *size = -ret; + if (buffer) { + channel->receive(); // discard + } + + return RTC_ERR_SUCCESS; + } else { + *size = -int(s.size() + 1); + return ret; + } + }, + }, + *message); + }); +} + +int rtcCreateDataChannel(int pc, const char *label) { + return rtcCreateDataChannelEx(pc, label, nullptr); +} + +int rtcCreateDataChannelEx(int pc, const char *label, const rtcDataChannelInit *init) { + return wrap([&] { + DataChannelInit dci = {}; + if (init) { + auto *reliability = &init->reliability; + dci.reliability.unordered = reliability->unordered; + if (reliability->unreliable) { + if (reliability->maxPacketLifeTime > 0) + dci.reliability.maxPacketLifeTime.emplace(milliseconds(reliability->maxPacketLifeTime)); + else + dci.reliability.maxRetransmits.emplace(reliability->maxRetransmits); + } + + dci.negotiated = init->negotiated; + dci.id = init->manualStream ? std::make_optional(init->stream) : nullopt; + dci.protocol = init->protocol ? init->protocol : ""; + } + + auto peerConnection = getPeerConnection(pc); + int dc = emplaceDataChannel( + peerConnection->createDataChannel(string(label ? label : ""), std::move(dci))); + + if (auto ptr = getUserPointer(pc)) + rtcSetUserPointer(dc, *ptr); + + return dc; + }); +} + +int rtcDeleteDataChannel(int dc) { + return wrap([dc] { + auto dataChannel = getDataChannel(dc); + dataChannel->close(); + eraseDataChannel(dc); + return RTC_ERR_SUCCESS; + }); +} + +int rtcGetDataChannelStream(int dc) { + return wrap([dc] { + auto dataChannel = getDataChannel(dc); + if (auto stream = dataChannel->stream()) + return int(*stream); + else + return RTC_ERR_NOT_AVAIL; + }); +} + +int rtcGetDataChannelLabel(int dc, char *buffer, int size) { + return wrap([&] { + auto dataChannel = getDataChannel(dc); + return copyAndReturn(dataChannel->label(), buffer, size); + }); +} + +int rtcGetDataChannelProtocol(int dc, char *buffer, int size) { + return wrap([&] { + auto dataChannel = getDataChannel(dc); + return copyAndReturn(dataChannel->protocol(), buffer, size); + }); +} + +int rtcGetDataChannelReliability(int dc, rtcReliability *reliability) { + return wrap([&] { + auto dataChannel = getDataChannel(dc); + + if (!reliability) + throw std::invalid_argument("Unexpected null pointer for reliability"); + + Reliability dcr = dataChannel->reliability(); + std::memset(reliability, 0, sizeof(*reliability)); + reliability->unordered = dcr.unordered; + if(dcr.maxPacketLifeTime) { + reliability->unreliable = true; + reliability->maxPacketLifeTime = static_cast(dcr.maxPacketLifeTime->count()); + } else if (dcr.maxRetransmits) { + reliability->unreliable = true; + reliability->maxRetransmits = *dcr.maxRetransmits; + } else { + reliability->unreliable = false; + } + return RTC_ERR_SUCCESS; + }); +} + +int rtcAddTrack(int pc, const char *mediaDescriptionSdp) { + return wrap([&] { + if (!mediaDescriptionSdp) + throw std::invalid_argument("Unexpected null pointer for track media description"); + + auto peerConnection = getPeerConnection(pc); + Description::Media media{string(mediaDescriptionSdp)}; + int tr = emplaceTrack(peerConnection->addTrack(std::move(media))); + if (auto ptr = getUserPointer(pc)) + rtcSetUserPointer(tr, *ptr); + + return tr; + }); +} + +int rtcAddTrackEx(int pc, const rtcTrackInit *init) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + + if (!init) + throw std::invalid_argument("Unexpected null pointer for track init"); + + auto direction = static_cast(init->direction); + + string mid; + if (init->mid) { + mid = string(init->mid); + } else { + switch (init->codec) { + case RTC_CODEC_AV1: + case RTC_CODEC_H264: + case RTC_CODEC_H265: + case RTC_CODEC_VP8: + case RTC_CODEC_VP9: + mid = "video"; + break; + case RTC_CODEC_OPUS: + case RTC_CODEC_PCMU: + case RTC_CODEC_PCMA: + case RTC_CODEC_AAC: + mid = "audio"; + break; + default: + mid = "video"; + break; + } + } + + int pt = init->payloadType; + auto profile = init->profile ? std::make_optional(string(init->profile)) : nullopt; + + unique_ptr description; + switch (init->codec) { + case RTC_CODEC_AV1: + case RTC_CODEC_H264: + case RTC_CODEC_H265: + case RTC_CODEC_VP8: + case RTC_CODEC_VP9: { + auto video = std::make_unique(mid, direction); + switch (init->codec) { + case RTC_CODEC_AV1: + video->addAV1Codec(pt, profile); + break; + case RTC_CODEC_H264: + video->addH264Codec(pt, profile); + break; + case RTC_CODEC_H265: + video->addH265Codec(pt, profile); + break; + case RTC_CODEC_VP8: + video->addVP8Codec(pt, profile); + break; + case RTC_CODEC_VP9: + video->addVP9Codec(pt, profile); + break; + default: + break; + } + description = std::move(video); + break; + } + case RTC_CODEC_OPUS: + case RTC_CODEC_PCMU: + case RTC_CODEC_PCMA: + case RTC_CODEC_AAC: { + auto audio = std::make_unique(mid, direction); + switch (init->codec) { + case RTC_CODEC_OPUS: + audio->addOpusCodec(pt, profile); + break; + case RTC_CODEC_PCMU: + audio->addPCMUCodec(pt, profile); + break; + case RTC_CODEC_PCMA: + audio->addPCMACodec(pt, profile); + break; + case RTC_CODEC_AAC: + audio->addAACCodec(pt, profile); + break; + default: + break; + } + description = std::move(audio); + break; + } + default: + break; + } + + if (!description) + throw std::invalid_argument("Unexpected codec"); + + description->addSSRC(init->ssrc, + init->name ? std::make_optional(string(init->name)) : nullopt, + init->msid ? std::make_optional(string(init->msid)) : nullopt, + init->trackId ? std::make_optional(string(init->trackId)) : nullopt); + + int tr = emplaceTrack(peerConnection->addTrack(std::move(*description))); + + if (auto ptr = getUserPointer(pc)) + rtcSetUserPointer(tr, *ptr); + + return tr; + }); +} + +int rtcDeleteTrack(int tr) { + return wrap([&] { + auto track = getTrack(tr); + track->close(); + eraseTrack(tr); + return RTC_ERR_SUCCESS; + }); +} + +int rtcGetTrackDescription(int tr, char *buffer, int size) { + return wrap([&] { + auto track = getTrack(tr); + return copyAndReturn(track->description(), buffer, size); + }); +} + +int rtcGetTrackMid(int tr, char *buffer, int size) { + return wrap([&] { + auto track = getTrack(tr); + return copyAndReturn(track->mid(), buffer, size); + }); +} + +int rtcGetTrackDirection(int tr, rtcDirection *direction) { + return wrap([&] { + if (!direction) + throw std::invalid_argument("Unexpected null pointer for track direction"); + + auto track = getTrack(tr); + *direction = static_cast(track->direction()); + return RTC_ERR_SUCCESS; + }); +} + +int rtcRequestKeyframe(int tr) { + return wrap([&] { + auto track = getTrack(tr); + track->requestKeyframe(); + return RTC_ERR_SUCCESS; + }); +} + +int rtcRequestBitrate(int tr, unsigned int bitrate) { + return wrap([&] { + auto track = getTrack(tr); + track->requestBitrate(bitrate); + return RTC_ERR_SUCCESS; + }); +} + +#if RTC_ENABLE_MEDIA + +void setSSRC(Description::Media *description, uint32_t ssrc, const char *_name, const char *_msid, + const char *_trackID) { + + optional name = nullopt; + if (_name) { + name = string(_name); + } + + optional msid = nullopt; + if (_msid) { + msid = string(_msid); + } + + optional trackID = nullopt; + if (_trackID) { + trackID = string(_trackID); + } + + description->addSSRC(ssrc, name, msid, trackID); +} + +rtcMessage *rtcCreateOpaqueMessage(void *data, int size) { + auto src = reinterpret_cast(data); + auto msg = new Message(src, src + size); + // Downgrade the message pointer to the opaque rtcMessage* type + return reinterpret_cast(msg); +} + +void rtcDeleteOpaqueMessage(rtcMessage *msg) { + // Cast the opaque pointer back to it's true type before deleting + delete reinterpret_cast(msg); +} + +int rtcSetMediaInterceptorCallback(int pc, rtcInterceptorCallbackFunc cb) { + return wrap([&] { + auto peerConnection = getPeerConnection(pc); + + if (cb == nullptr) { + peerConnection->setMediaHandler(nullptr); + return RTC_ERR_SUCCESS; + } + + auto interceptor = std::make_shared([pc, cb](void *data, int size) { + if (auto ptr = getUserPointer(pc)) + return cb(pc, reinterpret_cast(data), size, *ptr); + return data; + }); + + peerConnection->setMediaHandler(interceptor); + + return RTC_ERR_SUCCESS; + }); +} + +int rtcSetH264Packetizer(int tr, const rtcPacketizerInit *init) { + return wrap([&] { + auto track = getTrack(tr); + // create RTP configuration + auto rtpConfig = createRtpPacketizationConfig(init); + emplaceRtpConfig(rtpConfig, tr); + // create packetizer + auto nalSeparator = init ? init->nalSeparator : RTC_NAL_SEPARATOR_LENGTH; + auto maxFragmentSize = init && init->maxFragmentSize ? init->maxFragmentSize + : RTC_DEFAULT_MAX_FRAGMENT_SIZE; + auto packetizer = std::make_shared( + static_cast(nalSeparator), rtpConfig, maxFragmentSize); + track->setMediaHandler(packetizer); + return RTC_ERR_SUCCESS; + }); +} + +int rtcSetH265Packetizer(int tr, const rtcPacketizerInit *init) { + return wrap([&] { + auto track = getTrack(tr); + // create RTP configuration + auto rtpConfig = createRtpPacketizationConfig(init); + // create packetizer + auto nalSeparator = init ? init->nalSeparator : RTC_NAL_SEPARATOR_LENGTH; + auto maxFragmentSize = init && init->maxFragmentSize ? init->maxFragmentSize + : RTC_DEFAULT_MAX_FRAGMENT_SIZE; + auto packetizer = std::make_shared( + static_cast(nalSeparator), rtpConfig, maxFragmentSize); + track->setMediaHandler(packetizer); + return RTC_ERR_SUCCESS; + }); +} + +int rtcSetAV1Packetizer(int tr, const rtcPacketizerInit *init) { + return wrap([&] { + auto track = getTrack(tr); + // create RTP configuration + auto rtpConfig = createRtpPacketizationConfig(init); + // create packetizer + auto maxFragmentSize = init && init->maxFragmentSize ? init->maxFragmentSize + : RTC_DEFAULT_MAX_FRAGMENT_SIZE; + auto packetization = init->obuPacketization == RTC_OBU_PACKETIZED_TEMPORAL_UNIT + ? AV1RtpPacketizer::Packetization::TemporalUnit + : AV1RtpPacketizer::Packetization::Obu; + auto packetizer = + std::make_shared(packetization, rtpConfig, maxFragmentSize); + track->setMediaHandler(packetizer); + return RTC_ERR_SUCCESS; + }); +} + +int rtcSetOpusPacketizer(int tr, const rtcPacketizerInit *init) { + return wrap([&] { + auto track = getTrack(tr); + // create RTP configuration + auto rtpConfig = createRtpPacketizationConfig(init); + emplaceRtpConfig(rtpConfig, tr); + // create packetizer + auto packetizer = std::make_shared(rtpConfig); + track->setMediaHandler(packetizer); + return RTC_ERR_SUCCESS; + }); +} + +int rtcSetAACPacketizer(int tr, const rtcPacketizerInit *init) { + return wrap([&] { + auto track = getTrack(tr); + // create RTP configuration + auto rtpConfig = createRtpPacketizationConfig(init); + // create packetizer + auto packetizer = std::make_shared(rtpConfig); + track->setMediaHandler(packetizer); + return RTC_ERR_SUCCESS; + }); +} + +int rtcChainRtcpReceivingSession(int tr) { + return wrap([&] { + auto track = getTrack(tr); + auto session = std::make_shared(); + track->chainMediaHandler(session); + return RTC_ERR_SUCCESS; + }); +} + +int rtcChainRtcpSrReporter(int tr) { + return wrap([&] { + auto track = getTrack(tr); + auto config = getRtpConfig(tr); + auto reporter = std::make_shared(config); + track->chainMediaHandler(reporter); + emplaceRtcpSrReporter(reporter, tr); + return RTC_ERR_SUCCESS; + }); +} + +int rtcChainRtcpNackResponder(int tr, unsigned int maxStoredPacketsCount) { + return wrap([&] { + auto track = getTrack(tr); + auto responder = std::make_shared(maxStoredPacketsCount); + track->chainMediaHandler(responder); + return RTC_ERR_SUCCESS; + }); +} + +int rtcChainPliHandler(int tr, rtcPliHandlerCallbackFunc cb) { + return wrap([&] { + auto track = getTrack(tr); + auto handler = std::make_shared([tr, cb] { + if (auto ptr = getUserPointer(tr)) + cb(tr, *ptr); + }); + track->chainMediaHandler(handler); + return RTC_ERR_SUCCESS; + }); +} + +int rtcTransformSecondsToTimestamp(int id, double seconds, uint32_t *timestamp) { + return wrap([&] { + auto config = getRtpConfig(id); + if (timestamp) + *timestamp = config->secondsToTimestamp(seconds); + + return RTC_ERR_SUCCESS; + }); +} + +int rtcTransformTimestampToSeconds(int id, uint32_t timestamp, double *seconds) { + return wrap([&] { + auto config = getRtpConfig(id); + if (seconds) + *seconds = config->timestampToSeconds(timestamp); + + return RTC_ERR_SUCCESS; + }); +} + +int rtcGetCurrentTrackTimestamp(int id, uint32_t *timestamp) { + return wrap([&] { + auto config = getRtpConfig(id); + if (timestamp) + *timestamp = config->timestamp; + + return RTC_ERR_SUCCESS; + }); +} + +int rtcSetTrackRtpTimestamp(int id, uint32_t timestamp) { + return wrap([&] { + auto config = getRtpConfig(id); + config->timestamp = timestamp; + return RTC_ERR_SUCCESS; + }); +} + +int rtcGetLastTrackSenderReportTimestamp(int id, uint32_t *timestamp) { + return wrap([&] { + auto sender = getRtcpSrReporter(id); + if (timestamp) + *timestamp = sender->lastReportedTimestamp(); + + return RTC_ERR_SUCCESS; + }); +} + +int rtcSetNeedsToSendRtcpSr(int id) { + return wrap([id] { + auto sender = getRtcpSrReporter(id); + sender->setNeedsToReport(); + return RTC_ERR_SUCCESS; + }); +} + +int rtcGetTrackPayloadTypesForCodec(int tr, const char *ccodec, int *buffer, int size) { + return wrap([&] { + auto track = getTrack(tr); + auto codec = lowercased(string(ccodec)); + auto description = track->description(); + std::vector payloadTypes; + for (int pt : description.payloadTypes()) + if (lowercased(description.rtpMap(pt)->format) == codec) + payloadTypes.push_back(pt); + + return copyAndReturn(payloadTypes, buffer, size); + }); +} + +int rtcGetSsrcsForTrack(int tr, uint32_t *buffer, int count) { + return wrap([&] { + auto track = getTrack(tr); + auto ssrcs = track->description().getSSRCs(); + return copyAndReturn(ssrcs, buffer, count); + }); +} + +int rtcGetCNameForSsrc(int tr, uint32_t ssrc, char *cname, int cnameSize) { + return wrap([&] { + auto track = getTrack(tr); + auto description = track->description(); + auto optCName = description.getCNameForSsrc(ssrc); + if (optCName.has_value()) { + return copyAndReturn(optCName.value(), cname, cnameSize); + } else { + return 0; + } + }); +} + +int rtcGetSsrcsForType(const char *mediaType, const char *sdp, uint32_t *buffer, int bufferSize) { + return wrap([&] { + auto type = lowercased(string(mediaType)); + auto oldSDP = string(sdp); + auto description = Description(oldSDP, "unspec"); + auto mediaCount = description.mediaCount(); + for (unsigned int i = 0; i < mediaCount; i++) { + if (std::holds_alternative(description.media(i))) { + auto media = std::get(description.media(i)); + auto currentMediaType = lowercased(media->type()); + if (currentMediaType == type) { + auto ssrcs = media->getSSRCs(); + return copyAndReturn(ssrcs, buffer, bufferSize); + } + } + } + return 0; + }); +} + +int rtcSetSsrcForType(const char *mediaType, const char *sdp, char *buffer, const int bufferSize, + rtcSsrcForTypeInit *init) { + return wrap([&] { + auto type = lowercased(string(mediaType)); + auto prevSDP = string(sdp); + auto description = Description(prevSDP, "unspec"); + auto mediaCount = description.mediaCount(); + for (unsigned int i = 0; i < mediaCount; i++) { + if (std::holds_alternative(description.media(i))) { + auto media = std::get(description.media(i)); + auto currentMediaType = lowercased(media->type()); + if (currentMediaType == type) { + setSSRC(media, init->ssrc, init->name, init->msid, init->trackId); + break; + } + } + } + return copyAndReturn(string(description), buffer, bufferSize); + }); +} + +#endif // RTC_ENABLE_MEDIA + +#if RTC_ENABLE_WEBSOCKET + +int rtcCreateWebSocket(const char *url) { + return wrap([&] { + auto webSocket = std::make_shared(); + webSocket->open(url); + return emplaceWebSocket(webSocket); + }); +} + +int rtcCreateWebSocketEx(const char *url, const rtcWsConfiguration *config) { + return wrap([&] { + if (!url) + throw std::invalid_argument("Unexpected null pointer for URL"); + + if (!config) + throw std::invalid_argument("Unexpected null pointer for config"); + + WebSocket::Configuration c; + c.disableTlsVerification = config->disableTlsVerification; + + if (config->proxyServer) + c.proxyServer.emplace(config->proxyServer); + + for (int i = 0; i < config->protocolsCount; ++i) + c.protocols.emplace_back(string(config->protocols[i])); + + if (config->connectionTimeoutMs > 0) + c.connectionTimeout = milliseconds(config->connectionTimeoutMs); + else if (config->connectionTimeoutMs < 0) + c.connectionTimeout = milliseconds::zero(); // setting to 0 disables, + // not setting keeps default + if (config->pingIntervalMs > 0) + c.pingInterval = milliseconds(config->pingIntervalMs); + else if (config->pingIntervalMs < 0) + c.pingInterval = milliseconds::zero(); // setting to 0 disables, + // not setting keeps default + if (config->maxOutstandingPings > 0) + c.maxOutstandingPings = config->maxOutstandingPings; + else if (config->maxOutstandingPings < 0) + c.maxOutstandingPings = 0; // setting to 0 disables, not setting keeps default + + if(config->maxMessageSize > 0) + c.maxMessageSize = size_t(config->maxMessageSize); + + auto webSocket = std::make_shared(std::move(c)); + webSocket->open(url); + return emplaceWebSocket(webSocket); + }); +} + +int rtcDeleteWebSocket(int ws) { + return wrap([&] { + auto webSocket = getWebSocket(ws); + webSocket->forceClose(); + webSocket->resetCallbacks(); // not done on close by WebSocket + eraseWebSocket(ws); + return RTC_ERR_SUCCESS; + }); +} + +int rtcGetWebSocketRemoteAddress(int ws, char *buffer, int size) { + return wrap([&] { + auto webSocket = getWebSocket(ws); + if (auto remoteAddress = webSocket->remoteAddress()) + return copyAndReturn(*remoteAddress, buffer, size); + else + return RTC_ERR_NOT_AVAIL; + }); +} + +int rtcGetWebSocketPath(int ws, char *buffer, int size) { + return wrap([&] { + auto webSocket = getWebSocket(ws); + if (auto path = webSocket->path()) + return copyAndReturn(*path, buffer, size); + else + return RTC_ERR_NOT_AVAIL; + }); +} + +RTC_C_EXPORT int rtcCreateWebSocketServer(const rtcWsServerConfiguration *config, + rtcWebSocketClientCallbackFunc cb) { + return wrap([&] { + if (!config) + throw std::invalid_argument("Unexpected null pointer for config"); + + if (!cb) + throw std::invalid_argument("Unexpected null pointer for client callback"); + + WebSocketServer::Configuration c; + c.port = config->port; + c.enableTls = config->enableTls; + c.certificatePemFile = config->certificatePemFile + ? make_optional(string(config->certificatePemFile)) + : nullopt; + c.keyPemFile = config->keyPemFile ? make_optional(string(config->keyPemFile)) : nullopt; + c.keyPemPass = config->keyPemPass ? make_optional(string(config->keyPemPass)) : nullopt; + c.bindAddress = config->bindAddress ? make_optional(string(config->bindAddress)) : nullopt; + + if(config->maxMessageSize > 0) + c.maxMessageSize = size_t(config->maxMessageSize); + + auto webSocketServer = std::make_shared(std::move(c)); + int wsserver = emplaceWebSocketServer(webSocketServer); + + webSocketServer->onClient([wsserver, cb](shared_ptr webSocket) { + int ws = emplaceWebSocket(webSocket); + if (auto ptr = getUserPointer(wsserver)) { + rtcSetUserPointer(wsserver, *ptr); + cb(wsserver, ws, *ptr); + } + }); + + return wsserver; + }); +} + +RTC_C_EXPORT int rtcDeleteWebSocketServer(int wsserver) { + return wrap([&] { + auto webSocketServer = getWebSocketServer(wsserver); + webSocketServer->onClient(nullptr); + webSocketServer->stop(); + eraseWebSocketServer(wsserver); + return RTC_ERR_SUCCESS; + }); +} + +RTC_C_EXPORT int rtcGetWebSocketServerPort(int wsserver) { + return wrap([&] { + auto webSocketServer = getWebSocketServer(wsserver); + return int(webSocketServer->port()); + }); +} + +#endif + +void rtcPreload() { + try { + rtc::Preload(); + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + } +} + +void rtcCleanup() { + try { + size_t count = eraseAll(); + if (count != 0) { + PLOG_INFO << count << " objects were not properly destroyed before cleanup"; + } + + if (rtc::Cleanup().wait_for(10s) == std::future_status::timeout) + throw std::runtime_error( + "Cleanup timeout (possible deadlock or undestructible object)"); + + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + } +} + +int rtcSetSctpSettings(const rtcSctpSettings *settings) { + return wrap([&] { + SctpSettings s = {}; + + if (settings->recvBufferSize > 0) + s.recvBufferSize = size_t(settings->recvBufferSize); + + if (settings->sendBufferSize > 0) + s.sendBufferSize = size_t(settings->sendBufferSize); + + if (settings->maxChunksOnQueue > 0) + s.maxChunksOnQueue = size_t(settings->maxChunksOnQueue); + + if (settings->initialCongestionWindow > 0) + s.initialCongestionWindow = size_t(settings->initialCongestionWindow); + + if (settings->maxBurst > 0) + s.maxBurst = size_t(settings->maxBurst); + else if (settings->maxBurst < 0) + s.maxBurst = size_t(0); // setting to 0 disables, not setting chooses optimized default + + if (settings->congestionControlModule >= 0) + s.congestionControlModule = unsigned(settings->congestionControlModule); + + if (settings->delayedSackTimeMs > 0) + s.delayedSackTime = milliseconds(settings->delayedSackTimeMs); + else if (settings->delayedSackTimeMs < 0) + s.delayedSackTime = milliseconds(0); + + if (settings->minRetransmitTimeoutMs > 0) + s.minRetransmitTimeout = milliseconds(settings->minRetransmitTimeoutMs); + + if (settings->maxRetransmitTimeoutMs > 0) + s.maxRetransmitTimeout = milliseconds(settings->maxRetransmitTimeoutMs); + + if (settings->initialRetransmitTimeoutMs > 0) + s.initialRetransmitTimeout = milliseconds(settings->initialRetransmitTimeoutMs); + + if (settings->maxRetransmitAttempts > 0) + s.maxRetransmitAttempts = settings->maxRetransmitAttempts; + + if (settings->heartbeatIntervalMs > 0) + s.heartbeatInterval = milliseconds(settings->heartbeatIntervalMs); + + SetSctpSettings(std::move(s)); + return RTC_ERR_SUCCESS; + }); +} diff --git a/datachannel/src/channel.cpp b/datachannel/src/channel.cpp new file mode 100644 index 000000000..a63f91501 --- /dev/null +++ b/datachannel/src/channel.cpp @@ -0,0 +1,62 @@ +/** + * Copyright (c) 2019-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "channel.hpp" + +#include "impl/channel.hpp" +#include "impl/internals.hpp" + +namespace rtc { + +Channel::~Channel() { impl()->resetCallbacks(); } + +Channel::Channel(impl_ptr impl) : CheshireCat(std::move(impl)) {} + +size_t Channel::maxMessageSize() const { return 0; } + +size_t Channel::bufferedAmount() const { return impl()->bufferedAmount; } + +void Channel::onOpen(std::function callback) { impl()->openCallback = callback; } + +void Channel::onClosed(std::function callback) { impl()->closedCallback = callback; } + +void Channel::onError(std::function callback) { + impl()->errorCallback = callback; +} + +void Channel::onMessage(std::function callback) { + impl()->messageCallback = callback; + impl()->flushPendingMessages(); +} + +void Channel::onMessage(std::function binaryCallback, + std::function stringCallback) { + onMessage([binaryCallback, stringCallback](variant data) { + std::visit(overloaded{binaryCallback, stringCallback}, std::move(data)); + }); +} + +void Channel::onBufferedAmountLow(std::function callback) { + impl()->bufferedAmountLowCallback = callback; +} + +void Channel::setBufferedAmountLowThreshold(size_t amount) { + impl()->bufferedAmountLowThreshold = amount; +} + +void Channel::resetCallbacks() { impl()->resetCallbacks(); } + +optional Channel::receive() { return impl()->receive(); } + +optional Channel::peek() { return impl()->peek(); } + +size_t Channel::availableAmount() const { return impl()->availableAmount(); } + +void Channel::onAvailable(std::function callback) { impl()->availableCallback = callback; } + +} // namespace rtc diff --git a/datachannel/src/configuration.cpp b/datachannel/src/configuration.cpp new file mode 100644 index 000000000..fe28aef71 --- /dev/null +++ b/datachannel/src/configuration.cpp @@ -0,0 +1,155 @@ +/** + * Copyright (c) 2019 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "configuration.hpp" + +#include "impl/utils.hpp" + +#include +#include + +namespace { + +bool parse_url(const std::string &url, std::vector> &result) { + // Modified regex from RFC 3986, see https://www.rfc-editor.org/rfc/rfc3986.html#appendix-B + static const char *rs = + R"(^(([^:.@/?#]+):)?(/{0,2}((([^:@]*)(:([^@]*))?)@)?(([^:/?#]*)(:([^/?#]*))?))?([^?#]*)(\?([^#]*))?(#(.*))?)"; + static const std::regex r(rs, std::regex::extended); + + std::smatch m; + if (!std::regex_match(url, m, r) || m[10].length() == 0) + return false; + + result.resize(m.size()); + std::transform(m.begin(), m.end(), result.begin(), [](const auto &sm) { + return sm.length() > 0 ? std::make_optional(std::string(sm)) : std::nullopt; + }); + + assert(result.size() == 18); + return true; +} + +} // namespace + +namespace rtc { + +namespace utils = impl::utils; + +IceServer::IceServer(const string &url) { + std::vector> opt; + if (!parse_url(url, opt)) + throw std::invalid_argument("Invalid ICE server URL: " + url); + + string scheme = opt[2].value_or("stun"); + relayType = RelayType::TurnUdp; + if (scheme == "stun" || scheme == "STUN") + type = Type::Stun; + else if (scheme == "turn" || scheme == "TURN") + type = Type::Turn; + else if (scheme == "turns" || scheme == "TURNS") { + type = Type::Turn; + relayType = RelayType::TurnTls; + } else + throw std::invalid_argument("Unknown ICE server protocol: " + scheme); + + if (auto &query = opt[15]) { + if (query->find("transport=udp") != string::npos) + relayType = RelayType::TurnUdp; + if (query->find("transport=tcp") != string::npos) + relayType = RelayType::TurnTcp; + if (query->find("transport=tls") != string::npos) + relayType = RelayType::TurnTls; + } + + username = utils::url_decode(opt[6].value_or("")); + password = utils::url_decode(opt[8].value_or("")); + + hostname = opt[10].value(); + if (hostname.front() == '[' && hostname.back() == ']') { + // IPv6 literal + hostname.erase(hostname.begin()); + hostname.pop_back(); + } else { + hostname = utils::url_decode(hostname); + } + + string service = opt[12].value_or(relayType == RelayType::TurnTls ? "5349" : "3478"); + try { + port = uint16_t(std::stoul(service)); + } catch (...) { + throw std::invalid_argument("Invalid ICE server port in URL: " + service); + } +} + +IceServer::IceServer(string hostname_, uint16_t port_) + : hostname(std::move(hostname_)), port(port_), type(Type::Stun) {} + +IceServer::IceServer(string hostname_, string service_) + : hostname(std::move(hostname_)), type(Type::Stun) { + try { + port = uint16_t(std::stoul(service_)); + } catch (...) { + throw std::invalid_argument("Invalid ICE server port: " + service_); + } +} + +IceServer::IceServer(string hostname_, uint16_t port_, string username_, string password_, + RelayType relayType_) + : hostname(std::move(hostname_)), port(port_), type(Type::Turn), username(std::move(username_)), + password(std::move(password_)), relayType(relayType_) {} + +IceServer::IceServer(string hostname_, string service_, string username_, string password_, + RelayType relayType_) + : hostname(std::move(hostname_)), type(Type::Turn), username(std::move(username_)), + password(std::move(password_)), relayType(relayType_) { + try { + port = uint16_t(std::stoul(service_)); + } catch (...) { + throw std::invalid_argument("Invalid ICE server port: " + service_); + } +} + +ProxyServer::ProxyServer(const string &url) { + std::vector> opt; + if (!parse_url(url, opt)) + throw std::invalid_argument("Invalid proxy server URL: " + url); + + string scheme = opt[2].value_or("http"); + if (scheme == "http" || scheme == "HTTP") + type = Type::Http; + else if (scheme == "socks5" || scheme == "SOCKS5") + type = Type::Socks5; + else + throw std::invalid_argument("Unknown proxy server protocol: " + scheme); + + username = opt[6]; + password = opt[8]; + + hostname = opt[10].value(); + while (!hostname.empty() && hostname.front() == '[') + hostname.erase(hostname.begin()); + while (!hostname.empty() && hostname.back() == ']') + hostname.pop_back(); + + string service = opt[12].value_or(type == Type::Socks5 ? "1080" : "3128"); + try { + port = uint16_t(std::stoul(service)); + } catch (...) { + throw std::invalid_argument("Invalid proxy server port in URL: " + service); + } +} + +ProxyServer::ProxyServer(Type type_, string hostname_, uint16_t port_) + : type(type_), hostname(std::move(hostname_)), port(port_) {} + +ProxyServer::ProxyServer(Type type_, string hostname_, uint16_t port_, string username_, + string password_) + : type(type_), hostname(std::move(hostname_)), port(port_), username(std::move(username_)), + password(std::move(password_)) {} + +} // namespace rtc diff --git a/datachannel/src/datachannel.cpp b/datachannel/src/datachannel.cpp new file mode 100644 index 000000000..80a151f33 --- /dev/null +++ b/datachannel/src/datachannel.cpp @@ -0,0 +1,57 @@ +/** + * Copyright (c) 2019-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "datachannel.hpp" +#include "common.hpp" +#include "peerconnection.hpp" + +#include "impl/datachannel.hpp" +#include "impl/internals.hpp" +#include "impl/peerconnection.hpp" + +#ifdef _WIN32 +#include +#else +#include +#endif + +namespace rtc { + +DataChannel::DataChannel(impl_ptr impl) + : CheshireCat(impl), + Channel(std::dynamic_pointer_cast(impl)) {} + +DataChannel::~DataChannel() {} + +void DataChannel::close() { return impl()->close(); } + +optional DataChannel::stream() const { return impl()->stream(); } + +optional DataChannel::id() const { return impl()->stream(); } + +string DataChannel::label() const { return impl()->label(); } + +string DataChannel::protocol() const { return impl()->protocol(); } + +Reliability DataChannel::reliability() const { return impl()->reliability(); } + +bool DataChannel::isOpen(void) const { return impl()->isOpen(); } + +bool DataChannel::isClosed(void) const { return impl()->isClosed(); } + +size_t DataChannel::maxMessageSize() const { return impl()->maxMessageSize(); } + +bool DataChannel::send(message_variant data) { + return impl()->outgoing(make_message(std::move(data))); +} + +bool DataChannel::send(const byte *data, size_t size) { + return impl()->outgoing(std::make_shared(data, data + size, Message::Binary)); +} + +} // namespace rtc diff --git a/datachannel/src/description.cpp b/datachannel/src/description.cpp new file mode 100644 index 000000000..b2165a895 --- /dev/null +++ b/datachannel/src/description.cpp @@ -0,0 +1,1398 @@ +/** + * Copyright (c) 2019-2020 Paul-Louis Ageneau + * Copyright (c) 2020 Staz Modrzynski + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "description.hpp" + +#include "impl/internals.hpp" +#include "impl/utils.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +using std::chrono::system_clock; + +namespace { + +using std::string; +using std::string_view; + +inline bool match_prefix(string_view str, string_view prefix) { + return str.size() >= prefix.size() && + std::mismatch(prefix.begin(), prefix.end(), str.begin()).first == prefix.end(); +} + +inline void trim_end(string &str) { + str.erase( + std::find_if(str.rbegin(), str.rend(), [](char c) { return !std::isspace(c); }).base(), + str.end()); +} + +inline string get_first_line(const string &str) { + string line; + std::istringstream ss(str); + std::getline(ss, line); + return line; +} + +inline std::pair parse_pair(string_view attr) { + string_view key, value; + if (size_t separator = attr.find(':'); separator != string::npos) { + key = attr.substr(0, separator); + value = attr.substr(separator + 1); + } else { + key = attr; + } + return std::make_pair(std::move(key), std::move(value)); +} + +template T to_integer(string_view s) { + const string str(s); + try { + return std::is_signed::value ? T(std::stol(str)) : T(std::stoul(str)); + } catch (...) { + throw std::invalid_argument("Invalid integer \"" + str + "\" in description"); + } +} + +} // namespace + +namespace rtc { + +namespace utils = impl::utils; + +Description::Description(const string &sdp, Type type, Role role) + : mType(Type::Unspec), mRole(role) { + hintType(type); + + int index = -1; + shared_ptr current; + std::istringstream ss(sdp); + while (ss) { + string line; + std::getline(ss, line); + trim_end(line); + if (line.empty()) + continue; + + if (match_prefix(line, "m=")) { // Media description line (aka m-line) + current = createEntry(line.substr(2), std::to_string(++index), Direction::Unknown); + + } else if (match_prefix(line, "o=")) { // Origin line + std::istringstream origin(line.substr(2)); + origin >> mUsername >> mSessionId; + + } else if (match_prefix(line, "a=")) { // Attribute line + string attr = line.substr(2); + auto [key, value] = parse_pair(attr); + + if (key == "setup") { + if (value == "active") + mRole = Role::Active; + else if (value == "passive") + mRole = Role::Passive; + else + mRole = Role::ActPass; + + } else if (key == "fingerprint") { + // RFC 8122: The fingerprint attribute may be either a session-level or a + // media-level SDP attribute. If it is a session-level attribute, it applies to all + // TLS sessions for which no media-level fingerprint attribute is defined. + if (!mFingerprint || index == 0) { // first media overrides session-level + auto fingerprintExploded = utils::explode(string(value), ' '); + if (fingerprintExploded.size() != 2) { + PLOG_WARNING << "Unknown SDP fingerprint format: " << value; + continue; + } + + auto first = fingerprintExploded.at(0); + std::transform(first.begin(), first.end(), first.begin(), + [](char c) { return char(std::tolower(c)); }); + + std::optional fingerprintAlgorithm; + + for (auto a : std::array{ + CertificateFingerprint::Algorithm::Sha1, + CertificateFingerprint::Algorithm::Sha224, + CertificateFingerprint::Algorithm::Sha256, + CertificateFingerprint::Algorithm::Sha384, + CertificateFingerprint::Algorithm::Sha512}) { + if (first == CertificateFingerprint::AlgorithmIdentifier(a)) { + fingerprintAlgorithm = a; + break; + } + } + + if (fingerprintAlgorithm.has_value()) { + setFingerprint(CertificateFingerprint{ + fingerprintAlgorithm.value(), std::move(fingerprintExploded.at(1))}); + } else { + PLOG_WARNING << "Unknown certificate fingerprint algorithm: " << first; + } + } + } else if (key == "ice-ufrag") { + // RFC 8839: The "ice-pwd" and "ice-ufrag" attributes can appear at either the + // session-level or media-level. When present in both, the value in the media-level + // takes precedence. + if (!mIceUfrag || index == 0) // media-level for first media overrides session-level + mIceUfrag = value; + } else if (key == "ice-pwd") { + // RFC 8839: The "ice-pwd" and "ice-ufrag" attributes can appear at either the + // session-level or media-level. When present in both, the value in the media-level + // takes precedence. + if (!mIcePwd || index == 0) // media-level for first media overrides session-level + mIcePwd = value; + } else if (key == "ice-options") { + // RFC 8839: The "ice-options" attribute is a session-level and media-level + // attribute. + if (mIceOptions.empty()) + mIceOptions = utils::explode(string(value), ','); + } else if (key == "candidate") { + addCandidate(Candidate(attr, bundleMid())); + } else if (key == "end-of-candidates") { + mEnded = true; + } else if (current) { + current->parseSdpLine(std::move(line)); + } else { + mAttributes.emplace_back(attr); + } + + } else if (current) { + current->parseSdpLine(std::move(line)); + } + } + + if (mUsername.empty()) + mUsername = "rtc"; + + if (mSessionId.empty()) { + auto uniform = std::bind(std::uniform_int_distribution(), utils::random_engine()); + mSessionId = std::to_string(uniform()); + } +} + +Description::Description(const string &sdp, string typeString) + : Description(sdp, !typeString.empty() ? stringToType(typeString) : Type::Unspec, + Role::ActPass) {} + +Description::Type Description::type() const { return mType; } + +string Description::typeString() const { return typeToString(mType); } + +Description::Role Description::role() const { return mRole; } + +string Description::bundleMid() const { + // Get the mid of the first non-removed media + for (const auto &entry : mEntries) + if (!entry->isRemoved()) + return entry->mid(); + + return "0"; +} + +optional Description::iceUfrag() const { return mIceUfrag; } + +std::vector Description::iceOptions() const { return mIceOptions; } + +optional Description::icePwd() const { return mIcePwd; } + +optional Description::fingerprint() const { return mFingerprint; } + +bool Description::ended() const { return mEnded; } + +void Description::hintType(Type type) { + if (mType == Type::Unspec) + mType = type; +} + +void Description::setFingerprint(CertificateFingerprint f) { + if (!f.isValid()) + throw std::invalid_argument("Invalid " + CertificateFingerprint::AlgorithmIdentifier(f.algorithm) + " fingerprint \"" + f.value + "\""); + + std::transform(f.value.begin(), f.value.end(), f.value.begin(), + [](char c) { return char(std::toupper(c)); }); + mFingerprint = std::move(f); +} + +void Description::addIceOption(string option) { + if (std::find(mIceOptions.begin(), mIceOptions.end(), option) == mIceOptions.end()) + mIceOptions.emplace_back(std::move(option)); +} + +void Description::removeIceOption(const string &option) { + mIceOptions.erase(std::remove(mIceOptions.begin(), mIceOptions.end(), option), + mIceOptions.end()); +} + +std::vector Description::Entry::attributes() const { return mAttributes; } + +void Description::Entry::addAttribute(string attr) { + if (std::find(mAttributes.begin(), mAttributes.end(), attr) == mAttributes.end()) + mAttributes.emplace_back(std::move(attr)); +} + +void Description::Entry::removeAttribute(const string &attr) { + mAttributes.erase( + std::remove_if(mAttributes.begin(), mAttributes.end(), + [&](const auto &a) { return a == attr || parse_pair(a).first == attr; }), + mAttributes.end()); +} + +std::vector Description::candidates() const { return mCandidates; } + +std::vector Description::extractCandidates() { + std::vector result; + std::swap(mCandidates, result); + mEnded = false; + return result; +} + +bool Description::hasCandidate(const Candidate &candidate) const { + return std::find(mCandidates.begin(), mCandidates.end(), candidate) != mCandidates.end(); +} + +void Description::addCandidate(Candidate candidate) { + candidate.hintMid(bundleMid()); + + if (!hasCandidate(candidate)) + mCandidates.emplace_back(std::move(candidate)); +} + +void Description::addCandidates(std::vector candidates) { + for (Candidate candidate : candidates) + addCandidate(std::move(candidate)); +} + +void Description::endCandidates() { mEnded = true; } + +Description::operator string() const { return generateSdp("\r\n"); } + +string Description::generateSdp(string_view eol) const { + std::ostringstream sdp; + + // Header + sdp << "v=0" << eol; + sdp << "o=" << mUsername << " " << mSessionId << " 0 IN IP4 127.0.0.1" << eol; + sdp << "s=-" << eol; + sdp << "t=0 0" << eol; + + // BUNDLE (RFC 8843 Negotiating Media Multiplexing Using the Session Description Protocol) + // https://www.rfc-editor.org/rfc/rfc8843.html + std::ostringstream bundleGroup; + for (const auto &entry : mEntries) + if (!entry->isRemoved()) + bundleGroup << ' ' << entry->mid(); + + if (!bundleGroup.str().empty()) + sdp << "a=group:BUNDLE" << bundleGroup.str() << eol; + + // Lip-sync + std::ostringstream lsGroup; + for (const auto &entry : mEntries) + if (!entry->isRemoved() && entry != mApplication) + lsGroup << ' ' << entry->mid(); + + if (!lsGroup.str().empty()) + sdp << "a=group:LS" << lsGroup.str() << eol; + + // Session-level attributes + sdp << "a=msid-semantic:WMS *" << eol; + sdp << "a=setup:" << mRole << eol; + + if (mIceUfrag) + sdp << "a=ice-ufrag:" << *mIceUfrag << eol; + if (mIcePwd) + sdp << "a=ice-pwd:" << *mIcePwd << eol; + if (!mIceOptions.empty()) + sdp << "a=ice-options:" << utils::implode(mIceOptions, ',') << eol; + if (mFingerprint) + sdp << "a=fingerprint:" + << CertificateFingerprint::AlgorithmIdentifier(mFingerprint->algorithm) << " " + << mFingerprint->value << eol; + + for (const auto &attr : mAttributes) + sdp << "a=" << attr << eol; + + auto cand = defaultCandidate(); + const string addr = cand && cand->isResolved() + ? (string(cand->family() == Candidate::Family::Ipv6 ? "IP6" : "IP4") + + " " + *cand->address()) + : "IP4 0.0.0.0"; + const uint16_t port = + cand && cand->isResolved() ? *cand->port() : 9; // Port 9 is the discard protocol + + // Entries + bool first = true; + for (const auto &entry : mEntries) { + sdp << entry->generateSdp(eol, addr, port); + + if (!entry->isRemoved() && std::exchange(first, false)) { + // Candidates + for (const auto &candidate : mCandidates) + sdp << string(candidate) << eol; + + if (mEnded) + sdp << "a=end-of-candidates" << eol; + } + } + + return sdp.str(); +} + +string Description::generateApplicationSdp(string_view eol) const { + std::ostringstream sdp; + + // Header + sdp << "v=0" << eol; + sdp << "o=" << mUsername << " " << mSessionId << " 0 IN IP4 127.0.0.1" << eol; + sdp << "s=-" << eol; + sdp << "t=0 0" << eol; + + auto cand = defaultCandidate(); + const string addr = cand && cand->isResolved() + ? (string(cand->family() == Candidate::Family::Ipv6 ? "IP6" : "IP4") + + " " + *cand->address()) + : "IP4 0.0.0.0"; + const uint16_t port = + cand && cand->isResolved() ? *cand->port() : 9; // Port 9 is the discard protocol + + // Application + auto app = mApplication ? mApplication : std::make_shared(); + sdp << app->generateSdp(eol, addr, port); + + // Session-level attributes + sdp << "a=msid-semantic:WMS *" << eol; + sdp << "a=setup:" << mRole << eol; + + if (mIceUfrag) + sdp << "a=ice-ufrag:" << *mIceUfrag << eol; + if (mIcePwd) + sdp << "a=ice-pwd:" << *mIcePwd << eol; + if (!mIceOptions.empty()) + sdp << "a=ice-options:" << utils::implode(mIceOptions, ',') << eol; + if (mFingerprint) + sdp << "a=fingerprint:" + << CertificateFingerprint::AlgorithmIdentifier(mFingerprint->algorithm) << " " + << mFingerprint->value << eol; + + for (const auto &attr : mAttributes) + sdp << "a=" << attr << eol; + + // Candidates + for (const auto &candidate : mCandidates) + sdp << string(candidate) << eol; + + if (mEnded) + sdp << "a=end-of-candidates" << eol; + + return sdp.str(); +} + +optional Description::defaultCandidate() const { + // Return the first host candidate with highest priority, favoring IPv4 + optional result; + for (const auto &c : mCandidates) { + if (c.type() == Candidate::Type::Host) { + if (!result || + (result->family() == Candidate::Family::Ipv6 && + c.family() == Candidate::Family::Ipv4) || + (result->family() == c.family() && result->priority() < c.priority())) + result.emplace(c); + } + } + return result; +} + +shared_ptr Description::createEntry(string mline, string mid, Direction dir) { + string type = mline.substr(0, mline.find(' ')); + if (type == "application") { + removeApplication(); + mApplication = std::make_shared(mline, std::move(mid)); + mEntries.emplace_back(mApplication); + return mApplication; + } else { + auto media = std::make_shared(std::move(mline), std::move(mid), dir); + mEntries.emplace_back(media); + return media; + } +} + +void Description::removeApplication() { + if (!mApplication) + return; + + auto it = std::find(mEntries.begin(), mEntries.end(), mApplication); + if (it != mEntries.end()) + mEntries.erase(it); + + mApplication.reset(); +} + +bool Description::hasApplication() const { return mApplication && !mApplication->isRemoved(); } + +bool Description::hasAudioOrVideo() const { + for (auto entry : mEntries) + if (entry != mApplication && !entry->isRemoved()) + return true; + + return false; +} + +bool Description::hasMid(string_view mid) const { + for (const auto &entry : mEntries) + if (entry->mid() == mid) + return true; + + return false; +} + +int Description::addMedia(Media media) { + mEntries.emplace_back(std::make_shared(std::move(media))); + return int(mEntries.size()) - 1; +} + +int Description::addMedia(Application application) { + removeApplication(); + mApplication = std::make_shared(std::move(application)); + mEntries.emplace_back(mApplication); + return int(mEntries.size()) - 1; +} + +int Description::addApplication(string mid) { return addMedia(Application(std::move(mid))); } + +const Description::Application *Description::application() const { return mApplication.get(); } + +Description::Application *Description::application() { return mApplication.get(); } + +int Description::addVideo(string mid, Direction dir) { + return addMedia(Video(std::move(mid), dir)); +} + +int Description::addAudio(string mid, Direction dir) { + return addMedia(Audio(std::move(mid), dir)); +} + +void Description::clearMedia() { + mEntries.clear(); + mApplication.reset(); +} + +variant Description::media(unsigned int index) { + if (index >= mEntries.size()) + throw std::out_of_range("Media index out of range"); + + const auto &entry = mEntries[index]; + if (entry == mApplication) { + auto result = dynamic_cast(entry.get()); + if (!result) + throw std::logic_error("Bad type of application in description"); + + return result; + + } else { + auto result = dynamic_cast(entry.get()); + if (!result) + throw std::logic_error("Bad type of media in description"); + + return result; + } +} + +variant +Description::media(unsigned int index) const { + if (index >= mEntries.size()) + throw std::out_of_range("Media index out of range"); + + const auto &entry = mEntries[index]; + if (entry == mApplication) { + auto result = dynamic_cast(entry.get()); + if (!result) + throw std::logic_error("Bad type of application in description"); + + return result; + + } else { + auto result = dynamic_cast(entry.get()); + if (!result) + throw std::logic_error("Bad type of media in description"); + + return result; + } +} + +unsigned int Description::mediaCount() const { return unsigned(mEntries.size()); } + +Description::Entry::Entry(const string &mline, string mid, Direction dir) + : mMid(std::move(mid)), mDirection(dir) { + + uint16_t port = 0; + std::istringstream ss(match_prefix(mline, "m=") ? mline.substr(2) : mline); + ss >> mType; + ss >> port; + ss >> mDescription; + + if (mType.empty() || mDescription.empty()) + throw std::invalid_argument("Invalid media description line"); + + // RFC 3264: Existing media streams are removed by creating a new SDP with the port number for + // that stream set to zero. + // RFC 8843: If the offerer assigns a zero port value to a bundled "m=" section, but does not + // include an SDP 'bundle-only' attribute in the "m=" section, it is an indication that the + // offerer wants to disable the "m=" section. + mIsRemoved = (port == 0); +} + +string Description::Entry::type() const { return mType; } + +string Description::Entry::description() const { return mDescription; } + +string Description::Entry::mid() const { return mMid; } + +Description::Direction Description::Entry::direction() const { return mDirection; } + +void Description::Entry::setDirection(Direction dir) { mDirection = dir; } + +bool Description::Entry::isRemoved() const { return mIsRemoved; } + +void Description::Entry::markRemoved() { mIsRemoved = true; } + +std::vector Description::attributes() const { return mAttributes; } + +void Description::addAttribute(string attr) { + if (std::find(mAttributes.begin(), mAttributes.end(), attr) == mAttributes.end()) + mAttributes.emplace_back(std::move(attr)); +} + +void Description::Entry::addRid(string rid) { mRids.emplace_back(rid); } + +void Description::removeAttribute(const string &attr) { + mAttributes.erase( + std::remove_if(mAttributes.begin(), mAttributes.end(), + [&](const auto &a) { return a == attr || parse_pair(a).first == attr; }), + mAttributes.end()); +} + +std::vector Description::Entry::extIds() { + std::vector result; + for (auto it = mExtMaps.begin(); it != mExtMaps.end(); ++it) + result.push_back(it->first); + + return result; +} + +Description::Entry::ExtMap *Description::Entry::extMap(int id) { + auto it = mExtMaps.find(id); + if (it == mExtMaps.end()) + throw std::invalid_argument("extmap not found"); + + return &it->second; +} + +const Description::Entry::ExtMap *Description::Entry::extMap(int id) const { + auto it = mExtMaps.find(id); + if (it == mExtMaps.end()) + throw std::invalid_argument("extmap not found"); + + return &it->second; +} + +void Description::Entry::addExtMap(ExtMap map) { + auto id = map.id; + mExtMaps.emplace(id, std::move(map)); +} + +void Description::Entry::removeExtMap(int id) { mExtMaps.erase(id); } + +Description::Entry::operator string() const { return generateSdp("\r\n", "IP4 0.0.0.0", 9); } + +string Description::Entry::generateSdp(string_view eol, string_view addr, uint16_t port) const { + std::ostringstream sdp; + // RFC 3264: Existing media streams are removed by creating a new SDP with the port number for + // that stream set to zero. [...] A stream that is offered with a port of zero MUST be marked + // with port zero in the answer. + sdp << "m=" << type() << ' ' << (mIsRemoved ? 0 : port) << ' ' << description() << eol; + sdp << "c=IN " << addr << eol; + sdp << generateSdpLines(eol); + + return sdp.str(); +} + +string Description::Entry::generateSdpLines(string_view eol) const { + std::ostringstream sdp; + sdp << "a=mid:" << mMid << eol; + + for (auto it = mExtMaps.begin(); it != mExtMaps.end(); ++it) { + auto &map = it->second; + + sdp << "a=extmap:" << map.id; + if (map.direction != Direction::Unknown) + sdp << '/' << map.direction; + + sdp << ' ' << map.uri; + if (!map.attributes.empty()) + sdp << ' ' << map.attributes; + + sdp << eol; + } + + if (mDirection != Direction::Unknown) + sdp << "a=" << mDirection << eol; + + for (const auto &attr : mAttributes) { + if (mRids.size() != 0 && match_prefix(attr, "ssrc:")) { + continue; + } + + sdp << "a=" << attr << eol; + } + + for (const auto &rid : mRids) { + sdp << "a=rid:" << rid << " send" << eol; + } + + if (mRids.size() != 0) { + sdp << "a=simulcast:send "; + + bool first = true; + for (const auto &rid : mRids) { + if (first) { + first = false; + } else { + sdp << ";"; + } + + sdp << rid; + } + + sdp << eol; + } + + return sdp.str(); +} + +void Description::Entry::parseSdpLine(string_view line) { + if (match_prefix(line, "a=")) { + string_view attr = line.substr(2); + auto [key, value] = parse_pair(attr); + + if (key == "mid") { + mMid = value; + } else if (key == "extmap") { + auto id = Description::Media::ExtMap::parseId(value); + auto it = mExtMaps.find(id); + if (it == mExtMaps.end()) + it = mExtMaps.insert(std::make_pair(id, Description::Media::ExtMap(value))).first; + else + it->second.setDescription(value); + + } else if (attr == "sendonly") + mDirection = Direction::SendOnly; + else if (attr == "recvonly") + mDirection = Direction::RecvOnly; + else if (key == "sendrecv") + mDirection = Direction::SendRecv; + else if (key == "inactive") + mDirection = Direction::Inactive; + else if (key == "bundle-only") { + // RFC 8843: When an offerer generates a subsequent offer, in which it wants to disable + // a bundled "m=" section from a BUNDLE group, the offerer [...] MUST NOT assign an SDP + // 'bundle-only' attribute to the "m=" section. + mIsRemoved = false; + } else { + mAttributes.emplace_back(attr); + } + } +} + +int Description::Entry::ExtMap::parseId(string_view description) { + size_t p = description.find(' '); + return to_integer(description.substr(0, p)); +} + +Description::Entry::ExtMap::ExtMap(int id, string uri, Direction direction) { + this->id = id; + this->uri = std::move(uri); + this->direction = direction; +} + +Description::Entry::ExtMap::ExtMap(string_view description) { setDescription(description); } + +void Description::Entry::ExtMap::setDescription(string_view description) { + const size_t uriStart = description.find(' '); + if (uriStart == string::npos) + throw std::invalid_argument("Invalid description for extmap"); + + const string_view idAndDirection = description.substr(0, uriStart); + const size_t idSplit = idAndDirection.find('/'); + if (idSplit == string::npos) { + this->id = to_integer(idAndDirection); + } else { + this->id = to_integer(idAndDirection.substr(0, idSplit)); + + const string_view directionStr = idAndDirection.substr(idSplit + 1); + if (directionStr == "sendonly") + this->direction = Direction::SendOnly; + else if (directionStr == "recvonly") + this->direction = Direction::RecvOnly; + else if (directionStr == "sendrecv") + this->direction = Direction::SendRecv; + else if (directionStr == "inactive") + this->direction = Direction::Inactive; + else + throw std::invalid_argument("Invalid direction for extmap"); + } + + const string_view uriAndAttributes = description.substr(uriStart + 1); + const size_t attributeSplit = uriAndAttributes.find(' '); + + if (attributeSplit == string::npos) + this->uri = uriAndAttributes; + else { + this->uri = uriAndAttributes.substr(0, attributeSplit); + this->attributes = uriAndAttributes.substr(attributeSplit + 1); + } +} + +void Description::Media::addSSRC(uint32_t ssrc, optional name, optional msid, + optional trackId) { + if (name) { + mAttributes.emplace_back("ssrc:" + std::to_string(ssrc) + " cname:" + *name); + mCNameMap.emplace(ssrc, *name); + } else { + mAttributes.emplace_back("ssrc:" + std::to_string(ssrc)); + } + + if (msid) { + mAttributes.emplace_back("ssrc:" + std::to_string(ssrc) + " msid:" + *msid + " " + + trackId.value_or(*msid)); + mAttributes.emplace_back("msid:" + *msid + " " + trackId.value_or(*msid)); + } + + mSsrcs.emplace_back(ssrc); +} + +void Description::Media::removeSSRC(uint32_t ssrc) { + string prefix = "ssrc:" + std::to_string(ssrc); + mAttributes.erase(std::remove_if(mAttributes.begin(), mAttributes.end(), + [&](const auto &a) { return match_prefix(a, prefix); }), + mAttributes.end()); + + mSsrcs.erase(std::remove(mSsrcs.begin(), mSsrcs.end(), ssrc), mSsrcs.end()); +} + +void Description::Media::replaceSSRC(uint32_t old, uint32_t ssrc, optional name, + optional msid, optional trackID) { + removeSSRC(old); + addSSRC(ssrc, std::move(name), std::move(msid), std::move(trackID)); +} + +bool Description::Media::hasSSRC(uint32_t ssrc) const { + return std::find(mSsrcs.begin(), mSsrcs.end(), ssrc) != mSsrcs.end(); +} + +void Description::Media::clearSSRCs() { + auto it = mAttributes.begin(); + while (it != mAttributes.end()) { + if (match_prefix(*it, "ssrc:")) + it = mAttributes.erase(it); + else + ++it; + } + + mSsrcs.clear(); + mCNameMap.clear(); +} + +std::vector Description::Media::getSSRCs() const { return mSsrcs; } + +optional Description::Media::getCNameForSsrc(uint32_t ssrc) const { + auto it = mCNameMap.find(ssrc); + if (it != mCNameMap.end()) { + return it->second; + } + return nullopt; +} + +Description::Application::Application(string mid) + : Entry("application 9 UDP/DTLS/SCTP", std::move(mid), Direction::SendRecv) {} + +Description::Application::Application(const string &mline, string mid) + : Entry(mline, std::move(mid), Direction::SendRecv) {} + +string Description::Application::description() const { + return Entry::description() + " webrtc-datachannel"; +} + +Description::Application Description::Application::reciprocate() const { + Application reciprocated(*this); + + reciprocated.mMaxMessageSize.reset(); + + return reciprocated; +} + +void Description::Application::setSctpPort(uint16_t port) { mSctpPort = port; } + +void Description::Application::hintSctpPort(uint16_t port) { mSctpPort = mSctpPort.value_or(port); } + +void Description::Application::setMaxMessageSize(size_t size) { mMaxMessageSize = size; } + +optional Description::Application::sctpPort() const { return mSctpPort; } + +optional Description::Application::maxMessageSize() const { return mMaxMessageSize; } + +string Description::Application::generateSdpLines(string_view eol) const { + std::ostringstream sdp; + sdp << Entry::generateSdpLines(eol); + + if (mSctpPort) + sdp << "a=sctp-port:" << *mSctpPort << eol; + + if (mMaxMessageSize) + sdp << "a=max-message-size:" << *mMaxMessageSize << eol; + + return sdp.str(); +} + +void Description::Application::parseSdpLine(string_view line) { + if (match_prefix(line, "a=")) { + string_view attr = line.substr(2); + auto [key, value] = parse_pair(attr); + + if (key == "sctp-port") { + mSctpPort = to_integer(value); + } else if (key == "max-message-size") { + mMaxMessageSize = to_integer(value); + } else { + Entry::parseSdpLine(line); + } + } else { + Entry::parseSdpLine(line); + } +} + +Description::Media::Media(const string &sdp) : Entry(get_first_line(sdp), "", Direction::Unknown) { + string line; + std::istringstream ss(sdp); + std::getline(ss, line); // discard first line + while (ss) { + std::getline(ss, line); + trim_end(line); + if (line.empty()) + continue; + + parseSdpLine(line); + } + + if (mid().empty()) + throw std::invalid_argument("Missing mid in media description"); +} + +Description::Media::Media(const string &mline, string mid, Direction dir) + : Entry(mline, std::move(mid), dir) {} + +string Description::Media::description() const { + std::ostringstream desc; + desc << Entry::description(); + for (auto it = mRtpMaps.begin(); it != mRtpMaps.end(); ++it) + desc << ' ' << it->first; + + return desc.str(); +} + +Description::Media Description::Media::reciprocate() const { + Media reciprocated(*this); + + // Invert direction + switch (reciprocated.direction()) { + case Direction::RecvOnly: + reciprocated.setDirection(Direction::SendOnly); + break; + case Direction::SendOnly: + reciprocated.setDirection(Direction::RecvOnly); + break; + default: + // We are good + break; + } + + // Invert directions of extmap + auto &extMaps = reciprocated.mExtMaps; + for (auto it = extMaps.begin(); it != extMaps.end(); ++it) { + auto &map = it->second; + switch (map.direction) { + case Direction::RecvOnly: + map.direction = Direction::SendOnly; + break; + case Direction::SendOnly: + map.direction = Direction::RecvOnly; + break; + default: + // We are good + break; + } + } + + // Clear sent SSRCs + reciprocated.clearSSRCs(); + + // Remove rtcp-rsize attribute as Reduced-Size RTCP is not supported (see RFC 5506) + reciprocated.removeAttribute("rtcp-rsize"); + + return reciprocated; +} + +int Description::Media::bitrate() const { return mBas; } + +void Description::Media::setBitrate(int bitrate) { mBas = bitrate; } + +bool Description::Media::hasPayloadType(int payloadType) const { + return mRtpMaps.find(payloadType) != mRtpMaps.end(); +} + +std::vector Description::Media::payloadTypes() const { + std::vector result; + result.reserve(mRtpMaps.size()); + for (auto it = mRtpMaps.begin(); it != mRtpMaps.end(); ++it) + result.push_back(it->first); + + return result; +} + +Description::Media::RtpMap *Description::Media::rtpMap(int payloadType) { + auto it = mRtpMaps.find(payloadType); + if (it == mRtpMaps.end()) + throw std::invalid_argument("rtpmap not found"); + + return &it->second; +} + +const Description::Media::RtpMap *Description::Media::rtpMap(int payloadType) const { + auto it = mRtpMaps.find(payloadType); + if (it == mRtpMaps.end()) + throw std::invalid_argument("rtpmap not found"); + + return &it->second; +} + +void Description::Media::addRtpMap(RtpMap map) { + auto payloadType = map.payloadType; + mRtpMaps.emplace(payloadType, std::move(map)); +} + +void Description::Media::removeRtpMap(int payloadType) { + // Remove the actual format + mRtpMaps.erase(payloadType); + + // Remove any other rtpmaps that depend on the format we just removed + auto it = mRtpMaps.begin(); + while (it != mRtpMaps.end()) { + const auto &fmtps = it->second.fmtps; + if (std::find(fmtps.begin(), fmtps.end(), "apt=" + std::to_string(payloadType)) != + fmtps.end()) + it = mRtpMaps.erase(it); + else + ++it; + } +} + +void Description::Media::removeFormat(const string &format) { + std::vector payloadTypes; + for (const auto &it : mRtpMaps) { + if (it.second.format == format) + payloadTypes.push_back(it.first); + } + for (int pt : payloadTypes) + removeRtpMap(pt); +} + +void Description::Media::addRtxCodec(int payloadType, int origPayloadType, unsigned int clockRate) { + RtpMap rtp(std::to_string(payloadType) + " RTX/" + std::to_string(clockRate)); + rtp.fmtps.emplace_back("apt=" + std::to_string(origPayloadType)); + addRtpMap(rtp); +} + +string Description::Media::generateSdpLines(string_view eol) const { + std::ostringstream sdp; + if (mBas >= 0) + sdp << "b=AS:" << mBas << eol; + + sdp << Entry::generateSdpLines(eol); + sdp << "a=rtcp-mux" << eol; + + for (auto it = mRtpMaps.begin(); it != mRtpMaps.end(); ++it) { + auto &map = it->second; + + // Create the a=rtpmap + sdp << "a=rtpmap:" << map.payloadType << ' ' << map.format << '/' << map.clockRate; + if (!map.encParams.empty()) + sdp << '/' << map.encParams; + + sdp << eol; + + for (const auto &val : map.rtcpFbs) + sdp << "a=rtcp-fb:" << map.payloadType << ' ' << val << eol; + + for (const auto &val : map.fmtps) + sdp << "a=fmtp:" << map.payloadType << ' ' << val << eol; + } + + return sdp.str(); +} + +void Description::Media::parseSdpLine(string_view line) { + if (match_prefix(line, "a=")) { + string_view attr = line.substr(2); + auto [key, value] = parse_pair(attr); + + if (key == "rtpmap") { + auto pt = Description::Media::RtpMap::parsePayloadType(value); + auto it = mRtpMaps.find(pt); + if (it == mRtpMaps.end()) + it = mRtpMaps.insert(std::make_pair(pt, Description::Media::RtpMap(value))).first; + else + it->second.setDescription(value); + + } else if (key == "rtcp-fb") { + size_t p = value.find(' '); + int pt = to_integer(value.substr(0, p)); + auto it = mRtpMaps.find(pt); + if (it == mRtpMaps.end()) + it = mRtpMaps.insert(std::make_pair(pt, Description::Media::RtpMap(pt))).first; + + it->second.rtcpFbs.emplace_back(value.substr(p + 1)); + + } else if (key == "fmtp") { + size_t p = value.find(' '); + int pt = to_integer(value.substr(0, p)); + auto it = mRtpMaps.find(pt); + if (it == mRtpMaps.end()) + it = mRtpMaps.insert(std::make_pair(pt, Description::Media::RtpMap(pt))).first; + + it->second.fmtps.emplace_back(value.substr(p + 1)); + + } else if (key == "rtcp-mux") { + // always added + + } else if (key == "ssrc") { + auto ssrc = to_integer(value); + if (!hasSSRC(ssrc)) + mSsrcs.emplace_back(ssrc); + + auto cnamePos = value.find("cname:"); + if (cnamePos != string::npos) { + auto cname = value.substr(cnamePos + 6); + mCNameMap.emplace(ssrc, cname); + } + mAttributes.emplace_back(attr); + + } else { + Entry::parseSdpLine(line); + } + + } else if (match_prefix(line, "b=AS")) { + mBas = to_integer(line.substr(line.find(':') + 1)); + } else { + Entry::parseSdpLine(line); + } +} + +Description::Media::RtpMap::RtpMap(int payloadType) { + this->payloadType = payloadType; + this->clockRate = 0; +} + +int Description::Media::RtpMap::parsePayloadType(string_view mline) { + size_t p = mline.find(' '); + return to_integer(mline.substr(0, p)); +} + +Description::Media::RtpMap::RtpMap(string_view description) { setDescription(description); } + +void Description::Media::RtpMap::setDescription(string_view description) { + size_t p = description.find(' '); + if (p == string::npos) + throw std::invalid_argument("Invalid format description for rtpmap"); + + this->payloadType = to_integer(description.substr(0, p)); + + string_view line = description.substr(p + 1); + size_t spl = line.find('/'); + if (spl == string::npos) + throw std::invalid_argument("Invalid format description for rtpmap"); + + this->format = line.substr(0, spl); + + line = line.substr(spl + 1); + spl = line.find('/'); + if (spl == string::npos) { + spl = line.find(' '); + } + if (spl == string::npos) + this->clockRate = to_integer(line); + else { + this->clockRate = to_integer(line.substr(0, spl)); + this->encParams = line.substr(spl + 1); + } +} + +void Description::Media::RtpMap::addFeedback(string fb) { + if (std::find(rtcpFbs.begin(), rtcpFbs.end(), fb) == rtcpFbs.end()) + rtcpFbs.emplace_back(std::move(fb)); +} + +void Description::Media::RtpMap::removeFeedback(const string &str) { + auto it = rtcpFbs.begin(); + while (it != rtcpFbs.end()) { + if (it->find(str) != string::npos) + it = rtcpFbs.erase(it); + else + it++; + } +} + +void Description::Media::RtpMap::addParameter(string p) { + if (std::find(fmtps.begin(), fmtps.end(), p) == fmtps.end()) + fmtps.emplace_back(std::move(p)); +} + +void Description::Media::RtpMap::removeParameter(const string &str) { + fmtps.erase(std::remove_if(fmtps.begin(), fmtps.end(), + [&](const auto &p) { return p.find(str) != string::npos; }), + fmtps.end()); +} + +Description::Audio::Audio(string mid, Direction dir) + : Media("audio 9 UDP/TLS/RTP/SAVPF", std::move(mid), dir) {} + +void Description::Audio::addAudioCodec(int payloadType, string codec, optional profile) { + if (codec.find('/') == string::npos) { + if (codec == "PCMA" || codec == "PCMU") + codec += "/8000/1"; + else + codec += "/48000/2"; + } + + RtpMap map(std::to_string(payloadType) + ' ' + codec); + + if (profile) + map.fmtps.emplace_back(*profile); + + addRtpMap(map); +} + +void Description::Audio::addOpusCodec(int payloadType, optional profile) { + addAudioCodec(payloadType, "opus", profile); +} + +void Description::Audio::addPCMACodec(int payloadType, optional profile) { + addAudioCodec(payloadType, "PCMA", profile); +} + +void Description::Audio::addPCMUCodec(int payloadType, optional profile) { + addAudioCodec(payloadType, "PCMU", profile); +} + +void Description::Audio::addAACCodec(int payloadType, optional profile) { + if (profile) { + addAudioCodec(payloadType, "MP4A-LATM", profile); + } else { + addAudioCodec(payloadType, "MP4A-LATM", "cpresent=1"); + } +} + +Description::Video::Video(string mid, Direction dir) + : Media("video 9 UDP/TLS/RTP/SAVPF", std::move(mid), dir) {} + +void Description::Video::addVideoCodec(int payloadType, string codec, optional profile) { + if (codec.find('/') == string::npos) + codec += "/90000"; + + RtpMap map(std::to_string(payloadType) + ' ' + codec); + + map.addFeedback("nack"); + map.addFeedback("nack pli"); + // map.addFB("ccm fir"); + map.addFeedback("goog-remb"); + + if (profile) + map.fmtps.emplace_back(*profile); + + addRtpMap(map); + + /* TODO + * TIL that Firefox does not properly support the negotiation of RTX! It works, but doesn't + * negotiate the SSRC so we have no idea what SSRC is RTX going to be. Three solutions: One) we + * don't negotitate it and (maybe) break RTX support with Edge. Two) we do negotiate it and + * rebuild the original packet before we send it distribute it to each track. Three) we complain + * to mozilla. This one probably won't do much. + */ + // RTX Packets + // Format rtx(std::to_string(payloadType+1) + " rtx/90000"); + // // TODO rtx-time is how long can a request be stashed for before needing to resend it. + // Needs to be parameterized rtx.addAttribute("apt=" + std::to_string(payloadType) + + // ";rtx-time=3000"); addFormat(rtx); +} + +void Description::Video::addH264Codec(int payloadType, optional profile) { + addVideoCodec(payloadType, "H264", profile); +} + +void Description::Video::addH265Codec(int payloadType, optional profile) { + addVideoCodec(payloadType, "H265", profile); +} + +void Description::Video::addVP8Codec(int payloadType, optional profile) { + addVideoCodec(payloadType, "VP8", profile); +} + +void Description::Video::addVP9Codec(int payloadType, optional profile) { + addVideoCodec(payloadType, "VP9", profile); +} + +void Description::Video::addAV1Codec(int payloadType, optional profile) { + addVideoCodec(payloadType, "AV1", profile); +} + +Description::Type Description::stringToType(const string &typeString) { + using TypeMap_t = std::unordered_map; + static const TypeMap_t TypeMap = {{"unspec", Type::Unspec}, + {"offer", Type::Offer}, + {"answer", Type::Answer}, + {"pranswer", Type::Pranswer}, + {"rollback", Type::Rollback}}; + auto it = TypeMap.find(typeString); + return it != TypeMap.end() ? it->second : Type::Unspec; +} + +string Description::typeToString(Type type) { + switch (type) { + case Type::Unspec: + return "unspec"; + case Type::Offer: + return "offer"; + case Type::Answer: + return "answer"; + case Type::Pranswer: + return "pranswer"; + case Type::Rollback: + return "rollback"; + default: + return "unknown"; + } +} + +size_t +CertificateFingerprint::AlgorithmSize(CertificateFingerprint::Algorithm fingerprintAlgorithm) { + switch (fingerprintAlgorithm) { + case CertificateFingerprint::Algorithm::Sha1: + return 20; + case CertificateFingerprint::Algorithm::Sha224: + return 28; + case CertificateFingerprint::Algorithm::Sha256: + return 32; + case CertificateFingerprint::Algorithm::Sha384: + return 48; + case CertificateFingerprint::Algorithm::Sha512: + return 64; + default: + return 0; + } +} + +std::string CertificateFingerprint::AlgorithmIdentifier( + CertificateFingerprint::Algorithm fingerprintAlgorithm) { + switch (fingerprintAlgorithm) { + case CertificateFingerprint::Algorithm::Sha1: + return "sha-1"; + case CertificateFingerprint::Algorithm::Sha224: + return "sha-224"; + case CertificateFingerprint::Algorithm::Sha256: + return "sha-256"; + case CertificateFingerprint::Algorithm::Sha384: + return "sha-256"; + case CertificateFingerprint::Algorithm::Sha512: + return "sha-512"; + default: + return "unknown"; + } +} + +bool CertificateFingerprint::isValid() const { + size_t expectedSize = AlgorithmSize(this->algorithm); + if (expectedSize == 0 || this->value.size() != expectedSize * 3 - 1) { + return false; + } + + for (size_t i = 0; i < this->value.size(); ++i) { + if (i % 3 == 2) { + if (this->value[i] != ':') + return false; + } else { + if (!std::isxdigit(this->value[i])) + return false; + } + } + return true; +} + +std::ostream &operator<<(std::ostream &out, const Description &description) { + return out << string(description); +} + +std::ostream &operator<<(std::ostream &out, Description::Type type) { + return out << Description::typeToString(type); +} + +std::ostream &operator<<(std::ostream &out, Description::Role role) { + using Role = Description::Role; + // Used for SDP generation, do not change + switch (role) { + case Role::Active: + out << "active"; + break; + case Role::Passive: + out << "passive"; + break; + default: + out << "actpass"; + break; + } + return out; +} + +std::ostream &operator<<(std::ostream &out, const Description::Direction &direction) { + // Used for SDP generation, do not change + switch (direction) { + case Description::Direction::RecvOnly: + out << "recvonly"; + break; + case Description::Direction::SendOnly: + out << "sendonly"; + break; + case Description::Direction::SendRecv: + out << "sendrecv"; + break; + case Description::Direction::Inactive: + out << "inactive"; + break; + case Description::Direction::Unknown: + default: + out << "unknown"; + break; + } + return out; +} + +} // namespace rtc diff --git a/datachannel/src/global.cpp b/datachannel/src/global.cpp new file mode 100644 index 000000000..959053749 --- /dev/null +++ b/datachannel/src/global.cpp @@ -0,0 +1,118 @@ +/** + * Copyright (c) 2020-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "plog/Appenders/ColorConsoleAppender.h" +#include "plog/Converters/UTF8Converter.h" +#include "plog/Formatters/FuncMessageFormatter.h" +#include "plog/Formatters/TxtFormatter.h" +#include "plog/Init.h" +#include "plog/Log.h" +#include "plog/Logger.h" +// +#include "global.hpp" + +#include "impl/init.hpp" + +#include + +namespace { + +void plogInit(plog::Severity severity, plog::IAppender *appender) { + using Logger = plog::Logger; + static Logger *logger = nullptr; + if (!logger) { + PLOG_DEBUG << "Initializing logger"; + logger = new Logger(severity); + if (appender) { + logger->addAppender(appender); + } else { + using ConsoleAppender = plog::ColorConsoleAppender; + static ConsoleAppender *consoleAppender = new ConsoleAppender(); + logger->addAppender(consoleAppender); + } + } else { + logger->setMaxSeverity(severity); + if (appender) + logger->addAppender(appender); + } +} + +} // namespace + +namespace rtc { + +struct LogAppender : public plog::IAppender { + synchronized_callback callback; + + void write(const plog::Record &record) override { + const auto severity = record.getSeverity(); + auto formatted = plog::FuncMessageFormatter::format(record); + formatted.pop_back(); // remove newline + + const auto &converted = + plog::UTF8Converter::convert(formatted); // does nothing on non-Windows systems + + if (!callback(static_cast(severity), converted)) + std::cout << plog::severityToString(severity) << " " << converted << std::endl; + } +}; + +void InitLogger(LogLevel level, LogCallback callback) { + const auto severity = static_cast(level); + static LogAppender *appender = nullptr; + static std::mutex mutex; + std::lock_guard lock(mutex); + if (appender) { + appender->callback = std::move(callback); + plogInit(severity, nullptr); // change the severity + } else if (callback) { + appender = new LogAppender(); + appender->callback = std::move(callback); + plogInit(severity, appender); + } else { + plogInit(severity, nullptr); // log to cout + } +} + +void InitLogger(plog::Severity severity, plog::IAppender *appender) { + plogInit(severity, appender); +} + +void Preload() { impl::Init::Instance().preload(); } +std::shared_future Cleanup() { return impl::Init::Instance().cleanup(); } + +void SetSctpSettings(SctpSettings s) { impl::Init::Instance().setSctpSettings(std::move(s)); } + +RTC_CPP_EXPORT std::ostream &operator<<(std::ostream &out, LogLevel level) { + switch (level) { + case LogLevel::Fatal: + out << "fatal"; + break; + case LogLevel::Error: + out << "error"; + break; + case LogLevel::Warning: + out << "warning"; + break; + case LogLevel::Info: + out << "info"; + break; + case LogLevel::Debug: + out << "debug"; + break; + case LogLevel::Verbose: + out << "verbose"; + break; + default: + out << "none"; + break; + } + return out; +} + +} // namespace rtc diff --git a/datachannel/src/h264rtppacketizer.cpp b/datachannel/src/h264rtppacketizer.cpp new file mode 100644 index 000000000..cb01d00b2 --- /dev/null +++ b/datachannel/src/h264rtppacketizer.cpp @@ -0,0 +1,112 @@ +/** + * Copyright (c) 2020 Filip Klembara (in2core) + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#if RTC_ENABLE_MEDIA + +#include "h264rtppacketizer.hpp" + +#include "impl/internals.hpp" + +#include + +#ifdef _WIN32 +#include +#else +#include +#endif + +namespace rtc { + +shared_ptr H264RtpPacketizer::splitMessage(binary_ptr message) { + auto nalus = std::make_shared(); + if (separator == Separator::Length) { + size_t index = 0; + while (index < message->size()) { + assert(index + 4 < message->size()); + if (index + 4 >= message->size()) { + LOG_WARNING << "Invalid NAL Unit data (incomplete length), ignoring!"; + break; + } + auto lengthPtr = (uint32_t *)(message->data() + index); + uint32_t length = ntohl(*lengthPtr); + auto naluStartIndex = index + 4; + auto naluEndIndex = naluStartIndex + length; + + assert(naluEndIndex <= message->size()); + if (naluEndIndex > message->size()) { + LOG_WARNING << "Invalid NAL Unit data (incomplete unit), ignoring!"; + break; + } + auto begin = message->begin() + naluStartIndex; + auto end = message->begin() + naluEndIndex; + nalus->push_back(std::make_shared(begin, end)); + index = naluEndIndex; + } + } else { + NalUnitStartSequenceMatch match = NUSM_noMatch; + size_t index = 0; + while (index < message->size()) { + match = NalUnit::StartSequenceMatchSucc(match, (*message)[index++], separator); + if (match == NUSM_longMatch || match == NUSM_shortMatch) { + match = NUSM_noMatch; + break; + } + } + + size_t naluStartIndex = index; + + while (index < message->size()) { + match = NalUnit::StartSequenceMatchSucc(match, (*message)[index], separator); + if (match == NUSM_longMatch || match == NUSM_shortMatch) { + auto sequenceLength = match == NUSM_longMatch ? 4 : 3; + size_t naluEndIndex = index - sequenceLength; + match = NUSM_noMatch; + auto begin = message->begin() + naluStartIndex; + auto end = message->begin() + naluEndIndex + 1; + nalus->push_back(std::make_shared(begin, end)); + naluStartIndex = index + 1; + } + index++; + } + auto begin = message->begin() + naluStartIndex; + auto end = message->end(); + nalus->push_back(std::make_shared(begin, end)); + } + return nalus; +} + +H264RtpPacketizer::H264RtpPacketizer(shared_ptr rtpConfig, + uint16_t maxFragmentSize) + : RtpPacketizer(std::move(rtpConfig)), maxFragmentSize(maxFragmentSize), + separator(Separator::Length) {} + +H264RtpPacketizer::H264RtpPacketizer(Separator separator, + shared_ptr rtpConfig, + uint16_t maxFragmentSize) + : RtpPacketizer(rtpConfig), maxFragmentSize(maxFragmentSize), separator(separator) {} + +void H264RtpPacketizer::outgoing(message_vector &messages, [[maybe_unused]] const message_callback &send) { + message_vector result; + for(const auto &message : messages) { + auto nalus = splitMessage(message); + auto fragments = nalus->generateFragments(maxFragmentSize); + if (fragments.size() == 0) + continue; + + for (size_t i = 0; i < fragments.size() - 1; i++) + result.push_back(packetize(fragments[i], false)); + + result.push_back(packetize(fragments[fragments.size() - 1], true)); + } + + messages.swap(result); +} + +} // namespace rtc + +#endif /* RTC_ENABLE_MEDIA */ diff --git a/datachannel/src/h265nalunit.cpp b/datachannel/src/h265nalunit.cpp new file mode 100644 index 000000000..5fda10545 --- /dev/null +++ b/datachannel/src/h265nalunit.cpp @@ -0,0 +1,100 @@ +/** + * Copyright (c) 2023 Zita Liao (Dolby) + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#if RTC_ENABLE_MEDIA + +#include "h265nalunit.hpp" + +#include "impl/internals.hpp" + +#include + +namespace rtc { + +H265NalUnitFragment::H265NalUnitFragment(FragmentType type, bool forbiddenBit, uint8_t nuhLayerId, + uint8_t nuhTempIdPlus1, uint8_t unitType, binary data) + : H265NalUnit(data.size() + H265_NAL_HEADER_SIZE + H265_FU_HEADER_SIZE) { + setForbiddenBit(forbiddenBit); + setNuhLayerId(nuhLayerId); + setNuhTempIdPlus1(nuhTempIdPlus1); + fragmentIndicator()->setUnitType(H265NalUnitFragment::nal_type_fu); + setFragmentType(type); + setUnitType(unitType); + copy(data.begin(), data.end(), begin() + H265_NAL_HEADER_SIZE + H265_FU_HEADER_SIZE); +} + +std::vector> +H265NalUnitFragment::fragmentsFrom(shared_ptr nalu, uint16_t maxFragmentSize) { + assert(nalu->size() > maxFragmentSize); + auto fragments_count = ceil(double(nalu->size()) / maxFragmentSize); + maxFragmentSize = uint16_t(int(ceil(nalu->size() / fragments_count))); + + // 3 bytes for FU indicator and FU header + maxFragmentSize -= (H265_NAL_HEADER_SIZE + H265_FU_HEADER_SIZE); + auto f = nalu->forbiddenBit(); + uint8_t nuhLayerId = nalu->nuhLayerId() & 0x3F; // 6 bits + uint8_t nuhTempIdPlus1 = nalu->nuhTempIdPlus1() & 0x7; // 3 bits + uint8_t naluType = nalu->unitType() & 0x3F; // 6 bits + auto payload = nalu->payload(); + vector> result{}; + uint64_t offset = 0; + while (offset < payload.size()) { + vector fragmentData; + FragmentType fragmentType; + if (offset == 0) { + fragmentType = FragmentType::Start; + } else if (offset + maxFragmentSize < payload.size()) { + fragmentType = FragmentType::Middle; + } else { + if (offset + maxFragmentSize > payload.size()) { + maxFragmentSize = uint16_t(payload.size() - offset); + } + fragmentType = FragmentType::End; + } + fragmentData = {payload.begin() + offset, payload.begin() + offset + maxFragmentSize}; + auto fragment = std::make_shared( + fragmentType, f, nuhLayerId, nuhTempIdPlus1, naluType, fragmentData); + result.push_back(fragment); + offset += maxFragmentSize; + } + return result; +} + +void H265NalUnitFragment::setFragmentType(FragmentType type) { + switch (type) { + case FragmentType::Start: + fragmentHeader()->setStart(true); + fragmentHeader()->setEnd(false); + break; + case FragmentType::End: + fragmentHeader()->setStart(false); + fragmentHeader()->setEnd(true); + break; + default: + fragmentHeader()->setStart(false); + fragmentHeader()->setEnd(false); + } +} + +std::vector> H265NalUnits::generateFragments(uint16_t maxFragmentSize) { + vector> result{}; + for (auto nalu : *this) { + if (nalu->size() > maxFragmentSize) { + std::vector> fragments = + H265NalUnitFragment::fragmentsFrom(nalu, maxFragmentSize); + result.insert(result.end(), fragments.begin(), fragments.end()); + } else { + result.push_back(nalu); + } + } + return result; +} + +} // namespace rtc + +#endif /* RTC_ENABLE_MEDIA */ diff --git a/datachannel/src/h265rtppacketizer.cpp b/datachannel/src/h265rtppacketizer.cpp new file mode 100644 index 000000000..5776dafa8 --- /dev/null +++ b/datachannel/src/h265rtppacketizer.cpp @@ -0,0 +1,113 @@ +/** + * Copyright (c) 2023 Zita Liao (Dolby) + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#if RTC_ENABLE_MEDIA + +#include "h265rtppacketizer.hpp" + +#include "impl/internals.hpp" + +#include + +#ifdef _WIN32 +#include +#else +#include +#endif + +namespace rtc { + +shared_ptr H265RtpPacketizer::splitMessage(binary_ptr message) { + auto nalus = std::make_shared(); + if (separator == NalUnit::Separator::Length) { + size_t index = 0; + while (index < message->size()) { + assert(index + 4 < message->size()); + if (index + 4 >= message->size()) { + LOG_WARNING << "Invalid NAL Unit data (incomplete length), ignoring!"; + break; + } + auto lengthPtr = (uint32_t *)(message->data() + index); + uint32_t length = ntohl(*lengthPtr); + auto naluStartIndex = index + 4; + auto naluEndIndex = naluStartIndex + length; + + assert(naluEndIndex <= message->size()); + if (naluEndIndex > message->size()) { + LOG_WARNING << "Invalid NAL Unit data (incomplete unit), ignoring!"; + break; + } + auto begin = message->begin() + naluStartIndex; + auto end = message->begin() + naluEndIndex; + nalus->push_back(std::make_shared(begin, end)); + index = naluEndIndex; + } + } else { + NalUnitStartSequenceMatch match = NUSM_noMatch; + size_t index = 0; + while (index < message->size()) { + match = NalUnit::StartSequenceMatchSucc(match, (*message)[index++], separator); + if (match == NUSM_longMatch || match == NUSM_shortMatch) { + match = NUSM_noMatch; + break; + } + } + + size_t naluStartIndex = index; + + while (index < message->size()) { + match = NalUnit::StartSequenceMatchSucc(match, (*message)[index], separator); + if (match == NUSM_longMatch || match == NUSM_shortMatch) { + auto sequenceLength = match == NUSM_longMatch ? 4 : 3; + size_t naluEndIndex = index - sequenceLength; + match = NUSM_noMatch; + auto begin = message->begin() + naluStartIndex; + auto end = message->begin() + naluEndIndex + 1; + nalus->push_back(std::make_shared(begin, end)); + naluStartIndex = index + 1; + } + index++; + } + auto begin = message->begin() + naluStartIndex; + auto end = message->end(); + nalus->push_back(std::make_shared(begin, end)); + } + return nalus; +} + +H265RtpPacketizer::H265RtpPacketizer(shared_ptr rtpConfig, + uint16_t maxFragmentSize) + : RtpPacketizer(std::move(rtpConfig)), maxFragmentSize(maxFragmentSize), + separator(NalUnit::Separator::Length) {} + +H265RtpPacketizer::H265RtpPacketizer(NalUnit::Separator separator, + shared_ptr rtpConfig, + uint16_t maxFragmentSize) + : RtpPacketizer(std::move(rtpConfig)), maxFragmentSize(maxFragmentSize), + separator(separator) {} + +void H265RtpPacketizer::outgoing(message_vector &messages, [[maybe_unused]] const message_callback &send) { + message_vector result; + for (const auto &message : messages) { + auto nalus = splitMessage(message); + auto fragments = nalus->generateFragments(maxFragmentSize); + if (fragments.size() == 0) + continue; + + for (size_t i = 0; i < fragments.size() - 1; i++) + result.push_back(packetize(fragments[i], false)); + + result.push_back(packetize(fragments[fragments.size() - 1], true)); + } + + messages.swap(result); +} + +} // namespace rtc + +#endif /* RTC_ENABLE_MEDIA */ diff --git a/datachannel/src/impl/certificate.cpp b/datachannel/src/impl/certificate.cpp new file mode 100644 index 000000000..fb4f680e5 --- /dev/null +++ b/datachannel/src/impl/certificate.cpp @@ -0,0 +1,578 @@ +/** + * Copyright (c) 2019 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "certificate.hpp" +#include "threadpool.hpp" + +#include +#include +#include +#include +#include +#include + +namespace rtc::impl { + +#if USE_GNUTLS + +Certificate Certificate::FromString(string crt_pem, string key_pem) { + PLOG_DEBUG << "Importing certificate from PEM string (GnuTLS)"; + + shared_ptr creds(gnutls::new_credentials(), + gnutls::free_credentials); + gnutls_datum_t crt_datum = gnutls::make_datum(crt_pem.data(), crt_pem.size()); + gnutls_datum_t key_datum = gnutls::make_datum(key_pem.data(), key_pem.size()); + gnutls::check( + gnutls_certificate_set_x509_key_mem(*creds, &crt_datum, &key_datum, GNUTLS_X509_FMT_PEM), + "Unable to import PEM certificate and key"); + + return Certificate(std::move(creds)); +} + +Certificate Certificate::FromFile(const string &crt_pem_file, const string &key_pem_file, + const string &pass) { + PLOG_DEBUG << "Importing certificate from PEM file (GnuTLS): " << crt_pem_file; + + shared_ptr creds(gnutls::new_credentials(), + gnutls::free_credentials); + gnutls::check(gnutls_certificate_set_x509_key_file2(*creds, crt_pem_file.c_str(), + key_pem_file.c_str(), GNUTLS_X509_FMT_PEM, + pass.c_str(), 0), + "Unable to import PEM certificate and key from file"); + + return Certificate(std::move(creds)); +} + +Certificate Certificate::Generate(CertificateType type, const string &commonName) { + PLOG_DEBUG << "Generating certificate (GnuTLS)"; + + using namespace gnutls; + unique_ptr crt(new_crt(), free_crt); + unique_ptr privkey(new_privkey(), free_privkey); + + switch (type) { + // RFC 8827 WebRTC Security Architecture 6.5. Communications Security + // All implementations MUST support DTLS 1.2 with the TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 + // cipher suite and the P-256 curve + // See https://www.rfc-editor.org/rfc/rfc8827.html#section-6.5 + case CertificateType::Default: + case CertificateType::Ecdsa: { + gnutls::check(gnutls_x509_privkey_generate(*privkey, GNUTLS_PK_ECDSA, + GNUTLS_CURVE_TO_BITS(GNUTLS_ECC_CURVE_SECP256R1), + 0), + "Unable to generate ECDSA P-256 key pair"); + break; + } + case CertificateType::Rsa: { + const unsigned int bits = 2048; + gnutls::check(gnutls_x509_privkey_generate(*privkey, GNUTLS_PK_RSA, bits, 0), + "Unable to generate RSA key pair"); + break; + } + default: + throw std::invalid_argument("Unknown certificate type"); + } + + using namespace std::chrono; + auto now = time_point_cast(system_clock::now()); + gnutls_x509_crt_set_activation_time(*crt, (now - hours(1)).time_since_epoch().count()); + gnutls_x509_crt_set_expiration_time(*crt, (now + hours(24 * 365)).time_since_epoch().count()); + gnutls_x509_crt_set_version(*crt, 1); + gnutls_x509_crt_set_key(*crt, *privkey); + gnutls_x509_crt_set_dn_by_oid(*crt, GNUTLS_OID_X520_COMMON_NAME, 0, commonName.data(), + commonName.size()); + + const size_t serialSize = 16; + char serial[serialSize]; + gnutls_rnd(GNUTLS_RND_NONCE, serial, serialSize); + gnutls_x509_crt_set_serial(*crt, serial, serialSize); + + gnutls::check(gnutls_x509_crt_sign2(*crt, *crt, *privkey, GNUTLS_DIG_SHA256, 0), + "Unable to auto-sign certificate"); + + return Certificate(*crt, *privkey); +} + +Certificate::Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey) + : mCredentials(gnutls::new_credentials(), gnutls::free_credentials), + mFingerprint(make_fingerprint(crt, CertificateFingerprint::Algorithm::Sha256)) { + + gnutls::check(gnutls_certificate_set_x509_key(*mCredentials, &crt, 1, privkey), + "Unable to set certificate and key pair in credentials"); +} + +Certificate::Certificate(shared_ptr creds) + : mCredentials(std::move(creds)), + mFingerprint(make_fingerprint(*mCredentials, CertificateFingerprint::Algorithm::Sha256)) {} + +gnutls_certificate_credentials_t Certificate::credentials() const { return *mCredentials; } + +string make_fingerprint(gnutls_certificate_credentials_t credentials, + CertificateFingerprint::Algorithm fingerprintAlgorithm) { + auto new_crt_list = [credentials]() -> gnutls_x509_crt_t * { + gnutls_x509_crt_t *crt_list = nullptr; + unsigned int crt_list_size = 0; + gnutls::check(gnutls_certificate_get_x509_crt(credentials, 0, &crt_list, &crt_list_size)); + assert(crt_list_size == 1); + return crt_list; + }; + + auto free_crt_list = [](gnutls_x509_crt_t *crt_list) { + gnutls_x509_crt_deinit(crt_list[0]); + gnutls_free(crt_list); + }; + + unique_ptr crt_list(new_crt_list(), free_crt_list); + + return make_fingerprint(*crt_list, fingerprintAlgorithm); +} + +string make_fingerprint(gnutls_x509_crt_t crt, + CertificateFingerprint::Algorithm fingerprintAlgorithm) { + const size_t size = CertificateFingerprint::AlgorithmSize(fingerprintAlgorithm); + std::vector buffer(size); + size_t len = size; + + gnutls_digest_algorithm_t hashFunc; + switch (fingerprintAlgorithm) { + case CertificateFingerprint::Algorithm::Sha1: + hashFunc = GNUTLS_DIG_SHA1; + break; + case CertificateFingerprint::Algorithm::Sha224: + hashFunc = GNUTLS_DIG_SHA224; + break; + case CertificateFingerprint::Algorithm::Sha256: + hashFunc = GNUTLS_DIG_SHA256; + break; + case CertificateFingerprint::Algorithm::Sha384: + hashFunc = GNUTLS_DIG_SHA384; + break; + case CertificateFingerprint::Algorithm::Sha512: + hashFunc = GNUTLS_DIG_SHA512; + break; + default: + throw std::invalid_argument("Unknown fingerprint algorithm"); + } + + gnutls::check(gnutls_x509_crt_get_fingerprint(crt, hashFunc, buffer.data(), &len), + "X509 fingerprint error"); + + std::ostringstream oss; + oss << std::hex << std::uppercase << std::setfill('0'); + for (size_t i = 0; i < len; ++i) { + if (i) + oss << std::setw(1) << ':'; + oss << std::setw(2) << unsigned(buffer.at(i)); + } + return oss.str(); +} + +#elif USE_MBEDTLS +string make_fingerprint(mbedtls_x509_crt *crt, + CertificateFingerprint::Algorithm fingerprintAlgorithm) { + const int size = CertificateFingerprint::AlgorithmSize(fingerprintAlgorithm); + std::vector buffer(size); + std::stringstream fingerprint; + + switch (fingerprintAlgorithm) { + case CertificateFingerprint::Algorithm::Sha1: + mbedtls::check(mbedtls_sha1(crt->raw.p, crt->raw.len, buffer.data()), + "Failed to generate certificate fingerprint"); + break; + case CertificateFingerprint::Algorithm::Sha224: + mbedtls::check(mbedtls_sha256(crt->raw.p, crt->raw.len, buffer.data(), 1), + "Failed to generate certificate fingerprint"); + + break; + case CertificateFingerprint::Algorithm::Sha256: + mbedtls::check(mbedtls_sha256(crt->raw.p, crt->raw.len, buffer.data(), 0), + "Failed to generate certificate fingerprint"); + break; + case CertificateFingerprint::Algorithm::Sha384: + mbedtls::check(mbedtls_sha512(crt->raw.p, crt->raw.len, buffer.data(), 1), + "Failed to generate certificate fingerprint"); + break; + case CertificateFingerprint::Algorithm::Sha512: + mbedtls::check(mbedtls_sha512(crt->raw.p, crt->raw.len, buffer.data(), 0), + "Failed to generate certificate fingerprint"); + break; + default: + throw std::invalid_argument("Unknown fingerprint algorithm"); + } + + for (auto i = 0; i < size; i++) { + fingerprint << std::setfill('0') << std::setw(2) << std::hex + << static_cast(buffer.at(i)); + if (i != (size - 1)) { + fingerprint << ":"; + } + } + + return fingerprint.str(); +} + +Certificate::Certificate(shared_ptr crt, shared_ptr pk) + : mCrt(crt), mPk(pk), + mFingerprint(make_fingerprint(crt.get(), CertificateFingerprint::Algorithm::Sha256)) {} + +Certificate Certificate::FromString(string crt_pem, string key_pem) { + PLOG_DEBUG << "Importing certificate from PEM string (MbedTLS)"; + + auto crt = mbedtls::new_x509_crt(); + auto pk = mbedtls::new_pk_context(); + + mbedtls::check(mbedtls_x509_crt_parse(crt.get(), + reinterpret_cast(crt_pem.c_str()), + crt_pem.length()), + "Failed to parse certificate"); + mbedtls::check(mbedtls_pk_parse_key(pk.get(), + reinterpret_cast(key_pem.c_str()), + key_pem.size(), NULL, 0, NULL, 0), + "Failed to parse key"); + + return Certificate(std::move(crt), std::move(pk)); +} + +Certificate Certificate::FromFile(const string &crt_pem_file, const string &key_pem_file, + const string &pass) { + PLOG_DEBUG << "Importing certificate from PEM file (MbedTLS): " << crt_pem_file; + + auto crt = mbedtls::new_x509_crt(); + auto pk = mbedtls::new_pk_context(); + + mbedtls::check(mbedtls_x509_crt_parse_file(crt.get(), crt_pem_file.c_str()), + "Failed to parse certificate"); + mbedtls::check(mbedtls_pk_parse_keyfile(pk.get(), key_pem_file.c_str(), pass.c_str(), 0, NULL), + "Failed to parse key"); + + return Certificate(std::move(crt), std::move(pk)); +} + +Certificate Certificate::Generate(CertificateType type, const string &commonName) { + PLOG_DEBUG << "Generating certificate (MbedTLS)"; + + mbedtls_entropy_context entropy; + mbedtls_ctr_drbg_context drbg; + mbedtls_x509write_cert wcrt; + mbedtls_mpi serial; + auto crt = mbedtls::new_x509_crt(); + auto pk = mbedtls::new_pk_context(); + + mbedtls_entropy_init(&entropy); + mbedtls_ctr_drbg_init(&drbg); + mbedtls_ctr_drbg_set_prediction_resistance(&drbg, MBEDTLS_CTR_DRBG_PR_ON); + mbedtls_x509write_crt_init(&wcrt); + mbedtls_mpi_init(&serial); + + try { + mbedtls::check(mbedtls_ctr_drbg_seed( + &drbg, mbedtls_entropy_func, &entropy, + reinterpret_cast(commonName.data()), commonName.size())); + + switch (type) { + // RFC 8827 WebRTC Security Architecture 6.5. Communications Security + // All implementations MUST support DTLS 1.2 with the + // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 cipher suite and the P-256 curve + // See https://www.rfc-editor.org/rfc/rfc8827.html#section-6.5 + case CertificateType::Default: + case CertificateType::Ecdsa: { + mbedtls::check(mbedtls_pk_setup(pk.get(), mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY))); + mbedtls::check(mbedtls_ecp_gen_key(MBEDTLS_ECP_DP_SECP256R1, mbedtls_pk_ec(*pk.get()), + mbedtls_ctr_drbg_random, &drbg), + "Unable to generate ECDSA P-256 key pair"); + break; + } + case CertificateType::Rsa: { + const unsigned int nbits = 2048; + const int exponent = 65537; + + mbedtls::check(mbedtls_pk_setup(pk.get(), mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))); + mbedtls::check(mbedtls_rsa_gen_key(mbedtls_pk_rsa(*pk.get()), mbedtls_ctr_drbg_random, + &drbg, nbits, exponent), + "Unable to generate RSA key pair"); + break; + } + default: + throw std::invalid_argument("Unknown certificate type"); + } + + auto now = std::chrono::system_clock::now(); + string notBefore = mbedtls::format_time(now - std::chrono::hours(1)); + string notAfter = mbedtls::format_time(now + std::chrono::hours(24 * 365)); + + const size_t serialBufferSize = 16; + unsigned char serialBuffer[serialBufferSize]; + mbedtls::check(mbedtls_ctr_drbg_random(&drbg, serialBuffer, serialBufferSize), + "Failed to generate certificate"); + mbedtls::check(mbedtls_mpi_read_binary(&serial, serialBuffer, serialBufferSize), + "Failed to generate certificate"); + + std::string name = std::string("O=" + commonName + ",CN=" + commonName); + mbedtls::check(mbedtls_x509write_crt_set_serial(&wcrt, &serial), + "Failed to generate certificate"); + mbedtls::check(mbedtls_x509write_crt_set_subject_name(&wcrt, name.c_str()), + "Failed to generate certificate"); + mbedtls::check(mbedtls_x509write_crt_set_issuer_name(&wcrt, name.c_str()), + "Failed to generate certificate"); + mbedtls::check( + mbedtls_x509write_crt_set_validity(&wcrt, notBefore.c_str(), notAfter.c_str()), + "Failed to generate certificate"); + + mbedtls_x509write_crt_set_version(&wcrt, MBEDTLS_X509_CRT_VERSION_3); + mbedtls_x509write_crt_set_subject_key(&wcrt, pk.get()); + mbedtls_x509write_crt_set_issuer_key(&wcrt, pk.get()); + mbedtls_x509write_crt_set_md_alg(&wcrt, MBEDTLS_MD_SHA256); + + const size_t certificateBufferSize = 4096; + unsigned char certificateBuffer[certificateBufferSize]; + std::memset(certificateBuffer, 0, certificateBufferSize); + + auto certificateLen = mbedtls_x509write_crt_der( + &wcrt, certificateBuffer, certificateBufferSize, mbedtls_ctr_drbg_random, &drbg); + if (certificateLen <= 0) { + throw std::runtime_error("Certificate generation failed"); + } + + mbedtls::check(mbedtls_x509_crt_parse_der( + crt.get(), (certificateBuffer + certificateBufferSize - certificateLen), + certificateLen), + "Failed to generate certificate"); + } catch (...) { + mbedtls_entropy_free(&entropy); + mbedtls_ctr_drbg_free(&drbg); + mbedtls_x509write_crt_free(&wcrt); + mbedtls_mpi_free(&serial); + throw; + } + + mbedtls_entropy_free(&entropy); + mbedtls_ctr_drbg_free(&drbg); + mbedtls_x509write_crt_free(&wcrt); + mbedtls_mpi_free(&serial); + return Certificate(std::move(crt), std::move(pk)); +} + +std::tuple, shared_ptr> +Certificate::credentials() const { + return {mCrt, mPk}; +} + +#else // OPENSSL + +#include +#include +#include + +namespace { + +// Dummy password callback that copies the password from user data +int dummy_pass_cb(char *buf, int size, int /*rwflag*/, void *u) { + const char *pass = static_cast(u); + return snprintf(buf, size, "%s", pass); +} + +} // namespace + +Certificate Certificate::FromString(string crt_pem, string key_pem) { + PLOG_DEBUG << "Importing certificate from PEM string (OpenSSL)"; + + BIO *bio = BIO_new(BIO_s_mem()); + BIO_write(bio, crt_pem.data(), int(crt_pem.size())); + auto x509 = shared_ptr(PEM_read_bio_X509(bio, nullptr, nullptr, nullptr), X509_free); + BIO_free(bio); + if (!x509) + throw std::invalid_argument("Unable to import PEM certificate"); + + bio = BIO_new(BIO_s_mem()); + BIO_write(bio, key_pem.data(), int(key_pem.size())); + auto pkey = shared_ptr(PEM_read_bio_PrivateKey(bio, nullptr, nullptr, nullptr), + EVP_PKEY_free); + BIO_free(bio); + if (!pkey) + throw std::invalid_argument("Unable to import PEM key"); + + return Certificate(x509, pkey); +} + +Certificate Certificate::FromFile(const string &crt_pem_file, const string &key_pem_file, + const string &pass) { + PLOG_DEBUG << "Importing certificate from PEM file (OpenSSL): " << crt_pem_file; + + BIO *bio = openssl::BIO_new_from_file(crt_pem_file); + if (!bio) + throw std::invalid_argument("Unable to open PEM certificate file"); + + auto x509 = shared_ptr(PEM_read_bio_X509(bio, nullptr, nullptr, nullptr), X509_free); + BIO_free(bio); + if (!x509) + throw std::invalid_argument("Unable to import PEM certificate from file"); + + bio = openssl::BIO_new_from_file(key_pem_file); + if (!bio) + throw std::invalid_argument("Unable to open PEM key file"); + + auto pkey = shared_ptr( + PEM_read_bio_PrivateKey(bio, nullptr, dummy_pass_cb, const_cast(pass.c_str())), + EVP_PKEY_free); + BIO_free(bio); + if (!pkey) + throw std::invalid_argument("Unable to import PEM key from file"); + + return Certificate(x509, pkey); +} + +Certificate Certificate::Generate(CertificateType type, const string &commonName) { + PLOG_DEBUG << "Generating certificate (OpenSSL)"; + + shared_ptr x509(X509_new(), X509_free); + unique_ptr serial_number(BN_new(), BN_free); + unique_ptr name(X509_NAME_new(), X509_NAME_free); + if (!x509 || !serial_number || !name) + throw std::runtime_error("Unable to allocate structures for certificate generation"); + + shared_ptr pkey; + switch (type) { + // RFC 8827 WebRTC Security Architecture 6.5. Communications Security + // All implementations MUST support DTLS 1.2 with the TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 + // cipher suite and the P-256 curve + // See https://www.rfc-editor.org/rfc/rfc8827.html#section-6.5 + case CertificateType::Default: + case CertificateType::Ecdsa: { + PLOG_VERBOSE << "Generating ECDSA P-256 key pair"; +#if OPENSSL_VERSION_NUMBER >= 0x30000000 + pkey = shared_ptr(EVP_EC_gen("prime256v1"), EVP_PKEY_free); + if (!pkey) + throw std::runtime_error("Unable to generate ECDSA P-256 key pair"); +#else + pkey = shared_ptr(EVP_PKEY_new(), EVP_PKEY_free); + unique_ptr ecc( + EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free); + if (!pkey || !ecc) + throw std::runtime_error("Unable to allocate structure for ECDSA P-256 key pair"); + + EC_KEY_set_asn1_flag(ecc.get(), OPENSSL_EC_NAMED_CURVE); // Set ASN1 OID + if (!EC_KEY_generate_key(ecc.get()) || !EVP_PKEY_assign_EC_KEY(pkey.get(), ecc.get())) + throw std::runtime_error("Unable to generate ECDSA P-256 key pair"); + + ecc.release(); // the key will be freed when pkey is freed +#endif + break; + } + case CertificateType::Rsa: { + PLOG_VERBOSE << "Generating RSA key pair"; + const unsigned int bits = 2048; +#if OPENSSL_VERSION_NUMBER >= 0x30000000 + pkey = shared_ptr(EVP_RSA_gen(bits), EVP_PKEY_free); + if (!pkey) + throw std::runtime_error("Unable to generate RSA key pair"); +#else + pkey = shared_ptr(EVP_PKEY_new(), EVP_PKEY_free); + unique_ptr rsa(RSA_new(), RSA_free); + unique_ptr exponent(BN_new(), BN_free); + if (!pkey || !rsa || !exponent) + throw std::runtime_error("Unable to allocate structures for RSA key pair"); + + const unsigned int e = 65537; // 2^16 + 1 + if (!BN_set_word(exponent.get(), e) || + !RSA_generate_key_ex(rsa.get(), bits, exponent.get(), NULL) || + !EVP_PKEY_assign_RSA(pkey.get(), rsa.get())) + throw std::runtime_error("Unable to generate RSA key pair"); + + rsa.release(); // the key will be freed when pkey is freed +#endif + break; + } + default: + throw std::invalid_argument("Unknown certificate type"); + } + + const size_t serialSize = 16; + auto *commonNameBytes = + reinterpret_cast(const_cast(commonName.c_str())); + + if (!X509_set_pubkey(x509.get(), pkey.get())) + throw std::runtime_error("Unable to set certificate public key"); + + if (!X509_gmtime_adj(X509_getm_notBefore(x509.get()), 3600 * -1) || + !X509_gmtime_adj(X509_getm_notAfter(x509.get()), 3600 * 24 * 365) || + !X509_set_version(x509.get(), 1) || !BN_rand(serial_number.get(), serialSize, 0, 0) || + !BN_to_ASN1_INTEGER(serial_number.get(), X509_get_serialNumber(x509.get())) || + !X509_NAME_add_entry_by_NID(name.get(), NID_commonName, MBSTRING_UTF8, commonNameBytes, -1, + -1, 0) || + !X509_set_subject_name(x509.get(), name.get()) || + !X509_set_issuer_name(x509.get(), name.get())) + throw std::runtime_error("Unable to set certificate properties"); + + if (!X509_sign(x509.get(), pkey.get(), EVP_sha256())) + throw std::runtime_error("Unable to auto-sign certificate"); + + return Certificate(x509, pkey); +} + +Certificate::Certificate(shared_ptr x509, shared_ptr pkey) + : mX509(std::move(x509)), mPKey(std::move(pkey)), + mFingerprint(make_fingerprint(mX509.get(), CertificateFingerprint::Algorithm::Sha256)) {} + +std::tuple Certificate::credentials() const { + return {mX509.get(), mPKey.get()}; +} + +string make_fingerprint(X509 *x509, CertificateFingerprint::Algorithm fingerprintAlgorithm) { + size_t size = CertificateFingerprint::AlgorithmSize(fingerprintAlgorithm); + std::vector buffer(size); + auto len = static_cast(size); + + const EVP_MD *hashFunc; + switch (fingerprintAlgorithm) { + case CertificateFingerprint::Algorithm::Sha1: + hashFunc = EVP_sha1(); + break; + case CertificateFingerprint::Algorithm::Sha224: + hashFunc = EVP_sha224(); + break; + case CertificateFingerprint::Algorithm::Sha256: + hashFunc = EVP_sha256(); + break; + case CertificateFingerprint::Algorithm::Sha384: + hashFunc = EVP_sha384(); + break; + case CertificateFingerprint::Algorithm::Sha512: + hashFunc = EVP_sha512(); + break; + default: + throw std::invalid_argument("Unknown fingerprint algorithm"); + } + + if (!X509_digest(x509, hashFunc, buffer.data(), &len)) + throw std::runtime_error("X509 fingerprint error"); + + std::ostringstream oss; + oss << std::hex << std::uppercase << std::setfill('0'); + for (size_t i = 0; i < len; ++i) { + if (i) + oss << std::setw(1) << ':'; + oss << std::setw(2) << unsigned(buffer.at(i)); + } + return oss.str(); +} + +#endif + +// Common for GnuTLS, Mbed TLS, and OpenSSL + +future_certificate_ptr make_certificate(CertificateType type) { + return ThreadPool::Instance().enqueue([type, token = Init::Instance().token()]() { + return std::make_shared(Certificate::Generate(type, "libdatachannel")); + }); +} + +CertificateFingerprint Certificate::fingerprint() const { + return CertificateFingerprint{CertificateFingerprint::Algorithm::Sha256, mFingerprint}; +} + +} // namespace rtc::impl diff --git a/datachannel/src/impl/certificate.hpp b/datachannel/src/impl/certificate.hpp new file mode 100644 index 000000000..66111ccb4 --- /dev/null +++ b/datachannel/src/impl/certificate.hpp @@ -0,0 +1,76 @@ +/** + * Copyright (c) 2019 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_CERTIFICATE_H +#define RTC_IMPL_CERTIFICATE_H + +#include "description.hpp" // for CertificateFingerprint +#include "common.hpp" +#include "configuration.hpp" // for CertificateType +#include "init.hpp" +#include "tls.hpp" + +#include +#include + +namespace rtc::impl { + +class Certificate { +public: + static Certificate FromString(string crt_pem, string key_pem); + static Certificate FromFile(const string &crt_pem_file, const string &key_pem_file, + const string &pass = ""); + static Certificate Generate(CertificateType type, const string &commonName); + +#if USE_GNUTLS + Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey); + gnutls_certificate_credentials_t credentials() const; +#elif USE_MBEDTLS + Certificate(shared_ptr crt, shared_ptr pk); + std::tuple, shared_ptr> credentials() const; +#else // OPENSSL + Certificate(shared_ptr x509, shared_ptr pkey); + std::tuple credentials() const; +#endif + + CertificateFingerprint fingerprint() const; + +private: + const init_token mInitToken = Init::Instance().token(); + +#if USE_GNUTLS + Certificate(shared_ptr creds); + const shared_ptr mCredentials; +#elif USE_MBEDTLS + const shared_ptr mCrt; + const shared_ptr mPk; +#else + const shared_ptr mX509; + const shared_ptr mPKey; +#endif + + const string mFingerprint; +}; + +#if USE_GNUTLS +string make_fingerprint(gnutls_certificate_credentials_t credentials, CertificateFingerprint::Algorithm fingerprintAlgorithm); +string make_fingerprint(gnutls_x509_crt_t crt, CertificateFingerprint::Algorithm fingerprintAlgorithm); +#elif USE_MBEDTLS +string make_fingerprint(mbedtls_x509_crt *crt, CertificateFingerprint::Algorithm fingerprintAlgorithm); +#else +string make_fingerprint(X509 *x509, CertificateFingerprint::Algorithm fingerprintAlgorithm); +#endif + +using certificate_ptr = shared_ptr; +using future_certificate_ptr = std::shared_future; + +future_certificate_ptr make_certificate(CertificateType type = CertificateType::Default); + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/channel.cpp b/datachannel/src/impl/channel.cpp new file mode 100644 index 000000000..e545f8ede --- /dev/null +++ b/datachannel/src/impl/channel.cpp @@ -0,0 +1,96 @@ +/** + * Copyright (c) 2019-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "channel.hpp" +#include "internals.hpp" + +namespace rtc::impl { + +void Channel::triggerOpen() { + mOpenTriggered = true; + try { + openCallback(); + } catch (const std::exception &e) { + PLOG_WARNING << "Uncaught exception in callback: " << e.what(); + } + flushPendingMessages(); +} + +void Channel::triggerClosed() { + try { + closedCallback(); + } catch (const std::exception &e) { + PLOG_WARNING << "Uncaught exception in callback: " << e.what(); + } +} + +void Channel::triggerError(string error) { + try { + errorCallback(std::move(error)); + } catch (const std::exception &e) { + PLOG_WARNING << "Uncaught exception in callback: " << e.what(); + } +} + +void Channel::triggerAvailable(size_t count) { + if (count == 1) { + try { + availableCallback(); + } catch (const std::exception &e) { + PLOG_WARNING << "Uncaught exception in callback: " << e.what(); + } + } + + flushPendingMessages(); +} + +void Channel::triggerBufferedAmount(size_t amount) { + size_t previous = bufferedAmount.exchange(amount); + size_t threshold = bufferedAmountLowThreshold.load(); + if (previous > threshold && amount <= threshold) { + try { + bufferedAmountLowCallback(); + } catch (const std::exception &e) { + PLOG_WARNING << "Uncaught exception in callback: " << e.what(); + } + } +} + +void Channel::flushPendingMessages() { + if (!mOpenTriggered) + return; + + while (messageCallback) { + auto next = receive(); + if (!next) + break; + + try { + messageCallback(*next); + } catch (const std::exception &e) { + PLOG_WARNING << "Uncaught exception in callback: " << e.what(); + } + } +} + +void Channel::resetOpenCallback() { + mOpenTriggered = false; + openCallback = nullptr; +} + +void Channel::resetCallbacks() { + mOpenTriggered = false; + openCallback = nullptr; + closedCallback = nullptr; + errorCallback = nullptr; + availableCallback = nullptr; + bufferedAmountLowCallback = nullptr; + messageCallback = nullptr; +} + +} // namespace rtc::impl diff --git a/datachannel/src/impl/channel.hpp b/datachannel/src/impl/channel.hpp new file mode 100644 index 000000000..85093b8d5 --- /dev/null +++ b/datachannel/src/impl/channel.hpp @@ -0,0 +1,52 @@ +/** + * Copyright (c) 2019-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_CHANNEL_H +#define RTC_IMPL_CHANNEL_H + +#include "common.hpp" +#include "message.hpp" + +#include +#include + +namespace rtc::impl { + +struct Channel { + virtual optional receive() = 0; + virtual optional peek() = 0; + virtual size_t availableAmount() const = 0; + + virtual void triggerOpen(); + virtual void triggerClosed(); + virtual void triggerError(string error); + virtual void triggerAvailable(size_t count); + virtual void triggerBufferedAmount(size_t amount); + + void flushPendingMessages(); + void resetOpenCallback(); + void resetCallbacks(); + + synchronized_stored_callback<> openCallback; + synchronized_stored_callback<> closedCallback; + synchronized_stored_callback errorCallback; + synchronized_stored_callback<> availableCallback; + synchronized_stored_callback<> bufferedAmountLowCallback; + + synchronized_callback messageCallback; + + std::atomic bufferedAmount = 0; + std::atomic bufferedAmountLowThreshold = 0; + +private: + std::atomic mOpenTriggered = false; +}; + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/datachannel.cpp b/datachannel/src/impl/datachannel.cpp new file mode 100644 index 000000000..044952280 --- /dev/null +++ b/datachannel/src/impl/datachannel.cpp @@ -0,0 +1,393 @@ +/** + * Copyright (c) 2019-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "datachannel.hpp" +#include "common.hpp" +#include "internals.hpp" +#include "logcounter.hpp" +#include "peerconnection.hpp" +#include "sctptransport.hpp" +#include "utils.hpp" +#include "rtc/datachannel.hpp" +#include "rtc/track.hpp" + +#include + +#ifdef _WIN32 +#include +#else +#include +#endif + +using std::chrono::milliseconds; + +namespace rtc::impl { + +using utils::to_uint16; +using utils::to_uint32; + +// Messages for the DataChannel establishment protocol (RFC 8832) +// See https://www.rfc-editor.org/rfc/rfc8832.html + +enum MessageType : uint8_t { + MESSAGE_OPEN_REQUEST = 0x00, + MESSAGE_OPEN_RESPONSE = 0x01, + MESSAGE_ACK = 0x02, + MESSAGE_OPEN = 0x03 +}; + +enum ChannelType : uint8_t { + CHANNEL_RELIABLE = 0x00, + CHANNEL_PARTIAL_RELIABLE_REXMIT = 0x01, + CHANNEL_PARTIAL_RELIABLE_TIMED = 0x02 +}; + +#pragma pack(push, 1) +struct OpenMessage { + uint8_t type = MESSAGE_OPEN; + uint8_t channelType; + uint16_t priority; + uint32_t reliabilityParameter; + uint16_t labelLength; + uint16_t protocolLength; + // The following fields are: + // uint8_t[labelLength] label + // uint8_t[protocolLength] protocol +}; + +struct AckMessage { + uint8_t type = MESSAGE_ACK; +}; + +#pragma pack(pop) + +bool DataChannel::IsOpenMessage(message_ptr message) { + if (message->type != Message::Control) + return false; + + auto raw = reinterpret_cast(message->data()); + return !message->empty() && raw[0] == MESSAGE_OPEN; +} + +DataChannel::DataChannel(weak_ptr pc, string label, string protocol, + Reliability reliability) + : mPeerConnection(pc), mLabel(std::move(label)), mProtocol(std::move(protocol)), + mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) { + + if(reliability.maxPacketLifeTime && reliability.maxRetransmits) + throw std::invalid_argument("Both maxPacketLifeTime and maxRetransmits are set"); + + mReliability = std::make_shared(std::move(reliability)); +} + +DataChannel::~DataChannel() { + PLOG_VERBOSE << "Destroying DataChannel"; + try { + close(); + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + } +} + +void DataChannel::close() { + PLOG_VERBOSE << "Closing DataChannel"; + + shared_ptr transport; + { + std::shared_lock lock(mMutex); + transport = mSctpTransport.lock(); + } + + if (!mIsClosed.exchange(true)) { + if (transport && mStream.has_value()) + transport->closeStream(mStream.value()); + + triggerClosed(); + } + + resetCallbacks(); +} + +void DataChannel::remoteClose() { close(); } + +optional DataChannel::receive() { + auto next = mRecvQueue.pop(); + return next ? std::make_optional(to_variant(std::move(**next))) : nullopt; +} + +optional DataChannel::peek() { + auto next = mRecvQueue.peek(); + return next ? std::make_optional(to_variant(**next)) : nullopt; +} + +size_t DataChannel::availableAmount() const { return mRecvQueue.amount(); } + +optional DataChannel::stream() const { + std::shared_lock lock(mMutex); + return mStream; +} + +string DataChannel::label() const { + std::shared_lock lock(mMutex); + return mLabel; +} + +string DataChannel::protocol() const { + std::shared_lock lock(mMutex); + return mProtocol; +} + +Reliability DataChannel::reliability() const { + std::shared_lock lock(mMutex); + return *mReliability; +} + +bool DataChannel::isOpen(void) const { return !mIsClosed && mIsOpen; } + +bool DataChannel::isClosed(void) const { return mIsClosed; } + +size_t DataChannel::maxMessageSize() const { + auto pc = mPeerConnection.lock(); + return pc ? pc->remoteMaxMessageSize() : DEFAULT_REMOTE_MAX_MESSAGE_SIZE; +} + +void DataChannel::assignStream(uint16_t stream) { + std::unique_lock lock(mMutex); + + if (mStream.has_value()) + throw std::logic_error("DataChannel already has a stream assigned"); + + mStream = stream; +} + +void DataChannel::open(shared_ptr transport) { + { + std::unique_lock lock(mMutex); + mSctpTransport = transport; + } + + if (!mIsClosed && !mIsOpen.exchange(true)) + triggerOpen(); +} + +void DataChannel::processOpenMessage(message_ptr) { + PLOG_WARNING << "Received an open message for a user-negotiated DataChannel, ignoring"; +} + +bool DataChannel::outgoing(message_ptr message) { + shared_ptr transport; + { + std::shared_lock lock(mMutex); + transport = mSctpTransport.lock(); + + if (!transport || mIsClosed) + throw std::runtime_error("DataChannel is closed"); + + if (!mStream.has_value()) + throw std::logic_error("DataChannel has no stream assigned"); + + if (message->size() > maxMessageSize()) + throw std::invalid_argument("Message size exceeds limit"); + + // Before the ACK has been received on a DataChannel, all messages must be sent ordered + message->reliability = mIsOpen ? mReliability : nullptr; + message->stream = mStream.value(); + } + + return transport->send(message); +} + +void DataChannel::incoming(message_ptr message) { + if (!message || mIsClosed) + return; + + switch (message->type) { + case Message::Control: { + if (message->size() == 0) + break; // Ignore + auto raw = reinterpret_cast(message->data()); + switch (raw[0]) { + case MESSAGE_OPEN: + processOpenMessage(message); + break; + case MESSAGE_ACK: + if (!mIsOpen.exchange(true)) { + triggerOpen(); + } + break; + default: + // Ignore + break; + } + break; + } + case Message::Reset: + remoteClose(); + break; + case Message::String: + case Message::Binary: + mRecvQueue.push(message); + triggerAvailable(mRecvQueue.size()); + break; + default: + // Ignore + break; + } +} + +OutgoingDataChannel::OutgoingDataChannel(weak_ptr pc, string label, string protocol, + Reliability reliability) + : DataChannel(pc, std::move(label), std::move(protocol), std::move(reliability)) {} + +OutgoingDataChannel::~OutgoingDataChannel() {} + +void OutgoingDataChannel::open(shared_ptr transport) { + std::unique_lock lock(mMutex); + mSctpTransport = transport; + + if (!mStream.has_value()) + throw std::runtime_error("DataChannel has no stream assigned"); + + uint8_t channelType; + uint32_t reliabilityParameter; + if (mReliability->maxPacketLifeTime) { + channelType = CHANNEL_PARTIAL_RELIABLE_TIMED; + reliabilityParameter = to_uint32(mReliability->maxPacketLifeTime->count()); + } else if (mReliability->maxRetransmits) { + channelType = CHANNEL_PARTIAL_RELIABLE_REXMIT; + reliabilityParameter = to_uint32(*mReliability->maxRetransmits); + } + // else { + // channelType = CHANNEL_RELIABLE; + // reliabilityParameter = 0; + // } + // Deprecated + else + switch (mReliability->typeDeprecated) { + case Reliability::Type::Rexmit: + channelType = CHANNEL_PARTIAL_RELIABLE_REXMIT; + reliabilityParameter = to_uint32(std::max(std::get(mReliability->rexmit), 0)); + break; + + case Reliability::Type::Timed: + channelType = CHANNEL_PARTIAL_RELIABLE_TIMED; + reliabilityParameter = to_uint32(std::get(mReliability->rexmit).count()); + break; + + default: + channelType = CHANNEL_RELIABLE; + reliabilityParameter = 0; + break; + } + + if (mReliability->unordered) + channelType |= 0x80; + + const size_t len = sizeof(OpenMessage) + mLabel.size() + mProtocol.size(); + binary buffer(len, byte(0)); + auto &open = *reinterpret_cast(buffer.data()); + open.type = MESSAGE_OPEN; + open.channelType = channelType; + open.priority = htons(0); + open.reliabilityParameter = htonl(reliabilityParameter); + open.labelLength = htons(to_uint16(mLabel.size())); + open.protocolLength = htons(to_uint16(mProtocol.size())); + + auto end = reinterpret_cast(buffer.data() + sizeof(OpenMessage)); + std::copy(mLabel.begin(), mLabel.end(), end); + std::copy(mProtocol.begin(), mProtocol.end(), end + mLabel.size()); + + lock.unlock(); + + transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream.value())); +} + +void OutgoingDataChannel::processOpenMessage(message_ptr) { + PLOG_WARNING << "Received an open message for a locally-created DataChannel, ignoring"; +} + +IncomingDataChannel::IncomingDataChannel(weak_ptr pc, + weak_ptr transport) + : DataChannel(pc, "", "", {}) { + + mSctpTransport = transport; +} + +IncomingDataChannel::~IncomingDataChannel() {} + +void IncomingDataChannel::open(shared_ptr) { + // Ignore +} + +void IncomingDataChannel::processOpenMessage(message_ptr message) { + std::unique_lock lock(mMutex); + auto transport = mSctpTransport.lock(); + if (!transport) + throw std::logic_error("DataChannel has no transport"); + + if (!mStream.has_value()) + throw std::logic_error("DataChannel has no stream assigned"); + + if (message->size() < sizeof(OpenMessage)) + throw std::invalid_argument("DataChannel open message too small"); + + OpenMessage open = *reinterpret_cast(message->data()); + open.priority = ntohs(open.priority); + open.reliabilityParameter = ntohl(open.reliabilityParameter); + open.labelLength = ntohs(open.labelLength); + open.protocolLength = ntohs(open.protocolLength); + + if (message->size() < sizeof(OpenMessage) + size_t(open.labelLength + open.protocolLength)) + throw std::invalid_argument("DataChannel open message truncated"); + + auto end = reinterpret_cast(message->data() + sizeof(OpenMessage)); + mLabel.assign(end, open.labelLength); + mProtocol.assign(end + open.labelLength, open.protocolLength); + + mReliability->unordered = (open.channelType & 0x80) != 0; + mReliability->maxPacketLifeTime.reset(); + mReliability->maxRetransmits.reset(); + switch (open.channelType & 0x7F) { + case CHANNEL_PARTIAL_RELIABLE_REXMIT: + mReliability->maxRetransmits.emplace(open.reliabilityParameter); + break; + case CHANNEL_PARTIAL_RELIABLE_TIMED: + mReliability->maxPacketLifeTime.emplace(milliseconds(open.reliabilityParameter)); + break; + default: + break; + } + + // Deprecated + switch (open.channelType & 0x7F) { + case CHANNEL_PARTIAL_RELIABLE_REXMIT: + mReliability->typeDeprecated = Reliability::Type::Rexmit; + mReliability->rexmit = int(open.reliabilityParameter); + break; + case CHANNEL_PARTIAL_RELIABLE_TIMED: + mReliability->typeDeprecated = Reliability::Type::Timed; + mReliability->rexmit = milliseconds(open.reliabilityParameter); + break; + default: + mReliability->typeDeprecated = Reliability::Type::Reliable; + mReliability->rexmit = int(0); + } + + lock.unlock(); + + binary buffer(sizeof(AckMessage), byte(0)); + auto &ack = *reinterpret_cast(buffer.data()); + ack.type = MESSAGE_ACK; + + transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream.value())); + + if (!mIsOpen.exchange(true)) + triggerOpen(); +} + +} // namespace rtc::impl diff --git a/datachannel/src/impl/datachannel.hpp b/datachannel/src/impl/datachannel.hpp new file mode 100644 index 000000000..cd501bff2 --- /dev/null +++ b/datachannel/src/impl/datachannel.hpp @@ -0,0 +1,93 @@ +/** + * Copyright (c) 2019 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_DATA_CHANNEL_H +#define RTC_IMPL_DATA_CHANNEL_H + +#include "channel.hpp" +#include "common.hpp" +#include "message.hpp" +#include "peerconnection.hpp" +#include "queue.hpp" +#include "reliability.hpp" +#include "sctptransport.hpp" + +#include +#include + +namespace rtc::impl { + +struct PeerConnection; + +struct DataChannel : Channel, std::enable_shared_from_this { + static bool IsOpenMessage(message_ptr message); + + DataChannel(weak_ptr pc, string label, string protocol, + Reliability reliability); + virtual ~DataChannel(); + + void close(); + void remoteClose(); + bool outgoing(message_ptr message); + void incoming(message_ptr message); + + optional receive() override; + optional peek() override; + size_t availableAmount() const override; + + optional stream() const; + string label() const; + string protocol() const; + Reliability reliability() const; + + bool isOpen(void) const; + bool isClosed(void) const; + size_t maxMessageSize() const; + + virtual void assignStream(uint16_t stream); + virtual void open(shared_ptr transport); + virtual void processOpenMessage(message_ptr); + +protected: + const weak_ptr mPeerConnection; + weak_ptr mSctpTransport; + + optional mStream; + string mLabel; + string mProtocol; + shared_ptr mReliability; + + mutable std::shared_mutex mMutex; + + std::atomic mIsOpen = false; + std::atomic mIsClosed = false; + +private: + Queue mRecvQueue; +}; + +struct OutgoingDataChannel final : public DataChannel { + OutgoingDataChannel(weak_ptr pc, string label, string protocol, + Reliability reliability); + ~OutgoingDataChannel(); + + void open(shared_ptr transport) override; + void processOpenMessage(message_ptr message) override; +}; + +struct IncomingDataChannel final : public DataChannel { + IncomingDataChannel(weak_ptr pc, weak_ptr transport); + ~IncomingDataChannel(); + + void open(shared_ptr transport) override; + void processOpenMessage(message_ptr message) override; +}; + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/dtlssrtptransport.cpp b/datachannel/src/impl/dtlssrtptransport.cpp new file mode 100644 index 000000000..bf7a0ff11 --- /dev/null +++ b/datachannel/src/impl/dtlssrtptransport.cpp @@ -0,0 +1,393 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "dtlssrtptransport.hpp" +#include "logcounter.hpp" +#include "rtp.hpp" +#include "tls.hpp" + +#if RTC_ENABLE_MEDIA + +#include +#include + +using std::to_integer; +using std::to_string; + +namespace rtc::impl { + +static LogCounter COUNTER_MEDIA_TRUNCATED(plog::warning, + "Number of truncated SRT(C)P packets received"); +static LogCounter + COUNTER_UNKNOWN_PACKET_TYPE(plog::warning, + "Number of RTP packets received with an unknown packet type"); +static LogCounter COUNTER_SRTCP_REPLAY(plog::warning, "Number of SRTCP replay packets received"); +static LogCounter + COUNTER_SRTCP_AUTH_FAIL(plog::warning, + "Number of SRTCP packets received that failed authentication checks"); +static LogCounter + COUNTER_SRTCP_FAIL(plog::warning, + "Number of SRTCP packets received that had an unknown libSRTP failure"); +static LogCounter COUNTER_SRTP_REPLAY(plog::warning, "Number of SRTP replay packets received"); +static LogCounter + COUNTER_SRTP_AUTH_FAIL(plog::warning, + "Number of SRTP packets received that failed authentication checks"); +static LogCounter + COUNTER_SRTP_FAIL(plog::warning, + "Number of SRTP packets received that had an unknown libSRTP failure"); + +void DtlsSrtpTransport::Init() { srtp_init(); } + +void DtlsSrtpTransport::Cleanup() { srtp_shutdown(); } + +bool DtlsSrtpTransport::IsGcmSupported() { +#if RTC_SYSTEM_SRTP + // system libSRTP may not have GCM support + srtp_policy_t policy = {}; + return srtp_crypto_policy_set_from_profile_for_rtp( + &policy.rtp, srtp_profile_aead_aes_256_gcm) == srtp_err_status_ok; +#else + return true; +#endif +} + +DtlsSrtpTransport::DtlsSrtpTransport(shared_ptr lower, + shared_ptr certificate, optional mtu, + CertificateFingerprint::Algorithm fingerprintAlgorithm, + verifier_callback verifierCallback, + message_callback srtpRecvCallback, + state_callback stateChangeCallback) + : DtlsTransport(lower, certificate, mtu, fingerprintAlgorithm, std::move(verifierCallback), + std::move(stateChangeCallback)), + mSrtpRecvCallback(std::move(srtpRecvCallback)) { // distinct from Transport recv callback + + PLOG_DEBUG << "Initializing DTLS-SRTP transport"; + + if (srtp_err_status_t err = srtp_create(&mSrtpIn, nullptr)) { + throw std::runtime_error("srtp_create failed, status=" + to_string(static_cast(err))); + } + if (srtp_err_status_t err = srtp_create(&mSrtpOut, nullptr)) { + srtp_dealloc(mSrtpIn); + throw std::runtime_error("srtp_create failed, status=" + to_string(static_cast(err))); + } +} + +DtlsSrtpTransport::~DtlsSrtpTransport() { + stop(); // stop before deallocating + + srtp_dealloc(mSrtpIn); + srtp_dealloc(mSrtpOut); +} + +bool DtlsSrtpTransport::sendMedia(message_ptr message) { + std::lock_guard lock(sendMutex); + if (!message) + return false; + + if (!mInitDone) { + PLOG_ERROR << "SRTP media sent before keys are derived"; + return false; + } + + int size = int(message->size()); + PLOG_VERBOSE << "Send size=" << size; + + // The RTP header has a minimum size of 12 bytes + // An RTCP packet can have a minimum size of 8 bytes + if (size < 8) + throw std::runtime_error("RTP/RTCP packet too short"); + + // srtp_protect() and srtp_protect_rtcp() assume that they can write SRTP_MAX_TRAILER_LEN (for + // the authentication tag) into the location in memory immediately following the RTP packet. + // Copy instead of resizing so we don't interfere with media handlers keeping references + message = make_message(size + SRTP_MAX_TRAILER_LEN, message); + + if (IsRtcp(*message)) { // Demultiplex RTCP and RTP using payload type + if (srtp_err_status_t err = srtp_protect_rtcp(mSrtpOut, message->data(), &size)) { + if (err == srtp_err_status_replay_fail) + throw std::runtime_error("Outgoing SRTCP packet is a replay"); + else + throw std::runtime_error("SRTCP protect error, status=" + + to_string(static_cast(err))); + } + PLOG_VERBOSE << "Protected SRTCP packet, size=" << size; + + } else { + if (srtp_err_status_t err = srtp_protect(mSrtpOut, message->data(), &size)) { + if (err == srtp_err_status_replay_fail) + throw std::runtime_error("Outgoing SRTP packet is a replay"); + else + throw std::runtime_error("SRTP protect error, status=" + + to_string(static_cast(err))); + } + PLOG_VERBOSE << "Protected SRTP packet, size=" << size; + } + + message->resize(size); + + if (message->dscp == 0) { // Track might override the value + // Set recommended medium-priority DSCP value + // See https://www.rfc-editor.org/rfc/rfc8837.html#section-5 + message->dscp = 36; // AF42: Assured Forwarding class 4, medium drop probability + } + + return Transport::outgoing(message); // bypass DTLS DSCP marking +} + +void DtlsSrtpTransport::recvMedia(message_ptr message) { + // The RTP header has a minimum size of 12 bytes + // An RTCP packet can have a minimum size of 8 bytes + int size = int(message->size()); + if (size < 8) { + COUNTER_MEDIA_TRUNCATED++; + PLOG_VERBOSE << "Incoming SRTP/SRTCP packet too short, size=" << size; + return; + } + + uint8_t value2 = to_integer(*(message->begin() + 1)) & 0x7F; + PLOG_VERBOSE << "Demultiplexing SRTCP and SRTP with RTP payload type, value=" + << unsigned(value2); + + if (IsRtcp(*message)) { // Demultiplex RTCP and RTP using payload type + PLOG_VERBOSE << "Incoming SRTCP packet, size=" << size; + if (srtp_err_status_t err = srtp_unprotect_rtcp(mSrtpIn, message->data(), &size)) { + if (err == srtp_err_status_replay_fail) { + PLOG_VERBOSE << "Incoming SRTCP packet is a replay"; + COUNTER_SRTCP_REPLAY++; + } else if (err == srtp_err_status_auth_fail) { + PLOG_DEBUG << "Incoming SRTCP packet failed authentication check"; + COUNTER_SRTCP_AUTH_FAIL++; + } else { + PLOG_DEBUG << "SRTCP unprotect error, status=" << err; + COUNTER_SRTCP_FAIL++; + } + + return; + } + PLOG_VERBOSE << "Unprotected SRTCP packet, size=" << size; + message->type = Message::Control; + message->stream = reinterpret_cast(message->data())->senderSSRC(); + + } else { + PLOG_VERBOSE << "Incoming SRTP packet, size=" << size; + if (srtp_err_status_t err = srtp_unprotect(mSrtpIn, message->data(), &size)) { + if (err == srtp_err_status_replay_fail) { + PLOG_VERBOSE << "Incoming SRTP packet is a replay"; + COUNTER_SRTP_REPLAY++; + } else if (err == srtp_err_status_auth_fail) { + PLOG_DEBUG << "Incoming SRTP packet failed authentication check"; + COUNTER_SRTP_AUTH_FAIL++; + } else { + PLOG_DEBUG << "SRTP unprotect error, status=" << err; + COUNTER_SRTP_FAIL++; + } + return; + } + PLOG_VERBOSE << "Unprotected SRTP packet, size=" << size; + message->type = Message::Binary; + message->stream = reinterpret_cast(message->data())->ssrc(); + } + + message->resize(size); + mSrtpRecvCallback(message); +} + +bool DtlsSrtpTransport::demuxMessage(message_ptr message) { + if (!mInitDone) { + // Bypass + return false; + } + + if (message->size() == 0) + return false; + + // RFC 5764 5.1.2. Reception + // https://www.rfc-editor.org/rfc/rfc5764.html#section-5.1.2 + // The process for demultiplexing a packet is as follows. The receiver looks at the first byte + // of the packet. [...] If the value is in between 128 and 191 (inclusive), then the packet is + // RTP (or RTCP [...]). If the value is between 20 and 63 (inclusive), the packet is DTLS. + uint8_t value1 = to_integer(*message->begin()); + PLOG_VERBOSE << "Demultiplexing DTLS and SRTP/SRTCP with first byte, value=" + << unsigned(value1); + + if (value1 >= 20 && value1 <= 63) { + PLOG_VERBOSE << "Incoming DTLS packet, size=" << message->size(); + return false; + + } else if (value1 >= 128 && value1 <= 191) { + recvMedia(std::move(message)); + return true; + + } else { + COUNTER_UNKNOWN_PACKET_TYPE++; + PLOG_DEBUG << "Unknown packet type, value=" << unsigned(value1) + << ", size=" << message->size(); + return true; + } +} + +void DtlsSrtpTransport::postHandshake() { + if (mInitDone) + return; + +#if USE_GNUTLS + PLOG_INFO << "Deriving SRTP keying material (GnuTLS)"; + + const srtp_profile_t srtpProfile = srtp_profile_aes128_cm_sha1_80; + const size_t keySize = SRTP_AES_128_KEY_LEN; + const size_t saltSize = SRTP_SALT_LEN; + const size_t keySizeWithSalt = SRTP_AES_ICM_128_KEY_LEN_WSALT; + + const size_t materialLen = keySizeWithSalt * 2; + std::vector material(materialLen); + gnutls_datum_t clientKeyDatum, clientSaltDatum, serverKeyDatum, serverSaltDatum; + gnutls::check(gnutls_srtp_get_keys(mSession, material.data(), materialLen, &clientKeyDatum, + &clientSaltDatum, &serverKeyDatum, &serverSaltDatum), + "Failed to derive SRTP keys"); + + if (clientKeyDatum.size != keySize) + throw std::logic_error("Unexpected SRTP master key length: " + + to_string(clientKeyDatum.size)); + if (clientSaltDatum.size != saltSize) + throw std::logic_error("Unexpected SRTP salt length: " + to_string(clientSaltDatum.size)); + if (serverKeyDatum.size != keySize) + throw std::logic_error("Unexpected SRTP master key length: " + + to_string(serverKeyDatum.size)); + if (serverSaltDatum.size != saltSize) + throw std::logic_error("Unexpected SRTP salt size: " + to_string(serverSaltDatum.size)); + + const unsigned char *clientKey = reinterpret_cast(clientKeyDatum.data); + const unsigned char *clientSalt = reinterpret_cast(clientSaltDatum.data); + const unsigned char *serverKey = reinterpret_cast(serverKeyDatum.data); + const unsigned char *serverSalt = reinterpret_cast(serverSaltDatum.data); + +#elif USE_MBEDTLS + PLOG_INFO << "Deriving SRTP keying material (Mbed TLS)"; + + mbedtls_dtls_srtp_info srtpInfo; + mbedtls_ssl_get_dtls_srtp_negotiation_result(&mSsl, &srtpInfo); + if (srtpInfo.MBEDTLS_PRIVATE(chosen_dtls_srtp_profile) != MBEDTLS_TLS_SRTP_AES128_CM_HMAC_SHA1_80) + throw std::runtime_error("Failed to get SRTP profile"); + + const srtp_profile_t srtpProfile = srtp_profile_aes128_cm_sha1_80; + const size_t keySize = SRTP_AES_128_KEY_LEN; + const size_t saltSize = SRTP_SALT_LEN; + const size_t keySizeWithSalt = SRTP_AES_ICM_128_KEY_LEN_WSALT; + + if (mTlsProfile == MBEDTLS_SSL_TLS_PRF_NONE) + throw std::logic_error("TLS PRF type is not set"); + + // The extractor provides the client write master key, the server write master key, the client + // write master salt and the server write master salt in that order. + const string label = "EXTRACTOR-dtls_srtp"; + const size_t materialLen = keySizeWithSalt * 2; + std::vector material(materialLen); + + if (mbedtls_ssl_tls_prf(mTlsProfile, reinterpret_cast(mMasterSecret), 48, + label.c_str(), reinterpret_cast(mRandBytes), 64, + material.data(), materialLen) != 0) + throw std::runtime_error("Failed to derive SRTP keys"); + + // Order is client key, server key, client salt, and server salt + const unsigned char *clientKey = material.data(); + const unsigned char *serverKey = clientKey + keySize; + const unsigned char *clientSalt = serverKey + keySize; + const unsigned char *serverSalt = clientSalt + saltSize; + +#else // OpenSSL + PLOG_INFO << "Deriving SRTP keying material (OpenSSL)"; + auto profile = SSL_get_selected_srtp_profile(mSsl); + if (!profile) + throw std::runtime_error("Failed to get SRTP profile: " + + openssl::error_string(ERR_get_error())); + + PLOG_DEBUG << "SRTP profile is: " << profile->name; + + const auto [srtpProfile, keySize, saltSize] = getProfileParamsFromName(profile->name); + const size_t keySizeWithSalt = keySize + saltSize; + + // The extractor provides the client write master key, the server write master key, the client + // write master salt and the server write master salt in that order. + const string label = "EXTRACTOR-dtls_srtp"; + const size_t materialLen = keySizeWithSalt * 2; + std::vector material(materialLen); + + // returns 1 on success, 0 or -1 on failure (OpenSSL API is a complete mess...) + if (SSL_export_keying_material(mSsl, material.data(), materialLen, label.c_str(), label.size(), + nullptr, 0, 0) <= 0) + throw std::runtime_error("Failed to derive SRTP keys: " + + openssl::error_string(ERR_get_error())); + + // Order is client key, server key, client salt, and server salt + const unsigned char *clientKey = material.data(); + const unsigned char *serverKey = clientKey + keySize; + const unsigned char *clientSalt = serverKey + keySize; + const unsigned char *serverSalt = clientSalt + saltSize; +#endif + + mClientSessionKey.resize(keySizeWithSalt); + mServerSessionKey.resize(keySizeWithSalt); + std::memcpy(mClientSessionKey.data(), clientKey, keySize); + std::memcpy(mClientSessionKey.data() + keySize, clientSalt, saltSize); + + std::memcpy(mServerSessionKey.data(), serverKey, keySize); + std::memcpy(mServerSessionKey.data() + keySize, serverSalt, saltSize); + + srtp_policy_t inbound = {}; + if (srtp_crypto_policy_set_from_profile_for_rtp(&inbound.rtp, srtpProfile)) + throw std::runtime_error("SRTP profile is not supported"); + if (srtp_crypto_policy_set_from_profile_for_rtcp(&inbound.rtcp, srtpProfile)) + throw std::runtime_error("SRTP profile is not supported"); + + inbound.ssrc.type = ssrc_any_inbound; + inbound.key = mIsClient ? mServerSessionKey.data() : mClientSessionKey.data(); + inbound.window_size = 1024; + inbound.allow_repeat_tx = true; + inbound.next = nullptr; + + if (srtp_err_status_t err = srtp_add_stream(mSrtpIn, &inbound)) + throw std::runtime_error("SRTP add inbound stream failed, status=" + + to_string(static_cast(err))); + + srtp_policy_t outbound = {}; + if (srtp_crypto_policy_set_from_profile_for_rtp(&outbound.rtp, srtpProfile)) + throw std::runtime_error("SRTP profile is not supported"); + if (srtp_crypto_policy_set_from_profile_for_rtcp(&outbound.rtcp, srtpProfile)) + throw std::runtime_error("SRTP profile is not supported"); + + outbound.ssrc.type = ssrc_any_outbound; + outbound.key = mIsClient ? mClientSessionKey.data() : mServerSessionKey.data(); + outbound.window_size = 1024; + outbound.allow_repeat_tx = true; + outbound.next = nullptr; + + if (srtp_err_status_t err = srtp_add_stream(mSrtpOut, &outbound)) + throw std::runtime_error("SRTP add outbound stream failed, status=" + + to_string(static_cast(err))); + + mInitDone = true; +} + +#if !USE_GNUTLS && !USE_MBEDTLS +DtlsSrtpTransport::ProfileParams DtlsSrtpTransport::getProfileParamsFromName(string_view name) { + if (name == "SRTP_AES128_CM_SHA1_80") + return {srtp_profile_aes128_cm_sha1_80, SRTP_AES_128_KEY_LEN, SRTP_SALT_LEN}; + if (name == "SRTP_AES128_CM_SHA1_32") + return {srtp_profile_aes128_cm_sha1_32, SRTP_AES_128_KEY_LEN, SRTP_SALT_LEN}; + if (name == "SRTP_AEAD_AES_128_GCM") + return {srtp_profile_aead_aes_128_gcm, SRTP_AES_128_KEY_LEN, SRTP_AEAD_SALT_LEN}; + if (name == "SRTP_AEAD_AES_256_GCM") + return {srtp_profile_aead_aes_256_gcm, SRTP_AES_256_KEY_LEN, SRTP_AEAD_SALT_LEN}; + + throw std::logic_error("Unknown SRTP profile name: " + std::string(name)); +} +#endif + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/dtlssrtptransport.hpp b/datachannel/src/impl/dtlssrtptransport.hpp new file mode 100644 index 000000000..208afab2a --- /dev/null +++ b/datachannel/src/impl/dtlssrtptransport.hpp @@ -0,0 +1,68 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_DTLS_SRTP_TRANSPORT_H +#define RTC_IMPL_DTLS_SRTP_TRANSPORT_H + +#include "common.hpp" +#include "dtlstransport.hpp" + +#if RTC_ENABLE_MEDIA + +#if RTC_SYSTEM_SRTP +#include +#else +#include "srtp.h" +#endif + +#include + +namespace rtc::impl { + +class DtlsSrtpTransport final : public DtlsTransport { +public: + static void Init(); + static void Cleanup(); + static bool IsGcmSupported(); + + DtlsSrtpTransport(shared_ptr lower, certificate_ptr certificate, + optional mtu, CertificateFingerprint::Algorithm fingerprintAlgorithm, + verifier_callback verifierCallback, message_callback srtpRecvCallback, + state_callback stateChangeCallback); + ~DtlsSrtpTransport(); + + bool sendMedia(message_ptr message); + +private: + void recvMedia(message_ptr message); + bool demuxMessage(message_ptr message) override; + void postHandshake() override; + +#if !USE_GNUTLS && !USE_MBEDTLS + struct ProfileParams { + srtp_profile_t srtpProfile; + size_t keySize; + size_t saltSize; + }; + + ProfileParams getProfileParamsFromName(string_view name); +#endif + + message_callback mSrtpRecvCallback; + srtp_t mSrtpIn, mSrtpOut; + std::atomic mInitDone = false; + std::vector mClientSessionKey; + std::vector mServerSessionKey; + std::mutex sendMutex; +}; + +} // namespace rtc::impl + +#endif + +#endif diff --git a/datachannel/src/impl/dtlstransport.cpp b/datachannel/src/impl/dtlstransport.cpp new file mode 100644 index 000000000..caad9c710 --- /dev/null +++ b/datachannel/src/impl/dtlstransport.cpp @@ -0,0 +1,1095 @@ +/** + * Copyright (c) 2019 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "dtlstransport.hpp" +#include "dtlssrtptransport.hpp" +#include "icetransport.hpp" +#include "internals.hpp" +#include "threadpool.hpp" + +#include +#include +#include +#include + +#if !USE_GNUTLS +#ifdef _WIN32 +#include // for timeval +#else +#include // for timeval +#endif +#endif + +using namespace std::chrono; + +namespace rtc::impl { + +void DtlsTransport::enqueueRecv() { + if (mPendingRecvCount > 0) + return; + + if (auto shared_this = weak_from_this().lock()) { + ++mPendingRecvCount; + ThreadPool::Instance().enqueue(&DtlsTransport::doRecv, std::move(shared_this)); + } +} + +#if USE_GNUTLS + +void DtlsTransport::Init() { + gnutls_global_init(); // optional +} + +void DtlsTransport::Cleanup() { gnutls_global_deinit(); } + +DtlsTransport::DtlsTransport(shared_ptr lower, certificate_ptr certificate, + optional mtu, + CertificateFingerprint::Algorithm fingerprintAlgorithm, + verifier_callback verifierCallback, state_callback stateChangeCallback) + : Transport(lower, std::move(stateChangeCallback)), mMtu(mtu), mCertificate(certificate), + mFingerprintAlgorithm(fingerprintAlgorithm), mVerifierCallback(std::move(verifierCallback)), + mIsClient(lower->role() == Description::Role::Active) { + + PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)"; + + if (!mCertificate) + throw std::invalid_argument("DTLS certificate is null"); + + gnutls_certificate_credentials_t creds = mCertificate->credentials(); + gnutls_certificate_set_verify_function(creds, CertificateCallback); + + unsigned int flags = + GNUTLS_DATAGRAM | GNUTLS_NONBLOCK | (mIsClient ? GNUTLS_CLIENT : GNUTLS_SERVER); + gnutls::check(gnutls_init(&mSession, flags)); + + try { + // RFC 8261: SCTP performs segmentation and reassembly based on the path MTU. + // Therefore, the DTLS layer MUST NOT use any compression algorithm. + // See https://www.rfc-editor.org/rfc/rfc8261.html#section-5 + const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL:+COMP-NULL"; + const char *err_pos = NULL; + gnutls::check(gnutls_priority_set_direct(mSession, priorities, &err_pos), + "Failed to set TLS priorities"); + + // RFC 8827: The DTLS-SRTP protection profile SRTP_AES128_CM_HMAC_SHA1_80 MUST be supported + // See https://www.rfc-editor.org/rfc/rfc8827.html#section-6.5 + gnutls::check(gnutls_srtp_set_profile(mSession, GNUTLS_SRTP_AES128_CM_HMAC_SHA1_80), + "Failed to set SRTP profile"); + + gnutls::check(gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE, creds)); + + gnutls_dtls_set_timeouts(mSession, + 1000, // 1s retransmission timeout recommended by RFC 6347 + 30000); // 30s total timeout + gnutls_handshake_set_timeout(mSession, 30000); + + gnutls_session_set_ptr(mSession, this); + gnutls_transport_set_ptr(mSession, this); + gnutls_transport_set_push_function(mSession, WriteCallback); + gnutls_transport_set_pull_function(mSession, ReadCallback); + gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback); + + } catch (...) { + gnutls_deinit(mSession); + throw; + } + + // Set recommended medium-priority DSCP value for handshake + // See https://www.rfc-editor.org/rfc/rfc8837.html#section-5 + mCurrentDscp = 10; // AF11: Assured Forwarding class 1, low drop probability +} + +DtlsTransport::~DtlsTransport() { + stop(); + + PLOG_DEBUG << "Destroying DTLS transport"; + gnutls_deinit(mSession); +} + +void DtlsTransport::start() { + PLOG_DEBUG << "Starting DTLS transport"; + registerIncoming(); + changeState(State::Connecting); + + size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6 + gnutls_dtls_set_mtu(mSession, static_cast(mtu)); + PLOG_VERBOSE << "DTLS MTU set to " << mtu; + + enqueueRecv(); // to initiate the handshake +} + +void DtlsTransport::stop() { + PLOG_DEBUG << "Stopping DTLS transport"; + unregisterIncoming(); + mIncomingQueue.stop(); + enqueueRecv(); +} + +bool DtlsTransport::send(message_ptr message) { + if (!message || state() != State::Connected) + return false; + + PLOG_VERBOSE << "Send size=" << message->size(); + + ssize_t ret; + do { + std::lock_guard lock(mSendMutex); + mCurrentDscp = message->dscp; + ret = gnutls_record_send(mSession, message->data(), message->size()); + } while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN); + + if (ret == GNUTLS_E_LARGE_PACKET) + return false; + + if (!gnutls::check(ret)) + return false; + + return mOutgoingResult; +} + +void DtlsTransport::incoming(message_ptr message) { + if (!message) { + mIncomingQueue.stop(); + return; + } + + PLOG_VERBOSE << "Incoming size=" << message->size(); + mIncomingQueue.push(message); + enqueueRecv(); +} + +bool DtlsTransport::outgoing(message_ptr message) { + message->dscp = mCurrentDscp; + + bool result = Transport::outgoing(std::move(message)); + mOutgoingResult = result; + return result; +} + +bool DtlsTransport::demuxMessage(message_ptr) { + // Dummy + return false; +} + +void DtlsTransport::postHandshake() { + // Dummy +} + +void DtlsTransport::doRecv() { + std::lock_guard lock(mRecvMutex); + --mPendingRecvCount; + + if (state() != State::Connecting && state() != State::Connected) + return; + + try { + const size_t bufferSize = 4096; + char buffer[bufferSize]; + + // Handle handshake if connecting + if (state() == State::Connecting) { + int ret; + do { + ret = gnutls_handshake(mSession); + + if (ret == GNUTLS_E_AGAIN) { + // Schedule next call on timeout and return + auto timeout = milliseconds(gnutls_dtls_get_timeout(mSession)); + ThreadPool::Instance().schedule(timeout, [weak_this = weak_from_this()]() { + if (auto locked = weak_this.lock()) + locked->doRecv(); + }); + return; + } + + if (ret == GNUTLS_E_LARGE_PACKET) { + throw std::runtime_error("MTU is too low"); + } + + } while (!gnutls::check(ret, "Handshake failed")); // Re-call on non-fatal error + + // RFC 8261: DTLS MUST support sending messages larger than the current path MTU + // See https://www.rfc-editor.org/rfc/rfc8261.html#section-5 + gnutls_dtls_set_mtu(mSession, bufferSize + 1); + + PLOG_INFO << "DTLS handshake finished"; + changeState(State::Connected); + postHandshake(); + } + + if (state() == State::Connected) { + while (true) { + ssize_t ret = gnutls_record_recv(mSession, buffer, bufferSize); + + if (ret == GNUTLS_E_AGAIN) { + return; + } + + // RFC 8827: Implementations MUST NOT implement DTLS renegotiation and MUST reject + // it with a "no_renegotiation" alert if offered. See + // https://www.rfc-editor.org/rfc/rfc8827.html#section-6.5 + if (ret == GNUTLS_E_REHANDSHAKE) { + do { + std::lock_guard lock(mSendMutex); + ret = gnutls_alert_send(mSession, GNUTLS_AL_WARNING, + GNUTLS_A_NO_RENEGOTIATION); + } while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN); + continue; + } + + // Consider premature termination as remote closing + if (ret == GNUTLS_E_PREMATURE_TERMINATION) { + PLOG_DEBUG << "DTLS connection terminated"; + break; + } + + if (gnutls::check(ret)) { + if (ret == 0) { + // Closed + PLOG_DEBUG << "DTLS connection cleanly closed"; + break; + } + auto *b = reinterpret_cast(buffer); + recv(make_message(b, b + ret)); + } + } + } + } catch (const std::exception &e) { + PLOG_ERROR << "DTLS recv: " << e.what(); + } + + gnutls_bye(mSession, GNUTLS_SHUT_WR); + + if (state() == State::Connected) { + PLOG_INFO << "DTLS closed"; + changeState(State::Disconnected); + recv(nullptr); + } else { + PLOG_ERROR << "DTLS handshake failed"; + changeState(State::Failed); + } +} + +int DtlsTransport::CertificateCallback(gnutls_session_t session) { + DtlsTransport *t = static_cast(gnutls_session_get_ptr(session)); + try { + if (gnutls_certificate_type_get(session) != GNUTLS_CRT_X509) { + return GNUTLS_E_CERTIFICATE_ERROR; + } + + unsigned int count = 0; + const gnutls_datum_t *array = gnutls_certificate_get_peers(session, &count); + if (!array || count == 0) { + return GNUTLS_E_CERTIFICATE_ERROR; + } + + gnutls_x509_crt_t crt; + gnutls::check(gnutls_x509_crt_init(&crt)); + int ret = gnutls_x509_crt_import(crt, &array[0], GNUTLS_X509_FMT_DER); + if (ret != GNUTLS_E_SUCCESS) { + gnutls_x509_crt_deinit(crt); + return GNUTLS_E_CERTIFICATE_ERROR; + } + + string fingerprint = make_fingerprint(crt, t->mFingerprintAlgorithm); + gnutls_x509_crt_deinit(crt); + + bool success = t->mVerifierCallback(fingerprint); + return success ? GNUTLS_E_SUCCESS : GNUTLS_E_CERTIFICATE_ERROR; + + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + return GNUTLS_E_CERTIFICATE_ERROR; + } +} + +ssize_t DtlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len) { + DtlsTransport *t = static_cast(ptr); + try { + if (len > 0) { + auto b = reinterpret_cast(data); + t->outgoing(make_message(b, b + len)); + } + gnutls_transport_set_errno(t->mSession, 0); + return ssize_t(len); + + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + gnutls_transport_set_errno(t->mSession, ECONNRESET); + return -1; + } +} + +ssize_t DtlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) { + DtlsTransport *t = static_cast(ptr); + try { + while (t->mIncomingQueue.running()) { + auto next = t->mIncomingQueue.pop(); + if (!next) { + gnutls_transport_set_errno(t->mSession, EAGAIN); + return -1; + } + + message_ptr message = std::move(*next); + if (t->demuxMessage(message)) + continue; + + ssize_t len = std::min(maxlen, message->size()); + std::memcpy(data, message->data(), len); + gnutls_transport_set_errno(t->mSession, 0); + return len; + } + + // Closed + gnutls_transport_set_errno(t->mSession, 0); + return 0; + + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + gnutls_transport_set_errno(t->mSession, ECONNRESET); + return -1; + } +} + +int DtlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int /* ms */) { + DtlsTransport *t = static_cast(ptr); + try { + return !t->mIncomingQueue.empty() ? 1 : 0; + + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + return 1; + } +} + +#elif USE_MBEDTLS + +const mbedtls_ssl_srtp_profile srtpSupportedProtectionProfiles[] = { + MBEDTLS_TLS_SRTP_AES128_CM_HMAC_SHA1_80, + MBEDTLS_TLS_SRTP_UNSET, +}; + +DtlsTransport::DtlsTransport(shared_ptr lower, certificate_ptr certificate, + optional mtu, + CertificateFingerprint::Algorithm fingerprintAlgorithm, + verifier_callback verifierCallback, state_callback stateChangeCallback) + : Transport(lower, std::move(stateChangeCallback)), mMtu(mtu), mCertificate(certificate), + mFingerprintAlgorithm(fingerprintAlgorithm), mVerifierCallback(std::move(verifierCallback)), + mIsClient(lower->role() == Description::Role::Active) { + + PLOG_DEBUG << "Initializing DTLS transport (MbedTLS)"; + + if (!mCertificate) + throw std::invalid_argument("DTLS certificate is null"); + + mbedtls_entropy_init(&mEntropy); + mbedtls_ctr_drbg_init(&mDrbg); + mbedtls_ssl_init(&mSsl); + mbedtls_ssl_config_init(&mConf); + mbedtls_ctr_drbg_set_prediction_resistance(&mDrbg, MBEDTLS_CTR_DRBG_PR_ON); + + try { + mbedtls::check(mbedtls_ctr_drbg_seed(&mDrbg, mbedtls_entropy_func, &mEntropy, NULL, 0)); + + mbedtls::check(mbedtls_ssl_config_defaults( + &mConf, mIsClient ? MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER, + MBEDTLS_SSL_TRANSPORT_DATAGRAM, MBEDTLS_SSL_PRESET_DEFAULT)); + + mbedtls_ssl_conf_max_version(&mConf, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3); // TLS 1.2 + mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_OPTIONAL); + mbedtls_ssl_conf_verify(&mConf, DtlsTransport::CertificateCallback, this); + mbedtls_ssl_conf_rng(&mConf, mbedtls_ctr_drbg_random, &mDrbg); + + auto [crt, pk] = mCertificate->credentials(); + mbedtls::check(mbedtls_ssl_conf_own_cert(&mConf, crt.get(), pk.get())); + + mbedtls_ssl_conf_dtls_cookies(&mConf, NULL, NULL, NULL); + mbedtls_ssl_conf_dtls_srtp_protection_profiles(&mConf, srtpSupportedProtectionProfiles); + + mbedtls::check(mbedtls_ssl_setup(&mSsl, &mConf)); + + mbedtls_ssl_set_export_keys_cb(&mSsl, DtlsTransport::ExportKeysCallback, this); + mbedtls_ssl_set_bio(&mSsl, this, WriteCallback, ReadCallback, NULL); + mbedtls_ssl_set_timer_cb(&mSsl, this, SetTimerCallback, GetTimerCallback); + + } catch (...) { + mbedtls_entropy_free(&mEntropy); + mbedtls_ctr_drbg_free(&mDrbg); + mbedtls_ssl_free(&mSsl); + mbedtls_ssl_config_free(&mConf); + throw; + } + + // Set recommended medium-priority DSCP value for handshake + // See https://www.rfc-editor.org/rfc/rfc8837.html#section-5 + mCurrentDscp = 10; // AF11: Assured Forwarding class 1, low drop probability +} + +DtlsTransport::~DtlsTransport() { + stop(); + + PLOG_DEBUG << "Destroying DTLS transport"; + mbedtls_entropy_free(&mEntropy); + mbedtls_ctr_drbg_free(&mDrbg); + mbedtls_ssl_free(&mSsl); + mbedtls_ssl_config_free(&mConf); +} + +void DtlsTransport::Init() { + // Nothing to do +} + +void DtlsTransport::Cleanup() { + // Nothing to do +} + +void DtlsTransport::start() { + PLOG_DEBUG << "Starting DTLS transport"; + registerIncoming(); + changeState(State::Connecting); + + { + std::lock_guard lock(mSslMutex); + size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6 + mbedtls_ssl_set_mtu(&mSsl, static_cast(mtu)); + PLOG_VERBOSE << "DTLS MTU set to " << mtu; + } + + enqueueRecv(); // to initiate the handshake +} + +void DtlsTransport::stop() { + PLOG_DEBUG << "Stopping DTLS transport"; + unregisterIncoming(); + mIncomingQueue.stop(); + enqueueRecv(); +} + +bool DtlsTransport::send(message_ptr message) { + if (!message || state() != State::Connected) + return false; + + PLOG_VERBOSE << "Send size=" << message->size(); + + int ret; + do { + std::lock_guard lock(mSslMutex); + if (message->size() > size_t(mbedtls_ssl_get_max_out_record_payload(&mSsl))) + return false; + + mCurrentDscp = message->dscp; + ret = mbedtls_ssl_write(&mSsl, reinterpret_cast(message->data()), + message->size()); + } while (!mbedtls::check(ret)); + + return mOutgoingResult; +} + +void DtlsTransport::incoming(message_ptr message) { + if (!message) { + mIncomingQueue.stop(); + return; + } + + PLOG_VERBOSE << "Incoming size=" << message->size(); + mIncomingQueue.push(message); + enqueueRecv(); +} + +bool DtlsTransport::outgoing(message_ptr message) { + message->dscp = mCurrentDscp; + + bool result = Transport::outgoing(std::move(message)); + mOutgoingResult = result; + return result; +} + +bool DtlsTransport::demuxMessage(message_ptr) { + // Dummy + return false; +} + +void DtlsTransport::postHandshake() { + // Dummy +} + +void DtlsTransport::doRecv() { + std::lock_guard lock(mRecvMutex); + --mPendingRecvCount; + + if (state() != State::Connecting && state() != State::Connected) + return; + + try { + const size_t bufferSize = 4096; + char buffer[bufferSize]; + + // Handle handshake if connecting + if (state() == State::Connecting) { + while (true) { + int ret; + { + std::lock_guard lock(mSslMutex); + ret = mbedtls_ssl_handshake(&mSsl); + } + + if (ret == MBEDTLS_ERR_SSL_WANT_READ) { + ThreadPool::Instance().schedule(mTimerSetAt + milliseconds(mFinMs), + [weak_this = weak_from_this()]() { + if (auto locked = weak_this.lock()) + locked->doRecv(); + }); + return; + } + + if (mbedtls::check(ret, "Handshake failed")) { + // RFC 8261: DTLS MUST support sending messages larger than the current path MTU + // See https://www.rfc-editor.org/rfc/rfc8261.html#section-5 + { + std::lock_guard lock(mSslMutex); + mbedtls_ssl_set_mtu(&mSsl, static_cast(bufferSize + 1)); + } + + PLOG_INFO << "DTLS handshake finished"; + changeState(State::Connected); + postHandshake(); + break; + } + } + } + + if (state() == State::Connected) { + while (true) { + int ret; + { + std::lock_guard lock(mSslMutex); + ret = mbedtls_ssl_read(&mSsl, reinterpret_cast(buffer), + bufferSize); + } + + if (ret == MBEDTLS_ERR_SSL_WANT_READ) { + return; + } + + if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { + PLOG_DEBUG << "DTLS connection cleanly closed"; + break; + } + + if (mbedtls::check(ret)) { + if (ret == 0) { + PLOG_DEBUG << "DTLS connection terminated"; + break; + } + auto *b = reinterpret_cast(buffer); + recv(make_message(b, b + ret)); + } + } + } + } catch (const std::exception &e) { + PLOG_ERROR << "DTLS recv: " << e.what(); + } + + if (state() == State::Connected) { + PLOG_INFO << "DTLS closed"; + changeState(State::Disconnected); + recv(nullptr); + } else { + PLOG_ERROR << "DTLS handshake failed"; + changeState(State::Failed); + } +} + +int DtlsTransport::CertificateCallback(void *ctx, mbedtls_x509_crt *crt, int /*depth*/, + uint32_t * /*flags*/) { + auto this_ = static_cast(ctx); + string fingerprint = make_fingerprint(crt, this_->mFingerprintAlgorithm); + std::transform(fingerprint.begin(), fingerprint.end(), fingerprint.begin(), + [](char c) { return char(std::toupper(c)); }); + return this_->mVerifierCallback(fingerprint) ? 0 : 1; +} + +void DtlsTransport::ExportKeysCallback(void *ctx, mbedtls_ssl_key_export_type /*type*/, + const unsigned char *secret, size_t secret_len, + const unsigned char client_random[32], + const unsigned char server_random[32], + mbedtls_tls_prf_types tls_prf_type) { + auto dtlsTransport = static_cast(ctx); + std::memcpy(dtlsTransport->mMasterSecret, secret, secret_len); + std::memcpy(dtlsTransport->mRandBytes, client_random, 32); + std::memcpy(dtlsTransport->mRandBytes + 32, server_random, 32); + dtlsTransport->mTlsProfile = tls_prf_type; +} + +int DtlsTransport::WriteCallback(void *ctx, const unsigned char *buf, size_t len) { + auto *t = static_cast(ctx); + try { + if (len > 0) { + auto b = reinterpret_cast(buf); + t->outgoing(make_message(b, b + len)); + } + return int(len); + + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + return MBEDTLS_ERR_SSL_INTERNAL_ERROR; + } +} + +int DtlsTransport::ReadCallback(void *ctx, unsigned char *buf, size_t len) { + auto *t = static_cast(ctx); + try { + while (t->mIncomingQueue.running()) { + auto next = t->mIncomingQueue.pop(); + if (!next) { + return MBEDTLS_ERR_SSL_WANT_READ; + } + + message_ptr message = std::move(*next); + if (t->demuxMessage(message)) + continue; + + auto bufMin = std::min(len, size_t(message->size())); + std::memcpy(buf, message->data(), bufMin); + return int(len); + } + + // Closed + return 0; + + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + return MBEDTLS_ERR_SSL_INTERNAL_ERROR; + ; + } +} + +void DtlsTransport::SetTimerCallback(void *ctx, uint32_t int_ms, uint32_t fin_ms) { + auto dtlsTransport = static_cast(ctx); + dtlsTransport->mIntMs = int_ms; + dtlsTransport->mFinMs = fin_ms; + + if (fin_ms != 0) { + dtlsTransport->mTimerSetAt = std::chrono::steady_clock::now(); + } +} + +int DtlsTransport::GetTimerCallback(void *ctx) { + auto dtlsTransport = static_cast(ctx); + auto now = std::chrono::steady_clock::now(); + + if (dtlsTransport->mFinMs == 0) { + return -1; + } else if (now >= dtlsTransport->mTimerSetAt + milliseconds(dtlsTransport->mFinMs)) { + return 2; + } else if (now >= dtlsTransport->mTimerSetAt + milliseconds(dtlsTransport->mIntMs)) { + return 1; + } else { + return 0; + } +} + +#else // OPENSSL + +BIO_METHOD *DtlsTransport::BioMethods = NULL; +int DtlsTransport::TransportExIndex = -1; +std::mutex DtlsTransport::GlobalMutex; + +void DtlsTransport::Init() { + std::lock_guard lock(GlobalMutex); + + openssl::init(); + + if (!BioMethods) { + BioMethods = BIO_meth_new(BIO_TYPE_BIO, "DTLS writer"); + if (!BioMethods) + throw std::runtime_error("Failed to create BIO methods for DTLS writer"); + BIO_meth_set_create(BioMethods, BioMethodNew); + BIO_meth_set_destroy(BioMethods, BioMethodFree); + BIO_meth_set_write(BioMethods, BioMethodWrite); + BIO_meth_set_ctrl(BioMethods, BioMethodCtrl); + } + if (TransportExIndex < 0) { + TransportExIndex = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL); + } +} + +void DtlsTransport::Cleanup() { + // Nothing to do +} + +DtlsTransport::DtlsTransport(shared_ptr lower, certificate_ptr certificate, + optional mtu, + CertificateFingerprint::Algorithm fingerprintAlgorithm, + verifier_callback verifierCallback, state_callback stateChangeCallback) + : Transport(lower, std::move(stateChangeCallback)), mMtu(mtu), mCertificate(certificate), + mFingerprintAlgorithm(fingerprintAlgorithm), mVerifierCallback(std::move(verifierCallback)), + mIsClient(lower->role() == Description::Role::Active) { + PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)"; + + if (!mCertificate) + throw std::invalid_argument("DTLS certificate is null"); + + try { + mCtx = SSL_CTX_new(DTLS_method()); + if (!mCtx) + throw std::runtime_error("Failed to create SSL context"); + + // RFC 8261: SCTP performs segmentation and reassembly based on the path MTU. + // Therefore, the DTLS layer MUST NOT use any compression algorithm. + // See https://www.rfc-editor.org/rfc/rfc8261.html#section-5 + // RFC 8827: Implementations MUST NOT implement DTLS renegotiation + // See https://www.rfc-editor.org/rfc/rfc8827.html#section-6.5 + SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION | SSL_OP_NO_QUERY_MTU | + SSL_OP_NO_RENEGOTIATION); + + SSL_CTX_set_min_proto_version(mCtx, DTLS1_VERSION); + SSL_CTX_set_read_ahead(mCtx, 1); + SSL_CTX_set_quiet_shutdown(mCtx, 0); // send the close_notify alert + SSL_CTX_set_info_callback(mCtx, InfoCallback); + + SSL_CTX_set_verify(mCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, + CertificateCallback); + SSL_CTX_set_verify_depth(mCtx, 1); + + openssl::check(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"), + "Failed to set SSL priorities"); + +#if OPENSSL_VERSION_NUMBER >= 0x30000000 + openssl::check(SSL_CTX_set1_groups_list(mCtx, "P-256"), "Failed to set SSL groups"); +#else + auto ecdh = unique_ptr( + EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free); + SSL_CTX_set_tmp_ecdh(mCtx, ecdh.get()); +#endif + + auto [x509, pkey] = mCertificate->credentials(); + SSL_CTX_use_certificate(mCtx, x509); + SSL_CTX_use_PrivateKey(mCtx, pkey); + openssl::check(SSL_CTX_check_private_key(mCtx), "SSL local private key check failed"); + + mSsl = SSL_new(mCtx); + if (!mSsl) + throw std::runtime_error("Failed to create SSL instance"); + + SSL_set_ex_data(mSsl, TransportExIndex, this); + + if (mIsClient) + SSL_set_connect_state(mSsl); + else + SSL_set_accept_state(mSsl); + + mInBio = BIO_new(BIO_s_mem()); + mOutBio = BIO_new(BioMethods); + if (!mInBio || !mOutBio) + throw std::runtime_error("Failed to create BIO"); + + BIO_set_mem_eof_return(mInBio, BIO_EOF); + BIO_set_data(mOutBio, this); + SSL_set_bio(mSsl, mInBio, mOutBio); + + // RFC 8827: The DTLS-SRTP protection profile SRTP_AES128_CM_HMAC_SHA1_80 MUST be supported + // See https://www.rfc-editor.org/rfc/rfc8827.html#section-6.5 + // Warning: SSL_set_tlsext_use_srtp() returns 0 on success and 1 on error +#if RTC_ENABLE_MEDIA + // Try to use GCM suite + if (!DtlsSrtpTransport::IsGcmSupported() || + SSL_set_tlsext_use_srtp( + mSsl, "SRTP_AEAD_AES_256_GCM:SRTP_AEAD_AES_128_GCM:SRTP_AES128_CM_SHA1_80")) { + PLOG_WARNING << "AES-GCM for SRTP is not supported, falling back to default profile"; + if (SSL_set_tlsext_use_srtp(mSsl, "SRTP_AES128_CM_SHA1_80")) + throw std::runtime_error("Failed to set SRTP profile: " + + openssl::error_string(ERR_get_error())); + } +#else + if (SSL_set_tlsext_use_srtp(mSsl, "SRTP_AES128_CM_SHA1_80")) + throw std::runtime_error("Failed to set SRTP profile: " + + openssl::error_string(ERR_get_error())); +#endif + } catch (...) { + if (mSsl) + SSL_free(mSsl); + if (mCtx) + SSL_CTX_free(mCtx); + throw; + } + + // Set recommended medium-priority DSCP value for handshake + // See https://www.rfc-editor.org/rfc/rfc8837.html#section-5 + mCurrentDscp = 10; // AF11: Assured Forwarding class 1, low drop probability +} + +DtlsTransport::~DtlsTransport() { + stop(); + + PLOG_DEBUG << "Destroying DTLS transport"; + SSL_free(mSsl); + SSL_CTX_free(mCtx); +} + +void DtlsTransport::start() { + PLOG_DEBUG << "Starting DTLS transport"; + registerIncoming(); + changeState(State::Connecting); + + int ret, err; + { + std::lock_guard lock(mSslMutex); + + size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6 + SSL_set_mtu(mSsl, static_cast(mtu)); + PLOG_VERBOSE << "DTLS MTU set to " << mtu; + + // Initiate the handshake + ret = SSL_do_handshake(mSsl); + err = SSL_get_error(mSsl, ret); + } + + openssl::check_error(err, "Handshake failed"); + + handleTimeout(); +} + +void DtlsTransport::stop() { + PLOG_DEBUG << "Stopping DTLS transport"; + unregisterIncoming(); + mIncomingQueue.stop(); + enqueueRecv(); +} + +bool DtlsTransport::send(message_ptr message) { + if (!message || state() != State::Connected) + return false; + + PLOG_VERBOSE << "Send size=" << message->size(); + + int ret, err; + { + std::lock_guard lock(mSslMutex); + mCurrentDscp = message->dscp; + ret = SSL_write(mSsl, message->data(), int(message->size())); + err = SSL_get_error(mSsl, ret); + } + + if (!openssl::check_error(err)) + return false; + + return mOutgoingResult; +} + +void DtlsTransport::incoming(message_ptr message) { + if (!message) { + mIncomingQueue.stop(); + enqueueRecv(); + return; + } + + PLOG_VERBOSE << "Incoming size=" << message->size(); + mIncomingQueue.push(message); + enqueueRecv(); +} + +bool DtlsTransport::outgoing(message_ptr message) { + message->dscp = mCurrentDscp; + + bool result = Transport::outgoing(std::move(message)); + mOutgoingResult = result; + return result; +} + +bool DtlsTransport::demuxMessage(message_ptr) { + // Dummy + return false; +} + +void DtlsTransport::postHandshake() { + // Dummy +} + +void DtlsTransport::doRecv() { + std::lock_guard lock(mRecvMutex); + --mPendingRecvCount; + + if (state() != State::Connecting && state() != State::Connected) + return; + + try { + const size_t bufferSize = 4096; + byte buffer[bufferSize]; + + // Process pending messages + while (mIncomingQueue.running()) { + auto next = mIncomingQueue.pop(); + if (!next) { + // No more messages pending, handle timeout if connecting + if (state() == State::Connecting) + handleTimeout(); + + return; + } + + message_ptr message = std::move(*next); + if (demuxMessage(message)) + continue; + + BIO_write(mInBio, message->data(), int(message->size())); + + if (state() == State::Connecting) { + // Continue the handshake + int ret, err; + { + std::lock_guard lock(mSslMutex); + ret = SSL_do_handshake(mSsl); + err = SSL_get_error(mSsl, ret); + } + + if (openssl::check_error(err, "Handshake failed")) { + // RFC 8261: DTLS MUST support sending messages larger than the current path MTU + // See https://www.rfc-editor.org/rfc/rfc8261.html#section-5 + { + std::lock_guard lock(mSslMutex); + SSL_set_mtu(mSsl, bufferSize + 1); + } + + PLOG_INFO << "DTLS handshake finished"; + postHandshake(); + changeState(State::Connected); + } + } + + if (state() == State::Connected) { + int ret, err; + { + std::lock_guard lock(mSslMutex); + ret = SSL_read(mSsl, buffer, bufferSize); + err = SSL_get_error(mSsl, ret); + } + + if (err == SSL_ERROR_ZERO_RETURN) { + PLOG_DEBUG << "TLS connection cleanly closed"; + break; + } + + if (openssl::check_error(err)) + recv(make_message(buffer, buffer + ret)); + } + } + + std::lock_guard lock(mSslMutex); + SSL_shutdown(mSsl); + + } catch (const std::exception &e) { + PLOG_ERROR << "DTLS recv: " << e.what(); + } + + if (state() == State::Connected) { + PLOG_INFO << "DTLS closed"; + changeState(State::Disconnected); + recv(nullptr); + } else { + PLOG_ERROR << "DTLS handshake failed"; + changeState(State::Failed); + } +} + +void DtlsTransport::handleTimeout() { + std::lock_guard lock(mSslMutex); + + // Warning: This function breaks the usual return value convention + int ret = DTLSv1_handle_timeout(mSsl); + if (ret < 0) { + throw std::runtime_error("Handshake timeout"); // write BIO can't fail + } else if (ret > 0) { + LOG_VERBOSE << "DTLS retransmit done"; + } + + struct timeval tv = {}; + if (DTLSv1_get_timeout(mSsl, &tv)) { + auto timeout = milliseconds(tv.tv_sec * 1000 + tv.tv_usec / 1000); + // Also handle handshake timeout manually because OpenSSL actually + // doesn't... OpenSSL backs off exponentially in base 2 starting from the + // recommended 1s so this allows for 5 retransmissions and fails after + // roughly 30s. + if (timeout > 30s) + throw std::runtime_error("Handshake timeout"); + + LOG_VERBOSE << "DTLS retransmit timeout is " << timeout.count() << "ms"; + ThreadPool::Instance().schedule(timeout, [weak_this = weak_from_this()]() { + if (auto locked = weak_this.lock()) + locked->doRecv(); + }); + } +} + +int DtlsTransport::CertificateCallback(int /*preverify_ok*/, X509_STORE_CTX *ctx) { + SSL *ssl = + static_cast(X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx())); + DtlsTransport *t = + static_cast(SSL_get_ex_data(ssl, DtlsTransport::TransportExIndex)); + + X509 *crt = X509_STORE_CTX_get_current_cert(ctx); + string fingerprint = make_fingerprint(crt, t->mFingerprintAlgorithm); + + return t->mVerifierCallback(fingerprint) ? 1 : 0; +} + +void DtlsTransport::InfoCallback(const SSL *ssl, int where, int ret) { + DtlsTransport *t = + static_cast(SSL_get_ex_data(ssl, DtlsTransport::TransportExIndex)); + + if (where & SSL_CB_ALERT) { + if (ret != 256) { // Close Notify + PLOG_ERROR << "DTLS alert: " << SSL_alert_desc_string_long(ret); + } + t->mIncomingQueue.stop(); // Close the connection + } +} + +int DtlsTransport::BioMethodNew(BIO *bio) { + BIO_set_init(bio, 1); + BIO_set_data(bio, NULL); + BIO_set_shutdown(bio, 0); + return 1; +} + +int DtlsTransport::BioMethodFree(BIO *bio) { + if (!bio) + return 0; + BIO_set_data(bio, NULL); + return 1; +} + +int DtlsTransport::BioMethodWrite(BIO *bio, const char *in, int inl) { + if (inl <= 0) + return inl; + auto transport = reinterpret_cast(BIO_get_data(bio)); + if (!transport) + return -1; + auto b = reinterpret_cast(in); + transport->outgoing(make_message(b, b + inl)); + return inl; // can't fail +} + +long DtlsTransport::BioMethodCtrl(BIO * /*bio*/, int cmd, long /*num*/, void * /*ptr*/) { + switch (cmd) { + case BIO_CTRL_FLUSH: + return 1; + case BIO_CTRL_DGRAM_QUERY_MTU: + return 0; // SSL_OP_NO_QUERY_MTU must be set + case BIO_CTRL_WPENDING: + case BIO_CTRL_PENDING: + return 0; + default: + break; + } + return 0; +} + +#endif + +} // namespace rtc::impl diff --git a/datachannel/src/impl/dtlstransport.hpp b/datachannel/src/impl/dtlstransport.hpp new file mode 100644 index 000000000..96565b6a1 --- /dev/null +++ b/datachannel/src/impl/dtlstransport.hpp @@ -0,0 +1,125 @@ +/** + * Copyright (c) 2019 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_DTLS_TRANSPORT_H +#define RTC_IMPL_DTLS_TRANSPORT_H + +#include "certificate.hpp" +#include "common.hpp" +#include "queue.hpp" +#include "tls.hpp" +#include "transport.hpp" + +#include +#include +#include +#include + +namespace rtc::impl { + +class IceTransport; + +class DtlsTransport : public Transport, public std::enable_shared_from_this { +public: + static void Init(); + static void Cleanup(); + + using verifier_callback = std::function; + + DtlsTransport(shared_ptr lower, certificate_ptr certificate, optional mtu, + CertificateFingerprint::Algorithm fingerprintAlgorithm, + verifier_callback verifierCallback, state_callback stateChangeCallback); + ~DtlsTransport(); + + virtual void start() override; + virtual void stop() override; + virtual bool send(message_ptr message) override; // false if dropped + + bool isClient() const { return mIsClient; } + +protected: + virtual void incoming(message_ptr message) override; + virtual bool outgoing(message_ptr message) override; + virtual bool demuxMessage(message_ptr message); + virtual void postHandshake(); + + void enqueueRecv(); + void doRecv(); + + const optional mMtu; + const certificate_ptr mCertificate; + CertificateFingerprint::Algorithm mFingerprintAlgorithm; + const verifier_callback mVerifierCallback; + const bool mIsClient; + + Queue mIncomingQueue; + std::atomic mPendingRecvCount = 0; + std::mutex mRecvMutex; + std::atomic mCurrentDscp = 0; + std::atomic mOutgoingResult = true; + +#if USE_GNUTLS + gnutls_session_t mSession; + std::mutex mSendMutex; + + static int CertificateCallback(gnutls_session_t session); + static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len); + static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen); + static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms); + +#elif USE_MBEDTLS + mbedtls_entropy_context mEntropy; + mbedtls_ctr_drbg_context mDrbg; + mbedtls_ssl_config mConf; + mbedtls_ssl_context mSsl; + + std::mutex mSslMutex; + + uint32_t mFinMs = 0, mIntMs = 0; + std::chrono::time_point mTimerSetAt; + + char mMasterSecret[48]; + char mRandBytes[64]; + mbedtls_tls_prf_types mTlsProfile = MBEDTLS_SSL_TLS_PRF_NONE; + + static int CertificateCallback(void *ctx, mbedtls_x509_crt *crt, int depth, uint32_t *flags); + static int WriteCallback(void *ctx, const unsigned char *buf, size_t len); + static int ReadCallback(void *ctx, unsigned char *buf, size_t len); + static void ExportKeysCallback(void *ctx, mbedtls_ssl_key_export_type type, + const unsigned char *secret, size_t secret_len, + const unsigned char client_random[32], + const unsigned char server_random[32], + mbedtls_tls_prf_types tls_prf_type); + static void SetTimerCallback(void *ctx, uint32_t int_ms, uint32_t fin_ms); + static int GetTimerCallback(void *ctx); + +#else // OPENSSL + SSL_CTX *mCtx = NULL; + SSL *mSsl = NULL; + BIO *mInBio, *mOutBio; + std::mutex mSslMutex; + + void handleTimeout(); + + static BIO_METHOD *BioMethods; + static int TransportExIndex; + static std::mutex GlobalMutex; + + static int CertificateCallback(int preverify_ok, X509_STORE_CTX *ctx); + static void InfoCallback(const SSL *ssl, int where, int ret); + + static int BioMethodNew(BIO *bio); + static int BioMethodFree(BIO *bio); + static int BioMethodWrite(BIO *bio, const char *in, int inl); + static long BioMethodCtrl(BIO *bio, int cmd, long num, void *ptr); +#endif +}; + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/http.cpp b/datachannel/src/impl/http.cpp new file mode 100644 index 000000000..7aa12d63f --- /dev/null +++ b/datachannel/src/impl/http.cpp @@ -0,0 +1,66 @@ +/** + * Copyright (c) 2020-2023 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "http.hpp" + +#include + +namespace rtc::impl { + +bool isHttpRequest(const byte *buffer, size_t size) { + // Check the buffer starts with a valid-looking HTTP method + for (size_t i = 0; i < size; ++i) { + char c = static_cast(buffer[i]); + if (i > 0 && c == ' ') + break; + else if (i >= 8 || c < 'A' || c > 'Z') + return false; + } + return true; +} + +size_t parseHttpLines(const byte *buffer, size_t size, std::list &lines) { + lines.clear(); + auto begin = reinterpret_cast(buffer); + auto end = begin + size; + auto cur = begin; + while (true) { + auto last = cur; + cur = std::find(cur, end, '\n'); + if (cur == end) + return 0; + string line(last, cur != begin && *std::prev(cur) == '\r' ? std::prev(cur++) : cur++); + if (line.empty()) + break; + lines.emplace_back(std::move(line)); + } + + return cur - begin; +} + +std::multimap parseHttpHeaders(const std::list &lines) { + std::multimap headers; + for (const auto &line : lines) { + if (size_t pos = line.find_first_of(':'); pos != string::npos) { + string key = line.substr(0, pos); + string value = ""; + if (size_t subPos = line.find_first_not_of(' ', pos + 1); subPos != string::npos) { + value = line.substr(subPos); + } + std::transform(key.begin(), key.end(), key.begin(), + [](char c) { return std::tolower(c); }); + headers.emplace(std::move(key), std::move(value)); + } else { + headers.emplace(line, ""); + } + } + + return headers; +} + +} // namespace rtc::impl diff --git a/datachannel/src/impl/http.hpp b/datachannel/src/impl/http.hpp new file mode 100644 index 000000000..c78869f76 --- /dev/null +++ b/datachannel/src/impl/http.hpp @@ -0,0 +1,30 @@ +/** + * Copyright (c) 2020-2023 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_HTTP_H +#define RTC_IMPL_HTTP_H + +#include "common.hpp" + +#include +#include + +namespace rtc::impl { + +// Check the buffer contains the beginning of an HTTP request +bool isHttpRequest(const byte *buffer, size_t size); + +// Parse an http message into lines +size_t parseHttpLines(const byte *buffer, size_t size, std::list &lines); + +// Parse headers of a http message +std::multimap parseHttpHeaders(const std::list &lines); + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/httpproxytransport.cpp b/datachannel/src/impl/httpproxytransport.cpp new file mode 100644 index 000000000..e038c3428 --- /dev/null +++ b/datachannel/src/impl/httpproxytransport.cpp @@ -0,0 +1,129 @@ +/** + * Copyright (c) 2020-2021 Paul-Louis Ageneau + * Copyright (c) 2023 Eric Gressman + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "httpproxytransport.hpp" +#include "http.hpp" +#include "tcptransport.hpp" + +#if RTC_ENABLE_WEBSOCKET + +namespace rtc::impl { + +using std::to_string; +using std::chrono::system_clock; + +HttpProxyTransport::HttpProxyTransport(shared_ptr lower, std::string hostname, + std::string service, state_callback stateCallback) + : Transport(lower, std::move(stateCallback)), mHostname(std::move(hostname)), + mService(std::move(service)) { + PLOG_DEBUG << "Initializing HTTP proxy transport"; + if (!lower->isActive()) + throw std::logic_error("HTTP proxy transport expects the lower transport to be active"); +} + +HttpProxyTransport::~HttpProxyTransport() { unregisterIncoming(); } + +void HttpProxyTransport::start() { + registerIncoming(); + + changeState(State::Connecting); + sendHttpRequest(); +} + +void HttpProxyTransport::stop() { unregisterIncoming(); } + +bool HttpProxyTransport::send(message_ptr message) { + if (state() != State::Connected) + throw std::runtime_error("HTTP proxy connection is not open"); + + PLOG_VERBOSE << "Send size=" << message->size(); + return outgoing(message); +} + +bool HttpProxyTransport::isActive() const { return true; } + +void HttpProxyTransport::incoming(message_ptr message) { + auto s = state(); + if (s != State::Connecting && s != State::Connected) + return; // Drop + + if (message) { + PLOG_VERBOSE << "Incoming size=" << message->size(); + + try { + if (state() == State::Connecting) { + mBuffer.insert(mBuffer.end(), message->begin(), message->end()); + if (size_t len = parseHttpResponse(mBuffer.data(), mBuffer.size())) { + PLOG_INFO << "HTTP proxy connection open"; + changeState(State::Connected); + mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len); + + if (!mBuffer.empty()) { + recv(make_message(mBuffer)); + mBuffer.clear(); + } + } + } else if (state() == State::Connected) { + recv(std::move(message)); + } + + return; + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + } + } + + if (state() == State::Connected) { + PLOG_INFO << "HTTP proxy disconnected"; + changeState(State::Disconnected); + recv(nullptr); + } else { + PLOG_ERROR << "HTTP proxy connection failed"; + changeState(State::Failed); + } +} + +bool HttpProxyTransport::sendHttpRequest() { + PLOG_DEBUG << "Sending HTTP request to proxy"; + + const string request = generateHttpRequest(); + auto data = reinterpret_cast(request.data()); + return outgoing(make_message(data, data + request.size())); +} + +string HttpProxyTransport::generateHttpRequest() { + return "CONNECT " + mHostname + ":" + mService + " HTTP/1.1\r\nHost: " + mHostname + "\r\n\r\n"; +} + +size_t HttpProxyTransport::parseHttpResponse(std::byte *buffer, size_t size) { + std::list lines; + size_t length = parseHttpLines(buffer, size, lines); + if (length == 0) + return 0; + + if (lines.empty()) + throw std::runtime_error("Invalid response from HTTP proxy"); + + std::istringstream status(std::move(lines.front())); + lines.pop_front(); + + string protocol; + unsigned int code = 0; + status >> protocol >> code; + + if (code != 200) + throw std::runtime_error("Unexpected response code " + to_string(code) + + " from HTTP proxy"); + + return length; +} + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/httpproxytransport.hpp b/datachannel/src/impl/httpproxytransport.hpp new file mode 100644 index 000000000..d99cf7c23 --- /dev/null +++ b/datachannel/src/impl/httpproxytransport.hpp @@ -0,0 +1,50 @@ +/** + * Copyright (c) 2020-2021 Paul-Louis Ageneau + * Copyright (c) 2023 Eric Gressman + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_TCP_PROXY_TRANSPORT_H +#define RTC_IMPL_TCP_PROXY_TRANSPORT_H + +#include "common.hpp" +#include "transport.hpp" + +#if RTC_ENABLE_WEBSOCKET + +namespace rtc::impl { + +class TcpTransport; + +class HttpProxyTransport final : public Transport, + public std::enable_shared_from_this { +public: + HttpProxyTransport(shared_ptr lower, std::string hostname, std::string service, + state_callback stateCallback); + ~HttpProxyTransport(); + + void start() override; + void stop() override; + bool send(message_ptr message) override; + + bool isActive() const; + +private: + void incoming(message_ptr message) override; + bool sendHttpRequest(); + string generateHttpRequest(); + size_t parseHttpResponse(std::byte *buffer, size_t size); + + string mHostname; + string mService; + binary mBuffer; +}; + +} // namespace rtc::impl + +#endif + +#endif diff --git a/datachannel/src/impl/icetransport.cpp b/datachannel/src/impl/icetransport.cpp new file mode 100644 index 000000000..269d16a3f --- /dev/null +++ b/datachannel/src/impl/icetransport.cpp @@ -0,0 +1,893 @@ +/** + * Copyright (c) 2019-2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "icetransport.hpp" +#include "configuration.hpp" +#include "internals.hpp" +#include "transport.hpp" +#include "utils.hpp" + +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#include +#else +#include +#include +#include +#include +#endif + +#include + +using namespace std::chrono_literals; +using std::chrono::system_clock; + +namespace rtc::impl { + +#if !USE_NICE // libjuice + +const int MAX_TURN_SERVERS_COUNT = 2; + +void IceTransport::Init() { + // Dummy +} + +void IceTransport::Cleanup() { + // Dummy +} + +IceTransport::IceTransport(const Configuration &config, candidate_callback candidateCallback, + state_callback stateChangeCallback, + gathering_state_callback gatheringStateChangeCallback) + : Transport(nullptr, std::move(stateChangeCallback)), mRole(Description::Role::ActPass), + mMid("0"), mGatheringState(GatheringState::New), + mCandidateCallback(std::move(candidateCallback)), + mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)), + mAgent(nullptr, nullptr) { + + PLOG_DEBUG << "Initializing ICE transport (libjuice)"; + + juice_log_level_t level; + auto logger = plog::get(); + switch (logger ? logger->getMaxSeverity() : plog::none) { + case plog::none: + level = JUICE_LOG_LEVEL_NONE; + break; + case plog::fatal: + level = JUICE_LOG_LEVEL_FATAL; + break; + case plog::error: + level = JUICE_LOG_LEVEL_ERROR; + break; + case plog::warning: + level = JUICE_LOG_LEVEL_WARN; + break; + case plog::info: + case plog::debug: // juice debug is output as verbose + level = JUICE_LOG_LEVEL_INFO; + break; + default: + level = JUICE_LOG_LEVEL_VERBOSE; + break; + } + juice_set_log_handler(IceTransport::LogCallback); + juice_set_log_level(level); + + juice_config_t jconfig = {}; + jconfig.cb_state_changed = IceTransport::StateChangeCallback; + jconfig.cb_candidate = IceTransport::CandidateCallback; + jconfig.cb_gathering_done = IceTransport::GatheringDoneCallback; + jconfig.cb_recv = IceTransport::RecvCallback; + jconfig.user_ptr = this; + + if (config.enableIceTcp) { + PLOG_WARNING << "ICE-TCP is not supported with libjuice"; + } + + if (config.enableIceUdpMux) { + PLOG_DEBUG << "Enabling ICE UDP mux"; + jconfig.concurrency_mode = JUICE_CONCURRENCY_MODE_MUX; + } else { + jconfig.concurrency_mode = JUICE_CONCURRENCY_MODE_POLL; + } + + // Randomize servers order + std::vector servers = config.iceServers; + std::shuffle(servers.begin(), servers.end(), utils::random_engine()); + + // Pick a STUN server + for (auto &server : servers) { + if (!server.hostname.empty() && server.type == IceServer::Type::Stun) { + if (server.port == 0) + server.port = 3478; // STUN UDP port + PLOG_INFO << "Using STUN server \"" << server.hostname << ":" << server.port << "\""; + jconfig.stun_server_host = server.hostname.c_str(); + jconfig.stun_server_port = server.port; + break; + } + } + + juice_turn_server_t turn_servers[MAX_TURN_SERVERS_COUNT]; + std::memset(turn_servers, 0, sizeof(turn_servers)); + + // Add TURN servers + int k = 0; + for (auto &server : servers) { + if (!server.hostname.empty() && server.type == IceServer::Type::Turn) { + if (server.port == 0) + server.port = 3478; // TURN UDP port + PLOG_INFO << "Using TURN server \"" << server.hostname << ":" << server.port << "\""; + turn_servers[k].host = server.hostname.c_str(); + turn_servers[k].username = server.username.c_str(); + turn_servers[k].password = server.password.c_str(); + turn_servers[k].port = server.port; + if (++k >= MAX_TURN_SERVERS_COUNT) + break; + } + } + jconfig.turn_servers = k > 0 ? turn_servers : nullptr; + jconfig.turn_servers_count = k; + + // Bind address + if (config.bindAddress) { + jconfig.bind_address = config.bindAddress->c_str(); + } + + // Port range + if (config.portRangeBegin > 1024 || + (config.portRangeEnd != 0 && config.portRangeEnd != 65535)) { + jconfig.local_port_range_begin = config.portRangeBegin; + jconfig.local_port_range_end = config.portRangeEnd; + } + + // Create agent + mAgent = decltype(mAgent)(juice_create(&jconfig), juice_destroy); + if (!mAgent) + throw std::runtime_error("Failed to create the ICE agent"); +} + +IceTransport::~IceTransport() { + PLOG_DEBUG << "Destroying ICE transport"; + mAgent.reset(); +} + +Description::Role IceTransport::role() const { return mRole; } + +Description IceTransport::getLocalDescription(Description::Type type) const { + char sdp[JUICE_MAX_SDP_STRING_LEN]; + if (juice_get_local_description(mAgent.get(), sdp, JUICE_MAX_SDP_STRING_LEN) < 0) + throw std::runtime_error("Failed to generate local SDP"); + + // RFC 5763: The endpoint that is the offerer MUST use the setup attribute value of + // setup:actpass. + // See https://www.rfc-editor.org/rfc/rfc5763.html#section-5 + Description desc(string(sdp), type, + type == Description::Type::Offer ? Description::Role::ActPass : mRole); + desc.addIceOption("trickle"); + return desc; +} + +void IceTransport::setRemoteDescription(const Description &description) { + // RFC 5763: The answerer MUST use either a setup attribute value of setup:active or + // setup:passive. + // See https://www.rfc-editor.org/rfc/rfc5763.html#section-5 + if (description.type() == Description::Type::Answer && + description.role() == Description::Role::ActPass) + throw std::invalid_argument("Illegal role actpass in remote answer description"); + + // RFC 5763: Note that if the answerer uses setup:passive, then the DTLS handshake + // will not begin until the answerer is received, which adds additional latency. + // setup:active allows the answer and the DTLS handshake to occur in parallel. Thus, + // setup:active is RECOMMENDED. + if (mRole == Description::Role::ActPass) + mRole = description.role() == Description::Role::Active ? Description::Role::Passive + : Description::Role::Active; + + if (mRole == description.role()) + throw std::invalid_argument("Incompatible roles with remote description"); + + mMid = description.bundleMid(); + if (juice_set_remote_description(mAgent.get(), + description.generateApplicationSdp("\r\n").c_str()) < 0) + throw std::invalid_argument("Invalid ICE settings from remote SDP"); +} + +bool IceTransport::addRemoteCandidate(const Candidate &candidate) { + // Don't try to pass unresolved candidates for more safety + if (!candidate.isResolved()) + return false; + + return juice_add_remote_candidate(mAgent.get(), string(candidate).c_str()) >= 0; +} + +void IceTransport::gatherLocalCandidates(string mid) { + mMid = std::move(mid); + + // Change state now as candidates calls can be synchronous + changeGatheringState(GatheringState::InProgress); + + if (juice_gather_candidates(mAgent.get()) < 0) { + throw std::runtime_error("Failed to gather local ICE candidates"); + } +} + +optional IceTransport::getLocalAddress() const { + char str[JUICE_MAX_ADDRESS_STRING_LEN]; + if (juice_get_selected_addresses(mAgent.get(), str, JUICE_MAX_ADDRESS_STRING_LEN, NULL, 0) == + 0) { + return std::make_optional(string(str)); + } + return nullopt; +} +optional IceTransport::getRemoteAddress() const { + char str[JUICE_MAX_ADDRESS_STRING_LEN]; + if (juice_get_selected_addresses(mAgent.get(), NULL, 0, str, JUICE_MAX_ADDRESS_STRING_LEN) == + 0) { + return std::make_optional(string(str)); + } + return nullopt; +} + +bool IceTransport::getSelectedCandidatePair(Candidate *local, Candidate *remote) { + char sdpLocal[JUICE_MAX_CANDIDATE_SDP_STRING_LEN]; + char sdpRemote[JUICE_MAX_CANDIDATE_SDP_STRING_LEN]; + if (juice_get_selected_candidates(mAgent.get(), sdpLocal, JUICE_MAX_CANDIDATE_SDP_STRING_LEN, + sdpRemote, JUICE_MAX_CANDIDATE_SDP_STRING_LEN) == 0) { + if (local) { + *local = Candidate(sdpLocal, mMid); + local->resolve(Candidate::ResolveMode::Simple); + } + if (remote) { + *remote = Candidate(sdpRemote, mMid); + remote->resolve(Candidate::ResolveMode::Simple); + } + return true; + } + return false; +} + +bool IceTransport::send(message_ptr message) { + auto s = state(); + if (!message || (s != State::Connected && s != State::Completed)) + return false; + + PLOG_VERBOSE << "Send size=" << message->size(); + return outgoing(message); +} + +bool IceTransport::outgoing(message_ptr message) { + // Explicit Congestion Notification takes the least-significant 2 bits of the DS field + int ds = int(message->dscp << 2); + return juice_send_diffserv(mAgent.get(), reinterpret_cast(message->data()), + message->size(), ds) >= 0; +} + +void IceTransport::changeGatheringState(GatheringState state) { + if (mGatheringState.exchange(state) != state) + mGatheringStateChangeCallback(mGatheringState); +} + +void IceTransport::processStateChange(unsigned int state) { + switch (state) { + case JUICE_STATE_DISCONNECTED: + changeState(State::Disconnected); + break; + case JUICE_STATE_CONNECTING: + changeState(State::Connecting); + break; + case JUICE_STATE_CONNECTED: + changeState(State::Connected); + break; + case JUICE_STATE_COMPLETED: + changeState(State::Completed); + break; + case JUICE_STATE_FAILED: + changeState(State::Failed); + break; + }; +} + +void IceTransport::processCandidate(const string &candidate) { + mCandidateCallback(Candidate(candidate, mMid)); +} + +void IceTransport::processGatheringDone() { changeGatheringState(GatheringState::Complete); } + +void IceTransport::StateChangeCallback(juice_agent_t *, juice_state_t state, void *user_ptr) { + auto iceTransport = static_cast(user_ptr); + try { + iceTransport->processStateChange(static_cast(state)); + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + } +} + +void IceTransport::CandidateCallback(juice_agent_t *, const char *sdp, void *user_ptr) { + auto iceTransport = static_cast(user_ptr); + try { + iceTransport->processCandidate(sdp); + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + } +} + +void IceTransport::GatheringDoneCallback(juice_agent_t *, void *user_ptr) { + auto iceTransport = static_cast(user_ptr); + try { + iceTransport->processGatheringDone(); + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + } +} + +void IceTransport::RecvCallback(juice_agent_t *, const char *data, size_t size, void *user_ptr) { + auto iceTransport = static_cast(user_ptr); + try { + PLOG_VERBOSE << "Incoming size=" << size; + auto b = reinterpret_cast(data); + iceTransport->incoming(make_message(b, b + size)); + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + } +} + +void IceTransport::LogCallback(juice_log_level_t level, const char *message) { + plog::Severity severity; + switch (level) { + case JUICE_LOG_LEVEL_FATAL: + severity = plog::fatal; + break; + case JUICE_LOG_LEVEL_ERROR: + severity = plog::error; + break; + case JUICE_LOG_LEVEL_WARN: + severity = plog::warning; + break; + case JUICE_LOG_LEVEL_INFO: + severity = plog::info; + break; + default: + severity = plog::verbose; // libjuice debug as verbose + break; + } + PLOG(severity) << "juice: " << message; +} + +#else // USE_NICE == 1 + +unique_ptr IceTransport::MainLoop(nullptr, nullptr); +std::thread IceTransport::MainLoopThread; + +void IceTransport::Init() { + g_log_set_handler("libnice", G_LOG_LEVEL_MASK, LogCallback, nullptr); + + IF_PLOG(plog::verbose) { + nice_debug_enable(false); // do not output STUN debug messages + } + + MainLoop = decltype(MainLoop)(g_main_loop_new(nullptr, FALSE), g_main_loop_unref); + if (!MainLoop) + throw std::runtime_error("Failed to create the main loop"); + + MainLoopThread = std::thread(g_main_loop_run, MainLoop.get()); +} + +void IceTransport::Cleanup() { + g_main_loop_quit(MainLoop.get()); + MainLoopThread.join(); + MainLoop.reset(); +} + +static void closeNiceAgentCallback(GObject *niceAgent, GAsyncResult *, gpointer) { + g_object_unref(niceAgent); +} + +static void closeNiceAgent(NiceAgent *niceAgent) { + // close the agent to prune alive TURN refreshes, before releasing it + nice_agent_close_async(niceAgent, closeNiceAgentCallback, nullptr); +} + +IceTransport::IceTransport(const Configuration &config, candidate_callback candidateCallback, + state_callback stateChangeCallback, + gathering_state_callback gatheringStateChangeCallback) + : Transport(nullptr, std::move(stateChangeCallback)), mRole(Description::Role::ActPass), + mMid("0"), mGatheringState(GatheringState::New), + mCandidateCallback(std::move(candidateCallback)), + mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)), + mNiceAgent(nullptr, nullptr), mOutgoingDscp(0) { + + PLOG_DEBUG << "Initializing ICE transport (libnice)"; + + if (!MainLoop) + throw std::logic_error("Main loop for nice agent is not created"); + + // RFC 8445: The nomination process that was referred to as "aggressive nomination" in RFC 5245 + // has been deprecated in this specification. + // libnice defaults to aggressive nomation therefore we change to regular nomination. + // See https://gitlab.freedesktop.org/libnice/libnice/-/merge_requests/125 + NiceAgentOption flags = NICE_AGENT_OPTION_REGULAR_NOMINATION; + + // Create agent + mNiceAgent = decltype(mNiceAgent)( + nice_agent_new_full( + g_main_loop_get_context(MainLoop.get()), + NICE_COMPATIBILITY_RFC5245, // RFC 5245 was obsoleted by RFC 8445 but this should be OK + flags), + closeNiceAgent); + + if (!mNiceAgent) + throw std::runtime_error("Failed to create the nice agent"); + + mStreamId = nice_agent_add_stream(mNiceAgent.get(), 1); + if (!mStreamId) + throw std::runtime_error("Failed to add a stream"); + + g_object_set(G_OBJECT(mNiceAgent.get()), "controlling-mode", TRUE, nullptr); // decided later + g_object_set(G_OBJECT(mNiceAgent.get()), "ice-udp", TRUE, nullptr); + g_object_set(G_OBJECT(mNiceAgent.get()), "ice-tcp", config.enableIceTcp ? TRUE : FALSE, + nullptr); + + // RFC 8445: Agents MUST NOT use an RTO value smaller than 500 ms. + g_object_set(G_OBJECT(mNiceAgent.get()), "stun-initial-timeout", 500, nullptr); + g_object_set(G_OBJECT(mNiceAgent.get()), "stun-max-retransmissions", 3, nullptr); + + // RFC 8445: ICE agents SHOULD use a default Ta value, 50 ms, but MAY use another value based on + // the characteristics of the associated data. + g_object_set(G_OBJECT(mNiceAgent.get()), "stun-pacing-timer", 25, nullptr); + + g_object_set(G_OBJECT(mNiceAgent.get()), "upnp", FALSE, nullptr); + g_object_set(G_OBJECT(mNiceAgent.get()), "upnp-timeout", 200, nullptr); + + // Proxy + if (config.proxyServer) { + const auto &proxyServer = *config.proxyServer; + + NiceProxyType type; + switch (proxyServer.type) { + case ProxyServer::Type::Http: + type = NICE_PROXY_TYPE_HTTP; + break; + case ProxyServer::Type::Socks5: + type = NICE_PROXY_TYPE_SOCKS5; + break; + default: + PLOG_WARNING << "Proxy server type is not supported"; + type = NICE_PROXY_TYPE_NONE; + break; + } + + g_object_set(G_OBJECT(mNiceAgent.get()), "proxy-type", type, nullptr); + g_object_set(G_OBJECT(mNiceAgent.get()), "proxy-ip", proxyServer.hostname.c_str(), nullptr); + g_object_set(G_OBJECT(mNiceAgent.get()), "proxy-port", guint(proxyServer.port), nullptr); + + if (proxyServer.username) + g_object_set(G_OBJECT(mNiceAgent.get()), "proxy-username", + proxyServer.username->c_str(), nullptr); + + if (proxyServer.password) + g_object_set(G_OBJECT(mNiceAgent.get()), "proxy-password", + proxyServer.password->c_str(), nullptr); + } + + if (config.enableIceUdpMux) { + PLOG_WARNING << "ICE UDP mux is not available with libnice"; + } + + // Randomize order + std::vector servers = config.iceServers; + std::shuffle(servers.begin(), servers.end(), utils::random_engine()); + + // Add one STUN server + bool success = false; + for (auto &server : servers) { + if (server.hostname.empty()) + continue; + if (server.type != IceServer::Type::Stun) + continue; + if (server.port == 0) + server.port = 3478; // STUN UDP port + + struct addrinfo hints = {}; + hints.ai_family = AF_INET; // IPv4 + hints.ai_socktype = SOCK_DGRAM; + hints.ai_protocol = IPPROTO_UDP; + hints.ai_flags = AI_ADDRCONFIG; + struct addrinfo *result = nullptr; + if (getaddrinfo(server.hostname.c_str(), std::to_string(server.port).c_str(), &hints, + &result) != 0) { + PLOG_WARNING << "Unable to resolve STUN server address: " << server.hostname << ':' + << server.port; + continue; + } + + for (auto p = result; p; p = p->ai_next) { + if (p->ai_family == AF_INET) { + char nodebuffer[MAX_NUMERICNODE_LEN]; + char servbuffer[MAX_NUMERICSERV_LEN]; + if (getnameinfo(p->ai_addr, p->ai_addrlen, nodebuffer, MAX_NUMERICNODE_LEN, + servbuffer, MAX_NUMERICSERV_LEN, + NI_NUMERICHOST | NI_NUMERICSERV) == 0) { + PLOG_INFO << "Using STUN server \"" << server.hostname << ":" << server.port + << "\""; + g_object_set(G_OBJECT(mNiceAgent.get()), "stun-server", nodebuffer, nullptr); + g_object_set(G_OBJECT(mNiceAgent.get()), "stun-server-port", + std::stoul(servbuffer), nullptr); + success = true; + break; + } + } + } + + freeaddrinfo(result); + if (success) + break; + } + + // Add TURN servers + for (auto &server : servers) { + if (server.hostname.empty()) + continue; + if (server.type != IceServer::Type::Turn) + continue; + if (server.port == 0) + server.port = server.relayType == IceServer::RelayType::TurnTls ? 5349 : 3478; + + struct addrinfo hints = {}; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = + server.relayType == IceServer::RelayType::TurnUdp ? SOCK_DGRAM : SOCK_STREAM; + hints.ai_protocol = + server.relayType == IceServer::RelayType::TurnUdp ? IPPROTO_UDP : IPPROTO_TCP; + hints.ai_flags = AI_ADDRCONFIG; + struct addrinfo *result = nullptr; + if (getaddrinfo(server.hostname.c_str(), std::to_string(server.port).c_str(), &hints, + &result) != 0) { + PLOG_WARNING << "Unable to resolve TURN server address: " << server.hostname << ':' + << server.port; + continue; + } + + for (auto p = result; p; p = p->ai_next) { + if (p->ai_family == AF_INET || p->ai_family == AF_INET6) { + char nodebuffer[MAX_NUMERICNODE_LEN]; + char servbuffer[MAX_NUMERICSERV_LEN]; + if (getnameinfo(p->ai_addr, p->ai_addrlen, nodebuffer, MAX_NUMERICNODE_LEN, + servbuffer, MAX_NUMERICSERV_LEN, + NI_NUMERICHOST | NI_NUMERICSERV) == 0) { + PLOG_INFO << "Using TURN server \"" << server.hostname << ":" << server.port + << "\""; + NiceRelayType niceRelayType; + switch (server.relayType) { + case IceServer::RelayType::TurnTcp: + niceRelayType = NICE_RELAY_TYPE_TURN_TCP; + break; + case IceServer::RelayType::TurnTls: + niceRelayType = NICE_RELAY_TYPE_TURN_TLS; + break; + default: + niceRelayType = NICE_RELAY_TYPE_TURN_UDP; + break; + } + nice_agent_set_relay_info(mNiceAgent.get(), mStreamId, 1, nodebuffer, + std::stoul(servbuffer), server.username.c_str(), + server.password.c_str(), niceRelayType); + } + } + } + + freeaddrinfo(result); + } + + g_signal_connect(G_OBJECT(mNiceAgent.get()), "component-state-changed", + G_CALLBACK(StateChangeCallback), this); + g_signal_connect(G_OBJECT(mNiceAgent.get()), "new-candidate-full", + G_CALLBACK(CandidateCallback), this); + g_signal_connect(G_OBJECT(mNiceAgent.get()), "candidate-gathering-done", + G_CALLBACK(GatheringDoneCallback), this); + + nice_agent_set_stream_name(mNiceAgent.get(), mStreamId, "application"); + nice_agent_set_port_range(mNiceAgent.get(), mStreamId, 1, config.portRangeBegin, + config.portRangeEnd); + + nice_agent_attach_recv(mNiceAgent.get(), mStreamId, 1, g_main_loop_get_context(MainLoop.get()), + RecvCallback, this); +} + +IceTransport::~IceTransport() { + PLOG_DEBUG << "Destroying ICE transport"; + nice_agent_attach_recv(mNiceAgent.get(), mStreamId, 1, g_main_loop_get_context(MainLoop.get()), + NULL, NULL); + nice_agent_remove_stream(mNiceAgent.get(), mStreamId); + mNiceAgent.reset(); + + if (mTimeoutId) + g_source_remove(mTimeoutId); +} + +Description::Role IceTransport::role() const { return mRole; } + +Description IceTransport::getLocalDescription(Description::Type type) const { + // RFC 8445: The initiating agent that started the ICE processing MUST take the controlling + // role, and the other MUST take the controlled role. + g_object_set(G_OBJECT(mNiceAgent.get()), "controlling-mode", + type == Description::Type::Offer ? TRUE : FALSE, nullptr); + + unique_ptr sdp(nice_agent_generate_local_sdp(mNiceAgent.get()), + g_free); + + // RFC 5763: The endpoint that is the offerer MUST use the setup attribute value of + // setup:actpass. + // See https://www.rfc-editor.org/rfc/rfc5763.html#section-5 + Description desc(string(sdp.get()), type, + type == Description::Type::Offer ? Description::Role::ActPass : mRole); + desc.addIceOption("trickle"); + return desc; +} + +void IceTransport::setRemoteDescription(const Description &description) { + // RFC 5763: The answerer MUST use either a setup attribute value of setup:active or + // setup:passive. + // See https://www.rfc-editor.org/rfc/rfc5763.html#section-5 + if (description.type() == Description::Type::Answer && + description.role() == Description::Role::ActPass) + throw std::invalid_argument("Illegal role actpass in remote answer description"); + + // RFC 5763: Note that if the answerer uses setup:passive, then the DTLS handshake + // will not begin until the answerer is received, which adds additional latency. + // setup:active allows the answer and the DTLS handshake to occur in parallel. Thus, + // setup:active is RECOMMENDED. + if (mRole == Description::Role::ActPass) + mRole = description.role() == Description::Role::Active ? Description::Role::Passive + : Description::Role::Active; + + if (mRole == description.role()) + throw std::invalid_argument("Incompatible roles with remote description"); + + mMid = description.bundleMid(); + mTrickleTimeout = !description.ended() ? 30s : 0s; + + // Warning: libnice expects "\n" as end of line + if (nice_agent_parse_remote_sdp(mNiceAgent.get(), + description.generateApplicationSdp("\n").c_str()) < 0) + throw std::invalid_argument("Invalid ICE settings from remote SDP"); +} + +bool IceTransport::addRemoteCandidate(const Candidate &candidate) { + // Don't try to pass unresolved candidates to libnice for more safety + if (!candidate.isResolved()) + return false; + + // Warning: the candidate string must start with "a=candidate:" and it must not end with a + // newline or whitespace, else libnice will reject it. + string sdp(candidate); + NiceCandidate *cand = + nice_agent_parse_remote_candidate_sdp(mNiceAgent.get(), mStreamId, sdp.c_str()); + if (!cand) { + PLOG_WARNING << "Rejected ICE candidate: " << sdp; + return false; + } + + GSList *list = g_slist_append(nullptr, cand); + int ret = nice_agent_set_remote_candidates(mNiceAgent.get(), mStreamId, 1, list); + + g_slist_free_full(list, reinterpret_cast(nice_candidate_free)); + return ret > 0; +} + +void IceTransport::gatherLocalCandidates(string mid) { + mMid = std::move(mid); + + // Change state now as candidates calls can be synchronous + changeGatheringState(GatheringState::InProgress); + + if (!nice_agent_gather_candidates(mNiceAgent.get(), mStreamId)) { + throw std::runtime_error("Failed to gather local ICE candidates"); + } +} + +optional IceTransport::getLocalAddress() const { + NiceCandidate *local = nullptr; + NiceCandidate *remote = nullptr; + if (nice_agent_get_selected_pair(mNiceAgent.get(), mStreamId, 1, &local, &remote)) { + return std::make_optional(AddressToString(local->addr)); + } + return nullopt; +} + +optional IceTransport::getRemoteAddress() const { + NiceCandidate *local = nullptr; + NiceCandidate *remote = nullptr; + if (nice_agent_get_selected_pair(mNiceAgent.get(), mStreamId, 1, &local, &remote)) { + return std::make_optional(AddressToString(remote->addr)); + } + return nullopt; +} + +bool IceTransport::send(message_ptr message) { + auto s = state(); + if (!message || (s != State::Connected && s != State::Completed)) + return false; + + PLOG_VERBOSE << "Send size=" << message->size(); + return outgoing(message); +} + +bool IceTransport::outgoing(message_ptr message) { + std::lock_guard lock(mOutgoingMutex); + if (mOutgoingDscp != message->dscp) { + mOutgoingDscp = message->dscp; + // Explicit Congestion Notification takes the least-significant 2 bits of the DS field + int ds = int(message->dscp << 2); + nice_agent_set_stream_tos(mNiceAgent.get(), mStreamId, ds); // ToS is the legacy name for DS + } + return nice_agent_send(mNiceAgent.get(), mStreamId, 1, message->size(), + reinterpret_cast(message->data())) >= 0; +} + +void IceTransport::changeGatheringState(GatheringState state) { + if (mGatheringState.exchange(state) != state) + mGatheringStateChangeCallback(mGatheringState); +} + +void IceTransport::processTimeout() { + PLOG_WARNING << "ICE timeout"; + mTimeoutId = 0; + changeState(State::Failed); +} + +void IceTransport::processCandidate(const string &candidate) { + mCandidateCallback(Candidate(candidate, mMid)); +} + +void IceTransport::processGatheringDone() { changeGatheringState(GatheringState::Complete); } + +void IceTransport::processStateChange(unsigned int state) { + if (state == NICE_COMPONENT_STATE_FAILED && mTrickleTimeout.count() > 0) { + if (mTimeoutId) + g_source_remove(mTimeoutId); + mTimeoutId = g_timeout_add(mTrickleTimeout.count() /* ms */, TimeoutCallback, this); + return; + } + + if (state == NICE_COMPONENT_STATE_CONNECTED && mTimeoutId) { + g_source_remove(mTimeoutId); + mTimeoutId = 0; + } + + switch (state) { + case NICE_COMPONENT_STATE_DISCONNECTED: + changeState(State::Disconnected); + break; + case NICE_COMPONENT_STATE_CONNECTING: + changeState(State::Connecting); + break; + case NICE_COMPONENT_STATE_CONNECTED: + changeState(State::Connected); + break; + case NICE_COMPONENT_STATE_READY: + changeState(State::Completed); + break; + case NICE_COMPONENT_STATE_FAILED: + changeState(State::Failed); + break; + }; +} + +string IceTransport::AddressToString(const NiceAddress &addr) { + char buffer[NICE_ADDRESS_STRING_LEN]; + nice_address_to_string(&addr, buffer); + unsigned int port = nice_address_get_port(&addr); + std::ostringstream ss; + ss << buffer << ":" << port; + return ss.str(); +} + +void IceTransport::CandidateCallback(NiceAgent *agent, NiceCandidate *candidate, + gpointer userData) { + auto iceTransport = static_cast(userData); + gchar *cand = nice_agent_generate_local_candidate_sdp(agent, candidate); + try { + iceTransport->processCandidate(cand); + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + } + g_free(cand); +} + +void IceTransport::GatheringDoneCallback(NiceAgent * /*agent*/, guint /*streamId*/, + gpointer userData) { + auto iceTransport = static_cast(userData); + try { + iceTransport->processGatheringDone(); + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + } +} + +void IceTransport::StateChangeCallback(NiceAgent * /*agent*/, guint /*streamId*/, + guint /*componentId*/, guint state, gpointer userData) { + auto iceTransport = static_cast(userData); + try { + iceTransport->processStateChange(state); + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + } +} + +void IceTransport::RecvCallback(NiceAgent * /*agent*/, guint /*streamId*/, guint /*componentId*/, + guint len, gchar *buf, gpointer userData) { + auto iceTransport = static_cast(userData); + try { + PLOG_VERBOSE << "Incoming size=" << len; + auto b = reinterpret_cast(buf); + iceTransport->incoming(make_message(b, b + len)); + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + } +} + +gboolean IceTransport::TimeoutCallback(gpointer userData) { + auto iceTransport = static_cast(userData); + try { + iceTransport->processTimeout(); + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + } + return FALSE; +} + +void IceTransport::LogCallback(const gchar * /*logDomain*/, GLogLevelFlags logLevel, + const gchar *message, gpointer /*userData*/) { + plog::Severity severity; + unsigned int flags = logLevel & G_LOG_LEVEL_MASK; + if (flags & G_LOG_LEVEL_ERROR) + severity = plog::fatal; + else if (flags & G_LOG_LEVEL_CRITICAL) + severity = plog::error; + else if (flags & G_LOG_LEVEL_WARNING) + severity = plog::warning; + else if (flags & G_LOG_LEVEL_MESSAGE) + severity = plog::info; + else if (flags & G_LOG_LEVEL_INFO) + severity = plog::info; + else + severity = plog::verbose; // libnice debug as verbose + + PLOG(severity) << "nice: " << message; +} + +bool IceTransport::getSelectedCandidatePair(Candidate *local, Candidate *remote) { + NiceCandidate *niceLocal, *niceRemote; + if (!nice_agent_get_selected_pair(mNiceAgent.get(), mStreamId, 1, &niceLocal, &niceRemote)) + return false; + + gchar *sdpLocal = nice_agent_generate_local_candidate_sdp(mNiceAgent.get(), niceLocal); + if (local) + *local = Candidate(sdpLocal, mMid); + g_free(sdpLocal); + + gchar *sdpRemote = nice_agent_generate_local_candidate_sdp(mNiceAgent.get(), niceRemote); + if (remote) + *remote = Candidate(sdpRemote, mMid); + g_free(sdpRemote); + + if (local) + local->resolve(Candidate::ResolveMode::Simple); + if (remote) + remote->resolve(Candidate::ResolveMode::Simple); + return true; +} + +#endif + +} // namespace rtc::impl diff --git a/datachannel/src/impl/icetransport.hpp b/datachannel/src/impl/icetransport.hpp new file mode 100644 index 000000000..7724e2bda --- /dev/null +++ b/datachannel/src/impl/icetransport.hpp @@ -0,0 +1,114 @@ +/** + * Copyright (c) 2019-2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_ICE_TRANSPORT_H +#define RTC_IMPL_ICE_TRANSPORT_H + +#include "candidate.hpp" +#include "common.hpp" +#include "configuration.hpp" +#include "description.hpp" +#include "global.hpp" +#include "peerconnection.hpp" +#include "transport.hpp" + +#if !USE_NICE +#include +#else +#include +#endif + +#include +#include +#include +#include + +namespace rtc::impl { + +class IceTransport : public Transport { +public: + static void Init(); + static void Cleanup(); + + enum class GatheringState { New = 0, InProgress = 1, Complete = 2 }; + + using candidate_callback = std::function; + using gathering_state_callback = std::function; + + IceTransport(const Configuration &config, candidate_callback candidateCallback, + state_callback stateChangeCallback, + gathering_state_callback gatheringStateChangeCallback); + ~IceTransport(); + + Description::Role role() const; + GatheringState gatheringState() const; + Description getLocalDescription(Description::Type type) const; + void setRemoteDescription(const Description &description); + bool addRemoteCandidate(const Candidate &candidate); + void gatherLocalCandidates(string mid); + + optional getLocalAddress() const; + optional getRemoteAddress() const; + + bool send(message_ptr message) override; // false if dropped + + bool getSelectedCandidatePair(Candidate *local, Candidate *remote); + +private: + bool outgoing(message_ptr message) override; + + void changeGatheringState(GatheringState state); + + void processStateChange(unsigned int state); + void processCandidate(const string &candidate); + void processGatheringDone(); + void processTimeout(); + + Description::Role mRole; + string mMid; + std::chrono::milliseconds mTrickleTimeout; + std::atomic mGatheringState; + + candidate_callback mCandidateCallback; + gathering_state_callback mGatheringStateChangeCallback; + +#if !USE_NICE + unique_ptr mAgent; + + static void StateChangeCallback(juice_agent_t *agent, juice_state_t state, void *user_ptr); + static void CandidateCallback(juice_agent_t *agent, const char *sdp, void *user_ptr); + static void GatheringDoneCallback(juice_agent_t *agent, void *user_ptr); + static void RecvCallback(juice_agent_t *agent, const char *data, size_t size, void *user_ptr); + static void LogCallback(juice_log_level_t level, const char *message); +#else + static unique_ptr MainLoop; + static std::thread MainLoopThread; + + unique_ptr mNiceAgent; + uint32_t mStreamId = 0; + guint mTimeoutId = 0; + std::mutex mOutgoingMutex; + unsigned int mOutgoingDscp; + + static string AddressToString(const NiceAddress &addr); + + static void CandidateCallback(NiceAgent *agent, NiceCandidate *candidate, gpointer userData); + static void GatheringDoneCallback(NiceAgent *agent, guint streamId, gpointer userData); + static void StateChangeCallback(NiceAgent *agent, guint streamId, guint componentId, + guint state, gpointer userData); + static void RecvCallback(NiceAgent *agent, guint stream_id, guint component_id, guint len, + gchar *buf, gpointer userData); + static gboolean TimeoutCallback(gpointer userData); + static void LogCallback(const gchar *log_domain, GLogLevelFlags log_level, const gchar *message, + gpointer user_data); +#endif +}; + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/init.cpp b/datachannel/src/impl/init.cpp new file mode 100644 index 000000000..aad72dd38 --- /dev/null +++ b/datachannel/src/impl/init.cpp @@ -0,0 +1,181 @@ +/** + * Copyright (c) 2020-2022 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "init.hpp" +#include "certificate.hpp" +#include "dtlstransport.hpp" +#include "icetransport.hpp" +#include "internals.hpp" +#include "pollservice.hpp" +#include "sctptransport.hpp" +#include "threadpool.hpp" +#include "tls.hpp" +#include "utils.hpp" + +#if RTC_ENABLE_WEBSOCKET +#include "tlstransport.hpp" +#endif + +#if RTC_ENABLE_MEDIA +#include "dtlssrtptransport.hpp" +#endif + +#ifdef _WIN32 +#include +#endif + +#include + +namespace rtc::impl { + +struct Init::TokenPayload { + TokenPayload(std::shared_future *cleanupFuture) { + Init::Instance().doInit(); + if (cleanupFuture) + *cleanupFuture = cleanupPromise.get_future().share(); + } + + ~TokenPayload() { + std::thread t( + [](std::promise promise) { + utils::this_thread::set_name("RTC cleanup"); + try { + Init::Instance().doCleanup(); + promise.set_value(); + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + promise.set_exception(std::make_exception_ptr(e)); + } + }, + std::move(cleanupPromise)); + t.detach(); + } + + std::promise cleanupPromise; +}; + +Init &Init::Instance() { + static Init *instance = new Init; + return *instance; +} + +Init::Init() { + std::promise p; + p.set_value(); + mCleanupFuture = p.get_future(); // make it ready +} + +Init::~Init() {} + +init_token Init::token() { + std::lock_guard lock(mMutex); + if (auto locked = mWeak.lock()) + return locked; + + mGlobal = std::make_shared(&mCleanupFuture); + mWeak = *mGlobal; + return *mGlobal; +} + +void Init::preload() { + std::lock_guard lock(mMutex); + if (!mGlobal) { + mGlobal = std::make_shared(&mCleanupFuture); + mWeak = *mGlobal; + } +} + +std::shared_future Init::cleanup() { + std::lock_guard lock(mMutex); + mGlobal.reset(); + return mCleanupFuture; +} + +void Init::setSctpSettings(SctpSettings s) { + std::lock_guard lock(mMutex); + if (mGlobal) + SctpTransport::SetSettings(s); + + mCurrentSctpSettings = std::move(s); // store for next init +} + +void Init::doInit() { + // mMutex needs to be locked + + if (std::exchange(mInitialized, true)) + return; + + PLOG_DEBUG << "Global initialization"; + +#ifdef _WIN32 + WSADATA wsaData; + if (WSAStartup(MAKEWORD(2, 2), &wsaData)) + throw std::runtime_error("WSAStartup failed, error=" + std::to_string(WSAGetLastError())); +#endif + + int concurrency = std::thread::hardware_concurrency(); + int count = std::max(concurrency, MIN_THREADPOOL_SIZE); + PLOG_DEBUG << "Spawning " << count << " threads"; + ThreadPool::Instance().spawn(count); + +#if RTC_ENABLE_WEBSOCKET + PollService::Instance().start(); +#endif + +#if USE_GNUTLS + // Nothing to do +#elif USE_MBEDTLS + // Nothing to do +#else + openssl::init(); +#endif + + SctpTransport::Init(); + SctpTransport::SetSettings(mCurrentSctpSettings); + DtlsTransport::Init(); +#if RTC_ENABLE_WEBSOCKET + TlsTransport::Init(); +#endif +#if RTC_ENABLE_MEDIA + DtlsSrtpTransport::Init(); +#endif + IceTransport::Init(); +} + +void Init::doCleanup() { + std::lock_guard lock(mMutex); + if (mGlobal) + return; + + if (!std::exchange(mInitialized, false)) + return; + + PLOG_DEBUG << "Global cleanup"; + + ThreadPool::Instance().join(); + ThreadPool::Instance().clear(); +#if RTC_ENABLE_WEBSOCKET + PollService::Instance().join(); +#endif + + SctpTransport::Cleanup(); + DtlsTransport::Cleanup(); +#if RTC_ENABLE_WEBSOCKET + TlsTransport::Cleanup(); +#endif +#if RTC_ENABLE_MEDIA + DtlsSrtpTransport::Cleanup(); +#endif + IceTransport::Cleanup(); + +#ifdef _WIN32 + WSACleanup(); +#endif +} + +} // namespace rtc::impl diff --git a/datachannel/src/impl/init.hpp b/datachannel/src/impl/init.hpp new file mode 100644 index 000000000..cd42711ba --- /dev/null +++ b/datachannel/src/impl/init.hpp @@ -0,0 +1,56 @@ +/** + * Copyright (c) 2020-2022 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_INIT_H +#define RTC_IMPL_INIT_H + +#include "common.hpp" +#include "global.hpp" // for SctpSettings + +#include +#include +#include + +namespace rtc::impl { + +using init_token = shared_ptr; + +class Init { +public: + static Init &Instance(); + + Init(const Init &) = delete; + Init &operator=(const Init &) = delete; + Init(Init &&) = delete; + Init &operator=(Init &&) = delete; + + init_token token(); + void preload(); + std::shared_future cleanup(); + void setSctpSettings(SctpSettings s); + +private: + Init(); + ~Init(); + + void doInit(); + void doCleanup(); + + std::optional> mGlobal; + weak_ptr mWeak; + bool mInitialized = false; + SctpSettings mCurrentSctpSettings = {}; + std::mutex mMutex; + std::shared_future mCleanupFuture; + + struct TokenPayload; +}; + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/internals.hpp b/datachannel/src/impl/internals.hpp new file mode 100644 index 000000000..63341c9cc --- /dev/null +++ b/datachannel/src/impl/internals.hpp @@ -0,0 +1,54 @@ +/** + * Copyright (c) 2019-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_INTERNALS_H +#define RTC_IMPL_INTERNALS_H + +#include "common.hpp" + +// Disable warnings before including plog +#if defined(__GNUC__) || defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wall" +#elif defined(_MSC_VER) +#pragma warning(push, 0) +#endif + +#include "plog/Log.h" + +#if defined(__GNUC__) || defined(__clang__) +#pragma GCC diagnostic pop +#elif defined(_MSC_VER) +#pragma warning(pop) +#endif + +namespace rtc { + +const size_t MAX_NUMERICNODE_LEN = 48; // Max IPv6 string representation length +const size_t MAX_NUMERICSERV_LEN = 6; // Max port string representation length + +const uint16_t DEFAULT_SCTP_PORT = 5000; // SCTP port to use by default + +const uint16_t MAX_SCTP_STREAMS_COUNT = 1024; // Max number of negotiated SCTP streams + // RFC 8831 recommends 65535 but usrsctp needs a lot + // of memory, Chromium historically limits to 1024. + +const size_t DEFAULT_LOCAL_MAX_MESSAGE_SIZE = 256 * 1024; // Default local max message size +const size_t DEFAULT_REMOTE_MAX_MESSAGE_SIZE = 65536; // Remote max message size if not in SDP + +const size_t DEFAULT_WS_MAX_MESSAGE_SIZE = 256 * 1024; // Default max message size for WebSockets + +const size_t RECV_QUEUE_LIMIT = 1024 * 1024; // Max per-channel queue size + +const int MIN_THREADPOOL_SIZE = 4; // Minimum number of threads in the global thread pool (>= 2) + +const size_t DEFAULT_MTU = RTC_DEFAULT_MTU; // defined in rtc.h + +} // namespace rtc + +#endif diff --git a/datachannel/src/impl/logcounter.cpp b/datachannel/src/impl/logcounter.cpp new file mode 100644 index 000000000..513506f09 --- /dev/null +++ b/datachannel/src/impl/logcounter.cpp @@ -0,0 +1,40 @@ +/** + * Copyright (c) 2021 Staz Modrzynski + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "logcounter.hpp" + +namespace rtc::impl { + +LogCounter::LogCounter(plog::Severity severity, const std::string &text, + std::chrono::seconds duration) { + mData = std::make_shared(); + mData->mDuration = duration; + mData->mSeverity = severity; + mData->mText = text; +} + +LogCounter &LogCounter::operator++(int) { + if (mData->mCount++ == 0) { + ThreadPool::Instance().schedule( + mData->mDuration, + [](weak_ptr data) { + if (auto ptr = data.lock()) { + int countCopy; + countCopy = ptr->mCount.exchange(0); + PLOG(ptr->mSeverity) + << ptr->mText << ": " << countCopy << " (over " + << std::chrono::duration_cast(ptr->mDuration).count() + << " seconds)"; + } + }, + mData); + } + return *this; +} + +} // namespace rtc::impl diff --git a/datachannel/src/impl/logcounter.hpp b/datachannel/src/impl/logcounter.hpp new file mode 100644 index 000000000..9a14ed58a --- /dev/null +++ b/datachannel/src/impl/logcounter.hpp @@ -0,0 +1,41 @@ +/** + * Copyright (c) 2021 Staz Modrzynski + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_SERVER_LOGCOUNTER_HPP +#define RTC_SERVER_LOGCOUNTER_HPP + +#include "common.hpp" +#include "threadpool.hpp" + +#include +#include + +namespace rtc::impl { + +class LogCounter { +private: + struct LogData { + plog::Severity mSeverity; + std::string mText; + std::chrono::steady_clock::duration mDuration; + + std::atomic mCount = 0; + }; + + shared_ptr mData; + +public: + LogCounter(plog::Severity severity, const std::string &text, + std::chrono::seconds duration = std::chrono::seconds(1)); + + LogCounter &operator++(int); +}; + +} // namespace rtc::impl + +#endif // RTC_SERVER_LOGCOUNTER_HPP diff --git a/datachannel/src/impl/peerconnection.cpp b/datachannel/src/impl/peerconnection.cpp new file mode 100644 index 000000000..3133fe328 --- /dev/null +++ b/datachannel/src/impl/peerconnection.cpp @@ -0,0 +1,1323 @@ +/** + * Copyright (c) 2019 Paul-Louis Ageneau + * Copyright (c) 2020 Filip Klembara (in2core) + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "peerconnection.hpp" +#include "certificate.hpp" +#include "dtlstransport.hpp" +#include "icetransport.hpp" +#include "internals.hpp" +#include "logcounter.hpp" +#include "peerconnection.hpp" +#include "processor.hpp" +#include "rtp.hpp" +#include "sctptransport.hpp" +#include "utils.hpp" + +#if RTC_ENABLE_MEDIA +#include "dtlssrtptransport.hpp" +#endif + +#include +#include +#include +#include +#include +#include + +using namespace std::placeholders; + +namespace rtc::impl { + +static LogCounter COUNTER_MEDIA_TRUNCATED(plog::warning, + "Number of truncated RTP packets over past second"); +static LogCounter COUNTER_SRTP_DECRYPT_ERROR(plog::warning, + "Number of SRTP decryption errors over past second"); +static LogCounter COUNTER_SRTP_ENCRYPT_ERROR(plog::warning, + "Number of SRTP encryption errors over past second"); +static LogCounter + COUNTER_UNKNOWN_PACKET_TYPE(plog::warning, + "Number of unknown RTCP packet types over past second"); + +PeerConnection::PeerConnection(Configuration config_) + : config(std::move(config_)), mCertificate(make_certificate(config.certificateType)) { + PLOG_VERBOSE << "Creating PeerConnection"; + + if (config.portRangeEnd && config.portRangeBegin > config.portRangeEnd) + throw std::invalid_argument("Invalid port range"); + + if (config.mtu) { + if (*config.mtu < 576) // Min MTU for IPv4 + throw std::invalid_argument("Invalid MTU value"); + + if (*config.mtu > 1500) { // Standard Ethernet + PLOG_WARNING << "MTU set to " << *config.mtu; + } else { + PLOG_VERBOSE << "MTU set to " << *config.mtu; + } + } +} + +PeerConnection::~PeerConnection() { + PLOG_VERBOSE << "Destroying PeerConnection"; + mProcessor.join(); +} + +void PeerConnection::close() { + negotiationNeeded = false; + if (!closing.exchange(true)) { + PLOG_VERBOSE << "Closing PeerConnection"; + if (auto transport = std::atomic_load(&mSctpTransport)) + transport->stop(); + else + remoteClose(); + } +} + +void PeerConnection::remoteClose() { + close(); + if (state.load() != State::Closed) { + // Close data channels and tracks asynchronously + mProcessor.enqueue(&PeerConnection::closeDataChannels, shared_from_this()); + mProcessor.enqueue(&PeerConnection::closeTracks, shared_from_this()); + + closeTransports(); + } +} + +optional PeerConnection::localDescription() const { + std::lock_guard lock(mLocalDescriptionMutex); + return mLocalDescription; +} + +optional PeerConnection::remoteDescription() const { + std::lock_guard lock(mRemoteDescriptionMutex); + return mRemoteDescription; +} + +size_t PeerConnection::remoteMaxMessageSize() const { + const size_t localMax = config.maxMessageSize.value_or(DEFAULT_LOCAL_MAX_MESSAGE_SIZE); + + size_t remoteMax = DEFAULT_REMOTE_MAX_MESSAGE_SIZE; + std::lock_guard lock(mRemoteDescriptionMutex); + if (mRemoteDescription) + if (auto *application = mRemoteDescription->application()) + if (auto max = application->maxMessageSize()) { + // RFC 8841: If the SDP "max-message-size" attribute contains a maximum message + // size value of zero, it indicates that the SCTP endpoint will handle messages + // of any size, subject to memory capacity, etc. + remoteMax = *max > 0 ? *max : std::numeric_limits::max(); + } + + return std::min(remoteMax, localMax); +} + +// Helper for PeerConnection::initXTransport methods: start and emplace the transport +template +shared_ptr emplaceTransport(PeerConnection *pc, shared_ptr *member, shared_ptr transport) { + std::atomic_store(member, transport); + try { + transport->start(); + } catch (...) { + std::atomic_store(member, decltype(transport)(nullptr)); + throw; + } + + if (pc->closing.load() || pc->state.load() == PeerConnection::State::Closed) { + std::atomic_store(member, decltype(transport)(nullptr)); + transport->stop(); + return nullptr; + } + + return transport; +} + +shared_ptr PeerConnection::initIceTransport() { + try { + if (auto transport = std::atomic_load(&mIceTransport)) + return transport; + + PLOG_VERBOSE << "Starting ICE transport"; + + auto transport = std::make_shared( + config, weak_bind(&PeerConnection::processLocalCandidate, this, _1), + [this, weak_this = weak_from_this()](IceTransport::State transportState) { + auto shared_this = weak_this.lock(); + if (!shared_this) + return; + switch (transportState) { + case IceTransport::State::Connecting: + changeIceState(IceState::Checking); + changeState(State::Connecting); + break; + case IceTransport::State::Connected: + changeIceState(IceState::Connected); + initDtlsTransport(); + break; + case IceTransport::State::Completed: + changeIceState(IceState::Completed); + break; + case IceTransport::State::Failed: + changeIceState(IceState::Failed); + changeState(State::Failed); + mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this()); + break; + case IceTransport::State::Disconnected: + changeIceState(IceState::Disconnected); + changeState(State::Disconnected); + mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this()); + break; + default: + // Ignore + break; + } + }, + [this, weak_this = weak_from_this()](IceTransport::GatheringState gatheringState) { + auto shared_this = weak_this.lock(); + if (!shared_this) + return; + switch (gatheringState) { + case IceTransport::GatheringState::InProgress: + changeGatheringState(GatheringState::InProgress); + break; + case IceTransport::GatheringState::Complete: + endLocalCandidates(); + changeGatheringState(GatheringState::Complete); + break; + default: + // Ignore + break; + } + }); + + return emplaceTransport(this, &mIceTransport, std::move(transport)); + + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + changeState(State::Failed); + throw std::runtime_error("ICE transport initialization failed"); + } +} + +shared_ptr PeerConnection::initDtlsTransport() { + try { + if (auto transport = std::atomic_load(&mDtlsTransport)) + return transport; + + PLOG_VERBOSE << "Starting DTLS transport"; + + auto fingerprintAlgorithm = CertificateFingerprint::Algorithm::Sha256; + if (auto remote = remoteDescription(); remote && remote->fingerprint()) { + fingerprintAlgorithm = remote->fingerprint()->algorithm; + } + + auto lower = std::atomic_load(&mIceTransport); + if (!lower) + throw std::logic_error("No underlying ICE transport for DTLS transport"); + + auto certificate = mCertificate.get(); + auto verifierCallback = weak_bind(&PeerConnection::checkFingerprint, this, _1); + auto dtlsStateChangeCallback = + [this, weak_this = weak_from_this()](DtlsTransport::State transportState) { + auto shared_this = weak_this.lock(); + if (!shared_this) + return; + + switch (transportState) { + case DtlsTransport::State::Connected: + if (auto remote = remoteDescription(); remote && remote->hasApplication()) + initSctpTransport(); + else + changeState(State::Connected); + + mProcessor.enqueue(&PeerConnection::openTracks, shared_from_this()); + break; + case DtlsTransport::State::Failed: + changeState(State::Failed); + mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this()); + break; + case DtlsTransport::State::Disconnected: + changeState(State::Disconnected); + mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this()); + break; + default: + // Ignore + break; + } + }; + + shared_ptr transport; + auto local = localDescription(); + if (config.forceMediaTransport || (local && local->hasAudioOrVideo())) { +#if RTC_ENABLE_MEDIA + PLOG_INFO << "This connection requires media support"; + + // DTLS-SRTP + transport = std::make_shared( + lower, certificate, config.mtu, fingerprintAlgorithm, verifierCallback, + weak_bind(&PeerConnection::forwardMedia, this, _1), dtlsStateChangeCallback); +#else + PLOG_WARNING << "Ignoring media support (not compiled with media support)"; +#endif + } + + if (!transport) { + // DTLS only + transport = std::make_shared(lower, certificate, config.mtu, + fingerprintAlgorithm, verifierCallback, + dtlsStateChangeCallback); + } + + return emplaceTransport(this, &mDtlsTransport, std::move(transport)); + + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + changeState(State::Failed); + throw std::runtime_error("DTLS transport initialization failed"); + } +} + +shared_ptr PeerConnection::initSctpTransport() { + try { + if (auto transport = std::atomic_load(&mSctpTransport)) + return transport; + + PLOG_VERBOSE << "Starting SCTP transport"; + + auto lower = std::atomic_load(&mDtlsTransport); + if (!lower) + throw std::logic_error("No underlying DTLS transport for SCTP transport"); + + auto local = localDescription(); + if (!local || !local->application()) + throw std::logic_error("Starting SCTP transport without local application description"); + + auto remote = remoteDescription(); + if (!remote || !remote->application()) + throw std::logic_error( + "Starting SCTP transport without remote application description"); + + SctpTransport::Ports ports = {}; + ports.local = local->application()->sctpPort().value_or(DEFAULT_SCTP_PORT); + ports.remote = remote->application()->sctpPort().value_or(DEFAULT_SCTP_PORT); + + auto transport = std::make_shared( + lower, config, std::move(ports), weak_bind(&PeerConnection::forwardMessage, this, _1), + weak_bind(&PeerConnection::forwardBufferedAmount, this, _1, _2), + [this, weak_this = weak_from_this()](SctpTransport::State transportState) { + auto shared_this = weak_this.lock(); + if (!shared_this) + return; + + switch (transportState) { + case SctpTransport::State::Connected: + changeState(State::Connected); + assignDataChannels(); + mProcessor.enqueue(&PeerConnection::openDataChannels, shared_from_this()); + break; + case SctpTransport::State::Failed: + changeState(State::Failed); + mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this()); + break; + case SctpTransport::State::Disconnected: + changeState(State::Disconnected); + mProcessor.enqueue(&PeerConnection::remoteClose, shared_from_this()); + break; + default: + // Ignore + break; + } + }); + + return emplaceTransport(this, &mSctpTransport, std::move(transport)); + + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + changeState(State::Failed); + throw std::runtime_error("SCTP transport initialization failed"); + } +} + +shared_ptr PeerConnection::getIceTransport() const { + return std::atomic_load(&mIceTransport); +} + +shared_ptr PeerConnection::getDtlsTransport() const { + return std::atomic_load(&mDtlsTransport); +} + +shared_ptr PeerConnection::getSctpTransport() const { + return std::atomic_load(&mSctpTransport); +} + +void PeerConnection::closeTransports() { + PLOG_VERBOSE << "Closing transports"; + + // Change ICE state to sink state Closed + changeIceState(IceState::Closed); + + // Change state to sink state Closed + if (!changeState(State::Closed)) + return; // already closed + + // Reset interceptor and callbacks now that state is changed + setMediaHandler(nullptr); + resetCallbacks(); + + // Pass the pointers to a thread, allowing to terminate a transport from its own thread + auto sctp = std::atomic_exchange(&mSctpTransport, decltype(mSctpTransport)(nullptr)); + auto dtls = std::atomic_exchange(&mDtlsTransport, decltype(mDtlsTransport)(nullptr)); + auto ice = std::atomic_exchange(&mIceTransport, decltype(mIceTransport)(nullptr)); + + if (sctp) { + sctp->onRecv(nullptr); + sctp->onBufferedAmount(nullptr); + } + + using array = std::array, 3>; + array transports{std::move(sctp), std::move(dtls), std::move(ice)}; + + for (const auto &t : transports) + if (t) + t->onStateChange(nullptr); + + TearDownProcessor::Instance().enqueue( + [transports = std::move(transports), token = Init::Instance().token()]() mutable { + for (const auto &t : transports) { + if (t) { + t->stop(); + break; + } + } + + for (auto &t : transports) + t.reset(); + }); +} + +void PeerConnection::endLocalCandidates() { + std::lock_guard lock(mLocalDescriptionMutex); + if (mLocalDescription) + mLocalDescription->endCandidates(); +} + +void PeerConnection::rollbackLocalDescription() { + PLOG_DEBUG << "Rolling back pending local description"; + + std::unique_lock lock(mLocalDescriptionMutex); + if (mCurrentLocalDescription) { + std::vector existingCandidates; + if (mLocalDescription) + existingCandidates = mLocalDescription->extractCandidates(); + + mLocalDescription.emplace(std::move(*mCurrentLocalDescription)); + mLocalDescription->addCandidates(std::move(existingCandidates)); + mCurrentLocalDescription.reset(); + } +} + +bool PeerConnection::checkFingerprint(const std::string &fingerprint) const { + std::lock_guard lock(mRemoteDescriptionMutex); + if (!mRemoteDescription || !mRemoteDescription->fingerprint()) + return false; + + auto expectedFingerprint = mRemoteDescription->fingerprint()->value; + if (expectedFingerprint == fingerprint) { + PLOG_VERBOSE << "Valid fingerprint \"" << fingerprint << "\""; + return true; + } + + PLOG_ERROR << "Invalid fingerprint \"" << fingerprint << "\", expected \"" << expectedFingerprint << "\""; + return false; +} + +void PeerConnection::forwardMessage(message_ptr message) { + if (!message) { + remoteCloseDataChannels(); + return; + } + + auto iceTransport = std::atomic_load(&mIceTransport); + auto sctpTransport = std::atomic_load(&mSctpTransport); + if (!iceTransport || !sctpTransport) + return; + + const uint16_t stream = uint16_t(message->stream); + auto [channel, found] = findDataChannel(stream); + + if (DataChannel::IsOpenMessage(message)) { + if (found) { + // The stream is already used, the receiver must close the DataChannel + PLOG_WARNING << "Got open message on already used stream " << stream; + if (channel && !channel->isClosed()) + channel->close(); + else + sctpTransport->closeStream(message->stream); + + return; + } + + const uint16_t remoteParity = (iceTransport->role() == Description::Role::Active) ? 1 : 0; + if (stream % 2 != remoteParity) { + // The odd/even rule is violated, the receiver must close the DataChannel + PLOG_WARNING << "Got open message violating the odd/even rule on stream " << stream; + sctpTransport->closeStream(message->stream); + return; + } + + channel = std::make_shared(weak_from_this(), sctpTransport); + channel->assignStream(stream); + channel->openCallback = + weak_bind(&PeerConnection::triggerDataChannel, this, weak_ptr{channel}); + + std::unique_lock lock(mDataChannelsMutex); // we are going to emplace + mDataChannels.emplace(stream, channel); + } else if (!found) { + if (message->type == Message::Reset) + return; // ignore + + // Invalid, close the DataChannel + PLOG_WARNING << "Got unexpected message on stream " << stream; + sctpTransport->closeStream(message->stream); + return; + } + + if (message->type == Message::Reset) { + // Incoming stream is reset, unregister it + removeDataChannel(stream); + } + + if (channel) { + // Forward the message + channel->incoming(message); + } else { + // DataChannel was destroyed, ignore + PLOG_DEBUG << "Ignored message on stream " << stream << ", DataChannel is destroyed"; + } +} + +void PeerConnection::forwardMedia([[maybe_unused]] message_ptr message) { +#if RTC_ENABLE_MEDIA + if (!message) + return; + + // TODO: outgoing + if (auto handler = getMediaHandler()) { + message_vector messages{std::move(message)}; + + handler->incoming(messages, [this](message_ptr message) { + auto transport = std::atomic_load(&mDtlsTransport); + if (auto srtpTransport = std::dynamic_pointer_cast(transport)) + srtpTransport->send(std::move(message)); + }); + + for (auto &m : messages) + dispatchMedia(std::move(m)); + + } else { + dispatchMedia(std::move(message)); + } +#endif +} + +void PeerConnection::dispatchMedia([[maybe_unused]] message_ptr message) { +#if RTC_ENABLE_MEDIA + std::shared_lock lock(mTracksMutex); // read-only + if (mTrackLines.size()==1) { + if (auto track = mTrackLines.front().lock()) + track->incoming(message); + return; + } + // Browsers like to compound their packets with a random SSRC. + // we have to do this monstrosity to distribute the report blocks + if (message->type == Message::Control) { + std::set ssrcs; + size_t offset = 0; + while ((sizeof(RtcpHeader) + offset) <= message->size()) { + auto header = reinterpret_cast(message->data() + offset); + if (header->lengthInBytes() > message->size() - offset) { + COUNTER_MEDIA_TRUNCATED++; + break; + } + offset += header->lengthInBytes(); + if (header->payloadType() == 205 || header->payloadType() == 206) { + auto rtcpfb = reinterpret_cast(header); + ssrcs.insert(rtcpfb->packetSenderSSRC()); + ssrcs.insert(rtcpfb->mediaSourceSSRC()); + + } else if (header->payloadType() == 200) { + auto rtcpsr = reinterpret_cast(header); + ssrcs.insert(rtcpsr->senderSSRC()); + for (int i = 0; i < rtcpsr->header.reportCount(); ++i) + ssrcs.insert(rtcpsr->getReportBlock(i)->getSSRC()); + } else if (header->payloadType() == 201) { + auto rtcprr = reinterpret_cast(header); + ssrcs.insert(rtcprr->senderSSRC()); + for (int i = 0; i < rtcprr->header.reportCount(); ++i) + ssrcs.insert(rtcprr->getReportBlock(i)->getSSRC()); + } else if (header->payloadType() == 202) { + auto sdes = reinterpret_cast(header); + if (!sdes->isValid()) { + PLOG_WARNING << "RTCP SDES packet is invalid"; + continue; + } + for (unsigned int i = 0; i < sdes->chunksCount(); i++) { + auto chunk = sdes->getChunk(i); + ssrcs.insert(chunk->ssrc()); + } + } else { + // PT=203 == Goodbye + // PT=204 == Application Specific + // PT=207 == Extended Report + if (header->payloadType() != 203 && header->payloadType() != 204 && + header->payloadType() != 207) { + COUNTER_UNKNOWN_PACKET_TYPE++; + } + } + } + + if (!ssrcs.empty()) { + for (uint32_t ssrc : ssrcs) { + if (auto it = mTracksBySsrc.find(ssrc); it != mTracksBySsrc.end()) { + if (auto track = it->second.lock()) + track->incoming(message); + } + } + return; + } + } + + uint32_t ssrc = uint32_t(message->stream); + + if (auto it = mTracksBySsrc.find(ssrc); it != mTracksBySsrc.end()) { + if (auto track = it->second.lock()) + track->incoming(message); + } else { + /* + * TODO: So the problem is that when stop sending streams, we stop getting report blocks for + * those streams Therefore when we get compound RTCP packets, they are empty, and we can't + * forward them. Therefore, it is expected that we don't know where to forward packets. Is + * this ideal? No! Do I know how to fix it? No! + */ + // PLOG_WARNING << "Track not found for SSRC " << ssrc << ", dropping"; + return; + } +#endif +} + +void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) { + [[maybe_unused]] auto [channel, found] = findDataChannel(stream); + if (channel) + channel->triggerBufferedAmount(amount); +} + +shared_ptr PeerConnection::emplaceDataChannel(string label, DataChannelInit init) { + std::unique_lock lock(mDataChannelsMutex); // we are going to emplace + + // If the DataChannel is user-negotiated, do not negotiate it in-band + auto channel = + init.negotiated + ? std::make_shared(weak_from_this(), std::move(label), + std::move(init.protocol), std::move(init.reliability)) + : std::make_shared(weak_from_this(), std::move(label), + std::move(init.protocol), + std::move(init.reliability)); + + // If the user supplied a stream id, use it, otherwise assign it later + if (init.id) { + uint16_t stream = *init.id; + if (stream > maxDataChannelStream()) + throw std::invalid_argument("DataChannel stream id is too high"); + + channel->assignStream(stream); + mDataChannels.emplace(std::make_pair(stream, channel)); + + } else { + mUnassignedDataChannels.push_back(channel); + } + + lock.unlock(); // we are going to call assignDataChannels() + + // If SCTP is connected, assign and open now + auto sctpTransport = std::atomic_load(&mSctpTransport); + if (sctpTransport && sctpTransport->state() == SctpTransport::State::Connected) { + assignDataChannels(); + channel->open(sctpTransport); + } + + return channel; +} + +std::pair, bool> PeerConnection::findDataChannel(uint16_t stream) { + std::shared_lock lock(mDataChannelsMutex); // read-only + if (auto it = mDataChannels.find(stream); it != mDataChannels.end()) + return std::make_pair(it->second.lock(), true); + else + return std::make_pair(nullptr, false); +} + +bool PeerConnection::removeDataChannel(uint16_t stream) { + std::unique_lock lock(mDataChannelsMutex); // we are going to erase + return mDataChannels.erase(stream) != 0; +} + +uint16_t PeerConnection::maxDataChannelStream() const { + auto sctpTransport = std::atomic_load(&mSctpTransport); + return sctpTransport ? sctpTransport->maxStream() : (MAX_SCTP_STREAMS_COUNT - 1); +} + +void PeerConnection::assignDataChannels() { + std::unique_lock lock(mDataChannelsMutex); // we are going to emplace + + auto iceTransport = std::atomic_load(&mIceTransport); + if (!iceTransport) + throw std::logic_error("Attempted to assign DataChannels without ICE transport"); + + const uint16_t maxStream = maxDataChannelStream(); + for (auto it = mUnassignedDataChannels.begin(); it != mUnassignedDataChannels.end(); ++it) { + auto channel = it->lock(); + if (!channel) + continue; + + // RFC 8832: The peer that initiates opening a data channel selects a stream identifier + // for which the corresponding incoming and outgoing streams are unused. If the side is + // acting as the DTLS client, it MUST choose an even stream identifier; if the side is + // acting as the DTLS server, it MUST choose an odd one. See + // https://www.rfc-editor.org/rfc/rfc8832.html#section-6 + uint16_t stream = (iceTransport->role() == Description::Role::Active) ? 0 : 1; + while (true) { + if (stream > maxStream) + throw std::runtime_error("Too many DataChannels"); + + if (mDataChannels.find(stream) == mDataChannels.end()) + break; + + stream += 2; + } + + PLOG_DEBUG << "Assigning stream " << stream << " to DataChannel"; + + channel->assignStream(stream); + mDataChannels.emplace(std::make_pair(stream, channel)); + } + + mUnassignedDataChannels.clear(); +} + +void PeerConnection::iterateDataChannels( + std::function channel)> func) { + std::vector> locked; + { + std::shared_lock lock(mDataChannelsMutex); // read-only + locked.reserve(mDataChannels.size()); + for(auto it = mDataChannels.begin(); it != mDataChannels.end(); ++it) { + auto channel = it->second.lock(); + if (channel && !channel->isClosed()) + locked.push_back(std::move(channel)); + } + } + + for (auto &channel : locked) { + try { + func(std::move(channel)); + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + } + } +} + +void PeerConnection::openDataChannels() { + if (auto transport = std::atomic_load(&mSctpTransport)) + iterateDataChannels([&](shared_ptr channel) { + if (!channel->isOpen()) + channel->open(transport); + }); +} + +void PeerConnection::closeDataChannels() { + iterateDataChannels([&](shared_ptr channel) { channel->close(); }); +} + +void PeerConnection::remoteCloseDataChannels() { + iterateDataChannels([&](shared_ptr channel) { channel->remoteClose(); }); +} + +shared_ptr PeerConnection::emplaceTrack(Description::Media description) { + std::unique_lock lock(mTracksMutex); // we are going to emplace + +#if !RTC_ENABLE_MEDIA + // No media support, mark as removed + PLOG_WARNING << "Tracks are disabled (not compiled with media support)"; + description.markRemoved(); +#endif + + shared_ptr track; + if (auto it = mTracks.find(description.mid()); it != mTracks.end()) + if (auto t = it->second.lock(); t && !t->isClosed()) + track = std::move(t); + + if (track) { + track->setDescription(std::move(description)); + } else { + track = std::make_shared(weak_from_this(), std::move(description)); + mTracks.emplace(std::make_pair(track->mid(), track)); + mTrackLines.emplace_back(track); + } + + auto handler = getMediaHandler(); + if (handler) + handler->media(track->description()); + + if (track->description().isRemoved()) + track->close(); + + return track; +} + +void PeerConnection::iterateTracks(std::function track)> func) { + std::vector> locked; + { + std::shared_lock lock(mTracksMutex); // read-only + locked.reserve(mTrackLines.size()); + for(auto it = mTrackLines.begin(); it != mTrackLines.end(); ++it) { + auto track = it->lock(); + if (track && !track->isClosed()) + locked.push_back(std::move(track)); + } + } + + for (auto &track : locked) { + try { + func(std::move(track)); + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + } + } +} + +void PeerConnection::openTracks() { +#if RTC_ENABLE_MEDIA + if (auto transport = std::atomic_load(&mDtlsTransport)) { + auto srtpTransport = std::dynamic_pointer_cast(transport); + + iterateTracks([&](const shared_ptr &track) { + if (!track->isOpen()) { + if (srtpTransport) { + track->open(srtpTransport); + } else { + // A track was added during a latter renegotiation, whereas SRTP transport was + // not initialized. This is an optimization to use the library with data + // channels only. Set forceMediaTransport to true to initialize the transport + // before dynamically adding tracks. + auto errorMsg = "The connection has no media transport"; + PLOG_ERROR << errorMsg; + track->triggerError(errorMsg); + } + } + }); + } +#endif +} + +void PeerConnection::closeTracks() { + std::shared_lock lock(mTracksMutex); // read-only + iterateTracks([&](shared_ptr track) { track->close(); }); +} + +void PeerConnection::validateRemoteDescription(const Description &description) { + if (!description.iceUfrag()) + throw std::invalid_argument("Remote description has no ICE user fragment"); + + if (!description.icePwd()) + throw std::invalid_argument("Remote description has no ICE password"); + + if (!description.fingerprint()) + throw std::invalid_argument("Remote description has no valid fingerprint"); + + if (description.mediaCount() == 0) + throw std::invalid_argument("Remote description has no media line"); + + int activeMediaCount = 0; + for (unsigned int i = 0; i < description.mediaCount(); ++i) + std::visit(rtc::overloaded{[&](const Description::Application *application) { + if (!application->isRemoved()) + ++activeMediaCount; + }, + [&](const Description::Media *media) { + if (!media->isRemoved() || + media->direction() != Description::Direction::Inactive) + ++activeMediaCount; + }}, + description.media(i)); + + if (activeMediaCount == 0) + throw std::invalid_argument("Remote description has no active media"); + + if (auto local = localDescription(); local && local->iceUfrag() && local->icePwd()) + if (*description.iceUfrag() == *local->iceUfrag() && + *description.icePwd() == *local->icePwd()) + throw std::logic_error("Got the local description as remote description"); + + PLOG_VERBOSE << "Remote description looks valid"; +} + +void PeerConnection::processLocalDescription(Description description) { + const uint16_t localSctpPort = DEFAULT_SCTP_PORT; + const size_t localMaxMessageSize = + config.maxMessageSize.value_or(DEFAULT_LOCAL_MAX_MESSAGE_SIZE); + + // Clean up the application entry the ICE transport might have added already (libnice) + description.clearMedia(); + + if (auto remote = remoteDescription()) { + // Reciprocate remote description + for (unsigned int i = 0; i < remote->mediaCount(); ++i) + std::visit( // reciprocate each media + rtc::overloaded{ + [&](Description::Application *remoteApp) { + std::shared_lock lock(mDataChannelsMutex); + if (!mDataChannels.empty() || !mUnassignedDataChannels.empty()) { + // Prefer local description + Description::Application app(remoteApp->mid()); + app.setSctpPort(localSctpPort); + app.setMaxMessageSize(localMaxMessageSize); + + PLOG_DEBUG << "Adding application to local description, mid=\"" + << app.mid() << "\""; + + description.addMedia(std::move(app)); + return; + } + + auto reciprocated = remoteApp->reciprocate(); + reciprocated.hintSctpPort(localSctpPort); + reciprocated.setMaxMessageSize(localMaxMessageSize); + + PLOG_DEBUG << "Reciprocating application in local description, mid=\"" + << reciprocated.mid() << "\""; + + description.addMedia(std::move(reciprocated)); + }, + [&](Description::Media *remoteMedia) { + std::unique_lock lock(mTracksMutex); // we may emplace a track + if (auto it = mTracks.find(remoteMedia->mid()); it != mTracks.end()) { + // Prefer local description + if (auto track = it->second.lock()) { + auto media = track->description(); + + PLOG_DEBUG << "Adding media to local description, mid=\"" + << media.mid() << "\", removed=" << std::boolalpha + << media.isRemoved(); + + description.addMedia(std::move(media)); + + } else { + auto reciprocated = remoteMedia->reciprocate(); + reciprocated.markRemoved(); + + PLOG_DEBUG << "Adding media to local description, mid=\"" + << reciprocated.mid() + << "\", removed=true (track is destroyed)"; + + description.addMedia(std::move(reciprocated)); + } + return; + } + + auto reciprocated = remoteMedia->reciprocate(); +#if !RTC_ENABLE_MEDIA + if (!reciprocated.isRemoved()) { + // No media support, mark as removed + PLOG_WARNING << "Rejecting track (not compiled with media support)"; + reciprocated.markRemoved(); + } +#endif + + PLOG_DEBUG << "Reciprocating media in local description, mid=\"" + << reciprocated.mid() << "\", removed=" << std::boolalpha + << reciprocated.isRemoved(); + + // Create incoming track + auto track = + std::make_shared(weak_from_this(), std::move(reciprocated)); + mTracks.emplace(std::make_pair(track->mid(), track)); + mTrackLines.emplace_back(track); + triggerTrack(track); // The user may modify the track description + + auto handler = getMediaHandler(); + if (handler) + handler->media(track->description()); + + if (track->description().isRemoved()) + track->close(); + + description.addMedia(track->description()); + }, + }, + remote->media(i)); + + // We need to update the SSRC cache for newly-created incoming tracks + updateTrackSsrcCache(*remote); + } + + if (description.type() == Description::Type::Offer) { + // This is an offer, add locally created data channels and tracks + // Add media for local tracks + std::shared_lock lock(mTracksMutex); + for (auto it = mTrackLines.begin(); it != mTrackLines.end(); ++it) { + if (auto track = it->lock()) { + if (description.hasMid(track->mid())) + continue; + + auto media = track->description(); + + PLOG_DEBUG << "Adding media to local description, mid=\"" << media.mid() + << "\", removed=" << std::boolalpha << media.isRemoved(); + + description.addMedia(std::move(media)); + } + } + + // Add application for data channels + if (!description.hasApplication()) { + std::shared_lock lock(mDataChannelsMutex); + if (!mDataChannels.empty() || !mUnassignedDataChannels.empty()) { + // Prevents mid collision with remote or local tracks + unsigned int m = 0; + while (description.hasMid(std::to_string(m))) + ++m; + + Description::Application app(std::to_string(m)); + app.setSctpPort(localSctpPort); + app.setMaxMessageSize(localMaxMessageSize); + + PLOG_DEBUG << "Adding application to local description, mid=\"" << app.mid() + << "\""; + + description.addMedia(std::move(app)); + } + } + + // There might be no media at this point if the user created a Track, deleted it, + // then called setLocalDescription(). + if (description.mediaCount() == 0) + throw std::runtime_error("No DataChannel or Track to negotiate"); + } + + // Set local fingerprint (wait for certificate if necessary) + description.setFingerprint(mCertificate.get()->fingerprint()); + + PLOG_VERBOSE << "Issuing local description: " << description; + + if (description.mediaCount() == 0) + throw std::logic_error("Local description has no media line"); + + updateTrackSsrcCache(description); + + { + // Set as local description + std::lock_guard lock(mLocalDescriptionMutex); + + std::vector existingCandidates; + if (mLocalDescription) { + existingCandidates = mLocalDescription->extractCandidates(); + mCurrentLocalDescription.emplace(std::move(*mLocalDescription)); + } + + mLocalDescription.emplace(description); + mLocalDescription->addCandidates(std::move(existingCandidates)); + } + + mProcessor.enqueue(&PeerConnection::trigger, shared_from_this(), + &localDescriptionCallback, std::move(description)); + + // Reciprocated tracks might need to be open + if (auto dtlsTransport = std::atomic_load(&mDtlsTransport); + dtlsTransport && dtlsTransport->state() == Transport::State::Connected) + mProcessor.enqueue(&PeerConnection::openTracks, shared_from_this()); +} + +void PeerConnection::processLocalCandidate(Candidate candidate) { + std::lock_guard lock(mLocalDescriptionMutex); + if (!mLocalDescription) + throw std::logic_error("Got a local candidate without local description"); + + if (config.iceTransportPolicy == TransportPolicy::Relay && + candidate.type() != Candidate::Type::Relayed) { + PLOG_VERBOSE << "Not issuing local candidate because of transport policy: " << candidate; + return; + } + + PLOG_VERBOSE << "Issuing local candidate: " << candidate; + + candidate.resolve(Candidate::ResolveMode::Simple); + mLocalDescription->addCandidate(candidate); + + mProcessor.enqueue(&PeerConnection::trigger, shared_from_this(), + &localCandidateCallback, std::move(candidate)); +} + +void PeerConnection::processRemoteDescription(Description description) { + // Update the SSRC cache for existing tracks + updateTrackSsrcCache(description); + + { + // Set as remote description + std::lock_guard lock(mRemoteDescriptionMutex); + + std::vector existingCandidates; + if (mRemoteDescription) + existingCandidates = mRemoteDescription->extractCandidates(); + + mRemoteDescription.emplace(description); + mRemoteDescription->addCandidates(std::move(existingCandidates)); + } + + if (description.hasApplication()) { + auto dtlsTransport = std::atomic_load(&mDtlsTransport); + auto sctpTransport = std::atomic_load(&mSctpTransport); + if (!sctpTransport && dtlsTransport && + dtlsTransport->state() == Transport::State::Connected) + initSctpTransport(); + } else { + mProcessor.enqueue(&PeerConnection::remoteCloseDataChannels, shared_from_this()); + } +} + +void PeerConnection::processRemoteCandidate(Candidate candidate) { + auto iceTransport = std::atomic_load(&mIceTransport); + { + // Set as remote candidate + std::lock_guard lock(mRemoteDescriptionMutex); + if (!mRemoteDescription) + throw std::logic_error("Got a remote candidate without remote description"); + + if (!iceTransport) + throw std::logic_error("Got a remote candidate without ICE transport"); + + candidate.hintMid(mRemoteDescription->bundleMid()); + + if (mRemoteDescription->hasCandidate(candidate)) + return; // already in description, ignore + + candidate.resolve(Candidate::ResolveMode::Simple); + mRemoteDescription->addCandidate(candidate); + } + + if (candidate.isResolved()) { + iceTransport->addRemoteCandidate(std::move(candidate)); + } else { + // We might need a lookup, do it asynchronously + // We don't use the thread pool because we have no control on the timeout + if ((iceTransport = std::atomic_load(&mIceTransport))) { + weak_ptr weakIceTransport{iceTransport}; + std::thread t([weakIceTransport, candidate = std::move(candidate)]() mutable { + utils::this_thread::set_name("RTC resolver"); + if (candidate.resolve(Candidate::ResolveMode::Lookup)) + if (auto iceTransport = weakIceTransport.lock()) + iceTransport->addRemoteCandidate(std::move(candidate)); + }); + t.detach(); + } + } +} + +string PeerConnection::localBundleMid() const { + std::lock_guard lock(mLocalDescriptionMutex); + return mLocalDescription ? mLocalDescription->bundleMid() : "0"; +} + +void PeerConnection::setMediaHandler(shared_ptr handler) { + std::unique_lock lock(mMediaHandlerMutex); + mMediaHandler = handler; +} + +shared_ptr PeerConnection::getMediaHandler() { + std::shared_lock lock(mMediaHandlerMutex); + return mMediaHandler; +} + +void PeerConnection::triggerDataChannel(weak_ptr weakDataChannel) { + auto dataChannel = weakDataChannel.lock(); + if (dataChannel) { + dataChannel->resetOpenCallback(); // might be set internally + mPendingDataChannels.push(std::move(dataChannel)); + } + triggerPendingDataChannels(); +} + +void PeerConnection::triggerTrack(weak_ptr weakTrack) { + auto track = weakTrack.lock(); + if (track) { + track->resetOpenCallback(); // might be set internally + mPendingTracks.push(std::move(track)); + } + triggerPendingTracks(); +} + +void PeerConnection::triggerPendingDataChannels() { + while (dataChannelCallback) { + auto next = mPendingDataChannels.pop(); + if (!next) + break; + + auto impl = std::move(*next); + + try { + dataChannelCallback(std::make_shared(impl)); + } catch (const std::exception &e) { + PLOG_WARNING << "Uncaught exception in callback: " << e.what(); + } + + impl->triggerOpen(); + } +} + +void PeerConnection::triggerPendingTracks() { + while (trackCallback) { + auto next = mPendingTracks.pop(); + if (!next) + break; + + auto impl = std::move(*next); + + try { + trackCallback(std::make_shared(impl)); + } catch (const std::exception &e) { + PLOG_WARNING << "Uncaught exception in callback: " << e.what(); + } + + // Do not trigger open immediately for tracks as it'll be done later + } +} + +void PeerConnection::flushPendingDataChannels() { + mProcessor.enqueue(&PeerConnection::triggerPendingDataChannels, shared_from_this()); +} + +void PeerConnection::flushPendingTracks() { + mProcessor.enqueue(&PeerConnection::triggerPendingTracks, shared_from_this()); +} + +bool PeerConnection::changeState(State newState) { + State current; + do { + current = state.load(); + if (current == State::Closed) + return false; + if (current == newState) + return false; + + } while (!state.compare_exchange_weak(current, newState)); + + std::ostringstream s; + s << newState; + PLOG_INFO << "Changed state to " << s.str(); + + if (newState == State::Closed) { + auto callback = std::move(stateChangeCallback); // steal the callback + callback(State::Closed); // call it synchronously + } else { + mProcessor.enqueue(&PeerConnection::trigger, shared_from_this(), + &stateChangeCallback, newState); + } + return true; +} + +bool PeerConnection::changeIceState(IceState newState) { + if (iceState.exchange(newState) == newState) + return false; + + std::ostringstream s; + s << newState; + PLOG_INFO << "Changed ICE state to " << s.str(); + + if (newState == IceState::Closed) { + auto callback = std::move(iceStateChangeCallback); // steal the callback + callback(IceState::Closed); // call it synchronously + } else { + mProcessor.enqueue(&PeerConnection::trigger, shared_from_this(), + &iceStateChangeCallback, newState); + } + return true; +} + +bool PeerConnection::changeGatheringState(GatheringState newState) { + if (gatheringState.exchange(newState) == newState) + return false; + + std::ostringstream s; + s << newState; + PLOG_INFO << "Changed gathering state to " << s.str(); + mProcessor.enqueue(&PeerConnection::trigger, shared_from_this(), + &gatheringStateChangeCallback, newState); + + return true; +} + +bool PeerConnection::changeSignalingState(SignalingState newState) { + if (signalingState.exchange(newState) == newState) + return false; + + std::ostringstream s; + s << newState; + PLOG_INFO << "Changed signaling state to " << s.str(); + mProcessor.enqueue(&PeerConnection::trigger, shared_from_this(), + &signalingStateChangeCallback, newState); + + return true; +} + +void PeerConnection::resetCallbacks() { + // Unregister all callbacks + dataChannelCallback = nullptr; + localDescriptionCallback = nullptr; + localCandidateCallback = nullptr; + stateChangeCallback = nullptr; + iceStateChangeCallback = nullptr; + gatheringStateChangeCallback = nullptr; + signalingStateChangeCallback = nullptr; + trackCallback = nullptr; +} + +void PeerConnection::updateTrackSsrcCache(const Description &description) { + std::unique_lock lock(mTracksMutex); // for safely writing to mTracksBySsrc + + // Setup SSRC -> Track mapping + for (unsigned int i = 0; i < description.mediaCount(); ++i) + std::visit( // ssrc -> track mapping + rtc::overloaded{ + [&](Description::Application const *) { return; }, + [&](Description::Media const *media) { + const auto ssrcs = media->getSSRCs(); + + // Note: We don't want to lock (or do any other lookups), if we + // already know there's no SSRCs to loop over. + if (ssrcs.size() <= 0) { + return; + } + + std::shared_ptr track{nullptr}; + if (auto it = mTracks.find(media->mid()); it != mTracks.end()) + if (auto track_for_mid = it->second.lock()) + track = track_for_mid; + + if (!track) { + // Unable to find track for MID + return; + } + + for (auto ssrc : ssrcs) { + mTracksBySsrc.insert_or_assign(ssrc, track); + } + }, + }, + description.media(i)); +} + +} // namespace rtc::impl diff --git a/datachannel/src/impl/peerconnection.hpp b/datachannel/src/impl/peerconnection.hpp new file mode 100644 index 000000000..ad6f41e28 --- /dev/null +++ b/datachannel/src/impl/peerconnection.hpp @@ -0,0 +1,164 @@ +/** + * Copyright (c) 2019-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_PEER_CONNECTION_H +#define RTC_IMPL_PEER_CONNECTION_H + +#include "common.hpp" +#include "datachannel.hpp" +#include "dtlstransport.hpp" +#include "icetransport.hpp" +#include "init.hpp" +#include "processor.hpp" +#include "sctptransport.hpp" +#include "track.hpp" + +#include "rtc/peerconnection.hpp" + +#include +#include +#include +#include + +namespace rtc::impl { + +struct PeerConnection : std::enable_shared_from_this { + using State = rtc::PeerConnection::State; + using IceState = rtc::PeerConnection::IceState; + using GatheringState = rtc::PeerConnection::GatheringState; + using SignalingState = rtc::PeerConnection::SignalingState; + + PeerConnection(Configuration config_); + ~PeerConnection(); + + void close(); + void remoteClose(); + + optional localDescription() const; + optional remoteDescription() const; + size_t remoteMaxMessageSize() const; + + shared_ptr initIceTransport(); + shared_ptr initDtlsTransport(); + shared_ptr initSctpTransport(); + shared_ptr getIceTransport() const; + shared_ptr getDtlsTransport() const; + shared_ptr getSctpTransport() const; + void closeTransports(); + + void endLocalCandidates(); + void rollbackLocalDescription(); + bool checkFingerprint(const std::string &fingerprint) const; + void forwardMessage(message_ptr message); + void forwardMedia(message_ptr message); + void forwardBufferedAmount(uint16_t stream, size_t amount); + + shared_ptr emplaceDataChannel(string label, DataChannelInit init); + std::pair, bool> findDataChannel(uint16_t stream); + bool removeDataChannel(uint16_t stream); + uint16_t maxDataChannelStream() const; + void assignDataChannels(); + void iterateDataChannels(std::function channel)> func); + void openDataChannels(); + void closeDataChannels(); + void remoteCloseDataChannels(); + + shared_ptr emplaceTrack(Description::Media description); + void iterateTracks(std::function track)> func); + void openTracks(); + void closeTracks(); + + void validateRemoteDescription(const Description &description); + void processLocalDescription(Description description); + void processLocalCandidate(Candidate candidate); + void processRemoteDescription(Description description); + void processRemoteCandidate(Candidate candidate); + string localBundleMid() const; + + void setMediaHandler(shared_ptr handler); + shared_ptr getMediaHandler(); + + void triggerDataChannel(weak_ptr weakDataChannel); + void triggerTrack(weak_ptr weakTrack); + + void triggerPendingDataChannels(); + void triggerPendingTracks(); + + void flushPendingDataChannels(); + void flushPendingTracks(); + + bool changeState(State newState); + bool changeIceState(IceState newState); + bool changeGatheringState(GatheringState newState); + bool changeSignalingState(SignalingState newState); + + void resetCallbacks(); + + // Helper method for asynchronous callback invocation + template void trigger(synchronized_callback *cb, Args... args) { + try { + (*cb)(std::move(args...)); + } catch (const std::exception &e) { + PLOG_WARNING << "Uncaught exception in callback: " << e.what(); + } + } + + const Configuration config; + std::atomic state = State::New; + std::atomic iceState = IceState::New; + std::atomic gatheringState = GatheringState::New; + std::atomic signalingState = SignalingState::Stable; + std::atomic negotiationNeeded = false; + std::atomic closing = false; + std::mutex signalingMutex; + + synchronized_callback> dataChannelCallback; + synchronized_callback localDescriptionCallback; + synchronized_callback localCandidateCallback; + synchronized_callback stateChangeCallback; + synchronized_callback iceStateChangeCallback; + synchronized_callback gatheringStateChangeCallback; + synchronized_callback signalingStateChangeCallback; + synchronized_callback> trackCallback; + +private: + void dispatchMedia(message_ptr message); + void updateTrackSsrcCache(const Description &description); + + const init_token mInitToken = Init::Instance().token(); + const future_certificate_ptr mCertificate; + + Processor mProcessor; + optional mLocalDescription, mRemoteDescription; + optional mCurrentLocalDescription; + mutable std::mutex mLocalDescriptionMutex, mRemoteDescriptionMutex; + + shared_ptr mMediaHandler; + + mutable std::shared_mutex mMediaHandlerMutex; + + shared_ptr mIceTransport; + shared_ptr mDtlsTransport; + shared_ptr mSctpTransport; + + std::unordered_map> mDataChannels; // by stream ID + std::vector> mUnassignedDataChannels; + std::shared_mutex mDataChannelsMutex; + + std::unordered_map> mTracks; // by mid + std::unordered_map> mTracksBySsrc; // by SSRC + std::vector> mTrackLines; // by SDP order + std::shared_mutex mTracksMutex; + + Queue> mPendingDataChannels; + Queue> mPendingTracks; +}; + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/pollinterrupter.cpp b/datachannel/src/impl/pollinterrupter.cpp new file mode 100644 index 000000000..0d7fb82f7 --- /dev/null +++ b/datachannel/src/impl/pollinterrupter.cpp @@ -0,0 +1,125 @@ +/** + * Copyright (c) 2020-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "pollinterrupter.hpp" +#include "internals.hpp" + +#if RTC_ENABLE_WEBSOCKET + +#ifndef _WIN32 +#include +#include +#endif + +namespace rtc::impl { + +PollInterrupter::PollInterrupter() { +#ifdef _WIN32 + struct addrinfo *ai = NULL; + struct addrinfo hints = {}; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_DGRAM; + hints.ai_protocol = IPPROTO_UDP; + hints.ai_flags = AI_PASSIVE | AI_NUMERICSERV; + if (getaddrinfo("localhost", "0", &hints, &ai) != 0) + throw std::runtime_error("Resolution failed for localhost address"); + + try { + mSock = ::socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol); + if (mSock == INVALID_SOCKET) + throw std::runtime_error("UDP socket creation failed"); + + // Set non-blocking + ctl_t nbio = 1; + ::ioctlsocket(mSock, FIONBIO, &nbio); + + // Bind + if (::bind(mSock, ai->ai_addr, (socklen_t)ai->ai_addrlen) < 0) + throw std::runtime_error("Failed to bind UDP socket"); + + // Connect to self + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + if (::getsockname(mSock, reinterpret_cast(&addr), &addrlen) < 0) + throw std::runtime_error("getsockname failed"); + + if (::connect(mSock, reinterpret_cast(&addr), addrlen) < 0) + throw std::runtime_error("Failed to connect UDP socket"); + + } catch (...) { + freeaddrinfo(ai); + if (mSock != INVALID_SOCKET) + ::closesocket(mSock); + + throw; + } + + freeaddrinfo(ai); + +#else + int pipefd[2]; + if (::pipe(pipefd) != 0) + throw std::runtime_error("Failed to create pipe"); + + ::fcntl(pipefd[0], F_SETFL, O_NONBLOCK); + ::fcntl(pipefd[1], F_SETFL, O_NONBLOCK); + mPipeOut = pipefd[1]; // read + mPipeIn = pipefd[0]; // write +#endif +} + +PollInterrupter::~PollInterrupter() { +#ifdef _WIN32 + ::closesocket(mSock); +#else + ::close(mPipeIn); + ::close(mPipeOut); +#endif +} + +void PollInterrupter::prepare(struct pollfd &pfd) { +#ifdef _WIN32 + pfd.fd = mSock; +#else + pfd.fd = mPipeIn; +#endif + pfd.events = POLLIN; +} + +void PollInterrupter::process(struct pollfd &pfd) { + if (pfd.revents & POLLIN) { +#ifdef _WIN32 + char dummy; + while (::recv(pfd.fd, &dummy, 1, 0) >= 0) { + // Ignore + } +#else + char dummy; + while (::read(pfd.fd, &dummy, 1) > 0) { + // Ignore + } +#endif + } +} + +void PollInterrupter::interrupt() { +#ifdef _WIN32 + if (::send(mSock, NULL, 0, 0) < 0 && sockerrno != SEAGAIN && sockerrno != SEWOULDBLOCK) { + PLOG_WARNING << "Writing to interrupter socket failed, errno=" << sockerrno; + } +#else + char dummy = 0; + if (::write(mPipeOut, &dummy, 1) < 0 && errno != EAGAIN && errno != EWOULDBLOCK) { + PLOG_WARNING << "Writing to interrupter pipe failed, errno=" << errno; + } +#endif +} + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/pollinterrupter.hpp b/datachannel/src/impl/pollinterrupter.hpp new file mode 100644 index 000000000..8ab948b83 --- /dev/null +++ b/datachannel/src/impl/pollinterrupter.hpp @@ -0,0 +1,44 @@ +/** + * Copyright (c) 2020-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_POLL_INTERRUPTER_H +#define RTC_IMPL_POLL_INTERRUPTER_H + +#include "common.hpp" +#include "socket.hpp" + +#if RTC_ENABLE_WEBSOCKET + +namespace rtc::impl { + +// Utility class to interrupt poll() +class PollInterrupter final { +public: + PollInterrupter(); + ~PollInterrupter(); + + PollInterrupter(const PollInterrupter &other) = delete; + void operator=(const PollInterrupter &other) = delete; + + void prepare(struct pollfd &pfd); + void process(struct pollfd &pfd); + void interrupt(); + +private: +#ifdef _WIN32 + socket_t mSock; +#else // assume POSIX + int mPipeIn, mPipeOut; +#endif +}; + +} // namespace rtc::impl + +#endif + +#endif diff --git a/datachannel/src/impl/pollservice.cpp b/datachannel/src/impl/pollservice.cpp new file mode 100644 index 000000000..c03e6fc41 --- /dev/null +++ b/datachannel/src/impl/pollservice.cpp @@ -0,0 +1,229 @@ +/** + * Copyright (c) 2022 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "pollservice.hpp" +#include "internals.hpp" +#include "utils.hpp" + +#if RTC_ENABLE_WEBSOCKET + +#include +#include + +namespace rtc::impl { + +using namespace std::chrono_literals; +using std::chrono::duration_cast; +using std::chrono::milliseconds; + +PollService &PollService::Instance() { + static PollService *instance = new PollService; + return *instance; +} + +PollService::PollService() : mStopped(true) {} + +PollService::~PollService() {} + +void PollService::start() { + mSocks = std::make_unique(); + mInterrupter = std::make_unique(); + mStopped = false; + mThread = std::thread(&PollService::runLoop, this); +} + +void PollService::join() { + std::unique_lock lock(mMutex); + if (std::exchange(mStopped, true)) + return; + + lock.unlock(); + + mInterrupter->interrupt(); + mThread.join(); + + mSocks.reset(); + mInterrupter.reset(); +} + +void PollService::add(socket_t sock, Params params) { + assert(sock != INVALID_SOCKET); + assert(params.callback); + + std::unique_lock lock(mMutex); + PLOG_VERBOSE << "Registering socket in poll service, direction=" << params.direction; + auto until = params.timeout ? std::make_optional(clock::now() + *params.timeout) : nullopt; + assert(mSocks); + mSocks->insert_or_assign(sock, SocketEntry{std::move(params), std::move(until)}); + + assert(mInterrupter); + mInterrupter->interrupt(); +} + +void PollService::remove(socket_t sock) { + assert(sock != INVALID_SOCKET); + + std::unique_lock lock(mMutex); + PLOG_VERBOSE << "Unregistering socket in poll service"; + assert(mSocks); + mSocks->erase(sock); + + assert(mInterrupter); + mInterrupter->interrupt(); +} + +void PollService::prepare(std::vector &pfds, optional &next) { + std::unique_lock lock(mMutex); + pfds.resize(1 + mSocks->size()); + next.reset(); + + auto it = pfds.begin(); + mInterrupter->prepare(*it++); + for (const auto &[sock, entry] : *mSocks) { + it->fd = sock; + switch (entry.params.direction) { + case Direction::In: + it->events = POLLIN; + break; + case Direction::Out: + it->events = POLLOUT; + break; + default: + it->events = POLLIN | POLLOUT; + break; + } + if (entry.until) + next = next ? std::min(*next, *entry.until) : *entry.until; + + ++it; + } +} + +void PollService::process(std::vector &pfds) { + std::unique_lock lock(mMutex); + auto it = pfds.begin(); + if (it != pfds.end()) { + mInterrupter->process(*it++); + } + while (it != pfds.end()) { + socket_t sock = it->fd; + auto jt = mSocks->find(sock); + if (jt != mSocks->end()) { + try { + auto &entry = jt->second; + const auto ¶ms = entry.params; + + if (it->revents & POLLNVAL || it->revents & POLLERR || + (it->revents & POLLHUP && + !(it->events & POLLIN))) { // MacOS sets POLLHUP on connection failure + PLOG_VERBOSE << "Poll error event"; + auto callback = std::move(params.callback); + mSocks->erase(sock); + callback(Event::Error); + + } else if (it->revents & POLLIN || it->revents & POLLOUT || it->revents & POLLHUP) { + entry.until = params.timeout + ? std::make_optional(clock::now() + *params.timeout) + : nullopt; + + auto callback = params.callback; + if (it->revents & POLLIN || + it->revents & POLLHUP) { // Windows does not set POLLIN on close + PLOG_VERBOSE << "Poll in event"; + callback(Event::In); + } + if (it->revents & POLLOUT) { + PLOG_VERBOSE << "Poll out event"; + callback(Event::Out); + } + + } else if (entry.until && clock::now() >= *entry.until) { + PLOG_VERBOSE << "Poll timeout event"; + auto callback = std::move(params.callback); + mSocks->erase(sock); + callback(Event::Timeout); + } + + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + mSocks->erase(sock); + } + } + + ++it; + } +} + +void PollService::runLoop() { + utils::this_thread::set_name("RTC poll"); + PLOG_DEBUG << "Poll service started"; + + try { + assert(mSocks); + std::vector pfds; + optional next; + while (!mStopped) { + prepare(pfds, next); + + int ret; + do { + int timeout; + if (next) { + auto msecs = duration_cast( + std::max(clock::duration::zero(), *next - clock::now() + 1ms)); + PLOG_VERBOSE << "Entering poll, timeout=" << msecs.count() << "ms"; + timeout = static_cast(msecs.count()); + } else { + PLOG_VERBOSE << "Entering poll"; + timeout = -1; + } + + ret = ::poll(pfds.data(), static_cast(pfds.size()), timeout); + + PLOG_VERBOSE << "Exiting poll"; + + } while (ret < 0 && (sockerrno == SEINTR || sockerrno == SEAGAIN)); + +#ifdef _WIN32 + if (ret == WSAENOTSOCK) + continue; // prepare again as the fd has been removed +#endif + if (ret < 0) + throw std::runtime_error("poll failed, errno=" + std::to_string(sockerrno)); + + process(pfds); + } + } catch (const std::exception &e) { + PLOG_FATAL << "Poll service failed: " << e.what(); + } + + PLOG_DEBUG << "Poll service stopped"; +} + +std::ostream &operator<<(std::ostream &out, PollService::Direction direction) { + const char *str; + switch (direction) { + case PollService::Direction::In: + str = "in"; + break; + case PollService::Direction::Out: + str = "out"; + break; + case PollService::Direction::Both: + str = "both"; + break; + default: + str = "unknown"; + break; + } + return out << str; +} + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/pollservice.hpp b/datachannel/src/impl/pollservice.hpp new file mode 100644 index 000000000..d0fd6e784 --- /dev/null +++ b/datachannel/src/impl/pollservice.hpp @@ -0,0 +1,82 @@ +/** + * Copyright (c) 2022 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_POLL_SERVICE_H +#define RTC_IMPL_POLL_SERVICE_H + +#include "common.hpp" +#include "internals.hpp" +#include "pollinterrupter.hpp" +#include "socket.hpp" + +#if RTC_ENABLE_WEBSOCKET + +#include +#include +#include +#include +#include +#include + +namespace rtc::impl { + +class PollService { +public: + using clock = std::chrono::steady_clock; + + static PollService &Instance(); + + PollService(const PollService &) = delete; + PollService &operator=(const PollService &) = delete; + PollService(PollService &&) = delete; + PollService &operator=(PollService &&) = delete; + + void start(); + void join(); + + enum class Direction { Both, In, Out }; + enum class Event { None, Error, Timeout, In, Out }; + + struct Params { + Direction direction; + optional timeout; + std::function callback; + }; + + void add(socket_t sock, Params params); + void remove(socket_t sock); + +private: + PollService(); + ~PollService(); + + void prepare(std::vector &pfds, optional &next); + void process(std::vector &pfds); + void runLoop(); + + struct SocketEntry { + Params params; + optional until; + }; + + using SocketMap = std::unordered_map; + unique_ptr mSocks; + unique_ptr mInterrupter; + + std::recursive_mutex mMutex; + std::thread mThread; + bool mStopped; +}; + +std::ostream &operator<<(std::ostream &out, PollService::Direction direction); + +} // namespace rtc::impl + +#endif + +#endif diff --git a/datachannel/src/impl/processor.cpp b/datachannel/src/impl/processor.cpp new file mode 100644 index 000000000..f24cc097b --- /dev/null +++ b/datachannel/src/impl/processor.cpp @@ -0,0 +1,42 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "processor.hpp" + +namespace rtc::impl { + +Processor::Processor(size_t limit) : mTasks(limit) {} + +Processor::~Processor() { join(); } + +void Processor::join() { + std::unique_lock lock(mMutex); + mCondition.wait(lock, [this]() { return !mPending && mTasks.empty(); }); +} + +void Processor::schedule() { + std::unique_lock lock(mMutex); + if (auto next = mTasks.pop()) { + ThreadPool::Instance().enqueue(std::move(*next)); + } else { + // No more tasks + mPending = false; + mCondition.notify_all(); + } +} + +TearDownProcessor &TearDownProcessor::Instance() { + static TearDownProcessor *instance = new TearDownProcessor; + return *instance; +} + +TearDownProcessor::TearDownProcessor() {} + +TearDownProcessor::~TearDownProcessor() {} + +} // namespace rtc::impl diff --git a/datachannel/src/impl/processor.hpp b/datachannel/src/impl/processor.hpp new file mode 100644 index 000000000..d26bda55f --- /dev/null +++ b/datachannel/src/impl/processor.hpp @@ -0,0 +1,76 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_PROCESSOR_H +#define RTC_IMPL_PROCESSOR_H + +#include "common.hpp" +#include "queue.hpp" +#include "threadpool.hpp" + +#include +#include +#include +#include +#include + +namespace rtc::impl { + +// Processed tasks in order by delegating them to the thread pool +class Processor { +public: + Processor(size_t limit = 0); + virtual ~Processor(); + + Processor(const Processor &) = delete; + Processor &operator=(const Processor &) = delete; + Processor(Processor &&) = delete; + Processor &operator=(Processor &&) = delete; + + void join(); + + template void enqueue(F &&f, Args &&...args) noexcept; + +private: + void schedule(); + + Queue> mTasks; + bool mPending = false; // true iff a task is pending in the thread pool + + mutable std::mutex mMutex; + std::condition_variable mCondition; +}; + +class TearDownProcessor final : public Processor { +public: + static TearDownProcessor &Instance(); + +private: + TearDownProcessor(); + ~TearDownProcessor(); +}; + +template void Processor::enqueue(F &&f, Args &&...args) noexcept { + std::unique_lock lock(mMutex); + auto bound = std::bind(std::forward(f), std::forward(args)...); + auto task = [this, bound = std::move(bound)]() mutable { + scope_guard guard(std::bind(&Processor::schedule, this)); // chain the next task + return bound(); + }; + + if (!mPending) { + ThreadPool::Instance().enqueue(std::move(task)); + mPending = true; + } else { + mTasks.push(std::move(task)); + } +} + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/queue.hpp b/datachannel/src/impl/queue.hpp new file mode 100644 index 000000000..6a5ee2205 --- /dev/null +++ b/datachannel/src/impl/queue.hpp @@ -0,0 +1,129 @@ +/** + * Copyright (c) 2019 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_QUEUE_H +#define RTC_IMPL_QUEUE_H + +#include "common.hpp" + +#include +#include +#include +#include +#include + +namespace rtc::impl { + +template class Queue { +public: + using amount_function = std::function; + + Queue(size_t limit = 0, amount_function func = nullptr); + ~Queue(); + + void stop(); + bool running() const; + bool empty() const; + bool full() const; + size_t size() const; // elements + size_t amount() const; // amount + void push(T element); + optional pop(); + optional peek(); + optional exchange(T element); + +private: + const size_t mLimit; + size_t mAmount; + std::queue mQueue; + std::condition_variable mPushCondition; + amount_function mAmountFunction; + bool mStopping = false; + + mutable std::mutex mMutex; +}; + +template +Queue::Queue(size_t limit, amount_function func) : mLimit(limit), mAmount(0) { + mAmountFunction = func ? func : [](const T &element) -> size_t { + static_cast(element); + return 1; + }; +} + +template Queue::~Queue() { stop(); } + +template void Queue::stop() { + std::lock_guard lock(mMutex); + mStopping = true; + mPushCondition.notify_all(); +} + +template bool Queue::running() const { + std::lock_guard lock(mMutex); + return !mQueue.empty() || !mStopping; +} + +template bool Queue::empty() const { + std::lock_guard lock(mMutex); + return mQueue.empty(); +} + +template bool Queue::full() const { + std::lock_guard lock(mMutex); + return mQueue.size() >= mLimit; +} + +template size_t Queue::size() const { + std::lock_guard lock(mMutex); + return mQueue.size(); +} + +template size_t Queue::amount() const { + std::lock_guard lock(mMutex); + return mAmount; +} + +template void Queue::push(T element) { + std::unique_lock lock(mMutex); + mPushCondition.wait(lock, [this]() { return !mLimit || mQueue.size() < mLimit || mStopping; }); + if (mStopping) + return; + + mAmount += mAmountFunction(element); + mQueue.emplace(std::move(element)); +} + +template optional Queue::pop() { + std::unique_lock lock(mMutex); + if (mQueue.empty()) + return nullopt; + + mAmount -= mAmountFunction(mQueue.front()); + optional element{std::move(mQueue.front())}; + mQueue.pop(); + return element; +} + +template optional Queue::peek() { + std::unique_lock lock(mMutex); + return !mQueue.empty() ? std::make_optional(mQueue.front()) : nullopt; +} + +template optional Queue::exchange(T element) { + std::unique_lock lock(mMutex); + if (mQueue.empty()) + return nullopt; + + std::swap(mQueue.front(), element); + return std::make_optional(std::move(element)); +} + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/sctptransport.cpp b/datachannel/src/impl/sctptransport.cpp new file mode 100644 index 000000000..0946d2542 --- /dev/null +++ b/datachannel/src/impl/sctptransport.cpp @@ -0,0 +1,1005 @@ +/** + * Copyright (c) 2019 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "sctptransport.hpp" +#include "dtlstransport.hpp" +#include "internals.hpp" +#include "logcounter.hpp" +#include "utils.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// RFC 8831: SCTP MUST support performing Path MTU discovery without relying on ICMP or ICMPv6 as +// specified in [RFC4821] by using probing messages specified in [RFC4820]. +// See https://www.rfc-editor.org/rfc/rfc8831.html#section-5 +// +// However, usrsctp does not implement Path MTU discovery, so we need to disable it for now. +// See https://github.com/sctplab/usrsctp/issues/205 +#define USE_PMTUD 0 + +// TODO: When Path MTU discovery is supported, it needs to be enabled with libjuice as ICE backend +// on all platforms except Mac OS where the Don't Fragment (DF) flag can't be set: +/* +#if !USE_NICE +#ifndef __APPLE__ +// libjuice enables Linux path MTU discovery or sets the DF flag +#define USE_PMTUD 1 +#else +// Setting the DF flag is not available on Mac OS +#define USE_PMTUD 0 +#endif +#else // USE_NICE == 1 +#define USE_PMTUD 0 +#endif +*/ + +using namespace std::chrono_literals; +using namespace std::chrono; + +namespace rtc::impl { + +using utils::to_uint16; +using utils::to_uint32; + +static LogCounter COUNTER_UNKNOWN_PPID(plog::warning, + "Number of SCTP packets received with an unknown PPID"); + +class SctpTransport::InstancesSet { +public: + void insert(SctpTransport *instance) { + std::unique_lock lock(mMutex); + mSet.insert(instance); + } + + void erase(SctpTransport *instance) { + std::unique_lock lock(mMutex); + mSet.erase(instance); + } + + using shared_lock = std::shared_lock; + optional lock(SctpTransport *instance) noexcept { + shared_lock lock(mMutex); + return mSet.find(instance) != mSet.end() ? std::make_optional(std::move(lock)) : nullopt; + } + +private: + std::unordered_set mSet; + std::shared_mutex mMutex; +}; + +SctpTransport::InstancesSet *SctpTransport::Instances = new InstancesSet; + +void SctpTransport::Init() { + usrsctp_init(0, SctpTransport::WriteCallback, SctpTransport::DebugCallback); + usrsctp_sysctl_set_sctp_pr_enable(1); // Enable Partial Reliability Extension (RFC 3758) + usrsctp_sysctl_set_sctp_ecn_enable(0); // Disable Explicit Congestion Notification +#ifndef SCTP_ACCEPT_ZERO_CHECKSUM + usrsctp_enable_crc32c_offload(); // We'll compute CRC32 only for outgoing packets +#endif +#ifdef SCTP_DEBUG + usrsctp_sysctl_set_sctp_debug_on(SCTP_DEBUG_ALL); +#endif +} + +void SctpTransport::SetSettings(const SctpSettings &s) { + // The send and receive window size of usrsctp is 256KiB, which is too small for realistic RTTs, + // therefore we increase it to 1MiB by default for better performance. + // See https://bugzilla.mozilla.org/show_bug.cgi?id=1051685 + usrsctp_sysctl_set_sctp_recvspace(to_uint32(s.recvBufferSize.value_or(1024 * 1024))); + usrsctp_sysctl_set_sctp_sendspace(to_uint32(s.sendBufferSize.value_or(1024 * 1024))); + + // Increase maximum chunks number on queue to 10K by default + usrsctp_sysctl_set_sctp_max_chunks_on_queue(to_uint32(s.maxChunksOnQueue.value_or(10 * 1024))); + + // Increase initial congestion window size to 10 MTUs (RFC 6928) by default + usrsctp_sysctl_set_sctp_initial_cwnd(to_uint32(s.initialCongestionWindow.value_or(10))); + + // Set max burst to 10 MTUs by default (max burst is initially 0, meaning disabled) + usrsctp_sysctl_set_sctp_max_burst_default(to_uint32(s.maxBurst.value_or(10))); + + // Use standard SCTP congestion control (RFC 4960) by default + // See https://github.com/paullouisageneau/libdatachannel/issues/354 + usrsctp_sysctl_set_sctp_default_cc_module(to_uint32(s.congestionControlModule.value_or(0))); + + // Reduce SACK delay to 20ms by default (the recommended default value from RFC 4960 is 200ms) + usrsctp_sysctl_set_sctp_delayed_sack_time_default( + to_uint32(s.delayedSackTime.value_or(20ms).count())); + + // RTO settings + // RFC 2988 recommends a 1s min RTO, which is very high, but TCP on Linux has a 200ms min RTO + usrsctp_sysctl_set_sctp_rto_min_default( + to_uint32(s.minRetransmitTimeout.value_or(200ms).count())); + // Set only 10s as max RTO instead of 60s for shorter connection timeout + usrsctp_sysctl_set_sctp_rto_max_default( + to_uint32(s.maxRetransmitTimeout.value_or(10000ms).count())); + usrsctp_sysctl_set_sctp_init_rto_max_default( + to_uint32(s.maxRetransmitTimeout.value_or(10000ms).count())); + // Still set 1s as initial RTO + usrsctp_sysctl_set_sctp_rto_initial_default( + to_uint32(s.initialRetransmitTimeout.value_or(1000ms).count())); + + // RTX settings + // 5 retransmissions instead of 8 to shorten the backoff for shorter connection timeout + auto maxRtx = to_uint32(s.maxRetransmitAttempts.value_or(5)); + usrsctp_sysctl_set_sctp_init_rtx_max_default(maxRtx); + usrsctp_sysctl_set_sctp_assoc_rtx_max_default(maxRtx); + usrsctp_sysctl_set_sctp_path_rtx_max_default(maxRtx); // single path + + // Heartbeat interval + usrsctp_sysctl_set_sctp_heartbeat_interval_default( + to_uint32(s.heartbeatInterval.value_or(10000ms).count())); +} + +void SctpTransport::Cleanup() { + while (usrsctp_finish()) + std::this_thread::sleep_for(100ms); +} + +SctpTransport::SctpTransport(shared_ptr lower, const Configuration &config, Ports ports, + message_callback recvCallback, amount_callback bufferedAmountCallback, + state_callback stateChangeCallback) + : Transport(lower, std::move(stateChangeCallback)), + mMaxMessageSize(config.maxMessageSize.value_or(DEFAULT_LOCAL_MAX_MESSAGE_SIZE)), + mPorts(std::move(ports)), mSendQueue(0, message_size_func), + mBufferedAmountCallback(std::move(bufferedAmountCallback)) { + onRecv(std::move(recvCallback)); + + PLOG_DEBUG << "Initializing SCTP transport"; + + mSock = usrsctp_socket(AF_CONN, SOCK_STREAM, IPPROTO_SCTP, nullptr, nullptr, 0, nullptr); + if (!mSock) + throw std::runtime_error("Could not create SCTP socket, errno=" + std::to_string(errno)); + + usrsctp_set_upcall(mSock, &SctpTransport::UpcallCallback, this); + + if (usrsctp_set_non_blocking(mSock, 1)) + throw std::runtime_error("Unable to set non-blocking mode, errno=" + std::to_string(errno)); + + // SCTP must stop sending after the lower layer is shut down, so disable linger + struct linger sol = {}; + sol.l_onoff = 1; + sol.l_linger = 0; + if (usrsctp_setsockopt(mSock, SOL_SOCKET, SO_LINGER, &sol, sizeof(sol))) + throw std::runtime_error("Could not set socket option SO_LINGER, errno=" + + std::to_string(errno)); + + struct sctp_assoc_value av = {}; + av.assoc_id = SCTP_ALL_ASSOC; + av.assoc_value = 1; + if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_ENABLE_STREAM_RESET, &av, sizeof(av))) + throw std::runtime_error("Could not set socket option SCTP_ENABLE_STREAM_RESET, errno=" + + std::to_string(errno)); + int on = 1; + if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RECVRCVINFO, &on, sizeof(on))) + throw std::runtime_error("Could set socket option SCTP_RECVRCVINFO, errno=" + + std::to_string(errno)); + + struct sctp_event se = {}; + se.se_assoc_id = SCTP_ALL_ASSOC; + se.se_on = 1; + se.se_type = SCTP_ASSOC_CHANGE; + if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_EVENT, &se, sizeof(se))) + throw std::runtime_error("Could not subscribe to event SCTP_ASSOC_CHANGE, errno=" + + std::to_string(errno)); + se.se_type = SCTP_SENDER_DRY_EVENT; + if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_EVENT, &se, sizeof(se))) + throw std::runtime_error("Could not subscribe to event SCTP_SENDER_DRY_EVENT, errno=" + + std::to_string(errno)); + se.se_type = SCTP_STREAM_RESET_EVENT; + if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_EVENT, &se, sizeof(se))) + throw std::runtime_error("Could not subscribe to event SCTP_STREAM_RESET_EVENT, errno=" + + std::to_string(errno)); + + // RFC 8831 6.6. Transferring User Data on a Data Channel + // The sender SHOULD disable the Nagle algorithm (see [RFC1122) to minimize the latency + // See https://www.rfc-editor.org/rfc/rfc8831.html#section-6.6 + int nodelay = 1; + if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_NODELAY, &nodelay, sizeof(nodelay))) + throw std::runtime_error("Could not set socket option SCTP_NODELAY, errno=" + + std::to_string(errno)); + + struct sctp_paddrparams spp = {}; + // Enable SCTP heartbeats + spp.spp_flags = SPP_HB_ENABLE; + + // RFC 8261 5. DTLS considerations: + // If path MTU discovery is performed by the SCTP layer and IPv4 is used as the network-layer + // protocol, the DTLS implementation SHOULD allow the DTLS user to enforce that the + // corresponding IPv4 packet is sent with the Don't Fragment (DF) bit set. If controlling the DF + // bit is not possible (for example, due to implementation restrictions), a safe value for the + // path MTU has to be used by the SCTP stack. It is RECOMMENDED that the safe value not exceed + // 1200 bytes. + // See https://www.rfc-editor.org/rfc/rfc8261.html#section-5 +#if USE_PMTUD + if (!config.mtu.has_value()) { +#else + if (false) { +#endif + // Enable SCTP path MTU discovery + spp.spp_flags |= SPP_PMTUD_ENABLE; + PLOG_VERBOSE << "Path MTU discovery enabled"; + + } else { + // Fall back to a safe MTU value. + spp.spp_flags |= SPP_PMTUD_DISABLE; + // The MTU value provided specifies the space available for chunks in the + // packet, so we also subtract the SCTP header size. + size_t pmtu = config.mtu.value_or(DEFAULT_MTU) - 12 - 48 - 8 - 40; // SCTP/DTLS/UDP/IPv6 + spp.spp_pathmtu = to_uint32(pmtu); + PLOG_VERBOSE << "Path MTU discovery disabled, SCTP MTU set to " << pmtu; + } + + if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_PEER_ADDR_PARAMS, &spp, sizeof(spp))) + throw std::runtime_error("Could not set socket option SCTP_PEER_ADDR_PARAMS, errno=" + + std::to_string(errno)); + + // RFC 8831 6.2. SCTP Association Management + // The number of streams negotiated during SCTP association setup SHOULD be 65535, which is the + // maximum number of streams that can be negotiated during the association setup. + // See https://www.rfc-editor.org/rfc/rfc8831.html#section-6.2 + // However, usrsctp allocates tables to hold the stream states. For 65535 streams, it results in + // the waste of a few MBs for each association. Therefore, we use a lower limit to save memory. + // See https://github.com/sctplab/usrsctp/issues/121 + struct sctp_initmsg sinit = {}; + sinit.sinit_num_ostreams = MAX_SCTP_STREAMS_COUNT; + sinit.sinit_max_instreams = MAX_SCTP_STREAMS_COUNT; + if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_INITMSG, &sinit, sizeof(sinit))) + throw std::runtime_error("Could not set socket option SCTP_INITMSG, errno=" + + std::to_string(errno)); + + // Prevent fragmented interleave of messages (i.e. level 0), see RFC 6458 section 8.1.20. + // Unless the user has set the fragmentation interleave level to 0, notifications + // may also be interleaved with partially delivered messages. + int level = 0; + if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_FRAGMENT_INTERLEAVE, &level, sizeof(level))) + throw std::runtime_error("Could not disable SCTP fragmented interleave, errno=" + + std::to_string(errno)); + +#ifdef SCTP_ACCEPT_ZERO_CHECKSUM // not available in usrsctp v0.9.5.0 + // When using SCTP over DTLS, the data integrity is ensured by DTLS. Therefore, there's no + // need to check CRC32c additionally when receiving. See + // https://datatracker.ietf.org/doc/html/draft-ietf-tsvwg-sctp-zero-checksum + int edmid = SCTP_EDMID_LOWER_LAYER_DTLS; + if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_ACCEPT_ZERO_CHECKSUM, &edmid, sizeof(edmid))) + throw std::runtime_error("Could set socket option SCTP_ACCEPT_ZERO_CHECKSUM, errno=" + + std::to_string(errno)); +#endif + + int rcvBuf = 0; + socklen_t rcvBufLen = sizeof(rcvBuf); + if (usrsctp_getsockopt(mSock, SOL_SOCKET, SO_RCVBUF, &rcvBuf, &rcvBufLen)) + throw std::runtime_error("Could not get SCTP recv buffer size, errno=" + + std::to_string(errno)); + int sndBuf = 0; + socklen_t sndBufLen = sizeof(sndBuf); + if (usrsctp_getsockopt(mSock, SOL_SOCKET, SO_SNDBUF, &sndBuf, &sndBufLen)) + throw std::runtime_error("Could not get SCTP send buffer size, errno=" + + std::to_string(errno)); + + // Ensure the buffer is also large enough to accomodate the largest messages + const int minBuf = int(std::min(mMaxMessageSize, size_t(std::numeric_limits::max()))); + rcvBuf = std::max(rcvBuf, minBuf); + sndBuf = std::max(sndBuf, minBuf); + + if (usrsctp_setsockopt(mSock, SOL_SOCKET, SO_RCVBUF, &rcvBuf, sizeof(rcvBuf))) + throw std::runtime_error("Could not set SCTP recv buffer size, errno=" + + std::to_string(errno)); + + if (usrsctp_setsockopt(mSock, SOL_SOCKET, SO_SNDBUF, &sndBuf, sizeof(sndBuf))) + throw std::runtime_error("Could not set SCTP send buffer size, errno=" + + std::to_string(errno)); + + usrsctp_register_address(this); + Instances->insert(this); +} + +SctpTransport::~SctpTransport() { + PLOG_DEBUG << "Destroying SCTP transport"; + + mProcessor.join(); // if we are here, the processor must be empty + + // Before unregistering incoming() from the lower layer, we need to make sure the thread from + // lower layers is not blocked in incoming() by the WrittenOnce condition. + mWrittenOnce = true; + mWrittenCondition.notify_all(); + + unregisterIncoming(); + + usrsctp_close(mSock); + + usrsctp_deregister_address(this); + Instances->erase(this); +} + +void SctpTransport::onBufferedAmount(amount_callback callback) { + mBufferedAmountCallback = std::move(callback); +} + +void SctpTransport::start() { + registerIncoming(); + connect(); +} + +void SctpTransport::stop() { close(); } + +struct sockaddr_conn SctpTransport::getSockAddrConn(uint16_t port) { + struct sockaddr_conn sconn = {}; + sconn.sconn_family = AF_CONN; + sconn.sconn_port = htons(port); + sconn.sconn_addr = this; +#ifdef HAVE_SCONN_LEN + sconn.sconn_len = sizeof(sconn); +#endif + return sconn; +} + +void SctpTransport::connect() { + PLOG_DEBUG << "SCTP connecting (local port=" << mPorts.local + << ", remote port=" << mPorts.remote << ")"; + changeState(State::Connecting); + + auto local = getSockAddrConn(mPorts.local); + if (usrsctp_bind(mSock, reinterpret_cast(&local), sizeof(local))) + throw std::runtime_error("Could not bind usrsctp socket, errno=" + std::to_string(errno)); + + // According to RFC 8841, both endpoints must initiate the SCTP association, in a + // simultaneous-open manner, irrelevent to the SDP setup role. + // See https://www.rfc-editor.org/rfc/rfc8841.html#section-9.3 + auto remote = getSockAddrConn(mPorts.remote); + int ret = usrsctp_connect(mSock, reinterpret_cast(&remote), sizeof(remote)); + if (ret && errno != EINPROGRESS) + throw std::runtime_error("Connection attempt failed, errno=" + std::to_string(errno)); +} + +bool SctpTransport::send(message_ptr message) { + std::lock_guard lock(mSendMutex); + if (state() != State::Connected) + return false; + + if (!message) + return trySendQueue(); + + PLOG_VERBOSE << "Send size=" << message->size(); + + if (message->size() > mMaxMessageSize) + throw std::invalid_argument("Message is too large"); + + // Flush the queue, and if nothing is pending, try to send directly + if (trySendQueue() && trySendMessage(message)) + return true; + + mSendQueue.push(message); + updateBufferedAmount(to_uint16(message->stream), ptrdiff_t(message_size_func(message))); + return false; +} + +bool SctpTransport::flush() { + try { + std::lock_guard lock(mSendMutex); + if (state() != State::Connected) + return false; + + trySendQueue(); + return true; + + } catch (const std::exception &e) { + PLOG_WARNING << "SCTP flush: " << e.what(); + return false; + } +} + +void SctpTransport::closeStream(unsigned int stream) { + std::lock_guard lock(mSendMutex); + + // RFC 8831 6.7. Closing a Data Channel + // Closing of a data channel MUST be signaled by resetting the corresponding outgoing streams + // See https://www.rfc-editor.org/rfc/rfc8831.html#section-6.7 + mSendQueue.push(make_message(0, Message::Reset, to_uint16(stream))); + + // This method must not call the buffered callback synchronously + mProcessor.enqueue(&SctpTransport::flush, shared_from_this()); +} + +void SctpTransport::close() { + mSendQueue.stop(); + if (state() == State::Connected) { + mProcessor.enqueue(&SctpTransport::flush, shared_from_this()); + } else if (state() == State::Connecting) { + PLOG_DEBUG << "SCTP early shutdown"; + if (usrsctp_shutdown(mSock, SHUT_RDWR)) { + if (errno == ENOTCONN) { + PLOG_VERBOSE << "SCTP already shut down"; + } else { + PLOG_WARNING << "SCTP shutdown failed, errno=" << errno; + } + } + changeState(State::Failed); + mWrittenCondition.notify_all(); + } +} + +unsigned int SctpTransport::maxStream() const { + unsigned int streamsCount = mNegotiatedStreamsCount.value_or(MAX_SCTP_STREAMS_COUNT); + return streamsCount > 0 ? streamsCount - 1 : 0; +} + +void SctpTransport::incoming(message_ptr message) { + // There could be a race condition here where we receive the remote INIT before the local one is + // sent, which would result in the connection being aborted. Therefore, we need to wait for data + // to be sent on our side (i.e. the local INIT) before proceeding. + if (!mWrittenOnce) { // test the atomic boolean is not set first to prevent a lock contention + std::unique_lock lock(mWriteMutex); + mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || state() == State::Failed; }); + } + + if (state() == State::Failed) + return; + + if (!message) { + PLOG_INFO << "SCTP disconnected"; + changeState(State::Disconnected); + recv(nullptr); + return; + } + + PLOG_VERBOSE << "Incoming size=" << message->size(); + + usrsctp_conninput(this, message->data(), message->size(), 0); +} + +bool SctpTransport::outgoing(message_ptr message) { + // Set recommended medium-priority DSCP value + // See https://www.rfc-editor.org/rfc/rfc8837.html#section-5 + message->dscp = 10; // AF11: Assured Forwarding class 1, low drop probability + return Transport::outgoing(std::move(message)); +} + +void SctpTransport::doRecv() { + std::lock_guard lock(mRecvMutex); + --mPendingRecvCount; + try { + while (state() != State::Disconnected && state() != State::Failed) { + const size_t bufferSize = 65536; + byte buffer[bufferSize]; + socklen_t fromlen = 0; + struct sctp_rcvinfo info = {}; + socklen_t infolen = sizeof(info); + unsigned int infotype = 0; + int flags = 0; + ssize_t len = usrsctp_recvv(mSock, buffer, bufferSize, nullptr, &fromlen, &info, + &infolen, &infotype, &flags); + if (len < 0) { + if (errno == EWOULDBLOCK || errno == EAGAIN || errno == ECONNRESET) + break; + else + throw std::runtime_error("SCTP recv failed, errno=" + std::to_string(errno)); + } else if (len == 0) { + break; + } + + PLOG_VERBOSE << "SCTP recv, len=" << len; + + // SCTP_FRAGMENT_INTERLEAVE does not seem to work as expected for messages > 64KB, + // therefore partial notifications and messages need to be handled separately. + if (flags & MSG_NOTIFICATION) { + // SCTP event notification + mPartialNotification.insert(mPartialNotification.end(), buffer, buffer + len); + + if (flags & MSG_EOR) { + // Notification is complete, process it + binary notification; + mPartialNotification.swap(notification); + auto n = reinterpret_cast(notification.data()); + processNotification(n, notification.size()); + } + + } else { + // SCTP message + mPartialMessage.insert(mPartialMessage.end(), buffer, buffer + len); + if (mPartialMessage.size() > mMaxMessageSize) { + PLOG_WARNING << "SCTP message is too large, truncating it"; + mPartialMessage.resize(mMaxMessageSize); + } + + if (flags & MSG_EOR) { + // Message is complete, process it + binary message; + mPartialMessage.swap(message); + if (infotype != SCTP_RECVV_RCVINFO) + throw std::runtime_error("Missing SCTP recv info"); + + processData(std::move(message), info.rcv_sid, PayloadId(ntohl(info.rcv_ppid))); + } + } + } + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + } +} + +void SctpTransport::doFlush() { + std::lock_guard lock(mSendMutex); + --mPendingFlushCount; + try { + trySendQueue(); + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + } +} + +void SctpTransport::enqueueRecv() { + if (mPendingRecvCount > 0) + return; + + if (auto shared_this = weak_from_this().lock()) { + // This is called from the upcall callback, we must not release the shared ptr here + ++mPendingRecvCount; + mProcessor.enqueue(&SctpTransport::doRecv, std::move(shared_this)); + } +} + +void SctpTransport::enqueueFlush() { + if (mPendingFlushCount > 0) + return; + + if (auto shared_this = weak_from_this().lock()) { + // This is called from the upcall callback, we must not release the shared ptr here + ++mPendingFlushCount; + mProcessor.enqueue(&SctpTransport::doFlush, std::move(shared_this)); + } +} + +bool SctpTransport::trySendQueue() { + // Requires mSendMutex to be locked + while (auto next = mSendQueue.peek()) { + message_ptr message = std::move(*next); + if (!trySendMessage(message)) + return false; + + mSendQueue.pop(); + updateBufferedAmount(to_uint16(message->stream), -ptrdiff_t(message_size_func(message))); + } + + if (!mSendQueue.running() && !std::exchange(mSendShutdown, true)) { + PLOG_DEBUG << "SCTP shutdown"; + if (usrsctp_shutdown(mSock, SHUT_WR)) { + if (errno == ENOTCONN) { + PLOG_VERBOSE << "SCTP already shut down"; + } else { + PLOG_WARNING << "SCTP shutdown failed, errno=" << errno; + changeState(State::Disconnected); + recv(nullptr); + } + } + } + + return true; +} + +bool SctpTransport::trySendMessage(message_ptr message) { + // Requires mSendMutex to be locked + if (state() != State::Connected) + return false; + + uint32_t ppid; + switch (message->type) { + case Message::String: + ppid = !message->empty() ? PPID_STRING : PPID_STRING_EMPTY; + break; + case Message::Binary: + ppid = !message->empty() ? PPID_BINARY : PPID_BINARY_EMPTY; + break; + case Message::Control: + ppid = PPID_CONTROL; + break; + case Message::Reset: + sendReset(uint16_t(message->stream)); + return true; + default: + // Ignore + return true; + } + + PLOG_VERBOSE << "SCTP try send size=" << message->size(); + + // TODO: Implement SCTP ndata specification draft when supported everywhere + // See https://datatracker.ietf.org/doc/html/draft-ietf-tsvwg-sctp-ndata-08 + + const Reliability reliability = message->reliability ? *message->reliability : Reliability(); + + struct sctp_sendv_spa spa = {}; + + // set sndinfo + spa.sendv_flags |= SCTP_SEND_SNDINFO_VALID; + spa.sendv_sndinfo.snd_sid = uint16_t(message->stream); + spa.sendv_sndinfo.snd_ppid = htonl(ppid); + spa.sendv_sndinfo.snd_flags |= SCTP_EOR; // implicit here + + // set prinfo + spa.sendv_flags |= SCTP_SEND_PRINFO_VALID; + if (reliability.unordered) + spa.sendv_sndinfo.snd_flags |= SCTP_UNORDERED; + + if (reliability.maxPacketLifeTime) { + spa.sendv_flags |= SCTP_SEND_PRINFO_VALID; + spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_TTL; + spa.sendv_prinfo.pr_value = to_uint32(reliability.maxPacketLifeTime->count()); + } else if (reliability.maxRetransmits) { + spa.sendv_flags |= SCTP_SEND_PRINFO_VALID; + spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_RTX; + spa.sendv_prinfo.pr_value = to_uint32(*reliability.maxRetransmits); + } + // else { + // spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_NONE; + // } + // Deprecated + else switch (reliability.typeDeprecated) { + case Reliability::Type::Rexmit: + spa.sendv_flags |= SCTP_SEND_PRINFO_VALID; + spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_RTX; + spa.sendv_prinfo.pr_value = to_uint32(std::get(reliability.rexmit)); + break; + case Reliability::Type::Timed: + spa.sendv_flags |= SCTP_SEND_PRINFO_VALID; + spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_TTL; + spa.sendv_prinfo.pr_value = to_uint32(std::get(reliability.rexmit).count()); + break; + default: + spa.sendv_prinfo.pr_policy = SCTP_PR_SCTP_NONE; + break; + } + + ssize_t ret; + if (!message->empty()) { + ret = usrsctp_sendv(mSock, message->data(), message->size(), nullptr, 0, &spa, sizeof(spa), + SCTP_SENDV_SPA, 0); + } else { + const char zero = 0; + ret = usrsctp_sendv(mSock, &zero, 1, nullptr, 0, &spa, sizeof(spa), SCTP_SENDV_SPA, 0); + } + + if (ret < 0) { + if (errno == EWOULDBLOCK || errno == EAGAIN) { + PLOG_VERBOSE << "SCTP sending not possible"; + return false; + } + + PLOG_ERROR << "SCTP sending failed, errno=" << errno; + throw std::runtime_error("Sending failed, errno=" + std::to_string(errno)); + } + + PLOG_VERBOSE << "SCTP sent size=" << message->size(); + if (message->type == Message::Binary || message->type == Message::String) + mBytesSent += message->size(); + return true; +} + +void SctpTransport::updateBufferedAmount(uint16_t streamId, ptrdiff_t delta) { + // Requires mSendMutex to be locked + + if (delta == 0) + return; + + auto it = mBufferedAmount.insert(std::make_pair(streamId, 0)).first; + size_t amount = size_t(std::max(ptrdiff_t(it->second) + delta, ptrdiff_t(0))); + if (amount == 0) + mBufferedAmount.erase(it); + else + it->second = amount; + + // Synchronously call the buffered amount callback + triggerBufferedAmount(streamId, amount); +} + +void SctpTransport::triggerBufferedAmount(uint16_t streamId, size_t amount) { + try { + mBufferedAmountCallback(streamId, amount); + } catch (const std::exception &e) { + PLOG_WARNING << "SCTP buffered amount callback: " << e.what(); + } +} + +void SctpTransport::sendReset(uint16_t streamId) { + // Requires mSendMutex to be locked + if (state() != State::Connected) + return; + + PLOG_DEBUG << "SCTP resetting stream " << streamId; + + using srs_t = struct sctp_reset_streams; + const size_t len = sizeof(srs_t) + sizeof(uint16_t); + byte buffer[len] = {}; + srs_t &srs = *reinterpret_cast(buffer); + srs.srs_flags = SCTP_STREAM_RESET_OUTGOING; + srs.srs_number_streams = 1; + srs.srs_stream_list[0] = streamId; + + mWritten = false; + if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RESET_STREAMS, &srs, len) == 0) { + std::unique_lock lock(mWriteMutex); // locking before setsockopt might deadlock usrsctp... + mWrittenCondition.wait_for(lock, 1000ms, + [&]() { return mWritten || state() != State::Connected; }); + } else if (errno == EINVAL) { + PLOG_DEBUG << "SCTP stream " << streamId << " already reset"; + } else { + PLOG_WARNING << "SCTP reset stream " << streamId << " failed, errno=" << errno; + } +} + +void SctpTransport::handleUpcall() noexcept { + try { + PLOG_VERBOSE << "Handle upcall"; + + int events = usrsctp_get_events(mSock); + + if (events & SCTP_EVENT_READ) + enqueueRecv(); + + if (events & SCTP_EVENT_WRITE) + enqueueFlush(); + + } catch (const std::exception &e) { + PLOG_ERROR << "SCTP upcall: " << e.what(); + } +} + +int SctpTransport::handleWrite(byte *data, size_t len, uint8_t /*tos*/, + uint8_t /*set_df*/) noexcept { + try { + std::unique_lock lock(mWriteMutex); + PLOG_VERBOSE << "Handle write, len=" << len; + + if (!outgoing(make_message(data, data + len))) + return -1; + + mWritten = true; + mWrittenOnce = true; + mWrittenCondition.notify_all(); + + } catch (const std::exception &e) { + PLOG_ERROR << "SCTP write: " << e.what(); + return -1; + } + return 0; // success +} + +void SctpTransport::processData(binary &&data, uint16_t sid, PayloadId ppid) { + PLOG_VERBOSE << "Process data, size=" << data.size(); + + // RFC 8831: The usage of the PPIDs "WebRTC String Partial" and "WebRTC Binary Partial" is + // deprecated. They were used for a PPID-based fragmentation and reassembly of user messages + // belonging to reliable and ordered data channels. + // See https://www.rfc-editor.org/rfc/rfc8831.html#section-6.6 + // We handle those PPIDs at reception for compatibility reasons but shall never send them. + switch (ppid) { + case PPID_CONTROL: + recv(make_message(std::move(data), Message::Control, sid)); + break; + + case PPID_STRING_PARTIAL: // deprecated + mPartialStringData.insert(mPartialStringData.end(), data.begin(), data.end()); + mPartialStringData.resize(mMaxMessageSize); + break; + + case PPID_STRING: + if (mPartialStringData.empty()) { + mBytesReceived += data.size(); + recv(make_message(std::move(data), Message::String, sid)); + } else { + mPartialStringData.insert(mPartialStringData.end(), data.begin(), data.end()); + mPartialStringData.resize(mMaxMessageSize); + mBytesReceived += mPartialStringData.size(); + auto message = make_message(std::move(mPartialStringData), Message::String, sid); + mPartialStringData.clear(); + recv(std::move(message)); + } + break; + + case PPID_STRING_EMPTY: + recv(make_message(std::move(mPartialStringData), Message::String, sid)); + mPartialStringData.clear(); + break; + + case PPID_BINARY_PARTIAL: // deprecated + mPartialBinaryData.insert(mPartialBinaryData.end(), data.begin(), data.end()); + mPartialBinaryData.resize(mMaxMessageSize); + break; + + case PPID_BINARY: + if (mPartialBinaryData.empty()) { + mBytesReceived += data.size(); + recv(make_message(std::move(data), Message::Binary, sid)); + } else { + mPartialBinaryData.insert(mPartialBinaryData.end(), data.begin(), data.end()); + mPartialBinaryData.resize(mMaxMessageSize); + mBytesReceived += mPartialBinaryData.size(); + auto message = make_message(std::move(mPartialBinaryData), Message::Binary, sid); + mPartialBinaryData.clear(); + recv(std::move(message)); + } + break; + + case PPID_BINARY_EMPTY: + recv(make_message(std::move(mPartialBinaryData), Message::Binary, sid)); + mPartialBinaryData.clear(); + break; + + default: + // Unknown + COUNTER_UNKNOWN_PPID++; + PLOG_VERBOSE << "Unknown PPID: " << uint32_t(ppid); + return; + } +} + +void SctpTransport::processNotification(const union sctp_notification *notify, size_t len) { + if (len != size_t(notify->sn_header.sn_length)) { + PLOG_WARNING << "Unexpected notification length, expected=" << notify->sn_header.sn_length + << ", actual=" << len; + return; + } + + auto type = notify->sn_header.sn_type; + PLOG_VERBOSE << "Processing notification, type=" << type; + + switch (type) { + case SCTP_ASSOC_CHANGE: { + PLOG_VERBOSE << "SCTP association change event"; + const struct sctp_assoc_change &sac = notify->sn_assoc_change; + if (sac.sac_state == SCTP_COMM_UP) { + PLOG_DEBUG << "SCTP negotiated streams: incoming=" << sac.sac_inbound_streams + << ", outgoing=" << sac.sac_outbound_streams; + mNegotiatedStreamsCount.emplace( + std::min(sac.sac_inbound_streams, sac.sac_outbound_streams)); + + PLOG_INFO << "SCTP connected"; + changeState(State::Connected); + } else { + if (state() == State::Connected) { + PLOG_INFO << "SCTP disconnected"; + changeState(State::Disconnected); + recv(nullptr); + } else { + PLOG_ERROR << "SCTP connection failed"; + changeState(State::Failed); + } + mWrittenCondition.notify_all(); + } + break; + } + + case SCTP_SENDER_DRY_EVENT: { + PLOG_VERBOSE << "SCTP sender dry event"; + // It should not be necessary since the send callback should have been called already, + // but to be sure, let's try to send now. + flush(); + break; + } + + case SCTP_STREAM_RESET_EVENT: { + const struct sctp_stream_reset_event &reset_event = notify->sn_strreset_event; + const int count = (reset_event.strreset_length - sizeof(reset_event)) / sizeof(uint16_t); + const uint16_t flags = reset_event.strreset_flags; + + IF_PLOG(plog::verbose) { + std::ostringstream desc; + desc << "flags="; + if (flags & SCTP_STREAM_RESET_OUTGOING_SSN && flags & SCTP_STREAM_RESET_INCOMING_SSN) + desc << "outgoing|incoming"; + else if (flags & SCTP_STREAM_RESET_OUTGOING_SSN) + desc << "outgoing"; + else if (flags & SCTP_STREAM_RESET_INCOMING_SSN) + desc << "incoming"; + else + desc << "0"; + + desc << ", streams=["; + for (int i = 0; i < count; ++i) { + uint16_t streamId = reset_event.strreset_stream_list[i]; + desc << (i != 0 ? "," : "") << streamId; + } + desc << "]"; + + PLOG_VERBOSE << "SCTP reset event, " << desc.str(); + } + + // RFC 8831 6.7. Closing a Data Channel + // If one side decides to close the data channel, it resets the corresponding outgoing + // stream. When the peer sees that an incoming stream was reset, it also resets its + // corresponding outgoing stream. + // See https://www.rfc-editor.org/rfc/rfc8831.html#section-6.7 + if (flags & SCTP_STREAM_RESET_INCOMING_SSN) { + for (int i = 0; i < count; ++i) { + uint16_t streamId = reset_event.strreset_stream_list[i]; + recv(make_message(0, Message::Reset, streamId)); + } + } + break; + } + + default: + // Ignore + break; + } +} + +void SctpTransport::clearStats() { + mBytesReceived = 0; + mBytesSent = 0; +} + +size_t SctpTransport::bytesSent() { return mBytesSent; } + +size_t SctpTransport::bytesReceived() { return mBytesReceived; } + +optional SctpTransport::rtt() { + if (state() != State::Connected) + return nullopt; + + struct sctp_status status = {}; + socklen_t len = sizeof(status); + if (usrsctp_getsockopt(mSock, IPPROTO_SCTP, SCTP_STATUS, &status, &len)) + return nullopt; + + return milliseconds(status.sstat_primary.spinfo_srtt); +} + +void SctpTransport::UpcallCallback(struct socket *, void *arg, int /* flags */) { + auto *transport = static_cast(arg); + + if (auto locked = Instances->lock(transport)) + transport->handleUpcall(); +} + +int SctpTransport::WriteCallback(void *ptr, void *data, size_t len, uint8_t tos, uint8_t set_df) { + auto *transport = static_cast(ptr); + +#ifndef SCTP_ACCEPT_ZERO_CHECKSUM + // Set the CRC32 ourselves as we have enabled CRC32 offloading + if (len >= 12) { + uint32_t *checksum = reinterpret_cast(data) + 2; + *checksum = 0; + *checksum = usrsctp_crc32c(data, len); + } +#endif + + // Workaround for sctplab/usrsctp#405: Send callback is invoked on already closed socket + // https://github.com/sctplab/usrsctp/issues/405 + if (auto locked = Instances->lock(transport)) + return transport->handleWrite(static_cast(data), len, tos, set_df); + else + return -1; +} + +void SctpTransport::DebugCallback(const char *format, ...) { + const size_t bufferSize = 1024; + char buffer[bufferSize]; + va_list va; + va_start(va, format); + int len = std::vsnprintf(buffer, bufferSize, format, va); + va_end(va); + if (len <= 0) + return; + + len = std::min(len, int(bufferSize - 1)); + buffer[len - 1] = '\0'; // remove newline + + PLOG_VERBOSE << "usrsctp: " << buffer; // usrsctp debug as verbose +} + +} // namespace rtc::impl diff --git a/datachannel/src/impl/sctptransport.hpp b/datachannel/src/impl/sctptransport.hpp new file mode 100644 index 000000000..82b02a70c --- /dev/null +++ b/datachannel/src/impl/sctptransport.hpp @@ -0,0 +1,135 @@ +/** + * Copyright (c) 2019-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_SCTP_TRANSPORT_H +#define RTC_IMPL_SCTP_TRANSPORT_H + +#include "common.hpp" +#include "configuration.hpp" +#include "global.hpp" +#include "processor.hpp" +#include "queue.hpp" +#include "transport.hpp" + +#include +#include +#include +#include + +#include "usrsctp.h" + +namespace rtc::impl { + +class SctpTransport final : public Transport, public std::enable_shared_from_this { +public: + static void Init(); + static void SetSettings(const SctpSettings &s); + static void Cleanup(); + + using amount_callback = std::function; + + struct Ports { + uint16_t local = DEFAULT_SCTP_PORT; + uint16_t remote = DEFAULT_SCTP_PORT; + }; + + SctpTransport(shared_ptr lower, const Configuration &config, Ports ports, + message_callback recvCallback, amount_callback bufferedAmountCallback, + state_callback stateChangeCallback); + ~SctpTransport(); + + void onBufferedAmount(amount_callback callback); + + void start() override; + void stop() override; + bool send(message_ptr message) override; // false if buffered + bool flush(); + void closeStream(unsigned int stream); + void close(); + + unsigned int maxStream() const; + + // Stats + void clearStats(); + size_t bytesSent(); + size_t bytesReceived(); + optional rtt(); + +private: + // Order seems wrong but these are the actual values + // See https://datatracker.ietf.org/doc/html/draft-ietf-rtcweb-data-channel-13#section-8 + enum PayloadId : uint32_t { + PPID_CONTROL = 50, + PPID_STRING = 51, + PPID_BINARY_PARTIAL = 52, + PPID_BINARY = 53, + PPID_STRING_PARTIAL = 54, + PPID_STRING_EMPTY = 56, + PPID_BINARY_EMPTY = 57 + }; + + struct sockaddr_conn getSockAddrConn(uint16_t port); + + void connect(); + void shutdown(); + void incoming(message_ptr message) override; + bool outgoing(message_ptr message) override; + + void doRecv(); + void doFlush(); + void enqueueRecv(); + void enqueueFlush(); + bool trySendQueue(); + bool trySendMessage(message_ptr message); + void updateBufferedAmount(uint16_t streamId, ptrdiff_t delta); + void triggerBufferedAmount(uint16_t streamId, size_t amount); + void sendReset(uint16_t streamId); + + void handleUpcall() noexcept; + int handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_df) noexcept; + + void processData(binary &&data, uint16_t streamId, PayloadId ppid); + void processNotification(const union sctp_notification *notify, size_t len); + + const size_t mMaxMessageSize; + const Ports mPorts; + struct socket *mSock; + std::optional mNegotiatedStreamsCount; + + Processor mProcessor; + std::atomic mPendingRecvCount = 0; + std::atomic mPendingFlushCount = 0; + std::mutex mRecvMutex; + std::recursive_mutex mSendMutex; // buffered amount callback is synchronous + Queue mSendQueue; + bool mSendShutdown = false; + std::map mBufferedAmount; + amount_callback mBufferedAmountCallback; + + std::mutex mWriteMutex; + std::condition_variable mWrittenCondition; + std::atomic mWritten = false; // written outside lock + std::atomic mWrittenOnce = false; // same + + binary mPartialMessage, mPartialNotification; + binary mPartialStringData, mPartialBinaryData; + + // Stats + std::atomic mBytesSent = 0, mBytesReceived = 0; + + static void UpcallCallback(struct socket *sock, void *arg, int flags); + static int WriteCallback(void *sctp_ptr, void *data, size_t len, uint8_t tos, uint8_t set_df); + static void DebugCallback(const char *format, ...); + + class InstancesSet; + static InstancesSet *Instances; +}; + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/sha.cpp b/datachannel/src/impl/sha.cpp new file mode 100644 index 000000000..c93626e1d --- /dev/null +++ b/datachannel/src/impl/sha.cpp @@ -0,0 +1,74 @@ +/** + * Copyright (c) 2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "sha.hpp" + +#if RTC_ENABLE_WEBSOCKET + +#if USE_GNUTLS + +#include + +#elif USE_MBEDTLS + +#include + +#else + +#ifndef OPENSSL_API_COMPAT +#define OPENSSL_API_COMPAT 0x10100000L +#endif + +#include + +#endif + +namespace rtc::impl { + +namespace { + +binary Sha1(const byte *data, size_t size) { +#if USE_GNUTLS + + binary output(SHA1_DIGEST_SIZE); + struct sha1_ctx ctx; + sha1_init(&ctx); + sha1_update(&ctx, size, reinterpret_cast(data)); + sha1_digest(&ctx, SHA1_DIGEST_SIZE, reinterpret_cast(output.data())); + return output; + +#elif USE_MBEDTLS + + binary output(20); + mbedtls_sha1(reinterpret_cast(data), size, + reinterpret_cast(output.data())); + return output; + +#else + + binary output(SHA_DIGEST_LENGTH); + SHA_CTX ctx; + SHA1_Init(&ctx); + SHA1_Update(&ctx, data, size); + SHA1_Final(reinterpret_cast(output.data()), &ctx); + return output; + +#endif +} + +} // namespace + +binary Sha1(const binary &input) { return Sha1(input.data(), input.size()); } + +binary Sha1(const string &input) { + return Sha1(reinterpret_cast(input.data()), input.size()); +} + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/sha.hpp b/datachannel/src/impl/sha.hpp new file mode 100644 index 000000000..0ef960ba2 --- /dev/null +++ b/datachannel/src/impl/sha.hpp @@ -0,0 +1,25 @@ +/** + * Copyright (c) 2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_SHA_H +#define RTC_IMPL_SHA_H + +#if RTC_ENABLE_WEBSOCKET + +#include "common.hpp" + +namespace rtc::impl { + +binary Sha1(const binary &input); +binary Sha1(const string &input); + +} // namespace rtc::impl + +#endif + +#endif diff --git a/datachannel/src/impl/socket.hpp b/datachannel/src/impl/socket.hpp new file mode 100644 index 000000000..7ad9b50d1 --- /dev/null +++ b/datachannel/src/impl/socket.hpp @@ -0,0 +1,132 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +// This header defines types to allow cross-platform socket API usage. + +#ifndef RTC_SOCKET_H +#define RTC_SOCKET_H + +#ifdef _WIN32 + +#ifndef _WIN32_WINNT +#define _WIN32_WINNT 0x0601 // Windows 7 +#endif +#ifndef __MSVCRT_VERSION__ +#define __MSVCRT_VERSION__ 0x0601 +#endif + +#include +#include +// +#include +#include + +#ifdef __MINGW32__ +#include +#include +#ifndef IPV6_V6ONLY +#define IPV6_V6ONLY 27 +#endif +#endif + +#define NO_IFADDRS +#define NO_PMTUDISC + +typedef SOCKET socket_t; +typedef SOCKADDR sockaddr; +typedef ULONG ctl_t; +typedef DWORD sockopt_t; +#define sockerrno ((int)WSAGetLastError()) +#define IP_DONTFRAG IP_DONTFRAGMENT +#define HOST_NAME_MAX 256 + +#define poll WSAPoll +typedef ULONG nfds_t; + +#define SEADDRINUSE WSAEADDRINUSE +#define SEINTR WSAEINTR +#define SEAGAIN WSAEWOULDBLOCK +#define SEACCES WSAEACCES +#define SEWOULDBLOCK WSAEWOULDBLOCK +#define SEINPROGRESS WSAEINPROGRESS +#define SECONNREFUSED WSAECONNREFUSED +#define SECONNRESET WSAECONNRESET +#define SENETRESET WSAENETRESET + +#else // assume POSIX + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef __linux__ +#define NO_PMTUDISC +#endif + +#ifdef __ANDROID__ +#define NO_IFADDRS +#else +#include +#endif + +typedef int socket_t; +typedef int ctl_t; +typedef int sockopt_t; +#define sockerrno errno +#define INVALID_SOCKET -1 +#define ioctlsocket ioctl +#define closesocket close + +#define SEADDRINUSE EADDRINUSE +#define SEINTR EINTR +#define SEAGAIN EAGAIN +#define SEACCES EACCES +#define SEWOULDBLOCK EWOULDBLOCK +#define SEINPROGRESS EINPROGRESS +#define SECONNREFUSED ECONNREFUSED +#define SECONNRESET ECONNRESET +#define SENETRESET ENETRESET + +#endif // _WIN32 + +#ifndef IN6_IS_ADDR_LOOPBACK +#define IN6_IS_ADDR_LOOPBACK(a) \ + (((const uint32_t *)(a))[0] == 0 && ((const uint32_t *)(a))[1] == 0 && \ + ((const uint32_t *)(a))[2] == 0 && ((const uint32_t *)(a))[3] == htonl(1)) +#endif + +#ifndef IN6_IS_ADDR_LINKLOCAL +#define IN6_IS_ADDR_LINKLOCAL(a) \ + ((((const uint32_t *)(a))[0] & htonl(0xffc00000)) == htonl(0xfe800000)) +#endif + +#ifndef IN6_IS_ADDR_SITELOCAL +#define IN6_IS_ADDR_SITELOCAL(a) \ + ((((const uint32_t *)(a))[0] & htonl(0xffc00000)) == htonl(0xfec00000)) +#endif + +#ifndef IN6_IS_ADDR_V4MAPPED +#define IN6_IS_ADDR_V4MAPPED(a) \ + ((((const uint32_t *)(a))[0] == 0) && (((const uint32_t *)(a))[1] == 0) && \ + (((const uint32_t *)(a))[2] == htonl(0xFFFF))) +#endif + +#endif // JUICE_SOCKET_H diff --git a/datachannel/src/impl/tcpserver.cpp b/datachannel/src/impl/tcpserver.cpp new file mode 100644 index 000000000..25df61dcb --- /dev/null +++ b/datachannel/src/impl/tcpserver.cpp @@ -0,0 +1,190 @@ +/** + * Copyright (c) 2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "tcpserver.hpp" +#include "internals.hpp" + +#if RTC_ENABLE_WEBSOCKET + +#ifdef _WIN32 +#include +#else +#include +#include +#include +#endif + +namespace rtc::impl { + +TcpServer::TcpServer(uint16_t port, const char *bindAddress) { + PLOG_DEBUG << "Initializing TCP server"; + listen(port, bindAddress); +} + +TcpServer::~TcpServer() { close(); } + +shared_ptr TcpServer::accept() { + while (true) { + std::unique_lock lock(mSockMutex); + + if (mSock == INVALID_SOCKET) + break; + + struct pollfd pfd[2]; + mInterrupter.prepare(pfd[0]); + pfd[1].fd = mSock; + pfd[1].events = POLLIN; + + lock.unlock(); + int ret = ::poll(pfd, 2, -1); + lock.lock(); + + if (mSock == INVALID_SOCKET) + break; + + if (ret < 0) { + if (sockerrno == SEINTR || sockerrno == SEAGAIN) // interrupted + continue; + else + throw std::runtime_error("Failed to wait for socket connection"); + } + + mInterrupter.process(pfd[0]); + + if (pfd[1].revents & POLLNVAL || pfd[1].revents & POLLERR) { + throw std::runtime_error("Error while waiting for socket connection"); + } + + if (pfd[1].revents & POLLIN) { + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + socket_t incomingSock = ::accept(mSock, (struct sockaddr *)&addr, &addrlen); + + if (incomingSock != INVALID_SOCKET) { + return std::make_shared(incomingSock, nullptr); // no state callback + + } else if (sockerrno != SEAGAIN && sockerrno != SEWOULDBLOCK) { + PLOG_ERROR << "TCP server failed, errno=" << sockerrno; + throw std::runtime_error("TCP server failed"); + } + } + } + + PLOG_DEBUG << "TCP server closed"; + return nullptr; +} + +void TcpServer::close() { + std::unique_lock lock(mSockMutex); + if (mSock != INVALID_SOCKET) { + PLOG_DEBUG << "Closing TCP server socket"; + ::closesocket(mSock); + mSock = INVALID_SOCKET; + mInterrupter.interrupt(); + } +} + +void TcpServer::listen(uint16_t port, const char *bindAddress) { + PLOG_DEBUG << "Listening on port " << port; + + struct addrinfo hints = {}; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + hints.ai_flags = AI_PASSIVE | AI_NUMERICSERV; + + struct addrinfo *result = nullptr; + if (getaddrinfo(bindAddress, std::to_string(port).c_str(), &hints, &result)) + throw std::runtime_error("Resolution failed for local address"); + + try { + static const auto find_family = [](struct addrinfo *ai_list, int family) { + struct addrinfo *ai = ai_list; + while (ai && ai->ai_family != family) + ai = ai->ai_next; + return ai; + }; + + struct addrinfo *ai; + if ((ai = find_family(result, AF_INET6)) == NULL && + (ai = find_family(result, AF_INET)) == NULL) + throw std::runtime_error("No suitable address family found"); + + std::unique_lock lock(mSockMutex); + PLOG_VERBOSE << "Creating TCP server socket"; + + // Create socket + mSock = ::socket(ai->ai_family, SOCK_STREAM, IPPROTO_TCP); + if (mSock == INVALID_SOCKET) + throw std::runtime_error("TCP server socket creation failed"); + + const sockopt_t enabled = 1; + const sockopt_t disabled = 0; + + // Enable REUSEADDR + ::setsockopt(mSock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&enabled), + sizeof(enabled)); + + // Listen on both IPv6 and IPv4 + if (ai->ai_family == AF_INET6) + ::setsockopt(mSock, IPPROTO_IPV6, IPV6_V6ONLY, + reinterpret_cast(&disabled), sizeof(disabled)); + + // Set non-blocking + ctl_t nbio = 1; + if (::ioctlsocket(mSock, FIONBIO, &nbio) < 0) + throw std::runtime_error("Failed to set socket non-blocking mode"); + + // Bind socket + if (::bind(mSock, ai->ai_addr, socklen_t(ai->ai_addrlen)) < 0) { + PLOG_WARNING << "TCP server socket binding on port " << port + << " failed, errno=" << sockerrno; + throw std::runtime_error("TCP server socket binding failed"); + } + + // Listen + const int backlog = 10; + if (::listen(mSock, backlog) < 0) { + PLOG_WARNING << "TCP server socket listening failed, errno=" << sockerrno; + throw std::runtime_error("TCP server socket listening failed"); + } + + if (port != 0) { + mPort = port; + } else { + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + if (::getsockname(mSock, reinterpret_cast(&addr), &addrlen) < 0) + throw std::runtime_error("getsockname failed"); + + switch (addr.ss_family) { + case AF_INET: + mPort = ntohs(reinterpret_cast(&addr)->sin_port); + break; + case AF_INET6: + mPort = ntohs(reinterpret_cast(&addr)->sin6_port); + break; + default: + throw std::logic_error("Unknown address family"); + } + } + } catch (...) { + freeaddrinfo(result); + if (mSock != INVALID_SOCKET) { + ::closesocket(mSock); + mSock = INVALID_SOCKET; + } + throw; + } + + freeaddrinfo(result); +} + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/tcpserver.hpp b/datachannel/src/impl/tcpserver.hpp new file mode 100644 index 000000000..e16557557 --- /dev/null +++ b/datachannel/src/impl/tcpserver.hpp @@ -0,0 +1,48 @@ +/** + * Copyright (c) 2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_TCP_SERVER_H +#define RTC_IMPL_TCP_SERVER_H + +#include "common.hpp" +#include "pollinterrupter.hpp" +#include "queue.hpp" +#include "socket.hpp" +#include "tcptransport.hpp" + +#if RTC_ENABLE_WEBSOCKET + +namespace rtc::impl { + +class TcpServer final { +public: + TcpServer(uint16_t port, const char *bindAddress = nullptr); + ~TcpServer(); + + TcpServer(const TcpServer &other) = delete; + void operator=(const TcpServer &other) = delete; + + shared_ptr accept(); + void close(); + + uint16_t port() const { return mPort; } + +private: + void listen(uint16_t port, const char *bindAddress); + + uint16_t mPort; + socket_t mSock = INVALID_SOCKET; + std::mutex mSockMutex; + PollInterrupter mInterrupter; +}; + +} // namespace rtc::impl + +#endif + +#endif diff --git a/datachannel/src/impl/tcptransport.cpp b/datachannel/src/impl/tcptransport.cpp new file mode 100644 index 000000000..b3af2c209 --- /dev/null +++ b/datachannel/src/impl/tcptransport.cpp @@ -0,0 +1,473 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "tcptransport.hpp" +#include "internals.hpp" +#include "threadpool.hpp" + +#if RTC_ENABLE_WEBSOCKET + +#ifndef _WIN32 +#include +#include +#endif + +#include + +namespace rtc::impl { + +using namespace std::placeholders; +using namespace std::chrono_literals; +using std::chrono::duration_cast; +using std::chrono::milliseconds; + +namespace { + +bool unmap_inet6_v4mapped(struct sockaddr *sa, socklen_t *len) { + if (sa->sa_family != AF_INET6) + return false; + + const auto *sin6 = reinterpret_cast(sa); + if (!IN6_IS_ADDR_V4MAPPED(&sin6->sin6_addr)) + return false; + + struct sockaddr_in6 copy = *sin6; + sin6 = © + + auto *sin = reinterpret_cast(sa); + std::memset(sin, 0, sizeof(*sin)); + sin->sin_family = AF_INET; + sin->sin_port = sin6->sin6_port; + std::memcpy(&sin->sin_addr, reinterpret_cast(&sin6->sin6_addr) + 12, 4); + *len = sizeof(*sin); + return true; +} + +} + +TcpTransport::TcpTransport(string hostname, string service, state_callback callback) + : Transport(nullptr, std::move(callback)), mIsActive(true), mHostname(std::move(hostname)), + mService(std::move(service)), mSock(INVALID_SOCKET) { + + PLOG_DEBUG << "Initializing TCP transport"; +} + +TcpTransport::TcpTransport(socket_t sock, state_callback callback) + : Transport(nullptr, std::move(callback)), mIsActive(false), mSock(sock) { + + PLOG_DEBUG << "Initializing TCP transport with socket"; + + // Configure socket + configureSocket(); + + // Retrieve hostname and service + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + if (::getpeername(mSock, reinterpret_cast(&addr), &addrlen) < 0) + throw std::runtime_error("getsockname failed"); + + unmap_inet6_v4mapped(reinterpret_cast(&addr), &addrlen); + + char node[MAX_NUMERICNODE_LEN]; + char serv[MAX_NUMERICSERV_LEN]; + if (::getnameinfo(reinterpret_cast(&addr), addrlen, node, + MAX_NUMERICNODE_LEN, serv, MAX_NUMERICSERV_LEN, + NI_NUMERICHOST | NI_NUMERICSERV) != 0) + throw std::runtime_error("getnameinfo failed"); + + mHostname = node; + mService = serv; +} + +TcpTransport::~TcpTransport() { close(); } + +void TcpTransport::onBufferedAmount(amount_callback callback) { + mBufferedAmountCallback = std::move(callback); +} + +void TcpTransport::setReadTimeout(std::chrono::milliseconds readTimeout) { + mReadTimeout = readTimeout; +} + +void TcpTransport::start() { + if (mSock == INVALID_SOCKET) { + connect(); + } else { + changeState(State::Connected); + setPoll(PollService::Direction::In); + } +} + +bool TcpTransport::send(message_ptr message) { + std::lock_guard lock(mSendMutex); + + if (state() != State::Connected) + throw std::runtime_error("Connection is not open"); + + if (!message || message->size() == 0) + return trySendQueue(); + + PLOG_VERBOSE << "Send size=" << message->size(); + return outgoing(message); +} + +void TcpTransport::incoming(message_ptr message) { + if (!message) + return; + + PLOG_VERBOSE << "Incoming size=" << message->size(); + recv(message); +} + +bool TcpTransport::outgoing(message_ptr message) { + // mSendMutex must be locked + // Flush the queue, and if nothing is pending, try to send directly + if (trySendQueue() && trySendMessage(message)) + return true; + + mSendQueue.push(message); + updateBufferedAmount(ptrdiff_t(message->size())); + setPoll(PollService::Direction::Both); + return false; +} + +bool TcpTransport::isActive() const { return mIsActive; } + +string TcpTransport::remoteAddress() const { return mHostname + ':' + mService; } + +void TcpTransport::connect() { + if (state() == State::Connecting) + throw std::logic_error("TCP connection is already in progress"); + + if (state() == State::Connected) + throw std::logic_error("TCP is already connected"); + + PLOG_DEBUG << "Connecting to " << mHostname << ":" << mService; + changeState(State::Connecting); + + ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::resolve, this)); +} + +void TcpTransport::resolve() { + std::lock_guard lock(mSendMutex); + mResolved.clear(); + + if (state() != State::Connecting) + return; // Cancelled + + try { + PLOG_DEBUG << "Resolving " << mHostname << ":" << mService; + + struct addrinfo hints = {}; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + hints.ai_flags = AI_ADDRCONFIG; + + struct addrinfo *result = nullptr; + if (getaddrinfo(mHostname.c_str(), mService.c_str(), &hints, &result)) + throw std::runtime_error("Resolution failed for \"" + mHostname + ":" + mService + + "\""); + + try { + struct addrinfo *ai = result; + while (ai) { + struct sockaddr_storage addr; + std::memcpy(&addr, ai->ai_addr, ai->ai_addrlen); + mResolved.emplace_back(addr, socklen_t(ai->ai_addrlen)); + ai = ai->ai_next; + } + + } catch (...) { + freeaddrinfo(result); + throw; + } + + freeaddrinfo(result); + + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + changeState(State::Failed); + return; + } + + ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this)); +} + +void TcpTransport::attempt() { + std::lock_guard lock(mSendMutex); + + if (state() != State::Connecting) + return; // Cancelled + + if (mSock == INVALID_SOCKET) { + ::closesocket(mSock); + mSock = INVALID_SOCKET; + } + + if (mResolved.empty()) { + PLOG_WARNING << "Connection to " << mHostname << ":" << mService << " failed"; + changeState(State::Failed); + return; + } + + try { + auto [addr, addrlen] = mResolved.front(); + mResolved.pop_front(); + + createSocket(reinterpret_cast(&addr), addrlen); + + } catch (const std::runtime_error &e) { + PLOG_DEBUG << e.what(); + ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this)); + return; + } + + // Poll out event callback + auto callback = [this](PollService::Event event) { + try { + if (event == PollService::Event::Error) + throw std::runtime_error("TCP connection failed"); + + if (event == PollService::Event::Timeout) + throw std::runtime_error("TCP connection timed out"); + + if (event != PollService::Event::Out) + return; + + int err = 0; + socklen_t errlen = sizeof(err); + if (::getsockopt(mSock, SOL_SOCKET, SO_ERROR, reinterpret_cast(&err), + &errlen) != 0) + throw std::runtime_error("Failed to get socket error code"); + + if (err != 0) { + std::ostringstream msg; + msg << "TCP connection failed, errno=" << err; + throw std::runtime_error(msg.str()); + } + + // Success + PLOG_INFO << "TCP connected"; + changeState(State::Connected); + setPoll(PollService::Direction::In); + + } catch (const std::exception &e) { + PLOG_DEBUG << e.what(); + PollService::Instance().remove(mSock); + ThreadPool::Instance().enqueue(weak_bind(&TcpTransport::attempt, this)); + } + }; + + const auto timeout = 10s; + PollService::Instance().add(mSock, {PollService::Direction::Out, timeout, std::move(callback)}); +} + +void TcpTransport::createSocket(const struct sockaddr *addr, socklen_t addrlen) { + try { + char node[MAX_NUMERICNODE_LEN]; + char serv[MAX_NUMERICSERV_LEN]; + if (getnameinfo(addr, addrlen, node, MAX_NUMERICNODE_LEN, serv, MAX_NUMERICSERV_LEN, + NI_NUMERICHOST | NI_NUMERICSERV) == 0) { + PLOG_DEBUG << "Trying address " << node << ":" << serv; + } + + PLOG_VERBOSE << "Creating TCP socket"; + + // Create socket + mSock = ::socket(addr->sa_family, SOCK_STREAM, IPPROTO_TCP); + if (mSock == INVALID_SOCKET) + throw std::runtime_error("TCP socket creation failed"); + + // Configure socket + configureSocket(); + + // Initiate connection + int ret = ::connect(mSock, addr, addrlen); + if (ret < 0 && sockerrno != SEINPROGRESS && sockerrno != SEWOULDBLOCK) { + std::ostringstream msg; + msg << "TCP connection to " << node << ":" << serv << " failed, errno=" << sockerrno; + throw std::runtime_error(msg.str()); + } + + } catch (...) { + if (mSock != INVALID_SOCKET) { + ::closesocket(mSock); + mSock = INVALID_SOCKET; + } + throw; + } +} + +void TcpTransport::configureSocket() { + // Set non-blocking + ctl_t nbio = 1; + if (::ioctlsocket(mSock, FIONBIO, &nbio) < 0) + throw std::runtime_error("Failed to set socket non-blocking mode"); + + // Disable the Nagle algorithm + int nodelay = 1; + ::setsockopt(mSock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&nodelay), + sizeof(nodelay)); + +#ifdef __APPLE__ + // MacOS lacks MSG_NOSIGNAL and requires SO_NOSIGPIPE instead + const sockopt_t enabled = 1; + if (::setsockopt(mSock, SOL_SOCKET, SO_NOSIGPIPE, &enabled, sizeof(enabled)) < 0) + throw std::runtime_error("Failed to disable SIGPIPE for socket"); +#endif +} + +void TcpTransport::setPoll(PollService::Direction direction) { + PollService::Instance().add( + mSock, {direction, direction == PollService::Direction::In ? mReadTimeout : nullopt, + std::bind(&TcpTransport::process, this, _1)}); +} + +void TcpTransport::close() { + std::lock_guard lock(mSendMutex); + if (mSock != INVALID_SOCKET) { + PLOG_DEBUG << "Closing TCP socket"; + PollService::Instance().remove(mSock); + ::closesocket(mSock); + mSock = INVALID_SOCKET; + } + changeState(State::Disconnected); +} + +bool TcpTransport::trySendQueue() { + // mSendMutex must be locked + while (auto next = mSendQueue.peek()) { + message_ptr message = std::move(*next); + size_t size = message->size(); + if (!trySendMessage(message)) { // replaces message + mSendQueue.exchange(message); + updateBufferedAmount(-ptrdiff_t(size) + ptrdiff_t(message->size())); + return false; + } + + mSendQueue.pop(); + updateBufferedAmount(-ptrdiff_t(size)); + } + + return true; +} + +bool TcpTransport::trySendMessage(message_ptr &message) { + // mSendMutex must be locked + + auto data = reinterpret_cast(message->data()); + auto size = message->size(); + while (size) { +#if defined(__APPLE__) || defined(_WIN32) + int flags = 0; +#else + int flags = MSG_NOSIGNAL; +#endif + int len = ::send(mSock, data, int(size), flags); + if (len < 0) { + if (sockerrno == SEAGAIN || sockerrno == SEWOULDBLOCK) { + message = make_message(message->end() - size, message->end()); + return false; + } else { + PLOG_ERROR << "Connection closed, errno=" << sockerrno; + throw std::runtime_error("Connection closed"); + } + } + + data += len; + size -= len; + } + message = nullptr; + return true; +} + +void TcpTransport::updateBufferedAmount(ptrdiff_t delta) { + // Requires mSendMutex to be locked + + if (delta == 0) + return; + + mBufferedAmount = size_t(std::max(ptrdiff_t(mBufferedAmount) + delta, ptrdiff_t(0))); + + // Synchronously call the buffered amount callback + triggerBufferedAmount(mBufferedAmount); +} + +void TcpTransport::triggerBufferedAmount(size_t amount) { + try { + mBufferedAmountCallback(amount); + } catch (const std::exception &e) { + PLOG_WARNING << "TCP buffered amount callback: " << e.what(); + } +} + +void TcpTransport::process(PollService::Event event) { + auto self = weak_from_this().lock(); + if (!self) + return; + + try { + switch (event) { + case PollService::Event::Error: { + PLOG_WARNING << "TCP connection terminated"; + break; + } + + case PollService::Event::Timeout: { + PLOG_VERBOSE << "TCP is idle"; + incoming(make_message(0)); + setPoll(PollService::Direction::In); + return; + } + + case PollService::Event::Out: { + if (trySendQueue()) + setPoll(PollService::Direction::In); + + return; + } + + case PollService::Event::In: { + const size_t bufferSize = 4096; + char buffer[bufferSize]; + int len; + while ((len = ::recv(mSock, buffer, bufferSize, 0)) > 0) { + auto *b = reinterpret_cast(buffer); + incoming(make_message(b, b + len)); + } + + if (len == 0) + break; // clean close + + if (sockerrno != SEAGAIN && sockerrno != SEWOULDBLOCK) { + PLOG_WARNING << "TCP connection lost"; + break; + } + + return; + } + + default: + // Ignore + return; + } + + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + } + + PLOG_INFO << "TCP disconnected"; + PollService::Instance().remove(mSock); + changeState(State::Disconnected); + recv(nullptr); +} + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/tcptransport.hpp b/datachannel/src/impl/tcptransport.hpp new file mode 100644 index 000000000..02436ea82 --- /dev/null +++ b/datachannel/src/impl/tcptransport.hpp @@ -0,0 +1,80 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_TCP_TRANSPORT_H +#define RTC_IMPL_TCP_TRANSPORT_H + +#include "common.hpp" +#include "pollservice.hpp" +#include "queue.hpp" +#include "socket.hpp" +#include "transport.hpp" + +#if RTC_ENABLE_WEBSOCKET + +#include +#include +#include +#include + +namespace rtc::impl { + +class TcpTransport final : public Transport, public std::enable_shared_from_this { +public: + using amount_callback = std::function; + + TcpTransport(string hostname, string service, state_callback callback); // active + TcpTransport(socket_t sock, state_callback callback); // passive + ~TcpTransport(); + + void onBufferedAmount(amount_callback callback); + void setReadTimeout(std::chrono::milliseconds readTimeout); + + void start() override; + bool send(message_ptr message) override; + + void incoming(message_ptr message) override; + bool outgoing(message_ptr message) override; + + bool isActive() const; + string remoteAddress() const; + +private: + void connect(); + void resolve(); + void attempt(); + void createSocket(const struct sockaddr *addr, socklen_t addrlen); + void configureSocket(); + void setPoll(PollService::Direction direction); + void close(); + + bool trySendQueue(); + bool trySendMessage(message_ptr &message); + void updateBufferedAmount(ptrdiff_t delta); + void triggerBufferedAmount(size_t amount); + + void process(PollService::Event event); + + const bool mIsActive; + string mHostname, mService; + amount_callback mBufferedAmountCallback; + optional mReadTimeout; + + std::list> mResolved; + + socket_t mSock; + Queue mSendQueue; + size_t mBufferedAmount = 0; + std::mutex mSendMutex; +}; + +} // namespace rtc::impl + +#endif + +#endif diff --git a/datachannel/src/impl/threadpool.cpp b/datachannel/src/impl/threadpool.cpp new file mode 100644 index 000000000..1fda4e26e --- /dev/null +++ b/datachannel/src/impl/threadpool.cpp @@ -0,0 +1,97 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "threadpool.hpp" +#include "utils.hpp" + +namespace rtc::impl { + +ThreadPool &ThreadPool::Instance() { + static ThreadPool *instance = new ThreadPool; + return *instance; +} + +ThreadPool::ThreadPool() {} + +ThreadPool::~ThreadPool() {} + +int ThreadPool::count() const { + std::unique_lock lock(mWorkersMutex); + return int(mWorkers.size()); +} + +void ThreadPool::spawn(int count) { + std::unique_lock lock(mWorkersMutex); + while (count-- > 0) + mWorkers.emplace_back(std::bind(&ThreadPool::run, this)); +} + +void ThreadPool::join() { + { + std::unique_lock lock(mMutex); + mWaitingCondition.wait(lock, [&]() { return mBusyWorkers == 0; }); + mJoining = true; + mTasksCondition.notify_all(); + } + + std::unique_lock lock(mWorkersMutex); + for (auto &w : mWorkers) + w.join(); + + mWorkers.clear(); + + mJoining = false; +} + +void ThreadPool::clear() { + std::unique_lock lock(mMutex); + while (!mTasks.empty()) + mTasks.pop(); +} + +void ThreadPool::run() { + utils::this_thread::set_name("RTC worker"); + ++mBusyWorkers; + scope_guard guard([&]() { --mBusyWorkers; }); + while (runOne()) { + } +} + +bool ThreadPool::runOne() { + if (auto task = dequeue()) { + task(); + return true; + } + return false; +} + +std::function ThreadPool::dequeue() { + std::unique_lock lock(mMutex); + while (!mJoining) { + std::optional time; + if (!mTasks.empty()) { + time = mTasks.top().time; + if (*time <= clock::now()) { + auto func = std::move(mTasks.top().func); + mTasks.pop(); + return func; + } + } + + --mBusyWorkers; + scope_guard guard([&]() { ++mBusyWorkers; }); + mWaitingCondition.notify_all(); + if (time) + mTasksCondition.wait_until(lock, *time); + else + mTasksCondition.wait(lock); + } + return nullptr; +} + +} // namespace rtc::impl diff --git a/datachannel/src/impl/threadpool.hpp b/datachannel/src/impl/threadpool.hpp new file mode 100644 index 000000000..1cc207342 --- /dev/null +++ b/datachannel/src/impl/threadpool.hpp @@ -0,0 +1,118 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_THREADPOOL_H +#define RTC_IMPL_THREADPOOL_H + +#include "common.hpp" +#include "init.hpp" +#include "internals.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace rtc::impl { + +template +using invoke_future_t = std::future, std::decay_t...>>; + +class ThreadPool final { +public: + using clock = std::chrono::steady_clock; + + static ThreadPool &Instance(); + + ThreadPool(const ThreadPool &) = delete; + ThreadPool &operator=(const ThreadPool &) = delete; + ThreadPool(ThreadPool &&) = delete; + ThreadPool &operator=(ThreadPool &&) = delete; + + int count() const; + void spawn(int count = 1); + void join(); + void clear(); + void run(); + bool runOne(); + + template + auto enqueue(F &&f, Args &&...args) noexcept -> invoke_future_t; + + template + auto schedule(clock::duration delay, F &&f, Args &&...args) noexcept + -> invoke_future_t; + + template + auto schedule(clock::time_point time, F &&f, Args &&...args) noexcept + -> invoke_future_t; + +private: + ThreadPool(); + ~ThreadPool(); + + std::function dequeue(); // returns null function if joining + + std::vector mWorkers; + std::atomic mBusyWorkers = 0; + std::atomic mJoining = false; + + struct Task { + clock::time_point time; + std::function func; + bool operator>(const Task &other) const { return time > other.time; } + bool operator<(const Task &other) const { return time < other.time; } + }; + std::priority_queue, std::greater> mTasks; + + std::condition_variable mTasksCondition, mWaitingCondition; + mutable std::mutex mMutex, mWorkersMutex; +}; + +template +auto ThreadPool::enqueue(F &&f, Args &&...args) noexcept -> invoke_future_t { + return schedule(clock::now(), std::forward(f), std::forward(args)...); +} + +template +auto ThreadPool::schedule(clock::duration delay, F &&f, Args &&...args) noexcept + -> invoke_future_t { + return schedule(clock::now() + delay, std::forward(f), std::forward(args)...); +} + +template +auto ThreadPool::schedule(clock::time_point time, F &&f, Args &&...args) noexcept + -> invoke_future_t { + std::unique_lock lock(mMutex); + using R = std::invoke_result_t, std::decay_t...>; + auto bound = std::bind(std::forward(f), std::forward(args)...); + auto task = std::make_shared>([bound = std::move(bound)]() mutable { + try { + return bound(); + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + throw; + } + }); + std::future result = task->get_future(); + + mTasks.push({time, [task = std::move(task)]() { return (*task)(); }}); + mTasksCondition.notify_one(); + return result; +} + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/tls.cpp b/datachannel/src/impl/tls.cpp new file mode 100644 index 000000000..bc9d7ba1a --- /dev/null +++ b/datachannel/src/impl/tls.cpp @@ -0,0 +1,231 @@ +/** + * Copyright (c) 2019-2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "tls.hpp" + +#include +#include + +#if USE_GNUTLS + +namespace rtc::gnutls { + +// Return false on non-fatal error +bool check(int ret, const string &message) { + if (ret < 0) { + if (!gnutls_error_is_fatal(ret)) { + return false; + } + throw std::runtime_error(message + ": " + gnutls_strerror(ret)); + } + return true; +} + +gnutls_certificate_credentials_t *new_credentials() { + auto creds = new gnutls_certificate_credentials_t; + gnutls::check(gnutls_certificate_allocate_credentials(creds)); + return creds; +} + +void free_credentials(gnutls_certificate_credentials_t *creds) { + gnutls_certificate_free_credentials(*creds); + delete creds; +} + +gnutls_x509_crt_t *new_crt() { + auto crt = new gnutls_x509_crt_t; + gnutls::check(gnutls_x509_crt_init(crt)); + return crt; +} + +void free_crt(gnutls_x509_crt_t *crt) { + gnutls_x509_crt_deinit(*crt); + delete crt; +} + +gnutls_x509_privkey_t *new_privkey() { + auto privkey = new gnutls_x509_privkey_t; + gnutls::check(gnutls_x509_privkey_init(privkey)); + return privkey; +} + +void free_privkey(gnutls_x509_privkey_t *privkey) { + gnutls_x509_privkey_deinit(*privkey); + delete privkey; +} + +gnutls_datum_t make_datum(char *data, size_t size) { + gnutls_datum_t datum; + datum.data = reinterpret_cast(data); + datum.size = size; + return datum; +} + +} // namespace rtc::gnutls + +#elif USE_MBEDTLS + +#include + +namespace { + +// Safe gmtime +int my_gmtime(const time_t *t, struct tm *buf) { +#ifdef _WIN32 + return ::gmtime_s(buf, t) == 0 ? 0 : -1; +#else // POSIX + return ::gmtime_r(t, buf) != NULL ? 0 : -1; +#endif +} + +// Format time_t as UTC +size_t my_strftme(char *buf, size_t size, const char *format, const time_t *t) { + struct tm g; + if (my_gmtime(t, &g) != 0) + return 0; + + return ::strftime(buf, size, format, &g); +} + +} // namespace + +namespace rtc::mbedtls { + +// Return false on non-fatal error +bool check(int ret, const string &message) { + if (ret < 0) { + if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE || + ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS || ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS || + ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY || ret == MBEDTLS_ERR_SSL_RECEIVED_NEW_SESSION_TICKET) + return false; + + const size_t bufferSize = 1024; + char buffer[bufferSize]; + mbedtls_strerror(ret, reinterpret_cast(buffer), bufferSize); + throw std::runtime_error(message + ": " + std::string(buffer)); + } + return true; +} + +string format_time(const std::chrono::system_clock::time_point &tp) { + time_t t = std::chrono::system_clock::to_time_t(tp); + const size_t bufferSize = 256; + char buffer[bufferSize]; + if (my_strftme(buffer, bufferSize, "%Y%m%d%H%M%S", &t) == 0) + throw std::runtime_error("Time conversion failed"); + + return string(buffer); +}; + +std::shared_ptr new_pk_context() { + return std::shared_ptr{[]() { + auto p = new mbedtls_pk_context; + mbedtls_pk_init(p); + return p; + }(), + [](mbedtls_pk_context *p) { + mbedtls_pk_free(p); + delete p; + }}; +} + +std::shared_ptr new_x509_crt() { + return std::shared_ptr{[]() { + auto p = new mbedtls_x509_crt; + mbedtls_x509_crt_init(p); + return p; + }(), + [](mbedtls_x509_crt *crt) { + mbedtls_x509_crt_free(crt); + delete crt; + }}; +} + +} // namespace rtc::mbedtls + +#else // OPENSSL + +namespace rtc::openssl { + +void init() { + static std::mutex mutex; + static bool done = false; + + std::lock_guard lock(mutex); + if (!std::exchange(done, true)) { + OPENSSL_init_ssl(OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, nullptr); + OPENSSL_init_crypto(OPENSSL_INIT_LOAD_CRYPTO_STRINGS, nullptr); + } +} + +string error_string(unsigned long error) { + const size_t bufferSize = 256; + char buffer[bufferSize]; + ERR_error_string_n(error, buffer, bufferSize); + return string(buffer); +} + +bool check(int success, const string &message) { + unsigned long last_error = ERR_peek_last_error(); + ERR_clear_error(); + + if (success > 0) + return true; + + throw std::runtime_error(message + (last_error != 0 ? ": " + error_string(last_error) : "")); +} + +// Return false on recoverable error +bool check_error(int err, const string &message) { + unsigned long last_error = ERR_peek_last_error(); + ERR_clear_error(); + + if (err == SSL_ERROR_NONE) + return true; + + if (err == SSL_ERROR_ZERO_RETURN) + throw std::runtime_error(message + ": peer closed connection"); + + if (err == SSL_ERROR_SYSCALL) + throw std::runtime_error(message + ": fatal I/O error"); + + if (err == SSL_ERROR_SSL) + throw std::runtime_error(message + + (last_error != 0 ? ": " + error_string(last_error) : "")); + + // SSL_ERROR_WANT_READ and SSL_ERROR_WANT_WRITE end up here + return false; +} + +BIO *BIO_new_from_file(const string &filename) { + BIO *bio = nullptr; + try { + std::ifstream ifs(filename, std::ifstream::in | std::ifstream::binary); + if (!ifs.is_open()) + return nullptr; + + bio = BIO_new(BIO_s_mem()); + + const size_t bufferSize = 4096; + char buffer[bufferSize]; + while (ifs.good()) { + ifs.read(buffer, bufferSize); + BIO_write(bio, buffer, int(ifs.gcount())); + } + ifs.close(); + return bio; + + } catch (const std::exception &) { + BIO_free(bio); + return nullptr; + } +} + +} // namespace rtc::openssl + +#endif diff --git a/datachannel/src/impl/tls.hpp b/datachannel/src/impl/tls.hpp new file mode 100644 index 000000000..36ad6155f --- /dev/null +++ b/datachannel/src/impl/tls.hpp @@ -0,0 +1,96 @@ +/** + * Copyright (c) 2019-2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_TLS_H +#define RTC_TLS_H + +#include "common.hpp" + +#include + +#if USE_GNUTLS + +#include + +#include +#include +#include + +namespace rtc::gnutls { + +bool check(int ret, const string &message = "GnuTLS error"); + +gnutls_certificate_credentials_t *new_credentials(); +void free_credentials(gnutls_certificate_credentials_t *creds); + +gnutls_x509_crt_t *new_crt(); +void free_crt(gnutls_x509_crt_t *crt); + +gnutls_x509_privkey_t *new_privkey(); +void free_privkey(gnutls_x509_privkey_t *privkey); + +gnutls_datum_t make_datum(char *data, size_t size); + +} // namespace rtc::gnutls + +#elif USE_MBEDTLS + +#include "mbedtls/ctr_drbg.h" +#include "mbedtls/ecdsa.h" +#include "mbedtls/entropy.h" +#include "mbedtls/error.h" +#include "mbedtls/pk.h" +#include "mbedtls/rsa.h" +#include "mbedtls/sha256.h" +#include "mbedtls/ssl.h" +#include "mbedtls/x509_crt.h" + +namespace rtc::mbedtls { + +bool check(int ret, const string &message = "MbedTLS error"); + +string format_time(const std::chrono::system_clock::time_point &tp); + +std::shared_ptr new_pk_context(); +std::shared_ptr new_x509_crt(); + +} // namespace rtc::mbedtls + +#else // OPENSSL + +#ifdef _WIN32 +// Include winsock2.h header first since OpenSSL may include winsock.h +#include +#endif + +#include + +#include +#include +#include +#include + +#ifndef BIO_EOF +#define BIO_EOF -1 +#endif + +namespace rtc::openssl { + +void init(); +string error_string(unsigned long error); + +bool check(int success, const string &message = "OpenSSL error"); +bool check_error(int err, const string &message = "OpenSSL error"); + +BIO *BIO_new_from_file(const string &filename); + +} // namespace rtc::openssl + +#endif + +#endif diff --git a/datachannel/src/impl/tlstransport.cpp b/datachannel/src/impl/tlstransport.cpp new file mode 100644 index 000000000..8f66d200f --- /dev/null +++ b/datachannel/src/impl/tlstransport.cpp @@ -0,0 +1,834 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "tlstransport.hpp" +#include "httpproxytransport.hpp" +#include "tcptransport.hpp" +#include "threadpool.hpp" + +#if RTC_ENABLE_WEBSOCKET + +#include +#include +#include +#include + +using namespace std::chrono; + +namespace rtc::impl { + +void TlsTransport::enqueueRecv() { + if (mPendingRecvCount > 0) + return; + + if (auto shared_this = weak_from_this().lock()) { + ++mPendingRecvCount; + ThreadPool::Instance().enqueue(&TlsTransport::doRecv, std::move(shared_this)); + } +} + +#if USE_GNUTLS + +namespace { + +gnutls_certificate_credentials_t default_certificate_credentials() { + static std::mutex mutex; + static shared_ptr creds; + + std::lock_guard lock(mutex); + if (!creds) { + creds = shared_ptr(gnutls::new_credentials(), + gnutls::free_credentials); + gnutls::check(gnutls_certificate_set_x509_system_trust(*creds)); + } + return *creds; +} + +} // namespace + +void TlsTransport::Init() { + // Nothing to do +} + +void TlsTransport::Cleanup() { + // Nothing to do +} + +TlsTransport::TlsTransport(variant, shared_ptr> lower, + optional host, certificate_ptr certificate, + state_callback callback) + : Transport(std::visit([](auto l) { return std::static_pointer_cast(l); }, lower), + std::move(callback)), + mHost(std::move(host)), mIsClient(std::visit([](auto l) { return l->isActive(); }, lower)), + mIncomingQueue(RECV_QUEUE_LIMIT, message_size_func) { + + PLOG_DEBUG << "Initializing TLS transport (GnuTLS)"; + + unsigned int flags = GNUTLS_NONBLOCK | (mIsClient ? GNUTLS_CLIENT : GNUTLS_SERVER); + gnutls::check(gnutls_init(&mSession, flags)); + + try { + const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128"; + const char *err_pos = NULL; + gnutls::check(gnutls_priority_set_direct(mSession, priorities, &err_pos), + "Failed to set TLS priorities"); + + gnutls::check(gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE, + certificate ? certificate->credentials() + : default_certificate_credentials())); + + if (mIsClient && mHost) { + PLOG_VERBOSE << "Server Name Indication: " << *mHost; + gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, mHost->data(), mHost->size()); + } + + gnutls_session_set_ptr(mSession, this); + gnutls_transport_set_ptr(mSession, this); + gnutls_transport_set_push_function(mSession, WriteCallback); + gnutls_transport_set_pull_function(mSession, ReadCallback); + gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback); + + } catch (...) { + gnutls_deinit(mSession); + throw; + } +} + +TlsTransport::~TlsTransport() { + stop(); + + PLOG_DEBUG << "Destroying TLS transport"; + gnutls_deinit(mSession); +} + +void TlsTransport::start() { + PLOG_DEBUG << "Starting TLS transport"; + registerIncoming(); + changeState(State::Connecting); + enqueueRecv(); // to initiate the handshake +} + +void TlsTransport::stop() { + PLOG_DEBUG << "Stopping TLS transport"; + unregisterIncoming(); + mIncomingQueue.stop(); + enqueueRecv(); +} + +bool TlsTransport::send(message_ptr message) { + if (state() != State::Connected) + throw std::runtime_error("TLS is not open"); + + if (!message || message->size() == 0) + return outgoing(message); // pass through + + PLOG_VERBOSE << "Send size=" << message->size(); + + ssize_t ret; + do { + ret = gnutls_record_send(mSession, message->data(), message->size()); + } while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN); + + if (!gnutls::check(ret)) + throw std::runtime_error("TLS send failed"); + + return mOutgoingResult; +} + +void TlsTransport::incoming(message_ptr message) { + if (!message) { + mIncomingQueue.stop(); + enqueueRecv(); + return; + } + + PLOG_VERBOSE << "Incoming size=" << message->size(); + mIncomingQueue.push(message); + enqueueRecv(); +} + +bool TlsTransport::outgoing(message_ptr message) { + bool result = Transport::outgoing(std::move(message)); + mOutgoingResult = result; + return result; +} + +void TlsTransport::postHandshake() { + // Dummy +} + +void TlsTransport::doRecv() { + std::lock_guard lock(mRecvMutex); + --mPendingRecvCount; + + const size_t bufferSize = 4096; + char buffer[bufferSize]; + + try { + // Handle handshake if connecting + if (state() == State::Connecting) { + int ret; + do { + ret = gnutls_handshake(mSession); + + if (ret == GNUTLS_E_AGAIN) + return; + + } while (!gnutls::check(ret, "Handshake failed")); // Re-call on non-fatal error + + PLOG_INFO << "TLS handshake finished"; + changeState(State::Connected); + postHandshake(); + } + + if (state() == State::Connected) { + while (true) { + ssize_t ret = gnutls_record_recv(mSession, buffer, bufferSize); + + if (ret == GNUTLS_E_AGAIN) + return; + + // Consider premature termination as remote closing + if (ret == GNUTLS_E_PREMATURE_TERMINATION) { + PLOG_DEBUG << "TLS connection terminated"; + break; + } + + if (gnutls::check(ret)) { + if (ret == 0) { + // Closed + PLOG_DEBUG << "TLS connection cleanly closed"; + break; + } + auto *b = reinterpret_cast(buffer); + recv(make_message(b, b + ret)); + } + } + } + } catch (const std::exception &e) { + PLOG_ERROR << "TLS recv: " << e.what(); + } + + gnutls_bye(mSession, GNUTLS_SHUT_WR); + + if (state() == State::Connected) { + PLOG_INFO << "TLS closed"; + changeState(State::Disconnected); + recv(nullptr); + } else { + PLOG_ERROR << "TLS handshake failed"; + changeState(State::Failed); + } +} + +ssize_t TlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len) { + TlsTransport *t = static_cast(ptr); + try { + if (len > 0) { + auto b = reinterpret_cast(data); + t->outgoing(make_message(b, b + len)); + } + gnutls_transport_set_errno(t->mSession, 0); + return ssize_t(len); + + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + gnutls_transport_set_errno(t->mSession, ECONNRESET); + return -1; + } +} + +ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) { + TlsTransport *t = static_cast(ptr); + try { + message_ptr &message = t->mIncomingMessage; + size_t &position = t->mIncomingMessagePosition; + + if (message && position >= message->size()) + message.reset(); + + if (!message) { + position = 0; + while (auto next = t->mIncomingQueue.pop()) { + message = *next; + if (message->size() > 0) + break; + else + t->recv(message); // Pass zero-sized messages through + } + } + + if (message) { + size_t available = message->size() - position; + ssize_t len = std::min(maxlen, available); + std::memcpy(data, message->data() + position, len); + position += len; + gnutls_transport_set_errno(t->mSession, 0); + return len; + } else if (t->mIncomingQueue.running()) { + gnutls_transport_set_errno(t->mSession, EAGAIN); + return -1; + } else { + // Closed + gnutls_transport_set_errno(t->mSession, 0); + return 0; + } + + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + gnutls_transport_set_errno(t->mSession, ECONNRESET); + return -1; + } +} + +int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int /* ms */) { + TlsTransport *t = static_cast(ptr); + try { + message_ptr &message = t->mIncomingMessage; + size_t &position = t->mIncomingMessagePosition; + + if (message && position < message->size()) + return 1; + + return !t->mIncomingQueue.empty() ? 1 : 0; + + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + return 1; + } +} + +#elif USE_MBEDTLS + +void TlsTransport::Init() { + // Nothing to do +} + +void TlsTransport::Cleanup() { + // Nothing to do +} + +TlsTransport::TlsTransport(variant, shared_ptr> lower, + optional host, certificate_ptr certificate, + state_callback callback) + : Transport(std::visit([](auto l) { return std::static_pointer_cast(l); }, lower), + std::move(callback)), + mHost(std::move(host)), mIsClient(std::visit([](auto l) { return l->isActive(); }, lower)), + mIncomingQueue(RECV_QUEUE_LIMIT, message_size_func) { + + PLOG_DEBUG << "Initializing TLS transport (MbedTLS)"; + + psa_crypto_init(); + mbedtls_entropy_init(&mEntropy); + mbedtls_ctr_drbg_init(&mDrbg); + mbedtls_ssl_init(&mSsl); + mbedtls_ssl_config_init(&mConf); + mbedtls_ctr_drbg_set_prediction_resistance(&mDrbg, MBEDTLS_CTR_DRBG_PR_ON); + + try { + mbedtls::check(mbedtls_ctr_drbg_seed(&mDrbg, mbedtls_entropy_func, &mEntropy, NULL, 0)); + + mbedtls::check(mbedtls_ssl_config_defaults( + &mConf, mIsClient ? MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER, + MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)); + + mbedtls_ssl_conf_max_version(&mConf, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3); // TLS 1.2 + mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_OPTIONAL); + mbedtls_ssl_conf_rng(&mConf, mbedtls_ctr_drbg_random, &mDrbg); + + if (certificate) { + auto [crt, pk] = certificate->credentials(); + mbedtls::check(mbedtls_ssl_conf_own_cert(&mConf, crt.get(), pk.get())); + } + + if (mIsClient && mHost) { + PLOG_VERBOSE << "Server Name Indication: " << *mHost; + mbedtls_ssl_set_hostname(&mSsl, mHost->c_str()); + } + + mbedtls::check(mbedtls_ssl_setup(&mSsl, &mConf)); + mbedtls_ssl_set_bio(&mSsl, static_cast(this), WriteCallback, ReadCallback, NULL); + + } catch (...) { + mbedtls_entropy_free(&mEntropy); + mbedtls_ctr_drbg_free(&mDrbg); + mbedtls_ssl_free(&mSsl); + mbedtls_ssl_config_free(&mConf); + throw; + } +} + +TlsTransport::~TlsTransport() { + stop(); + + PLOG_DEBUG << "Destroying TLS transport"; + mbedtls_entropy_free(&mEntropy); + mbedtls_ctr_drbg_free(&mDrbg); + mbedtls_ssl_free(&mSsl); + mbedtls_ssl_config_free(&mConf); +} + +void TlsTransport::start() { + PLOG_DEBUG << "Starting TLS transport"; + registerIncoming(); + changeState(State::Connecting); + enqueueRecv(); // to initiate the handshake +} + +void TlsTransport::stop() { + PLOG_DEBUG << "Stopping TLS transport"; + unregisterIncoming(); + mIncomingQueue.stop(); + enqueueRecv(); +} + +bool TlsTransport::send(message_ptr message) { + if (state() != State::Connected) + throw std::runtime_error("TLS is not open"); + + if (!message || message->size() == 0) + return outgoing(message); // pass through + + PLOG_VERBOSE << "Send size=" << message->size(); + + int ret; + do { + std::lock_guard lock(mSslMutex); + ret = mbedtls_ssl_write(&mSsl, reinterpret_cast(message->data()), + int(message->size())); + } while (ret == MBEDTLS_ERR_SSL_WANT_WRITE); + + if (!mbedtls::check(ret)) + throw std::runtime_error("TLS send failed"); + + return mOutgoingResult; +} + +void TlsTransport::incoming(message_ptr message) { + if (!message) { + mIncomingQueue.stop(); + enqueueRecv(); + return; + } + + PLOG_VERBOSE << "Incoming size=" << message->size(); + mIncomingQueue.push(message); + enqueueRecv(); +} + +bool TlsTransport::outgoing(message_ptr message) { + bool result = Transport::outgoing(std::move(message)); + mOutgoingResult = result; + return result; +} + +void TlsTransport::postHandshake() { + // Dummy +} + +void TlsTransport::doRecv() { + std::lock_guard lock(mRecvMutex); + --mPendingRecvCount; + + if (state() != State::Connecting && state() != State::Connected) + return; + + try { + const size_t bufferSize = 4096; + char buffer[bufferSize]; + + // Handle handshake if connecting + if (state() == State::Connecting) { + while (true) { + int ret; + { + std::lock_guard lock(mSslMutex); + ret = mbedtls_ssl_handshake(&mSsl); + } + + if (ret == MBEDTLS_ERR_SSL_WANT_READ) { + return; + } + + if (mbedtls::check(ret, "Handshake failed")) { + PLOG_INFO << "TLS handshake finished"; + changeState(State::Connected); + postHandshake(); + break; + } + } + } + + if (state() == State::Connected) { + while (true) { + int ret; + { + std::lock_guard lock(mSslMutex); + ret = mbedtls_ssl_read(&mSsl, reinterpret_cast(buffer), + bufferSize); + } + + if (ret == MBEDTLS_ERR_SSL_WANT_READ) { + return; + } + + if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { + PLOG_DEBUG << "TLS connection cleanly closed"; + break; + } + + if (mbedtls::check(ret)) { + if (ret == 0) { + PLOG_DEBUG << "TLS connection terminated"; + break; + } + auto *b = reinterpret_cast(buffer); + recv(make_message(b, b + ret)); + } + } + } + } catch (const std::exception &e) { + PLOG_ERROR << "TLS recv: " << e.what(); + } + + if (state() == State::Connected) { + PLOG_INFO << "TLS closed"; + changeState(State::Disconnected); + recv(nullptr); + } else { + PLOG_ERROR << "TLS handshake failed"; + changeState(State::Failed); + } +} + +int TlsTransport::WriteCallback(void *ctx, const unsigned char *buf, size_t len) { + auto *t = static_cast(ctx); + auto *b = reinterpret_cast(buf); + t->outgoing(make_message(b, b + len)); + + return int(len); +} + +int TlsTransport::ReadCallback(void *ctx, unsigned char *buf, size_t len) { + TlsTransport *t = static_cast(ctx); + try { + message_ptr &message = t->mIncomingMessage; + size_t &position = t->mIncomingMessagePosition; + + if (message && position >= message->size()) + message.reset(); + + if (!message) { + position = 0; + while (auto next = t->mIncomingQueue.pop()) { + message = *next; + if (message->size() > 0) + break; + else + t->recv(message); // Pass zero-sized messages through + } + } + + if (message) { + size_t available = message->size() - position; + size_t writeLen = std::min(len, available); + std::memcpy(buf, message->data() + position, writeLen); + position += writeLen; + return int(writeLen); + } else if (t->mIncomingQueue.running()) { + return MBEDTLS_ERR_SSL_WANT_READ; + } else { + return MBEDTLS_ERR_SSL_CONN_EOF; + } + + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + return MBEDTLS_ERR_SSL_INTERNAL_ERROR; + } +} + +#else + +int TlsTransport::TransportExIndex = -1; + +void TlsTransport::Init() { + openssl::init(); + + if (TransportExIndex < 0) { + TransportExIndex = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL); + } +} + +void TlsTransport::Cleanup() { + // Nothing to do +} + +TlsTransport::TlsTransport(variant, shared_ptr> lower, + optional host, certificate_ptr certificate, + state_callback callback) + : Transport(std::visit([](auto l) { return std::static_pointer_cast(l); }, lower), + std::move(callback)), + mHost(std::move(host)), mIsClient(std::visit([](auto l) { return l->isActive(); }, lower)), + mIncomingQueue(RECV_QUEUE_LIMIT, message_size_func) { + + PLOG_DEBUG << "Initializing TLS transport (OpenSSL)"; + + try { + if (!(mCtx = SSL_CTX_new(TLS_method()))) // version-flexible + throw std::runtime_error("Failed to create SSL context"); + + openssl::check(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"), + "Failed to set SSL priorities"); + +#if OPENSSL_VERSION_NUMBER >= 0x30000000 + openssl::check(SSL_CTX_set1_groups_list(mCtx, "P-256"), "Failed to set SSL groups"); +#else + auto ecdh = unique_ptr( + EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free); + SSL_CTX_set_tmp_ecdh(mCtx, ecdh.get()); +#endif + + if(mIsClient) { + if (!SSL_CTX_set_default_verify_paths(mCtx)) { + PLOG_WARNING << "SSL root CA certificates unavailable"; + } + } + + if (certificate) { + auto [x509, pkey] = certificate->credentials(); + SSL_CTX_use_certificate(mCtx, x509); + SSL_CTX_use_PrivateKey(mCtx, pkey); + } + + SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3 | SSL_OP_NO_RENEGOTIATION); + SSL_CTX_set_min_proto_version(mCtx, TLS1_VERSION); + SSL_CTX_set_read_ahead(mCtx, 1); + SSL_CTX_set_quiet_shutdown(mCtx, 0); // send the close_notify alert + SSL_CTX_set_info_callback(mCtx, InfoCallback); + SSL_CTX_set_verify(mCtx, SSL_VERIFY_NONE, NULL); + + if (!(mSsl = SSL_new(mCtx))) + throw std::runtime_error("Failed to create SSL instance"); + + SSL_set_ex_data(mSsl, TransportExIndex, this); + + if (mIsClient && mHost) { + SSL_set_hostflags(mSsl, 0); + openssl::check(SSL_set1_host(mSsl, mHost->c_str()), "Failed to set SSL host"); + + PLOG_VERBOSE << "Server Name Indication: " << *mHost; + SSL_set_tlsext_host_name(mSsl, mHost->c_str()); + } + + if (mIsClient) + SSL_set_connect_state(mSsl); + else + SSL_set_accept_state(mSsl); + + if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BIO_s_mem()))) + throw std::runtime_error("Failed to create BIO"); + + BIO_set_mem_eof_return(mInBio, BIO_EOF); + BIO_set_mem_eof_return(mOutBio, BIO_EOF); + SSL_set_bio(mSsl, mInBio, mOutBio); + + } catch (...) { + if (mSsl) + SSL_free(mSsl); + if (mCtx) + SSL_CTX_free(mCtx); + throw; + } +} + +TlsTransport::~TlsTransport() { + stop(); + + PLOG_DEBUG << "Destroying TLS transport"; + SSL_free(mSsl); + SSL_CTX_free(mCtx); +} + +void TlsTransport::start() { + PLOG_DEBUG << "Starting TLS transport"; + registerIncoming(); + changeState(State::Connecting); + + // Initiate the handshake + int ret, err; + { + std::lock_guard lock(mSslMutex); + ret = SSL_do_handshake(mSsl); + err = SSL_get_error(mSsl, ret); + flushOutput(); + } + + openssl::check_error(err, "Handshake failed"); +} + +void TlsTransport::stop() { + PLOG_DEBUG << "Stopping TLS transport"; + unregisterIncoming(); + mIncomingQueue.stop(); + enqueueRecv(); +} + +bool TlsTransport::send(message_ptr message) { + if (state() != State::Connected) + throw std::runtime_error("TLS is not open"); + + if (!message || message->size() == 0) + return outgoing(message); // pass through + + PLOG_VERBOSE << "Send size=" << message->size(); + + int err; + bool result; + { + std::lock_guard lock(mSslMutex); + int ret = SSL_write(mSsl, message->data(), int(message->size())); + err = SSL_get_error(mSsl, ret); + result = flushOutput(); + } + + if (!openssl::check_error(err)) + throw std::runtime_error("TLS send failed"); + + return result; +} + +void TlsTransport::incoming(message_ptr message) { + if (!message) { + mIncomingQueue.stop(); + enqueueRecv(); + return; + } + + PLOG_VERBOSE << "Incoming size=" << message->size(); + mIncomingQueue.push(message); + enqueueRecv(); +} + +bool TlsTransport::outgoing(message_ptr message) { return Transport::outgoing(std::move(message)); } + +void TlsTransport::postHandshake() { + // Dummy +} + +void TlsTransport::doRecv() { + std::lock_guard lock(mRecvMutex); + --mPendingRecvCount; + + if (state() != State::Connecting && state() != State::Connected) + return; + + try { + const size_t bufferSize = 4096; + byte buffer[bufferSize]; + + // Read incoming messages + while (mIncomingQueue.running()) { + auto next = mIncomingQueue.pop(); + if (!next) + return; + + message_ptr message = std::move(*next); + if (message->size() > 0) + BIO_write(mInBio, message->data(), int(message->size())); // Input + else + recv(message); // Pass zero-sized messages through + + if (state() == State::Connecting) { + // Continue the handshake + int ret, err; + { + std::lock_guard lock(mSslMutex); + ret = SSL_do_handshake(mSsl); + err = SSL_get_error(mSsl, ret); + flushOutput(); + } + + if (openssl::check_error(err, "Handshake failed")) { + PLOG_INFO << "TLS handshake finished"; + changeState(State::Connected); + postHandshake(); + } + } + + if (state() == State::Connected) { + int ret, err; + while (true) { + { + std::lock_guard lock(mSslMutex); + ret = SSL_read(mSsl, buffer, bufferSize); + err = SSL_get_error(mSsl, ret); + flushOutput(); // SSL_read() can also cause write operations + } + + if (err == SSL_ERROR_ZERO_RETURN) + break; + + if (openssl::check_error(err)) + recv(make_message(buffer, buffer + ret)); + else + break; + } + + if (err == SSL_ERROR_ZERO_RETURN) { + PLOG_DEBUG << "TLS connection cleanly closed"; + break; // No more data can be read + } + } + } + + std::lock_guard lock(mSslMutex); + SSL_shutdown(mSsl); + + } catch (const std::exception &e) { + PLOG_ERROR << "TLS recv: " << e.what(); + } + + if (state() == State::Connected) { + PLOG_INFO << "TLS closed"; + changeState(State::Disconnected); + recv(nullptr); + } else { + PLOG_ERROR << "TLS handshake failed"; + changeState(State::Failed); + } +} + +bool TlsTransport::flushOutput() { + // Requires mSslMutex to be locked + bool result = true; + const size_t bufferSize = 4096; + byte buffer[bufferSize]; + int len; + while ((len = BIO_read(mOutBio, buffer, bufferSize)) > 0) + result = outgoing(make_message(buffer, buffer + len)); + + return result; +} + +void TlsTransport::InfoCallback(const SSL *ssl, int where, int ret) { + TlsTransport *t = + static_cast(SSL_get_ex_data(ssl, TlsTransport::TransportExIndex)); + + if (where & SSL_CB_ALERT) { + if (ret != 256) { // Close Notify + PLOG_ERROR << "TLS alert: " << SSL_alert_desc_string_long(ret); + } + t->mIncomingQueue.stop(); // Close the connection + } +} + +#endif + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/tlstransport.hpp b/datachannel/src/impl/tlstransport.hpp new file mode 100644 index 000000000..392bbc087 --- /dev/null +++ b/datachannel/src/impl/tlstransport.hpp @@ -0,0 +1,102 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_TLS_TRANSPORT_H +#define RTC_IMPL_TLS_TRANSPORT_H + +#include "certificate.hpp" +#include "common.hpp" +#include "queue.hpp" +#include "tls.hpp" +#include "transport.hpp" + +#if RTC_ENABLE_WEBSOCKET + +#include +#include + +namespace rtc::impl { + +class TcpTransport; +class HttpProxyTransport; + +class TlsTransport : public Transport, public std::enable_shared_from_this { +public: + static void Init(); + static void Cleanup(); + + TlsTransport(variant, shared_ptr> lower, + optional host, certificate_ptr certificate, state_callback callback); + virtual ~TlsTransport(); + + void start() override; + void stop() override; + bool send(message_ptr message) override; + + bool isClient() const { return mIsClient; } + +protected: + virtual void incoming(message_ptr message) override; + virtual bool outgoing(message_ptr message) override; + virtual void postHandshake(); + + void enqueueRecv(); + void doRecv(); + + const optional mHost; + const bool mIsClient; + + Queue mIncomingQueue; + std::atomic mPendingRecvCount = 0; + std::mutex mRecvMutex; + +#if USE_GNUTLS + gnutls_session_t mSession; + + message_ptr mIncomingMessage; + size_t mIncomingMessagePosition = 0; + std::atomic mOutgoingResult = true; + + static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len); + static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen); + static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms); + +#elif USE_MBEDTLS + mbedtls_entropy_context mEntropy; + mbedtls_ctr_drbg_context mDrbg; + mbedtls_ssl_config mConf; + mbedtls_ssl_context mSsl; + + std::mutex mSslMutex; + std::atomic mOutgoingResult = true; + + message_ptr mIncomingMessage; + size_t mIncomingMessagePosition = 0; + + static int WriteCallback(void *ctx, const unsigned char *buf, size_t len); + static int ReadCallback(void *ctx, unsigned char *buf, size_t len); + +#else + SSL_CTX *mCtx; + SSL *mSsl; + BIO *mInBio, *mOutBio; + std::mutex mSslMutex; + + bool flushOutput(); + + static int TransportExIndex; + + static void InfoCallback(const SSL *ssl, int where, int ret); +#endif +}; + +} // namespace rtc::impl + +#endif + +#endif diff --git a/datachannel/src/impl/track.cpp b/datachannel/src/impl/track.cpp new file mode 100644 index 000000000..80927fd76 --- /dev/null +++ b/datachannel/src/impl/track.cpp @@ -0,0 +1,229 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "track.hpp" +#include "internals.hpp" +#include "logcounter.hpp" +#include "peerconnection.hpp" +#include "rtp.hpp" + +namespace rtc::impl { + +static LogCounter COUNTER_MEDIA_BAD_DIRECTION(plog::warning, + "Number of media packets sent in invalid directions"); +static LogCounter COUNTER_QUEUE_FULL(plog::warning, + "Number of media packets dropped due to a full queue"); + +Track::Track(weak_ptr pc, Description::Media desc) + : mPeerConnection(pc), mMediaDescription(std::move(desc)), + mRecvQueue(RECV_QUEUE_LIMIT, [](const message_ptr &m) { return m->size(); }) { + + // Discard messages by default if track is send only + if (mMediaDescription.direction() == Description::Direction::SendOnly) + messageCallback = [](message_variant) {}; +} + +Track::~Track() { + PLOG_VERBOSE << "Destroying Track"; + try { + close(); + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + } +} + +string Track::mid() const { + std::shared_lock lock(mMutex); + return mMediaDescription.mid(); +} + +Description::Direction Track::direction() const { + std::shared_lock lock(mMutex); + return mMediaDescription.direction(); +} + +Description::Media Track::description() const { + std::shared_lock lock(mMutex); + return mMediaDescription; +} + +void Track::setDescription(Description::Media desc) { + { + std::unique_lock lock(mMutex); + if (desc.mid() != mMediaDescription.mid()) + throw std::logic_error("Media description mid does not match track mid"); + + mMediaDescription = std::move(desc); + } + + if (auto handler = getMediaHandler()) + handler->media(description()); +} + +void Track::close() { + PLOG_VERBOSE << "Closing Track"; + + if (!mIsClosed.exchange(true)) + triggerClosed(); + + setMediaHandler(nullptr); + resetCallbacks(); +} + +optional Track::receive() { + if (auto next = mRecvQueue.pop()) { + message_ptr message = *next; + if (message->type == Message::Control) + return to_variant(**next); // The same message may be frowarded into multiple Tracks + else + return to_variant(std::move(*message)); + } + return nullopt; +} + +optional Track::peek() { + if (auto next = mRecvQueue.peek()) { + message_ptr message = *next; + if (message->type == Message::Control) + return to_variant(**next); // The same message may be forwarded into multiple Tracks + else + return to_variant(std::move(*message)); + } + return nullopt; +} + +size_t Track::availableAmount() const { return mRecvQueue.amount(); } + +bool Track::isOpen(void) const { +#if RTC_ENABLE_MEDIA + std::shared_lock lock(mMutex); + return !mIsClosed && mDtlsSrtpTransport.lock(); +#else + return false; +#endif +} + +bool Track::isClosed(void) const { return mIsClosed; } + +size_t Track::maxMessageSize() const { + optional mtu; + if (auto pc = mPeerConnection.lock()) + mtu = pc->config.mtu; + + return mtu.value_or(DEFAULT_MTU) - 12 - 8 - 40; // SRTP/UDP/IPv6 +} + +#if RTC_ENABLE_MEDIA +void Track::open(shared_ptr transport) { + { + std::lock_guard lock(mMutex); + mDtlsSrtpTransport = transport; + } + + if (!mIsClosed) + triggerOpen(); +} +#endif + +void Track::incoming(message_ptr message) { + if (!message) + return; + + auto dir = direction(); + if ((dir == Description::Direction::SendOnly || dir == Description::Direction::Inactive) && + message->type != Message::Control) { + COUNTER_MEDIA_BAD_DIRECTION++; + return; + } + + message_vector messages{std::move(message)}; + if (auto handler = getMediaHandler()) + handler->incomingChain(messages, [this](message_ptr m) { transportSend(m); }); + + for (auto &m : messages) { + // Tail drop if queue is full + if (mRecvQueue.full()) { + COUNTER_QUEUE_FULL++; + return; + } + + mRecvQueue.push(m); + triggerAvailable(mRecvQueue.size()); + } +} + +bool Track::outgoing(message_ptr message) { + if (mIsClosed) + throw std::runtime_error("Track is closed"); + + auto handler = getMediaHandler(); + + // If there is no handler, the track expects RTP or RTCP packets + if (!handler && IsRtcp(*message)) + message->type = Message::Control; // to allow sending RTCP packets irrelevant of direction + + auto dir = direction(); + if ((dir == Description::Direction::RecvOnly || dir == Description::Direction::Inactive) && + message->type != Message::Control) { + COUNTER_MEDIA_BAD_DIRECTION++; + return false; + } + + if (handler) { + message_vector messages{std::move(message)}; + handler->outgoingChain(messages, [this](message_ptr m) { transportSend(m); }); + bool ret = false; + for (auto &m : messages) + ret = transportSend(std::move(m)); + + return ret; + + } else { + return transportSend(std::move(message)); + } +} + +bool Track::transportSend([[maybe_unused]] message_ptr message) { +#if RTC_ENABLE_MEDIA + shared_ptr transport; + { + std::shared_lock lock(mMutex); + transport = mDtlsSrtpTransport.lock(); + if (!transport) + throw std::runtime_error("Track is closed"); + + // Set recommended medium-priority DSCP value + // See https://www.rfc-editor.org/rfc/rfc8837.html#section-5 + if (mMediaDescription.type() == "audio") + message->dscp = 46; // EF: Expedited Forwarding + else + message->dscp = 36; // AF42: Assured Forwarding class 4, medium drop probability + } + + return transport->sendMedia(message); +#else + throw std::runtime_error("Track is disabled (not compiled with media support)"); +#endif +} + +void Track::setMediaHandler(shared_ptr handler) { + { + std::unique_lock lock(mMutex); + mMediaHandler = handler; + } + + if(handler) + handler->media(description()); +} + +shared_ptr Track::getMediaHandler() { + std::shared_lock lock(mMutex); + return mMediaHandler; +} + +} // namespace rtc::impl diff --git a/datachannel/src/impl/track.hpp b/datachannel/src/impl/track.hpp new file mode 100644 index 000000000..ea1446b34 --- /dev/null +++ b/datachannel/src/impl/track.hpp @@ -0,0 +1,78 @@ +/** + * Copyright (c) 2020-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_TRACK_H +#define RTC_IMPL_TRACK_H + +#include "channel.hpp" +#include "common.hpp" +#include "description.hpp" +#include "mediahandler.hpp" +#include "queue.hpp" + +#if RTC_ENABLE_MEDIA +#include "dtlssrtptransport.hpp" +#endif + +#include +#include + +namespace rtc::impl { + +struct PeerConnection; + +class Track final : public std::enable_shared_from_this, public Channel { +public: + Track(weak_ptr pc, Description::Media desc); + ~Track(); + + void close(); + void incoming(message_ptr message); + bool outgoing(message_ptr message); + + optional receive() override; + optional peek() override; + size_t availableAmount() const override; + + bool isOpen() const; + bool isClosed() const; + size_t maxMessageSize() const; + + string mid() const; + Description::Direction direction() const; + Description::Media description() const; + void setDescription(Description::Media desc); + + shared_ptr getMediaHandler(); + void setMediaHandler(shared_ptr handler); + +#if RTC_ENABLE_MEDIA + void open(shared_ptr transport); +#endif + + bool transportSend(message_ptr message); + +private: + const weak_ptr mPeerConnection; +#if RTC_ENABLE_MEDIA + weak_ptr mDtlsSrtpTransport; +#endif + + Description::Media mMediaDescription; + shared_ptr mMediaHandler; + + mutable std::shared_mutex mMutex; + + std::atomic mIsClosed = false; + + Queue mRecvQueue; +}; + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/transport.cpp b/datachannel/src/impl/transport.cpp new file mode 100644 index 000000000..1c28e4f36 --- /dev/null +++ b/datachannel/src/impl/transport.cpp @@ -0,0 +1,79 @@ +/** + * Copyright (c) 2019-2022 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "transport.hpp" + +namespace rtc::impl { + +Transport::Transport(shared_ptr lower, state_callback callback) + : mLower(std::move(lower)), mStateChangeCallback(std::move(callback)) {} + +Transport::~Transport() { + unregisterIncoming(); + + if (mLower) { + mLower->stop(); + mLower.reset(); + } +} + +void Transport::registerIncoming() { + if (mLower) { + PLOG_VERBOSE << "Registering incoming callback"; + mLower->onRecv(std::bind(&Transport::incoming, this, std::placeholders::_1)); + } +} + +void Transport::unregisterIncoming() { + if (mLower) { + PLOG_VERBOSE << "Unregistering incoming callback"; + mLower->onRecv(nullptr); + } +} + +Transport::State Transport::state() const { return mState; } + +void Transport::onRecv(message_callback callback) { mRecvCallback = std::move(callback); } + +void Transport::onStateChange(state_callback callback) { + mStateChangeCallback = std::move(callback); +} + +void Transport::start() { registerIncoming(); } + +void Transport::stop() { unregisterIncoming(); } + +bool Transport::send(message_ptr message) { return outgoing(message); } + +void Transport::recv(message_ptr message) { + try { + mRecvCallback(message); + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + } +} + +void Transport::changeState(State state) { + try { + if (mState.exchange(state) != state) + mStateChangeCallback(state); + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + } +} + +void Transport::incoming(message_ptr message) { recv(message); } + +bool Transport::outgoing(message_ptr message) { + if (mLower) + return mLower->send(message); + else + return false; +} + +} // namespace rtc::impl diff --git a/datachannel/src/impl/transport.hpp b/datachannel/src/impl/transport.hpp new file mode 100644 index 000000000..fc879df2b --- /dev/null +++ b/datachannel/src/impl/transport.hpp @@ -0,0 +1,60 @@ +/** + * Copyright (c) 2019-2022 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_TRANSPORT_H +#define RTC_IMPL_TRANSPORT_H + +#include "common.hpp" +#include "init.hpp" +#include "internals.hpp" +#include "message.hpp" + +#include +#include +#include + +namespace rtc::impl { + +class Transport { +public: + enum class State { Disconnected, Connecting, Connected, Completed, Failed }; + using state_callback = std::function; + + Transport(shared_ptr lower = nullptr, state_callback callback = nullptr); + virtual ~Transport(); + + void registerIncoming(); + void unregisterIncoming(); + State state() const; + + void onRecv(message_callback callback); + void onStateChange(state_callback callback); + + virtual void start(); + virtual void stop(); + virtual bool send(message_ptr message); + +protected: + void recv(message_ptr message); + void changeState(State state); + virtual void incoming(message_ptr message); + virtual bool outgoing(message_ptr message); + +private: + const init_token mInitToken = Init::Instance().token(); + + shared_ptr mLower; + synchronized_callback mStateChangeCallback; + synchronized_callback mRecvCallback; + + std::atomic mState = State::Disconnected; +}; + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/utils.cpp b/datachannel/src/impl/utils.cpp new file mode 100644 index 000000000..ea5f10d41 --- /dev/null +++ b/datachannel/src/impl/utils.cpp @@ -0,0 +1,183 @@ +/** + * Copyright (c) 2020-2022 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "utils.hpp" + +#include "impl/internals.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#if defined(_WIN32) +#include + +typedef HRESULT(WINAPI *pfnSetThreadDescription)(HANDLE, PCWSTR); +#endif +#if defined(__linux__) +#include // for prctl(PR_SET_NAME) +#endif +#if defined(__FreeBSD__) +#include // for pthread_set_name_np +#endif + +namespace rtc::impl::utils { + +using std::to_integer; + +std::vector explode(const string &str, char delim) { + std::vector result; + std::istringstream ss(str); + string token; + while (std::getline(ss, token, delim)) + result.push_back(token); + + return result; +} + +string implode(const std::vector &tokens, char delim) { + string sdelim(1, delim); + std::ostringstream ss; + std::copy(tokens.begin(), tokens.end(), std::ostream_iterator(ss, sdelim.c_str())); + string result = ss.str(); + if (result.size() > 0) + result.resize(result.size() - 1); + + return result; +} + +string url_decode(const string &str) { + string result; + size_t i = 0; + while (i < str.size()) { + char c = str[i++]; + if (c == '%') { + auto value = str.substr(i, 2); + try { + if (value.size() != 2 || !std::isxdigit(value[0]) || !std::isxdigit(value[1])) + throw std::exception(); + + c = static_cast(std::stoi(value, nullptr, 16)); + i += 2; + + } catch (...) { + PLOG_WARNING << "Invalid percent-encoded character in URL: \"%" + value + "\""; + } + } + + result.push_back(c); + } + + return result; +} + +string base64_encode(const binary &data) { + static const char tab[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + string out; + out.reserve(3 * ((data.size() + 3) / 4)); + int i = 0; + while (data.size() - i >= 3) { + auto d0 = to_integer(data[i]); + auto d1 = to_integer(data[i + 1]); + auto d2 = to_integer(data[i + 2]); + out += tab[d0 >> 2]; + out += tab[((d0 & 3) << 4) | (d1 >> 4)]; + out += tab[((d1 & 0x0F) << 2) | (d2 >> 6)]; + out += tab[d2 & 0x3F]; + i += 3; + } + + int left = int(data.size() - i); + if (left) { + auto d0 = to_integer(data[i]); + out += tab[d0 >> 2]; + if (left == 1) { + out += tab[(d0 & 3) << 4]; + out += '='; + } else { // left == 2 + auto d1 = to_integer(data[i + 1]); + out += tab[((d0 & 3) << 4) | (d1 >> 4)]; + out += tab[(d1 & 0x0F) << 2]; + } + out += '='; + } + + return out; +} + +std::seed_seq random_seed() { + std::vector seed; + + // Seed with random device + try { + // On some systems an exception might be thrown if the random_device can't be initialized + std::random_device device; + // 128 bits should be more than enough + std::generate_n(std::back_inserter(seed), 4, std::ref(device)); + } catch (...) { + // Ignore + } + + // Seed with high-resolution clock + using std::chrono::high_resolution_clock; + seed.push_back( + static_cast(high_resolution_clock::now().time_since_epoch().count())); + + // Seed with thread id + seed.push_back( + static_cast(std::hash{}(std::this_thread::get_id()))); + + return std::seed_seq(seed.begin(), seed.end()); +} + +namespace { + +void thread_set_name_self(const char *name) { +#if defined(_WIN32) + int name_length = (int)strlen(name); + int wname_length = + MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, name, name_length, nullptr, 0); + if (wname_length > 0) { + std::wstring wname(wname_length, L'\0'); + wname_length = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, name, name_length, + &wname[0], wname_length + 1); + + HMODULE kernel32 = GetModuleHandleW(L"kernel32.dll"); + if (kernel32 != nullptr) { + auto pSetThreadDescription = + (pfnSetThreadDescription)GetProcAddress(kernel32, "SetThreadDescription"); + if (pSetThreadDescription != nullptr) { + pSetThreadDescription(GetCurrentThread(), wname.c_str()); + } + } + } +#elif defined(__linux__) + prctl(PR_SET_NAME, name); +#elif defined(__APPLE__) + pthread_setname_np(name); +#elif defined(__FreeBSD__) + pthread_set_name_np(pthread_self(), name); +#else + (void)name; +#endif +} + +} // namespace + +namespace this_thread { + +void set_name(const string &name) { thread_set_name_self(name.c_str()); } + +} // namespace this_thread + +} // namespace rtc::impl::utils diff --git a/datachannel/src/impl/utils.hpp b/datachannel/src/impl/utils.hpp new file mode 100644 index 000000000..808fe1ed5 --- /dev/null +++ b/datachannel/src/impl/utils.hpp @@ -0,0 +1,88 @@ +/** + * Copyright (c) 2020-2022 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_UTILS_H +#define RTC_IMPL_UTILS_H + +#include "common.hpp" + +#include +#include +#include +#include +#include +#include + +namespace rtc::impl::utils { + +std::vector explode(const string &str, char delim); +string implode(const std::vector &tokens, char delim); + +// Decode URL percent-encoding (RFC 3986) +// See https://www.rfc-editor.org/rfc/rfc3986.html#section-2.1 +string url_decode(const string &str); + +// Encode as base64 (RFC 4648) +// See https://www.rfc-editor.org/rfc/rfc4648.html#section-4 +string base64_encode(const binary &data); + +// Return a random seed sequence +std::seed_seq random_seed(); + +template +struct random_engine_wrapper { + Generator &engine; + using result_type = Result; + static constexpr result_type min() { return static_cast(Generator::min()); } + static constexpr result_type max() { return static_cast(Generator::max()); } + inline result_type operator()() { return static_cast(engine()); } + inline void discard(unsigned long long z) { engine.discard(z); } +}; + +// Return a wrapped thread-local seeded random number generator +template +auto random_engine() { + static thread_local std::seed_seq seed = random_seed(); + static thread_local Generator engine{seed}; + return random_engine_wrapper{engine}; +} + +// Return a wrapped thread-local seeded random bytes generator +template auto random_bytes_engine() { + using char_independent_bits_engine = + std::independent_bits_engine; + static_assert(char_independent_bits_engine::min() == std::numeric_limits::min()); + static_assert(char_independent_bits_engine::max() == std::numeric_limits::max()); + return random_engine(); +} + +template uint16_t to_uint16(T i) { + if (i >= 0 && static_cast::type>(i) <= + std::numeric_limits::max()) + return static_cast(i); + else + throw std::invalid_argument("Integer out of range"); +} + +template uint32_t to_uint32(T i) { + if (i >= 0 && static_cast::type>(i) <= + std::numeric_limits::max()) + return static_cast(i); + else + throw std::invalid_argument("Integer out of range"); +} + +namespace this_thread { + +void set_name(const string &name); + +} // namespace this_thread + +} // namespace rtc::impl::utils + +#endif diff --git a/datachannel/src/impl/verifiedtlstransport.cpp b/datachannel/src/impl/verifiedtlstransport.cpp new file mode 100644 index 000000000..cea7b0cf8 --- /dev/null +++ b/datachannel/src/impl/verifiedtlstransport.cpp @@ -0,0 +1,71 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "verifiedtlstransport.hpp" +#include "common.hpp" + +#if RTC_ENABLE_WEBSOCKET + +namespace rtc::impl { + +static const string PemBeginCertificateTag = "-----BEGIN CERTIFICATE-----"; + +VerifiedTlsTransport::VerifiedTlsTransport( + variant, shared_ptr> lower, string host, + certificate_ptr certificate, state_callback callback, [[maybe_unused]] optional cacert) + : TlsTransport(std::move(lower), std::move(host), std::move(certificate), std::move(callback)) { + + PLOG_DEBUG << "Setting up TLS certificate verification"; + +#if USE_GNUTLS + gnutls_session_set_verify_cert(mSession, mHost->c_str(), 0); +#elif USE_MBEDTLS + mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_REQUIRED); + mbedtls_x509_crt_init(&mCaCert); + try { + if (cacert) { + if (cacert->find(PemBeginCertificateTag) == string::npos) { + // *cacert is a file path + mbedtls::check(mbedtls_x509_crt_parse_file(&mCaCert, cacert->c_str())); + } else { + // *cacert is a PEM content + mbedtls::check(mbedtls_x509_crt_parse( + &mCaCert, reinterpret_cast(cacert->c_str()), + cacert->size() + 1)); + } + mbedtls_ssl_conf_ca_chain(&mConf, &mCaCert, NULL); + } + } catch (...) { + mbedtls_x509_crt_free(&mCaCert); + throw; + } +#else + if (cacert) { + if (cacert->find(PemBeginCertificateTag) == string::npos) { + // *cacert is a file path + openssl::check(SSL_CTX_load_verify_locations(mCtx, cacert->c_str(), NULL), "Failed to load CA certificate"); + } else { + // *cacert is a PEM content + PLOG_WARNING << "CA certificate as PEM is not supported for OpenSSL"; + } + } + SSL_set_verify(mSsl, SSL_VERIFY_PEER, NULL); + SSL_set_verify_depth(mSsl, 4); +#endif +} + +VerifiedTlsTransport::~VerifiedTlsTransport() { + stop(); +#if USE_MBEDTLS + mbedtls_x509_crt_free(&mCaCert); +#endif +} + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/verifiedtlstransport.hpp b/datachannel/src/impl/verifiedtlstransport.hpp new file mode 100644 index 000000000..0d38feba5 --- /dev/null +++ b/datachannel/src/impl/verifiedtlstransport.hpp @@ -0,0 +1,35 @@ +/** + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_VERIFIED_TLS_TRANSPORT_H +#define RTC_IMPL_VERIFIED_TLS_TRANSPORT_H + +#include "tlstransport.hpp" + +#if RTC_ENABLE_WEBSOCKET + +namespace rtc::impl { + +class VerifiedTlsTransport final : public TlsTransport { +public: + VerifiedTlsTransport(variant, shared_ptr> lower, + string host, certificate_ptr certificate, state_callback callback, + optional cacert); + ~VerifiedTlsTransport(); + +private: +#if USE_MBEDTLS + mbedtls_x509_crt mCaCert; +#endif +}; + +} // namespace rtc::impl + +#endif + +#endif diff --git a/datachannel/src/impl/websocket.cpp b/datachannel/src/impl/websocket.cpp new file mode 100644 index 000000000..70381a7e6 --- /dev/null +++ b/datachannel/src/impl/websocket.cpp @@ -0,0 +1,533 @@ +/** + * Copyright (c) 2020-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#if RTC_ENABLE_WEBSOCKET + +#include "websocket.hpp" +#include "common.hpp" +#include "internals.hpp" +#include "processor.hpp" +#include "utils.hpp" + +#include "httpproxytransport.hpp" +#include "tcptransport.hpp" +#include "tlstransport.hpp" +#include "verifiedtlstransport.hpp" +#include "wstransport.hpp" + +#include +#include +#include + +#ifdef _WIN32 +#include +#endif + +namespace rtc::impl { + +using namespace std::placeholders; +using namespace std::chrono_literals; +using std::chrono::milliseconds; + +WebSocket::WebSocket(optional optConfig, certificate_ptr certificate) + : config(optConfig ? std::move(*optConfig) : Configuration()), + mCertificate(certificate ? std::move(certificate) : std::move(loadCertificate(config))), + mIsSecure(mCertificate != nullptr), mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) { + PLOG_VERBOSE << "Creating WebSocket"; + if (config.proxyServer) { + if (config.proxyServer->type == ProxyServer::Type::Socks5) + throw std::invalid_argument( + "Proxy server support for WebSocket is not implemented for Socks5"); + if (config.proxyServer->username || config.proxyServer->password) { + PLOG_WARNING << "HTTP authentication support for proxy is not implemented"; + } + } +} + +certificate_ptr WebSocket::loadCertificate(const Configuration& config) { + if (!config.certificatePemFile) + return nullptr; + + if (config.keyPemFile) + return std::make_shared( + Certificate::FromFile(*config.certificatePemFile, *config.keyPemFile, + config.keyPemPass.value_or(""))); + + throw std::invalid_argument( + "Either none or both certificate and key PEM files must be specified"); +} + +WebSocket::~WebSocket() { PLOG_VERBOSE << "Destroying WebSocket"; } + +void WebSocket::open(const string &url) { + PLOG_VERBOSE << "Opening WebSocket to URL: " << url; + + if (state != State::Closed) + throw std::logic_error("WebSocket must be closed before opening"); + + // Modified regex from RFC 3986, see https://www.rfc-editor.org/rfc/rfc3986.html#appendix-B + static const char *rs = + R"(^(([^:.@/?#]+):)?(/{0,2}((([^:@]*)(:([^@]*))?)@)?(([^:/?#]*)(:([^/?#]*))?))?([^?#]*)(\?([^#]*))?(#(.*))?)"; + + static const std::regex r(rs, std::regex::extended); + + std::smatch m; + if (!std::regex_match(url, m, r) || m[10].length() == 0) + throw std::invalid_argument("Invalid WebSocket URL: " + url); + + string scheme = m[2]; + if (scheme.empty()) + scheme = "ws"; + + if (scheme != "ws" && scheme != "wss") + throw std::invalid_argument("Invalid WebSocket scheme: " + scheme); + + mIsSecure = (scheme != "ws"); + + string username = utils::url_decode(m[6]); + string password = utils::url_decode(m[8]); + if (!username.empty() || !password.empty()) { + PLOG_WARNING << "HTTP authentication support for WebSocket is not implemented"; + } + + string host; + string hostname = m[10]; + string service = m[12]; + if (service.empty()) { + service = mIsSecure ? "443" : "80"; + host = hostname; + } else { + host = hostname + ':' + service; + } + + if (hostname.front() == '[' && hostname.back() == ']') { + // IPv6 literal + hostname.erase(hostname.begin()); + hostname.pop_back(); + } else { + hostname = utils::url_decode(hostname); + } + + string path = m[13]; + if (path.empty()) + path += '/'; + + if (string query = m[15]; !query.empty()) + path += "?" + query; + + mHostname = hostname; // for TLS SNI and Proxy + mService = service; // For proxy + std::atomic_store(&mWsHandshake, std::make_shared(host, path, config.protocols)); + + changeState(State::Connecting); + + if (config.proxyServer) { + setTcpTransport(std::make_shared( + config.proxyServer->hostname, std::to_string(config.proxyServer->port), nullptr)); + } else { + setTcpTransport(std::make_shared(hostname, service, nullptr)); + } +} + +void WebSocket::close() { + auto s = state.load(); + if (s == State::Connecting || s == State::Open) { + PLOG_VERBOSE << "Closing WebSocket"; + changeState(State::Closing); + if (auto transport = std::atomic_load(&mWsTransport)) + transport->stop(); + else + remoteClose(); + } +} + +void WebSocket::remoteClose() { + close(); + if (state.load() != State::Closed) + closeTransports(); +} + +bool WebSocket::isOpen() const { return state == State::Open; } + +bool WebSocket::isClosed() const { return state == State::Closed; } + +size_t WebSocket::maxMessageSize() const { return config.maxMessageSize.value_or(DEFAULT_WS_MAX_MESSAGE_SIZE); } + +optional WebSocket::receive() { + auto next = mRecvQueue.pop(); + return next ? std::make_optional(to_variant(std::move(**next))) : nullopt; +} + +optional WebSocket::peek() { + auto next = mRecvQueue.peek(); + return next ? std::make_optional(to_variant(std::move(**next))) : nullopt; +} + +size_t WebSocket::availableAmount() const { return mRecvQueue.amount(); } + +bool WebSocket::changeState(State newState) { return state.exchange(newState) != newState; } + +bool WebSocket::outgoing(message_ptr message) { + if (state != State::Open || !mWsTransport) + throw std::runtime_error("WebSocket is not open"); + + if (message->size() > maxMessageSize()) + throw std::runtime_error("Message size exceeds limit"); + + return mWsTransport->send(message); +} + +void WebSocket::incoming(message_ptr message) { + if (!message) { + remoteClose(); + return; + } + + if (message->type == Message::String || message->type == Message::Binary) { + mRecvQueue.push(message); + triggerAvailable(mRecvQueue.size()); + } +} + +// Helper for WebSocket::initXTransport methods: start and emplace the transport +template +shared_ptr emplaceTransport(WebSocket *ws, shared_ptr *member, shared_ptr transport) { + std::atomic_store(member, transport); + try { + transport->start(); + } catch (...) { + std::atomic_store(member, decltype(transport)(nullptr)); + transport->stop(); + throw; + } + + if (ws->state == WebSocket::State::Closed) { + std::atomic_store(member, decltype(transport)(nullptr)); + transport->stop(); + return nullptr; + } + + return transport; +} + +shared_ptr WebSocket::setTcpTransport(shared_ptr transport) { + PLOG_VERBOSE << "Starting TCP transport"; + + if (!transport) + throw std::logic_error("TCP transport is null"); + + using State = TcpTransport::State; + try { + if (std::atomic_load(&mTcpTransport)) + throw std::logic_error("TCP transport is already set"); + + transport->onBufferedAmount(weak_bind(&WebSocket::triggerBufferedAmount, this, _1)); + + transport->onStateChange([this, weak_this = weak_from_this()](State transportState) { + auto shared_this = weak_this.lock(); + if (!shared_this) + return; + switch (transportState) { + case State::Connected: + if (config.proxyServer) + initProxyTransport(); + else if (mIsSecure) + initTlsTransport(); + else + initWsTransport(); + break; + case State::Failed: + triggerError("TCP connection failed"); + remoteClose(); + break; + case State::Disconnected: + remoteClose(); + break; + default: + // Ignore + break; + } + }); + + // WS transport sends a ping on read timeout + auto pingInterval = config.pingInterval.value_or(10000ms); + if (pingInterval > milliseconds::zero()) + transport->setReadTimeout(pingInterval); + + scheduleConnectionTimeout(); + + return emplaceTransport(this, &mTcpTransport, std::move(transport)); + + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + remoteClose(); + throw std::runtime_error("TCP transport initialization failed"); + } +} + +shared_ptr WebSocket::initProxyTransport() { + PLOG_VERBOSE << "Starting Tcp Proxy transport"; + using State = HttpProxyTransport::State; + try { + if (auto transport = std::atomic_load(&mProxyTransport)) + return transport; + + auto lower = std::atomic_load(&mTcpTransport); + if (!lower) + throw std::logic_error("No underlying TCP transport for Proxy transport"); + + auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) { + auto shared_this = weak_this.lock(); + if (!shared_this) + return; + switch (transportState) { + case State::Connected: + if (mIsSecure) + initTlsTransport(); + else + initWsTransport(); + break; + case State::Failed: + triggerError("Proxy connection failed"); + remoteClose(); + break; + case State::Disconnected: + remoteClose(); + break; + default: + // Ignore + break; + } + }; + + auto transport = std::make_shared( + lower, mHostname.value(), mService.value(), stateChangeCallback); + + return emplaceTransport(this, &mProxyTransport, std::move(transport)); + + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + remoteClose(); + throw std::runtime_error("Tcp Proxy transport initialization failed"); + } +} + +shared_ptr WebSocket::initTlsTransport() { + PLOG_VERBOSE << "Starting TLS transport"; + using State = TlsTransport::State; + try { + if (auto transport = std::atomic_load(&mTlsTransport)) + return transport; + + variant, shared_ptr> lower; + if (config.proxyServer) { + auto transport = std::atomic_load(&mProxyTransport); + if (!transport) + throw std::logic_error("No underlying proxy transport for TLS transport"); + + lower = transport; + } else { + auto transport = std::atomic_load(&mTcpTransport); + if (!transport) + throw std::logic_error("No underlying TCP transport for TLS transport"); + + lower = transport; + } + + auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) { + auto shared_this = weak_this.lock(); + if (!shared_this) + return; + switch (transportState) { + case State::Connected: + initWsTransport(); + break; + case State::Failed: + triggerError("TLS connection failed"); + remoteClose(); + break; + case State::Disconnected: + remoteClose(); + break; + default: + // Ignore + break; + } + }; + + bool verify = mHostname.has_value() && !config.disableTlsVerification; + +#ifdef _WIN32 + if (std::exchange(verify, false)) { + PLOG_WARNING << "TLS certificate verification with root CA is not supported on Windows"; + } +#endif + + shared_ptr transport; + if (verify) + transport = std::make_shared(lower, mHostname.value(), + mCertificate, stateChangeCallback, + config.caCertificatePemFile); + else + transport = + std::make_shared(lower, mHostname, mCertificate, stateChangeCallback); + + return emplaceTransport(this, &mTlsTransport, std::move(transport)); + + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + remoteClose(); + throw std::runtime_error("TLS transport initialization failed"); + } +} + +shared_ptr WebSocket::initWsTransport() { + PLOG_VERBOSE << "Starting WebSocket transport"; + using State = WsTransport::State; + try { + if (auto transport = std::atomic_load(&mWsTransport)) + return transport; + + variant, shared_ptr, shared_ptr> + lower; + if (mIsSecure) { + auto transport = std::atomic_load(&mTlsTransport); + if (!transport) + throw std::logic_error("No underlying TLS transport for WebSocket transport"); + + lower = transport; + } else if (config.proxyServer) { + auto transport = std::atomic_load(&mProxyTransport); + if (!transport) + throw std::logic_error("No underlying proxy transport for WebSocket transport"); + + lower = transport; + } else { + auto transport = std::atomic_load(&mTcpTransport); + if (!transport) + throw std::logic_error("No underlying TCP transport for WebSocket transport"); + + lower = transport; + } + + if (!atomic_load(&mWsHandshake)) + atomic_store(&mWsHandshake, std::make_shared()); + + auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) { + auto shared_this = weak_this.lock(); + if (!shared_this) + return; + switch (transportState) { + case State::Connected: + if (state == WebSocket::State::Connecting) { + PLOG_DEBUG << "WebSocket open"; + if (changeState(WebSocket::State::Open)) + triggerOpen(); + } + break; + case State::Failed: + triggerError("WebSocket connection failed"); + remoteClose(); + break; + case State::Disconnected: + remoteClose(); + break; + default: + // Ignore + break; + } + }; + + auto transport = std::make_shared(lower, mWsHandshake, config, + weak_bind(&WebSocket::incoming, this, _1), + stateChangeCallback); + + return emplaceTransport(this, &mWsTransport, std::move(transport)); + + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + remoteClose(); + throw std::runtime_error("WebSocket transport initialization failed"); + } +} + +shared_ptr WebSocket::getTcpTransport() const { + return std::atomic_load(&mTcpTransport); +} + +shared_ptr WebSocket::getTlsTransport() const { + return std::atomic_load(&mTlsTransport); +} + +shared_ptr WebSocket::getWsTransport() const { + return std::atomic_load(&mWsTransport); +} + +shared_ptr WebSocket::getWsHandshake() const { + return std::atomic_load(&mWsHandshake); +} + +void WebSocket::closeTransports() { + PLOG_VERBOSE << "Closing transports"; + + if (!changeState(State::Closed)) + return; // already closed + + // Pass the pointers to a thread, allowing to terminate a transport from its own thread + auto ws = std::atomic_exchange(&mWsTransport, decltype(mWsTransport)(nullptr)); + auto tls = std::atomic_exchange(&mTlsTransport, decltype(mTlsTransport)(nullptr)); + auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr)); + + if (ws) + ws->onRecv(nullptr); + + if (tcp) + tcp->onBufferedAmount(nullptr); + + using array = std::array, 3>; + array transports{std::move(ws), std::move(tls), std::move(tcp)}; + + for (const auto &t : transports) + if (t) + t->onStateChange(nullptr); + + TearDownProcessor::Instance().enqueue( + [transports = std::move(transports), token = Init::Instance().token()]() mutable { + for (const auto &t : transports) { + if (t) { + t->stop(); + break; + } + } + + for (auto &t : transports) + t.reset(); + }); + + triggerClosed(); +} + +void WebSocket::scheduleConnectionTimeout() { + auto defaultTimeout = 30s; + auto timeout = config.connectionTimeout.value_or(milliseconds(defaultTimeout)); + if (timeout > milliseconds::zero()) { + ThreadPool::Instance().schedule(timeout, [weak_this = weak_from_this()]() { + if (auto locked = weak_this.lock()) { + if (locked->state == WebSocket::State::Connecting) { + PLOG_WARNING << "WebSocket connection timed out"; + locked->triggerError("Connection timed out"); + locked->remoteClose(); + } + } + }); + } +} + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/websocket.hpp b/datachannel/src/impl/websocket.hpp new file mode 100644 index 000000000..ef82068d8 --- /dev/null +++ b/datachannel/src/impl/websocket.hpp @@ -0,0 +1,95 @@ +/** + * Copyright (c) 2020-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_WEBSOCKET_H +#define RTC_IMPL_WEBSOCKET_H + +#if RTC_ENABLE_WEBSOCKET + +#include "channel.hpp" +#include "common.hpp" +#include "httpproxytransport.hpp" +#include "init.hpp" +#include "message.hpp" +#include "queue.hpp" +#include "tcptransport.hpp" +#include "tlstransport.hpp" +#include "wstransport.hpp" + +#include "rtc/websocket.hpp" + +#include +#include + +namespace rtc::impl { + +struct WebSocket final : public Channel, public std::enable_shared_from_this { + using State = rtc::WebSocket::State; + using Configuration = rtc::WebSocket::Configuration; + + WebSocket(optional optConfig = nullopt, certificate_ptr certificate = nullptr); + ~WebSocket(); + + void open(const string &url); + void close(); + void remoteClose(); + bool outgoing(message_ptr message); + void incoming(message_ptr message); + + optional receive() override; + optional peek() override; + size_t availableAmount() const override; + + bool isOpen() const; + bool isClosed() const; + size_t maxMessageSize() const; + + bool changeState(State state); + + shared_ptr setTcpTransport(shared_ptr transport); + shared_ptr initProxyTransport(); + shared_ptr initTlsTransport(); + shared_ptr initWsTransport(); + shared_ptr getTcpTransport() const; + shared_ptr getTlsTransport() const; + shared_ptr getWsTransport() const; + shared_ptr getWsHandshake() const; + + void closeTransports(); + + const Configuration config; + + std::atomic state = State::Closed; + +private: + static certificate_ptr loadCertificate(const Configuration& config); + + void scheduleConnectionTimeout(); + + const init_token mInitToken = Init::Instance().token(); + + const certificate_ptr mCertificate; + bool mIsSecure; + + optional mHostname; // for TLS SNI and Proxy + optional mService; // for Proxy + + shared_ptr mTcpTransport; + shared_ptr mProxyTransport; + shared_ptr mTlsTransport; + shared_ptr mWsTransport; + shared_ptr mWsHandshake; + + Queue mRecvQueue; +}; + +} // namespace rtc::impl + +#endif + +#endif // RTC_IMPL_WEBSOCKET_H diff --git a/datachannel/src/impl/websocketserver.cpp b/datachannel/src/impl/websocketserver.cpp new file mode 100644 index 000000000..79b1d584f --- /dev/null +++ b/datachannel/src/impl/websocketserver.cpp @@ -0,0 +1,102 @@ +/** + * Copyright (c) 2020-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#if RTC_ENABLE_WEBSOCKET + +#include "websocketserver.hpp" +#include "common.hpp" +#include "internals.hpp" +#include "threadpool.hpp" +#include "utils.hpp" + +namespace rtc::impl { + +using namespace std::placeholders; + +const string PemBeginCertificateTag = "-----BEGIN CERTIFICATE-----"; + +WebSocketServer::WebSocketServer(Configuration config_) + : config(std::move(config_)), mStopped(false) { + PLOG_VERBOSE << "Creating WebSocketServer"; + + // Create certificate + if (config.enableTls) { + if (config.certificatePemFile && config.keyPemFile) { + mCertificate = std::make_shared( + config.certificatePemFile->find(PemBeginCertificateTag) != string::npos + ? Certificate::FromString(*config.certificatePemFile, *config.keyPemFile) + : Certificate::FromFile(*config.certificatePemFile, *config.keyPemFile, + config.keyPemPass.value_or(""))); + + } else if (!config.certificatePemFile && !config.keyPemFile) { + mCertificate = std::make_shared( + Certificate::Generate(CertificateType::Default, "localhost")); + } else { + throw std::invalid_argument( + "Either none or both certificate and key PEM files must be specified"); + } + } + + const char *bindAddress = nullptr; + if (config.bindAddress) { + bindAddress = config.bindAddress->c_str(); + } + // Create TCP server + tcpServer = std::make_unique(config.port, bindAddress); + + // Create server thread + mThread = std::thread(&WebSocketServer::runLoop, this); +} + +WebSocketServer::~WebSocketServer() { + PLOG_VERBOSE << "Destroying WebSocketServer"; + stop(); +} + +void WebSocketServer::stop() { + if (mStopped.exchange(true)) + return; + + PLOG_DEBUG << "Stopping WebSocketServer thread"; + tcpServer->close(); + mThread.join(); +} + +void WebSocketServer::runLoop() { + utils::this_thread::set_name("RTC server"); + PLOG_INFO << "Starting WebSocketServer"; + + try { + while (auto incoming = tcpServer->accept()) { + try { + if (!clientCallback) + continue; + + WebSocket::Configuration clientConfig; + clientConfig.connectionTimeout = config.connectionTimeout; + clientConfig.maxMessageSize = config.maxMessageSize; + + auto impl = std::make_shared(std::move(clientConfig), mCertificate); + impl->changeState(WebSocket::State::Connecting); + impl->setTcpTransport(incoming); + clientCallback(std::make_shared(impl)); + + } catch (const std::exception &e) { + PLOG_ERROR << "WebSocketServer: " << e.what(); + } + } + } catch (const std::exception &e) { + PLOG_FATAL << "WebSocketServer: " << e.what(); + } + + PLOG_INFO << "Stopped WebSocketServer"; +} + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/websocketserver.hpp b/datachannel/src/impl/websocketserver.hpp new file mode 100644 index 000000000..09e082a4a --- /dev/null +++ b/datachannel/src/impl/websocketserver.hpp @@ -0,0 +1,55 @@ +/** + * Copyright (c) 2020-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_WEBSOCKETSERVER_H +#define RTC_IMPL_WEBSOCKETSERVER_H + +#if RTC_ENABLE_WEBSOCKET + +#include "certificate.hpp" +#include "common.hpp" +#include "init.hpp" +#include "message.hpp" +#include "tcpserver.hpp" +#include "websocket.hpp" + +#include "rtc/websocket.hpp" +#include "rtc/websocketserver.hpp" + +#include +#include + +namespace rtc::impl { + +struct WebSocketServer final : public std::enable_shared_from_this { + using Configuration = rtc::WebSocketServer::Configuration; + + WebSocketServer(Configuration config_); + ~WebSocketServer(); + + void stop(); + + const Configuration config; + unique_ptr tcpServer; + synchronized_callback> clientCallback; + +private: + const init_token mInitToken = Init::Instance().token(); + + void runLoop(); + + certificate_ptr mCertificate; + std::thread mThread; + std::atomic mStopped; +}; + +} // namespace rtc::impl + +#endif + +#endif // RTC_IMPL_WEBSOCKET_H diff --git a/datachannel/src/impl/wshandshake.cpp b/datachannel/src/impl/wshandshake.cpp new file mode 100644 index 000000000..155b94611 --- /dev/null +++ b/datachannel/src/impl/wshandshake.cpp @@ -0,0 +1,254 @@ +/** + * Copyright (c) 2020-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "wshandshake.hpp" +#include "http.hpp" +#include "internals.hpp" +#include "sha.hpp" +#include "utils.hpp" + +#if RTC_ENABLE_WEBSOCKET + +#include +#include +#include +#include +#include +#include + +using std::string; + +namespace rtc::impl { + +using std::to_string; +using std::chrono::system_clock; + +WsHandshake::WsHandshake() {} + +WsHandshake::WsHandshake(string host, string path, std::vector protocols) + : mHost(std::move(host)), mPath(std::move(path)), mProtocols(std::move(protocols)) { + + if (mHost.empty()) + throw std::invalid_argument("WebSocket HTTP host cannot be empty"); + + if (mPath.empty()) + throw std::invalid_argument("WebSocket HTTP path cannot be empty"); +} + +string WsHandshake::host() const { + std::unique_lock lock(mMutex); + return mHost; +} + +string WsHandshake::path() const { + std::unique_lock lock(mMutex); + return mPath; +} + +std::vector WsHandshake::protocols() const { + std::unique_lock lock(mMutex); + return mProtocols; +} + +string WsHandshake::generateHttpRequest() { + std::unique_lock lock(mMutex); + mKey = generateKey(); + + string out = "GET " + mPath + + " HTTP/1.1\r\n" + "Host: " + + mHost + + "\r\n" + "Connection: upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Version: 13\r\n" + "Sec-WebSocket-Key: " + + mKey + "\r\n"; + + if (!mProtocols.empty()) + out += "Sec-WebSocket-Protocol: " + utils::implode(mProtocols, ',') + "\r\n"; + + out += "\r\n"; + + return out; +} + +string WsHandshake::generateHttpResponse() { + std::unique_lock lock(mMutex); + const string out = "HTTP/1.1 101 Switching Protocols\r\n" + "Server: libdatachannel\r\n" + "Connection: upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Accept: " + + computeAcceptKey(mKey) + "\r\n\r\n"; + + return out; +} + +namespace { + +string GetHttpErrorName(int responseCode) { + switch (responseCode) { + case 400: + return "Bad Request"; + case 404: + return "Not Found"; + case 405: + return "Method Not Allowed"; + case 426: + return "Upgrade Required"; + case 500: + return "Internal Server Error"; + default: + return "Error"; + } +} + +} // namespace + +string WsHandshake::generateHttpError(int responseCode) { + std::unique_lock lock(mMutex); + + const string error = to_string(responseCode) + " " + GetHttpErrorName(responseCode); + + const string out = "HTTP/1.1 " + error + + "\r\n" + "Server: libdatachannel\r\n" + "Connection: upgrade\r\n" + "Upgrade: websocket\r\n" + "Content-Type: text/plain\r\n" + "Content-Length: " + + to_string(error.size()) + + "\r\n" + "Access-Control-Allow-Origin: *\r\n\r\n" + + error; + + return out; +} + +size_t WsHandshake::parseHttpRequest(const byte *buffer, size_t size) { + if (!isHttpRequest(buffer, size)) + throw RequestError("Invalid HTTP request for WebSocket", 400); + + std::unique_lock lock(mMutex); + std::list lines; + size_t length = parseHttpLines(buffer, size, lines); + if (length == 0) + return 0; + + if (lines.empty()) + throw RequestError("Invalid HTTP request for WebSocket", 400); + + std::istringstream requestLine(std::move(lines.front())); + lines.pop_front(); + + string method, path, protocol; + requestLine >> method >> path >> protocol; + PLOG_DEBUG << "WebSocket request method=\"" << method << "\", path=\"" << path << "\""; + if (method != "GET") + throw RequestError("Invalid request method \"" + method + "\" for WebSocket", 405); + + mPath = std::move(path); + + auto headers = parseHttpHeaders(lines); + + auto h = headers.find("host"); + if (h == headers.end()) + throw RequestError("WebSocket host header missing in request", 400); + + mHost = std::move(h->second); + + h = headers.find("upgrade"); + if (h == headers.end()) + throw RequestError("WebSocket upgrade header missing in request", 426); + + string upgrade; + std::transform(h->second.begin(), h->second.end(), std::back_inserter(upgrade), + [](char c) { return std::tolower(c); }); + if (upgrade != "websocket") + throw RequestError("WebSocket upgrade header mismatching", 426); + + h = headers.find("sec-websocket-key"); + if (h == headers.end()) + throw RequestError("WebSocket key header missing in request", 400); + + mKey = std::move(h->second); + + h = headers.find("sec-websocket-protocol"); + if (h != headers.end()) + mProtocols = utils::explode(h->second, ','); + + return length; +} + +size_t WsHandshake::parseHttpResponse(const byte *buffer, size_t size) { + std::unique_lock lock(mMutex); + std::list lines; + size_t length = parseHttpLines(buffer, size, lines); + if (length == 0) + return 0; + + if (lines.empty()) + throw Error("Invalid HTTP response for WebSocket"); + + std::istringstream status(std::move(lines.front())); + lines.pop_front(); + + string protocol; + unsigned int code = 0; + status >> protocol >> code; + PLOG_DEBUG << "WebSocket response code=" << code; + if (code != 101) + throw std::runtime_error("Unexpected response code " + to_string(code) + " for WebSocket"); + + auto headers = parseHttpHeaders(lines); + + auto h = headers.find("upgrade"); + if (h == headers.end()) + throw Error("WebSocket update header missing"); + + string upgrade; + std::transform(h->second.begin(), h->second.end(), std::back_inserter(upgrade), + [](char c) { return std::tolower(c); }); + if (upgrade != "websocket") + throw Error("WebSocket update header mismatching"); + + h = headers.find("sec-websocket-accept"); + if (h == headers.end()) + throw Error("WebSocket accept header missing"); + + if (h->second != computeAcceptKey(mKey)) + throw Error("WebSocket accept header is invalid"); + + return length; +} + +string WsHandshake::generateKey() { + // RFC 6455: The request MUST include a header field with the name Sec-WebSocket-Key. The value + // of this header field MUST be a nonce consisting of a randomly selected 16-byte value that has + // been base64-encoded. [...] The nonce MUST be selected randomly for each connection. + binary key(16); + auto k = reinterpret_cast(key.data()); + std::generate(k, k + key.size(), utils::random_bytes_engine()); + return utils::base64_encode(key); +} + +string WsHandshake::computeAcceptKey(const string &key) { + return utils::base64_encode(Sha1(string(key) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")); +} + +WsHandshake::Error::Error(const string &w) : std::runtime_error(w) {} + +WsHandshake::RequestError::RequestError(const string &w, int responseCode) + : Error(w), mResponseCode(responseCode) {} + +int WsHandshake::RequestError::RequestError::responseCode() const { return mResponseCode; } + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/wshandshake.hpp b/datachannel/src/impl/wshandshake.hpp new file mode 100644 index 000000000..d59237423 --- /dev/null +++ b/datachannel/src/impl/wshandshake.hpp @@ -0,0 +1,68 @@ +/** + * Copyright (c) 2020-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_WS_HANDSHAKE_H +#define RTC_IMPL_WS_HANDSHAKE_H + +#include "common.hpp" + +#if RTC_ENABLE_WEBSOCKET + +#include +#include +#include +#include + +namespace rtc::impl { + +class WsHandshake final { +public: + WsHandshake(); + WsHandshake(string host, string path = "/", std::vector protocols = {}); + + string host() const; + string path() const; + std::vector protocols() const; + + string generateHttpRequest(); + string generateHttpResponse(); + string generateHttpError(int responseCode = 400); + + class Error : public std::runtime_error { + public: + explicit Error(const string &w); + }; + + class RequestError : public Error { + public: + explicit RequestError(const string &w, int responseCode = 400); + int responseCode() const; + + private: + const int mResponseCode; + }; + + size_t parseHttpRequest(const byte *buffer, size_t size); + size_t parseHttpResponse(const byte *buffer, size_t size); + +private: + static string generateKey(); + static string computeAcceptKey(const string &key); + + string mHost; + string mPath; + std::vector mProtocols; + string mKey; + mutable std::mutex mMutex; +}; + +} // namespace rtc::impl + +#endif + +#endif diff --git a/datachannel/src/impl/wstransport.cpp b/datachannel/src/impl/wstransport.cpp new file mode 100644 index 000000000..1ddf629d5 --- /dev/null +++ b/datachannel/src/impl/wstransport.cpp @@ -0,0 +1,424 @@ +/** + * Copyright (c) 2020-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "wstransport.hpp" +#include "httpproxytransport.hpp" +#include "tcptransport.hpp" +#include "threadpool.hpp" +#include "tlstransport.hpp" +#include "utils.hpp" + +#if RTC_ENABLE_WEBSOCKET + +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#else +#include +#endif + +#ifndef htonll +#define htonll(x) \ + ((uint64_t)(((uint64_t)htonl((uint32_t)(x))) << 32) | (uint64_t)htonl((uint32_t)((x) >> 32))) +#endif +#ifndef ntohll +#define ntohll(x) htonll(x) +#endif + +namespace rtc::impl { + +using std::to_integer; +using std::to_string; +using std::chrono::system_clock; + +WsTransport::WsTransport(LowerTransport lower, shared_ptr handshake, + const WebSocketConfiguration &config, message_callback recvCallback, + state_callback stateCallback) + : Transport(std::visit([](auto l) { return std::static_pointer_cast(l); }, lower), + std::move(stateCallback)), + mHandshake(std::move(handshake)), + mIsClient( + std::visit(rtc::overloaded{[](auto l) { return l->isActive(); }, + [](shared_ptr l) { return l->isClient(); }}, + lower)), + mMaxMessageSize(config.maxMessageSize.value_or(DEFAULT_WS_MAX_MESSAGE_SIZE)), + mMaxOutstandingPings(config.maxOutstandingPings.value_or(0)) { + + onRecv(std::move(recvCallback)); + + PLOG_DEBUG << "Initializing WebSocket transport"; +} + +WsTransport::~WsTransport() { unregisterIncoming(); } + +void WsTransport::start() { + registerIncoming(); + + changeState(State::Connecting); + if (mIsClient) + sendHttpRequest(); +} + +void WsTransport::stop() { close(); } + +bool WsTransport::send(message_ptr message) { + if (state() != State::Connected) + throw std::runtime_error("WebSocket is not open"); + + if (!message) + return false; + + PLOG_VERBOSE << "Send size=" << message->size(); + return sendFrame({message->type == Message::String ? TEXT_FRAME : BINARY_FRAME, message->data(), + message->size(), true, mIsClient}); +} + +void WsTransport::close() { + if (state() != State::Connected) + return; + + if (mCloseSent.exchange(true)) + return; + + PLOG_INFO << "WebSocket closing"; + try { + sendFrame({CLOSE, NULL, 0, true, mIsClient}); + } catch (const std::exception &e) { + // The connection might not be open anymore + PLOG_DEBUG << "Unable to send WebSocket close frame: " << e.what(); + changeState(State::Disconnected); + return; + } + + ThreadPool::Instance().schedule(std::chrono::seconds(10), + [this, weak_this = weak_from_this()]() { + if (auto shared_this = weak_this.lock()) { + PLOG_DEBUG << "WebSocket close timeout"; + changeState(State::Disconnected); + } + }); +} + +void WsTransport::incoming(message_ptr message) { + auto s = state(); + if (s != State::Connecting && s != State::Connected) + return; // Drop + + if (message) { + PLOG_VERBOSE << "Incoming size=" << message->size(); + + try { + mBuffer.insert(mBuffer.end(), message->begin(), message->end()); + + if (state() == State::Connecting) { + if (mIsClient) { + if (size_t len = + mHandshake->parseHttpResponse(mBuffer.data(), mBuffer.size())) { + PLOG_INFO << "WebSocket client-side open"; + changeState(State::Connected); + mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len); + } + } else { + if (size_t len = mHandshake->parseHttpRequest(mBuffer.data(), mBuffer.size())) { + PLOG_INFO << "WebSocket server-side open"; + sendHttpResponse(); + changeState(State::Connected); + mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len); + } + } + } + + if (state() == State::Connected) { + if (message->size() == 0) { + // TCP is idle, send a ping + PLOG_DEBUG << "WebSocket sending ping"; + uint32_t dummy = 0; + sendFrame({PING, reinterpret_cast(&dummy), 4, true, mIsClient}); + addOutstandingPing(); + } else { + if (mIgnoreLength > 0) { + size_t len = std::min(mIgnoreLength, mBuffer.size()); + mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len); + mIgnoreLength -= len; + } + if (mIgnoreLength == 0) { + Frame frame; + while (size_t len = parseFrame(mBuffer.data(), mBuffer.size(), frame)) { + recvFrame(frame); + if (len > mBuffer.size()) { + mIgnoreLength = len - mBuffer.size(); + mBuffer.clear(); + break; + } + mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len); + } + } + } + } + + return; + + } catch (const WsHandshake::RequestError &e) { + PLOG_WARNING << e.what(); + try { + sendHttpError(e.responseCode()); + + } catch (const std::exception &e) { + PLOG_WARNING << e.what(); + } + + } catch (const WsHandshake::Error &e) { + PLOG_WARNING << e.what(); + + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + } + } + + if (state() == State::Connected) { + PLOG_INFO << "WebSocket disconnected"; + changeState(State::Disconnected); + recv(nullptr); + } else { + PLOG_ERROR << "WebSocket handshake failed"; + changeState(State::Failed); + } +} + +bool WsTransport::sendHttpRequest() { + PLOG_DEBUG << "Sending WebSocket HTTP request"; + + const string request = mHandshake->generateHttpRequest(); + auto data = reinterpret_cast(request.data()); + return outgoing(make_message(data, data + request.size())); +} + +bool WsTransport::sendHttpResponse() { + PLOG_DEBUG << "Sending WebSocket HTTP response"; + + const string response = mHandshake->generateHttpResponse(); + auto data = reinterpret_cast(response.data()); + return outgoing(make_message(data, data + response.size())); +} + +bool WsTransport::sendHttpError(int code) { + PLOG_WARNING << "Sending WebSocket HTTP error response " << code; + + const string response = mHandshake->generateHttpError(code); + auto data = reinterpret_cast(response.data()); + return outgoing(make_message(data, data + response.size())); +} + +// RFC6455 5.2. Base Framing Protocol +// https://www.rfc-editor.org/rfc/rfc6455.html#section-5.2 +// +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-------+-+-------------+-------------------------------+ +// |F|R|R|R| opcode|M| Payload len | Extended payload length | +// |I|S|S|S| (4) |A| (7) | (16/64) | +// |N|V|V|V| |S| | (if payload len==126/127) | +// | |1|2|3| |K| | | +// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + +// | Extended payload length continued, if payload len == 127 | +// + - - - - - - - - - - - - - - - +-------------------------------+ +// | | Masking-key, if MASK set to 1 | +// +-------------------------------+-------------------------------+ +// | Masking-key (continued) | Payload Data | +// +-------------------------------+ - - - - - - - - - - - - - - - + +// : Payload Data continued ... : +// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +// | Payload Data continued ... | +// +---------------------------------------------------------------+ + +size_t WsTransport::parseFrame(byte *buffer, size_t size, Frame &frame) { + const byte *end = buffer + size; + if (end - buffer < 2) + return 0; + + byte *cur = buffer; + auto b1 = to_integer(*cur++); + auto b2 = to_integer(*cur++); + + frame.fin = (b1 & 0x80) != 0; + frame.mask = (b2 & 0x80) != 0; + frame.opcode = static_cast(b1 & 0x0F); + frame.length = b2 & 0x7F; + + if (frame.length == 0x7E) { + if (end - cur < 2) + return 0; + frame.length = ntohs(*reinterpret_cast(cur)); + cur += 2; + } else if (frame.length == 0x7F) { + if (end - cur < 8) + return 0; + frame.length = ntohll(*reinterpret_cast(cur)); + cur += 8; + } + + const byte *maskingKey = nullptr; + if (frame.mask) { + if (end - cur < 4) + return 0; + maskingKey = cur; + cur += 4; + } + + const size_t maxControlFrameLength = 125; + const size_t maxFrameLength = std::max(maxControlFrameLength, mMaxMessageSize); + if (size_t(end - cur) < std::min(frame.length, maxFrameLength)) + return 0; + + size_t length = frame.length; + if (frame.length > maxFrameLength) { + PLOG_WARNING << "WebSocket frame is too large (length=" << frame.length + << "), truncating it"; + frame.length = maxFrameLength; + } + + frame.payload = cur; + + if (maskingKey) + for (size_t i = 0; i < frame.length; ++i) + frame.payload[i] ^= maskingKey[i % 4]; + + return frame.payload + length - buffer; // can be more than buffer size +} + +void WsTransport::recvFrame(const Frame &frame) { + PLOG_DEBUG << "WebSocket received frame: opcode=" << int(frame.opcode) + << ", length=" << frame.length; + + switch (frame.opcode) { + case TEXT_FRAME: + case BINARY_FRAME: { + size_t size = frame.length; + if (size > mMaxMessageSize) { + PLOG_WARNING << "WebSocket message is too large, truncating it"; + size = mMaxMessageSize; + } + if (!mPartial.empty()) { + PLOG_WARNING << "WebSocket unfinished message: type=" + << (mPartialOpcode == TEXT_FRAME ? "text" : "binary") + << ", size=" << mPartial.size(); + auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary; + recv(make_message(mPartial.begin(), mPartial.end(), type)); + mPartial.clear(); + } + mPartialOpcode = frame.opcode; + if (frame.fin) { + PLOG_DEBUG << "WebSocket finished message: type=" + << (frame.opcode == TEXT_FRAME ? "text" : "binary") << ", size=" << size; + auto type = frame.opcode == TEXT_FRAME ? Message::String : Message::Binary; + recv(make_message(frame.payload, frame.payload + size, type)); + } else { + mPartial.insert(mPartial.end(), frame.payload, frame.payload + size); + } + break; + } + case CONTINUATION: { + mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length); + if (mPartial.size() > mMaxMessageSize) { + PLOG_WARNING << "WebSocket message is too large, truncating it"; + mPartial.resize(mMaxMessageSize); + } + if (frame.fin) { + PLOG_DEBUG << "WebSocket finished message: type=" + << (frame.opcode == TEXT_FRAME ? "text" : "binary") + << ", size=" << mPartial.size(); + auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary; + recv(make_message(mPartial.begin(), mPartial.end(), type)); + mPartial.clear(); + } + break; + } + case PING: { + PLOG_DEBUG << "WebSocket received ping, sending pong"; + sendFrame({PONG, frame.payload, frame.length, true, mIsClient}); + break; + } + case PONG: { + PLOG_DEBUG << "WebSocket received pong"; + mOutstandingPings = 0; + break; + } + case CLOSE: { + PLOG_INFO << "WebSocket closed"; + close(); + changeState(State::Disconnected); + break; + } + default: { + PLOG_ERROR << "Unknown WebSocket opcode: " + to_string(frame.opcode); + close(); + break; + } + } +} + +bool WsTransport::sendFrame(const Frame &frame) { + std::lock_guard lock(mSendMutex); + + PLOG_DEBUG << "WebSocket sending frame: opcode=" << int(frame.opcode) + << ", length=" << frame.length; + + byte buffer[14]; + byte *cur = buffer; + + *cur++ = byte((frame.opcode & 0x0F) | (frame.fin ? 0x80 : 0)); + + if (frame.length < 0x7E) { + *cur++ = byte((frame.length & 0x7F) | (frame.mask ? 0x80 : 0)); + } else if (frame.length <= 0xFFFF) { + *cur++ = byte(0x7E | (frame.mask ? 0x80 : 0)); + *reinterpret_cast(cur) = htons(uint16_t(frame.length)); + cur += 2; + } else { + *cur++ = byte(0x7F | (frame.mask ? 0x80 : 0)); + *reinterpret_cast(cur) = htonll(uint64_t(frame.length)); + cur += 8; + } + + if (frame.mask) { + byte *maskingKey = reinterpret_cast(cur); + + auto u = reinterpret_cast(maskingKey); + std::generate(u, u + 4, utils::random_bytes_engine()); + cur += 4; + + for (size_t i = 0; i < frame.length; ++i) + frame.payload[i] ^= maskingKey[i % 4]; + } + + const size_t length = cur - buffer; // header length + auto message = make_message(length + frame.length); + std::copy(buffer, buffer + length, message->begin()); // header + std::copy(frame.payload, frame.payload + frame.length, + message->begin() + length); // payload + + return outgoing(std::move(message)); +} + +void WsTransport::addOutstandingPing() { + ++mOutstandingPings; + if (mMaxOutstandingPings > 0 && mOutstandingPings > mMaxOutstandingPings) { + changeState(State::Failed); + } +} + +} // namespace rtc::impl + +#endif diff --git a/datachannel/src/impl/wstransport.hpp b/datachannel/src/impl/wstransport.hpp new file mode 100644 index 000000000..5fee9dfe9 --- /dev/null +++ b/datachannel/src/impl/wstransport.hpp @@ -0,0 +1,91 @@ +/** + * Copyright (c) 2020-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#ifndef RTC_IMPL_WS_TRANSPORT_H +#define RTC_IMPL_WS_TRANSPORT_H + +#include "common.hpp" +#include "transport.hpp" +#include "configuration.hpp" +#include "wshandshake.hpp" + +#if RTC_ENABLE_WEBSOCKET + +#include + +namespace rtc::impl { + +class HttpProxyTransport; +class TcpTransport; +class TlsTransport; + +class WsTransport final : public Transport, public std::enable_shared_from_this { +public: + using LowerTransport = + variant, shared_ptr, shared_ptr>; + + WsTransport(LowerTransport lower, shared_ptr handshake, + const WebSocketConfiguration &config, message_callback recvCallback, + state_callback stateCallback); + ~WsTransport(); + + void start() override; + void stop() override; + bool send(message_ptr message) override; + void close(); + void incoming(message_ptr message) override; + + bool isClient() const { return mIsClient; } + +private: + enum Opcode : uint8_t { + CONTINUATION = 0, + TEXT_FRAME = 1, + BINARY_FRAME = 2, + CLOSE = 8, + PING = 9, + PONG = 10, + }; + + struct Frame { + Opcode opcode = BINARY_FRAME; + byte *payload = nullptr; + size_t length = 0; + bool fin = true; + bool mask = true; + }; + + bool sendHttpRequest(); + bool sendHttpError(int code); + bool sendHttpResponse(); + + size_t parseFrame(byte *buffer, size_t size, Frame &frame); + void recvFrame(const Frame &frame); + bool sendFrame(const Frame &frame); + + void addOutstandingPing(); + + const shared_ptr mHandshake; + const bool mIsClient; + const size_t mMaxMessageSize; + const int mMaxOutstandingPings; + + binary mBuffer; + binary mPartial; + Opcode mPartialOpcode; + size_t mIgnoreLength = 0; + std::mutex mSendMutex; + int mOutstandingPings = 0; + std::atomic mCloseSent = false; +}; + +} // namespace rtc::impl + +#endif + +#endif diff --git a/datachannel/src/mediahandler.cpp b/datachannel/src/mediahandler.cpp new file mode 100644 index 000000000..1b956f1e4 --- /dev/null +++ b/datachannel/src/mediahandler.cpp @@ -0,0 +1,80 @@ +/** + * Copyright (c) 2023 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "mediahandler.hpp" + +#include "impl/internals.hpp" + +namespace rtc { + +MediaHandler::MediaHandler() {} + +MediaHandler::~MediaHandler() {} + +void MediaHandler::addToChain(shared_ptr handler) { last()->setNext(handler); } + +void MediaHandler::setNext(shared_ptr handler) { + return std::atomic_store(&mNext, handler); +} + +shared_ptr MediaHandler::next() { return std::atomic_load(&mNext); } + +shared_ptr MediaHandler::next() const { return std::atomic_load(&mNext); } + +shared_ptr MediaHandler::last() { + if (auto handler = next()) + return handler->last(); + else + return shared_from_this(); +} + +shared_ptr MediaHandler::last() const { + if (auto handler = next()) + return handler->last(); + else + return shared_from_this(); +} + +bool MediaHandler::requestKeyframe(const message_callback &send) { + // Default implementation is to call next handler + if (auto handler = next()) + return handler->requestKeyframe(send); + else + return false; +} + +bool MediaHandler::requestBitrate(unsigned int bitrate, const message_callback &send) { + // Default implementation is to call next handler + if (auto handler = next()) + return handler->requestBitrate(bitrate, send); + else + return false; +} + +void MediaHandler::mediaChain(const Description::Media &desc) { + media(desc); + + if (auto handler = next()) + handler->mediaChain(desc); +} + +void MediaHandler::incomingChain(message_vector &messages, const message_callback &send) { + if (auto handler = next()) + handler->incomingChain(messages, send); + + incoming(messages, send); +} + +void MediaHandler::outgoingChain(message_vector &messages, const message_callback &send) { + outgoing(messages, send); + + if (auto handler = next()) + return handler->outgoingChain(messages, send); +} + +} // namespace rtc diff --git a/datachannel/src/message.cpp b/datachannel/src/message.cpp new file mode 100644 index 000000000..c2e582a35 --- /dev/null +++ b/datachannel/src/message.cpp @@ -0,0 +1,79 @@ +/** + * Copyright (c) 2019-2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "message.hpp" + +namespace rtc { + +message_ptr make_message(size_t size, Message::Type type, unsigned int stream, + shared_ptr reliability) { + auto message = std::make_shared(size, type); + message->stream = stream; + message->reliability = reliability; + return message; +} + +message_ptr make_message(binary &&data, Message::Type type, unsigned int stream, + shared_ptr reliability) { + auto message = std::make_shared(std::move(data), type); + message->stream = stream; + message->reliability = reliability; + return message; +} + +message_ptr make_message(size_t size, message_ptr orig) { + if(!orig) + return nullptr; + + auto message = std::make_shared(size, orig->type); + std::copy(orig->begin(), orig->begin() + std::min(size, orig->size()), message->begin()); + message->stream = orig->stream; + message->reliability = orig->reliability; + return message; +} + +message_ptr make_message(message_variant data) { + return std::visit( // + overloaded{ + [&](binary data) { return make_message(std::move(data), Message::Binary); }, + [&](string data) { + auto b = reinterpret_cast(data.data()); + return make_message(b, b + data.size(), Message::String); + }, + }, + std::move(data)); +} + +#if RTC_ENABLE_MEDIA + +message_ptr make_message_from_opaque_ptr(rtcMessage *&&message) { + auto ptr = std::unique_ptr(reinterpret_cast(message)); + return message_ptr(std::move(ptr)); +} + +#endif + +message_variant to_variant(Message &&message) { + switch (message.type) { + case Message::String: + return string(reinterpret_cast(message.data()), message.size()); + default: + return std::move(message); + } +} + +message_variant to_variant(const Message &message) { + switch (message.type) { + case Message::String: + return string(reinterpret_cast(message.data()), message.size()); + default: + return message; + } +} + +} // namespace rtc diff --git a/datachannel/src/nalunit.cpp b/datachannel/src/nalunit.cpp new file mode 100644 index 000000000..64741c1b4 --- /dev/null +++ b/datachannel/src/nalunit.cpp @@ -0,0 +1,99 @@ +/** + * Copyright (c) 2020 Filip Klembara (in2core) + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#if RTC_ENABLE_MEDIA + +#include "nalunit.hpp" + +#include "impl/internals.hpp" + +#include + +namespace rtc { + +NalUnitFragmentA::NalUnitFragmentA(FragmentType type, bool forbiddenBit, uint8_t nri, + uint8_t unitType, binary data) + : NalUnit(data.size() + 2) { + setForbiddenBit(forbiddenBit); + setNRI(nri); + fragmentIndicator()->setUnitType(NalUnitFragmentA::nal_type_fu_A); + setFragmentType(type); + setUnitType(unitType); + copy(data.begin(), data.end(), begin() + 2); +} + +std::vector> +NalUnitFragmentA::fragmentsFrom(shared_ptr nalu, uint16_t maxFragmentSize) { + assert(nalu->size() > maxFragmentSize); + auto fragments_count = ceil(double(nalu->size()) / maxFragmentSize); + maxFragmentSize = uint16_t(int(ceil(nalu->size() / fragments_count))); + + // 2 bytes for FU indicator and FU header + maxFragmentSize -= 2; + auto f = nalu->forbiddenBit(); + uint8_t nri = nalu->nri() & 0x03; + uint8_t naluType = nalu->unitType() & 0x1F; + auto payload = nalu->payload(); + vector> result{}; + uint64_t offset = 0; + while (offset < payload.size()) { + vector fragmentData; + FragmentType fragmentType; + if (offset == 0) { + fragmentType = FragmentType::Start; + } else if (offset + maxFragmentSize < payload.size()) { + fragmentType = FragmentType::Middle; + } else { + if (offset + maxFragmentSize > payload.size()) { + maxFragmentSize = uint16_t(payload.size() - offset); + } + fragmentType = FragmentType::End; + } + fragmentData = {payload.begin() + offset, payload.begin() + offset + maxFragmentSize}; + auto fragment = + std::make_shared(fragmentType, f, nri, naluType, fragmentData); + result.push_back(fragment); + offset += maxFragmentSize; + } + return result; +} + +void NalUnitFragmentA::setFragmentType(FragmentType type) { + fragmentHeader()->setReservedBit6(false); + switch (type) { + case FragmentType::Start: + fragmentHeader()->setStart(true); + fragmentHeader()->setEnd(false); + break; + case FragmentType::End: + fragmentHeader()->setStart(false); + fragmentHeader()->setEnd(true); + break; + default: + fragmentHeader()->setStart(false); + fragmentHeader()->setEnd(false); + } +} + +std::vector> NalUnits::generateFragments(uint16_t maxFragmentSize) { + vector> result{}; + for (auto nalu : *this) { + if (nalu->size() > maxFragmentSize) { + std::vector> fragments = + NalUnitFragmentA::fragmentsFrom(nalu, maxFragmentSize); + result.insert(result.end(), fragments.begin(), fragments.end()); + } else { + result.push_back(nalu); + } + } + return result; +} + +} // namespace rtc + +#endif /* RTC_ENABLE_MEDIA */ diff --git a/datachannel/src/peerconnection.cpp b/datachannel/src/peerconnection.cpp new file mode 100644 index 000000000..dd1ec7c82 --- /dev/null +++ b/datachannel/src/peerconnection.cpp @@ -0,0 +1,464 @@ +/** + * Copyright (c) 2019 Paul-Louis Ageneau + * Copyright (c) 2020 Filip Klembara (in2core) + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "peerconnection.hpp" +#include "common.hpp" +#include "rtp.hpp" + +#include "impl/certificate.hpp" +#include "impl/dtlstransport.hpp" +#include "impl/icetransport.hpp" +#include "impl/internals.hpp" +#include "impl/peerconnection.hpp" +#include "impl/sctptransport.hpp" +#include "impl/threadpool.hpp" +#include "impl/track.hpp" + +#if RTC_ENABLE_MEDIA +#include "impl/dtlssrtptransport.hpp" +#endif + +#include +#include +#include + +using namespace std::placeholders; + +namespace rtc { + +PeerConnection::PeerConnection() : PeerConnection(Configuration()) {} + +PeerConnection::PeerConnection(Configuration config) + : CheshireCat(std::move(config)) {} + +PeerConnection::~PeerConnection() { + try { + impl()->remoteClose(); + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + } +} + +void PeerConnection::close() { impl()->close(); } + +const Configuration *PeerConnection::config() const { return &impl()->config; } + +PeerConnection::State PeerConnection::state() const { return impl()->state; } + +PeerConnection::IceState PeerConnection::iceState() const { return impl()->iceState; } + +PeerConnection::GatheringState PeerConnection::gatheringState() const { + return impl()->gatheringState; +} + +PeerConnection::SignalingState PeerConnection::signalingState() const { + return impl()->signalingState; +} + +optional PeerConnection::localDescription() const { + return impl()->localDescription(); +} + +optional PeerConnection::remoteDescription() const { + return impl()->remoteDescription(); +} + +size_t PeerConnection::remoteMaxMessageSize() const { return impl()->remoteMaxMessageSize(); } + +bool PeerConnection::hasMedia() const { + auto local = localDescription(); + return local && local->hasAudioOrVideo(); +} + +void PeerConnection::setLocalDescription(Description::Type type) { + std::unique_lock signalingLock(impl()->signalingMutex); + PLOG_VERBOSE << "Setting local description, type=" << Description::typeToString(type); + + SignalingState signalingState = impl()->signalingState.load(); + if (type == Description::Type::Rollback) { + if (signalingState == SignalingState::HaveLocalOffer || + signalingState == SignalingState::HaveLocalPranswer) { + impl()->rollbackLocalDescription(); + impl()->changeSignalingState(SignalingState::Stable); + } + return; + } + + // Guess the description type if unspecified + if (type == Description::Type::Unspec) { + if (signalingState == SignalingState::HaveRemoteOffer) + type = Description::Type::Answer; + else + type = Description::Type::Offer; + } + + // Only a local offer resets the negotiation needed flag + if (type == Description::Type::Offer && !impl()->negotiationNeeded.exchange(false)) { + PLOG_DEBUG << "No negotiation needed"; + return; + } + + // Get the new signaling state + SignalingState newSignalingState; + switch (signalingState) { + case SignalingState::Stable: + if (type != Description::Type::Offer) { + std::ostringstream oss; + oss << "Unexpected local desciption type " << type << " in signaling state " + << signalingState; + throw std::logic_error(oss.str()); + } + newSignalingState = SignalingState::HaveLocalOffer; + break; + + case SignalingState::HaveRemoteOffer: + case SignalingState::HaveLocalPranswer: + if (type != Description::Type::Answer && type != Description::Type::Pranswer) { + std::ostringstream oss; + oss << "Unexpected local description type " << type + << " description in signaling state " << signalingState; + throw std::logic_error(oss.str()); + } + newSignalingState = SignalingState::Stable; + break; + + default: { + std::ostringstream oss; + oss << "Unexpected local description in signaling state " << signalingState << ", ignoring"; + LOG_WARNING << oss.str(); + return; + } + } + + auto iceTransport = impl()->initIceTransport(); + if (!iceTransport) + return; // closed + + Description local = iceTransport->getLocalDescription(type); + impl()->processLocalDescription(std::move(local)); + + impl()->changeSignalingState(newSignalingState); + signalingLock.unlock(); + + if (impl()->gatheringState == GatheringState::New) { + iceTransport->gatherLocalCandidates(impl()->localBundleMid()); + } +} + +void PeerConnection::setRemoteDescription(Description description) { + std::unique_lock signalingLock(impl()->signalingMutex); + PLOG_VERBOSE << "Setting remote description: " << string(description); + + if (description.type() == Description::Type::Rollback) { + // This is mostly useless because we accept any offer + PLOG_VERBOSE << "Rolling back pending remote description"; + impl()->changeSignalingState(SignalingState::Stable); + return; + } + + impl()->validateRemoteDescription(description); + + // Get the new signaling state + SignalingState signalingState = impl()->signalingState.load(); + SignalingState newSignalingState; + switch (signalingState) { + case SignalingState::Stable: + description.hintType(Description::Type::Offer); + if (description.type() != Description::Type::Offer) { + std::ostringstream oss; + oss << "Unexpected remote " << description.type() << " description in signaling state " + << signalingState; + throw std::logic_error(oss.str()); + } + newSignalingState = SignalingState::HaveRemoteOffer; + break; + + case SignalingState::HaveLocalOffer: + description.hintType(Description::Type::Answer); + if (description.type() == Description::Type::Offer) { + // The ICE agent will automatically initiate a rollback when a peer that had previously + // created an offer receives an offer from the remote peer + impl()->rollbackLocalDescription(); + impl()->changeSignalingState(SignalingState::Stable); + signalingState = SignalingState::Stable; + newSignalingState = SignalingState::HaveRemoteOffer; + break; + } + if (description.type() != Description::Type::Answer && + description.type() != Description::Type::Pranswer) { + std::ostringstream oss; + oss << "Unexpected remote " << description.type() << " description in signaling state " + << signalingState; + throw std::logic_error(oss.str()); + } + newSignalingState = SignalingState::Stable; + break; + + case SignalingState::HaveRemotePranswer: + description.hintType(Description::Type::Answer); + if (description.type() != Description::Type::Answer && + description.type() != Description::Type::Pranswer) { + std::ostringstream oss; + oss << "Unexpected remote " << description.type() << " description in signaling state " + << signalingState; + throw std::logic_error(oss.str()); + } + newSignalingState = SignalingState::Stable; + break; + + default: { + std::ostringstream oss; + oss << "Unexpected remote description in signaling state " << signalingState; + throw std::logic_error(oss.str()); + } + } + + // Candidates will be added at the end, extract them for now + auto remoteCandidates = description.extractCandidates(); + auto type = description.type(); + + auto iceTransport = impl()->initIceTransport(); + if (!iceTransport) + return; // closed + + iceTransport->setRemoteDescription(description); // ICE transport might reject the description + + impl()->processRemoteDescription(std::move(description)); + impl()->changeSignalingState(newSignalingState); + signalingLock.unlock(); + + if (type == Description::Type::Offer) { + // This is an offer, we need to answer + if (!impl()->config.disableAutoNegotiation) + setLocalDescription(Description::Type::Answer); + } + + for (const auto &candidate : remoteCandidates) + addRemoteCandidate(candidate); +} + +void PeerConnection::addRemoteCandidate(Candidate candidate) { + std::unique_lock signalingLock(impl()->signalingMutex); + PLOG_VERBOSE << "Adding remote candidate: " << string(candidate); + impl()->processRemoteCandidate(std::move(candidate)); +} + +void PeerConnection::setMediaHandler(shared_ptr handler) { + impl()->setMediaHandler(std::move(handler)); +}; + +shared_ptr PeerConnection::getMediaHandler() { return impl()->getMediaHandler(); }; + +optional PeerConnection::localAddress() const { + auto iceTransport = impl()->getIceTransport(); + return iceTransport ? iceTransport->getLocalAddress() : nullopt; +} + +optional PeerConnection::remoteAddress() const { + auto iceTransport = impl()->getIceTransport(); + return iceTransport ? iceTransport->getRemoteAddress() : nullopt; +} + +uint16_t PeerConnection::maxDataChannelId() const { return impl()->maxDataChannelStream(); } + +shared_ptr PeerConnection::createDataChannel(string label, DataChannelInit init) { + auto channelImpl = impl()->emplaceDataChannel(std::move(label), std::move(init)); + auto channel = std::make_shared(channelImpl); + + // Renegotiation is needed iff the current local description does not have application + auto local = impl()->localDescription(); + if (!local || !local->hasApplication()) + impl()->negotiationNeeded = true; + + if (!impl()->config.disableAutoNegotiation) + setLocalDescription(); + + return channel; +} + +void PeerConnection::onDataChannel( + std::function dataChannel)> callback) { + impl()->dataChannelCallback = callback; + impl()->flushPendingDataChannels(); +} + +std::shared_ptr PeerConnection::addTrack(Description::Media description) { + auto trackImpl = impl()->emplaceTrack(std::move(description)); + auto track = std::make_shared(trackImpl); + + // Renegotiation is needed for the new or updated track + impl()->negotiationNeeded = true; + + return track; +} + +void PeerConnection::onTrack(std::function)> callback) { + impl()->trackCallback = callback; + impl()->flushPendingTracks(); +} + +void PeerConnection::onLocalDescription(std::function callback) { + impl()->localDescriptionCallback = callback; +} + +void PeerConnection::onLocalCandidate(std::function callback) { + impl()->localCandidateCallback = callback; +} + +void PeerConnection::onStateChange(std::function callback) { + impl()->stateChangeCallback = callback; +} + +void PeerConnection::onIceStateChange(std::function callback) { + impl()->iceStateChangeCallback = callback; +} + +void PeerConnection::onGatheringStateChange(std::function callback) { + impl()->gatheringStateChangeCallback = callback; +} + +void PeerConnection::onSignalingStateChange(std::function callback) { + impl()->signalingStateChangeCallback = callback; +} + +void PeerConnection::resetCallbacks() { impl()->resetCallbacks(); } + +bool PeerConnection::getSelectedCandidatePair(Candidate *local, Candidate *remote) { + auto iceTransport = impl()->getIceTransport(); + return iceTransport ? iceTransport->getSelectedCandidatePair(local, remote) : false; +} + +void PeerConnection::clearStats() { + if (auto sctpTransport = impl()->getSctpTransport()) + return sctpTransport->clearStats(); +} + +size_t PeerConnection::bytesSent() { + auto sctpTransport = impl()->getSctpTransport(); + return sctpTransport ? sctpTransport->bytesSent() : 0; +} + +size_t PeerConnection::bytesReceived() { + auto sctpTransport = impl()->getSctpTransport(); + return sctpTransport ? sctpTransport->bytesReceived() : 0; +} + +optional PeerConnection::rtt() { + auto sctpTransport = impl()->getSctpTransport(); + return sctpTransport ? sctpTransport->rtt() : nullopt; +} + +std::ostream &operator<<(std::ostream &out, PeerConnection::State state) { + using State = PeerConnection::State; + const char *str; + switch (state) { + case State::New: + str = "new"; + break; + case State::Connecting: + str = "connecting"; + break; + case State::Connected: + str = "connected"; + break; + case State::Disconnected: + str = "disconnected"; + break; + case State::Failed: + str = "failed"; + break; + case State::Closed: + str = "closed"; + break; + default: + str = "unknown"; + break; + } + return out << str; +} + +std::ostream &operator<<(std::ostream &out, PeerConnection::IceState state) { + using IceState = PeerConnection::IceState; + const char *str; + switch (state) { + case IceState::New: + str = "new"; + break; + case IceState::Checking: + str = "checking"; + break; + case IceState::Connected: + str = "connected"; + break; + case IceState::Completed: + str = "completed"; + break; + case IceState::Failed: + str = "failed"; + break; + case IceState::Disconnected: + str = "disconnected"; + break; + case IceState::Closed: + str = "closed"; + break; + default: + str = "unknown"; + break; + } + return out << str; +} + +std::ostream &operator<<(std::ostream &out, PeerConnection::GatheringState state) { + using GatheringState = PeerConnection::GatheringState; + const char *str; + switch (state) { + case GatheringState::New: + str = "new"; + break; + case GatheringState::InProgress: + str = "in-progress"; + break; + case GatheringState::Complete: + str = "complete"; + break; + default: + str = "unknown"; + break; + } + return out << str; +} + +std::ostream &operator<<(std::ostream &out, PeerConnection::SignalingState state) { + using SignalingState = PeerConnection::SignalingState; + const char *str; + switch (state) { + case SignalingState::Stable: + str = "stable"; + break; + case SignalingState::HaveLocalOffer: + str = "have-local-offer"; + break; + case SignalingState::HaveRemoteOffer: + str = "have-remote-offer"; + break; + case SignalingState::HaveLocalPranswer: + str = "have-local-pranswer"; + break; + case SignalingState::HaveRemotePranswer: + str = "have-remote-pranswer"; + break; + default: + str = "unknown"; + break; + } + return out << str; +} + +} // namespace rtc diff --git a/datachannel/src/plihandler.cpp b/datachannel/src/plihandler.cpp new file mode 100644 index 000000000..b4cb4dedc --- /dev/null +++ b/datachannel/src/plihandler.cpp @@ -0,0 +1,45 @@ +/** + * Copyright (c) 2023 Arda Cinar + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "plihandler.hpp" +#include "rtp.hpp" + +#if RTC_ENABLE_MEDIA + +namespace rtc { + +PliHandler::PliHandler(std::function onPli) : mOnPli(onPli) {} + +void PliHandler::incoming(message_vector &messages, [[maybe_unused]] const message_callback &send) { + for (const auto &message : messages) { + size_t offset = 0; + while ((sizeof(RtcpHeader) + offset) <= message->size()) { + auto header = reinterpret_cast(message->data() + offset); + uint8_t payload_type = header->payloadType(); + + if (payload_type == 196) { + // FIR message, call pli handler anyway + mOnPli(); + break; + } else if (payload_type == 206) { + // On a payload specific fb message, there is a "feedback message type" (FMT) in the + // header instead of a report count. PT = 206, FMT = 1 means a PLI message + uint8_t feedback_message_type = header->reportCount(); + if (feedback_message_type == 1) { + mOnPli(); + break; + } + } + offset += header->lengthInBytes(); + } + } +} + +} // namespace rtc + +#endif // RTC_ENABLE_MEDIA diff --git a/datachannel/src/rtcpnackresponder.cpp b/datachannel/src/rtcpnackresponder.cpp new file mode 100644 index 000000000..88d1face2 --- /dev/null +++ b/datachannel/src/rtcpnackresponder.cpp @@ -0,0 +1,114 @@ +/** + * Copyright (c) 2020 Filip Klembara (in2core) + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#if RTC_ENABLE_MEDIA + +#include "rtcpnackresponder.hpp" +#include "rtp.hpp" + +#include "impl/internals.hpp" + +#include + +namespace rtc { + +RtcpNackResponder::RtcpNackResponder(size_t maxSize) + : mStorage(std::make_shared(maxSize)) {} + +void RtcpNackResponder::incoming(message_vector &messages, const message_callback &send) { + for (const auto &message : messages) { + if (message->type != Message::Control) + continue; + + size_t p = 0; + while (p + sizeof(RtcpNack) <= message->size()) { + auto nack = reinterpret_cast(message->data() + p); + p += nack->header.header.lengthInBytes(); + if (p > message->size()) + break; + + // check if RTCP is NACK + if (nack->header.header.payloadType() != 205 || nack->header.header.reportCount() != 1) + continue; + + unsigned int fieldsCount = nack->getSeqNoCount(); + std::vector missingSequenceNumbers; + for (unsigned int i = 0; i < fieldsCount; i++) { + auto field = nack->parts[i]; + auto newMissingSeqenceNumbers = field.getSequenceNumbers(); + missingSequenceNumbers.insert(missingSequenceNumbers.end(), + newMissingSeqenceNumbers.begin(), + newMissingSeqenceNumbers.end()); + } + + for (auto sequenceNumber : missingSequenceNumbers) { + if (auto optPacket = mStorage->get(sequenceNumber)) + send(make_message(*optPacket.value())); + } + } + } +} + +void RtcpNackResponder::outgoing(message_vector &messages, + [[maybe_unused]] const message_callback &send) { + for (const auto &message : messages) + if (message->type != Message::Control) + mStorage->store(message); +} + +RtcpNackResponder::Storage::Element::Element(binary_ptr packet, uint16_t sequenceNumber, + shared_ptr next) + : packet(packet), sequenceNumber(sequenceNumber), next(next) {} + +size_t RtcpNackResponder::Storage::size() { return storage.size(); } + +RtcpNackResponder::Storage::Storage(size_t _maxSize) : maxSize(_maxSize) { + assert(maxSize > 0); + storage.reserve(maxSize); +} + +optional RtcpNackResponder::Storage::get(uint16_t sequenceNumber) { + std::lock_guard lock(mutex); + auto position = storage.find(sequenceNumber); + return position != storage.end() ? std::make_optional(storage.at(sequenceNumber)->packet) + : nullopt; +} + +void RtcpNackResponder::Storage::store(binary_ptr packet) { + if (!packet || packet->size() < sizeof(RtpHeader)) + return; + + auto rtp = reinterpret_cast(packet->data()); + auto sequenceNumber = rtp->seqNumber(); + + std::lock_guard lock(mutex); + assert((storage.empty() && !oldest && !newest) || (!storage.empty() && oldest && newest)); + + if (size() == 0) { + newest = std::make_shared(packet, sequenceNumber); + oldest = newest; + } else { + auto current = std::make_shared(packet, sequenceNumber); + newest->next = current; + newest = current; + } + + storage.emplace(sequenceNumber, newest); + + if (size() > maxSize) { + assert(oldest); + if (oldest) { + storage.erase(oldest->sequenceNumber); + oldest = oldest->next; + } + } +} + +} // namespace rtc + +#endif /* RTC_ENABLE_MEDIA */ diff --git a/datachannel/src/rtcpreceivingsession.cpp b/datachannel/src/rtcpreceivingsession.cpp new file mode 100644 index 000000000..7fc5a977c --- /dev/null +++ b/datachannel/src/rtcpreceivingsession.cpp @@ -0,0 +1,133 @@ +/** + * Copyright (c) 2020 Staz Modrzynski + * Copyright (c) 2020 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#if RTC_ENABLE_MEDIA + +#include "rtcpreceivingsession.hpp" +#include "track.hpp" + +#include "impl/logcounter.hpp" + +#include +#include + +#ifdef _WIN32 +#include +#else +#include +#endif + +namespace rtc { + +static impl::LogCounter COUNTER_BAD_RTP_HEADER(plog::warning, "Number of malformed RTP headers"); +static impl::LogCounter COUNTER_UNKNOWN_PPID(plog::warning, "Number of Unknown PPID messages"); +static impl::LogCounter COUNTER_BAD_NOTIF_LEN(plog::warning, + "Number of Bad-Lengthed notifications"); +static impl::LogCounter COUNTER_BAD_SCTP_STATUS(plog::warning, + "Number of unknown SCTP_STATUS errors"); + +void RtcpReceivingSession::incoming(message_vector &messages, const message_callback &send) { + message_vector result; + for (auto message : messages) { + switch (message->type) { + case Message::Binary: { + if (message->size() < sizeof(RtpHeader)) { + COUNTER_BAD_RTP_HEADER++; + PLOG_VERBOSE << "RTP packet is too small, size=" << message->size(); + continue; + } + + auto rtp = reinterpret_cast(message->data()); + + // https://www.rfc-editor.org/rfc/rfc3550.html#appendix-A.1 + if (rtp->version() != 2) { + COUNTER_BAD_RTP_HEADER++; + PLOG_VERBOSE << "RTP packet is not version 2"; + continue; + } + + if (rtp->payloadType() == 201 || rtp->payloadType() == 200) { + COUNTER_BAD_RTP_HEADER++; + PLOG_VERBOSE << "RTP packet has a payload type indicating RR/SR"; + continue; + } + + mSsrc = rtp->ssrc(); + result.push_back(std::move(message)); + break; + } + + case Message::Control: { + auto rr = reinterpret_cast(message->data()); + if (rr->header.payloadType() == 201) { // RR + mSsrc = rr->senderSSRC(); + rr->log(); + } else if (rr->header.payloadType() == 200) { // SR + mSsrc = rr->senderSSRC(); + auto sr = reinterpret_cast(message->data()); + mSyncRTPTS = sr->rtpTimestamp(); + mSyncNTPTS = sr->ntpTimestamp(); + sr->log(); + + // TODO For the time being, we will send RR's/REMB's when we get an SR + pushRR(send, 0); + if (unsigned int bitrate = mRequestedBitrate.load(); bitrate > 0) + pushREMB(send, bitrate); + } + break; + } + + default: + break; + } + } + + messages.swap(result); +} + +bool RtcpReceivingSession::requestBitrate(unsigned int bitrate, const message_callback &send) { + PLOG_DEBUG << "Requesting bitrate: " << bitrate << std::endl; + mRequestedBitrate.store(bitrate); + pushREMB(send, bitrate); + return true; +} + +void RtcpReceivingSession::pushREMB(const message_callback &send, unsigned int bitrate) { + auto message = make_message(RtcpRemb::SizeWithSSRCs(1), Message::Control); + auto remb = reinterpret_cast(message->data()); + remb->preparePacket(mSsrc, 1, bitrate); + remb->setSsrc(0, mSsrc); + send(message); +} + +void RtcpReceivingSession::pushRR(const message_callback &send, unsigned int lastSrDelay) { + auto message = make_message(RtcpRr::SizeWithReportBlocks(1), Message::Control); + auto rr = reinterpret_cast(message->data()); + rr->preparePacket(mSsrc, 1); + rr->getReportBlock(0)->preparePacket(mSsrc, 0, 0, uint16_t(mGreatestSeqNo), 0, 0, mSyncNTPTS, + lastSrDelay); + rr->log(); + send(message); +} + +bool RtcpReceivingSession::requestKeyframe(const message_callback &send) { + pushPLI(send); + return true; +} + +void RtcpReceivingSession::pushPLI(const message_callback &send) { + auto message = make_message(RtcpPli::Size(), Message::Control); + auto *pli = reinterpret_cast(message->data()); + pli->preparePacket(mSsrc); + send(message); +} + +} // namespace rtc + +#endif // RTC_ENABLE_MEDIA diff --git a/datachannel/src/rtcpsrreporter.cpp b/datachannel/src/rtcpsrreporter.cpp new file mode 100644 index 000000000..a9d79b5a0 --- /dev/null +++ b/datachannel/src/rtcpsrreporter.cpp @@ -0,0 +1,90 @@ +/** + * Copyright (c) 2020 Filip Klembara (in2core) + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#if RTC_ENABLE_MEDIA + +#include "rtcpsrreporter.hpp" + +#include +#include +#include + +namespace { + +// TODO: move to utils +uint64_t ntp_time() { + const auto now = std::chrono::system_clock::now(); + const double secs = std::chrono::duration(now.time_since_epoch()).count(); + // Assume the epoch is 01/01/1970 and adds the number of seconds between 1900 and 1970 + return uint64_t(std::floor((secs + 2208988800.) * double(uint64_t(1) << 32))); +} + +} // namespace + +namespace rtc { + +RtcpSrReporter::RtcpSrReporter(shared_ptr rtpConfig) + : rtpConfig(rtpConfig) { + mLastReportedTimestamp = rtpConfig->timestamp; +} + +void RtcpSrReporter::setNeedsToReport() { mNeedsToReport = true; } + +uint32_t RtcpSrReporter::lastReportedTimestamp() const { return mLastReportedTimestamp; } + +void RtcpSrReporter::outgoing(message_vector &messages, const message_callback &send) { + for (const auto &message : messages) { + if (message->type == Message::Control) + continue; + + if (message->size() < sizeof(RtpHeader)) + continue; + + auto rtp = reinterpret_cast(message->data()); + addToReport(rtp, uint32_t(message->size())); + } + + if (std::exchange(mNeedsToReport, false)) { + auto timestamp = rtpConfig->timestamp; + auto sr = getSenderReport(timestamp); + send(sr); + } +} + +void RtcpSrReporter::addToReport(RtpHeader *rtp, uint32_t rtpSize) { + mPacketCount += 1; + assert(!rtp->padding()); + mPayloadOctets += rtpSize - uint32_t(rtp->getSize()); +} + +message_ptr RtcpSrReporter::getSenderReport(uint32_t timestamp) { + auto srSize = RtcpSr::Size(0); + auto msg = make_message(srSize + RtcpSdes::Size({{uint8_t(rtpConfig->cname.size())}}), + Message::Control); + auto sr = reinterpret_cast(msg->data()); + sr->setNtpTimestamp(ntp_time()); + sr->setRtpTimestamp(timestamp); + sr->setPacketCount(mPacketCount); + sr->setOctetCount(mPayloadOctets); + sr->preparePacket(rtpConfig->ssrc, 0); + + auto sdes = reinterpret_cast(msg->data() + srSize); + auto chunk = sdes->getChunk(0); + chunk->setSSRC(rtpConfig->ssrc); + auto item = chunk->getItem(0); + item->type = 1; + item->setText(rtpConfig->cname); + sdes->preparePacket(1); + + mLastReportedTimestamp = timestamp; + return msg; +} + +} // namespace rtc + +#endif /* RTC_ENABLE_MEDIA */ diff --git a/datachannel/src/rtp.cpp b/datachannel/src/rtp.cpp new file mode 100644 index 000000000..4038111e3 --- /dev/null +++ b/datachannel/src/rtp.cpp @@ -0,0 +1,663 @@ +/** + * Copyright (c) 2020 Staz Modrzynski + * Copyright (c) 2020 Paul-Louis Ageneau + * Copyright (c) 2020 Filip Klembara (in2core) + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "rtp.hpp" + +#include "impl/internals.hpp" + +#include +#include + +#ifdef _WIN32 +#include +#else +#include +#endif + +#ifndef htonll +#define htonll(x) \ + ((uint64_t)(((uint64_t)htonl((uint32_t)(x))) << 32) | (uint64_t)htonl((uint32_t)((x) >> 32))) +#endif +#ifndef ntohll +#define ntohll(x) htonll(x) +#endif + +namespace rtc { + +bool IsRtcp(const binary &data) { + if (data.size() < 8) + return false; + + uint8_t payloadType = std::to_integer(data[1]) & 0x7F; + PLOG_VERBOSE << "Demultiplexing RTCP and RTP with payload type, value=" << int(payloadType); + + // RFC 5761 Multiplexing RTP and RTCP 4. Distinguishable RTP and RTCP Packets + // https://www.rfc-editor.org/rfc/rfc5761.html#section-4 + // It is RECOMMENDED to follow the guidelines in the RTP/AVP profile for the choice of RTP + // payload type values, with the additional restriction that payload type values in the + // range 64-95 MUST NOT be used. Specifically, dynamic RTP payload types SHOULD be chosen in + // the range 96-127 where possible. Values below 64 MAY be used if that is insufficient + // [...] + return (payloadType >= 64 && payloadType <= 95); // Range 64-95 (inclusive) MUST be RTCP +} + +uint8_t RtpHeader::version() const { return _first >> 6; } +bool RtpHeader::padding() const { return (_first >> 5) & 0x01; } +bool RtpHeader::extension() const { return (_first >> 4) & 0x01; } +uint8_t RtpHeader::csrcCount() const { return _first & 0x0F; } +uint8_t RtpHeader::marker() const { return _payloadType & 0b10000000; } +uint8_t RtpHeader::payloadType() const { return _payloadType & 0b01111111; } +uint16_t RtpHeader::seqNumber() const { return ntohs(_seqNumber); } +uint32_t RtpHeader::timestamp() const { return ntohl(_timestamp); } +uint32_t RtpHeader::ssrc() const { return ntohl(_ssrc); } + +size_t RtpHeader::getSize() const { + return reinterpret_cast(&_ssrc + 1 + csrcCount()) - + reinterpret_cast(this); +} + +size_t RtpHeader::getExtensionHeaderSize() const { + auto header = getExtensionHeader(); + return header ? header->getSize() + sizeof(RtpExtensionHeader) : 0; +} + +const RtpExtensionHeader *RtpHeader::getExtensionHeader() const { + return extension() ? reinterpret_cast(&_ssrc + 1 + csrcCount()) + : nullptr; +} + +RtpExtensionHeader *RtpHeader::getExtensionHeader() { + return extension() ? reinterpret_cast(&_ssrc + 1 + csrcCount()) : nullptr; +} + +const char *RtpHeader::getBody() const { + return reinterpret_cast(&_ssrc + 1 + csrcCount()) + getExtensionHeaderSize(); +} + +char *RtpHeader::getBody() { + return reinterpret_cast(&_ssrc + 1 + csrcCount()) + getExtensionHeaderSize(); +} + +void RtpHeader::preparePacket() { _first |= (1 << 7); } + +void RtpHeader::setSeqNumber(uint16_t newSeqNo) { _seqNumber = htons(newSeqNo); } + +void RtpHeader::setPayloadType(uint8_t newPayloadType) { + _payloadType = (_payloadType & 0b10000000u) | (0b01111111u & newPayloadType); +} + +void RtpHeader::setSsrc(uint32_t in_ssrc) { _ssrc = htonl(in_ssrc); } + +void RtpHeader::setMarker(bool marker) { _payloadType = (_payloadType & 0x7F) | (marker << 7); }; + +void RtpHeader::setTimestamp(uint32_t i) { _timestamp = htonl(i); } + +void RtpHeader::setExtension(bool extension) { _first = (_first & ~0x10) | ((extension & 1) << 4); } + +void RtpHeader::log() const { + PLOG_VERBOSE << "RtpHeader V: " << (int)version() << " P: " << (padding() ? "P" : " ") + << " X: " << (extension() ? "X" : " ") << " CC: " << (int)csrcCount() + << " M: " << (marker() ? "M" : " ") << " PT: " << (int)payloadType() + << " SEQNO: " << seqNumber() << " TS: " << timestamp(); +} + +uint16_t RtpExtensionHeader::profileSpecificId() const { return ntohs(_profileSpecificId); } + +uint16_t RtpExtensionHeader::headerLength() const { return ntohs(_headerLength); } + +size_t RtpExtensionHeader::getSize() const { return headerLength() * 4; } + +const char *RtpExtensionHeader::getBody() const { + return reinterpret_cast((&_headerLength) + 1); +} + +char *RtpExtensionHeader::getBody() { return reinterpret_cast((&_headerLength) + 1); } + +void RtpExtensionHeader::setProfileSpecificId(uint16_t profileSpecificId) { + _profileSpecificId = htons(profileSpecificId); +} + +void RtpExtensionHeader::setHeaderLength(uint16_t headerLength) { + _headerLength = htons(headerLength); +} + +void RtpExtensionHeader::clearBody() { std::memset(getBody(), 0, getSize()); } + +void RtpExtensionHeader::writeOneByteHeader(size_t offset, uint8_t id, const byte *value, + size_t size) { + if ((id == 0) || (id > 14) || (size == 0) || (size > 16) || ((offset + 1 + size) > getSize())) + return; + auto buf = getBody() + offset; + buf[0] = id << 4; + if (size != 1) { + buf[0] |= (uint8_t(size) - 1); + } + std::memcpy(buf + 1, value, size); +} + +void RtpExtensionHeader::writeCurrentVideoOrientation(size_t offset, const uint8_t id, + uint8_t value) { + auto v = std::byte{value}; + writeOneByteHeader(offset, id, &v, 1); +} + +SSRC RtcpReportBlock::getSSRC() const { return ntohl(_ssrc); } + +void RtcpReportBlock::preparePacket(SSRC in_ssrc, [[maybe_unused]] unsigned int packetsLost, + [[maybe_unused]] unsigned int totalPackets, + uint16_t highestSeqNo, uint16_t seqNoCycles, uint32_t jitter, + uint64_t lastSR_NTP, uint64_t lastSR_DELAY) { + setSeqNo(highestSeqNo, seqNoCycles); + setJitter(jitter); + setSSRC(in_ssrc); + + // Middle 32 bits of NTP Timestamp + // _lastReport = lastSR_NTP >> 16u; + setNTPOfSR(uint64_t(lastSR_NTP)); + setDelaySinceSR(uint32_t(lastSR_DELAY)); + + // The delay, expressed in units of 1/65536 seconds + // _delaySinceLastReport = lastSR_DELAY; +} + +void RtcpReportBlock::setSSRC(SSRC in_ssrc) { _ssrc = htonl(in_ssrc); } + +void RtcpReportBlock::setPacketsLost(uint8_t fractionLost, + unsigned int packetsLostCount) { + _fractionLostAndPacketsLost = htonl((uint32_t(fractionLost) << 24) | (packetsLostCount & 0xFFFFFF)); +} + +uint8_t RtcpReportBlock::getFractionLost() const { + // Fraction lost is expressed as 8-bit fixed point number + // In order to get actual lost percentage divide the result by 256 + return _fractionLostAndPacketsLost & 0xFF; +} + +unsigned int RtcpReportBlock::getPacketsLostCount() const { + return ntohl(_fractionLostAndPacketsLost & 0xFFFFFF00); +} + +uint16_t RtcpReportBlock::seqNoCycles() const { return ntohs(_seqNoCycles); } + +uint16_t RtcpReportBlock::highestSeqNo() const { return ntohs(_highestSeqNo); } + +uint32_t RtcpReportBlock::extendedHighestSeqNo() const { return (seqNoCycles() << 16) | highestSeqNo(); } + +uint32_t RtcpReportBlock::jitter() const { return ntohl(_jitter); } + +uint32_t RtcpReportBlock::delaySinceSR() const { return ntohl(_delaySinceLastReport); } + +void RtcpReportBlock::setSeqNo(uint16_t highestSeqNo, uint16_t seqNoCycles) { + _highestSeqNo = htons(highestSeqNo); + _seqNoCycles = htons(seqNoCycles); +} + +void RtcpReportBlock::setJitter(uint32_t jitter) { _jitter = htonl(jitter); } + +void RtcpReportBlock::setNTPOfSR(uint64_t ntp) { _lastReport = htonl((uint32_t)(ntp >> 16)); } + +uint32_t RtcpReportBlock::getNTPOfSR() const { return ntohl(_lastReport) << 16u; } + +void RtcpReportBlock::setDelaySinceSR(uint32_t sr) { + // The delay, expressed in units of 1/65536 seconds + _delaySinceLastReport = htonl(sr); +} + +void RtcpReportBlock::log() const { + PLOG_VERBOSE << "RTCP report block: " + << "ssrc=" + << ntohl(_ssrc) + // TODO: Implement these reports + // << ", fractionLost=" << fractionLost + // << ", packetsLost=" << packetsLost + << ", highestSeqNo=" << highestSeqNo() << ", seqNoCycles=" << seqNoCycles() + << ", jitter=" << jitter() << ", lastSR=" << getNTPOfSR() + << ", lastSRDelay=" << delaySinceSR(); +} + +uint8_t RtcpHeader::version() const { return _first >> 6; } + +bool RtcpHeader::padding() const { return (_first >> 5) & 0x01; } + +uint8_t RtcpHeader::reportCount() const { return _first & 0x1F; } + +uint8_t RtcpHeader::payloadType() const { return _payloadType; } + +uint16_t RtcpHeader::length() const { return ntohs(_length); } + +size_t RtcpHeader::lengthInBytes() const { return (1 + length()) * 4; } + +void RtcpHeader::setPayloadType(uint8_t type) { _payloadType = type; } + +void RtcpHeader::setReportCount(uint8_t count) { + _first = (_first & 0b11100000u) | (count & 0b00011111u); +} + +void RtcpHeader::setLength(uint16_t length) { _length = htons(length); } + +void RtcpHeader::prepareHeader(uint8_t payloadType, uint8_t reportCount, uint16_t length) { + _first = 0b10000000; // version 2, no padding + setReportCount(reportCount); + setPayloadType(payloadType); + setLength(length); +} + +void RtcpHeader::log() const { + PLOG_VERBOSE << "RTCP header: " + << "version=" << unsigned(version()) << ", padding=" << padding() + << ", reportCount=" << unsigned(reportCount()) + << ", payloadType=" << unsigned(payloadType()) << ", length=" << length(); +} + +SSRC RtcpFbHeader::packetSenderSSRC() const { return ntohl(_packetSender); } + +SSRC RtcpFbHeader::mediaSourceSSRC() const { return ntohl(_mediaSource); } + +void RtcpFbHeader::setPacketSenderSSRC(SSRC ssrc) { _packetSender = htonl(ssrc); } + +void RtcpFbHeader::setMediaSourceSSRC(SSRC ssrc) { _mediaSource = htonl(ssrc); } + +void RtcpFbHeader::log() const { + header.log(); + PLOG_VERBOSE << "FB: " + << " packet sender: " << packetSenderSSRC() + << " media source: " << mediaSourceSSRC(); +} + +unsigned int RtcpSr::Size(unsigned int reportCount) { + return sizeof(RtcpHeader) + 24 + reportCount * sizeof(RtcpReportBlock); +} + +void RtcpSr::preparePacket(SSRC senderSSRC, uint8_t reportCount) { + unsigned int length = ((sizeof(header) + 24 + reportCount * sizeof(RtcpReportBlock)) / 4) - 1; + header.prepareHeader(200, reportCount, uint16_t(length)); + this->_senderSSRC = htonl(senderSSRC); +} + +const RtcpReportBlock *RtcpSr::getReportBlock(int num) const { return &_reportBlocks + num; } + +RtcpReportBlock *RtcpSr::getReportBlock(int num) { return &_reportBlocks + num; } + +size_t RtcpSr::getSize() const { + // "length" in packet is one less than the number of 32 bit words in the packet. + return sizeof(uint32_t) * (1 + size_t(header.length())); +} + +uint64_t RtcpSr::ntpTimestamp() const { return ntohll(_ntpTimestamp); } +uint32_t RtcpSr::rtpTimestamp() const { return ntohl(_rtpTimestamp); } +uint32_t RtcpSr::packetCount() const { return ntohl(_packetCount); } +uint32_t RtcpSr::octetCount() const { return ntohl(_octetCount); } +uint32_t RtcpSr::senderSSRC() const { return ntohl(_senderSSRC); } + +void RtcpSr::setNtpTimestamp(uint64_t ts) { _ntpTimestamp = htonll(ts); } +void RtcpSr::setRtpTimestamp(uint32_t ts) { _rtpTimestamp = htonl(ts); } +void RtcpSr::setOctetCount(uint32_t ts) { _octetCount = htonl(ts); } +void RtcpSr::setPacketCount(uint32_t ts) { _packetCount = htonl(ts); } + +void RtcpSr::log() const { + header.log(); + PLOG_VERBOSE << "RTCP SR: " + << " SSRC=" << senderSSRC() << ", NTP_TS=" << ntpTimestamp() + << ", RtpTS=" << rtpTimestamp() << ", packetCount=" << packetCount() + << ", octetCount=" << octetCount(); + + for (unsigned i = 0; i < unsigned(header.reportCount()); i++) { + getReportBlock(i)->log(); + } +} + +unsigned int RtcpSdesItem::Size(uint8_t textLength) { return textLength + 2; } + +std::string RtcpSdesItem::text() const { return std::string(_text, _length); } + +void RtcpSdesItem::setText(std::string text) { + if (text.size() > 0xFF) + throw std::invalid_argument("text is too long"); + + _length = uint8_t(text.size()); + memcpy(_text, text.data(), text.size()); +} + +uint8_t RtcpSdesItem::length() const { return _length; } + +unsigned int RtcpSdesChunk::Size(const std::vector textLengths) { + unsigned int itemsSize = 0; + for (auto length : textLengths) { + itemsSize += RtcpSdesItem::Size(length); + } + auto nullTerminatedItemsSize = itemsSize + 1; + auto words = uint8_t(std::ceil(double(nullTerminatedItemsSize) / 4)) + 1; + return words * 4; +} + +SSRC RtcpSdesChunk::ssrc() const { return ntohl(_ssrc); } + +void RtcpSdesChunk::setSSRC(SSRC ssrc) { _ssrc = htonl(ssrc); } + +const RtcpSdesItem *RtcpSdesChunk::getItem(int num) const { + auto base = &_items; + while (num-- > 0) { + auto itemSize = RtcpSdesItem::Size(base->length()); + base = reinterpret_cast(reinterpret_cast(base) + + itemSize); + } + return reinterpret_cast(base); +} + +RtcpSdesItem *RtcpSdesChunk::getItem(int num) { + auto base = &_items; + while (num-- > 0) { + auto itemSize = RtcpSdesItem::Size(base->length()); + base = reinterpret_cast(reinterpret_cast(base) + itemSize); + } + return reinterpret_cast(base); +} + +unsigned int RtcpSdesChunk::getSize() const { + std::vector textLengths{}; + unsigned int i = 0; + auto item = getItem(i); + while (item->type != 0) { + textLengths.push_back(item->length()); + item = getItem(++i); + } + return Size(textLengths); +} + +long RtcpSdesChunk::safelyCountChunkSize(size_t maxChunkSize) const { + if (maxChunkSize < RtcpSdesChunk::Size({})) { + // chunk is truncated + return -1; + } + + size_t size = sizeof(SSRC); + unsigned int i = 0; + // We can always access first 4 bytes of first item (in case of no items there will be 4 + // null bytes) + auto item = getItem(i); + std::vector textsLength{}; + while (item->type != 0) { + if (size + RtcpSdesItem::Size(0) > maxChunkSize) { + // item is too short + return -1; + } + auto itemLength = item->length(); + if (size + RtcpSdesItem::Size(itemLength) >= maxChunkSize) { + // item is too large (it can't be equal to chunk size because after item there + // must be 1-4 null bytes as padding) + return -1; + } + textsLength.push_back(itemLength); + // safely to access next item + item = getItem(++i); + } + auto realSize = RtcpSdesChunk::Size(textsLength); + if (realSize > maxChunkSize) { + // Chunk is too large + return -1; + } + return realSize; +} + +unsigned int RtcpSdes::Size(const std::vector> lengths) { + unsigned int chunks_size = 0; + for (auto length : lengths) + chunks_size += RtcpSdesChunk::Size(length); + + return 4 + chunks_size; +} + +bool RtcpSdes::isValid() const { + auto chunksSize = header.lengthInBytes() - sizeof(header); + if (chunksSize == 0) { + return true; + } + // there is at least one chunk + unsigned int i = 0; + unsigned int size = 0; + while (size < chunksSize) { + if (chunksSize < size + RtcpSdesChunk::Size({})) { + // chunk is truncated + return false; + } + auto chunk = getChunk(i++); + auto chunkSize = chunk->safelyCountChunkSize(chunksSize - size); + if (chunkSize < 0) { + // chunk is invalid + return false; + } + size += chunkSize; + } + return size == chunksSize; +} + +unsigned int RtcpSdes::chunksCount() const { + if (!isValid()) { + return 0; + } + uint16_t chunksSize = 4 * (header.length() + 1) - sizeof(header); + unsigned int size = 0; + unsigned int i = 0; + while (size < chunksSize) { + size += getChunk(i++)->getSize(); + } + return i; +} + +const RtcpSdesChunk *RtcpSdes::getChunk(int num) const { + auto base = &_chunks; + while (num-- > 0) { + auto chunkSize = base->getSize(); + base = reinterpret_cast(reinterpret_cast(base) + + chunkSize); + } + return reinterpret_cast(base); +} + +RtcpSdesChunk *RtcpSdes::getChunk(int num) { + auto base = &_chunks; + while (num-- > 0) { + auto chunkSize = base->getSize(); + base = reinterpret_cast(reinterpret_cast(base) + chunkSize); + } + return reinterpret_cast(base); +} + +void RtcpSdes::preparePacket(uint8_t chunkCount) { + unsigned int chunkSize = 0; + for (uint8_t i = 0; i < chunkCount; i++) { + auto chunk = getChunk(i); + chunkSize += chunk->getSize(); + } + uint16_t length = uint16_t((sizeof(header) + chunkSize) / 4 - 1); + header.prepareHeader(202, chunkCount, length); +} + +const RtcpReportBlock *RtcpRr::getReportBlock(int num) const { return &_reportBlocks + num; } + +RtcpReportBlock *RtcpRr::getReportBlock(int num) { return &_reportBlocks + num; } + +size_t RtcpRr::SizeWithReportBlocks(uint8_t reportCount) { + return sizeof(header) + 4 + size_t(reportCount) * sizeof(RtcpReportBlock); +} + +SSRC RtcpRr::senderSSRC() const { return ntohl(_senderSSRC); } + +bool RtcpRr::isSenderReport() { return header.payloadType() == 200; } + +bool RtcpRr::isReceiverReport() { return header.payloadType() == 201; } + +size_t RtcpRr::getSize() const { + // "length" in packet is one less than the number of 32 bit words in the packet. + return sizeof(uint32_t) * (1 + size_t(header.length())); +} + +void RtcpRr::preparePacket(SSRC senderSSRC, uint8_t reportCount) { + // "length" in packet is one less than the number of 32 bit words in the packet. + size_t length = (SizeWithReportBlocks(reportCount) / 4) - 1; + header.prepareHeader(201, reportCount, uint16_t(length)); + this->_senderSSRC = htonl(senderSSRC); +} + +void RtcpRr::setSenderSSRC(SSRC ssrc) { this->_senderSSRC = htonl(ssrc); } + +void RtcpRr::log() const { + header.log(); + PLOG_VERBOSE << "RTCP RR: " + << " SSRC=" << ntohl(_senderSSRC); + + for (unsigned i = 0; i < unsigned(header.reportCount()); i++) { + getReportBlock(i)->log(); + } +} + +size_t RtcpRemb::SizeWithSSRCs(int count) { return sizeof(RtcpRemb) + (count - 1) * sizeof(SSRC); } + +unsigned int RtcpRemb::getSize() const { + // "length" in packet is one less than the number of 32 bit words in the packet. + return sizeof(uint32_t) * (1 + header.header.length()); +} + +void RtcpRemb::preparePacket(SSRC senderSSRC, unsigned int numSSRC, unsigned int in_bitrate) { + + // Report Count becomes the format here. + header.header.prepareHeader(206, 15, 0); + + // Always zero. + header.setMediaSourceSSRC(0); + + header.setPacketSenderSSRC(senderSSRC); + + _id[0] = 'R'; + _id[1] = 'E'; + _id[2] = 'M'; + _id[3] = 'B'; + + setBitrate(numSSRC, in_bitrate); +} + +void RtcpRemb::setBitrate(unsigned int numSSRC, unsigned int in_bitrate) { + unsigned int exp = 0; + while (in_bitrate > pow(2, 18) - 1) { + exp++; + in_bitrate /= 2; + } + + // "length" in packet is one less than the number of 32 bit words in the packet. + header.header.setLength(uint16_t((offsetof(RtcpRemb, _ssrc) / sizeof(uint32_t)) - 1 + numSSRC)); + + _bitrate = htonl((numSSRC << (32u - 8u)) | (exp << (32u - 8u - 6u)) | in_bitrate); +} + +void RtcpRemb::setSsrc(int iterator, SSRC newSssrc) { _ssrc[iterator] = htonl(newSssrc); } + +unsigned int RtcpPli::Size() { return sizeof(RtcpFbHeader); } + +void RtcpPli::preparePacket(SSRC messageSSRC) { + header.header.prepareHeader(206, 1, 2); + header.setPacketSenderSSRC(messageSSRC); + header.setMediaSourceSSRC(messageSSRC); +} + +void RtcpPli::log() const { header.log(); } + +unsigned int RtcpFir::Size() { return sizeof(RtcpFbHeader) + sizeof(RtcpFirPart); } + +void RtcpFir::preparePacket(SSRC messageSSRC, uint8_t seqNo) { + header.header.prepareHeader(206, 4, 2 + 2 * 1); + header.setPacketSenderSSRC(messageSSRC); + header.setMediaSourceSSRC(messageSSRC); + parts[0].ssrc = htonl(messageSSRC); + parts[0].seqNo = seqNo; +} + +void RtcpFir::log() const { header.log(); } + +uint16_t RtcpNackPart::pid() { return ntohs(_pid); } +uint16_t RtcpNackPart::blp() { return ntohs(_blp); } + +void RtcpNackPart::setPid(uint16_t pid) { _pid = htons(pid); } +void RtcpNackPart::setBlp(uint16_t blp) { _blp = htons(blp); } + +std::vector RtcpNackPart::getSequenceNumbers() { + std::vector result{}; + result.reserve(17); + uint16_t p = pid(); + result.push_back(p); + uint16_t bitmask = blp(); + uint16_t i = p + 1; + while (bitmask > 0) { + if (bitmask & 0x1) { + result.push_back(i); + } + i += 1; + bitmask >>= 1; + } + return result; +} + +unsigned int RtcpNack::Size(unsigned int discreteSeqNoCount) { + return offsetof(RtcpNack, parts) + sizeof(RtcpNackPart) * discreteSeqNoCount; +} + +unsigned int RtcpNack::getSeqNoCount() { return header.header.length() - 2; } + +void RtcpNack::preparePacket(SSRC ssrc, unsigned int discreteSeqNoCount) { + header.header.prepareHeader(205, 1, 2 + uint16_t(discreteSeqNoCount)); + header.setMediaSourceSSRC(ssrc); + header.setPacketSenderSSRC(ssrc); +} + +bool RtcpNack::addMissingPacket(unsigned int *fciCount, uint16_t *fciPID, uint16_t missingPacket) { + if (*fciCount == 0 || missingPacket < *fciPID || missingPacket > (*fciPID + 16)) { + parts[*fciCount].setPid(missingPacket); + parts[*fciCount].setBlp(0); + *fciPID = missingPacket; + (*fciCount)++; + return true; + } else { + // TODO SPEED! + uint16_t blp = parts[(*fciCount) - 1].blp(); + uint16_t newBit = uint16_t(1u << (missingPacket - (1 + *fciPID))); + parts[(*fciCount) - 1].setBlp(blp | newBit); + return false; + } +} + +uint16_t RtpRtx::getOriginalSeqNo() const { return ntohs(*(uint16_t *)(header.getBody())); } + +const char *RtpRtx::getBody() const { return header.getBody() + sizeof(uint16_t); } + +char *RtpRtx::getBody() { return header.getBody() + sizeof(uint16_t); } + +size_t RtpRtx::getBodySize(size_t totalSize) const { + return totalSize - (getBody() - reinterpret_cast(this)); +} + +size_t RtpRtx::getSize() const { return header.getSize() + sizeof(uint16_t); } + +size_t RtpRtx::normalizePacket(size_t totalSize, SSRC originalSSRC, uint8_t originalPayloadType) { + header.setSeqNumber(getOriginalSeqNo()); + header.setSsrc(originalSSRC); + header.setPayloadType(originalPayloadType); + // TODO, the -12 is the size of the header (which is variable!) + memmove(header.getBody(), getBody(), totalSize - getSize()); + return totalSize - 2; +} + +size_t RtpRtx::copyTo(RtpHeader *dest, size_t totalSize, uint8_t originalPayloadType) { + memmove((char *)dest, (char *)this, header.getSize()); + dest->setSeqNumber(getOriginalSeqNo()); + dest->setPayloadType(originalPayloadType); + memmove(dest->getBody(), getBody(), getBodySize(totalSize)); + return totalSize; +} + +}; // namespace rtc diff --git a/datachannel/src/rtppacketizationconfig.cpp b/datachannel/src/rtppacketizationconfig.cpp new file mode 100644 index 000000000..6f208c2ca --- /dev/null +++ b/datachannel/src/rtppacketizationconfig.cpp @@ -0,0 +1,56 @@ +/** + * Copyright (c) 2020 Filip Klembara (in2core) + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#if RTC_ENABLE_MEDIA + +#include "rtppacketizationconfig.hpp" + +#include "impl/utils.hpp" + +#include +#include +#include +#include + +namespace rtc { + +namespace utils = impl::utils; + +RtpPacketizationConfig::RtpPacketizationConfig(SSRC ssrc, string cname, uint8_t payloadType, + uint32_t clockRate, uint8_t videoOrientationId) + : ssrc(ssrc), cname(cname), payloadType(payloadType), clockRate(clockRate), + videoOrientationId(videoOrientationId) { + assert(clockRate > 0); + + // RFC 3550: The initial value of the sequence number SHOULD be random (unpredictable) to make + // known-plaintext attacks on encryption more difficult [...] The initial value of the timestamp + // SHOULD be random, as for the sequence number. + auto uniform = std::bind(std::uniform_int_distribution(), utils::random_engine()); + sequenceNumber = static_cast(uniform()); + timestamp = startTimestamp = uniform(); +} + +double RtpPacketizationConfig::getSecondsFromTimestamp(uint32_t timestamp, uint32_t clockRate) { + return double(timestamp) / double(clockRate); +} + +double RtpPacketizationConfig::timestampToSeconds(uint32_t timestamp) { + return RtpPacketizationConfig::getSecondsFromTimestamp(timestamp, clockRate); +} + +uint32_t RtpPacketizationConfig::getTimestampFromSeconds(double seconds, uint32_t clockRate) { + return uint32_t(int64_t(round(seconds * double(clockRate)))); // convert to integer then cast to u32 +} + +uint32_t RtpPacketizationConfig::secondsToTimestamp(double seconds) { + return RtpPacketizationConfig::getTimestampFromSeconds(seconds, clockRate); +} + +} // namespace rtc + +#endif /* RTC_ENABLE_MEDIA */ diff --git a/datachannel/src/rtppacketizer.cpp b/datachannel/src/rtppacketizer.cpp new file mode 100644 index 000000000..ba7048ac5 --- /dev/null +++ b/datachannel/src/rtppacketizer.cpp @@ -0,0 +1,109 @@ +/** + * Copyright (c) 2020 Filip Klembara (in2core) + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#if RTC_ENABLE_MEDIA + +#include "rtppacketizer.hpp" + +#include +#include + +namespace rtc { + +RtpPacketizer::RtpPacketizer(shared_ptr rtpConfig) : rtpConfig(rtpConfig) {} + +RtpPacketizer::~RtpPacketizer() {} + +message_ptr RtpPacketizer::packetize(shared_ptr payload, bool mark) { + size_t rtpExtHeaderSize = 0; + + const bool setVideoRotation = (rtpConfig->videoOrientationId != 0) && + (rtpConfig->videoOrientationId < + 15) && // needs fixing if longer extension headers are supported + mark && + (rtpConfig->videoOrientation != 0); + + if (setVideoRotation) + rtpExtHeaderSize += 2; + + if (rtpConfig->mid.has_value()) + rtpExtHeaderSize += (1 + rtpConfig->mid->length()); + + if (rtpConfig->rid.has_value()) + rtpExtHeaderSize += (1 + rtpConfig->rid->length()); + + if (rtpExtHeaderSize != 0) + rtpExtHeaderSize += 4; + + rtpExtHeaderSize = (rtpExtHeaderSize + 3) & ~3; + + auto message = make_message(RtpHeaderSize + rtpExtHeaderSize + payload->size()); + auto *rtp = (RtpHeader *)message->data(); + rtp->setPayloadType(rtpConfig->payloadType); + rtp->setSeqNumber(rtpConfig->sequenceNumber++); // increase sequence number + rtp->setTimestamp(rtpConfig->timestamp); + rtp->setSsrc(rtpConfig->ssrc); + + if (mark) { + rtp->setMarker(true); + } + + if (rtpExtHeaderSize) { + rtp->setExtension(true); + + auto extHeader = rtp->getExtensionHeader(); + extHeader->setProfileSpecificId(0xbede); + + auto headerLength = static_cast(rtpExtHeaderSize / 4) - 1; + + extHeader->setHeaderLength(headerLength); + extHeader->clearBody(); + + size_t offset = 0; + if (setVideoRotation) { + extHeader->writeCurrentVideoOrientation(offset, rtpConfig->videoOrientationId, + rtpConfig->videoOrientation); + offset += 2; + } + + if (rtpConfig->mid.has_value()) { + extHeader->writeOneByteHeader( + offset, rtpConfig->midId, + reinterpret_cast(rtpConfig->mid->c_str()), + rtpConfig->mid->length()); + offset += (1 + rtpConfig->mid->length()); + } + + if (rtpConfig->rid.has_value()) { + extHeader->writeOneByteHeader( + offset, rtpConfig->ridId, + reinterpret_cast(rtpConfig->rid->c_str()), + rtpConfig->rid->length()); + } + } + + rtp->preparePacket(); + + std::memcpy(message->data() + RtpHeaderSize + rtpExtHeaderSize, payload->data(), + payload->size()); + + return message; +} + +void RtpPacketizer::media([[maybe_unused]] const Description::Media &desc) {} + +void RtpPacketizer::outgoing([[maybe_unused]] message_vector &messages, + [[maybe_unused]] const message_callback &send) { + // Default implementation + for (auto &message : messages) + message = packetize(message, false); +} + +} // namespace rtc + +#endif /* RTC_ENABLE_MEDIA */ diff --git a/datachannel/src/track.cpp b/datachannel/src/track.cpp new file mode 100644 index 000000000..70bda4672 --- /dev/null +++ b/datachannel/src/track.cpp @@ -0,0 +1,73 @@ +/** + * Copyright (c) 2020-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#include "track.hpp" + +#include "impl/internals.hpp" +#include "impl/track.hpp" + +namespace rtc { + +Track::Track(impl_ptr impl) + : CheshireCat(impl), Channel(std::dynamic_pointer_cast(impl)) {} + +Track::~Track() {} + +string Track::mid() const { return impl()->mid(); } + +Description::Direction Track::direction() const { return impl()->direction(); } + +Description::Media Track::description() const { return impl()->description(); } + +void Track::setDescription(Description::Media description) { + impl()->setDescription(std::move(description)); +} + +void Track::close() { impl()->close(); } + +bool Track::send(message_variant data) { return impl()->outgoing(make_message(std::move(data))); } + +bool Track::send(const byte *data, size_t size) { return send(binary(data, data + size)); } + +bool Track::isOpen(void) const { return impl()->isOpen(); } + +bool Track::isClosed(void) const { return impl()->isClosed(); } + +size_t Track::maxMessageSize() const { return impl()->maxMessageSize(); } + +void Track::setMediaHandler(shared_ptr handler) { + impl()->setMediaHandler(std::move(handler)); +} + +void Track::chainMediaHandler(shared_ptr handler) { + if (auto first = impl()->getMediaHandler()) + first->addToChain(std::move(handler)); + else + impl()->setMediaHandler(std::move(handler)); +} + +bool Track::requestKeyframe() { + // only push PLI for video + if (description().type() == "video") + if (auto handler = impl()->getMediaHandler()) + return handler->requestKeyframe([this](message_ptr m) { impl()->transportSend(m); }); + + return false; +} + +bool Track::requestBitrate(unsigned int bitrate) { + if (auto handler = impl()->getMediaHandler()) + return handler->requestBitrate(bitrate, + [this](message_ptr m) { impl()->transportSend(m); }); + + return false; +} + +shared_ptr Track::getMediaHandler() { return impl()->getMediaHandler(); } + +} // namespace rtc diff --git a/datachannel/src/websocket.cpp b/datachannel/src/websocket.cpp new file mode 100644 index 000000000..70368192c --- /dev/null +++ b/datachannel/src/websocket.cpp @@ -0,0 +1,96 @@ +/** + * Copyright (c) 2020-2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#if RTC_ENABLE_WEBSOCKET + +#include "websocket.hpp" +#include "common.hpp" + +#include "impl/internals.hpp" +#include "impl/websocket.hpp" + +namespace rtc { + +WebSocket::WebSocket() : WebSocket(Configuration()) {} + +WebSocket::WebSocket(Configuration config) + : CheshireCat(std::move(config)), + Channel(std::dynamic_pointer_cast(CheshireCat::impl())) {} + +WebSocket::WebSocket(impl_ptr impl) + : CheshireCat(std::move(impl)), + Channel(std::dynamic_pointer_cast(CheshireCat::impl())) {} + +WebSocket::~WebSocket() { + try { + impl()->remoteClose(); + impl()->resetCallbacks(); // not done by impl::WebSocket + } catch (const std::exception &e) { + PLOG_ERROR << e.what(); + } +} + +WebSocket::State WebSocket::readyState() const { return impl()->state; } + +bool WebSocket::isOpen() const { return impl()->state.load() == State::Open; } + +bool WebSocket::isClosed() const { return impl()->state.load() == State::Closed; } + +size_t WebSocket::maxMessageSize() const { return impl()->maxMessageSize(); } + +void WebSocket::open(const string &url) { impl()->open(url); } + +void WebSocket::close() { impl()->close(); } + +void WebSocket::forceClose() { impl()->remoteClose(); } + +bool WebSocket::send(message_variant data) { + return impl()->outgoing(make_message(std::move(data))); +} + +bool WebSocket::send(const byte *data, size_t size) { + return impl()->outgoing(make_message(data, data + size, Message::Binary)); +} + +optional WebSocket::remoteAddress() const { + auto tcpTransport = impl()->getTcpTransport(); + return tcpTransport ? make_optional(tcpTransport->remoteAddress()) : nullopt; +} + +optional WebSocket::path() const { + auto state = impl()->state.load(); + auto handshake = impl()->getWsHandshake(); + return state != State::Connecting && handshake ? make_optional(handshake->path()) : nullopt; +} + +std::ostream &operator<<(std::ostream &out, WebSocket::State state) { + using State = WebSocket::State; + const char *str; + switch (state) { + case State::Connecting: + str = "connecting"; + break; + case State::Open: + str = "open"; + break; + case State::Closing: + str = "closing"; + break; + case State::Closed: + str = "closed"; + break; + default: + str = "unknown"; + break; + } + return out << str; +} + +} // namespace rtc + +#endif diff --git a/datachannel/src/websocketserver.cpp b/datachannel/src/websocketserver.cpp new file mode 100644 index 000000000..5e1f6340f --- /dev/null +++ b/datachannel/src/websocketserver.cpp @@ -0,0 +1,36 @@ +/** + * Copyright (c) 2021 Paul-Louis Ageneau + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. + */ + +#if RTC_ENABLE_WEBSOCKET + +#include "websocketserver.hpp" +#include "common.hpp" + +#include "impl/internals.hpp" +#include "impl/websocketserver.hpp" + +namespace rtc { + +WebSocketServer::WebSocketServer() : WebSocketServer(Configuration()) {} + +WebSocketServer::WebSocketServer(Configuration config) + : CheshireCat(std::move(config)) {} + +WebSocketServer::~WebSocketServer() { impl()->stop(); } + +void WebSocketServer::stop() { impl()->stop(); } + +uint16_t WebSocketServer::port() const { return impl()->tcpServer->port(); } + +void WebSocketServer::onClient(std::function)> callback) { + impl()->clientCallback = callback; +} + +} // namespace rtc + +#endif diff --git a/libs.cmake b/libs.cmake index 512afc3aa..b69493ab6 100644 --- a/libs.cmake +++ b/libs.cmake @@ -16,6 +16,7 @@ set(libs umysql uuid exif + datachannel juice usrsctp openssl