Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions labgrid/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,22 @@ def wrapper(self, *_args, **_kwargs):

return wrapper

@classmethod
def check_bound(cls, func):
@wraps(func)
def wrapper(self, *_args, **_kwargs):
if self.state is BindingState.active:
raise StateError(
f'{self} is active, but must be deactivated to call {func.__qualname__}' # pylint: disable=line-too-long
)
elif self.state is not BindingState.bound:
raise StateError(
f'{self} has not been bound, {func.__qualname__} cannot be called in state "{self.state.name}"' # pylint: disable=line-too-long
)
return func(self, *_args, **_kwargs)

return wrapper

class NamedBinding:
"""
Marks a binding (or binding set) as requiring an explicit name.
Expand Down
13 changes: 13 additions & 0 deletions labgrid/driver/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ def get_priority(self, protocol):

return 0

def get_export_name(self):
"""Get the name to be used for exported variables.

Falls back to the class name if the driver has no name.
"""
if self.name:
return self.name
return self.__class__.__name__

def get_export_vars(self):
"""Get a dictionary of variables to be exported."""
return {}


def check_file(filename, *, command_prefix=[]):
if subprocess.call(command_prefix + ['test', '-r', filename]) != 0:
Expand Down
7 changes: 7 additions & 0 deletions labgrid/driver/networkinterfacedriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ def on_deactivate(self):
self.wrapper = None
self.proxy = None

@Driver.check_bound
def get_export_vars(self):
return {
"host": self.iface.host,
"ifname": self.iface.ifname or "",
}

# basic
@Driver.check_active
@step()
Expand Down
8 changes: 8 additions & 0 deletions labgrid/driver/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ class BaseProviderDriver(Driver):
def __attrs_post_init__(self):
super().__attrs_post_init__()

@Driver.check_bound
def get_export_vars(self):
return {
"host": self.provider.host,
"internal": self.provider.internal,
"external": self.provider.external,
}

@Driver.check_active
@step(args=['filename'], result=True)
def stage(self, filename):
Expand Down
14 changes: 14 additions & 0 deletions labgrid/driver/serialdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ def on_activate(self):
def on_deactivate(self):
self.close()

@Driver.check_bound
def get_export_vars(self):
vars = {
"speed": str(self.port.speed)
}
if isinstance(self.port, SerialPort):
vars["port"] = self.port.port
else:
host, port = proxymanager.get_host_and_port(self.port)
vars["host"] = host
vars["port"] = str(port)
vars["protocol"] = self.port.protocol
return vars

def _read(self, size: int = 1, timeout: float = 0.0):
"""
Reads 'size' or more bytes from the serialport
Expand Down
5 changes: 5 additions & 0 deletions labgrid/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ class NoResourceFoundError(NoSupplierFoundError):
pass


@attr.s(eq=False)
class NoStrategyFoundError(NoSupplierFoundError):
pass


