# Mistral Document AI OCR Sample

This notebook demonstrates how to analyze a document image with the Mistral Document AI service.
It encodes a local PDF or image file as base64, submits it to the deployed endpoint, and inspects
the OCR response.

## 1. Configure authentication

Make sure you have deployed `mistral-document-ai-2505` and collected the endpoint URL
and API key from the Deployments + Endpoint page. Set them as environment variables before
running the notebook:

```bash
export AZURE_API_KEY="<your-api-key>"
export AZURE_MISTRAL_OCR_ENDPOINT="https://<your-endpoint>/providers/mistral/azure/ocr"
```

Equivalent commands for PowerShell and Windows CMD are shown in the quick-start guide.

In [None]:
%pip install -q requests pillow python-dotenv

## 2. Load configuration

The next cell loads environment variables (optionally using a `.env` file) and validates that
we have the information we need to call the API.

In [None]:
import os
import json
import base64
import mimetypes
from pathlib import Path

try:
    from dotenv import load_dotenv
except ImportError:  # python-dotenv is optional
    load_dotenv = None

if load_dotenv:
    load_dotenv()

AZURE_API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
ENDPOINT = os.getenv("AZURE_MISTRAL_OCR_ENDPOINT", "https://<your-endpoint>/providers/mistral/azure/ocr")
MODEL_NAME = os.getenv("MISTRAL_DOCUMENT_MODEL", "mistral-document-ai-2505")

if not AZURE_API_KEY or AZURE_API_KEY == "YOUR_API_KEY":
    raise ValueError("Set the AZURE_API_KEY environment variable (or edit AZURE_API_KEY above) before continuing.")

if ENDPOINT.startswith("https://<your-endpoint>"):
    raise ValueError("Set the AZURE_MISTRAL_OCR_ENDPOINT environment variable (or edit ENDPOINT above) before continuing.")

print(f"Sending requests to: {ENDPOINT}")


## 3. Encode a local document

Update `sample_file_path` to point to a local PDF or image file. The helper below converts the
file to a base64 data URL that the REST API accepts.

In [None]:
from typing import Dict

sample_file_path = Path("data/sample-document.jpg")  # <-- replace with your own file

if not sample_file_path.exists():
    raise FileNotFoundError("Update sample_file_path to point to a local PDF or image file before continuing.")

mime_type, _ = mimetypes.guess_type(sample_file_path.name)
if mime_type is None:
    mime_type = "application/pdf" if sample_file_path.suffix.lower() == ".pdf" else "image/jpeg"

with sample_file_path.open("rb") as document_file:
    base64_payload = base64.b64encode(document_file.read()).decode("utf-8")

if mime_type == "application/pdf":
    document_payload: Dict[str, str] = {
        "type": "document_url",
        "document_url": f"data:{mime_type};base64,{base64_payload}",
    }
else:
    document_payload = {
        "type": "image_url",
        "image_url": f"data:{mime_type};base64,{base64_payload}",
    }

print(f"Prepared {sample_file_path.name} as {mime_type}.")


## 4. Submit the request

With the payload ready we can call the REST endpoint. The response is JSON containing the OCR
results (and optionally bounding-box annotations if you configured them in your deployment).

In [None]:
import requests
from pprint import pprint

payload = {
    "model": MODEL_NAME,
    "document": document_payload,
    "include_image_base64": True,
}

headers = {
    "Authorization": f"Bearer {AZURE_API_KEY}",
    "Content-Type": "application/json",
}

response = requests.post(ENDPOINT, headers=headers, json=payload, timeout=60)
response.raise_for_status()

ocr_result = response.json()
request_id = response.headers.get("x-ms-request-id", "n/a")

print(f"Status: {response.status_code}, request id: {request_id}")
print("Top-level keys returned:", list(ocr_result.keys()))


## 5. Inspect extracted text

The structure of the response may vary depending on the deployment options you enabled.
The helper below walks the JSON payload and collects any fields that look like textual OCR
results (keys containing `text` or `content`). Adjust the filters if your deployment uses
different field names.

In [None]:
from collections import Counter
from itertools import islice

TARGET_KEYS = {"text", "content", "plain_text", "normalized_text"}

