diff --git a/lisa/feature.py b/lisa/feature.py index 62cdfb010f..aa9aebf4a9 100644 --- a/lisa/feature.py +++ b/lisa/feature.py @@ -89,7 +89,7 @@ def create_setting( @classmethod def create_image_requirement( - cls, *args: Any, **kwargs: Any + cls, image: schema.ImageSchema ) -> Optional[schema.FeatureSettings]: """ It's called in the platform to check if an image restricts the feature or not. diff --git a/lisa/schema.py b/lisa/schema.py index 38364bd23f..9bf335d4a6 100644 --- a/lisa/schema.py +++ b/lisa/schema.py @@ -870,10 +870,6 @@ class NodeSpace(search_space.RequirementMixin, TypedSchema, ExtendableSchemaMixi ), ) - # Platform may add image features to - # keep track of platform requirements. - _image_features: Optional[Dict[str, Any]] = None - def __post_init__(self, *args: Any, **kwargs: Any) -> None: # clarify types to avoid type errors in properties. self._features: Optional[search_space.SetSpace[FeatureSettings]] @@ -1624,6 +1620,11 @@ class ImageSchema: pass +class ArchitectureType(str, Enum): + x64 = "x64" + Arm64 = "Arm64" + + def load_by_type(schema_type: Type[T], raw_runbook: Any, many: bool = False) -> T: """ Convert dict, list or base typed schema to specified typed schema. diff --git a/lisa/sut_orchestrator/azure/common.py b/lisa/sut_orchestrator/azure/common.py index 31c0299651..4190d5d11a 100644 --- a/lisa/sut_orchestrator/azure/common.py +++ b/lisa/sut_orchestrator/azure/common.py @@ -11,7 +11,17 @@ from pathlib import Path, PurePath from threading import Lock from time import sleep, time -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, + Union, +) import requests from assertpy import assert_that @@ -23,7 +33,11 @@ ) from azure.keyvault.secrets import SecretClient from azure.mgmt.compute import ComputeManagementClient -from azure.mgmt.compute.models import VirtualMachine +from azure.mgmt.compute.models import ( + CommunityGalleryImage, + GalleryImage, + VirtualMachine, +) from azure.mgmt.keyvault import KeyVaultManagementClient from azure.mgmt.keyvault.models import ( AccessPolicyEntry, @@ -79,6 +93,7 @@ from lisa import schema, search_space from lisa.environment import Environment, load_environments from lisa.feature import Features +from lisa.features.security_profile import SecurityProfileType from lisa.node import Node, RemoteNode, local from lisa.secret import PATTERN_HEADTAIL, PATTERN_URL, add_secret, replace from lisa.tools import Ls @@ -195,12 +210,67 @@ class AzureVmPurchasePlanSchema: @dataclass_json @dataclass -class AzureImageSchema: +class AzureImageSchema(schema.ImageSchema): + architecture: Union[ + schema.ArchitectureType, search_space.SetSpace[schema.ArchitectureType] + ] = field( # type: ignore + default_factory=partial( + search_space.SetSpace, + is_allow_set=True, + items=[schema.ArchitectureType.x64, schema.ArchitectureType.Arm64], + ), + metadata=field_metadata( + decoder=partial( + search_space.decode_nullable_set_space, + base_type=schema.ArchitectureType, + is_allow_set=True, + default_values=[ + schema.ArchitectureType.x64, + schema.ArchitectureType.Arm64, + ], + ) + ), + ) + disk_controller_type: Optional[ + Union[ + search_space.SetSpace[schema.DiskControllerType], schema.DiskControllerType + ] + ] = field( # type:ignore + default_factory=partial( + search_space.SetSpace, + is_allow_set=True, + items=[schema.DiskControllerType.SCSI, schema.DiskControllerType.NVME], + ), + metadata=field_metadata( + decoder=partial( + search_space.decode_nullable_set_space, + base_type=schema.DiskControllerType, + is_allow_set=True, + default_values=[ + schema.DiskControllerType.SCSI, + schema.DiskControllerType.NVME, + ], + ) + ), + ) + hyperv_generation: Optional[ + Union[search_space.SetSpace[int], int] + ] = field( # type:ignore + default_factory=partial( + search_space.SetSpace, + is_allow_set=True, + items=[1, 2], + ), + metadata=field_metadata( + decoder=partial(search_space.decode_set_space_by_type, base_type=int) + ), + ) network_data_path: Optional[ Union[search_space.SetSpace[schema.NetworkDataPath], schema.NetworkDataPath] ] = field( # type: ignore default_factory=partial( search_space.SetSpace, + is_allow_set=True, items=[ schema.NetworkDataPath.Synthetic, schema.NetworkDataPath.Sriov, @@ -208,11 +278,141 @@ class AzureImageSchema: ), metadata=field_metadata( decoder=partial( - search_space.decode_set_space_by_type, base_type=schema.NetworkDataPath + search_space.decode_set_space_by_type, + base_type=schema.NetworkDataPath, + ) + ), + ) + security_profile: Union[ + search_space.SetSpace[SecurityProfileType], SecurityProfileType + ] = field( # type:ignore + default_factory=partial( + search_space.SetSpace, + is_allow_set=True, + items=[ + SecurityProfileType.Standard, + SecurityProfileType.SecureBoot, + SecurityProfileType.CVM, + SecurityProfileType.Stateless, + ], + ), + metadata=field_metadata( + decoder=partial( + search_space.decode_nullable_set_space, + base_type=SecurityProfileType, + is_allow_set=True, + default_values=[ + SecurityProfileType.Standard, + SecurityProfileType.SecureBoot, + SecurityProfileType.CVM, + SecurityProfileType.Stateless, + ], ) ), ) + def load_from_platform(self, platform: "AzurePlatform") -> None: + """ + Load image features from Azure platform. + Relevant image tags will be used to populate the schema. + """ + raw_features = self._get_info(platform) + if raw_features: + self._parse_info(raw_features, platform._log) + + def _get_info(self, platform: "AzurePlatform") -> Dict[str, Any]: + """Get raw image tags from Azure platform.""" + raise NotImplementedError() + + def _parse_info(self, raw_features: Dict[str, Any], log: Logger) -> None: + """Parse raw image tags to AzureImageSchema""" + self._parse_architecture(raw_features, log) + self._parse_disk_controller_type(raw_features, log) + self._parse_hyperv_generation(raw_features, log) + self._parse_network_data_path(raw_features, log) + self._parse_security_profile(raw_features, log) + + def _parse_architecture(self, raw_features: Dict[str, Any], log: Logger) -> None: + arch = raw_features.get("architecture") + if arch == "arm64": + self.architecture = schema.ArchitectureType.Arm64 + elif arch == "x64": + self.architecture = schema.ArchitectureType.x64 + + def _parse_disk_controller_type( + self, raw_features: Dict[str, Any], log: Logger + ) -> None: + disk_controller_type = raw_features.get("DiskControllerTypes") + if ( + isinstance(disk_controller_type, str) + and disk_controller_type.lower() == "scsi" + ): + self.disk_controller_type = schema.DiskControllerType.SCSI + elif ( + isinstance(disk_controller_type, str) + and disk_controller_type.lower() == "nvme" + ): + self.disk_controller_type = schema.DiskControllerType.NVME + + def _parse_hyperv_generation( + self, raw_features: Dict[str, Any], log: Logger + ) -> None: + try: + gen = raw_features.get("hyper_v_generation") + if gen: + self.hyperv_generation = int(gen.strip("V")) + except (TypeError, ValueError, AttributeError): + log.debug( + "Failed to parse Hyper-V generation: " + f"{raw_features.get('hyper_v_generation')}" + ) + + def _parse_network_data_path( + self, raw_features: Dict[str, Any], log: Logger + ) -> None: + network_data_path = raw_features.get("IsAcceleratedNetworkSupported") + if network_data_path == "False": + self.network_data_path = schema.NetworkDataPath.Synthetic + + def _parse_security_profile( + self, raw_features: Dict[str, Any], log: Logger + ) -> None: + security_profile = raw_features.get("SecurityType") + capabilities: List[SecurityProfileType] = [SecurityProfileType.Standard] + if security_profile == "TrustedLaunchSupported": + capabilities.append(SecurityProfileType.SecureBoot) + elif security_profile in ( + "TrustedLaunchAndConfidentialVmSupported", + "ConfidentialVmSupported", + ): + capabilities.append(SecurityProfileType.SecureBoot) + capabilities.append(SecurityProfileType.CVM) + capabilities.append(SecurityProfileType.Stateless) + self.security_profile = search_space.SetSpace(True, capabilities) + + +def _get_image_tags(image: Any) -> Dict[str, Any]: + """ + Marketplace, Shared Image Gallery, and Community Gallery images + have similar structures for image tags. This function extracts + the tags and converts to a dictionary. + """ + image_tags: Dict[str, Any] = {} + if not image: + return image_tags + if hasattr(image, "hyper_v_generation") and image.hyper_v_generation: + image_tags["hyper_v_generation"] = image.hyper_v_generation + if hasattr(image, "architecture") and image.architecture: + image_tags["architecture"] = image.architecture + if ( + hasattr(image, "features") + and image.features + and isinstance(image.features, Iterable) + ): + for feat in image.features: + image_tags[feat.name] = feat.value + return image_tags + @dataclass_json() @dataclass @@ -225,6 +425,13 @@ class AzureVmMarketplaceSchema(AzureImageSchema): def __hash__(self) -> int: return hash(f"{self.publisher}/{self.offer}/{self.sku}/{self.version}") + def _get_info(self, platform: "AzurePlatform") -> Dict[str, Any]: + for location in platform.find_marketplace_image_location(): + image_info = platform.get_image_info(location, self) + if image_info: + return _get_image_tags(image_info) + return {} + @dataclass_json() @dataclass @@ -243,6 +450,67 @@ def __hash__(self) -> int: f"{self.image_version}" ) + def query_platform(self, platform: "AzurePlatform") -> GalleryImage: + assert self.resource_group_name, "'resource_group_name' must not be 'None'" + compute_client = get_compute_client(platform) + sig = compute_client.gallery_images.get( + resource_group_name=self.resource_group_name, + gallery_name=self.image_gallery, + gallery_image_name=self.image_definition, + ) + assert isinstance(sig, GalleryImage), f"actual: {type(sig)}" + return sig + + def _get_info(self, platform: "AzurePlatform") -> Dict[str, Any]: + self.resolve_version(platform) + sig_info = self.query_platform(platform) + return _get_image_tags(sig_info) + + def resolve_version(self, platform: "AzurePlatform") -> None: + compute_client = get_compute_client(platform) + if not self.resource_group_name: + # /subscriptions/xxxx/resourceGroups/xxxx/providers/Microsoft.Compute/ + # galleries/xxxx + rg_pattern = re.compile(r"resourceGroups/(.*)/providers", re.M) + galleries = compute_client.galleries.list() + for gallery in galleries: + if gallery.name and gallery.name.lower() == self.image_gallery: + assert gallery.id, "'gallery.id' must not be 'None'" + self.resource_group_name = get_matched_str(gallery.id, rg_pattern) + break + if not self.resource_group_name: + raise LisaException(f"did not find matched gallery {self.image_gallery}") + + if self.image_version.lower() == "latest": + gallery_images = ( + compute_client.gallery_image_versions.list_by_gallery_image( + resource_group_name=self.resource_group_name, + gallery_name=self.image_gallery, + gallery_image_name=self.image_definition, + ) + ) + time: Optional[datetime] = None + for image in gallery_images: + assert image, "'image' must not be 'None'" + assert image.name, "'image.name' must not be 'None'" + gallery_image = compute_client.gallery_image_versions.get( + resource_group_name=self.resource_group_name, + gallery_name=self.image_gallery, + gallery_image_name=self.image_definition, + gallery_image_version_name=image.name, + expand="ReplicationStatus", + ) + if not time: + time = gallery_image.publishing_profile.published_date + assert image, "'image' must not be 'None'" + assert image.name, "'image.name' must not be 'None'" + self.image_version = image.name + elif gallery_image.publishing_profile.published_date > time: + time = gallery_image.publishing_profile.published_date + assert image, "'image' must not be 'None'" + assert image.name, "'image.name' must not be 'None'" + self.image_version = image.name + @dataclass_json() @dataclass @@ -250,6 +518,13 @@ class VhdSchema(AzureImageSchema): vhd_path: str = "" vmgs_path: Optional[str] = None + def _get_info(self, platform: "AzurePlatform") -> Dict[str, Any]: + # No image tag information is available for VHDs + return {} + + def _parse_info(self, raw_features: Dict[str, Any], log: Logger) -> None: + return + @dataclass_json() @dataclass @@ -264,6 +539,50 @@ def __hash__(self) -> int: f"{self.image_gallery}/{self.image_definition}/{self.image_version}" ) + def query_platform(self, platform: "AzurePlatform") -> CommunityGalleryImage: + compute_client = get_compute_client(platform) + cgi = compute_client.community_gallery_images.get( + location=self.location, + public_gallery_name=self.image_gallery, + gallery_image_name=self.image_definition, + ) + assert isinstance(cgi, CommunityGalleryImage), f"actual: {type(cgi)}" + return cgi + + def _get_info(self, platform: "AzurePlatform") -> Dict[str, Any]: + self.resolve_version(platform) + cgi_info = self.query_platform(platform) + return _get_image_tags(cgi_info) + + def resolve_version(self, platform: "AzurePlatform") -> None: + compute_client = get_compute_client(platform) + if self.image_version.lower() == "latest": + community_gallery_images_list = ( + compute_client.community_gallery_image_versions.list( + location=self.location, + public_gallery_name=self.image_gallery, + gallery_image_name=self.image_definition, + ) + ) + time: Optional[datetime] = None + for image in community_gallery_images_list: + assert image, "'image' must not be 'None'" + assert image.name, "'image.name' must not be 'None'" + community_gallery_image_version = ( + compute_client.community_gallery_image_versions.get( + location=self.location, + public_gallery_name=self.image_gallery, + gallery_image_name=self.image_definition, + gallery_image_version_name=image.name, + ) + ) + if not time: + time = community_gallery_image_version.published_date + self.image_version = image.name + elif community_gallery_image_version.published_date > time: + time = community_gallery_image_version.published_date + self.image_version = image.name + @dataclass_json() @dataclass @@ -568,6 +887,31 @@ def community_gallery_image( ) -> None: self._parse_image_raw("community_gallery_image", value) + @property + def image(self) -> Optional[AzureImageSchema]: + if self.marketplace: + return self.marketplace + elif self.shared_gallery: + return self.shared_gallery + elif self.community_gallery_image: + return self.community_gallery_image + elif self.vhd: + return self.vhd + return None + + @image.setter + def image(self, value: Optional[AzureImageSchema]) -> None: + if isinstance(value, AzureVmMarketplaceSchema): + self.marketplace = value + elif isinstance(value, SharedImageGallerySchema): + self.shared_gallery = value + elif isinstance(value, CommunityGalleryImageSchema): + self.community_gallery_image = value + elif isinstance(value, VhdSchema): + self.vhd = value + else: + raise LisaException(f"unsupported image type: {type(value)}") + def get_image_name(self) -> str: result = "" if self._orignal_vhd_path: @@ -578,7 +922,7 @@ def get_image_name(self) -> str: ), f"actual type: {type(self.shared_gallery_raw)}" if self.shared_gallery.resource_group_name: result = "/".join( - [self.shared_gallery_raw.get(k, "") for k in SIG_IMAGE_KEYS] + [getattr(self.shared_gallery, k, "") for k in SIG_IMAGE_KEYS] ) else: result = ( @@ -602,6 +946,13 @@ def get_image_name(self) -> str: ) return result + def update_raw(self) -> None: + self._parse_image_raw("purchase_plan", self.purchase_plan) + self._parse_image_raw("marketplace", self.marketplace) + self._parse_image_raw("shared_gallery", self.shared_gallery) + self._parse_image_raw("vhd", self.vhd) + self._parse_image_raw("community_gallery_image", self.community_gallery_image) + def _parse_image( self, prop_name: str, diff --git a/lisa/sut_orchestrator/azure/features.py b/lisa/sut_orchestrator/azure/features.py index 081407677c..a9e6d71632 100644 --- a/lisa/sut_orchestrator/azure/features.py +++ b/lisa/sut_orchestrator/azure/features.py @@ -6,7 +6,6 @@ import re import string from dataclasses import dataclass, field -from enum import Enum from functools import partial from pathlib import Path from random import randint @@ -90,6 +89,7 @@ from .common import ( AvailabilityArmParameter, AzureArmParameter, + AzureImageSchema, AzureNodeSchema, check_or_create_storage_account, create_update_private_dns_zone_groups, @@ -2298,31 +2298,10 @@ def create_setting( @classmethod def create_image_requirement( - cls, *args: Any, **kwargs: Any + cls, image: schema.ImageSchema ) -> Optional[schema.FeatureSettings]: - raw_capabilities: Any = kwargs.get("raw_capabilities") - assert isinstance(raw_capabilities, dict) - - # Skip checking tags for VHD - if raw_capabilities.get("type") == "vhd": - return None - - capabilities: List[SecurityProfileType] = [SecurityProfileType.Standard] - - value = raw_capabilities.get("SecurityType") - if value == "TrustedLaunchSupported": - capabilities.append(SecurityProfileType.SecureBoot) - elif value in ( - "TrustedLaunchAndConfidentialVmSupported", - "ConfidentialVmSupported", - ): - capabilities.append(SecurityProfileType.SecureBoot) - capabilities.append(SecurityProfileType.CVM) - capabilities.append(SecurityProfileType.Stateless) - - return SecurityProfileSettings( - security_profile=search_space.SetSpace(True, capabilities) - ) + assert isinstance(image, AzureImageSchema), f"actual: {type(image)}" + return SecurityProfileSettings(security_profile=image.security_profile) @classmethod def on_before_deployment(cls, *args: Any, **kwargs: Any) -> None: @@ -3011,14 +2990,10 @@ def create_setting( @classmethod def create_image_requirement( - cls, *args: Any, **kwargs: Any + cls, image: schema.ImageSchema ) -> Optional[schema.FeatureSettings]: - raw_capabilities: Any = kwargs.get("raw_capabilities") - assert isinstance(raw_capabilities, dict), f"actual: {type(raw_capabilities)}" - - value: str = raw_capabilities.get("hyper_v_generation", "V1") - gen = 2 if value == "V2" else 1 - return VhdGenerationSettings(gen=gen) + assert isinstance(image, AzureImageSchema), f"actual: {type(image)}" + return VhdGenerationSettings(gen=image.hyperv_generation) @classmethod def settings_type(cls) -> Type[schema.FeatureSettings]: @@ -3032,27 +3007,26 @@ def enabled(self) -> bool: return True -class ArchitectureType(str, Enum): - x64 = "x64" - Arm64 = "Arm64" - - @dataclass_json() @dataclass() class ArchitectureSettings(schema.FeatureSettings): type: str = "Architecture" # Architecture in hyper-v arch: Union[ - ArchitectureType, search_space.SetSpace[ArchitectureType] + schema.ArchitectureType, search_space.SetSpace[schema.ArchitectureType] ] = field( # type: ignore default_factory=partial( - search_space.SetSpace, items=[ArchitectureType.x64, ArchitectureType.Arm64] + search_space.SetSpace, + items=[schema.ArchitectureType.x64, schema.ArchitectureType.Arm64], ), metadata=field_metadata( decoder=partial( search_space.decode_nullable_set_space, - base_type=ArchitectureType, - default_values=[ArchitectureType.x64, ArchitectureType.Arm64], + base_type=schema.ArchitectureType, + default_values=[ + schema.ArchitectureType.x64, + schema.ArchitectureType.Arm64, + ], ) ), ) @@ -3099,7 +3073,7 @@ def _call_requirement_method( value.arch = getattr(search_space, f"{method.value}_setspace_by_priority")( self.arch, capability.arch, - [ArchitectureType.x64, ArchitectureType.Arm64], + [schema.ArchitectureType.x64, schema.ArchitectureType.Arm64], ) return value @@ -3111,21 +3085,15 @@ def create_setting( ) -> Optional[schema.FeatureSettings]: raw_capabilities: Any = kwargs.get("raw_capabilities") return ArchitectureSettings( - arch=raw_capabilities.get("CpuArchitectureType", "x64") + arch=raw_capabilities.get("Cpuschema.ArchitectureType", "x64") ) @classmethod def create_image_requirement( - cls, *args: Any, **kwargs: Any + cls, image: schema.ImageSchema ) -> Optional[schema.FeatureSettings]: - raw_capabilities: Any = kwargs.get("raw_capabilities") - assert isinstance(raw_capabilities, dict), f"actual: {type(raw_capabilities)}" - value: Optional[str] = raw_capabilities.get("architecture") - if value == "arm64": - return ArchitectureSettings(arch=ArchitectureType.Arm64) - elif value == "x64": - return ArchitectureSettings(arch=ArchitectureType.x64) - return None + assert isinstance(image, AzureImageSchema), f"actual: {type(image)}" + return ArchitectureSettings(arch=image.architecture) @classmethod def settings_type(cls) -> Type[schema.FeatureSettings]: diff --git a/lisa/sut_orchestrator/azure/platform_.py b/lisa/sut_orchestrator/azure/platform_.py index 8bd5eedf31..842146e27b 100644 --- a/lisa/sut_orchestrator/azure/platform_.py +++ b/lisa/sut_orchestrator/azure/platform_.py @@ -784,9 +784,9 @@ def _get_node_information(self, node: Node) -> Dict[str, str]: node.log.debug(f"vm generation: {information[KEY_VM_GENERATION]}") if node.capture_kernel_config: node.log.debug("detecting mana driver enabled...") - information[KEY_MANA_DRIVER_ENABLED] = ( - node.nics.is_mana_driver_enabled() - ) + information[ + KEY_MANA_DRIVER_ENABLED + ] = node.nics.is_mana_driver_enabled() node.log.debug(f"mana enabled: {information[KEY_MANA_DRIVER_ENABLED]}") node.log.debug("detecting nvme driver enabled...") _has_nvme_core = node.tools[KernelConfig].is_built_in( @@ -1011,13 +1011,13 @@ def _initialize_credential(self) -> None: logging.getLogger("azure").setLevel(azure_runbook.log_level) if azure_runbook.service_principal_tenant_id: - os.environ["AZURE_TENANT_ID"] = ( - azure_runbook.service_principal_tenant_id - ) + os.environ[ + "AZURE_TENANT_ID" + ] = azure_runbook.service_principal_tenant_id if azure_runbook.service_principal_client_id: - os.environ["AZURE_CLIENT_ID"] = ( - azure_runbook.service_principal_client_id - ) + os.environ[ + "AZURE_CLIENT_ID" + ] = azure_runbook.service_principal_client_id if azure_runbook.service_principal_key: os.environ["AZURE_CLIENT_SECRET"] = azure_runbook.service_principal_key @@ -1405,26 +1405,21 @@ def _create_node_runbook( azure_node_runbook.vhd = vhd azure_node_runbook.marketplace = None azure_node_runbook.shared_gallery = None + azure_node_runbook.community_gallery_image = None log.debug( f"current vhd generation is {azure_node_runbook.hyperv_generation}." ) elif azure_node_runbook.shared_gallery: azure_node_runbook.marketplace = None - azure_node_runbook.shared_gallery = self._parse_shared_gallery_image( - azure_node_runbook.shared_gallery - ) - azure_node_runbook.hyperv_generation = _get_gallery_image_generation( - self._get_sig(azure_node_runbook.shared_gallery) - ) + azure_node_runbook.community_gallery_image = None + azure_node_runbook.shared_gallery.resolve_version(self) + azure_node_runbook.update_raw() elif azure_node_runbook.community_gallery_image: azure_node_runbook.marketplace = None - azure_node_runbook.community_gallery_image = ( - self._parse_community_gallery_image( - azure_node_runbook.community_gallery_image - ) - ) + azure_node_runbook.community_gallery_image.resolve_version(self) + azure_node_runbook.update_raw() azure_node_runbook.hyperv_generation = _get_gallery_image_generation( - self._get_cgi(azure_node_runbook.community_gallery_image) + azure_node_runbook.community_gallery_image.query_platform(self) ) elif not azure_node_runbook.marketplace: # set to default marketplace, if nothing specified @@ -1438,7 +1433,7 @@ def _create_node_runbook( azure_node_runbook.marketplace = self._resolve_marketplace_image( azure_node_runbook.location, azure_node_runbook.marketplace ) - image_info = self._get_image_info( + image_info = self.get_image_info( azure_node_runbook.location, azure_node_runbook.marketplace ) # HyperVGenerationTypes return "V1"/"V2", so we need to strip "V" @@ -1496,7 +1491,7 @@ def _create_node_arm_parameters( assert ( arm_parameters.marketplace ), "not set one of marketplace, shared_gallery or vhd." - image_info = self._get_image_info( + image_info = self.get_image_info( arm_parameters.location, arm_parameters.marketplace ) if image_info: @@ -1837,14 +1832,14 @@ def _resource_sku_to_capability( # noqa: C901 azure_raw_capabilities["availability_zones"] = location_info.zones for zone_details in location_info.zone_details: for location_capability in zone_details.capabilities: - azure_raw_capabilities[location_capability.name] = ( - location_capability.value - ) + azure_raw_capabilities[ + location_capability.name + ] = location_capability.value # Zones supporting the feature if zone_details.additional_properties["Name"]: - azure_raw_capabilities["availability_zones"] = ( - zone_details.additional_properties["Name"] - ) + azure_raw_capabilities[ + "availability_zones" + ] = zone_details.additional_properties["Name"] if resource_sku.capabilities: for sku_capability in resource_sku.capabilities: @@ -2077,94 +2072,6 @@ def _resolve_marketplace_image( return new_marketplace - @lru_cache(maxsize=10) # noqa: B019 - def _parse_community_gallery_image( - self, community_gallery_image: CommunityGalleryImageSchema - ) -> CommunityGalleryImageSchema: - new_community_gallery_image = copy.copy(community_gallery_image) - compute_client = get_compute_client(self) - if community_gallery_image.image_version.lower() == "latest": - community_gallery_images_list = ( - compute_client.community_gallery_image_versions.list( - location=community_gallery_image.location, - public_gallery_name=community_gallery_image.image_gallery, - gallery_image_name=community_gallery_image.image_definition, - ) - ) - image: Optional[CommunityGalleryImageVersion] = None - time: Optional[datetime] = None - for image in community_gallery_images_list: - assert image, "'image' must not be 'None'" - assert image.name, "'image.name' must not be 'None'" - community_gallery_image_version = ( - compute_client.community_gallery_image_versions.get( - location=community_gallery_image.location, - public_gallery_name=community_gallery_image.image_gallery, - gallery_image_name=community_gallery_image.image_definition, - gallery_image_version_name=image.name, - ) - ) - if not time: - time = community_gallery_image_version.published_date - new_community_gallery_image.image_version = image.name - if community_gallery_image_version.published_date > time: - time = community_gallery_image_version.published_date - new_community_gallery_image.image_version = image.name - return new_community_gallery_image - - @lru_cache(maxsize=10) # noqa: B019 - def _parse_shared_gallery_image( - self, shared_image: SharedImageGallerySchema - ) -> SharedImageGallerySchema: - new_shared_image = copy.copy(shared_image) - compute_client = get_compute_client(self) - rg_name = shared_image.resource_group_name - if not shared_image.resource_group_name: - # /subscriptions/xxxx/resourceGroups/xxxx/providers/Microsoft.Compute/ - # galleries/xxxx - rg_pattern = re.compile(r"resourceGroups/(.*)/providers", re.M) - galleries = compute_client.galleries.list() - for gallery in galleries: - if gallery.name.lower() == shared_image.image_gallery: - rg_name = get_matched_str(gallery.id, rg_pattern) - break - if not rg_name: - raise LisaException( - f"not find matched gallery {shared_image.image_gallery}" - ) - new_shared_image.resource_group_name = rg_name - if shared_image.image_version.lower() == "latest": - gallery_images = ( - compute_client.gallery_image_versions.list_by_gallery_image( - resource_group_name=new_shared_image.resource_group_name, - gallery_name=new_shared_image.image_gallery, - gallery_image_name=new_shared_image.image_definition, - ) - ) - image: Optional[GalleryImageVersion] = None - time: Optional[datetime] = None - for image in gallery_images: - assert image, "'image' must not be 'None'" - assert image.name, "'image.name' must not be 'None'" - gallery_image = compute_client.gallery_image_versions.get( - resource_group_name=new_shared_image.resource_group_name, - gallery_name=new_shared_image.image_gallery, - gallery_image_name=new_shared_image.image_definition, - gallery_image_version_name=image.name, - expand="ReplicationStatus", - ) - if not time: - time = gallery_image.publishing_profile.published_date - assert image, "'image' must not be 'None'" - assert image.name, "'image.name' must not be 'None'" - new_shared_image.image_version = image.name - if gallery_image.publishing_profile.published_date > time: - time = gallery_image.publishing_profile.published_date - assert image, "'image' must not be 'None'" - assert image.name, "'image.name' must not be 'None'" - new_shared_image.image_version = image.name - return new_shared_image - @lru_cache(maxsize=10) # noqa: B019 def _process_marketplace_image_plan( self, @@ -2306,7 +2213,7 @@ def _generate_data_disks( data_disks: List[DataDiskSchema] = [] assert node.capability.disk if azure_node_runbook.marketplace: - marketplace = self._get_image_info( + marketplace = self.get_image_info( azure_node_runbook.location, azure_node_runbook.marketplace ) # some images has data disks by default @@ -2368,7 +2275,7 @@ def _generate_data_disks( return data_disks @lru_cache(maxsize=10) # noqa: B019 - def _get_image_info( + def get_image_info( self, location: str, marketplace: Optional[AzureVmMarketplaceSchema] ) -> Optional[VirtualMachineImage]: # resolve "latest" to specified version @@ -2525,17 +2432,6 @@ def _get_cgi( assert isinstance(cgi, CommunityGalleryImage), f"actual: {type(cgi)}" return cgi - @lru_cache(maxsize=10) # noqa: B019 - def _get_sig(self, shared_image: SharedImageGallerySchema) -> GalleryImage: - compute_client = get_compute_client(self) - sig = compute_client.gallery_images.get( - resource_group_name=shared_image.resource_group_name, - gallery_name=shared_image.image_gallery, - gallery_image_name=shared_image.image_definition, - ) - assert isinstance(sig, GalleryImage), f"actual: {type(sig)}" - return sig - def _get_sig_os_disk_size(self, shared_image: SharedImageGallerySchema) -> int: found_image = self._get_sig_version(shared_image) assert found_image.storage_profile, "'storage_profile' must not be 'None'" @@ -2885,7 +2781,7 @@ def _resolve_marketplace_image_version( node_runbook.location, node_runbook.marketplace ) - def _find_marketplace_image_location(self) -> List[str]: + def find_marketplace_image_location(self) -> List[str]: # locations used to query marketplace image information. Some image is not # available in all locations, so try several of them. _marketplace_image_locations = [ @@ -2921,44 +2817,21 @@ def _add_image_features(self, node_space: schema.NodeSpace) -> None: is_allow_set=True ) - image_features: Dict[str, Any] = {} azure_runbook = node_space.get_extended_runbook(AzureNodeSchema, AZURE) - - # Get the image information - if azure_runbook.marketplace: - for location in self._find_marketplace_image_location(): - image_info = self._get_image_info(location, azure_runbook.marketplace) - if image_info: - break - image_features = _get_image_features("marketplace", image_info) - elif azure_runbook.shared_gallery: - azure_runbook.shared_gallery = self._parse_shared_gallery_image( - azure_runbook.shared_gallery - ) - sig_info = self._get_detailed_sig(azure_runbook.shared_gallery) - image_features = _get_image_features("shared_gallery", sig_info) - elif azure_runbook.community_gallery_image: - azure_runbook.community_gallery_image = self._parse_community_gallery_image( - azure_runbook.community_gallery_image - ) - cgi = self._get_cgi(azure_runbook.community_gallery_image) - generation = _get_gallery_image_generation(cgi) - node_space.features.add(features.VhdGenerationSettings(gen=generation)) - node_space.features.add( - features.ArchitectureSettings(arch=cgi.architecture) # type: ignore - ) - elif azure_runbook.vhd: - image_features["type"] = "vhd" - image_features["hyper_v_generation"] = f"V{azure_runbook.hyperv_generation}" - else: + image = azure_runbook.image + if not image: return + # Default to provided hyperv_generation, + # but will be overrriden if the image is tagged + image.hyperv_generation = azure_runbook.hyperv_generation + image.load_from_platform(self) # Create Image requirements for each Feature - node_space._image_features = image_features for feat in self.supported_features(): - image_req = feat.create_image_requirement(raw_capabilities=image_features) + image_req = feat.create_image_requirement(image) if not image_req: continue + # Merge with existing requirements node_cap = node_space._find_feature_by_type( image_req.type, node_space.features ) @@ -2977,7 +2850,7 @@ def _set_disk_features( ) -> None: assert node_space.disk assert node_space.disk.os_disk_type - assert node_space._image_features + assert azure_runbook.image if ( isinstance(node_space.disk.os_disk_type, schema.DiskType) and schema.DiskType.Ephemeral == node_space.disk.os_disk_type @@ -2994,19 +2867,14 @@ def _set_disk_features( node_space.disk.disk_controller_type = search_space.SetSpace[ schema.DiskControllerType ](is_allow_set=True, items=[node_space.disk.disk_controller_type]) - disk_controller_types: str = node_space._image_features.get( - "DiskControllerTypes", "SCSI, NVMe" - ) - disk_controller_types = disk_controller_types.lower() - types = disk_controller_types.split(",") - types = [x.strip() for x in types] - allowed_types = search_space.SetSpace[schema.DiskControllerType]( - is_allow_set=True - ) - if "scsi" in types: - allowed_types.add(schema.DiskControllerType.SCSI) - if "nvme" in types: - allowed_types.add(schema.DiskControllerType.NVME) + if isinstance( + azure_runbook.image.disk_controller_type, schema.DiskControllerType + ): + azure_runbook.image.disk_controller_type = search_space.SetSpace[ + schema.DiskControllerType + ](is_allow_set=True, items=[azure_runbook.image.disk_controller_type]) + + allowed_types = azure_runbook.image.disk_controller_type if node_space.disk.disk_controller_type: node_space.disk.disk_controller_type = ( node_space.disk.disk_controller_type.intersect(allowed_types) @@ -3017,8 +2885,8 @@ def _set_disk_features( def _get_os_disk_size(self, azure_runbook: AzureNodeSchema) -> int: assert azure_runbook if azure_runbook.marketplace: - for location in self._find_marketplace_image_location(): - image_info = self._get_image_info(location, azure_runbook.marketplace) + for location in self.find_marketplace_image_location(): + image_info = self.get_image_info(location, azure_runbook.marketplace) if image_info: break if image_info and image_info.os_disk_image: @@ -3029,14 +2897,12 @@ def _get_os_disk_size(self, azure_runbook: AzureNodeSchema) -> int: # if no image info, use default size 30 return 30 elif azure_runbook.shared_gallery: - azure_runbook.shared_gallery = self._parse_shared_gallery_image( - azure_runbook.shared_gallery - ) + azure_runbook.shared_gallery.resolve_version(self) + azure_runbook.update_raw() return self._get_sig_os_disk_size(azure_runbook.shared_gallery) elif azure_runbook.community_gallery_image: - azure_runbook.community_gallery_image = self._parse_community_gallery_image( - azure_runbook.community_gallery_image - ) + azure_runbook.community_gallery_image.resolve_version(self) + azure_runbook.update_raw() return self._get_cgi_os_disk_size(azure_runbook.community_gallery_image) else: assert azure_runbook.vhd @@ -3153,18 +3019,3 @@ def _get_disk_size_in_gb(additional_properties: Dict[str, int]) -> int: ) return osdisk_size_in_gb - - -def _get_image_features(image_type: str, image: Any) -> Dict[str, Any]: - image_features: Dict[str, Any] = {} - if not image: - return image_features - image_features["type"] = image_type - if image.hyper_v_generation: - image_features["hyper_v_generation"] = image.hyper_v_generation - if image.architecture: - image_features["architecture"] = image.architecture - if image.features and isinstance(image.features, Iterable): - for feat in image.features: - image_features[feat.name] = feat.value - return image_features