diff --git a/src/AzSKPy/azskpy/__init__.py b/src/AzSKPy/azskpy/__init__.py new file mode 100644 index 000000000..15b950d3c --- /dev/null +++ b/src/AzSKPy/azskpy/__init__.py @@ -0,0 +1,7 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License in the project root for +# license information. +# ------------------------------------------------------------------------------ + +from .constants import * diff --git a/src/AzSKPy/azskpy/constants.py b/src/AzSKPy/azskpy/constants.py new file mode 100644 index 000000000..d3fac0353 --- /dev/null +++ b/src/AzSKPy/azskpy/constants.py @@ -0,0 +1,9 @@ +__name__ = "azskpy" +__version__ = "0.9.58" +PASSED = "Passed" +FAILED = "Failed" +MANUAL = "Verify" +ERROR = "Error" +HDI_METADATA_WRITE_PATH = "wasb:///AzSK_Metadata/Meta-{}.json" +HDI_SCANLOG_WRITE_PATH = "wasb:///AzSK_Logs/Scan-{}.csv" +LOG_ANALYTICS_API_VERSION = "2016-04-01" diff --git a/src/AzSKPy/azskpy/kubernetescontrols.py b/src/AzSKPy/azskpy/kubernetescontrols.py new file mode 100644 index 000000000..e77d5aa0f --- /dev/null +++ b/src/AzSKPy/azskpy/kubernetescontrols.py @@ -0,0 +1,479 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License in the project root for +# license information. +# ------------------------------------------------------------------------------ + +from kubernetes import client, config +from kubernetes.client import configuration +from applicationinsights import TelemetryClient +import pandas as pd +import numpy as np +import os +import copy +from .utils import fail_with_manual +from distutils.version import StrictVersion + + +class AKSBootstrap: + __instance = None + + def __init__(self): + if AKSBootstrap.__instance != None: + raise Exception("Something went wrong.") + else: + self.resources = { + "APP_INSIGHT_KEY": None, + "SUBSCRIPTION_ID": None, + "RG_NAME": None, + "RESOURCE_NAME": None, + "pods": [], + "service_accounts": [] + } + try: + config.load_incluster_config() + v1 = client.CoreV1Api() + pod_response = v1.list_pod_for_all_namespaces(watch=False) + pods = list(filter(lambda x: x.metadata.namespace != 'kube-system', pod_response.items)) + self.resources['pods'] = pods + serviceacc_response = v1.list_service_account_for_all_namespaces(watch=False) + service_accounts = list( + filter(lambda x: x.metadata.namespace != 'kube-system', serviceacc_response.items)) + self.resources['service_accounts'] = service_accounts + self.resources['SUBSCRIPTION_ID'] = os.environ.get("SUBSCRIPTION_ID", None) + self.resources['RG_NAME'] = os.environ.get("RG_NAME", None) + self.resources['RESOURCE_NAME'] = os.environ.get("RESOURCE_NAME", None) + self.resources['APP_INSIGHT_KEY'] = os.environ.get("APP_INSIGHT_KEY", None) + except Exception as e: + print(e) + + AKSBootstrap.__instance = self + + @staticmethod + def get_config(): + if AKSBootstrap.__instance == None: + AKSBootstrap() + return AKSBootstrap.__instance + + +class AKSBaseControlTest: + + def __init__(self): + cluster_config = AKSBootstrap.get_config() + self.pods = cluster_config.resources['pods'] + self.service_accounts = cluster_config.resources['service_accounts'] + self.APP_INSIGHT_KEY = cluster_config.resources['APP_INSIGHT_KEY'] + self.SUBSCRIPTION_ID = cluster_config.resources['SUBSCRIPTION_ID'] + self.RG_NAME = cluster_config.resources['RG_NAME'] + self.RESOURCE_NAME = cluster_config.resources['RESOURCE_NAME'] + self.detailed_logs = {'desc': "", 'control_id': "", 'non_compliant_containers': [], + 'service_accounts': [], 'non_compliant_pods': [], 'pods_with_secrets': [],'container_images' : [], 'non_compliant_services':[]} + + def test(self) -> str: + raise NotImplementedError + + def CheckSecurityConfig(self, property, expected_value): + result = [] + for pod in self.pods: + security_context = pod.spec.security_context + info = { + "namespace": pod.metadata.namespace, + "pod_name": pod.metadata.name, + "container": None, + "run_as_non_root": None, + "allow_privilege_escalation": None, + "read_only_root_filesystem": None + } + if security_context != None: + pod_run_as_non_root = getattr(security_context, 'run_as_non_root', None) + pod_allow_privilege_escalation = getattr(security_context, 'allow_privilege_escalation', None) + pod_read_only_root_filesystem = getattr(security_context, 'read_only_root_filesystem', None) + for container in pod.spec.containers: + if container.name != 'azsk-ca-job': + security_context = container.security_context + run_as_non_root = getattr(security_context, 'run_as_non_root', pod_run_as_non_root) + if run_as_non_root == None: + run_as_non_root = pod_run_as_non_root + allow_privilege_escalation = getattr(security_context, 'allow_privilege_escalation', + pod_allow_privilege_escalation) + read_only_root_filesystem = getattr(security_context, 'read_only_root_filesystem', + pod_read_only_root_filesystem) + info = copy.deepcopy(info) + info['container'] = container.name + info['run_as_non_root'] = run_as_non_root + info['allow_privilege_escalation'] = allow_privilege_escalation + info['read_only_root_filesystem'] = read_only_root_filesystem + result.append(info) + non_compliant_containers = filter(lambda x: x[property] != expected_value, result) + return non_compliant_containers + + def set_credentials(self, uname, password): + pass + + def send_telemetry(self, event_name, custom_properties): + if self.APP_INSIGHT_KEY != None: + try: + tc = TelemetryClient(self.APP_INSIGHT_KEY) + # Add common properties + custom_properties['SubscriptionId'] = self.SUBSCRIPTION_ID + custom_properties['ResourceGroupName'] = self.RG_NAME + custom_properties['ResourceName'] = self.RESOURCE_NAME + # Send telemetry event + tc.track_event(event_name, custom_properties) + tc.flush() + except: + pass + # No need to break execution, if any exception occurs while sending telemetry + + +@fail_with_manual +class CheckContainerRunAsNonRoot(AKSBaseControlTest): + name = "Dont_Run_Container_As_Root" + + def __init__(self): + super().__init__() + self.desc = ("Container must run as a non-root user") + + def test(self): + non_compliant_containers = [] + if len(self.pods) > 0: + non_compliant_containers = list(self.CheckSecurityConfig('run_as_non_root', True)) + else: + return "Manual" + if len(non_compliant_containers) > 0: + self.detailed_logs[ + 'desc'] = self.desc + "\nFor following container(s), runAsNonRoot is either set to 'False' or 'None':" + self.detailed_logs['non_compliant_containers'] = non_compliant_containers + return ("Failed") + else: + return ("Passed") + + +@fail_with_manual +class CheckContainerPrivilegeEscalation(AKSBaseControlTest): + name = "Restrict_Container_Privilege_Escalation" + + def __init__(self): + super().__init__() + self.desc = ("Container should not allow privilege escalation") + + def test(self): + non_compliant_containers = [] + if len(self.pods) > 0: + non_compliant_containers = list(self.CheckSecurityConfig('allow_privilege_escalation', False)) + else: + return ("Manual") + if len(non_compliant_containers) > 0: + self.detailed_logs[ + 'desc'] = self.desc + "\nFor Following container(s), allowPrivilegeEscalation is either set to 'True' or 'None':" + self.detailed_logs['non_compliant_containers'] = non_compliant_containers + return ("Failed") + else: + return ("Passed") + + +@fail_with_manual +class CheckContainerReadOnlyRootFilesystem(AKSBaseControlTest): + name = "Set_Read_Only_Root_File_System" + + def __init__(self): + super().__init__() + self.desc = ("Container should not be allowed to write to the root/host filesystem") + + def test(self): + non_compliant_containers = [] + if len(self.pods) > 0: + non_compliant_containers = list(self.CheckSecurityConfig('read_only_root_filesystem', True)) + else: + return ("Manual") + if len(non_compliant_containers) > 0: + self.detailed_logs[ + 'desc'] = self.desc + "\nFor Following container(s), readOnlyRootFilesystem is either set to 'False' or 'None':" + self.detailed_logs['non_compliant_containers'] = non_compliant_containers + return ("Failed") + else: + return ("Passed") + + +@fail_with_manual +class CheckInactiveServiceAccounts(AKSBaseControlTest): + name = "Remove_Inactive_Service_Accounts" + + def __init__(self): + super().__init__() + self.desc = ("Cluster should not have any inactive service account") + + def test(self): + if len(self.service_accounts) > 0: + self.service_accounts = list(filter(lambda x: x.metadata.name != 'default', self.service_accounts)) + else: + return ("Manual") + all_svc_accounts = set() + pod_svc_accounts = set() + for item in self.service_accounts: + all_svc_accounts.add((item.metadata.name, item.metadata.namespace)) + for item in self.pods: + pod_svc_accounts.add((item.spec.service_account, item.metadata.namespace)) + inactive_svc_accounts = all_svc_accounts - pod_svc_accounts + if len(inactive_svc_accounts) > 0: + self.detailed_logs[ + 'desc'] = self.desc + "\nFollowing service account(s) are not referenced by any pod/container:" + self.detailed_logs['service_accounts'] = inactive_svc_accounts + return ("Failed") + else: + return ("Passed") + + +@fail_with_manual +class CheckClusterManagedIdentity(AKSBaseControlTest): + name = "Use_Managed_Service_Identity" + + def __init__(self): + super().__init__() + self.desc = ("Managed System Identity (MSI) should be used to access Azure resources from cluster") + + def test(self): + if len(self.pods) > 0: + result = "Failed" + else: + return "Manual" + + for item in self.pods: + if item.metadata.name.find("mic-") != -1 and item.spec.containers[0].image.find( + "mcr.microsoft.com/k8s/aad-pod-identity/mic") != -1: + result = "Verify" + return result + + +class CheckDefaultSvcRoleBinding(AKSBaseControlTest): + name = "Dont_Bind_Role_To_Default_Svc_Acc" + + def __init__(self): + super().__init__() + self.desc = ("Default service account should not be assigned any cluster role") + + @fail_with_manual + def test(self): + is_failed = False + config.load_incluster_config() + v1beta1 = client.RbacAuthorizationV1beta1Api() + clusterRoles = v1beta1.list_cluster_role_binding(watch=False) + for item in clusterRoles.items: + try: + subjects = item.subjects + for subject in subjects: + if subject.namespace == "default" and subject.name == "default": + is_failed = True + except: + pass + if is_failed: + return "Failed" + else: + return "Passed" + + +class CheckContainerPrivilegeMode(AKSBaseControlTest): + name = "Dont_Run_Privileged_Container" + + def __init__(self): + super().__init__() + self.desc = ("Container should not run in the privileged mode") + + @fail_with_manual + def test(self): + if len(self.pods) > 0: + result = [] + for pod in self.pods: + security_context = pod.spec.security_context + info = { + "namespace": pod.metadata.namespace, + "pod_name": pod.metadata.name, + "container": None, + "privileged": None + } + if security_context != None: + pod_privileged = getattr(security_context, 'privileged', None) + for container in pod.spec.containers: + if container.name != 'azsk-ca-job': + security_context = container.security_context + privileged = getattr(security_context, 'privileged', pod_privileged) + info = copy.deepcopy(info) + info['container'] = container.name + info['privileged'] = privileged + result.append(info) + non_compliant_containers = list(filter(lambda x: x["privileged"] == True, result)) + else: + return ("Manual") + if len(non_compliant_containers) > 0: + self.detailed_logs['desc'] = self.desc + "\nFollowing container(s) will run in the privileged mode:" + self.detailed_logs['non_compliant_containers'] = non_compliant_containers + return ("Failed") + else: + return ("Passed") + + +class CheckDefaultNamespaceResources(AKSBaseControlTest): + name = "Dont_Use_Default_Namespace" + + def __init__(self): + super().__init__() + self.desc = ("Do not use the default cluster namespace to deploy applications") + + @fail_with_manual + def test(self): + pods_in_default_namespace = [] + if len(self.pods) > 0: + pods_in_default_namespace = list(filter(lambda x: x.metadata.namespace == 'default', self.pods)) + else: + return ("Manual") + if len(pods_in_default_namespace) > 0: + non_compliant_pods = [] + self.detailed_logs['desc'] = self.desc + "\nFollowing pods(s) are present in default namespace:" + for pod in pods_in_default_namespace: + non_compliant_pods.append(pod.metadata.name) + self.detailed_logs['non_compliant_pods'] = non_compliant_pods + return ("Failed") + else: + return ("Passed") + + +class CheckResourcesWithSecrets(AKSBaseControlTest): + name = "Use_KeyVault_To_Store_Secret" + + def __init__(self): + super().__init__() + self.desc = ("Use Azure Key Vault to store credentials/keys") + + @fail_with_manual + def test(self): + pods_with_secrets = [] + if len(self.pods) > 0: + svc_accounts = [] + + if len(self.service_accounts) > 0: + all_service_accounts = {svc.metadata.name for svc in self.service_accounts} + svc_accounts = dict.fromkeys(all_service_accounts, 0) + + for pod in self.pods: + info = { + "namespace": pod.metadata.namespace, + "pod_name": pod.metadata.name + } + is_secret_mounted = False + for volume in pod.spec.volumes: + if (volume.secret != None): + tokenIndex = volume.secret.secret_name.rfind('-token-') + if tokenIndex == -1 or not (volume.secret.secret_name[:tokenIndex] in svc_accounts): + is_secret_mounted = True + + if not is_secret_mounted: + for container in pod.spec.containers: + secret_key_refs = [] + if container.env != None and len(container.env) > 0: + secret_key_refs = list( + filter(lambda x: x.value_from != None and x.value_from.secret_key_ref != None, + container.env)) + if len(secret_key_refs) > 0: + is_secret_mounted = True + + if is_secret_mounted: + pods_with_secrets.append(info) + else: + return ("Manual") + if len(pods_with_secrets) > 0: + self.detailed_logs[ + 'desc'] = self.desc + "\nFollowing pod(s) are using Kubernetes secret objects to store secrets:" + self.detailed_logs['pods_with_secrets'] = pods_with_secrets + return ("Failed") + else: + return ("Passed") + + +class CheckKubernetesVersion(AKSBaseControlTest): + name = "Use_Latest_Kubernetes_Version" + + def __init__(self): + super().__init__() + self.desc = ("The latest version of Kubernetes should be used") + + @fail_with_manual + def test(self): + config.load_incluster_config() + v1 = client.CoreV1Api() + res = v1.list_node(watch=False) + nodes = list(res.items) + if len(nodes) > 0: + try: + node = nodes[0] + cur_version = node.status.node_info.kubelet_version + cur_version = cur_version.replace("v", "") + req_version = '1.14.6' + if StrictVersion(req_version) > StrictVersion(cur_version): + return ("Failed") + else: + return ("Passed") + + except: + return ("Manual") + + else: + return ("Manual") + +class CheckMountedImages(AKSBaseControlTest): + name = "Review_Mounted_Images_Source" + + def __init__(self): + super().__init__() + self.desc = ("Make sure container images deployed in cluster are trustworthy") + + @fail_with_manual + def test(self): + config.load_incluster_config() + v1 = client.CoreV1Api() + res = v1.list_node(watch=False) + nodes = list(res.items) + image_list = [] + whitelisted_sources = ['k8s.gcr.io','microsoft'] + for node in nodes: + images = node.status.images + for image in images: + image_list.append(image.names) + + images = [image for sublist in image_list for image in sublist if '@sha' not in image] + # Filter whitelisted images + images = [image for image in images if image.split('/')[0] not in whitelisted_sources] + + if len(images) > 0: + self.detailed_logs[ + 'desc'] = self.desc + "\nFollowing container images are mounted in Cluster:" + self.detailed_logs['container_images'] = images + + return ("Verify") + + +class CheckExternalServices(AKSBaseControlTest): + name = "Review_Publicly_Exposed_Services" + + def __init__(self): + super().__init__() + self.desc = ("Review services with external IP") + + @fail_with_manual + def test(self): + config.load_incluster_config() + try: + v1 = client.CoreV1Api() + res = v1.list_service_for_all_namespaces(watch=False) + services = list(res.items) + services_with_external_ip = [service for service in services if service.spec.type.lower() == 'loadbalancer'] + if len(services_with_external_ip) == 0: + return('Passed') + else: + non_compliant_services = [(service.metadata.namespace, service.metadata.name) for service in services_with_external_ip] + self.detailed_logs['desc'] = self.desc + "\nFollowing service(s) have external IP configured:" + self.detailed_logs['non_compliant_services'] = non_compliant_services + return("Verify") + except: + return("Manual") + + \ No newline at end of file diff --git a/src/AzSKPy/azskpy/loganalytics.py b/src/AzSKPy/azskpy/loganalytics.py new file mode 100644 index 000000000..fa3a00e3a --- /dev/null +++ b/src/AzSKPy/azskpy/loganalytics.py @@ -0,0 +1,53 @@ +import requests +import datetime +import hashlib +import hmac +import base64 +import json +from .constants import LOG_ANALYTICS_API_VERSION + + +class LogAnalyticsClient: + def __init__(self, workspace_id, shared_key): + self.workspace_id = workspace_id + self.shared_key = shared_key + + def __get_header(self, date, content_length): + sigs = "POST\n{}\napplication/json\nx-ms-date:{}\n/api/logs".format( + str(content_length), date) + utf8_sigs = sigs.encode('utf-8') + decoded_shared_key = base64.b64decode(self.shared_key) + hmac_sha256_sigs = hmac.new( + decoded_shared_key, utf8_sigs, digestmod=hashlib.sha256).digest() + b64bash = base64.b64encode(hmac_sha256_sigs).decode('utf-8') + authorization = "SharedKey {}:{}".format(self.workspace_id, b64bash) + return authorization + + def __rfcdate(self): + return datetime.datetime.utcnow().strftime('%a, %d %b %Y %H:%M:%S GMT') + + def __post_data(self, log_type, json_records): + if not log_type.isalpha(): + raise Exception( + "ERROR: log_type supports only alpha characters: {}".format(log_type)) + + body = json.dumps(json_records) + rfcdate = self.__rfcdate() + content_length = len(body) + signature = self.__get_header(rfcdate, content_length) + uri = "https://{}.ods.opinsights.azure.com/api/logs?api-version={}".format( + self.workspace_id, LOG_ANALYTICS_API_VERSION) + headers = { + 'content-type': 'application/json', + 'Authorization': signature, + 'Log-Type': log_type, + 'x-ms-date': rfcdate + } + return requests.post(uri, data=body, headers=headers) + + def post_data(self, json_log, log_name): + response = self.__post_data(log_name, json_log) + if response.status_code == 200: + print('Telemetry sent to Log Analytics') + else: + print("Failure in posting to Log Analytics: Error code:{}".format(response.status_code)) diff --git a/src/AzSKPy/azskpy/main.py b/src/AzSKPy/azskpy/main.py new file mode 100644 index 000000000..d6cb55012 --- /dev/null +++ b/src/AzSKPy/azskpy/main.py @@ -0,0 +1,448 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License in the project root for +# license information. +# ------------------------------------------------------------------------------ +from applicationinsights import TelemetryClient +from datetime import datetime +from .constants import __version__ +from .kubernetescontrols import * +from .sparkcontrols import * +from .loganalytics import LogAnalyticsClient + +sct = None + + +class DBSparkControlTester: + def __init__(self, **kwargs): + self.context = None + self.kwargs = kwargs + self.context = "databricks" + self.controls = [DiskEncryptionEnabled, StorageEncryptionKeySize, CheckNumberOfAdminsDB, + CheckGuestAdmin, SameKeyvaultReference, AccessTokenExpiry, NonKeyvaultBackend, + ExternalLibsInstalled, InitScriptPresent, MountPointPresent, TokenNearExpiry] + self.detailed_logs = [] + + def line(self): + print("-" * 123) + + def run_single_cluster(self, spark_context: dict, cluster_name): + df = pd.DataFrame(columns=["ControlName", + "ControlDescription", + "Result"]) + for Control in self.controls: + control = Control(spark_context, **self.kwargs) + result = control.test() + self.detailed_logs.append(result) + df = df.append({ + "ControlName": Control.name, + "ControlDescription": control.desc, + "Result": result.result + }, ignore_index=True) + self.print_report(df) + self.save_report(df) + self.print_detailed_logs() + self.send_telemetry_app_insights(df, cluster_name) + self.send_events_log_analytics(df, cluster_name) + return df + + def print_detailed_logs(self): + print("\n\nDetailed Logs") + self.line() + for result in self.detailed_logs: + print(result) + self.line() + + def run(self): + spark_contexts = self.get_db_clusters_config() + print("DevOps Kit (AzSK) for Cluster Security v", __version__) + scan_result_dfs = [] + for (spark_context, cluster_name) in spark_contexts: + self.line() + print("Running cluster scan for cluster: {}".format(cluster_name)) + df = self.run_single_cluster(spark_context, cluster_name) + scan_result_dfs.append(df) + print("AzSK Scan Completed") + self.update_post_scan_meta() + return scan_result_dfs + + def invoke_rest_api(self, end_point, body=None): + databricks_base_url = self.get_secret("DatabricksHostDomain") + pat = self.get_secret("AzSK_CA_Scan_Key") + url = databricks_base_url + "/api/2.0/" + end_point + header = { + "Authorization": "Bearer " + pat, + "Content-type": "application/json" + } + try: + if body: + res = requests.get(url=url, headers=header, json=body) + else: + res = requests.get(url=url, headers=header) + response = res.json() + except Exception as e: + print("Error making GET request") + print(e) + response = {} + return response + + def get_db_clusters_config(self): + config = self.invoke_rest_api("clusters/list") + configs = [] + for x in config["clusters"]: + if x["cluster_source"] != "JOB": + configs.append((x["spark_conf"], x["cluster_name"])) + return configs + + def get_secret(self, key): + dbutils = self.kwargs["dbutils"] + return dbutils.secrets.get(scope="AzSK_CA_Secret_Scope", + key=key) + + def get_ik(self): + try: + ik = self.get_secret("AzSK_AppInsight_Key") + except Exception as e: + ik = None + return ik + + def print_report(self, df): + print("{0: <45}{1: <70}{2: <8}".format("Control ID", + "Control Description", + "Status")) + self.line() + for indx, x in df.iterrows(): + print("{0: <45}{1: <70}{2: <8}".format(x["ControlName"], + x["ControlDescription"], + x["Result"])) + self.line() + + def get_base_telemetry(self): + return { + "ResourceType": "Databricks", + "EventName": "Control Scanned", + "ResourceName": self.get_secret("res_name"), + "SubscriptionId": self.get_secret("sid"), + "ResourceGroupName": self.get_secret("rg_name"), + "ControlID": None, + "VerificationResult": None + } + + def send_telemetry_app_insights(self, df, cluster_name): + ik = self.get_ik() + df = df[["ControlName", "Result"]] + if ik is None: + print("Skipping Telemetry") + else: + df_dict = df.to_dict("list") + df_dict = {x: y for x, y in + zip(df_dict["ControlName"], df_dict["Result"])} + tc = TelemetryClient(ik) + for control_name in df_dict: + tm_dict = self.get_base_telemetry() + tm_dict["ControlID"] = control_name + tm_dict["VerificationResult"] = df_dict[control_name] + tm_dict["ClusterName"] = cluster_name + tc.track_event("Scan Results", tm_dict) + tc.flush() + print("Telemetry Sent") + + def save_report(self, df): + self.dbutils = self.kwargs["dbutils"] + self.dbutils.fs.mkdirs("/AzSK_Logs/") + timestamp = str(datetime.now()).replace(" ", "_").replace("-", "").replace(":", "")[:-7] + df.to_csv("/dbfs/AzSK_Logs/AzSK_Scan_Results_{}.csv" + .format(timestamp)) + + def update_post_scan_meta(self): + dbutils = self.kwargs["dbutils"] + dbutils.fs.mkdirs("/AzSK_Meta/") + data = { + "Last Scan": str(datetime.now()), + "Log Analytics Workspace ID": self.get_secret("LAWorkspaceId"), + "Databricks Cluster": self.get_secret("res_name"), + "Databricks Resource Group": self.get_secret("rg_name"), + "Subscription ID": self.get_secret("sid") + } + dbutils.fs.put("/AzSK_Meta/meta.json", str(data), overwrite=True) + + def send_events_log_analytics(self, df, cluster_name): + oms_workspace_id = self.get_secret("LAWorkspaceId") + shared_key = self.get_secret("LASharedSecret") + la_client = LogAnalyticsClient(oms_workspace_id, shared_key) + df_dict = df.to_dict("list") + df_dict = {x: y for x, y in + zip(df_dict["ControlName"], df_dict["Result"])} + tm_list = [] + for control_name in df_dict: + tm_dict = self.get_base_telemetry() + tm_dict["ControlID"] = control_name + tm_dict["VerificationResult"] = df_dict[control_name] + tm_dict["ClusterName"] = cluster_name + tm_list.append(tm_dict) + la_client.post_data(tm_list, "AzSKInCluster") + + def get_recommendations(self): + for control in self.controls: + print("Control Name:", control.name) + print("Control Description:", control.desc) + print("Recommendation:", control.recommendation) + self.line() + print("Note: Controls marked as (+) will need spark.authenticate.password to be set. You can choose the " + "password of your choice.") + + +class HDISparkControlTester: + def __init__(self, spark_context, **kwargs): + self.spark_context = spark_context + self.context = None + self.kwargs = kwargs + self.context = "hdinsight" + self.controls = [DiskEncryptionEnabled, + AuthenticationEnabled, RPCEnabled, + EnableSASLEncryption, + SASLAlwaysEncrypt, StorageEncryptionKeySize] + self.detailed_logs = [] + + def line(self): + print("-" * 123) + + def run(self): + print("DevOps Kit (AzSK) for Cluster Security v", __version__) + self.line() + df = pd.DataFrame(columns=["ControlName", + "ControlDescription", + "Result"]) + + for Control in self.controls: + control = Control(self.spark_context, **self.kwargs) + result = control.test() + self.detailed_logs.append(result) + df = df.append({ + "ControlName": Control.name, + "ControlDescription": control.desc, + "Result": result.result + }, ignore_index=True) + self.print_report(df) + self.print_detailed_logs() + self.save_report(df) + self.send_telemetry(df) + return df + + def print_detailed_logs(self): + print("\n\nDetailed Logs") + self.line() + for result in self.detailed_logs: + print(result) + self.line() + + def get_ik(self): + ik = self.kwargs["app_insight_key"] + if ik != "": + return ik + else: + return "" + + def print_report(self, df): + print("{0: <45}{1: <70}{2: <8}".format("Control ID", "Control Description", "Status")) + self.line() + for indx, x in df.iterrows(): + print("{0: <45}{1: <70}{2: <8}".format(x["ControlName"], x["ControlDescription"], x["Result"])) + self.line() + + def send_telemetry(self, df): + ik = self.get_ik() + df = df[["ControlName", "Result"]] + if ik is "": + print("Skipping Telemetry") + else: + df_dict = df.to_dict("list") + df_dict = {x: y for x, y in zip(df_dict["ControlName"], df_dict["Result"])} + tc = TelemetryClient(ik) + tc.track_event("Scan Results", df_dict) + tc.flush() + print("Telemetry Sent to Application Insights") + + def save_report(self, df): + current_scan_time = str(datetime.now()) + gca_metadata = { + "CA Version": [__version__], + "Last Scan": [current_scan_time], + "App Insight Key": [self.get_ik()], + "Subscription Id": [self.kwargs.get("sid")], + "Resource Group Name": [self.kwargs.get("rg_name")], + "Resource Name": [self.kwargs.get("res_name")], + "CA Notebook Path": ["/PySpark/AzSK_CA_Note"], + "Metadata Store Path": ["/AzSK_Metadata/"], + "Scan Logs Sore Path": ["/AzSK_Logs/"] + } + current_scan_time = current_scan_time.replace(" ", "-").replace(":", "-") + gca_metadata_df = pd.DataFrame.from_dict(gca_metadata) + scanlog_df = pd.DataFrame(df) + spark_metadata_rdd = self.spark_context.createDataFrame(gca_metadata_df) + spark_metadata_rdd.write.json(HDI_METADATA_WRITE_PATH.format(current_scan_time)) + scanlog_rdd = self.spark_context.createDataFrame(scanlog_df) + scanlog_rdd.write.json(HDI_SCANLOG_WRITE_PATH.format(current_scan_time)) + + def get_recommendations(self): + for control in self.controls: + print("Control Name:", control.name) + print("Control Description:", control.desc) + print("Recommendation:", control.recommendation) + self.line() + print("Note: Controls marked as (+) will need spark.authenticate.password to be set. You can choose the " + "password of your choice.") + + +def get_databricks_security_scan_status(**kwargs): + global sct + sct = DBSparkControlTester(**kwargs) + sct.run() + + +def get_hdinsight_security_scan_status(spark_context, **kwargs): + global sct + sct = HDISparkControlTester(spark_context, **kwargs) + sct.run() + return sct + + +def get_spark_recommendations(): + global sct + if sct is None: + print("No context found, please run `get_cluster_security_scan_status` first to instantiate the tester") + else: + sct.get_recommendations() + + +class AKSControlTester: + def __init__(self): + self.context = "kubernetes" + self.controls = [CheckContainerRunAsNonRoot, + CheckContainerPrivilegeEscalation, CheckContainerPrivilegeMode, + CheckInactiveServiceAccounts, + CheckClusterManagedIdentity, + CheckContainerReadOnlyRootFilesystem, + CheckDefaultSvcRoleBinding, + CheckDefaultNamespaceResources, + CheckResourcesWithSecrets, + CheckKubernetesVersion, + CheckExternalServices, + CheckMountedImages] + + def run(self): + print("DevOps Kit (AzSK) for Cluster Security v", __version__) + self.line() + df = pd.DataFrame(columns=["ControlName", + "ControlDescription", + "Result"]) + detailed_logs_list = [] + for Control in self.controls: + control = Control() + result = control.test() + control.detailed_logs['control_id'] = Control.name + detailed_logs_list.append(control.detailed_logs) + result_item = { + "ControlName": Control.name, + "ControlDescription": control.desc, + "Result": result + } + df = df.append(result_item, ignore_index=True) + control.send_telemetry("AzSK AKS Control Scanned", result_item) + + self.print_report(df) + self.print_detailed_logs(detailed_logs_list) + self.save_report(df) + + def line(self): + print("-" * 138) + + def send_events_log_analytics(self, df, cluster_name): + oms_workspace_id = self.get_secret("LAWorkspaceId") + shared_key = self.get_secret("LASharedSecret") + la_client = LogAnalyticsClient(oms_workspace_id, shared_key) + df_dict = df.to_dict("list") + df_dict = {x: y for x, y in + zip(df_dict["ControlName"], df_dict["Result"])} + tm_list = [] + for control_name in df_dict: + tm_dict = self.get_base_telemetry() + tm_dict["ControlID"] = control_name + tm_dict["VerificationResult"] = df_dict[control_name] + tm_dict["ClusterName"] = cluster_name + tm_list.append(tm_dict) + la_client.post_data(tm_list, "AzSKInCluster") + + def save_report(self, df): + pass + + def print_report(self, df): + print(" {0: <40}{1: <85}{2: <10}".format("Control ID", "Control Description", "Status")) + self.line() + for indx, x in df.iterrows(): + print("{0: <3}{1: <40}{2: <85}{3: <10}".format(indx + 1, x["ControlName"], x["ControlDescription"], + x["Result"])) + self.line() + + def print_detailed_logs(self, detailed_logs_ls): + print("\n") + self.line() + print("Detailed Logs:") + self.line() + detailed_log_printed = False + for detailed_logs_item in detailed_logs_ls: + if len(detailed_logs_item['non_compliant_containers']) > 0: + detailed_log_printed = True + df = pd.DataFrame(detailed_logs_item['non_compliant_containers']) + print("{0} : {1}".format(detailed_logs_item['control_id'], detailed_logs_item['desc'])) + print(" {0: <25}{1: <50}{2: <25}".format("Namespace", "Pod", "Container")) + for indx, x in df.iterrows(): + print("{0: <3}{1: <25}{2: <50}{3: <25}".format(indx + 1, x["namespace"], x["pod_name"], + x["container"])) + self.line() + elif len(detailed_logs_item['service_accounts']) > 0: + detailed_log_printed = True + print("{0} : {1}".format(detailed_logs_item['control_id'], detailed_logs_item['desc'])) + df = pd.DataFrame(list(detailed_logs_item['service_accounts']), + columns=['ServiceAccountName', 'NameSpace']) + print(" {0: <25}{1: <50}".format("Namespace", "ServiceAccount")) + for indx, x in df.iterrows(): + print("{0: <3}{1: <25}{2: <50}".format(indx + 1, x[1], x[0])) + self.line() + elif len(detailed_logs_item['non_compliant_pods']) > 0: + detailed_log_printed = True + print("{0} : {1}".format(detailed_logs_item['control_id'], detailed_logs_item['desc'])) + for pod in detailed_logs_item['non_compliant_pods']: + print(pod) + self.line() + elif len(detailed_logs_item['pods_with_secrets']) > 0: + detailed_log_printed = True + print("{0} : {1}".format(detailed_logs_item['control_id'], detailed_logs_item['desc'])) + df = pd.DataFrame(list(detailed_logs_item['pods_with_secrets'])) + print(" {0: <25}{1: <50}".format("Namespace", "Pod")) + for indx, x in df.iterrows(): + print("{0: <3}{1: <25}{2: <50}".format(indx + 1, x[0], x[1])) + self.line() + elif len(detailed_logs_item['container_images']) > 0: + detailed_log_printed = True + print("{0} : {1}".format(detailed_logs_item['control_id'], detailed_logs_item['desc'])) + df = pd.DataFrame(list(detailed_logs_item['container_images'])) + for indx, x in df.iterrows(): + print("{0: <3}{1: <25}".format(indx + 1, x[0])) + self.line() + elif len(detailed_logs_item['non_compliant_services']) > 0: + detailed_log_printed = True + print("{0} : {1}".format(detailed_logs_item['control_id'], detailed_logs_item['desc'])) + df = pd.DataFrame(list(detailed_logs_item['non_compliant_services']), + columns=['NameSpace','ServiceName']) + print(" {0: <25}{1: <50}".format("Namespace", "Service Name")) + for indx, x in df.iterrows(): + print("{0: <3}{1: <25}{2: <50}".format(indx + 1, x[0], x[1])) + self.line() + if (not detailed_log_printed): + print("No detailed logs to show.") + self.line() + + +def run_aks_cluster_scan(): + AKSControlTester().run() diff --git a/src/AzSKPy/azskpy/sparkcontrols.py b/src/AzSKPy/azskpy/sparkcontrols.py new file mode 100644 index 000000000..f7e75528e --- /dev/null +++ b/src/AzSKPy/azskpy/sparkcontrols.py @@ -0,0 +1,554 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License in the project root for +# license information. +# ------------------------------------------------------------------------------ +import json + +import requests +import pandas as pd +import numpy as np +from datetime import datetime +from .constants import * +from .utils import fail_with_manual, TestResponse + + +class BaseControlTest: + def __init__(self, spark_context, **runtime_params): + self.spark = spark_context + self.runtime_params = runtime_params + # spark_context will be dictionary, if the environment is + # Databricks. Else it will be an instance of spark + if isinstance(spark_context, dict): + self.__dictconfig = spark_context + else: + self.__config = self.spark.sparkContext.getConf().getAll() + self.__dictconfig = {} + for setting in self.__config: + self.__dictconfig[setting[0]] = setting[1] + + @property + def config(self): + return self.__dictconfig + + def fail_by_default_check(self, spark_setting, desired_value="true"): + if spark_setting not in self.config: + # name will be set in the inherited class + return TestResponse(self.name, spark_setting + " in spark settings", + spark_setting + " was not set", FAILED) + if self.config[spark_setting] != desired_value: + return TestResponse(self.name, spark_setting + " in spark settings", + spark_setting + " was not set", FAILED) + else: + return TestResponse(self.name, spark_setting + " in spark settings", + spark_setting + " was set as expected", PASSED) + + def pass_by_default_check(self, spark_setting, desired_value="true"): + if spark_setting not in self.config: + return TestResponse(self.name, spark_setting + " in spark settings", + spark_setting + " was set as expected", PASSED) + if self.config[spark_setting] != desired_value: + return TestResponse(self.name, spark_setting + " in spark settings", + spark_setting + " was not set", FAILED) + else: + return TestResponse(self.name, spark_setting + " in spark settings", + spark_setting + " was set as expected", PASSED) + + def set_credentials(self, uname, password): + pass + + def get_secret(self, key): + dbutils = self.runtime_params["dbutils"] + return dbutils.secrets.get(scope="AzSK_CA_Secret_Scope", + key=key) + + def invoke_rest_api(self, end_point, body=None): + databricks_base_url = self.get_secret("DatabricksHostDomain") + pat = self.get_secret("AzSK_CA_Scan_Key") + url = databricks_base_url + "/api/2.0/" + end_point + header = { + "Authorization": "Bearer " + pat, + "Content-type": "application/json" + } + try: + if body: + res = requests.get(url=url, headers=header, json=body) + else: + res = requests.get(url=url, headers=header) + response = res.json() + except Exception as e: + print("Error making GET request") + print(e) + response = {} + return response + + +class CheckNumberOfAdminsHDI(BaseControlTest): + # TODO + name = "" + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + self.__dictconfig = None + self.desc = "Number of admins should not be over 2" + + @fail_with_manual + def test(self): + adminCount = 0 + for item in self.__dictconfig["items"]: + for priv in item["privileges"]: + if priv["PrivilegeInfo"][ + "permission_name"] == "AMBARI.ADMINISTRATOR": + adminCount += 1 + if adminCount > 2: + return FAILED + return PASSED + + +class DiskEncryptionEnabled(BaseControlTest): + name = "Enable_Disk_Encryption" + recommendation = "Set `spark.io.encryption.enabled` to `true` in the Spark configuration" + desc = "Local disk storage encryption should be enabled" + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + self.name = "Enable_Disk_Encryption" + + @fail_with_manual + def test(self): + return self.fail_by_default_check("spark.io.encryption.enabled") + + +class AuthenticationEnabled(BaseControlTest): + name = "Enable_Internal_Authentication" + recommendation = "Set `spark.authenticate` to `true` in the Spark configuration (+)" + desc = "Checks Spark internal connection authentication" + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + self.name = "Enable_Internal_Authentication" + + @fail_with_manual + def test(self): + return self.fail_by_default_check("spark.authenticate") + + +class RPCEnabled(BaseControlTest): + name = "Enable_RPC" + recommendation = "Set `spark.network.crypto.enabled` to `true` in the Spark configuration" + desc = "Enable AES-based RPC encryption" + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + self.name = "Enable_RPC" + + def test(self): + return self.fail_by_default_check("spark.network.crypto.enabled") + + +class EnableSASLEncryption(BaseControlTest): + name = "Enable_SASL_Encryption" + recommendation = "Set `spark.authenticate.enableSaslEncryption` to `true` in the Spark configuration (+)" + desc = "Enable SASL-based encrypted communication." + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + + @fail_with_manual + def test(self): + return self.fail_by_default_check( + "spark.authenticate.enableSaslEncryption") + + +class SASLAlwaysEncrypt(BaseControlTest): + name = "Enable_Always_Encrypt_In_SASL" + recommendation = "Set `spark.network.sasl.serverAlwaysEncrypt` to `true` in the Spark configuration" + desc = ("Disable unencrypted connections for ports using SASL" + " authentication") + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + + @fail_with_manual + def test(self): + return self.fail_by_default_check( + "spark.network.sasl.serverAlwaysEncrypt") + + +class StorageEncryptionKeySize(BaseControlTest): + name = "Use_Strong_Encryption_Keysize" + recommendation = "Set `spark.io.encryption.keySizeBits` to `256` in the Spark configuration" + desc = "256 bit encryption is recommended" + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + + @fail_with_manual + def test(self): + return self.fail_by_default_check("spark.io.encryption.keySizeBits", + "256") + + +class WASBSProtocol(BaseControlTest): + name = "Enable_WASBS" + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + self.desc = ("SSL supported for WASB") + + @property + def config(self): + return "tomato" + + def get_notebooks_from_cluster(self): + return "tomato" + + @fail_with_manual + def test(self): + for notebook_content in self.get_notebooks_from_cluster(): + if "wasb://" in notebook_content: + return FAILED + return PASSED + + +class SSLEnabled(BaseControlTest): + name = "Enable_SSL" + recommendation = "Set `spark.ssl.enabled` to `true` in the Spark configuration" + desc = "SSL should be enabled" + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + + @fail_with_manual + def test(self): + return self.fail_by_default_check("spark.ssl.enabled") + + +class SSLKeyPassword(BaseControlTest): + name = "Dont_Use_Plaintext_Key_Pwd" + recommendation = "Do not set `spark.ssl.keyPassword` in the Spark configuration" + desc = ("Password to private key in keystore shouldn't be" + " stored in plaintext") + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + + @fail_with_manual + def test(self): + if "spark.ssl.keyPassword" in self.config: + return FAILED + else: + return PASSED + + +class SSLKeyStorePassword(BaseControlTest): + name = "Dont_Use_Plaintext_KeyStore_Pwd" + recommendation = "Do not set `spark.ssl.keyStorePassword` in the Spark configuration" + desc = ("Password to key store should not be" + " stored in plaintext") + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + + @fail_with_manual + def test(self): + if "spark.ssl.keyStorePassword" in self.config: + return FAILED + else: + return PASSED + + +class SSLTrustedStorePassword(BaseControlTest): + name = "Dont_Use_Plaintext_TrustedStore_Pwd" + recommendation = " Do not set `spark.ssl.trustStorePassword` in the Spark configuration" + desc = ("Password to the trusted store should not be" + " stored in plaintext") + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + + @fail_with_manual + def test(self): + if "spark.ssl.trustStorePassword" in self.config: + return FAILED + else: + return PASSED + + +class XSSProtectionEnabled(BaseControlTest): + name = "Enable_HTTP_XSS_Protection_Header" + recommendation = "Set `spark.ui.xXssProtection` to `1; mode=block` in the Spark configuration" + desc = ("HTTP X-XSS-Protection response header" + " should be set") + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + + @fail_with_manual + def test(self): + return self.pass_by_default_check("spark.ui.xXssProtection", + "1; mode=block") + + +class CheckNumberOfAdminsDB(BaseControlTest): + name = "Limit_Workspace_Admin_Count" + recommendation = "Limit the number of admins to less than 5 in the Databricks Admin Settings" + desc = ("Number of admins should be less than" + " or equal to 5") + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + body = {"group_name": "admins"} + self.data = self.invoke_rest_api("groups/list-members", body) + + @fail_with_manual + def test(self): + admins = self.data["members"] + expected_response = "Number of admins less than 5" + pass_response = "Number of admins are less than 5" + fail_response = "Number of admins are over 5. List:\n" + for i, admin in enumerate(admins): + fail_response += "\n\t{}. {}".format(i + 1, admin["user_name"]) + if len(admins) <= 5: + return TestResponse(self.name, expected_response, pass_response, PASSED) + else: + return TestResponse(self.name, expected_response, fail_response, FAILED) + + +class CheckGuestAdmin(BaseControlTest): + name = "Prohibit_Guest_Account_Admin_Access" + recommendation = "Disable administrator access to non microsoft email users" + desc = "Disable admin access to guests" + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + body = {"group_name": "admins"} + self.data = self.invoke_rest_api("groups/list-members", body) + + @fail_with_manual + def test(self): + admins = self.data['members'] + non_ms_accounts = [] + expected_response = "No non-MS accounts should have administrator privileges" + for admin in admins: + user_name = admin['user_name'] + domain = user_name.split("@")[1] + if domain != "microsoft.com": + non_ms_accounts.append(user_name) + if non_ms_accounts: + fail_response = "Following non-MS accounts have administrator privileges:" + for i, non_ms_account in enumerate(non_ms_accounts): + fail_response += "\n\t{}. {}".format(i + 1, non_ms_account) + return TestResponse(self.name, expected_response, fail_response, FAILED) + else: + pass_response = "No non-MS accounts have administrator privileges" + return TestResponse(self.name, expected_response, pass_response, PASSED) + + +class SameKeyvaultReference(BaseControlTest): + name = "Use_Independent_Keyvault_Per_Scope" + recommendation = "Independent Keyvaults should be used for secrets" + desc = ("Same Keyvault should not be referenced by multiple" + " secret scopes") + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + self.data = self.invoke_rest_api("secrets/scopes/list") + + @fail_with_manual + def test(self): + expected_response = self.desc + pass_response = "Same Keyvault not referenced by multiple scopes" + fail_response = "Found Keyvault with multiple references:" + secretScopeLists = self.data['scopes'] + # todo: optimize this. There's no need of pandas to filter things + keyVaultBackedSecretScope = list( + filter(lambda x: x['backend_type'] == 'AZURE_KEYVAULT', + secretScopeLists)) + if len(keyVaultBackedSecretScope) == 0: + return TestResponse(self.name, expected_response, pass_response, PASSED) + SummarizedList = pd.DataFrame([{'KeyVault_ResourceId': + item['keyvault_metadata']['resource_id'], + 'ScopeName': item['name']} for item in + keyVaultBackedSecretScope]).groupby( + 'KeyVault_ResourceId').agg(np.size) + KeyVaultWithManyReference = SummarizedList[SummarizedList['ScopeName'] > 1] + if KeyVaultWithManyReference.empty: + return TestResponse(self.name, expected_response, pass_response, PASSED) + else: + for idx, (row, _) in enumerate(KeyVaultWithManyReference.iterrows()): + fail_response += "\n\t{}. {}".format(idx + 1, row.split("/")[-1]) + return TestResponse(self.name, expected_response, fail_response, FAILED) + + +class AccessTokenExpiry(BaseControlTest): + name = "Keep_Minimal_Token_Validity" + recommendation = "Personal Access Token (PAT) should have minimum validity" + desc = "Use minimum validity token for PAT" + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + self.data = self.invoke_rest_api("token/list") + + def test(self): + expected_response = "PAT token should be have minimum validity (<90 days)" + pass_response = "PAT tokens have minimum validity" + token_lists = self.data['token_infos'] + failed = False + infinite_pat = list( + filter(lambda x: x['expiry_time'] == -1, token_lists)) + finite_pat = list( + filter(lambda x: x['expiry_time'] != -1, token_lists)) + long_pat = list(filter(lambda x: + (datetime.utcfromtimestamp(x['expiry_time'] / 1000) + - datetime.utcfromtimestamp(x['creation_time'] / 1000)).days > 90, + finite_pat)) + if infinite_pat: + fail_response = "PAT token with indefinite validity:" + failed = True + for i, x in enumerate(infinite_pat): + fail_response += "\n\t{}. {}".format(i + 1, x["comment"]) + if long_pat: + fail_response = "PAT token with > 90 day validity" + failed = True + for i, x in enumerate(long_pat): + fail_response += "\n\t{}. {}".format(i + 1, x["comment"]) + if failed: + return TestResponse(self.name, expected_response, fail_response, FAILED) + else: + return TestResponse(self.name, expected_response, pass_response, PASSED) + + +class NonKeyvaultBackend(BaseControlTest): + name = "Use_KeyVault_Backed_Secret_Scope" + recommendation = "Use only Keyvault backed secrets" + desc = "Use Azure Keyvault backed secret scope to hold secrets" + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + self.data = self.invoke_rest_api("secrets/scopes/list") + + @fail_with_manual + def test(self): + secretScopeLists = self.data['scopes'] + expected_response = "All secrets should be backed by Azure KeyVault" + pass_response = "All secrets are backed by Azure KeyVault" + fail_response = "Found following scopes in non-Azure KeyVault backend\n" + DataBricksBackedSecretScope = list( + filter(lambda x: x['backend_type'] != 'AZURE_KEYVAULT', + secretScopeLists)) + if DataBricksBackedSecretScope: + for i, x in enumerate(DataBricksBackedSecretScope): + fail_response += "\n\t{}. {}".format(i + 1, x["name"]) + return TestResponse(self.name, expected_response, fail_response, FAILED) + else: + return TestResponse(self.name, expected_response, pass_response, PASSED) + + +class ExternalLibsInstalled(BaseControlTest): + name = "External_Libs_Installed" + recommendation = "Avoid using external libraries from the internet" + desc = recommendation + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + self.data = self.invoke_rest_api("libraries/all-cluster-statuses") + + def test(self): + expected_response = "External libraries should be absent or verified" + pass_response = "No external libraries are installed" + fail_response = "Following external libraries on cluster:\n" + data = self.data + fails = False + if "statuses" in data and len(data["statuses"]) > 0: + try: + for i, x in enumerate(data["statuses"][0]["library_statuses"]): + # todo: this currently only considers packages installed from PyPi should be extended + fail_response += "\n\t{}. {}".format(i + 1, x["library"]["pypi"]["package"]) + fails = True + except: + pass + + if fails: + return TestResponse(self.name, expected_response, fail_response, "Verify") + else: + return TestResponse(self.name, expected_response, pass_response, PASSED) + + +class InitScriptPresent(BaseControlTest): + name = "Init_Scripts_Present" + recommendation = "Where present, init scripts should be verified" + desc = recommendation + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + self.data = self.invoke_rest_api("clusters/list") + + def test(self): + expected_response = "Init scripts should be absent or verified" + pass_response = "Init scripts absent" + fail_response = "Init scripts at the following location:\n" + for x in self.data["clusters"]: + if x["cluster_source"] != "JOB": + if "init_scripts" in x: + for i, filepath in enumerate(x["init_scripts"]): + fail_response += "\n\t{}. {}".format(i + 1, filepath["dbfs"]["destination"]) + return TestResponse(self.name, expected_response, fail_response, MANUAL) + return TestResponse(self.name, expected_response, pass_response, PASSED) + + +class MountPointPresent(BaseControlTest): + name = "Mount_Points_Present" + recommendation = "Where present, mount points should be verified" + desc = recommendation + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + dbutils = self.runtime_params["dbutils"] + self.data = dbutils.fs.mounts() + + def test(self): + expected_response = "Mount points should be absent or verified" + pass_response = "Mount points absent" + fail_response = "Unsafe mount points:\n" + safe_mounts = {"DatabricksRoot", "databricks-datasets", "databricks-results"} + ctr = 1 + verify = False + for mount_point in self.data: + if mount_point.source not in safe_mounts: + fail_response += "\n\t{}. Location: {} Source: {}".format(ctr, mount_point.mountPoint, mount_point.source) + verify = True + ctr += 1 + if verify: + return TestResponse(self.name, expected_response, fail_response, MANUAL) + else: + return TestResponse(self.name, expected_response, pass_response, PASSED) + + +class TokenNearExpiry(BaseControlTest): + name = "Token_Near_Expiry" + recommendation = "Expiry for PAT tokens should be greater than 30 days" + desc = recommendation + + def __init__(self, spark_context, **runtime_params): + super().__init__(spark_context, **runtime_params) + self.data = self.invoke_rest_api("token/list") + + def test(self): + expected_response = "PAT tokens expiry should be >30 days" + pass_response = "PAT tokens are far from expiry" + fail_response = "Following PAT tokens near expiry (<30 days):\n" + assert "token_infos" in self.data + now = datetime.now() + ctr = 1 + fail = False + for tokens in self.data["token_infos"]: + expiry = datetime.fromtimestamp(tokens["expiry_time"] // 1000) + howfar = (expiry - now) + if howfar.days <= 30: + fail_response += "\n\t{}. {}".format(ctr, tokens["comment"]) + fail = True + if fail: + return TestResponse(self.name, expected_response, fail_response, FAILED) + else: + return TestResponse(self.name, expected_response, pass_response, PASSED) \ No newline at end of file diff --git a/src/AzSKPy/azskpy/utils.py b/src/AzSKPy/azskpy/utils.py new file mode 100644 index 000000000..bd0bc096f --- /dev/null +++ b/src/AzSKPy/azskpy/utils.py @@ -0,0 +1,42 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License in the project root for +# license information. +# ------------------------------------------------------------------------------ +from functools import wraps +from .constants import * + + +def fail_with_manual(func, *args, **kwargs): + @wraps(func) + def wrapper(*args, **kwargs): + try: + result = func(*args, **kwargs) + return result + except Exception as e: + print("Exception in executing {} with" + " args: {} kwargs: {}".format(func.__name__, + args, kwargs)) + print(e) + return TestResponse("Manual") + + return wrapper + + +class TestResponse: + def __init__(self, control_name=None, expected=None, actual=None, result="Unverified"): + self.control_name = control_name + self.expected = expected + self.actual = actual + self.result = result + + def __str__(self): + if self.result == PASSED: + return """[{}]: {}\nExpected {}\nFound configuration as expected.""".format(self.result, + self.control_name, + self.actual) + elif self.result == FAILED or self.result == MANUAL: + return """[{}]: {}\nExpected {}\nFound: {}""".format(self.result, + self.control_name, + self.expected, + self.actual)