From f135cf6d13129bd93643855a4e7bb1ff1a71daf6 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Sat, 15 Mar 2025 09:56:32 -0400 Subject: [PATCH 1/2] jumpstarter_driver_flashers: fmt --- .../jumpstarter_driver_flashers/bundle.py | 10 +- .../jumpstarter_driver_flashers/client.py | 105 ++++++++++-------- .../jumpstarter_driver_flashers/driver.py | 45 ++++---- .../driver_test.py | 25 ++--- .../test_bundle.py | 2 - .../jumpstarter_driver_flashers/uboot.py | 42 +++---- 6 files changed, 116 insertions(+), 113 deletions(-) diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/bundle.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/bundle.py index c61636614..c1e58cabf 100644 --- a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/bundle.py +++ b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/bundle.py @@ -9,27 +9,26 @@ class FileAddress(BaseModel): file: str address: str + class DtbVariant(BaseModel): default: str address: str variants: dict[str, str] + class FlasherLogin(BaseModel): login_prompt: str username: str | None = None password: str | None = None prompt: str + class FlashBundleSpecV1Alpha1(BaseModel): manufacturer: str link: Optional[str] bootcmd: str shelltype: Literal["busybox"] = Field(default="busybox") - login: FlasherLogin = Field( - default_factory=lambda: FlasherLogin( - login_prompt="login:", - prompt="#") - ) + login: FlasherLogin = Field(default_factory=lambda: FlasherLogin(login_prompt="login:", prompt="#")) default_target: str targets: dict[str, str] kernel: FileAddress @@ -41,6 +40,7 @@ class FlashBundleSpecV1Alpha1(BaseModel): class ObjectMeta(BaseModel): name: str + class FlasherBundleManifestV1Alpha1(BaseModel): apiVersion: Literal["jumpstarter.dev/v1alpha1"] = Field(default="jumpstarter.dev/v1alpha1") kind: Literal["FlashBundleManifest"] = Field(default="FlashBundleManifest") diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py index 6fae6d6dd..f69cc4918 100644 --- a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py +++ b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py @@ -25,6 +25,7 @@ debug_console_option = click.option("--console-debug", is_flag=True, help="Enable console debug mode") + @dataclass(kw_only=True) class BaseFlasherClient(FlasherClient, CompositeClient): """ @@ -99,10 +100,11 @@ def flash( error_queue = Queue() # Start the storage write operation in the background - storage_thread = threading.Thread(target=self._transfer_bg_thread, - args=(path, operator, operator_scheme, - os_image_checksum, self.http.storage, error_queue), - name="storage_transfer") + storage_thread = threading.Thread( + target=self._transfer_bg_thread, + args=(path, operator, operator_scheme, os_image_checksum, self.http.storage, error_queue), + name="storage_transfer", + ) storage_thread.start() # Make the exporter download the bundle contents and set files in the right places @@ -155,7 +157,6 @@ def flash( self.logger.info("Powering off target") self.power.off() - def _flash_with_progress(self, console, manifest, path, image_url, target_path): """Flash image to target device with progress monitoring. @@ -170,8 +171,8 @@ def _flash_with_progress(self, console, manifest, path, image_url, target_path): decompress_cmd = _get_decompression_command(path) flash_cmd = ( f'( wget -q -O - "{image_url}" | ' - f'{decompress_cmd} ' - f'dd of={target_path} bs=64k iflag=fullblock oflag=direct) &' + f"{decompress_cmd} " + f"dd of={target_path} bs=64k iflag=fullblock oflag=direct) &" ) console.sendline(flash_cmd) console.expect(manifest.spec.login.prompt, timeout=60) @@ -190,7 +191,7 @@ def _flash_with_progress(self, console, manifest, path, image_url, target_path): if "No such file or directory" in console.before.decode(errors="ignore"): break data = console.before.decode(errors="ignore") - match = re.search(r'pos:\s+(\d+)', data) + match = re.search(r"pos:\s+(\d+)", data) if match: current_bytes = int(match.group(1)) current_time = time.time() @@ -198,8 +199,8 @@ def _flash_with_progress(self, console, manifest, path, image_url, target_path): if elapsed >= 1.0: # Update speed every second bytes_diff = current_bytes - last_pos - speed_mb = (bytes_diff / (1024*1024)) / elapsed - total_mb = current_bytes/(1024*1024) + speed_mb = (bytes_diff / (1024 * 1024)) / elapsed + total_mb = current_bytes / (1024 * 1024) self.logger.info(f"Flash progress: {total_mb:.2f} MB, Speed: {speed_mb:.2f} MB/s") last_pos = current_bytes @@ -209,7 +210,6 @@ def _flash_with_progress(self, console, manifest, path, image_url, target_path): console.sendline("sync") console.expect(manifest.spec.login.prompt, timeout=1200) - def _get_target_device(self, target: str, manifest: FlasherBundleManifestV1Alpha1, console) -> str: """Get the target device path from the manifest, resolving block devices if needed. @@ -229,15 +229,19 @@ def _get_target_device(self, target: str, manifest: FlasherBundleManifestV1Alpha raise ArgumentError(f"Target {target} not found in manifest") if target_path.startswith("/sys/class/block#"): - target_path = self._lookup_block_device( - console, manifest.spec.login.prompt, target_path.split("#")[1]) + target_path = self._lookup_block_device(console, manifest.spec.login.prompt, target_path.split("#")[1]) return target_path - - def _transfer_bg_thread(self, src_path: PathBuf, src_operator: Operator, src_operator_scheme: str, - known_hash: str | None, - to_storage: OpendalClient, error_queue): + def _transfer_bg_thread( + self, + src_path: PathBuf, + src_operator: Operator, + src_operator_scheme: str, + known_hash: str | None, + to_storage: OpendalClient, + error_queue, + ): """Transfer image to storage in the background Args: @@ -285,7 +289,6 @@ def _transfer_bg_thread(self, src_path: PathBuf, src_operator: Operator, src_ope raise def _sha256_file(self, src_operator, src_path) -> str: - m = hashlib.sha256() with src_operator.open(src_path, "rb") as f: while True: @@ -299,11 +302,13 @@ def _sha256_file(self, src_operator, src_path) -> str: def _create_metadata_and_json(self, src_operator, src_path) -> tuple[Metadata, str]: """Create a metadata json string from a metadata object""" metadata = src_operator.stat(src_path) - return metadata, json.dumps({ - "path": str(src_path), - "content_length": metadata.content_length, - "etag": metadata.etag, - }) + return metadata, json.dumps( + { + "path": str(src_path), + "content_length": metadata.content_length, + "etag": metadata.etag, + } + ) def _lookup_block_device(self, console, prompt, address: str) -> str: """Lookup block device for a given address. @@ -317,7 +322,7 @@ def _lookup_block_device(self, console, prompt, address: str) -> str: # lrwxrwxrwx 1 root root 0 Jan 1 # 00:00 mmcblk1 -> ../../devices/platform/bus@100000/4fb0000.mmc/mmc_host/mmc1/mmc1:aaaa/block/mmcblk1 output = console.before.decode(errors="ignore") - match = re.search(r'\s(\w+)\s->', output) + match = re.search(r"\s(\w+)\s->", output) if match: return "/dev/" + match.group(1) else: @@ -359,7 +364,6 @@ def _services_up(self): self.http.stop() self.tftp.stop() - def _generate_uboot_env(self): """Generate a uboot environment dictionary, may need specific overrides for different targets""" tftp_host = self.tftp.get_host() @@ -367,7 +371,6 @@ def _generate_uboot_env(self): "serverip": tftp_host, } - @contextmanager def _busybox(self): """Start a busybox shell. @@ -406,7 +409,7 @@ def _busybox(self): uboot.run_command(f"tftpboot {dtb_address} {dtb_filename}", timeout=120) self.logger.info(f"Running boot command: {manifest.spec.bootcmd}") - console.send(manifest.spec.bootcmd +"\n") + console.send(manifest.spec.bootcmd + "\n") # if manifest has login details, we need to login if manifest.spec.login.username: @@ -438,7 +441,7 @@ def use_initram(self, path: PathBuf, operator: Operator | None = None): def use_kernel(self, path: PathBuf, operator: Operator | None = None): """Use kernel file""" if operator is None: - path, operator, operator_scheme = operator_for_path(path) + path, operator, operator_scheme = operator_for_path(path) ... @@ -461,16 +464,24 @@ def base(): @base.command() @click.argument("file") @click.option("--partition", type=str) - @click.option('--os-image-checksum', - help='SHA256 checksum of OS image (direct value)') - @click.option('--os-image-checksum-file', - help='File containing SHA256 checksum of OS image', - type=click.Path(exists=True, dir_okay=False)) - @click.option('--force-exporter-http', is_flag=True, help='Force use of exporter HTTP') - @click.option('--force-flash-bundle', type=str, help='Force use of a specific flasher OCI bundle') + @click.option("--os-image-checksum", help="SHA256 checksum of OS image (direct value)") + @click.option( + "--os-image-checksum-file", + help="File containing SHA256 checksum of OS image", + type=click.Path(exists=True, dir_okay=False), + ) + @click.option("--force-exporter-http", is_flag=True, help="Force use of exporter HTTP") + @click.option("--force-flash-bundle", type=str, help="Force use of a specific flasher OCI bundle") @debug_console_option - def flash(file, partition, os_image_checksum, os_image_checksum_file, - console_debug, force_exporter_http, force_flash_bundle): + def flash( + file, + partition, + os_image_checksum, + os_image_checksum_file, + console_debug, + force_exporter_http, + force_flash_bundle, + ): """Flash image to DUT from file""" if os_image_checksum_file and os.path.exists(os_image_checksum_file): with open(os_image_checksum_file) as f: @@ -478,10 +489,12 @@ def flash(file, partition, os_image_checksum, os_image_checksum_file, self.logger.info(f"Read checksum from file: {os_image_checksum}") self.set_console_debug(console_debug) - self.flash(file, - partition=partition, - force_exporter_http=force_exporter_http, - force_flash_bundle=force_flash_bundle) + self.flash( + file, + partition=partition, + force_exporter_http=force_exporter_http, + force_flash_bundle=force_flash_bundle, + ) @base.command() @debug_console_option @@ -522,8 +535,8 @@ def _get_decompression_command(filename_or_url) -> str: filename = urlparse(filename_or_url).path.split("/")[-1] filename = filename.lower() - if filename.endswith(('.gz', '.gzip')): - return 'zcat |' - elif filename.endswith('.xz'): - return 'xzcat |' - return '' + if filename.endswith((".gz", ".gzip")): + return "zcat |" + elif filename.endswith(".xz"): + return "xzcat |" + return "" diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/driver.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/driver.py index 7b36877ff..033e587aa 100644 --- a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/driver.py +++ b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/driver.py @@ -13,11 +13,12 @@ @dataclass(kw_only=True) class BaseFlasher(Driver): - """ driver for Jumpstarter""" + """driver for Jumpstarter""" + flasher_bundle: str = field(default="quay.io/jumpstarter-dev/jumpstarter-flasher-test:latest") cache_dir: str = field(default="/var/lib/jumpstarter/flasher") - tftp_dir : str = field(default="/var/lib/tftpboot") - http_dir : str = field(default="/var/www/html") + tftp_dir: str = field(default="/var/lib/tftpboot") + http_dir: str = field(default="/var/www/html") def __post_init__(self): if hasattr(super(), "__post_init__"): @@ -25,26 +26,28 @@ def __post_init__(self): # Ensure required children are present if not already instantiated # in configuration - if 'tftp' not in self.children: - self.children['tftp'] = Tftp(root_dir=self.tftp_dir) - self.tftp = self.children['tftp'] + if "tftp" not in self.children: + self.children["tftp"] = Tftp(root_dir=self.tftp_dir) + self.tftp = self.children["tftp"] - if 'http' not in self.children: - self.children['http'] = HttpServer(root_dir=self.http_dir) - self.http = self.children['http'] + if "http" not in self.children: + self.children["http"] = HttpServer(root_dir=self.http_dir) + self.http = self.children["http"] # Ensure required children are present, the following are not auto-created - if 'serial' not in self.children: - raise ConfigurationError("'serial' instance is required for BaseFlasher " - "either via a ref ir a direct child instance") + if "serial" not in self.children: + raise ConfigurationError( + "'serial' instance is required for BaseFlasher either via a ref ir a direct child instance" + ) - if 'power' not in self.children: - raise ConfigurationError("'power' instance is required for BaseFlasher " - "either via a ref ir a direct child instance") + if "power" not in self.children: + raise ConfigurationError( + "'power' instance is required for BaseFlasher either via a ref ir a direct child instance" + ) # bundles that have already been downloaded in the current session self._downloaded = {} - self._use_dtb = None # use default dtb unless set by client + self._use_dtb = None # use default dtb unless set by client @classmethod def client(cls) -> str: @@ -76,7 +79,6 @@ async def setup_flasher_bundle(self, force_flash_bundle: str | None = None): self.logger.info(f"Setting up dtb in tftp: {dtb_path}") await self.tftp.storage.copy_exporter_file(dtb_path, dtb_path.name) - @export def set_dtb(self, handle): """Provide a different dtb from client""" @@ -87,8 +89,10 @@ async def use_dtb_variant(self, variant): """Provide a different dtb reference from the flasher bundle""" manifest = await self.get_flasher_manifest() if manifest.get_dtb_file(variant) is None: - raise ValueError(f"DTB variant {variant} not found in the flasher bundle, " - f"available variants are: {list(manifest.spec.dtb.variants.keys())}") + raise ValueError( + f"DTB variant {variant} not found in the flasher bundle, " + f"available variants are: {list(manifest.spec.dtb.variants.keys())}" + ) self._use_dtb = variant def set_kernel(self, handle): @@ -186,5 +190,6 @@ async def get_initram_address(self) -> str: @dataclass(kw_only=True) class TIJ784S4Flasher(BaseFlasher): - """ driver for Jumpstarter""" + """driver for Jumpstarter""" + flasher_bundle: str = "quay.io/jumpstarter-dev/jumpstarter-flasher-ti-j784s4:latest" diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/driver_test.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/driver_test.py index bd816345c..15d17fa05 100644 --- a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/driver_test.py +++ b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/driver_test.py @@ -21,7 +21,8 @@ def temp_dirs(): os.mkdir(tftp) yield cache, http, tftp -@pytest.fixture(scope="session") # session to retain cache over time + +@pytest.fixture(scope="session") # session to retain cache over time def complete_flasher(temp_dirs): cache, http, tftp = temp_dirs yield BaseFlasher( @@ -34,39 +35,30 @@ def complete_flasher(temp_dirs): }, ) + def test_missing_serial(temp_dirs): cache, http, tftp = temp_dirs with pytest.raises(ConfigurationError): - BaseFlasher(cache_dir=cache, - http_dir=http, - tftp_dir=tftp, - children={ - "power": MockPower() - } - ) + BaseFlasher(cache_dir=cache, http_dir=http, tftp_dir=tftp, children={"power": MockPower()}) def test_missing_power(temp_dirs): cache, http, tftp = temp_dirs with pytest.raises(ConfigurationError): - BaseFlasher(cache_dir=cache, - http_dir=http, - tftp_dir=tftp, - children = { - "serial": PySerial(url="loop://") - } - ) + BaseFlasher(cache_dir=cache, http_dir=http, tftp_dir=tftp, children={"serial": PySerial(url="loop://")}) + def test_drivers_flashers_setup_flasher_bundle(complete_flasher): with serve(complete_flasher) as client: client.call("setup_flasher_bundle") dtb = client.call("get_dtb_filename") kernel = client.call("get_kernel_filename") - initram = client.call("get_initram_filename") + initram = client.call("get_initram_filename") assert client.tftp.storage.read_bytes(kernel) == b"\x00" * 1024 assert client.tftp.storage.read_bytes(initram) == b"\x00" * 1024 * 2 assert client.tftp.storage.read_bytes(dtb) == b"\x00" * 1024 * 3 + def test_drivers_flashers_manifest(complete_flasher): with serve(complete_flasher) as client: assert client.manifest.spec.kernel.file == "data/kernel" @@ -83,6 +75,7 @@ def test_drivers_flashers_dtb_switching(complete_flasher): with pytest.raises(ValueError): client.call("use_dtb_variant", "noexists") + def test_drivers_flashers_filenames(complete_flasher): with serve(complete_flasher) as client: assert client.call("get_dtb_filename") == "test-dtb.dtb" diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/test_bundle.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/test_bundle.py index 2e484398f..05e9d5a4f 100644 --- a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/test_bundle.py +++ b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/test_bundle.py @@ -13,5 +13,3 @@ def test_bundle_read(): "usd": "/sys/class/block#4fb0000", "emmc": "/sys/class/block#4f80000", } - - diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/uboot.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/uboot.py index 95928ce2c..927688281 100644 --- a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/uboot.py +++ b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/uboot.py @@ -15,45 +15,46 @@ class DhcpInfo: @property def cidr(self) -> str: try: - octets = [int(x) for x in self.netmask.split('.')] - binary = ''.join([bin(x)[2:].zfill(8) for x in octets]) - return str(binary.count('1')) + octets = [int(x) for x in self.netmask.split(".")] + binary = "".join([bin(x)[2:].zfill(8) for x in octets]) + return str(binary.count("1")) except Exception: return "24" + class UbootConsole: - def __init__(self, /, console, power, logger, prompt="=>"): + def __init__(self, /, console, power, logger, prompt="=>"): self.console = console self.power = power self.logger = logger self.prompt = prompt def reboot_to_console(self): - """ Trigger U-Boot console + """Trigger U-Boot console Power cycle the target and wait for the U-Boot prompt """ self.power.cycle() self.logger.info("Waiting for U-Boot prompt...") data = b"" for _ in range(100): - self.console.send('\x1b') + self.console.send("\x1b") try: recv = self.console.read_nonblocking(size=4096, timeout=0.1) if recv: data += recv except pexpect.TIMEOUT: pass - #print(data) + # print(data) if self.prompt.encode() in data: - return self.console.send('\x1b') + return self.console.send("\x1b") time.sleep(0.1) raise RuntimeError("Failed to get U-Boot prompt") def run_command(self, cmd: str, timeout: int = 60): self.logger.info(f"Running command: {cmd}") - if not cmd.endswith('\n'): - cmd += '\n' - self.console.send(cmd.encode('utf-8')) + if not cmd.endswith("\n"): + cmd += "\n" + self.console.send(cmd.encode("utf-8")) return self._read_until(self.prompt, timeout) def setup_dhcp(self, timeout: int = 60) -> DhcpInfo: @@ -79,13 +80,9 @@ def setup_dhcp(self, timeout: int = 60) -> DhcpInfo: raise ValueError("Could not extract complete network information") # Get netmask from environment - netmask = self.get_env('netmask') or "255.255.255.0" + netmask = self.get_env("netmask") or "255.255.255.0" - return DhcpInfo( - ip_address=ip_address, - gateway=gateway, - netmask=netmask - ) + return DhcpInfo(ip_address=ip_address, gateway=gateway, netmask=netmask) def wait_for_pattern(self, pattern: str, timeout: int = 300, print_output: bool = False): """Wait for specific pattern in output""" @@ -98,7 +95,7 @@ def get_env(self, var_name: str, timeout: int = 5) -> Optional[str]: buffer = self.run_command(f"printenv {var_name}", timeout) for line in buffer.splitlines(): if f"{var_name}=" in line: - return line.split('=', 1)[1].strip() + return line.split("=", 1)[1].strip() except TimeoutError as err: raise TimeoutError(f"Timed out waiting for {var_name}") from err @@ -115,12 +112,9 @@ def set_env_dict(self, env): # TODO: rewrite this, there is a way to do it just with pexpect # https://github.com/jumpstarter-dev/jumpstarter-devspace/blob/orin-nx-testing/tests/test_on_orin_nx.py#L156 - def _read_until(self, - target: str, - timeout: int = 60, - print_output: bool = False, - error_patterns: list[str] = None) -> str: - + def _read_until( + self, target: str, timeout: int = 60, print_output: bool = False, error_patterns: list[str] = None + ) -> str: saved_logfile = self.console.logfile_read self.logger.debug(f"_read_until {target}") try: From 28308bb638505b3d45bf94306b4733220478f4b0 Mon Sep 17 00:00:00 2001 From: Nick Cao Date: Sat, 15 Mar 2025 10:04:07 -0400 Subject: [PATCH 2/2] Make use of standalone uboot driver --- .../jumpstarter_driver_flashers/client.py | 33 ++--- .../jumpstarter_driver_flashers/driver.py | 9 ++ .../jumpstarter_driver_flashers/uboot.py | 134 ------------------ .../pyproject.toml | 4 + .../jumpstarter_driver_opendal/client.py | 24 ++-- .../jumpstarter_driver_opendal/driver.py | 1 + uv.lock | 2 + 7 files changed, 39 insertions(+), 168 deletions(-) delete mode 100644 packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/uboot.py diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py index f69cc4918..a50e69f9f 100644 --- a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py +++ b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/client.py @@ -20,7 +20,6 @@ from jumpstarter_driver_flashers.bundle import FlasherBundleManifestV1Alpha1 -from .uboot import UbootConsole from jumpstarter.common.exceptions import ArgumentError debug_console_option = click.option("--console-debug", is_flag=True, help="Enable console debug mode") @@ -42,6 +41,7 @@ def __post_init__(self): def set_console_debug(self, debug: bool): """Set console debug mode""" self._console_debug = debug + # TODO: also set console debug on uboot client @contextmanager def busybox_shell(self): @@ -61,12 +61,8 @@ def bootloader_shell(self): self.logger.info("Setting up flasher bundle files in exporter") self.call("setup_flasher_bundle") with self._services_up(): - with self.serial.pexpect() as console: - if self._console_debug: - console.logfile_read = sys.stdout.buffer - uboot = UbootConsole(console=console, power=self.power, logger=self.logger) - uboot.reboot_to_console() - console.sendline("") + with self.uboot.reboot_to_console(): + pass yield self.serial def flash( @@ -377,36 +373,37 @@ def _busybox(self): This is a helper context manager that boots the device into uboot and returns a console object. """ - with self.serial.pexpect() as console: - if self._console_debug: - console.logfile_read = sys.stdout.buffer - uboot = UbootConsole(console=console, power=self.power, logger=self.logger) - # make sure that the device is booted into the uboot console - uboot.reboot_to_console() + + # make sure that the device is booted into the uboot console + with self.uboot.reboot_to_console(): # run dhcp discovery and gather details useful for later - self._dhcp_details = uboot.setup_dhcp() + self._dhcp_details = self.uboot.setup_dhcp() self.logger.info(f"discovered dhcp details: {self._dhcp_details}") # configure the environment necessary env = self._generate_uboot_env() - uboot.set_env_dict(env) + self.uboot.set_env_dict(env) # load any necessary files to RAM from the tftp storage manifest = self.manifest kernel_filename = Path(manifest.get_kernel_file()).name kernel_address = manifest.get_kernel_address() - uboot.run_command(f"tftpboot {kernel_address} {kernel_filename}", timeout=120) + self.uboot.run_command(f"tftpboot {kernel_address} {kernel_filename}", timeout=120) if manifest.get_initram_file(): initram_filename = Path(manifest.get_initram_file()).name initram_address = manifest.get_initram_address() - uboot.run_command(f"tftpboot {initram_address} {initram_filename}", timeout=120) + self.uboot.run_command(f"tftpboot {initram_address} {initram_filename}", timeout=120) if manifest.get_dtb_file(): dtb_filename = Path(manifest.get_dtb_file()).name dtb_address = manifest.get_dtb_address() - uboot.run_command(f"tftpboot {dtb_address} {dtb_filename}", timeout=120) + self.uboot.run_command(f"tftpboot {dtb_address} {dtb_filename}", timeout=120) + + with self.serial.pexpect() as console: + if self._console_debug: + console.logfile_read = sys.stdout.buffer self.logger.info(f"Running boot command: {manifest.spec.bootcmd}") console.send(manifest.spec.bootcmd + "\n") diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/driver.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/driver.py index 033e587aa..c7058d8ca 100644 --- a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/driver.py +++ b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/driver.py @@ -4,6 +4,7 @@ import anyio.to_thread from jumpstarter_driver_http.driver import HttpServer from jumpstarter_driver_tftp.driver import Tftp +from jumpstarter_driver_uboot.driver import UbootConsole from oras.provider import Registry from .bundle import FlasherBundleManifestV1Alpha1 @@ -45,6 +46,14 @@ def __post_init__(self): "'power' instance is required for BaseFlasher either via a ref ir a direct child instance" ) + if "uboot" not in self.children: + self.children["uboot"] = UbootConsole( + children={ + "power": self.children["power"], + "serial": self.children["serial"], + } + ) + # bundles that have already been downloaded in the current session self._downloaded = {} self._use_dtb = None # use default dtb unless set by client diff --git a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/uboot.py b/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/uboot.py deleted file mode 100644 index 927688281..000000000 --- a/packages/jumpstarter-driver-flashers/jumpstarter_driver_flashers/uboot.py +++ /dev/null @@ -1,134 +0,0 @@ -import sys -import time -from dataclasses import dataclass -from typing import Optional - -import pexpect - - -@dataclass -class DhcpInfo: - ip_address: str - gateway: str - netmask: str - - @property - def cidr(self) -> str: - try: - octets = [int(x) for x in self.netmask.split(".")] - binary = "".join([bin(x)[2:].zfill(8) for x in octets]) - return str(binary.count("1")) - except Exception: - return "24" - - -class UbootConsole: - def __init__(self, /, console, power, logger, prompt="=>"): - self.console = console - self.power = power - self.logger = logger - self.prompt = prompt - - def reboot_to_console(self): - """Trigger U-Boot console - Power cycle the target and wait for the U-Boot prompt - """ - self.power.cycle() - self.logger.info("Waiting for U-Boot prompt...") - data = b"" - for _ in range(100): - self.console.send("\x1b") - try: - recv = self.console.read_nonblocking(size=4096, timeout=0.1) - if recv: - data += recv - except pexpect.TIMEOUT: - pass - # print(data) - if self.prompt.encode() in data: - return self.console.send("\x1b") - time.sleep(0.1) - raise RuntimeError("Failed to get U-Boot prompt") - - def run_command(self, cmd: str, timeout: int = 60): - self.logger.info(f"Running command: {cmd}") - if not cmd.endswith("\n"): - cmd += "\n" - self.console.send(cmd.encode("utf-8")) - return self._read_until(self.prompt, timeout) - - def setup_dhcp(self, timeout: int = 60) -> DhcpInfo: - self.logger.info("Running DHCP to obtain network configuration...") - buffer = self.run_command("dhcp", timeout) - - # Extract IP and - ip_address = None - gateway = None - - for line in buffer.splitlines(): - if "DHCP client bound to address" in line: - bind_index = line.find("DHCP client bound to address") + len("DHCP client bound to address") - ip_end = line.find("(", bind_index) - if ip_end != -1: - ip_address = line[bind_index:ip_end].strip() - - if "sending through gateway" in line: - gw_index = line.find("sending through gateway") + len("sending through gateway") - gateway = line[gw_index:].strip() - - if not ip_address or not gateway: - raise ValueError("Could not extract complete network information") - - # Get netmask from environment - netmask = self.get_env("netmask") or "255.255.255.0" - - return DhcpInfo(ip_address=ip_address, gateway=gateway, netmask=netmask) - - def wait_for_pattern(self, pattern: str, timeout: int = 300, print_output: bool = False): - """Wait for specific pattern in output""" - return self._read_until(pattern, timeout, print_output) - - def get_env(self, var_name: str, timeout: int = 5) -> Optional[str]: - """Get U-Boot environment variable value""" - self.logger.debug(f"\nGetting U-Boot env var: {var_name}") - try: - buffer = self.run_command(f"printenv {var_name}", timeout) - for line in buffer.splitlines(): - if f"{var_name}=" in line: - return line.split("=", 1)[1].strip() - except TimeoutError as err: - raise TimeoutError(f"Timed out waiting for {var_name}") from err - - return None - - def set_env(self, key: str, value: str): - cmd = f"setenv {key} '{value}'" - self.logger.debug(f"Sending command to U-Boot: {cmd}") - self.run_command(cmd, timeout=5) - - def set_env_dict(self, env): - for key, value in env.items(): - self.set_env(key, value) - - # TODO: rewrite this, there is a way to do it just with pexpect - # https://github.com/jumpstarter-dev/jumpstarter-devspace/blob/orin-nx-testing/tests/test_on_orin_nx.py#L156 - def _read_until( - self, target: str, timeout: int = 60, print_output: bool = False, error_patterns: list[str] = None - ) -> str: - saved_logfile = self.console.logfile_read - self.logger.debug(f"_read_until {target}") - try: - if print_output: - self.console.logfile_read = sys.stdout.buffer - - self.console.expect(target, timeout=timeout) - buffer = self.console.before.decode().strip() - if error_patterns and any(pattern in buffer.lower() for pattern in error_patterns): - raise RuntimeError(f"Error detected in output: {buffer}") - return buffer - except pexpect.TIMEOUT as err: - raise TimeoutError(f"Timed out waiting for '{target}'") from err - except RuntimeError: - raise - finally: - self.console.logfile_read = saved_logfile diff --git a/packages/jumpstarter-driver-flashers/pyproject.toml b/packages/jumpstarter-driver-flashers/pyproject.toml index de40b2812..50e43027f 100644 --- a/packages/jumpstarter-driver-flashers/pyproject.toml +++ b/packages/jumpstarter-driver-flashers/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "jumpstarter-driver-http", "jumpstarter-driver-tftp", "jumpstarter-driver-power", + "jumpstarter-driver-uboot", ] [tool.hatch.version] @@ -33,6 +34,9 @@ addopts = "--cov --cov-report=html --cov-report=xml" log_cli = true log_cli_level = "INFO" testpaths = ["jumpstarter_driver_flashers"] + +[tool.uv.sources] +jumpstarter-driver-uboot = { workspace = true } #asyncio_mode = "auto" [build-system] diff --git a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py index 01eb77e1a..e597e278f 100644 --- a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py +++ b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/client.py @@ -44,23 +44,20 @@ async def aclose(self): def fs_operator_for_path(path: PathBuf) -> tuple[PathBuf, Operator]: return Path(path).resolve(), Operator("fs", root="/") + def operator_for_path(path: PathBuf) -> tuple[PathBuf, Operator, str]: - """ Create an operator for the given path + """Create an operator for the given path Return a tuple of: - the path - the operator for the given path - the scheme of the operator. """ - if type(path) is str and path.startswith(('http://', 'https://')): - parsed_url = urlparse(path) - operator = Operator( - 'http', - root='/', - endpoint=f"{parsed_url.scheme}://{parsed_url.netloc}" - ) - return Path(parsed_url.path), operator, 'http' + if type(path) is str and path.startswith(("http://", "https://")): + parsed_url = urlparse(path) + operator = Operator("http", root="/", endpoint=f"{parsed_url.scheme}://{parsed_url.netloc}") + return Path(parsed_url.path), operator, "http" else: - return *fs_operator_for_path(path), 'fs' + return *fs_operator_for_path(path), "fs" @dataclass(kw_only=True) @@ -529,12 +526,7 @@ def flash( ... @abstractmethod - def dump( - self, - path: PathBuf, - *, - partition: str | None = None, - operator: Operator | None = None): + def dump(self, path: PathBuf, *, partition: str | None = None, operator: Operator | None = None): """Dump image from DUT""" ... diff --git a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver.py b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver.py index 81a732f0e..732b9972d 100644 --- a/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver.py +++ b/packages/jumpstarter-driver-opendal/jumpstarter_driver_opendal/driver.py @@ -195,6 +195,7 @@ async def copy_exporter_file(self, /, source: Path, target: str): break await dst.write(bs=data) + class FlasherInterface(metaclass=ABCMeta): @classmethod def client(cls) -> str: diff --git a/uv.lock b/uv.lock index a78efed66..6f82c07e8 100644 --- a/uv.lock +++ b/uv.lock @@ -1311,6 +1311,7 @@ dependencies = [ { name = "jumpstarter-driver-power" }, { name = "jumpstarter-driver-pyserial" }, { name = "jumpstarter-driver-tftp" }, + { name = "jumpstarter-driver-uboot" }, { name = "oras" }, ] @@ -1329,6 +1330,7 @@ requires-dist = [ { name = "jumpstarter-driver-power", editable = "packages/jumpstarter-driver-power" }, { name = "jumpstarter-driver-pyserial", editable = "packages/jumpstarter-driver-pyserial" }, { name = "jumpstarter-driver-tftp", editable = "packages/jumpstarter-driver-tftp" }, + { name = "jumpstarter-driver-uboot", editable = "packages/jumpstarter-driver-uboot" }, { name = "oras", specifier = ">=0.2.25" }, ]