Skip to content

Commit

Permalink
Merge pull request #778 from dorssel/attach
Browse files Browse the repository at this point in the history
Move attach command to top level
  • Loading branch information
dorssel committed Nov 23, 2023
2 parents 94d6f88 + e4388f5 commit ff5aab1
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 194 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace UnitTests;
using ExitCode = Program.ExitCode;

[TestClass]
sealed class Parse_wsl_attach_Tests
sealed class Parse_attach_Tests
: ParseTestBase
{
static readonly BusId TestBusId = BusId.Parse("3-42");
Expand All @@ -21,127 +21,133 @@ sealed class Parse_wsl_attach_Tests
public void BusIdSuccess()
{
var mock = CreateMock();
mock.Setup(m => m.WslAttach(It.Is<BusId>(busId => busId == TestBusId), false, null,
mock.Setup(m => m.AttachWsl(It.Is<BusId>(busId => busId == TestBusId), false, null,
It.IsNotNull<IConsole>(), It.IsAny<CancellationToken>())).Returns(Task.FromResult(ExitCode.Success));

Test(ExitCode.Success, mock, "wsl", "attach", "--busid", TestBusId.ToString());
Test(ExitCode.Success, mock, "attach", "--wsl", "--busid", TestBusId.ToString());
}

[TestMethod]
public void BusIdSuccessWithAutoAttach()
{
var mock = CreateMock();
mock.Setup(m => m.WslAttach(It.Is<BusId>(busId => busId == TestBusId), true, null,
mock.Setup(m => m.AttachWsl(It.Is<BusId>(busId => busId == TestBusId), true, null,
It.IsNotNull<IConsole>(), It.IsAny<CancellationToken>())).Returns(Task.FromResult(ExitCode.Success));

Test(ExitCode.Success, mock, "wsl", "attach", "--busid", TestBusId.ToString(), "--auto-attach");
Test(ExitCode.Success, mock, "attach", "--wsl", "--busid", TestBusId.ToString(), "--auto-attach");
}

[TestMethod]
public void BusIdSuccessWithDistribution()
{
var mock = CreateMock();
mock.Setup(m => m.WslAttach(It.Is<BusId>(busId => busId == TestBusId), false, It.Is<string>(distribution => distribution == TestDistribution),
mock.Setup(m => m.AttachWsl(It.Is<BusId>(busId => busId == TestBusId), false, It.Is<string>(distribution => distribution == TestDistribution),
It.IsNotNull<IConsole>(), It.IsAny<CancellationToken>())).Returns(Task.FromResult(ExitCode.Success));

Test(ExitCode.Success, mock, "wsl", "attach", "--busid", TestBusId.ToString(), "--distribution", TestDistribution);
Test(ExitCode.Success, mock, "attach", "--wsl", TestDistribution, "--busid", TestBusId.ToString());
}

[TestMethod]
public void BusIdFailure()
{
var mock = CreateMock();
mock.Setup(m => m.WslAttach(It.Is<BusId>(busId => busId == TestBusId), false, null,
mock.Setup(m => m.AttachWsl(It.Is<BusId>(busId => busId == TestBusId), false, null,
It.IsNotNull<IConsole>(), It.IsAny<CancellationToken>())).Returns(Task.FromResult(ExitCode.Failure));

Test(ExitCode.Failure, mock, "wsl", "attach", "--busid", TestBusId.ToString());
Test(ExitCode.Failure, mock, "attach", "--wsl", "--busid", TestBusId.ToString());
}

[TestMethod]
public void BusIdCanceled()
{
var mock = CreateMock();
mock.Setup(m => m.WslAttach(It.Is<BusId>(busId => busId == TestBusId), false, null,
mock.Setup(m => m.AttachWsl(It.Is<BusId>(busId => busId == TestBusId), false, null,
It.IsNotNull<IConsole>(), It.IsAny<CancellationToken>())).Throws<OperationCanceledException>();

Test(ExitCode.Canceled, mock, "wsl", "attach", "--busid", TestBusId.ToString());
Test(ExitCode.Canceled, mock, "attach", "--wsl", "--busid", TestBusId.ToString());
}

[TestMethod]
public void HardwareIdSuccess()
{
var mock = CreateMock();
mock.Setup(m => m.WslAttach(It.Is<VidPid>(vidPid => vidPid == TestHardwareId), false, null,
mock.Setup(m => m.AttachWsl(It.Is<VidPid>(vidPid => vidPid == TestHardwareId), false, null,
It.IsNotNull<IConsole>(), It.IsAny<CancellationToken>())).Returns(Task.FromResult(ExitCode.Success));

Test(ExitCode.Success, mock, "wsl", "attach", "--hardware-id", TestHardwareId.ToString());
Test(ExitCode.Success, mock, "attach", "--wsl", "--hardware-id", TestHardwareId.ToString());
}

[TestMethod]
public void HardwareIdFailure()
{
var mock = CreateMock();
mock.Setup(m => m.WslAttach(It.Is<VidPid>(vidPid => vidPid == TestHardwareId), false, null,
mock.Setup(m => m.AttachWsl(It.Is<VidPid>(vidPid => vidPid == TestHardwareId), false, null,
It.IsNotNull<IConsole>(), It.IsAny<CancellationToken>())).Returns(Task.FromResult(ExitCode.Failure));

Test(ExitCode.Failure, mock, "wsl", "attach", "--hardware-id", TestHardwareId.ToString());
Test(ExitCode.Failure, mock, "attach", "--wsl", "--hardware-id", TestHardwareId.ToString());
}

[TestMethod]
public void HardwareIdCanceled()
{
var mock = CreateMock();
mock.Setup(m => m.WslAttach(It.Is<VidPid>(vidPid => vidPid == TestHardwareId), false, null,
mock.Setup(m => m.AttachWsl(It.Is<VidPid>(vidPid => vidPid == TestHardwareId), false, null,
It.IsNotNull<IConsole>(), It.IsAny<CancellationToken>())).Throws<OperationCanceledException>();

Test(ExitCode.Canceled, mock, "wsl", "attach", "--hardware-id", TestHardwareId.ToString());
Test(ExitCode.Canceled, mock, "attach", "--wsl", "--hardware-id", TestHardwareId.ToString());
}

[TestMethod]
public void Help()
{
Test(ExitCode.Success, "wsl", "attach", "--help");
Test(ExitCode.Success, "attach", "--help");
}

[TestMethod]
public void OptionMissing()
public void WslMissing()
{
Test(ExitCode.ParseError, "wsl", "attach");
Test(ExitCode.ParseError, "attach");
}

[TestMethod]
public void DeviceMissing()
{
Test(ExitCode.ParseError, "attach", "--wsl");
}

[TestMethod]
public void BusIdAndHardwareId()
{
Test(ExitCode.ParseError, "wsl", "attach", "--busid", TestBusId.ToString(), "--hardware-id", TestHardwareId.ToString());
Test(ExitCode.ParseError, "attach", "--wsl", "--busid", TestBusId.ToString(), "--hardware-id", TestHardwareId.ToString());
}

[TestMethod]
public void BusIdArgumentMissing()
{
Test(ExitCode.ParseError, "wsl", "attach", "--busid");
Test(ExitCode.ParseError, "attach", "--wsl", "--busid");
}

[TestMethod]
public void HardwareIdArgumentMissing()
{
Test(ExitCode.ParseError, "wsl", "attach", "--hardware-id");
Test(ExitCode.ParseError, "attach", "--wsl", "--hardware-id");
}

[TestMethod]
public void BusIdArgumentInvalid()
{
Test(ExitCode.ParseError, "wsl", "attach", "--busid", "not-a-busid");
Test(ExitCode.ParseError, "attach", "--wsl", "--busid", "not-a-busid");
}

[TestMethod]
public void HardwareIdArgumentInvalid()
{
Test(ExitCode.ParseError, "wsl", "attach", "--hardware-id", "not-a-hardware-id");
Test(ExitCode.ParseError, "attach", "--wsl", "--hardware-id", "not-a-hardware-id");
}

[TestMethod]
public void StrayArgument()
{
Test(ExitCode.ParseError, "wsl", "attach", "stray-argument");
Test(ExitCode.ParseError, "attach", "stray-argument");
}
}
14 changes: 5 additions & 9 deletions UnitTests/Parse_wsl_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,16 @@ sealed class Parse_wsl_Tests
: ParseTestBase
{
[TestMethod]
public void Success()
public void ParseError()
{
Test(ExitCode.Success, "wsl");
// 'wsl' has been removed, so this is an error.
Test(ExitCode.ParseError, "wsl");
}

[TestMethod]
public void Help()
{
Test(ExitCode.Success, "wsl", "--help");
}

[TestMethod]
public void UnknownCommand()
{
Test(ExitCode.ParseError, "wsl", "unknown-command");
// Even --help will give a parse error, just to remind the user the command is entirely gone.
Test(ExitCode.ParseError, "wsl", "--help");
}
}
12 changes: 6 additions & 6 deletions Usbipd/AttachedEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ async Task HandleSubmitIsochronousAsync(UsbIpHeaderBasic basic, UsbIpHeaderCmdSu
}

var packetDescriptors = await Stream.ReadUsbIpIsoPacketDescriptorsAsync(submit.number_of_packets, cancellationToken);
if (packetDescriptors.Any((d) => d.length > ushort.MaxValue))
if (packetDescriptors.Any(d => d.length > ushort.MaxValue))
{
// VBoxUSB uses ushort for length, and that is fine as none of the current
// USB standards support larger ISO packets sizes. This is just a sanity check.
throw new ProtocolViolationException("ISO packet too big");
}
if (packetDescriptors.Sum((d) => d.length) != submit.transfer_buffer_length)
if (packetDescriptors.Sum(d => d.length) != submit.transfer_buffer_length)
{
// USBIP requires the packets in the data buffer to be sequential without any padding.
throw new ProtocolViolationException($"cumulative lengths of ISO packets does not match transfer_buffer_length");
Expand Down Expand Up @@ -155,10 +155,10 @@ async Task HandleSubmitIsochronousAsync(UsbIpHeaderBasic basic, UsbIpHeaderCmdSu
ret_submit = new()
{
status = -(int)Errno.SUCCESS,
actual_length = (int)packetDescriptors.Sum((pd) => pd.actual_length),
actual_length = (int)packetDescriptors.Sum(pd => pd.actual_length),
start_frame = submit.start_frame,
number_of_packets = submit.number_of_packets,
error_count = packetDescriptors.Count((d) => d.status != -(int)Errno.SUCCESS),
error_count = packetDescriptors.Count(d => d.status != -(int)Errno.SUCCESS),
},
};
Expand Down Expand Up @@ -199,7 +199,7 @@ async Task HandleSubmitIsochronousAsync(UsbIpHeaderBasic basic, UsbIpHeaderCmdSu
}
finally
{
_ = Task.WhenAll(ioctls).ContinueWith((task) =>
_ = Task.WhenAll(ioctls).ContinueWith(task =>
{
gcHandle.Free();
}, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default);
Expand Down Expand Up @@ -327,7 +327,7 @@ public async Task HandleSubmitAsync(UsbIpHeaderBasic basic, UsbIpHeaderCmdSubmit
gcHandle.Free();
throw;
}
_ = ioctl.ContinueWith((task) =>
_ = ioctl.ContinueWith(task =>
{
gcHandle.Free();
}, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default);
Expand Down
61 changes: 15 additions & 46 deletions Usbipd/CommandHandlers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@
using System.CommandLine;
using System.ComponentModel;
using System.Diagnostics;
using System.Security.Principal;
using System.Text;
using System.Text.Encodings.Web;
using System.Text.Json;
using System.Text.RegularExpressions;
using Microsoft.Extensions.Hosting.WindowsServices;
using Usbipd.Automation;
using Windows.Win32.Security;
using static Usbipd.ConsoleTools;
using ExitCode = Usbipd.Program.ExitCode;

namespace Usbipd;

interface ICommandHandlers
{
public Task<ExitCode> AttachWsl(BusId busId, bool autoAttach, string? distribution, IConsole console, CancellationToken cancellationToken);
public Task<ExitCode> AttachWsl(VidPid vidPid, bool autoAttach, string? distribution, IConsole console, CancellationToken cancellationToken);
public Task<ExitCode> Bind(BusId busId, bool force, IConsole console, CancellationToken cancellationToken);
public Task<ExitCode> Bind(VidPid vidPid, bool force, IConsole console, CancellationToken cancellationToken);
public Task<ExitCode> Detach(BusId busId, IConsole console, CancellationToken cancellationToken);
Expand All @@ -33,10 +33,6 @@ interface ICommandHandlers
public Task<ExitCode> Unbind(Guid guid, IConsole console, CancellationToken cancellationToken);
public Task<ExitCode> Unbind(VidPid vidPid, IConsole console, CancellationToken cancellationToken);
public Task<ExitCode> UnbindAll(IConsole console, CancellationToken cancellationToken);

public Task<ExitCode> WslAttach(BusId busId, bool autoAttach, string? distribution, IConsole console, CancellationToken cancellationToken);
public Task<ExitCode> WslAttach(VidPid vidPid, bool autoAttach, string? distribution, IConsole console, CancellationToken cancellationToken);

public Task<ExitCode> State(IConsole console, CancellationToken cancellationToken);
}

Expand Down Expand Up @@ -182,21 +178,18 @@ Task<ExitCode> ICommandHandlers.List(bool usbids, IConsole console, Cancellation
return Task.FromResult(ExitCode.Success);
}

static ExitCode Bind(BusId busId, bool wslAttach, bool force, IConsole console)
static ExitCode Bind(BusId busId, bool force, IConsole console)
{
var device = UsbDevice.GetAll().Where(d => d.BusId.HasValue && d.BusId.Value == busId).SingleOrDefault();
if (device is null)
{
console.ReportError($"There is no device with busid '{busId}'.");
return ExitCode.Failure;
}
if (device.Guid.HasValue && (wslAttach || (force == device.IsForced)))
if (device.Guid.HasValue && (force == device.IsForced))
{
// Not an error, just let the user know they just executed a no-op.
if (!wslAttach)
{
console.ReportInfo($"Device with busid '{busId}' was already shared.");
}
console.ReportInfo($"Device with busid '{busId}' was already shared.");
if (!device.IsForced)
{
console.ReportIfForceNeeded();
Expand All @@ -205,29 +198,6 @@ static ExitCode Bind(BusId busId, bool wslAttach, bool force, IConsole console)
}
if (!CheckWriteAccess(console))
{
if (wslAttach)
{
TOKEN_ELEVATION_TYPE elevationType;
unsafe
{
using var identity = WindowsIdentity.GetCurrent();
var b = Windows.Win32.PInvoke.GetTokenInformation(identity.AccessToken, TOKEN_INFORMATION_CLASS.TokenElevationType, &elevationType, 4, out var returnLength);
if (!b || returnLength != 4)
{
// Assume elevation is not available.
elevationType = TOKEN_ELEVATION_TYPE.TokenElevationTypeDefault;
}
}

if (elevationType == TOKEN_ELEVATION_TYPE.TokenElevationTypeLimited)
{
console.ReportInfo("The first time attaching a device to WSL requires elevated privileges; subsequent attaches will succeed with standard user privileges.");
}
else
{
console.ReportInfo($"To share this device, an administrator will first have to execute 'usbipd bind --busid {busId}'.");
}
}
return ExitCode.AccessDenied;
}
if (!device.Guid.HasValue)
Expand All @@ -254,7 +224,7 @@ static ExitCode Bind(BusId busId, bool wslAttach, bool force, IConsole console)

Task<ExitCode> ICommandHandlers.Bind(BusId busId, bool force, IConsole console, CancellationToken cancellationToken)
{
return Task.FromResult(Bind(busId, false, force, console));
return Task.FromResult(Bind(busId, force, console));
}

Task<ExitCode> ICommandHandlers.Bind(VidPid vidPid, bool force, IConsole console, CancellationToken cancellationToken)
Expand All @@ -263,7 +233,7 @@ Task<ExitCode> ICommandHandlers.Bind(VidPid vidPid, bool force, IConsole console
{
return Task.FromResult(ExitCode.Failure);
}
return Task.FromResult(Bind(busId, false, force, console));
return Task.FromResult(Bind(busId, force, console));
}

async Task<ExitCode> ICommandHandlers.Server(string[] args, IConsole console, CancellationToken cancellationToken)
Expand Down Expand Up @@ -472,14 +442,19 @@ Task<ExitCode> ICommandHandlers.UnbindAll(IConsole console, CancellationToken ca
[GeneratedRegex(@"^[a-zA-Z]:\\")]
private static partial Regex LocalDriveRegex();

async Task<ExitCode> ICommandHandlers.WslAttach(BusId busId, bool autoAttach, string? distribution, IConsole console, CancellationToken cancellationToken)
async Task<ExitCode> ICommandHandlers.AttachWsl(BusId busId, bool autoAttach, string? distribution, IConsole console, CancellationToken cancellationToken)
{
var device = UsbDevice.GetAll().Where(d => d.BusId.HasValue && d.BusId.Value == busId).SingleOrDefault();
if (device is null)
{
console.ReportError($"There is no device with busid '{busId}'.");
return ExitCode.Failure;
}
if (!device.Guid.HasValue)
{
console.ReportError($"Device is not shared; run 'usbipd bind -b {busId}' as administrator first.");
return ExitCode.Failure;
}
// We allow auto-attach on devices that are already attached.
if (!autoAttach && (device.IPAddress is not null))
{
Expand Down Expand Up @@ -622,12 +597,6 @@ async Task<ExitCode> ICommandHandlers.WslAttach(BusId busId, bool autoAttach, st
return ExitCode.Failure;
}

var bindResult = Bind(busId, true, false, console);
if (bindResult != ExitCode.Success)
{
return bindResult;
}

// 6) WSL kernel must be USBIP capable.

{
Expand Down Expand Up @@ -719,13 +688,13 @@ async Task<ExitCode> ICommandHandlers.WslAttach(BusId busId, bool autoAttach, st
return ExitCode.Success;
}

async Task<ExitCode> ICommandHandlers.WslAttach(VidPid vidPid, bool autoAttach, string? distribution, IConsole console, CancellationToken cancellationToken)
async Task<ExitCode> ICommandHandlers.AttachWsl(VidPid vidPid, bool autoAttach, string? distribution, IConsole console, CancellationToken cancellationToken)
{
if (GetBusIdByHardwareId(vidPid, console) is not BusId busId)
{
return ExitCode.Failure;
}
return await ((ICommandHandlers)this).WslAttach(busId, autoAttach, distribution, console, cancellationToken);
return await ((ICommandHandlers)this).AttachWsl(busId, autoAttach, distribution, console, cancellationToken);
}

static ExitCode Detach(IEnumerable<UsbDevice> devices, IConsole console)
Expand Down
Loading

0 comments on commit ff5aab1

Please sign in to comment.