@attr.s(eq=False)
class RegistrationError(Exception):
msg = attr.ib(validator=attr.validators.instance_of(str))
52 changes: 52 additions & 0 deletions labgrid/remote/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
coordinator, acquire a place and interact with the connected resources"""
import argparse
import asyncio
import atexit
import contextlib
import enum
import os
import subprocess
import traceback
import logging
import signal
import sys
import shlex
import json
from textwrap import indent
from socket import gethostname
from getpass import getuser
Expand All @@ -27,6 +31,7 @@
from .. import Target, target_factory
from ..util.proxy import proxymanager
from ..util.helper import processwrapper
from ..util import atomic_replace
from ..driver import Mode

txaio.use_asyncio()
Expand Down Expand Up @@ -1327,6 +1332,34 @@ async def print_reservations(self):
print(f"Reservation '{res.token}':")
res.show(level=1)

async def export(self, place, target):
exported = target.export()
exported["LG__CLIENT_PID"] = str(os.getpid())
if self.args.format is ExportFormat.SHELL:
lines = []
for k, v in sorted(exported.items()):
lines.append(f"{k}={shlex.quote(v)}")
data = "\n".join(lines)
elif self.args.format is ExportFormat.SHELL_EXPORT:
lines = []
for k, v in sorted(exported.items()):
lines.append(f"export {k}={shlex.quote(v)}")
data = "\n".join(lines)+"\n"
elif self.args.format is ExportFormat.JSON:
data = json.dumps(exported)
if self.args.filename == "-":
sys.stdout.write(data)
else:
atomic_replace(self.args.filename, data.encode())
print(f"Exported to {self.args.filename}", file=sys.stderr)
try:
print("Waiting for CTRL+C or SIGTERM...", file=sys.stderr)
while True:
await asyncio.sleep(1.0)
except GeneratorExit:
print("Exiting...\n", file=sys.stderr)
export.needs_target = True


def start_session(url, realm, extra):
from autobahn.asyncio.wamp import ApplicationRunner
Expand Down Expand Up @@ -1421,6 +1454,16 @@ def __call__(self, parser, namespace, value, option_string):
v.append((local, remote))
setattr(namespace, self.dest, v)


class ExportFormat(enum.Enum):
SHELL = "shell"
SHELL_EXPORT = "shell-export"
JSON = "json"

def __str__(self):
return self.value


def main():
processwrapper.enable_logging()
logging.basicConfig(
Expand Down Expand Up @@ -1756,6 +1799,13 @@ def main():
subparser = subparsers.add_parser('reservations', help="list current reservations")
subparser.set_defaults(func=ClientSession.print_reservations)

subparser = subparsers.add_parser('export', help="export driver information to a file (needs environment with drivers)")
subparser.add_argument('--format', dest='format',
type=ExportFormat, choices=ExportFormat, default=ExportFormat.SHELL_EXPORT,
help="output format (default: %(default)s)")
subparser.add_argument('filename', help='output filename')
subparser.set_defaults(func=ClientSession.export)

# make any leftover arguments available for some commands
args, leftover = parser.parse_known_args()
if args.command not in ['ssh', 'rsync', 'forward']:
Expand Down Expand Up @@ -1806,6 +1856,8 @@ def main():
if args.command and args.command != 'help':
exitcode = 0
try:
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))

session = start_session(args.crossbar, os.environ.get("LG_CROSSBAR_REALM", "realm1"),
extra)
try:
Expand Down
7 changes: 7 additions & 0 deletions labgrid/strategy/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,10 @@ def transition(self, status):

def force(self, status):
raise NotImplementedError(f"Strategy.force() is not implemented for {self.__class__.__name__}")

def prepare_export(self):
"""By default, export all drivers bound by the strategy."""
name_map = {}
for name in self.bindings.keys():
name_map[getattr(self, name)] = name
return name_map
53 changes: 52 additions & 1 deletion labgrid/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .binding import BindingError, BindingState
from .driver import Driver
from .exceptions import NoSupplierFoundError, NoDriverFoundError, NoResourceFoundError
from .exceptions import NoSupplierFoundError, NoDriverFoundError, NoResourceFoundError, NoStrategyFoundError
from .resource import Resource
from .strategy import Strategy
from .util import Timeout
Expand Down Expand Up @@ -212,6 +212,24 @@ def get_driver(self, cls, *, name=None, activate=True):
"""
return self._get_driver(cls, name=name, activate=activate)

def get_strategy(self):
"""
Helper function to get the strategy of the target.

Returns the Strategy, if exactly one exists and raises a
NoStrategyFoundError otherwise.
"""
found = []
for drv in self.drivers:
if not isinstance(drv, Strategy):
continue
found.append(drv)
if not found:
raise NoStrategyFoundError(f"no Strategy found in {self}")
elif len(found) > 1:
raise NoStrategyFoundError(f"multiple Strategies found in {self}")
return found[0]

