117 changes: 99 additions & 18 deletions Source/Core/Common/TraversalServer.cpp
Expand Up @@ -9,8 +9,10 @@
#include <cstring>
#include <fcntl.h>
#include <netinet/in.h>
#include <sys/select.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <tuple>
#include <unistd.h>
#include <unordered_map>
#include <utility>
Expand All @@ -26,13 +28,15 @@
#define DEBUG 0
#define NUMBER_OF_TRIES 5
#define PORT 6262
#define PORT_ALT 6226

static u64 currentTime;

struct OutgoingPacketInfo
{
Common::TraversalPacket packet;
Common::TraversalRequestId misc;
bool fromAlt;
sockaddr_in6 dest;
int tries;
u64 sendTime;
Expand Down Expand Up @@ -119,6 +123,7 @@ using ConnectedClients =
using OutgoingPackets = std::unordered_map<Common::TraversalRequestId, OutgoingPacketInfo>;

static int sock;
static int sockAlt;
static OutgoingPackets outgoingPackets;
static ConnectedClients connectedClients;

Expand Down Expand Up @@ -186,25 +191,27 @@ static const char* SenderName(sockaddr_in6* addr)
return buf;
}

static void TrySend(const void* buffer, size_t size, sockaddr_in6* addr)
static void TrySend(const void* buffer, size_t size, sockaddr_in6* addr, bool fromAlt)
{
#if DEBUG
const auto* packet = static_cast<const Common::TraversalPacket*>(buffer);
printf("-> %d %llu %s\n", static_cast<int>(packet->type),
printf("%s-> %d %llu %s\n", fromAlt ? "alt " : "", static_cast<int>(packet->type),
static_cast<long long>(packet->requestId), SenderName(addr));
#endif
if ((size_t)sendto(sock, buffer, size, 0, (sockaddr*)addr, sizeof(*addr)) != size)
if ((size_t)sendto(fromAlt ? sockAlt : sock, buffer, size, 0, (sockaddr*)addr, sizeof(*addr)) !=
size)
{
perror("sendto");
}
}

static Common::TraversalPacket* AllocPacket(const sockaddr_in6& dest,
static Common::TraversalPacket* AllocPacket(const sockaddr_in6& dest, bool fromAlt,
Common::TraversalRequestId misc = 0)
{
Common::TraversalRequestId requestId{};
Common::Random::Generate(&requestId, sizeof(requestId));
OutgoingPacketInfo* info = &outgoingPackets[requestId];
info->fromAlt = fromAlt;
info->dest = dest;
info->misc = misc;
info->tries = 0;
Expand All @@ -219,12 +226,13 @@ static void SendPacket(OutgoingPacketInfo* info)
{
info->tries++;
info->sendTime = currentTime;
TrySend(&info->packet, sizeof(info->packet), &info->dest);
TrySend(&info->packet, sizeof(info->packet), &info->dest, info->fromAlt);
}

static void ResendPackets()
{
std::vector<std::pair<Common::TraversalInetAddress, Common::TraversalRequestId>> todoFailures;
std::vector<std::tuple<Common::TraversalInetAddress, bool, Common::TraversalRequestId>>
todoFailures;
todoFailures.clear();
for (auto it = outgoingPackets.begin(); it != outgoingPackets.end();)
{
Expand All @@ -235,7 +243,8 @@ static void ResendPackets()
{
if (info->packet.type == Common::TraversalPacketType::PleaseSendPacket)
{
todoFailures.push_back(std::make_pair(info->packet.pleaseSendPacket.address, info->misc));
todoFailures.push_back(
std::make_tuple(info->packet.pleaseSendPacket.address, info->fromAlt, info->misc));
}
it = outgoingPackets.erase(it);
continue;
Expand All @@ -250,14 +259,14 @@ static void ResendPackets()

for (const auto& p : todoFailures)
{
Common::TraversalPacket* fail = AllocPacket(MakeSinAddr(p.first));
Common::TraversalPacket* fail = AllocPacket(MakeSinAddr(std::get<0>(p)), std::get<1>(p));
fail->type = Common::TraversalPacketType::ConnectFailed;
fail->connectFailed.requestId = p.second;
fail->connectFailed.requestId = std::get<2>(p);
fail->connectFailed.reason = Common::TraversalConnectFailedReason::ClientDidntRespond;
}
}

static void HandlePacket(Common::TraversalPacket* packet, sockaddr_in6* addr)
static void HandlePacket(Common::TraversalPacket* packet, sockaddr_in6* addr, bool toAlt)
{
#if DEBUG
printf("<- %d %llu %s\n", static_cast<int>(packet->type),
Expand All @@ -276,7 +285,7 @@ static void HandlePacket(Common::TraversalPacket* packet, sockaddr_in6* addr)

if (info->packet.type == Common::TraversalPacketType::PleaseSendPacket)
{
auto* ready = AllocPacket(MakeSinAddr(info->packet.pleaseSendPacket.address));
auto* ready = AllocPacket(MakeSinAddr(info->packet.pleaseSendPacket.address), toAlt);
if (packet->ack.ok)
{
ready->type = Common::TraversalPacketType::ConnectReady;
Expand All @@ -303,7 +312,7 @@ static void HandlePacket(Common::TraversalPacket* packet, sockaddr_in6* addr)
case Common::TraversalPacketType::HelloFromClient:
{
u8 ok = packet->helloFromClient.protoVersion <= Common::TraversalProtoVersion;
Common::TraversalPacket* reply = AllocPacket(*addr);
Common::TraversalPacket* reply = AllocPacket(*addr, toAlt);
reply->type = Common::TraversalPacketType::HelloFromServer;
reply->helloFromServer.ok = ok;
if (ok)
Expand Down Expand Up @@ -336,19 +345,35 @@ static void HandlePacket(Common::TraversalPacket* packet, sockaddr_in6* addr)
auto r = EvictFind(connectedClients, hostId);
if (!r.found)
{
Common::TraversalPacket* reply = AllocPacket(*addr);
Common::TraversalPacket* reply = AllocPacket(*addr, toAlt);
reply->type = Common::TraversalPacketType::ConnectFailed;
reply->connectFailed.requestId = packet->requestId;
reply->connectFailed.reason = Common::TraversalConnectFailedReason::NoSuchClient;
}
else
{
Common::TraversalPacket* please = AllocPacket(MakeSinAddr(*r.value), packet->requestId);
Common::TraversalPacket* please =
AllocPacket(MakeSinAddr(*r.value), toAlt, packet->requestId);
please->type = Common::TraversalPacketType::PleaseSendPacket;
please->pleaseSendPacket.address = MakeInetAddress(*addr);
}
break;
}
case Common::TraversalPacketType::TestPlease:
{
Common::TraversalHostId& hostId = packet->testPlease.hostId;
auto r = EvictFind(connectedClients, hostId);
if (r.found)
{
Common::TraversalPacket ack = {};
ack.type = Common::TraversalPacketType::Ack;
ack.requestId = packet->requestId;
ack.ack.ok = true;
sockaddr_in6 mainAddr = MakeSinAddr(*r.value);
TrySend(&ack, sizeof(ack), &mainAddr, toAlt);
}
break;
}
default:
fprintf(stderr, "received unknown packet type %d from %s\n", static_cast<int>(packet->type),
SenderName(addr));
Expand All @@ -360,7 +385,8 @@ static void HandlePacket(Common::TraversalPacket* packet, sockaddr_in6* addr)
ack.type = Common::TraversalPacketType::Ack;
ack.requestId = packet->requestId;
ack.ack.ok = packetOk;
TrySend(&ack, sizeof(ack), addr);
TrySend(&ack, sizeof(ack), addr,
packet->type != Common::TraversalPacketType::TestPlease ? toAlt : !toAlt);
}
}

Expand All @@ -373,13 +399,25 @@ int main()
perror("socket");
return 1;
}
sockAlt = socket(PF_INET6, SOCK_DGRAM, 0);
if (sockAlt == -1)
{
perror("socket alt");
return 1;
}
int no = 0;
rv = setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, &no, sizeof(no));
if (rv < 0)
{
perror("setsockopt IPV6_V6ONLY");
return 1;
}
rv = setsockopt(sockAlt, IPPROTO_IPV6, IPV6_V6ONLY, &no, sizeof(no));
if (rv < 0)
{
perror("setsockopt IPV6_V6ONLY alt");
return 1;
}
in6_addr any = IN6ADDR_ANY_INIT;
sockaddr_in6 addr;
#ifdef SIN6_LEN
Expand All @@ -397,6 +435,13 @@ int main()
perror("bind");
return 1;
}
addr.sin6_port = htons(PORT_ALT);
rv = bind(sockAlt, (sockaddr*)&addr, sizeof(addr));
if (rv < 0)
{
perror("bind alt");
return 1;
}

timeval tv;
tv.tv_sec = 0;
Expand All @@ -407,19 +452,55 @@ int main()
perror("setsockopt SO_RCVTIMEO");
return 1;
}
rv = setsockopt(sockAlt, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
if (rv < 0)
{
perror("setsockopt SO_RCVTIMEO alt");
return 1;
}

#ifdef HAVE_LIBSYSTEMD
sd_notifyf(0, "READY=1\nSTATUS=Listening on port %d", PORT);
sd_notifyf(0, "READY=1\nSTATUS=Listening on port %d (alt port: %d)", PORT, PORT_ALT);
#endif

while (true)
{
tv.tv_sec = 0;
tv.tv_usec = 300000;
fd_set readSet;
FD_ZERO(&readSet);
FD_SET(sock, &readSet);
FD_SET(sockAlt, &readSet);
rv = select(std::max(sock, sockAlt) + 1, &readSet, nullptr, nullptr, &tv);
if (rv < 0)
{
if (errno != EINTR && errno != EAGAIN)
{
perror("recvfrom");
return 1;
}
}

int recvsock;
if (FD_ISSET(sock, &readSet))
{
recvsock = sock;
}
else if (FD_ISSET(sockAlt, &readSet))
{
recvsock = sockAlt;
}
else
{
ResendPackets();
continue;
}
sockaddr_in6 raddr;
socklen_t addrLen = sizeof(raddr);
Common::TraversalPacket packet{};
// note: switch to recvmmsg (yes, mmsg) if this becomes
// expensive
rv = recvfrom(sock, &packet, sizeof(packet), 0, (sockaddr*)&raddr, &addrLen);
rv = recvfrom(recvsock, &packet, sizeof(packet), 0, (sockaddr*)&raddr, &addrLen);
currentTime = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
Expand All @@ -437,7 +518,7 @@ int main()
}
else
{
HandlePacket(&packet, &raddr);
HandlePacket(&packet, &raddr, recvsock == sockAlt);
}
ResendPackets();
#ifdef HAVE_LIBSYSTEMD
Expand Down
1 change: 1 addition & 0 deletions Source/Core/Core/Config/NetplaySettings.cpp
Expand Up @@ -16,6 +16,7 @@ static constexpr u16 DEFAULT_LISTEN_PORT = 2626;
const Info<std::string> NETPLAY_TRAVERSAL_SERVER{{System::Main, "NetPlay", "TraversalServer"},
"stun.dolphin-emu.org"};
const Info<u16> NETPLAY_TRAVERSAL_PORT{{System::Main, "NetPlay", "TraversalPort"}, 6262};
const Info<u16> NETPLAY_TRAVERSAL_PORT_ALT{{System::Main, "NetPlay", "TraversalPortAlt"}, 6226};
const Info<std::string> NETPLAY_TRAVERSAL_CHOICE{{System::Main, "NetPlay", "TraversalChoice"},
"direct"};
const Info<std::string> NETPLAY_INDEX_URL{{System::Main, "NetPlay", "IndexServer"},
Expand Down
1 change: 1 addition & 0 deletions Source/Core/Core/Config/NetplaySettings.h
Expand Up @@ -16,6 +16,7 @@ namespace Config

extern const Info<std::string> NETPLAY_TRAVERSAL_SERVER;
extern const Info<u16> NETPLAY_TRAVERSAL_PORT;
extern const Info<u16> NETPLAY_TRAVERSAL_PORT_ALT;
extern const Info<std::string> NETPLAY_TRAVERSAL_CHOICE;
extern const Info<std::string> NETPLAY_HOST_CODE;
extern const Info<std::string> NETPLAY_INDEX_URL;
Expand Down
2 changes: 2 additions & 0 deletions Source/Core/Core/NetPlayClient.h
Expand Up @@ -72,6 +72,7 @@ class NetPlayUI
virtual void OnTraversalStateChanged(Common::TraversalClient::State state) = 0;
virtual void OnGameStartAborted() = 0;
virtual void OnGolferChanged(bool is_golfer, const std::string& golfer_name) = 0;
virtual void OnTtlDetermined(u8 ttl) = 0;

virtual bool IsRecording() = 0;
virtual std::shared_ptr<const UICommon::GameFile>
Expand Down Expand Up @@ -148,6 +149,7 @@ class NetPlayClient : public Common::TraversalClientClient
void OnTraversalStateChanged() override;
void OnConnectReady(ENetAddress addr) override;
void OnConnectFailed(Common::TraversalConnectFailedReason reason) override;
void OnTtlDetermined(u8 ttl) override {}

bool IsFirstInGamePad(int ingame_pad) const;
int NumLocalPads() const;
Expand Down
6 changes: 4 additions & 2 deletions Source/Core/Core/NetPlayProto.h
Expand Up @@ -116,15 +116,17 @@ struct NetSettings
struct NetTraversalConfig
{
NetTraversalConfig() = default;
NetTraversalConfig(bool use_traversal_, std::string traversal_host_, u16 traversal_port_)
NetTraversalConfig(bool use_traversal_, std::string traversal_host_, u16 traversal_port_,
u16 traversal_port_alt_ = 0)
: use_traversal{use_traversal_}, traversal_host{std::move(traversal_host_)},
traversal_port{traversal_port_}
traversal_port{traversal_port_}, traversal_port_alt{traversal_port_alt_}
{
}

bool use_traversal = false;
std::string traversal_host;
u16 traversal_port = 0;
u16 traversal_port_alt = 0;
};

enum class MessageID : u8
Expand Down
8 changes: 7 additions & 1 deletion Source/Core/Core/NetPlayServer.cpp
Expand Up @@ -133,7 +133,8 @@ NetPlayServer::NetPlayServer(const u16 port, const bool forward_port, NetPlayUI*
if (traversal_config.use_traversal)
{
if (!Common::EnsureTraversalClient(traversal_config.traversal_host,
traversal_config.traversal_port, port))
traversal_config.traversal_port,
traversal_config.traversal_port_alt, port))
{
return;
}
Expand Down Expand Up @@ -1268,6 +1269,11 @@ void NetPlayServer::OnTraversalStateChanged()
m_dialog->OnTraversalStateChanged(state);
}

void NetPlayServer::OnTtlDetermined(u8 ttl)
{
m_dialog->OnTtlDetermined(ttl);
}

// called from ---GUI--- thread
void NetPlayServer::SendChatMessage(const std::string& msg)
{
Expand Down
1 change: 1 addition & 0 deletions Source/Core/Core/NetPlayServer.h
Expand Up @@ -144,6 +144,7 @@ class NetPlayServer : public Common::TraversalClientClient
void OnTraversalStateChanged() override;
void OnConnectReady(ENetAddress) override {}
void OnConnectFailed(Common::TraversalConnectFailedReason) override {}
void OnTtlDetermined(u8 ttl) override;
void UpdatePadMapping();
void UpdateGBAConfig();
void UpdateWiimoteMapping();
Expand Down
8 changes: 5 additions & 3 deletions Source/Core/DolphinQt/MainWindow.cpp
Expand Up @@ -1591,14 +1591,16 @@ bool MainWindow::NetPlayHost(const UICommon::GameFile& game)

const std::string traversal_host = Config::Get(Config::NETPLAY_TRAVERSAL_SERVER);
const u16 traversal_port = Config::Get(Config::NETPLAY_TRAVERSAL_PORT);
const u16 traversal_port_alt = Config::Get(Config::NETPLAY_TRAVERSAL_PORT_ALT);

if (is_traversal)
host_port = Config::Get(Config::NETPLAY_LISTEN_PORT);

// Create Server
Settings::Instance().ResetNetPlayServer(new NetPlay::NetPlayServer(
host_port, use_upnp, m_netplay_dialog,
NetPlay::NetTraversalConfig{is_traversal, traversal_host, traversal_port}));
Settings::Instance().ResetNetPlayServer(
new NetPlay::NetPlayServer(host_port, use_upnp, m_netplay_dialog,
NetPlay::NetTraversalConfig{is_traversal, traversal_host,
traversal_port, traversal_port_alt}));

if (!Settings::Instance().GetNetPlayServer()->is_connected)
{
Expand Down
5 changes: 5 additions & 0 deletions Source/Core/DolphinQt/NetPlay/NetPlayDialog.cpp
Expand Up @@ -1039,6 +1039,11 @@ void NetPlayDialog::OnGolferChanged(const bool is_golfer, const std::string& gol
DisplayMessage(tr("%1 is now golfing").arg(QString::fromStdString(golfer_name)), "");
}

void NetPlayDialog::OnTtlDetermined(u8 ttl)
{
DisplayMessage(tr("Using TTL %1 for probe packet").arg(QString::number(ttl)), "");
}

bool NetPlayDialog::IsRecording()
{
std::optional<bool> is_recording = RunOnObject(m_record_input_action, &QAction::isChecked);
Expand Down
1 change: 1 addition & 0 deletions Source/Core/DolphinQt/NetPlay/NetPlayDialog.h
Expand Up @@ -71,6 +71,7 @@ class NetPlayDialog : public QDialog, public NetPlay::NetPlayUI
void OnTraversalStateChanged(Common::TraversalClient::State state) override;
void OnGameStartAborted() override;
void OnGolferChanged(bool is_golfer, const std::string& golfer_name) override;
void OnTtlDetermined(u8 ttl) override;

void OnIndexAdded(bool success, const std::string error) override;
void OnIndexRefreshFailed(const std::string error) override;
Expand Down