diff --git a/packages/jumpstarter-driver-iscsi/examples/exporter.yaml b/packages/jumpstarter-driver-iscsi/examples/exporter.yaml index 35e9af07a..bd0768ace 100644 --- a/packages/jumpstarter-driver-iscsi/examples/exporter.yaml +++ b/packages/jumpstarter-driver-iscsi/examples/exporter.yaml @@ -13,4 +13,4 @@ export: iqn_prefix: "iqn.2024-06.dev.jumpstarter" target_name: "my-target" host: "" - port: 3260 \ No newline at end of file + port: 3260 \ No newline at end of file diff --git a/packages/jumpstarter-driver-iscsi/jumpstarter_driver_iscsi/client.py b/packages/jumpstarter-driver-iscsi/jumpstarter_driver_iscsi/client.py index 548226e6a..66a1039e8 100644 --- a/packages/jumpstarter-driver-iscsi/jumpstarter_driver_iscsi/client.py +++ b/packages/jumpstarter-driver-iscsi/jumpstarter_driver_iscsi/client.py @@ -1,8 +1,13 @@ +import contextlib import hashlib import os from dataclasses import dataclass +from tempfile import NamedTemporaryFile from typing import Any, Dict, List, Optional +from urllib.parse import urlparse +import click +import requests from jumpstarter_driver_composite.client import CompositeClient from jumpstarter_driver_opendal.common import PathBuf from opendal import Operator @@ -64,6 +69,62 @@ def get_target_iqn(self) -> str: """ return self.call("get_target_iqn") + def _normalized_name_from_file(self, path: str) -> str: + base = os.path.basename(path) + for ext in (".gz", ".xz", ".bz2"): + if base.endswith(ext): + base = base[: -len(ext)] + break + if base.endswith(".img"): + base = base[: -len(".img")] + return base or "image" + + def _get_src_and_operator( + self, file: str, headers: tuple[str, ...] + ) -> tuple[str, Optional[Operator], Optional[str]]: + from jumpstarter_driver_opendal.client import operator_for_path + + if file.startswith(("http://", "https://")): + if headers: + header_map: Dict[str, str] = {} + for h in headers: + if ":" not in h: + raise click.ClickException(f"Invalid header format: {h!r}. Expected 'Key: Value'.") + key, value = h.split(":", 1) + key = key.strip() + value = value.strip() + if not key: + raise click.ClickException(f"Invalid header key in: {h!r}") + header_map[key] = value + + parsed = urlparse(file) + tf = NamedTemporaryFile( + prefix="jumpstarter-iscsi-", + suffix=os.path.basename(parsed.path), + delete=False, + ) + temp_path = tf.name + try: + with requests.get(file, stream=True, headers=header_map, timeout=(10, 60)) as resp: + resp.raise_for_status() + for chunk in resp.iter_content(chunk_size=65536): + if chunk: + tf.write(chunk) + tf.close() + return temp_path, None, temp_path + except Exception: + tf.close() + with contextlib.suppress(Exception): + os.unlink(temp_path) + raise + + _, src_operator, _ = operator_for_path(file) + return file, src_operator, None + + file = os.path.abspath(file) + _, src_operator, _ = operator_for_path(file) + return file, src_operator, None + def add_lun(self, name: str, file_path: str, size_mb: int = 0, is_block: bool = False) -> str: """ Add a new LUN to the iSCSI target @@ -112,11 +173,12 @@ def _calculate_file_hash(self, file_path: str, operator: Optional[Operator] = No hash_obj.update(chunk) return hash_obj.hexdigest() else: - from jumpstarter_driver_opendal.client import operator_for_path - - path, op, _ = operator_for_path(file_path) hash_obj = hashlib.sha256() - with op.open(str(path), "rb") as f: + if isinstance(file_path, str) and file_path.startswith(("http://", "https://")): + src_path = urlparse(file_path).path + else: + src_path = str(file_path) + with operator.open(str(src_path), "rb") as f: while chunk := f.read(8192): hash_obj.update(chunk) return hash_obj.hexdigest() @@ -125,6 +187,7 @@ def _files_are_identical(self, src: PathBuf, dst_path: str, operator: Optional[O """Check if source and destination files are identical""" try: if not self.storage.exists(dst_path): + self.logger.info(f"{dst_path} does not exist") return False dst_stat = self.storage.stat(dst_path) @@ -133,22 +196,58 @@ def _files_are_identical(self, src: PathBuf, dst_path: str, operator: Optional[O if operator is None: src_size = os.path.getsize(str(src)) else: - from jumpstarter_driver_opendal.client import operator_for_path - - path, op, _ = operator_for_path(src) - src_size = op.stat(str(path)).content_length + if isinstance(src, str) and src.startswith(("http://", "https://")): + src_path = urlparse(src).path + else: + src_path = str(src) + src_size = operator.stat(str(src_path)).content_length if src_size != dst_size: + self.logger.info(f"Source size {src_size} != destination size {dst_size}") return False + self.logger.info("checking hashes") src_hash = self._calculate_file_hash(str(src), operator) + self.logger.info(f"Source hash: {src_hash}") dst_hash = self.storage.hash(dst_path, "sha256") + self.logger.info(f"Destination hash: {dst_hash}") return src_hash == dst_hash except Exception: return False + def _should_skip_upload( + self, src_path: str, dst_path: str, operator: Optional[Operator], force_upload: bool, algo: Optional[str] + ) -> bool: + if force_upload or algo is not None or not self.storage.exists(dst_path): + return False + + self.logger.info(f"Checking if {src_path} and {dst_path} are identical") + if self._files_are_identical(src_path, dst_path, operator): + self.logger.info(f"File {dst_path} already exists and is identical to source. Skipping upload...") + return True + + self.logger.info(f"File {dst_path} is not identical to source") + return False + + def _upload_file( + self, src_path: str, dst_name: str, dst_path: str, operator: Optional[Operator], algo: Optional[str] + ): + if algo is None: + self.logger.info(f"Uploading {src_path} to {dst_path}...") + self.storage.write_from_path(dst_path, src_path, operator) + else: + ext_to_algo = {".gz": "gz", ".xz": "xz", ".bz2": "bz2"} + ext = next(k for k, v in ext_to_algo.items() if v == algo) + compressed_path = f"{dst_name}.img{ext}" + self.logger.info(f"Uploading {src_path} to {compressed_path}...") + self.storage.write_from_path(compressed_path, src_path, operator) + self.logger.info(f"Decompressing on exporter: {compressed_path} -> {dst_name}.img ...") + self.call("decompress", compressed_path, f"{dst_name}.img", algo) + with contextlib.suppress(Exception): + self.storage.delete(compressed_path) + def upload_image( self, dst_name: str, @@ -176,18 +275,70 @@ def upload_image( size_mb = int(size_mb) dst_path = f"{dst_name}.img" - if not force_upload and self._files_are_identical(src, dst_path, operator): - print(f"File {dst_path} already exists and is identical to source. Skipping upload.") - else: - print(f"Uploading {src} to {dst_path}...") - self.storage.write_from_path(dst_path, src, operator) + src_path = str(src) + if operator is None and not src_path.startswith(("http://", "https://")): + src_path = os.path.abspath(src_path) + + ext_to_algo = {".gz": "gz", ".xz": "xz", ".bz2": "bz2"} + algo = next((v for k, v in ext_to_algo.items() if src_path.endswith(k)), None) + + if not self._should_skip_upload(src_path, dst_path, operator, force_upload, algo): + self._upload_file(src_path, dst_name, dst_path, operator, algo) if size_mb <= 0: - src_path = os.path.join(self.storage._storage.root_dir, dst_path) - size_mb = os.path.getsize(src_path) // (1024 * 1024) - if size_mb <= 0: + try: + dst_stat = self.storage.stat(dst_path) + size_mb = max(1, int(dst_stat.content_length) // (1024 * 1024)) + except Exception: size_mb = 1 self.add_lun(dst_name, dst_path, size_mb) - return self.get_target_iqn() + + def cli(self): + base = super().cli() + + @base.command() + @click.argument("file", type=str) + @click.option("--name", "name", "-n", type=str, help="LUN name (defaults to basename without extension)") + @click.option("--size-mb", type=int, default=0, show_default=True, help="Size in MB if creating a new image") + @click.option( + "--force-upload", + is_flag=True, + default=False, + help="Force uploading even if the file appears identical on the exporter", + ) + @click.option( + "--header", + "headers", + multiple=True, + help="Custom HTTP header in 'Key: Value' format. Repeatable.", + ) + def serve(file: str, name: Optional[str], size_mb: int, force_upload: bool, headers: tuple[str, ...]): + """Serve an image as an iSCSI LUN from a local path or HTTP(S) URL.""" + self.start() + + try: + self.call("clear_all_luns") + except Exception: + pass + + if not name: + candidate = urlparse(file).path if file.startswith(("http://", "https://")) else file + name = self._normalized_name_from_file(candidate) + + src_path, src_operator, temp_cleanup = self._get_src_and_operator(file, headers) + try: + iqn = self.upload_image( + name, src_path, size_mb=size_mb, operator=src_operator, force_upload=force_upload + ) + finally: + if temp_cleanup is not None: + with contextlib.suppress(Exception): + os.remove(temp_cleanup) + host = self.get_host() + port = self.get_port() + + click.echo(f"{host}:{port} {iqn}") + + return base diff --git a/packages/jumpstarter-driver-iscsi/jumpstarter_driver_iscsi/driver.py b/packages/jumpstarter-driver-iscsi/jumpstarter_driver_iscsi/driver.py index 8c5bea4a2..847f8590f 100644 --- a/packages/jumpstarter-driver-iscsi/jumpstarter_driver_iscsi/driver.py +++ b/packages/jumpstarter-driver-iscsi/jumpstarter_driver_iscsi/driver.py @@ -1,6 +1,11 @@ +import bz2 +import gzip +import lzma import os import socket +from contextlib import suppress from dataclasses import dataclass, field +from tempfile import NamedTemporaryFile from typing import Any, Dict, List, Optional from jumpstarter_driver_opendal.driver import Opendal @@ -62,9 +67,7 @@ def __post_init__(self): os.makedirs(self.root_dir, exist_ok=True) self.children["storage"] = Opendal( - scheme="fs", - kwargs={"root": self.root_dir}, - remove_created_on_close=self.remove_created_on_close + scheme="fs", kwargs={"root": self.root_dir}, remove_created_on_close=self.remove_created_on_close ) self.storage = self.children["storage"] @@ -153,6 +156,51 @@ def _configure_tpg_attributes(self): self._tpg.set_attribute("generate_node_acls", "1") # type: ignore[attr-defined] self._tpg.set_attribute("demo_mode_write_protect", "0") # type: ignore[attr-defined] + def _clear_tpg_luns(self): + """Clear all LUNs from the current TPG""" + try: + for lun in list(self._tpg.luns): # type: ignore[attr-defined] + try: + storage_obj = getattr(lun, "storage_object", None) + except Exception: + storage_obj = None + + try: + lun.delete() + finally: + if storage_obj is not None: + with suppress(Exception): + storage_obj.delete() + except Exception as e: + self.logger.warning(f"Failed clearing existing LUNs from TPG: {e}") + + def _cleanup_orphan_storage_objects(self): + """Clean up orphan storage objects under root_dir""" + try: + root_abs = os.path.abspath(self.root_dir) + for so in list(self._rtsroot.storage_objects): # type: ignore[attr-defined] + try: + if isinstance(so, FileIOStorageObject): + udev_path = os.path.abspath(getattr(so, "udev_path", "")) + if udev_path.startswith(root_abs + os.sep) or udev_path == root_abs: + with suppress(Exception): + so.delete() + except Exception: + continue + except Exception as e: + self.logger.debug(f"No orphan storage object cleanup performed: {e}") + + @export + def clear_all_luns(self): + """Remove all existing LUNs and their backstores, including any orphans under root_dir""" + if self._tpg is None: + self._configure_target() + + self._clear_tpg_luns() + self._luns.clear() + self._storage_objects.clear() + self._cleanup_orphan_storage_objects() + @export def start(self): """Start the iSCSI target server @@ -233,7 +281,7 @@ def _get_full_path(self, file_path: str, is_block: bool) -> str: else: normalized_path = os.path.normpath(file_path) - if normalized_path.startswith('..') or os.path.isabs(normalized_path): + if normalized_path.startswith("..") or os.path.isabs(normalized_path): raise ISCSIError(f"Invalid file path: {file_path}") full_path = os.path.join(self.root_dir, normalized_path) @@ -245,8 +293,70 @@ def _get_full_path(self, file_path: str, is_block: bool) -> str: os.makedirs(os.path.dirname(full_path), exist_ok=True) return full_path + def _safe_join_under_root(self, rel_path: str) -> str: + normalized_path = os.path.normpath(rel_path) + if normalized_path.startswith("..") or os.path.isabs(normalized_path): + raise ISCSIError(f"Invalid path: {rel_path}") + full_path = os.path.join(self.root_dir, normalized_path) + resolved_path = os.path.abspath(full_path) + root_path = os.path.abspath(self.root_dir) + if not resolved_path.startswith(root_path + os.sep) and resolved_path != root_path: + raise ISCSIError(f"Path traversal attempt detected: {rel_path}") + os.makedirs(os.path.dirname(full_path), exist_ok=True) + return full_path + + def _check_no_symlinks_in_path(self, path: str) -> None: + """Verify no path component is a symlink to prevent writing outside root.""" + path_to_check = path + root_abs = os.path.abspath(self.root_dir) + while path_to_check != root_abs and path_to_check != os.path.dirname(path_to_check): + if os.path.lexists(path_to_check) and os.path.islink(path_to_check): + raise ISCSIError(f"Destination path contains symlink: {path_to_check}") + path_to_check = os.path.dirname(path_to_check) + + @export + def decompress(self, src_path: str, dst_path: str, algo: str) -> None: + """Decompress a file under storage root into another path under storage root. + + src_path and dst_path are relative to root_dir. + algo is one of: gz, xz, bz2 + """ + src_full = self._safe_join_under_root(src_path) + dst_full = self._safe_join_under_root(dst_path) + self._check_no_symlinks_in_path(dst_full) + + def _copy_stream(read_f, write_f): + while True: + chunk = read_f.read(1024 * 1024) + if not chunk: + break + write_f.write(chunk) + + tmp_path = None + try: + with NamedTemporaryFile(dir=os.path.dirname(dst_full), prefix=".decomp-", delete=False) as tf: + tmp_path = tf.name + if algo == "gz": + with gzip.open(src_full, "rb") as decomp: + _copy_stream(decomp, tf) + elif algo == "xz": + with lzma.open(src_full, "rb") as decomp: + _copy_stream(decomp, tf) + elif algo == "bz2": + with bz2.open(src_full, "rb") as decomp: + _copy_stream(decomp, tf) + else: + raise ISCSIError(f"Unsupported compression algo: {algo}") + tf.flush() + os.fsync(tf.fileno()) + os.replace(tmp_path, dst_full) + except Exception as e: + with suppress(Exception): + if tmp_path is not None: + os.remove(tmp_path) + raise ISCSIError(f"Decompression failed: {e}") from e + def _create_file_storage_object(self, name: str, full_path: str, size_mb: int) -> tuple: - """Create file-backed storage object and return (storage_obj, final_size_mb)""" if not os.path.exists(full_path): if size_mb <= 0: raise ISCSIError("size_mb must be > 0 for new file-backed LUNs") @@ -297,6 +407,10 @@ def add_lun(self, name: str, file_path: str, size_mb: int = 0, is_block: bool = Raises: ISCSIError: On error or if the file_path is invalid. """ + if name in self._luns: + with suppress(Exception): + self.remove_lun(name) + size_mb = self._validate_lun_inputs(name, size_mb) full_path = self._get_full_path(file_path, is_block) diff --git a/packages/jumpstarter-driver-iscsi/pyproject.toml b/packages/jumpstarter-driver-iscsi/pyproject.toml index f5569065d..e8ee9cc7e 100644 --- a/packages/jumpstarter-driver-iscsi/pyproject.toml +++ b/packages/jumpstarter-driver-iscsi/pyproject.toml @@ -10,7 +10,9 @@ dependencies = [ "jumpstarter", "jumpstarter-driver-composite", "jumpstarter-driver-opendal", + "click>=8.1.8", "rtslib-fb", + "requests>=2.32.5" ] [tool.hatch.version]