diff --git a/labgrid/binding.py b/labgrid/binding.py index 80c7c0901..90f159aa9 100644 --- a/labgrid/binding.py +++ b/labgrid/binding.py @@ -97,6 +97,22 @@ def wrapper(self, *_args, **_kwargs): return wrapper + @classmethod + def check_bound(cls, func): + @wraps(func) + def wrapper(self, *_args, **_kwargs): + if self.state is BindingState.active: + raise StateError( + f'{self} is active, but must be deactivated to call {func.__qualname__}' # pylint: disable=line-too-long + ) + elif self.state is not BindingState.bound: + raise StateError( + f'{self} has not been bound, {func.__qualname__} cannot be called in state "{self.state.name}"' # pylint: disable=line-too-long + ) + return func(self, *_args, **_kwargs) + + return wrapper + class NamedBinding: """ Marks a binding (or binding set) as requiring an explicit name. diff --git a/labgrid/driver/common.py b/labgrid/driver/common.py index aee93a003..49ec36c49 100644 --- a/labgrid/driver/common.py +++ b/labgrid/driver/common.py @@ -45,6 +45,19 @@ def get_priority(self, protocol): return 0 + def get_export_name(self): + """Get the name to be used for exported variables. + + Falls back to the class name if the driver has no name. + """ + if self.name: + return self.name + return self.__class__.__name__ + + def get_export_vars(self): + """Get a dictionary of variables to be exported.""" + return {} + def check_file(filename, *, command_prefix=[]): if subprocess.call(command_prefix + ['test', '-r', filename]) != 0: diff --git a/labgrid/driver/networkinterfacedriver.py b/labgrid/driver/networkinterfacedriver.py index bc0b6ae60..d5d47c898 100644 --- a/labgrid/driver/networkinterfacedriver.py +++ b/labgrid/driver/networkinterfacedriver.py @@ -35,6 +35,13 @@ def on_deactivate(self): self.wrapper = None self.proxy = None + @Driver.check_bound + def get_export_vars(self): + return { + "host": self.iface.host, + "ifname": self.iface.ifname or "", + } + # basic @Driver.check_active @step() diff --git a/labgrid/driver/provider.py b/labgrid/driver/provider.py index 9e8b4b683..0a0637aec 100644 --- a/labgrid/driver/provider.py +++ b/labgrid/driver/provider.py @@ -14,6 +14,14 @@ class BaseProviderDriver(Driver): def __attrs_post_init__(self): super().__attrs_post_init__() + @Driver.check_bound + def get_export_vars(self): + return { + "host": self.provider.host, + "internal": self.provider.internal, + "external": self.provider.external, + } + @Driver.check_active @step(args=['filename'], result=True) def stage(self, filename): diff --git a/labgrid/driver/serialdriver.py b/labgrid/driver/serialdriver.py index 1a78e37d2..9f3f45d4a 100644 --- a/labgrid/driver/serialdriver.py +++ b/labgrid/driver/serialdriver.py @@ -69,6 +69,20 @@ def on_activate(self): def on_deactivate(self): self.close() + @Driver.check_bound + def get_export_vars(self): + vars = { + "speed": str(self.port.speed) + } + if isinstance(self.port, SerialPort): + vars["port"] = self.port.port + else: + host, port = proxymanager.get_host_and_port(self.port) + vars["host"] = host + vars["port"] = str(port) + vars["protocol"] = self.port.protocol + return vars + def _read(self, size: int = 1, timeout: float = 0.0): """ Reads 'size' or more bytes from the serialport diff --git a/labgrid/exceptions.py b/labgrid/exceptions.py index b771bd0a3..d5d8356c1 100644 --- a/labgrid/exceptions.py +++ b/labgrid/exceptions.py @@ -30,6 +30,11 @@ class NoResourceFoundError(NoSupplierFoundError): pass +@attr.s(eq=False) +class NoStrategyFoundError(NoSupplierFoundError): + pass + + @attr.s(eq=False) class RegistrationError(Exception): msg = attr.ib(validator=attr.validators.instance_of(str)) diff --git a/labgrid/remote/client.py b/labgrid/remote/client.py index b340c6d5c..f01beb126 100755 --- a/labgrid/remote/client.py +++ b/labgrid/remote/client.py @@ -2,13 +2,17 @@ coordinator, acquire a place and interact with the connected resources""" import argparse import asyncio +import atexit import contextlib +import enum import os import subprocess import traceback import logging import signal import sys +import shlex +import json from textwrap import indent from socket import gethostname from getpass import getuser @@ -27,6 +31,7 @@ from .. import Target, target_factory from ..util.proxy import proxymanager from ..util.helper import processwrapper +from ..util import atomic_replace from ..driver import Mode txaio.use_asyncio() @@ -1327,6 +1332,34 @@ async def print_reservations(self): print(f"Reservation '{res.token}':") res.show(level=1) + async def export(self, place, target): + exported = target.export() + exported["LG__CLIENT_PID"] = str(os.getpid()) + if self.args.format is ExportFormat.SHELL: + lines = [] + for k, v in sorted(exported.items()): + lines.append(f"{k}={shlex.quote(v)}") + data = "\n".join(lines) + elif self.args.format is ExportFormat.SHELL_EXPORT: + lines = [] + for k, v in sorted(exported.items()): + lines.append(f"export {k}={shlex.quote(v)}") + data = "\n".join(lines)+"\n" + elif self.args.format is ExportFormat.JSON: + data = json.dumps(exported) + if self.args.filename == "-": + sys.stdout.write(data) + else: + atomic_replace(self.args.filename, data.encode()) + print(f"Exported to {self.args.filename}", file=sys.stderr) + try: + print("Waiting for CTRL+C or SIGTERM...", file=sys.stderr) + while True: + await asyncio.sleep(1.0) + except GeneratorExit: + print("Exiting...\n", file=sys.stderr) + export.needs_target = True + def start_session(url, realm, extra): from autobahn.asyncio.wamp import ApplicationRunner @@ -1421,6 +1454,16 @@ def __call__(self, parser, namespace, value, option_string): v.append((local, remote)) setattr(namespace, self.dest, v) + +class ExportFormat(enum.Enum): + SHELL = "shell" + SHELL_EXPORT = "shell-export" + JSON = "json" + + def __str__(self): + return self.value + + def main(): processwrapper.enable_logging() logging.basicConfig( @@ -1756,6 +1799,13 @@ def main(): subparser = subparsers.add_parser('reservations', help="list current reservations") subparser.set_defaults(func=ClientSession.print_reservations) + subparser = subparsers.add_parser('export', help="export driver information to a file (needs environment with drivers)") + subparser.add_argument('--format', dest='format', + type=ExportFormat, choices=ExportFormat, default=ExportFormat.SHELL_EXPORT, + help="output format (default: %(default)s)") + subparser.add_argument('filename', help='output filename') + subparser.set_defaults(func=ClientSession.export) + # make any leftover arguments available for some commands args, leftover = parser.parse_known_args() if args.command not in ['ssh', 'rsync', 'forward']: @@ -1806,6 +1856,8 @@ def main(): if args.command and args.command != 'help': exitcode = 0 try: + signal.signal(signal.SIGTERM, lambda *_: sys.exit(0)) + session = start_session(args.crossbar, os.environ.get("LG_CROSSBAR_REALM", "realm1"), extra) try: diff --git a/labgrid/strategy/common.py b/labgrid/strategy/common.py index 627374fc5..6263f7b96 100644 --- a/labgrid/strategy/common.py +++ b/labgrid/strategy/common.py @@ -48,3 +48,10 @@ def transition(self, status): def force(self, status): raise NotImplementedError(f"Strategy.force() is not implemented for {self.__class__.__name__}") + + def prepare_export(self): + """By default, export all drivers bound by the strategy.""" + name_map = {} + for name in self.bindings.keys(): + name_map[getattr(self, name)] = name + return name_map diff --git a/labgrid/target.py b/labgrid/target.py index b4875f1da..212e0b92e 100644 --- a/labgrid/target.py +++ b/labgrid/target.py @@ -7,7 +7,7 @@ from .binding import BindingError, BindingState from .driver import Driver -from .exceptions import NoSupplierFoundError, NoDriverFoundError, NoResourceFoundError +from .exceptions import NoSupplierFoundError, NoDriverFoundError, NoResourceFoundError, NoStrategyFoundError from .resource import Resource from .strategy import Strategy from .util import Timeout @@ -212,6 +212,24 @@ def get_driver(self, cls, *, name=None, activate=True): """ return self._get_driver(cls, name=name, activate=activate) + def get_strategy(self): + """ + Helper function to get the strategy of the target. + + Returns the Strategy, if exactly one exists and raises a + NoStrategyFoundError otherwise. + """ + found = [] + for drv in self.drivers: + if not isinstance(drv, Strategy): + continue + found.append(drv) + if not found: + raise NoStrategyFoundError(f"no Strategy found in {self}") + elif len(found) > 1: + raise NoStrategyFoundError(f"multiple Strategies found in {self}") + return found[0] + def __getitem__(self, key): """ Syntactic sugar to access drivers by class (optionally filtered by @@ -469,6 +487,39 @@ def _atexit_cleanup(self): "method on targets yourself to handle exceptions explictly.") print(f"Error: {e}") + def export(self): + """ + Export information from drivers. + + All drivers are deactivated before being exported. + + The Strategy can decide for which driver the export method is called and + with which name. Otherwise, all drivers are exported. + """ + try: + name_map = self.get_strategy().prepare_export() + selection = set(name_map.keys()) + except NoStrategyFoundError: + name_map = {} + selection = set(driver for driver in self.drivers if not isinstance(driver, Strategy)) + + assert len(name_map) == len(set(name_map.values())), "duplicate export name" + + # drivers need to be deactivated for export to avoid conflicts + self.deactivate_all_drivers() + + export_vars = {} + for driver in selection: + name = name_map.get(driver) + if not name: + name = driver.get_export_name() + for k, v in driver.get_export_vars().items(): + assert isinstance(k, str), f"key {k} from {driver} is not a string" + assert isinstance(v, str), f"value {v} for key {k} from {driver} is not a string" + export_vars[f"LG__{name}_{k}".upper()] = v + return export_vars + + def cleanup(self): """Clean up conntected drivers and resources in reversed order""" self.deactivate_all_drivers() diff --git a/tests/test_export.py b/tests/test_export.py new file mode 100644 index 000000000..132115956 --- /dev/null +++ b/tests/test_export.py @@ -0,0 +1,100 @@ +import pytest + +from labgrid.resource import Resource, NetworkSerialPort +from labgrid.resource.remote import RemoteNetworkInterface, RemoteTFTPProvider +from labgrid.driver import Driver, SerialDriver, NetworkInterfaceDriver, TFTPProviderDriver +from labgrid.strategy import Strategy +from labgrid.binding import StateError + + +class ResourceA(Resource): + pass + + +class DriverA(Driver): + bindings = {"res": ResourceA} + + @Driver.check_bound + def get_export_vars(self): + return { + "a": "b", + } + + +class StrategyA(Strategy): + bindings = { + "drv": DriverA, + } + + +def test_export(target): + ra = ResourceA(target, "resource") + d = DriverA(target, "driver") + s = StrategyA(target, "strategy") + + exported = target.export() + assert exported == { + "LG__DRV_A": "b", + } + + target.activate(d) + with pytest.raises(StateError): + d.get_export_vars() + + +class StrategyB(Strategy): + bindings = { + "drv": DriverA, + } + + def prepare_export(self): + return { + self.drv: "custom_name", + } + + +def test_export_custom(target): + ra = ResourceA(target, "resource") + d = DriverA(target, "driver") + s = StrategyB(target, "strategy") + + exported = target.export() + assert exported == { + "LG__CUSTOM_NAME_A": "b", + } + + +def test_export_network_serial(target): + NetworkSerialPort(target, None, host='testhost', port=12345, speed=115200) + SerialDriver(target, None) + + exported = target.export() + assert exported == { + 'LG__SERIALDRIVER_HOST': 'testhost', + 'LG__SERIALDRIVER_PORT': '12345', + 'LG__SERIALDRIVER_PROTOCOL': 'rfc2217', + 'LG__SERIALDRIVER_SPEED': '115200' + } + + +def test_export_remote_network_interface(target): + RemoteNetworkInterface(target, None, host='testhost', ifname='wlan0') + NetworkInterfaceDriver(target, "netif") + + exported = target.export() + assert exported == { + 'LG__NETIF_HOST': 'testhost', + 'LG__NETIF_IFNAME': 'wlan0' + } + + +def test_export_remote_tftp_provider(target): + RemoteTFTPProvider(target, None, host='testhost', internal='/srv/tftp/testboard/', external='testboard/') + TFTPProviderDriver(target, "tftp") + + exported = target.export() + assert exported == { + 'LG__TFTP_HOST': 'testhost', + 'LG__TFTP_INTERNAL': '/srv/tftp/testboard/', + 'LG__TFTP_EXTERNAL': 'testboard/', + } diff --git a/tests/test_target.py b/tests/test_target.py index 1f83cb42d..197be1617 100644 --- a/tests/test_target.py +++ b/tests/test_target.py @@ -9,7 +9,7 @@ from labgrid.resource import Resource from labgrid.driver import Driver from labgrid.strategy import Strategy -from labgrid.exceptions import NoSupplierFoundError, NoDriverFoundError, NoResourceFoundError +from labgrid.exceptions import NoSupplierFoundError, NoDriverFoundError, NoResourceFoundError, NoStrategyFoundError # test basic construction @@ -252,6 +252,27 @@ def test_suppliers_optional_named_a_missing(target): +class StrategyA(Strategy): + bindings = { + "drv": DriverWithA, + } + + +def test_get_strategy(target): + ra = ResourceA(target, "resource") + d = DriverWithA(target, "driver") + + with pytest.raises(NoStrategyFoundError): + target.get_strategy() + + s1 = StrategyA(target, "s1") + assert target.get_strategy() is s1 + + s2 = StrategyA(target, "s2") + with pytest.raises(NoStrategyFoundError): + target.get_strategy() + + # test nested resource creation @attr.s(eq=False) class DiscoveryResource(Resource):