@@ -8,6 +8,7 @@
#include <cstddef>
#include <cstring>

#include "Common/Assert.h"
#include "Common/ChunkFile.h"
#include "Common/Logging/Log.h"
#include "Common/Swap.h"
@@ -31,14 +32,14 @@ V5CtrlMessage::V5CtrlMessage(Kernel& ios, const IOCtlVRequest& ioctlv)
V5BulkMessage::V5BulkMessage(Kernel& ios, const IOCtlVRequest& ioctlv)
: BulkMessage(ios, ioctlv, ioctlv.GetVector(1)->address)
{
length = static_cast<u16>(ioctlv.GetVector(1)->size);
length = ioctlv.GetVector(1)->size;
endpoint = Memory::Read_U8(ioctlv.in_vectors[0].address + 18);
}

V5IntrMessage::V5IntrMessage(Kernel& ios, const IOCtlVRequest& ioctlv)
: IntrMessage(ios, ioctlv, ioctlv.GetVector(1)->address)
{
length = static_cast<u16>(ioctlv.GetVector(1)->size);
length = ioctlv.GetVector(1)->size;
endpoint = Memory::Read_U8(ioctlv.in_vectors[0].address + 14);
}

@@ -48,9 +49,16 @@ V5IsoMessage::V5IsoMessage(Kernel& ios, const IOCtlVRequest& ioctlv)
num_packets = Memory::Read_U8(ioctlv.in_vectors[0].address + 16);
endpoint = Memory::Read_U8(ioctlv.in_vectors[0].address + 17);
packet_sizes_addr = ioctlv.GetVector(1)->address;
u32 total_packet_size = 0;
for (size_t i = 0; i < num_packets; ++i)
packet_sizes.push_back(Memory::Read_U16(static_cast<u32>(packet_sizes_addr + i * sizeof(u16))));
length = static_cast<u16>(ioctlv.GetVector(2)->size);
{
const u32 packet_size = Memory::Read_U16(static_cast<u32>(packet_sizes_addr + i * sizeof(u16)));
packet_sizes.push_back(packet_size);
total_packet_size += packet_size;
}
length = ioctlv.GetVector(2)->size;
ASSERT_MSG(IOS_USB, length == total_packet_size, "Wrong buffer size (0x%x != 0x%x)", length,
total_packet_size);
}
} // namespace USB

@@ -132,7 +140,7 @@ IPCCommandResult USBV5ResourceManager::SetAlternateSetting(USBV5Device& device,
const IOCtlRequest& request)
{
const auto host_device = GetDeviceById(device.host_id);
if (!host_device->Attach(device.interface_number))
if (!host_device->AttachAndChangeInterface(device.interface_number))
return GetDefaultReply(-1);

const u8 alt_setting = Memory::Read_U8(request.buffer_in + 2 * sizeof(s32));
@@ -55,7 +55,7 @@ IPCCommandResult USB_HIDv4::IOCtl(const IOCtlRequest& request)
if (request.buffer_in == 0 || request.buffer_in_size != 32)
return GetDefaultReply(IPC_EINVAL);
const auto device = GetDeviceByIOSID(Memory::Read_U32(request.buffer_in + 16));
if (!device->Attach(0))
if (!device->Attach())
return GetDefaultReply(IPC_EINVAL);
return HandleTransfer(device, request.request,
[&, this]() { return SubmitTransfer(*device, request); });
@@ -67,7 +67,10 @@ IPCCommandResult USB_HIDv5::IOCtlV(const IOCtlVRequest& request)
if (!device)
return GetDefaultReply(IPC_EINVAL);
auto host_device = GetDeviceById(device->host_id);
host_device->Attach(device->interface_number);
if (request.request == USB::IOCTLV_USBV5_CTRLMSG)
host_device->Attach();
else
host_device->AttachAndChangeInterface(device->interface_number);
return HandleTransfer(host_device, request.request,
[&, this]() { return SubmitTransfer(*device, *host_device, request); });
}
@@ -76,7 +76,10 @@ IPCCommandResult USB_VEN::IOCtlV(const IOCtlVRequest& request)
if (!device)
return GetDefaultReply(IPC_EINVAL);
auto host_device = GetDeviceById(device->host_id);
host_device->Attach(device->interface_number);
if (request.request == USB::IOCTLV_USBV5_CTRLMSG)
host_device->Attach();
else
host_device->AttachAndChangeInterface(device->interface_number);
return HandleTransfer(host_device, request.request,
[&, this]() { return SubmitTransfer(*host_device, request); });
}
@@ -104,8 +107,10 @@ s32 USB_VEN::SubmitTransfer(USB::Device& device, const IOCtlVRequest& ioctlv)

IPCCommandResult USB_VEN::CancelEndpoint(USBV5Device& device, const IOCtlRequest& request)
{
const u8 endpoint = static_cast<u8>(Memory::Read_U32(request.buffer_in + 8));
GetDeviceById(device.host_id)->CancelTransfer(endpoint);
const u8 endpoint = Memory::Read_U8(request.buffer_in + 8);
// IPC_EINVAL (-4) is returned when no transfer was cancelled.
if (GetDeviceById(device.host_id)->CancelTransfer(endpoint) < 0)
return GetDefaultReply(IPC_EINVAL);
return GetDefaultReply(IPC_SUCCESS);
}