Skip to content

Commit

Permalink
Merge pull request dolphin-emu#10700 from sepalani/ssl-handshake
Browse files Browse the repository at this point in the history
Socket: Fix some non-blocking connect edge cases
  • Loading branch information
JMC47 authored and dvessel committed Jun 28, 2022
2 parents bbe21a3 + d7135da commit 5483134
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 62 deletions.
18 changes: 18 additions & 0 deletions Source/Core/Common/Network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,4 +186,22 @@ u16 ComputeNetworkChecksum(const void* data, u16 length, u32 initial_value)
checksum = (checksum >> 16) + (checksum & 0xFFFF);
return ~static_cast<u16>(checksum);
}

NetworkErrorState SaveNetworkErrorState()
{
return {
errno,
#ifdef _WIN32
WSAGetLastError(),
#endif
};
}

void RestoreNetworkErrorState(const NetworkErrorState& state)
{
errno = state.error;
#ifdef _WIN32
WSASetLastError(state.wsa_error);
#endif
}
} // namespace Common
10 changes: 10 additions & 0 deletions Source/Core/Common/Network.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,18 @@ struct UDPHeader
};
static_assert(sizeof(UDPHeader) == UDPHeader::SIZE);

struct NetworkErrorState
{
int error;
#ifdef _WIN32
int wsa_error;
#endif
};

MACAddress GenerateMacAddress(MACConsumer type);
std::string MacAddressToString(const MACAddress& mac);
std::optional<MACAddress> StringToMacAddress(std::string_view mac_string);
u16 ComputeNetworkChecksum(const void* data, u16 length, u32 initial_value = 0);
NetworkErrorState SaveNetworkErrorState();
void RestoreNetworkErrorState(const NetworkErrorState& state);
} // namespace Common
202 changes: 177 additions & 25 deletions Source/Core/Core/IOS/Network/Socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
#include <sys/select.h>
#endif

#include "Common/BitUtils.h"
#include "Common/FileUtil.h"
#include "Common/IOFile.h"
#include "Common/Network.h"
#include "Common/ScopeGuard.h"
#include "Core/Config/MainSettings.h"
#include "Core/Core.h"
#include "Core/IOS/Device.h"
Expand Down Expand Up @@ -224,6 +227,7 @@ s32 WiiSocket::CloseFd()
GetIOS()->EnqueueIPCReply(it->request, -SO_ENOTCONN);
it = pending_sockops.erase(it);
}
connecting_state = ConnectingState::None;
return ReturnValue;
}