def collect_text_segments(node):
    segments = []

    def _walk(candidate):
        if isinstance(candidate, dict):
            for key, value in candidate.items():
                lowered = key.lower()
                if lowered in TARGET_KEYS and isinstance(value, str):
                    cleaned = value.strip()
                    if cleaned:
                        segments.append(cleaned)
                if isinstance(value, (dict, list)):
                    _walk(value)
        elif isinstance(candidate, list):
            for item in candidate:
                _walk(item)

    _walk(node)
    return segments

text_segments = collect_text_segments(ocr_result)
print(f"Collected {len(text_segments)} text segments from the response.")
for idx, segment in enumerate(islice(text_segments, 10)):
    print(f"{idx + 1:>2}: {segment}")


## 6. Compare against a reference document

Use this helper to turn the OCR markdown into structured fields and
compare them with reference values from the original document. Update
`expected_fields` in the next cell to match your ground truth before
running the comparison.


In [None]:
import re
from typing import Dict, List

if 'ocr_result' not in globals():
    raise RuntimeError('Run the request cell first so ocr_result is populated.')

pages = ocr_result.get('pages') or []
if not pages or not pages[0].get('markdown'):
    raise ValueError('No markdown content found in the OCR response.')

doc_text = pages[0]['markdown']
doc_text_lower = doc_text.lower()

