Skip to content

Commit

Permalink
tests: Move VMs out of test_usbip.py
Browse files Browse the repository at this point in the history
Make the fixture global, so it can be used in multiple tests.
  • Loading branch information
holesch committed Apr 12, 2024
1 parent 69a0bba commit 588539e
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 173 deletions.
28 changes: 28 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

import pytest

import not_my_board._util as util

from .util import ClientVM, ExporterVM, HubVM, VMs


@pytest.fixture(scope="session")
def event_loop():
Expand All @@ -10,3 +14,27 @@ def event_loop():
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()


@pytest.fixture(scope="session")
async def vms():
async with HubVM() as hub:
while True:
try:
async with util.connect("127.0.0.1", 5001):
pass
async with util.connect("127.0.0.1", 5002):
pass
except ConnectionRefusedError:
await asyncio.sleep(0.1)
continue
break
async with ExporterVM() as exporter:
async with ClientVM() as client:
await util.run_concurrently(
hub.configure(),
exporter.configure(),
client.configure(),
)
await exporter.usb_attach()
yield VMs(hub, exporter, client)
173 changes: 0 additions & 173 deletions tests/test_usbip.py
Original file line number Diff line number Diff line change
@@ -1,96 +1,4 @@
import asyncio
import collections
import contextlib
import pathlib
import sys

import pytest

import not_my_board._util as util


class _VM(util.ContextStack):
_name = ""

async def _context_stack(self, stack):
await stack.enter_async_context(
sh_task(f"./scripts/vmctl run {self._name}", f"vm {self._name}")
)

async def configure(self):
await sh(
f"./scripts/vmctl configure {self._name}", prefix=f"configure {self._name}"
)

def ssh_task(self, cmd, *args, **kwargs):
return sh_task(
f"./scripts/vmctl ssh {self._name} while-stdin " + cmd,
*args,
terminate=False,
**kwargs,
)

def ssh_task_root(self, cmd, *args, **kwargs):
return sh_task(
f"./scripts/vmctl ssh {self._name} doas while-stdin " + cmd,
*args,
terminate=False,
**kwargs,
)

async def ssh(self, cmd, *args, **kwargs):
return await sh(f"./scripts/vmctl ssh {self._name} " + cmd, *args, **kwargs)

async def ssh_poll(self, cmd, timeout=None):
return await sh_poll(f"./scripts/vmctl ssh {self._name} " + cmd, timeout)


class HubVM(_VM):
_name = "hub"
ip = "192.168.200.1"


class ExporterVM(_VM):
_name = "exporter"
ip = "192.168.200.2"

async def usb_attach(self):
await sh("./scripts/vmctl usb attach")

async def usb_detach(self):
await sh("./scripts/vmctl usb detach")


class ClientVM(_VM):
_name = "client"
ip = "192.168.200.3"


VMs = collections.namedtuple("VMs", ["hub", "exporter", "client"])


@pytest.fixture(scope="session")
async def vms():
async with HubVM() as hub:
while True:
try:
async with util.connect("127.0.0.1", 5001):
pass
async with util.connect("127.0.0.1", 5002):
pass
except ConnectionRefusedError:
await asyncio.sleep(0.1)
continue
break
async with ExporterVM() as exporter:
async with ClientVM() as client:
await util.run_concurrently(
hub.configure(),
exporter.configure(),
client.configure(),
)
await exporter.usb_attach()
yield VMs(hub, exporter, client)


async def test_raw_usb_forwarding(vms):
Expand Down Expand Up @@ -169,84 +77,3 @@ async def test_usb_forwarding(vms):
result = await vms.exporter.ssh("readlink /sys/bus/usb/devices/2-1/driver")
driver_name = pathlib.Path(result.stdout).name
assert driver_name == "usb"


ShResult = collections.namedtuple("ShResult", ["stdout", "stderr", "returncode"])


async def sh(cmd, check=True, strip=True, prefix=None):
proc = await asyncio.create_subprocess_shell(
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)

stdout, _ = await util.run_concurrently(
proc.stdout.read(), _log_output(proc.stderr, cmd, prefix)
)
await proc.wait()
if check and proc.returncode:
raise RuntimeError(f"{cmd!r} exited with {proc.returncode}")

stdout = stdout.decode("utf-8")
if strip:
stdout = stdout.rstrip()

return ShResult(stdout, None, proc.returncode)


@contextlib.asynccontextmanager
async def sh_task(cmd, prefix=None, terminate=True):
# need to exec, otherwise only the shell process is killed with
# proc.terminate()
proc = await asyncio.create_subprocess_shell(
f"exec {cmd}",
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)

