Skip to content

Commit

Permalink
Parse TCP host using hostname:port format
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenCWills authored and AJenbo committed Jun 11, 2024
1 parent bae4030 commit 78eb3c7
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 21 deletions.
4 changes: 2 additions & 2 deletions Source/dvlnet/abstract_net.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ using provider_t = unsigned long;

class abstract_net {
public:
virtual int create(std::string addrstr) = 0;
virtual int join(std::string addrstr) = 0;
virtual int create(std::string_view addrstr) = 0;
virtual int join(std::string_view addrstr) = 0;
virtual bool SNetReceiveMessage(uint8_t *sender, void **data, size_t *size) = 0;
virtual bool SNetSendMessage(uint8_t dest, void *data, size_t size) = 0;
virtual bool SNetReceiveTurns(char **data, size_t *size, uint32_t *status) = 0;
Expand Down
8 changes: 4 additions & 4 deletions Source/dvlnet/base_protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ namespace devilution::net {
template <class P>
class base_protocol : public base {
public:
int create(std::string addrstr) override;
int join(std::string addrstr) override;
int create(std::string_view addrstr) override;
int join(std::string_view addrstr) override;
tl::expected<void, PacketError> poll() override;
tl::expected<void, PacketError> send(packet &pkt) override;
void DisconnectNet(plr_t plr) override;
Expand Down Expand Up @@ -161,7 +161,7 @@ tl::expected<void, PacketError> base_protocol<P>::wait_join()
}

template <class P>
int base_protocol<P>::create(std::string addrstr)
int base_protocol<P>::create(std::string_view addrstr)
{
gamename = addrstr;
isGameHost_ = true;
Expand All @@ -183,7 +183,7 @@ int base_protocol<P>::create(std::string addrstr)
}

template <class P>
int base_protocol<P>::join(std::string addrstr)
int base_protocol<P>::join(std::string_view addrstr)
{
gamename = addrstr;
isGameHost_ = false;
Expand Down
4 changes: 2 additions & 2 deletions Source/dvlnet/cdwrap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ void cdwrap::reset()
dvlnet_wrap->SNetRegisterEventHandler(eventType, eventHandler);
}

int cdwrap::create(std::string addrstr)
int cdwrap::create(std::string_view addrstr)
{
reset();
return dvlnet_wrap->create(addrstr);
}

int cdwrap::join(std::string addrstr)
int cdwrap::join(std::string_view addrstr)
{
game_init_info = buffer_t();
reset();
Expand Down
4 changes: 2 additions & 2 deletions Source/dvlnet/cdwrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ class cdwrap : public abstract_net {
reset();
}

int create(std::string addrstr) override;
int join(std::string addrstr) override;
int create(std::string_view addrstr) override;
int join(std::string_view addrstr) override;
bool SNetReceiveMessage(uint8_t *sender, void **data, size_t *size) override;
bool SNetSendMessage(uint8_t dest, void *data, size_t size) override;
bool SNetReceiveTurns(char **data, size_t *size, uint32_t *status) override;
Expand Down
4 changes: 2 additions & 2 deletions Source/dvlnet/loopback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

namespace devilution::net {

int loopback::create(std::string /*addrstr*/)
int loopback::create(std::string_view /*addrstr*/)
{
IsLoopback = true;
return plr_single;
}

int loopback::join(std::string /*addrstr*/)
int loopback::join(std::string_view /*addrstr*/)
{
ABORT();
}
Expand Down
4 changes: 2 additions & 2 deletions Source/dvlnet/loopback.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ class loopback : public abstract_net {
public:
loopback() = default;

int create(std::string addrstr) override;
int join(std::string addrstr) override;
int create(std::string_view addrstr) override;
int join(std::string_view addrstr) override;
bool SNetReceiveMessage(uint8_t *sender, void **data, size_t *size) override;
bool SNetSendMessage(uint8_t dest, void *data, size_t size) override;
bool SNetReceiveTurns(char **data, size_t *size, uint32_t *status) override;
Expand Down
41 changes: 36 additions & 5 deletions Source/dvlnet/tcp_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,56 @@
#include "options.h"
#include "utils/language.h"
#include "utils/str_cat.hpp"
#include "utils/str_split.hpp"

namespace devilution::net {

int tcp_client::create(std::string addrstr)
int tcp_client::create(std::string_view addrstr)
{
auto port = *sgOptions.Network.port;
local_server = std::make_unique<tcp_server>(ioc, addrstr, port, *pktfty);
local_server = std::make_unique<tcp_server>(ioc, std::string(addrstr), port, *pktfty);
return join(local_server->LocalhostSelf());
}

int tcp_client::join(std::string addrstr)
int tcp_client::join(std::string_view addrstr)
{
constexpr int MsSleep = 10;
constexpr int NoSleep = 250;

std::string port = StrCat(*sgOptions.Network.port);
const char *defaultPort = "6112";
std::string_view host;
std::string_view port = defaultPort;
if (!addrstr.empty() && addrstr[0] == '[') {
// Assume IPv6 address in square brackets, followed by port
// Example: [::1]:6113
size_t pos = addrstr.find(']', 1);
pos = pos != std::string::npos ? pos + 1 : addrstr.length();
host = addrstr.substr(0, pos);

if (pos != addrstr.length()) {
if (addrstr[pos] != ':') {
SDL_SetError("Invalid hostname: expected colon after square brackets");
return -1;
}
if (++pos != addrstr.length())
port = addrstr.substr(pos);
}
} else {
// Assume "hostname:port"
SplitByChar splithost(addrstr, ':');
auto it = splithost.begin();
if (it != splithost.end()) host = *it++;
if (it != splithost.end()) port = *it++;

// If there is more than one colon, assume it's just a plain IPv6 address
if (it != splithost.end()) {
host = addrstr;
port = defaultPort;
}
}

asio::error_code errorCode;
asio::ip::basic_resolver_results<asio::ip::tcp> range = resolver.resolve(addrstr, port, errorCode);
asio::ip::basic_resolver_results<asio::ip::tcp> range = resolver.resolve(host, port, errorCode);
if (errorCode) {
SDL_SetError("%s", errorCode.message().c_str());
return -1;
Expand Down
4 changes: 2 additions & 2 deletions Source/dvlnet/tcp_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ namespace devilution::net {

class tcp_client : public base {
public:
int create(std::string addrstr) override;
int join(std::string addrstr) override;
int create(std::string_view addrstr) override;
int join(std::string_view addrstr) override;

tl::expected<void, PacketError> poll() override;
tl::expected<void, PacketError> send(packet &pkt) override;
Expand Down

0 comments on commit 78eb3c7

Please sign in to comment.