Expand Down Expand Up @@ -278,8 +282,8 @@ void WiiSocket::Update(bool read, bool write, bool except)
case IOCTL_SO_BIND:
{
sockaddr_in local_name;
WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(ioctl.buffer_in + 8);
WiiSockMan::Convert(*wii_name, local_name);
const u8* addr = Memory::GetPointer(ioctl.buffer_in + 8);
WiiSockMan::ToNativeAddrIn(addr, &local_name);

int ret = bind(fd, (sockaddr*)&local_name, sizeof(local_name));
ReturnValue = WiiSockMan::GetNetErrorCode(ret, "SO_BIND", false);
Expand All @@ -291,11 +295,12 @@ void WiiSocket::Update(bool read, bool write, bool except)
case IOCTL_SO_CONNECT:
{
sockaddr_in local_name;
WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(ioctl.buffer_in + 8);
WiiSockMan::Convert(*wii_name, local_name);
const u8* addr = Memory::GetPointer(ioctl.buffer_in + 8);
WiiSockMan::ToNativeAddrIn(addr, &local_name);

int ret = connect(fd, (sockaddr*)&local_name, sizeof(local_name));
ReturnValue = WiiSockMan::GetNetErrorCode(ret, "SO_CONNECT", false);
UpdateConnectingState(ReturnValue);

INFO_LOG_FMT(IOS_NET, "IOCTL_SO_CONNECT ({:08x}, {}:{}) = {}", wii_fd,
inet_ntoa(local_name.sin_addr), Common::swap16(local_name.sin_port), ret);
Expand All @@ -307,13 +312,13 @@ void WiiSocket::Update(bool read, bool write, bool except)
if (ioctl.buffer_out_size > 0)
{
sockaddr_in local_name;
WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(ioctl.buffer_out);
WiiSockMan::Convert(*wii_name, local_name);
u8* addr = Memory::GetPointer(ioctl.buffer_out);
WiiSockMan::ToNativeAddrIn(addr, &local_name);

socklen_t addrlen = sizeof(sockaddr_in);
ret = static_cast<s32>(accept(fd, (sockaddr*)&local_name, &addrlen));

WiiSockMan::Convert(local_name, *wii_name, addrlen);
WiiSockMan::ToWiiAddrIn(local_name, addr, addrlen);
}
else
{
Expand Down Expand Up @@ -341,10 +346,12 @@ void WiiSocket::Update(bool read, bool write, bool except)
{
ReturnValue = -SO_ENETUNREACH;
ResetTimeout();
connecting_state = ConnectingState::Error;
}
break;
case -SO_EISCONN:
ReturnValue = SO_SUCCESS;
connecting_state = ConnectingState::Connected;
[[fallthrough]];
default:
ResetTimeout();
Expand Down Expand Up @@ -392,6 +399,24 @@ void WiiSocket::Update(bool read, bool write, bool except)
{
case IOCTLV_NET_SSL_DOHANDSHAKE:
{
// The Wii allows a socket with an in-progress connection to
// perform the SSL handshake. MbedTLS doesn't support it so
// we have to check it manually.
connecting_state = GetConnectingState();
if (connecting_state == ConnectingState::Connecting)
{
WriteReturnValue(SSL_ERR_RAGAIN, BufferIn);
ReturnValue = SSL_ERR_RAGAIN;
break;
}
else if (connecting_state == ConnectingState::None ||
connecting_state == ConnectingState::Error)
{
WriteReturnValue(SSL_ERR_SYSCALL, BufferIn);
ReturnValue = SSL_ERR_SYSCALL;
break;
}

mbedtls_ssl_context* ctx = &NetSSLDevice::_SSL[sslID].ctx;
const int ret = mbedtls_ssl_handshake(ctx);
if (ret != 0)
Expand Down Expand Up @@ -550,6 +575,16 @@ void WiiSocket::Update(bool read, bool write, bool except)
{
case IOCTLV_SO_SENDTO:
{
// The Wii allows a socket with a connection in progress to use
// sendto(). This might not be supported by the operating system.
// We have to enforce it manually.
connecting_state = GetConnectingState();
if (nonBlock && IsTCP() && connecting_state == ConnectingState::Connecting)
{
ReturnValue = -SO_EAGAIN;
break;
}

u32 flags = Memory::Read_U32(BufferIn2 + 0x04);
u32 has_destaddr = Memory::Read_U32(BufferIn2 + 0x08);

Expand All @@ -564,8 +599,8 @@ void WiiSocket::Update(bool read, bool write, bool except)
sockaddr_in local_name = {0};
if (has_destaddr)
{
WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(BufferIn2 + 0x0C);
WiiSockMan::Convert(*wii_name, local_name);
const u8* addr = Memory::GetPointer(BufferIn2 + 0x0C);
WiiSockMan::ToNativeAddrIn(addr, &local_name);
}

auto* to = has_destaddr ? reinterpret_cast<sockaddr*>(&local_name) : nullptr;
Expand All @@ -587,6 +622,16 @@ void WiiSocket::Update(bool read, bool write, bool except)
}
case IOCTLV_SO_RECVFROM:
{
// The Wii allows a socket with a connection in progress to use
// recvfrom(). This might not be supported by the operating system.
// We have to enforce it manually.
connecting_state = GetConnectingState();
if (nonBlock && IsTCP() && connecting_state == ConnectingState::Connecting)
{
ReturnValue = -SO_EAGAIN;
break;
}

u32 flags = Memory::Read_U32(BufferIn + 0x04);
// Not a string, Windows requires a char* for recvfrom
char* data = (char*)Memory::GetPointer(BufferOut);
Expand All @@ -597,8 +642,8 @@ void WiiSocket::Update(bool read, bool write, bool except)

if (BufferOutSize2 != 0)
{
WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(BufferOut2);
WiiSockMan::Convert(*wii_name, local_name);
const u8* addr = Memory::GetPointer(BufferOut2);
WiiSockMan::ToNativeAddrIn(addr, &local_name);
}

// Act as non blocking when SO_MSG_NONBLOCK is specified
Expand Down Expand Up @@ -634,8 +679,8 @@ void WiiSocket::Update(bool read, bool write, bool except)

if (BufferOutSize2 != 0)
{
WiiSockAddrIn* wii_name = (WiiSockAddrIn*)Memory::GetPointer(BufferOut2);
WiiSockMan::Convert(local_name, *wii_name, addrlen);
u8* addr = Memory::GetPointer(BufferOut2);
WiiSockMan::ToWiiAddrIn(local_name, addr, addrlen);
}
break;
}
Expand Down Expand Up @@ -672,6 +717,112 @@ void WiiSocket::Update(bool read, bool write, bool except)
}
}

void WiiSocket::UpdateConnectingState(s32 connect_rv)
{
if (connect_rv == -SO_EAGAIN || connect_rv == -SO_EALREADY || connect_rv == -SO_EINPROGRESS)
{
connecting_state = ConnectingState::Connecting;
}
else if (connect_rv >= 0)
{
connecting_state = ConnectingState::Connected;
}
else
{
connecting_state = ConnectingState::Error;
}
}

WiiSocket::ConnectingState WiiSocket::GetConnectingState() const
{
const auto state = Common::SaveNetworkErrorState();
Common::ScopeGuard guard([&state] { Common::RestoreNetworkErrorState(state); });

#ifdef _WIN32
constexpr int (*get_errno)() = &WSAGetLastError;
#else
constexpr int (*get_errno)() = []() { return errno; };
#endif

switch (connecting_state)
{
case ConnectingState::Error:
case ConnectingState::Connected:
case ConnectingState::None:
break;
case ConnectingState::Connecting:
{
const s32 nfds = fd + 1;
fd_set read_fds;
fd_set write_fds;
fd_set except_fds;
struct timeval t = {0, 0};
FD_ZERO(&read_fds);
FD_ZERO(&write_fds);
FD_ZERO(&except_fds);
FD_SET(fd, &write_fds);
FD_SET(fd, &except_fds);

auto& sm = WiiSockMan::GetInstance();
if (select(nfds, &read_fds, &write_fds, &except_fds, &t) < 0)
{
const s32 error = get_errno();
ERROR_LOG_FMT(IOS_SSL, "Failed to get socket (fd={}) connection state (err={}): {}", wii_fd,
error, sm.DecodeError(error));
return ConnectingState::Error;
}

if (FD_ISSET(fd, &write_fds) == 0 && FD_ISSET(fd, &except_fds) == 0)
break;

s32 error = 0;
socklen_t len = sizeof(error);
if (getsockopt(fd, SOL_SOCKET, SO_ERROR, reinterpret_cast<char*>(&error), &len) != 0)
{
error = get_errno();
ERROR_LOG_FMT(IOS_SSL, "Failed to get socket (fd={}) error state (err={}): {}", wii_fd, error,
sm.DecodeError(error));
return ConnectingState::Error;
}

if (error != 0)
{
ERROR_LOG_FMT(IOS_SSL, "Non-blocking connect (fd={}) failed (err={}): {}", wii_fd, error,
sm.DecodeError(error));
return ConnectingState::Error;
}

// Get peername to ensure the socket is connected
sockaddr_in peer;
socklen_t peer_len = sizeof(peer);
if (getpeername(fd, reinterpret_cast<sockaddr*>(&peer), &peer_len) != 0)
{
error = get_errno();
ERROR_LOG_FMT(IOS_SSL, "Non-blocking connect (fd={}) failed to get peername (err={}): {}",
wii_fd, error, sm.DecodeError(error));
return ConnectingState::Error;
}

INFO_LOG_FMT(IOS_SSL, "Non-blocking connect (fd={}) succeeded", wii_fd);
return ConnectingState::Connected;
}
}

return connecting_state;
}

bool WiiSocket::IsTCP() const
{
const auto state = Common::SaveNetworkErrorState();
Common::ScopeGuard guard([&state] { Common::RestoreNetworkErrorState(state); });

int socket_type;
socklen_t option_length = sizeof(socket_type);
return getsockopt(fd, SOL_SOCKET, SO_TYPE, reinterpret_cast<char*>(&socket_type),
&option_length) == 0 &&
socket_type == SOCK_STREAM;
}

const WiiSocket::Timeout& WiiSocket::GetTimeout()
{
if (!timeout.has_value())
Expand Down Expand Up @@ -937,11 +1088,12 @@ void WiiSockMan::UpdatePollCommands()
pending_polls.end());
}

void WiiSockMan::Convert(WiiSockAddrIn const& from, sockaddr_in& to)
void WiiSockMan::ToNativeAddrIn(const u8* addr, sockaddr_in* to)
{
to.sin_addr.s_addr = from.addr.addr;
to.sin_family = from.family;
to.sin_port = from.port;
const WiiSockAddrIn from = Common::BitCastPtr<WiiSockAddrIn>(addr);
to->sin_addr.s_addr = from.addr.addr;
to->sin_family = from.family;
to->sin_port = from.port;
}

s32 WiiSockMan::ConvertEvents(s32 events, ConvertDirection dir)
Expand Down Expand Up @@ -981,15 +1133,15 @@ s32 WiiSockMan::ConvertEvents(s32 events, ConvertDirection dir)
return converted_events;
}

void WiiSockMan::Convert(sockaddr_in const& from, WiiSockAddrIn& to, s32 addrlen)
void WiiSockMan::ToWiiAddrIn(const sockaddr_in& from, u8* to, socklen_t addrlen)
{
to.addr.addr = from.sin_addr.s_addr;
to.family = from.sin_family & 0xFF;
to.port = from.sin_port;
if (addrlen < 0 || addrlen > static_cast<s32>(sizeof(WiiSockAddrIn)))
to.len = sizeof(WiiSockAddrIn);
else
to.len = addrlen;
to[offsetof(WiiSockAddrIn, len)] =
u8(addrlen > sizeof(WiiSockAddrIn) ? sizeof(WiiSockAddrIn) : addrlen);
to[offsetof(WiiSockAddrIn, family)] = u8(from.sin_family & 0xFF);
const u16& from_port = from.sin_port;
memcpy(to + offsetof(WiiSockAddrIn, port), &from_port, sizeof(from_port));
const u32& from_addr = from.sin_addr.s_addr;
memcpy(to + offsetof(WiiSockAddrIn, addr.addr), &from_addr, sizeof(from_addr));
}

void WiiSockMan::DoState(PointerWrap& p)
Expand Down

0 comments on commit 5483134

Please sign in to comment.