# Securing AI/ ML Models With HiddenLayer's AISec Platform

This tutorial describes how to use the HiddenLayer's AI Security Platform to detect and respond to an attack on a machine learning model. This guide will provide examples of using HiddenLayer's Model Scanner, Machine Learning Detection and Response, and Prompt-Analyzer. All detections made by the AISec Platform will map to one or more [Mitre Atlas](https://atlas.mitre.org/) tactics and techniques. All models and data in this demo are open source and available via [HuggingFace 🤗](https://huggingface.co).

### Model Scanner 🎯
HiddenLayer's Model Scanner is a tool to statically scan and analyze models to identify any threats or vulnerabilities in the model artifact. It is always best practice to scan any code or artifact before it is loaded into memory. Lots of modern tooling can load a model from a 3rd party repository directly into memory and opens up a threat vector of downloading and loading a model with malicious intent.

Supported Formats
- Pickle (numpy, joblib, scikit-learn)
- PyTorch (pickle, zip)
- TensorFlow/ Keras (tf, h5, protobuf)
- SafeTensors
- Onnx

##### Steps 🚀
1. Setup environment for demo
    1. Install libraries
    2. Login and create API credentials for demo
    3. Initalize AISec Platform client for demo
2. Model Scanner Scenarios
    1. Scan LLM 
        - Model - [Microsoft Phi-2](https://huggingface.co/microsoft/phi-2/tree/main)
    2. Scan Production Ready Model 
        - Model - [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32/tree/main)
    3. Scan Non-Production Ready Model (FastAI 3rd Party Tool)
        - Model - [fastai/fastbook_04_mnist_basics](https://huggingface.co/fastai/fastbook_04_mnist_basics/tree/main)
    4. Scan Malicious Model 
        - Model - [ScanMe/Models](https://huggingface.co/ScanMe/Models/tree/main)



##### ❗️ For any issues with the notebook or feedback on how we could improve the experience please contact support@hiddenlayer.com

---

# Setup environment for demo
1. Install libraries
2. Login and create API credentials for demo
3. Initalize AISec Platform client for demo

#### Install Libraries

In [1]:
# install libraries needed for demo
#%pip install httpx adversarial-robustness-toolbox scikit-learn numpy numba huggingface_hub requests torch Pillow transformers matplotlib datasets "datasets[vision]" nest-asyncio

#### Login and create API credentials for demo

## ❗️ Credentials must be created and copied into the notebok to run demos

Navigate to the [HiddenLayer Console](https://console.hiddenlayer.ai/admin?activeTab=apiKeys) and login to create new API credentials to use for this demo. Copy and paste the client id and client secret into the next step of the notebook

In [2]:
# add client credentials for the demo
import os
client_id = os.getenv('HIDDENLAYER_CLIENT_ID')
client_secret = os.getenv('HIDDENLAYER_CLIENT_SECRET')

#### Initalize AISec Platform client for demo

The follow is the client code to drive this demo notebook. If there are any issues or bugs found please contact us at support@hiddenlayer.com.

In [3]:
import asyncio
import nest_asyncio
import httpx
import huggingface_hub
import re
import os
import base64

from huggingface_hub import hf_hub_download
from typing import Any, Callable, Optional
from random import randint

nest_asyncio.apply()

class AISecClient(object):
    def __init__(self, client_id:str, client_secret: str, base_url: str = "https://api.hiddenlayer.ai", retry_count: int = 3, timeout: int = 30):
        self.client_id = client_id
        self.client_secret = client_secret
        self.client_auth = httpx.BasicAuth(client_id, client_secret)
        self.base_url = base_url
        self.token_url = "https://auth.hiddenlayer.ai/oauth2/token?grant_type=client_credentials"
        self.retry_count = retry_count
        self.async_client = httpx.AsyncClient(timeout=None)

    @staticmethod
    def b64str(b: bytes) -> str:
        """
        Convert bytes to base64 string"""
        return base64.b64encode(b).decode()
    
    async def _set_token(self) -> str:
        """
        Get access token from HiddenLayer API"""
        err_exc = Exception("Error: Unable to retrieve token")
        resp = await self.async_client.post(self.token_url, auth=self.client_auth)
        if resp.status_code != 200:
            raise err_exc
        content = resp.json()
        if "access_token" not in content:
            raise err_exc
        self.token = content["access_token"]
        self.async_client.headers = {"Authorization": f"Bearer {self.token}"}

    async def _async_request_handler(self, meth: Callable, url: str, **kwargs: Any) -> httpx.Response:
        """
        Handle async requests to HiddenLayer API"""
        resp = None
        for i in range(self.retry_count + 1):
            resp = await meth(url, **kwargs)
            if resp.status_code == 401:
                await self._set_token()
                continue
            elif resp.status_code < 500:
                break
            await asyncio.sleep(randint(1, i + 1) / 100)  # nosec
        return resp
    
    async def create_sensor(self, name: str, tags: dict = None):
        """
        Create a sensor in HiddenLayer"""
        tags = {} if tags is None else tags
        resp = await self._async_request_handler(
            self.async_client.post, 
            f"{self.base_url}/api/v2/sensors/create", 
            json={"plaintext_name": name, "active": True, "tags": tags}
        )
        return resp.json() if resp.is_success else None
    
    async def get_sensor_by_name_version(self, name: str, version: int):
        """
        Get a sensor by name and version"""
        sensor = None
        resp = await self._async_request_handler(
            self.async_client.post, 
            f"{self.base_url}/api/v2/sensors/query", 
            json={"filter":{"plaintext_name": name, "version": version}}
        )
        if resp.is_success:
            content = resp.json()
            if content["total_count"] >= 1:
                sensor = content["results"][0]
        return sensor

    async def get_or_create_sensor(self, name: str, version: int = 1, tags: dict = None):
        """
        Get or create a sensor by name and version"""
        sensor = None
        tags = {} if tags is None else tags
        sensor = await self.get_sensor_by_name_version(name, version)
        if sensor is None:
            sensor = await self.create_sensor(name, tags=tags)
        return sensor

    async def _start_multipart_upload(self, sensor_id: str, filesize: int) -> dict:
        """
        Start a multipart upload for a sensor"""
        headers = {"X-Content-Length": str(filesize)}
        resp = await self._async_request_handler(
            self.async_client.post,
            f"{self.base_url}/api/v2/sensors/{sensor_id}/upload/begin",
            headers=headers
        )
        return resp.json() if resp.is_success else None

    async def _upload_parts(self, sensor_id: str, multipart: dict, filepath: str) -> None:
        """
        Upload parts for a multipart upload"""
        chunk = 4
        upload_id = multipart["upload_id"]
        with open(filepath, "rb") as fin:
            for i in range(0, len(multipart["parts"]), chunk):
                upload_tasks = []
                group = multipart["parts"][i:i+chunk]
                for p in group:
                    part_number = p["part_number"]
                    read_amt = p["end_offset"] - p["start_offset"]
                    fin.seek(p["start_offset"])
                    part_data = fin.read(read_amt)
                    t = self._async_request_handler(
                        self.async_client.put, 
                        f"{self.base_url}/api/v2/sensors/{sensor_id}/upload/{upload_id}/part/{part_number}", data=part_data
                    )
                    upload_tasks.append(t)
                results = await asyncio.gather(*upload_tasks)

    async def _complete_multipart_upload(self, sensor_id: str, upload_id: str) -> bool:
        """
        Complete a multipart upload"""
        resp = await self._async_request_handler(
            self.async_client.post,
            f"{self.base_url}/api/v2/sensors/{sensor_id}/upload/{upload_id}/complete",
        )
        return resp.is_success

    async def upload(self, sensor_id: str, filepath: str, verbose: bool = False):
        """
        Upload a model to HiddenLayer"""
        filesize = os.path.getsize(filepath)
        # start multipart upload
        if verbose:
            print(f"Starting upload for {sensor_id}: {filepath}")
        multipart = await self._start_multipart_upload(sensor_id, filesize)
        # upload parts
        await self._upload_parts(sensor_id, multipart, filepath)
        # complete multipart upload
        success = await self._complete_multipart_upload(sensor_id, multipart["upload_id"])
        if verbose:
            print(f"Completed upload for {sensor_id}: {filepath}")
        return success

    async def scan_sensor(self, sensor_id: str):
        """
        Scan a sensor in HiddenLayer after upload"""
        # kick off scan for sensor id
        resp = await self._async_request_handler(
            self.async_client.post,
            f"{self.base_url}/api/v2/submit/sensors/{sensor_id}/scan"
        )
        return resp.is_success
        
    async def upload_and_scan(self, model_name: str, filepath: str, version: int = 1, tags: dict = None, verbose: bool = False):
        """
        Upload and scan a model"""
        tags = {} if tags is None else tags
        sensor = await self.get_or_create_sensor(model_name, version=version, tags=tags)
        sensor_id = sensor["sensor_id"]
        await self.upload(sensor_id, filepath, verbose=verbose)
        ok = await self.scan_sensor(sensor_id)
        return ok, sensor

    async def get_scan_results(self, sensor_id: str, max_retry_count: int = 160, verbose: bool = False, wait: bool = True):
        """
        Get scan results for a sensor"""
        retry = 1
        results = None
        while retry < max_retry_count: 
            resp = await self._async_request_handler(
                self.async_client.get,
                f"{self.base_url}/api/v2/scan/status/{sensor_id}"
            )
            if not resp.is_success:
                break
            else:
                content = resp.json()
                status = content["status"]
                if not wait or status == "done":
                    results = content
                    break
                else:
                    if verbose:
                        print(f"status, {sensor_id}, {status}")
                    await asyncio.sleep(5)  
            retry += 1

        return results
    
    def pretty_print_scan_results(self, scan_results: dict):
        """
        Pretty print scan results"""
        for sensor_id, scan_result in scan_results.items():
            verdict = "SAFE" if len(scan_result["scan_results"]["detections"]) == 0 else "UNSAFE"
            print(f"{scan_result['sensor']['plaintext_name']} ({sensor_id})")
            print("#"*128)
            print(f"{'Verdict':<25}{verdict:<48}")
            try:
                print(f"{'File':<25}{scan_result['filename']:<48}")
                print(f"{'Type':<25}{scan_result['scan_results']['results']['type']:<48}")
                print(f"{'MD5':<25}{scan_result['scan_results']['results']['md5']:<48}")
                print(f"{'SHA256':<25}{scan_result['scan_results']['results']['sha256']:<48}")
                print(f"{'TLSH':<25}{scan_result['scan_results']['results']['tlsh']:<48}")
                if len(scan_result["scan_results"]["detections"]) > 0:
                    print("Detections: ")
                    print(f"{'Severity':<25}{'Detection':<48}{'Description':<64}")
                    print("-"*128)
                    for d in scan_result["scan_results"]["detections"]:
                        print(f"{d['severity']:<25}{d['message']:<48}{d['description']:<64}")
            except Exception as err:
                print(f"Error printing scan results: {err}")
            print("\n")
    
    def find_model_files_in_hf_repo(self, repo: str):
        """
        Find model files in a HuggingFace repo"""
        # read .gitattributes
        # create regex patterns
        # match model files and mark others as skipped
        try:
            model_info = huggingface_hub.model_info(repo)
        except Exception:
            return {}

        gitattr = hf_hub_download(repo_id=repo, filename=".gitattributes")
        with open(gitattr, "r") as fin:
            patterns = [
                re.compile(line.split(" ")[0].replace(".", "\.").replace("*", ".*") + "$")
                for line in fin
                if line
            ]                    

        return {
            s.rfilename: any(p.match(s.rfilename) for p in patterns)
            for s in model_info.siblings
        }

    async def scan_huggingface_repo(self, repo: str, verbose: bool = False, filename: str = None, scan:bool = True):
        """
        Scan a HuggingFace repo"""
        scan_results = {}
        if filename is None:
            model_files = self.find_model_files_in_hf_repo(repo)
        else:
            model_files = {filename: True}

        model_version = 1
        for i, (f, is_model) in enumerate(model_files.items()):
            if is_model:
                if verbose:
                    print(f"Downloading {repo}/{f}")
                fn = hf_hub_download(repo_id=repo, filename=f)
                if scan:
                    ok, sensor = await self.upload_and_scan(repo, fn, version=model_version, verbose=verbose, tags={"env":"demo", "source": "huggingface"})
                else:
                    sensor = await self.get_or_create_sensor(repo, version=model_version, tags={"env":"demo", "source": "huggingface"})
                sensor_scan_results = await self.get_scan_results(sensor["sensor_id"], wait=scan)
                scan_results[sensor["sensor_id"]] = {"sensor": sensor, "uploaded": scan, "scan_results": sensor_scan_results, "filename": f}
                model_version += 1
            else:
                if verbose:
                    print(f"Skipping {repo}/{f}")
        return scan_results


client = AISecClient(client_id, client_secret)

---

# Model Scanner Scenarios
1. Scan LLM 
    - Model - [Microsoft Phi-2](https://huggingface.co/microsoft/phi-2/tree/main)
2. Scan Production Ready Model 
    - Model - [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32/tree/main)
3. Scan Non-Production Ready Model (FastAI 3rd Party Tool)
    - Model - [fastai/fastbook_04_mnist_basics](https://huggingface.co/fastai/fastbook_04_mnist_basics/tree/main)
4. Scan Malicious Model 
    - Model - [ScanMe/Models](https://huggingface.co/ScanMe/Models/tree/main)

### ❗️ Note: After the models have been uploaded to HiddenLayer, they will show up in the [Model Inventory](https://console.hiddenlayer.ai/model-inventory)



#### Scan LLM ✅

In this section we will scan the [Microsoft Phi-2](https://huggingface.co/microsoft/phi-2/tree/main) LLM. This section can take 10-15 minutes to run due to downloading, uploading, and scanning the model that is 5Gb in size.

###### ❗️ Note: Since the Phi-2 model is multiple artifacts. Two artifacts will show in the model overview view.

In [4]:
repo_name = "microsoft/phi-2"
print(f"Scanning the {repo_name} model repo...")
print(f"Model {repo_name} wil be available in model-inventory soon...")
scan_results = await client.scan_huggingface_repo(repo_name)
print(f"The {repo_name} scan is complete...\n")
client.pretty_print_scan_results(scan_results)

Scanning the microsoft/phi-2 model repo...
Model microsoft/phi-2 wil be available in model-inventory soon...
The microsoft/phi-2 scan is complete...

microsoft/phi-2 (e25fb73b-5974-4884-9350-080d3d53a590)
################################################################################################################################
Verdict                  SAFE                                            
File                     model-00001-of-00002.safetensors                
Type                     safetensors                                     
MD5                      2568cd93356c436b9c828a240af37a27                
SHA256                   7fbcdefa72edf7527bf5da40535b57d9f5bd3d16829b94a9d25d2b457df62e84
TLSH                     ceba23e3b1e1b68f8015dc6e4b19fa3419ebcd275c43e590b188868fd83da615f58fa0


microsoft/phi-2 (c7a9d9a8-0419-436b-8064-efa8ff46824d)
###############################################################################################################################

#### Scan Production Ready Model ✅

This section will scan a specifc model from a model repository. It will select the `pytorch_model.bin` file from the [OpenAI Clip-Vit-Base-Patch32](https://huggingface.co/openai/clip-vit-base-patch32/tree/main) model repo. 

Notes:
- This model has the expected pickle modules for a model that is ready for production
- Only modules need to reload the tensors are found.

In [5]:
repo_name = "openai/clip-vit-base-patch32"
print(f"Scanning the {repo_name} model repo...")
print(f"Model {repo_name} wil be available in model-inventory soon...")
scan_results = await client.scan_huggingface_repo(repo_name, filename="pytorch_model.bin")
print(f"The {repo_name} scan is complete...\n")
client.pretty_print_scan_results(scan_results)

Scanning the openai/clip-vit-base-patch32 model repo...
Model openai/clip-vit-base-patch32 wil be available in model-inventory soon...
The openai/clip-vit-base-patch32 scan is complete...

openai/clip-vit-base-patch32 (1c565544-22da-45e2-b919-c3fe1b76c510)
################################################################################################################################
Verdict                  SAFE                                            
File                     pytorch_model.bin                               
Type                     pytorch                                         
MD5                      47767ea81d24718fcc0c8923607792a7                
SHA256                   a63082132ba4f97a80bea76823f544493bffa8082296d62d71581a4feff1576f
TLSH                     7a597481e1068fd0bca17b7bb8bf5d4e8edbca14d1bb10509726517da35b1d02fa3268




#### Scan Non-Production Ready Model ✅

This section will scan a specifc model from a model repository. It will select the `model.pkl` file from the [fastai/fastbook_04_mnist_basics](https://huggingface.co/fastai/fastbook_04_mnist_basics/tree/main) model repo. 

Notes:
- This model has a lot of pickle imports, making it more risky to introduce into a production environment
  - Some legitimate use cases such as debugging, remote telemetry, and monitoring
- `global.__getattr__` being found can add additional risk due to being an usafe way to execute python code

In [6]:

repo_name = "fastai/fastbook_04_mnist_basics"
print(f"Scanning the {repo_name} model repo...")
print(f"Model {repo_name} wil be available in model-inventory soon...")
scan_results = await client.scan_huggingface_repo(repo_name, filename="model.pkl")
print(f"The {repo_name} scan is complete...\n")
client.pretty_print_scan_results(scan_results)

Scanning the fastai/fastbook_04_mnist_basics model repo...
Model fastai/fastbook_04_mnist_basics wil be available in model-inventory soon...
The fastai/fastbook_04_mnist_basics scan is complete...

fastai/fastbook_04_mnist_basics (290cc15c-e514-4692-9d23-efc3e354ba86)
################################################################################################################################
Verdict                  UNSAFE                                          
File                     model.pkl                                       
Type                     pytorch                                         
MD5                      87e0800ed11fffc1700b31803fbfb8cf                
SHA256                   01aad6e71a0e92b73dfa1332688ca991f3b2270b8e81eefbacce648f6b320cc6
TLSH                     7da733c1ab3f614ad83520ae836990c37b48e0ef6b3bd6d716e2fd492c750425ec56c6
Detections: 
Severity                 Detection                                       Description                       

#### Scan malicious model ❌

❗️ Note: Be careful when handling unsafe models


In [7]:
repo_name = "ScanMe/Models"
print(f"Scanning the {repo_name} model repo...")
print(f"Model {repo_name} wil be available in model-inventory soon...")
scan_results = await client.scan_huggingface_repo(repo_name, filename="eval_lambda.h5")
print(f"The {repo_name} scan is complete...\n")
client.pretty_print_scan_results(scan_results)

Scanning the ScanMe/Models model repo...
Model ScanMe/Models wil be available in model-inventory soon...
The ScanMe/Models scan is complete...

ScanMe/Models (9ad2ae99-809f-49cb-8b11-c29c70ec15ee)
################################################################################################################################
Verdict                  UNSAFE                                          
File                     eval_lambda.h5                                  
Type                     keras                                           
MD5                      139152c3ae4ed27124d4217079e8b6e7                
SHA256                   d9ceae4da8037b02280a23368325ddda263889ce11771c4a78301aac1b2254ba
TLSH                     a012ca37ab21dd3fd0b99838048643b92b20df4317c15747a690b92c3eb58585f61cd9
Detections: 
Severity                 Detection                                       Description                                                     
-----------------------------------------