logging_task = None
try:
logging_task = asyncio.create_task(_log_output(proc.stdout, cmd, prefix))
yield
finally:
proc.stdin.close()
await proc.stdin.wait_closed()

if terminate:
with contextlib.suppress(ProcessLookupError):
proc.terminate()

await proc.wait()
if logging_task:
await logging_task


async def sh_poll(cmd, timeout=None):
if pathlib.Path("/dev/kvm").exists():
if timeout is None:
timeout = 7
interval = 0.1
else:
if timeout is None:
timeout = 60
interval = 1

async def poll_loop():
while True:
result = await sh(cmd, check=False)
if result.returncode == 0:
break
await asyncio.sleep(interval)

await asyncio.wait_for(poll_loop(), timeout)


async def _log_output(stream, cmd, prefix):
if prefix is None:
prefix = f"[{cmd}] ".encode()
else:
prefix = f"[{prefix}] ".encode()

async for line in stream:
sys.stderr.buffer.write(prefix + line)
sys.stderr.buffer.flush()
145 changes: 145 additions & 0 deletions tests/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import asyncio
import collections
import contextlib
import pathlib
import sys

import not_my_board._util as util

VMs = collections.namedtuple("VMs", ["hub", "exporter", "client"])
ShResult = collections.namedtuple("ShResult", ["stdout", "stderr", "returncode"])


class _VM(util.ContextStack):
_name = ""

async def _context_stack(self, stack):
await stack.enter_async_context(
sh_task(f"./scripts/vmctl run {self._name}", f"vm {self._name}")
)

async def configure(self):
await sh(
f"./scripts/vmctl configure {self._name}", prefix=f"configure {self._name}"
)

def ssh_task(self, cmd, *args, **kwargs):
return sh_task(
f"./scripts/vmctl ssh {self._name} while-stdin " + cmd,
*args,
terminate=False,
**kwargs,
)

def ssh_task_root(self, cmd, *args, **kwargs):
return sh_task(
f"./scripts/vmctl ssh {self._name} doas while-stdin " + cmd,
*args,
terminate=False,
**kwargs,
)

async def ssh(self, cmd, *args, **kwargs):
return await sh(f"./scripts/vmctl ssh {self._name} " + cmd, *args, **kwargs)

async def ssh_poll(self, cmd, timeout=None):
return await sh_poll(f"./scripts/vmctl ssh {self._name} " + cmd, timeout)


class HubVM(_VM):
_name = "hub"
ip = "192.168.200.1"


class ExporterVM(_VM):
_name = "exporter"
ip = "192.168.200.2"

async def usb_attach(self):
await sh("./scripts/vmctl usb attach")

async def usb_detach(self):
await sh("./scripts/vmctl usb detach")


class ClientVM(_VM):
_name = "client"
ip = "192.168.200.3"


async def sh(cmd, check=True, strip=True, prefix=None):
proc = await asyncio.create_subprocess_shell(
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)

stdout, _ = await util.run_concurrently(
proc.stdout.read(), _log_output(proc.stderr, cmd, prefix)
)
await proc.wait()
if check and proc.returncode:
raise RuntimeError(f"{cmd!r} exited with {proc.returncode}")

stdout = stdout.decode("utf-8")
if strip:
stdout = stdout.rstrip()

return ShResult(stdout, None, proc.returncode)


@contextlib.asynccontextmanager
async def sh_task(cmd, prefix=None, terminate=True):
# need to exec, otherwise only the shell process is killed with
# proc.terminate()
proc = await asyncio.create_subprocess_shell(
f"exec {cmd}",
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)

logging_task = None
try:
logging_task = asyncio.create_task(_log_output(proc.stdout, cmd, prefix))
yield
finally:
proc.stdin.close()
await proc.stdin.wait_closed()

if terminate:
with contextlib.suppress(ProcessLookupError):
proc.terminate()

await proc.wait()
if logging_task:
await logging_task


async def sh_poll(cmd, timeout=None):
if pathlib.Path("/dev/kvm").exists():
if timeout is None:
timeout = 7
interval = 0.1
else:
if timeout is None:
timeout = 60
interval = 1

async def poll_loop():
while True:
result = await sh(cmd, check=False)
if result.returncode == 0:
break
await asyncio.sleep(interval)

await asyncio.wait_for(poll_loop(), timeout)


async def _log_output(stream, cmd, prefix):
if prefix is None:
prefix = f"[{cmd}] ".encode()
else:
prefix = f"[{prefix}] ".encode()

async for line in stream:
sys.stderr.buffer.write(prefix + line)
sys.stderr.buffer.flush()

0 comments on commit 588539e

Please sign in to comment.