In [None]:
import boto3
import sagemaker
import json
import httpx
import base64
from io import BytesIO
from typing import Dict, Any
from sagemaker import ModelPackage
from IPython.display import Markdown

# Deploying and running Mistral OCR on SageMaker

This notebook contains the basic steps to deploy and run the Mistral OCR model on SageMaker.

> **Pre-requisites**
> - The ARN of the Mistral OCR model package
> - An IAM role with sufficient permissions to deploy a SageMaker endpoint and run inference on it

You can run this notebook anywhere (not necessarily in SageMaker) provided that you have properly configured the AWS authentication in your development environment.

You will need to fill in all the following parameters in the code cell below to run this notebook:

| Name | Description |
| --- | --- |
| `MISTRAL_OCR_MODEL_PACKAGE_ARN` | The ARN of the Mistral OCR model package |
| `MISTRAL_OCR_MODEL_CONFIG_INSTANCE_TYPE` | The EC2 instance type used to host the endpoint (e.g. `ml.g6.24xlarge`) |
| `SAGEMAKER_EXECUTION_ROLE_ARN` | The ARN of the IAM role with permissions to deploy and run SageMaker endpoints |
| `MISTRAL_OCR_ENDPOINT_NAME` | The name you want to give to your endpoint |


In [None]:
MISTRAL_OCR_MODEL_PACKAGE_ARN = "" 
MISTRAL_OCR_MODEL_CONFIG_INSTANCE_TYPE = "" 
SAGEMAKER_EXECUTION_ROLE_ARN = "" 
MISTRAL_OCR_ENDPOINT_NAME = "" 


## 1. Deploy the model

Start by defining the `deploy_sagemaker_model()` function that will start an endpoint which we will query for inference:

In [None]:
def deploy_sagemaker_model(model_package_arn: str, execution_role: str, instance_type: str, endpoint_name: str,):
    session = sagemaker.Session()
    role = execution_role
    model_package = ModelPackage(
        role=role,
        model_package_arn=model_package_arn,
        sagemaker_session=session
    )
    model_package.deploy(
        initial_instance_count=1,
        instance_type=instance_type,
        endpoint_name=endpoint_name,
        model_data_download_timeout=3600,
        container_startup_health_check_timeout=3600
    )

Then run this function with the parameters of your choice:

In [None]:
deploy_sagemaker_model(
    model_package_arn=MISTRAL_OCR_MODEL_PACKAGE_ARN,
    execution_role=SAGEMAKER_EXECUTION_ROLE_ARN,
    instance_type=MISTRAL_OCR_MODEL_CONFIG_INSTANCE_TYPE,
    endpoint_name=MISTRAL_OCR_ENDPOINT_NAME
)

## 2. Run inference on the model

Now that the endpoint is live you can start querying it. The OCR API differs quite a bit from the usual chat completion API, so make sure to familiarize yourself with its specifics by reading the [documentation](https://docs.mistral.ai/api/#tag/ocr/operation/ocr_v1_ocr_post) before moving forward.

Start by defining a helper function `run_inference()` to wrap the model invocation into:

In [None]:
def run_inference(client, endpoint_name: str, payload: dict[str,Any]) -> Dict[str, Any]:
    try:
        inference_out = client.invoke_endpoint(
            EndpointName=endpoint_name,
            ContentType="application/json",
            Body=json.dumps(payload)
        )
        inference_resp_str = inference_out["Body"].read().decode("utf-8")
        return json.loads(inference_resp_str)
    except Exception as e:
        print(f"Inference error: {e}")
        raise

### 2.1 With URL inputs

The following example runs the entire [Pixtral technical report](https://arxiv.org/pdf/2410.07073) through the OCR model. The input is listed as `document_url`, meaning that the input payload contains a link pointing to the document to be downloaded and parsed. For brevity, we only scan the first page by passing `"pages": [0]"`.

In [None]:
client = boto3.client("sagemaker-runtime")

pixtral_report_pdf_url = "https://arxiv.org/pdf/2410.07073"
payload ={
    "model": "mistral-ocr-2503",
    "document": {
        "type": "document_url",
        "document_url": pixtral_report_pdf_url

    },
    "pages": "0"
}
pixtral_report_parsed = run_inference(client=client, endpoint_name=MISTRAL_OCR_ENDPOINT_NAME, payload=payload)
Markdown(pixtral_report_parsed["pages"][0]["markdown"])

> **Note**
> You can also extract images from documents by setting the `include_image_base64` flag to `True` in the input payload.

### 2.2 With base64-encoded inputs

The OCR API endpoint also accepts base64-encoded inputs. This is particularly useful when your documents are stored locally.

The following example encodes an image taken from the [Mistral AI blog](https://mistral.ai/fr/news/pixtral-large) before passing it to the endpoint. To build a base64-encoded image example from an image URL, you will use the followign function:

You can now call the endpoint:

In [None]:
def download_and_encode_image(url: str) -> str:
    try:
        # Send a GET request to the URL
        with httpx.Client() as client:
            response = client.get(url, timeout=10)
            response.raise_for_status()  # Raise an exception for HTTP errors
        # Encode the image content to base64
        image_data = response.content
        base64_encoded_data = base64.b64encode(image_data).decode('utf-8')
        return base64_encoded_data
    except httpx.HTTPStatusError as exc:
        print(f"Error response {exc.response.status_code} while requesting {exc.request.url}")
        raise
    except httpx.RequestException as e:
        print(f"Error downloading image: {e}")
        raise

Note that if you want to test the endpoint with your own images, you can discard the function above and directly pass in your data.

Run the OCR API call as follows, and notice how the `"type"` has changed to `"image_url`":

In [None]:
receipt_image_url = "https://cms.mistral.ai/assets/1d7df1b8-5caa-47b9-b6a1-666b05d38019"
receipt_image_b64 = download_and_encode_image(url=receipt_image_url)

payload ={
    "model": "mistral-ocr-2503",
    "document": {
        "type": "image_url",
        "image_url": f"data:image/jpeg;base64,{receipt_image_b64}"
    }
}
receipt_parsed = run_inference(client=client, endpoint_name=MISTRAL_OCR_ENDPOINT_NAME, payload=payload)
Markdown(receipt_parsed["pages"][0]["markdown"])

## 3. Cleanup

Once you don't need the endpoint anymore, you should delete it to avoid unnecessary resource consumption. Since several resource types were created when the endpoint was spawned, here is the deletion sequence you should follow:

1. Delete the endpoint
2. Delete the endpoint configuration
3. Delete the model

## 4. Wrapping up

Here are some useful resources to further explore the topic of the Mistral OCR model:

- [Mistral blog post on the OCR API release](https://mistral.ai/fr/news/mistral-ocr)
- [Mistral documentation on OCR capabilities](https://docs.mistral.ai/capabilities/document/)
- [Mistral OCR API reference documentation](https://docs.mistral.ai/capabilities/document/)