def get_field(label: str) -> str:
    pattern = rf"{re.escape(label)}\s*(.+)"
    match = re.search(pattern, doc_text)
    if not match:
        return ''
    value = match.group(1).split('
')[0].strip()
    value = value.replace('_', ' ').strip()
    return value

def selected_options(options: List[str]) -> List[str]:
    selected = []
    for option in options:
        if f"[x] {option.lower()}" in doc_text_lower:
            selected.append(option)
    return selected

def normalize(value: str) -> str:
    return re.sub(r'\s+', '', value.lower())

extracted_fields: Dict[str, str] = {
    'Owner Type': next(iter(selected_options(['Individual', 'Entity (e.g., trust/corporation)'])), ''),
    'Owner Name': get_field('Owner Name:'),
    'Date of Birth': get_field('Date of Birth:'),
}

joint_line = get_field("Joint Owner's Name:")
if 'Date of Birth:' in joint_line:
    joint_name, _ = joint_line.split('Date of Birth:', 1)
    joint_line = joint_name.strip()
extracted_fields['Joint Owner Name'] = joint_line

employment_selected = selected_options(['Employed', 'Self-Employed', 'Retired', 'Not Employed'])
extracted_fields['Employment Status'] = employment_selected[0] if employment_selected else ''
extracted_fields['Business Name'] = get_field('Name of Business (if applicable):')
extracted_fields['Occupation'] = get_field('Occupation or Nature of Business if self-employed (if applicable):')
extracted_fields['Premium Amount'] = get_field('Premium Amount:')
extracted_fields['Surrender Charge Period'] = get_field('Surrender Charge Period Length:')
extracted_fields['Year 1 Surrender Charge'] = get_field('Year 1 Surrender Charge:')
rider_selected = selected_options(['Enhanced Death Benefit', 'Living Benefit', 'No Rider'])
extracted_fields['Rider'] = rider_selected[0] if rider_selected else ''
extracted_fields['Rider Fee'] = get_field('If Yes, Rider Fee:')

income_options = [
    'Current Wages (Owner)',
    'Current Wages (Spouse/Partner)',
    'Social Security',
    'Pension Plan Payments',
    'Guaranteed Annuity Payments',
    'Regular Distributions from Investments',
    'Rental Income',
]
extracted_fields['Income Sources'] = '; '.join(selected_options(income_options))
extracted_fields['Annual Gross Income'] = get_field('Annual Household Gross Income:')
extracted_fields['Annual Living Expenses'] = get_field('Annual Living Expenses (including all household expenses):')
extracted_fields['Disposable Income'] = get_field('Disposable Income (b minus c):')

if 'do you anticipate a significant change' in doc_text_lower:
    if '[x] yes' in doc_text_lower and '[x] no' in doc_text_lower:
        extracted_fields['Future Disposable Change'] = 'Yes' if doc_text_lower.index('[x] yes') < doc_text_lower.index('[x] no') else 'No'
    elif '[x] yes' in doc_text_lower:
        extracted_fields['Future Disposable Change'] = 'Yes'
    elif '[x] no' in doc_text_lower:
        extracted_fields['Future Disposable Change'] = 'No'

tax_bracket_options = ['0%', '10%', '12%', '22%', '24%', '32%', '35%', '37%']
brackets = selected_options(tax_bracket_options)
extracted_fields['Federal Tax Bracket'] = brackets[0] if brackets else ''

financial_products = ['Annuity', 'Life Insurance', 'Stocks, Bonds & Mutual Funds', "CD's", 'Real Estate', 'None']
extracted_fields['Other Financial Products'] = '; '.join(selected_options(financial_products))

expected_fields = {
    # Example reference values – replace with the ground truth from your document
    'Owner Type': 'Individual',
    'Owner Name': 'Aaron Rodgers',
    'Date of Birth': '2/16/81',
    'Joint Owner Name': '',
    'Employment Status': 'Employed',
    'Business Name': 'Football',
    'Occupation': 'Quarterback',
    'Premium Amount': '$1,000,000.00',
    'Surrender Charge Period': '9 years',
    'Year 1 Surrender Charge': '8.75%',
    'Rider': 'Living Benefit',
    'Rider Fee': '0.95%',
    'Income Sources': 'Current Wages (Owner)',
    'Annual Gross Income': '$8,000,000.00',
    'Annual Living Expenses': '$1,000,000.00',
    'Disposable Income': '$7,000,000.00',
    'Future Disposable Change': 'No',
    'Federal Tax Bracket': '37%',
    'Other Financial Products': 'None',
}

rows = []
for key, expected_value in expected_fields.items():
    actual_value = extracted_fields.get(key, '')
    status = 'MATCH' if normalize(actual_value) == normalize(expected_value) else 'CHECK'
    rows.append((key, actual_value, expected_value, status))

col_widths = [max(len(str(row[i])) for row in rows) for i in range(4)]
headers = ['Field', 'OCR', 'Reference', 'Status']
col_widths = [max(col_widths[i], len(headers[i])) for i in range(4)]

print(' | '.join(headers[i].ljust(col_widths[i]) for i in range(4)))
print('-+-'.join('-' * col_widths[i] for i in range(4)))
for row in rows:
    print(' | '.join(str(row[i]).ljust(col_widths[i]) for i in range(4)))

print('
Extracted field values:')
for key, value in extracted_fields.items():
    print(f'  {key}: {value}')

mismatches = [row for row in rows if row[-1] != 'MATCH']
if mismatches:
    print(f"
Found {len(mismatches)} field(s) marked CHECK – review and adjust as needed.")
else:
    print('
All fields match the reference document.')


## 7. (Optional) Visualise bounding boxes

If the service returns `bbox` information alongside an inline image (via `include_image_base64`),
the cell below overlays the boxes on top of the image. Update the keys if your payload nests
the annotations differently.

In [None]:
import io
from PIL import Image, ImageDraw

image_base64 = None
if isinstance(ocr_result, dict):
    image_base64 = ocr_result.get("image_base64") or ocr_result.get("image")

if not image_base64:
    print("No base64 image found in the response. Ensure include_image_base64=True if you need a visualisation.")
else:
    if "," in image_base64:
        _, b64_data = image_base64.split(",", 1)
    else:
        b64_data = image_base64

    image_bytes = base64.b64decode(b64_data)
    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    draw = ImageDraw.Draw(image)

    def draw_boxes(candidate):
        if isinstance(candidate, dict):
            if {"bbox", "text"}.issubset(candidate.keys()):
                bbox = candidate["bbox"]
                if isinstance(bbox, (list, tuple)) and len(bbox) == 4:
                    draw.rectangle(bbox, outline="red", width=2)
            for value in candidate.values():
                draw_boxes(value)
        elif isinstance(candidate, list):
            for item in candidate:
                draw_boxes(item)

    draw_boxes(ocr_result)
    display(image)