def __getitem__(self, key):
"""
Syntactic sugar to access drivers by class (optionally filtered by
Expand Down Expand Up @@ -469,6 +487,39 @@ def _atexit_cleanup(self):
"method on targets yourself to handle exceptions explictly.")
print(f"Error: {e}")

def export(self):
"""
Export information from drivers.

All drivers are deactivated before being exported.

The Strategy can decide for which driver the export method is called and
with which name. Otherwise, all drivers are exported.
"""
try:
name_map = self.get_strategy().prepare_export()
selection = set(name_map.keys())
except NoStrategyFoundError:
name_map = {}
selection = set(driver for driver in self.drivers if not isinstance(driver, Strategy))

assert len(name_map) == len(set(name_map.values())), "duplicate export name"

# drivers need to be deactivated for export to avoid conflicts
self.deactivate_all_drivers()

export_vars = {}
for driver in selection:
name = name_map.get(driver)
if not name:
name = driver.get_export_name()
for k, v in driver.get_export_vars().items():
assert isinstance(k, str), f"key {k} from {driver} is not a string"
assert isinstance(v, str), f"value {v} for key {k} from {driver} is not a string"
export_vars[f"LG__{name}_{k}".upper()] = v
return export_vars


def cleanup(self):
"""Clean up conntected drivers and resources in reversed order"""
self.deactivate_all_drivers()
Expand Down
100 changes: 100 additions & 0 deletions tests/test_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import pytest

from labgrid.resource import Resource, NetworkSerialPort
from labgrid.resource.remote import RemoteNetworkInterface, RemoteTFTPProvider
from labgrid.driver import Driver, SerialDriver, NetworkInterfaceDriver, TFTPProviderDriver
from labgrid.strategy import Strategy
from labgrid.binding import StateError


class ResourceA(Resource):
pass


class DriverA(Driver):
bindings = {"res": ResourceA}

@Driver.check_bound
def get_export_vars(self):
return {
"a": "b",
}


class StrategyA(Strategy):
bindings = {
"drv": DriverA,
}


def test_export(target):
ra = ResourceA(target, "resource")
d = DriverA(target, "driver")
s = StrategyA(target, "strategy")

exported = target.export()
assert exported == {
"LG__DRV_A": "b",
}

target.activate(d)
with pytest.raises(StateError):
d.get_export_vars()


class StrategyB(Strategy):
bindings = {
"drv": DriverA,
}

def prepare_export(self):
return {
self.drv: "custom_name",
}


def test_export_custom(target):
ra = ResourceA(target, "resource")
d = DriverA(target, "driver")
s = StrategyB(target, "strategy")

exported = target.export()
assert exported == {
"LG__CUSTOM_NAME_A": "b",
}


def test_export_network_serial(target):
NetworkSerialPort(target, None, host='testhost', port=12345, speed=115200)
SerialDriver(target, None)

exported = target.export()
assert exported == {
'LG__SERIALDRIVER_HOST': 'testhost',
'LG__SERIALDRIVER_PORT': '12345',
'LG__SERIALDRIVER_PROTOCOL': 'rfc2217',
'LG__SERIALDRIVER_SPEED': '115200'
}


def test_export_remote_network_interface(target):
RemoteNetworkInterface(target, None, host='testhost', ifname='wlan0')
NetworkInterfaceDriver(target, "netif")

exported = target.export()
assert exported == {
'LG__NETIF_HOST': 'testhost',
'LG__NETIF_IFNAME': 'wlan0'
}


def test_export_remote_tftp_provider(target):
RemoteTFTPProvider(target, None, host='testhost', internal='/srv/tftp/testboard/', external='testboard/')
TFTPProviderDriver(target, "tftp")

exported = target.export()
assert exported == {
'LG__TFTP_HOST': 'testhost',
'LG__TFTP_INTERNAL': '/srv/tftp/testboard/',
'LG__TFTP_EXTERNAL': 'testboard/',
}
Loading