### Installation

In [1]:
%%capture
import os, re
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    import torch; v = re.match(r"[0-9]{1,}\.[0-9]{1,}", str(torch.__version__)).group(0)
    xformers = "xformers==" + ("0.0.33.post1" if v=="2.9" else "0.0.32.post2" if v=="2.8" else "0.0.29.post3")
    !pip install --no-deps bitsandbytes accelerate {xformers} peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth
!pip install transformers==4.56.2
!pip install --no-deps trl==0.22.2
!pip install jiwer
!pip install einops addict easydict
!pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zoo transformers timm

In [2]:
# @title Mount from Google Drive
from google.colab import drive
drive.mount('/content/drive')

MessageError: Error: credential propagation was unsuccessful

In [3]:
# @title Environment Inspection
import os
import sys
import platform
import subprocess

# --- Environment Detection and Path Setup ---

def detect_env_and_setup_paths():
    """
    Detects the current environment (Kaggle, Colab, or Local) and
    returns appropriate paths and environment name.
    """
    env_name = "Local/Other"
    input_dir = "input/"
    output_dir = "output/"

    # 1. Detect Kaggle
    if 'KAGGLE_KERNEL_RUN_TYPE' in os.environ:
        env_name = "Kaggle"
        # Kaggle's standard paths
        input_dir = "/kaggle/input/"
        output_dir = "/kaggle/working/"

    # 2. Detect Colab (Note: Colab also has 'COLAB_GPU' if a GPU is assigned)
    elif 'google.colab' in sys.modules:
        env_name = "Colab"
        # Colab's default paths relative to the content folder
        input_dir = "/content/" # Often you mount Drive or download here
        output_dir = "/content/" # Default working directory
        print("Note: In Colab, you might need to mount Google Drive manually for persistent storage:")
        print("      from google.colab import drive; drive.mount('/content/drive')")

    # 3. Default to Local/Other
    # Paths are set to a simple structure relative to the script location

    return env_name, input_dir, output_dir

# --- System and Package Info ---

def get_system_info(env_name):
    """Prints Python, PyTorch, CUDA, and GPU information."""
    print("\n" + "="*50)
    print(f"       ðŸ’» System and Package Information for {env_name}")
    print("="*50)

    # 1. Python Version
    print(f"**Python Version:** {sys.version.split()[0]} ({platform.python_implementation()})")

    # 2. PyTorch and CUDA Info
    try:
        import torch
        print(f"**PyTorch Version:** {torch.__version__}")

        if torch.cuda.is_available():
            print("\n**CUDA/GPU Information (PyTorch):**")
            # CUDA version
            print(f"  - CUDA is Available: **True**")
            print(f"  - CUDA Version (Runtime): {torch.version.cuda}")
            # GPU details
            gpu_count = torch.cuda.device_count()
            print(f"  - GPU Count: {gpu_count}")
            for i in range(gpu_count):
                print(f"  - Device {i}: {torch.cuda.get_device_name(i)}")
        else:
            print(f"  - CUDA is Available: **False** (Running on CPU)")
    except ImportError:
        print("\n**PyTorch:** Not installed or not found.")
    except Exception as e:
        print(f"\n**PyTorch/CUDA Check Error:** {e}")

    # 3. nvidia-smi (System-level GPU info)
    print("\n**NVIDIA-SMI Output (Raw Driver/System Info):**")
    try:
        # Run nvidia-smi command
        result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, check=True)
        print(result.stdout)
    except FileNotFoundError:
        # This will happen if nvidia-smi is not in PATH or no NVIDIA driver is installed
        print("  - `nvidia-smi` command not found (No NVIDIA GPU/drivers or not in PATH).")
    except subprocess.CalledProcessError as e:
        # This might happen if the command runs but fails (e.g., driver issues)
        print(f"  - `nvidia-smi` failed to run. Error: {e.stderr.strip()}")

# --- Main Execution ---

if __name__ == "__main__":

    # Get environment details and paths
    env_name, INPUT_DIR, OUTPUT_DIR = detect_env_and_setup_paths()

    # Print detected environment and paths
    print("="*50)
    print(f"       âœ… Environment Detected: **{env_name}**")
    print("="*50)
    print(f"**Input Path (Default):** {INPUT_DIR}")
    print(f"**Output Path (Default):** {OUTPUT_DIR}")

    # Run system checks
    get_system_info(env_name)

    print("="*50)

# Example Usage within the script (simulated)
# train_data_path = os.path.join(INPUT_DIR, "dataset_folder", "train.csv")
# model_save_path = os.path.join(OUTPUT_DIR, "best_model.pth")

Note: In Colab, you might need to mount Google Drive manually for persistent storage:
      from google.colab import drive; drive.mount('/content/drive')
       âœ… Environment Detected: **Colab**
**Input Path (Default):** /content/
**Output Path (Default):** /content/

       ðŸ’» System and Package Information for Colab
**Python Version:** 3.12.12 (CPython)
**PyTorch Version:** 2.9.0+cu126

**CUDA/GPU Information (PyTorch):**
  - CUDA is Available: **True**
  - CUDA Version (Runtime): 12.6
  - GPU Count: 1
  - Device 0: Tesla T4

**NVIDIA-SMI Output (Raw Driver/System Info):**
Fri Dec 12 08:43:18 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   

In [4]:
# @title Unzipping everything
import os
import glob
from zipfile import ZipFile

# 2. Use glob to find all files ending in '.zip' within the search directory
# The os.path.join ensures correct path construction
zip_file_paths = glob.glob(os.path.join(INPUT_DIR, '*.zip'))

# Check if any zip files were found
if not zip_file_paths:
    print(f'No .zip files found in {INPUT_DIR}.')
else:
    # 3. Loop through all found zip files and unzip them
    for zip_file_path in zip_file_paths:
        if os.path.exists(zip_file_path):
            print(f'Unzipping {zip_file_path}...')

            # Use the shell command !unzip for simplicity in Colab
            # -q for quiet (less output)
            # -o to overwrite existing files without prompting
            # -d ./ to extract to the current working directory (usually /content/)
            !unzip -q -o {zip_file_path} -d ./

            print(f'Unzipping of {zip_file_path} complete.')
        else:
            # This path is unlikely given glob was just used, but kept for robustness
            print(f'Error: The file {zip_file_path} was not found (post-glob check).')

Unzipping /content/UIT_HWDB_line.zip...
[/content/UIT_HWDB_line.zip]
  End-of-central-directory signature not found.  Either this file is not
  a zipfile, or it constitutes one disk of a multi-part archive.  In the
  latter case the central directory and zipfile comment will be found on
  the last disk(s) of this archive.
unzip:  cannot find zipfile directory in one of /content/UIT_HWDB_line.zip or
        /content/UIT_HWDB_line.zip.zip, and cannot find /content/UIT_HWDB_line.zip.ZIP, period.
Unzipping of /content/UIT_HWDB_line.zip complete.


### Configure environment

### Unsloth

Let's prepare the OCR model to our local first

In [5]:
from huggingface_hub import snapshot_download
snapshot_download("unsloth/DeepSeek-OCR", local_dir = "deepseek_ocr")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 21 files:   0%|          | 0/21 [00:00<?, ?it/s]

README.md: 0.00B [00:00, ?B/s]

README-checkpoint.md: 0.00B [00:00, ?B/s]

LICENSE: 0.00B [00:00, ?B/s]

.gitattributes: 0.00B [00:00, ?B/s]

assets/show3.jpg:   0%|          | 0.00/247k [00:00<?, ?B/s]

assets/show1.jpg:   0%|          | 0.00/117k [00:00<?, ?B/s]

assets/show2.jpg:   0%|          | 0.00/216k [00:00<?, ?B/s]

assets/fig1.png:   0%|          | 0.00/396k [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

conversation.py: 0.00B [00:00, ?B/s]

configuration_deepseek_v2.py: 0.00B [00:00, ?B/s]

assets/show4.jpg:   0%|          | 0.00/269k [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

deepencoder.py: 0.00B [00:00, ?B/s]

model-00001-of-000001.safetensors:   0%|          | 0.00/6.67G [00:00<?, ?B/s]

modeling_deepseekocr.py: 0.00B [00:00, ?B/s]

modeling_deepseekv2.py: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/801 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/460 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

'/content/deepseek_ocr'

In [6]:
# @title Download original DeepSeek OCR model
import unsloth
from unsloth import FastVisionModel # FastLanguageModel for LLMs
import torch
from transformers import AutoModel
import os
os.environ["UNSLOTH_WARN_UNINITIALIZED"] = '0'
# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
    "unsloth/Qwen3-VL-8B-Instruct-bnb-4bit", # Qwen 3 vision support
    "unsloth/Qwen3-VL-8B-Thinking-bnb-4bit",
    "unsloth/Qwen3-VL-32B-Instruct-bnb-4bit",
    "unsloth/Qwen3-VL-32B-Thinking-bnb-4bit",
] # More models at https://huggingface.co/unsloth

original_model, original_tokenizer = FastVisionModel.from_pretrained(
    "./deepseek_ocr",
    load_in_4bit = False, # Use 4bit to reduce memory use. False for 16bit LoRA.
    auto_model = AutoModel,
    trust_remote_code=True,
    unsloth_force_compile=True,
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
)

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!


Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
You are using a model of type deepseek_vl_v2 to instantiate a model of type DeepseekOCR. This is not supported for all configurations of models and can yield errors.


Are you certain you want to do remote code execution?
==((====))==  Unsloth 2025.12.4: Fast Deepseekocr patching. Transformers: 4.57.3.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.33.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


You are using a model of type deepseek_vl_v2 to instantiate a model of type DeepseekOCR. This is not supported for all configurations of models and can yield errors.
You are using a model of type deepseek_vl_v2 to instantiate a model of type DeepseekOCR. This is not supported for all configurations of models and can yield errors.


Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.


Some weights of DeepseekOCRForCausalLM were not initialized from the model checkpoint at ./deepseek_ocr and are newly initialized: ['model.vision_model.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


To format the dataset, all vision finetuning tasks should be formatted as follows:

```python
[
{ "role": "<|User|>",
  "content": "",
  "images": []
},
{ "role": "<|Assistant|>",
  "content": ""
},
]
```

In [7]:
# @title Convert data to DeepSeek chat template
import pandas as pd
from datasets import Dataset, Image, DatasetDict
import os
import json

# --- Configuration ---
# Update these paths according to your setup
TRAIN_DATA_DIR = os.path.join(INPUT_DIR, "UIT_HWDB_line", "train_data")
TEST_DATA_DIR = os.path.join(INPUT_DIR, "UIT_HWDB_line", "test_data")
TRAIN_LABELS_FILE = os.path.join(TRAIN_DATA_DIR, "labels.json")  # Assuming train has labels.json too
TEST_LABELS_FILE = os.path.join(TEST_DATA_DIR, "labels.json")
DEEPSEEK_PROMPT = "<image>\nFree OCR."

def convert_to_conversation(sample):
    """Convert dataset sample to conversation format"""
    conversation = [
        {
            "role": "<|User|>",
            "content": DEEPSEEK_PROMPT,
            "images": [sample['image']]
        },
        {
            "role": "<|Assistant|>",
            "content": sample["text"]
        },
    ]
    return {"messages": conversation}

def load_and_format_dataset(data_dir, labels_file):
    """Load and format a single dataset split"""
    print(f"Loading from {data_dir}...")

    # Load JSON labels
    try:
        with open(labels_file, 'r', encoding='utf-8') as f:
            labels_data = pd.read_json(f, typ='series')
    except FileNotFoundError:
        print(f"Warning: {labels_file} not found. Trying alternative loading method...")
        with open(labels_file, 'r', encoding='utf-8') as f:
            labels_data = json.load(f)
        labels_data = pd.Series(labels_data)

    label_list = []
    missing_count = 0

    for file_name, text in labels_data.items():
        full_path = os.path.join(data_dir, file_name)

        if os.path.exists(full_path):
            label_list.append({
                'file_name': file_name,
                'text': text,
                'image_path': full_path
            })
        else:
            missing_count += 1
            if missing_count <= 5:  # Print first 5 missing files
                print(f"  Missing: {full_path}")

    if not label_list:
        raise ValueError(
            f"No images found in {data_dir}!\n"
            f"Missing files count: {missing_count}\n"
            "Please verify paths and filename matches."
        )

    print(f"  Successfully loaded {len(label_list)} samples. ({missing_count} files missing)")

    # Convert to Hugging Face Dataset
    df = pd.DataFrame(label_list)
    dataset = Dataset.from_pandas(df)
    dataset = dataset.cast_column("image_path", Image())
    dataset = dataset.rename_column("image_path", "image")

    # Format for DeepSeek
    dataset = dataset.map(convert_to_conversation)

    return dataset

# --- Main execution ---
def main():
    # Load train and test datasets
    print("1. Loading train dataset...")
    train_dataset = load_and_format_dataset(TRAIN_DATA_DIR, TRAIN_LABELS_FILE)

    print("\n2. Loading test dataset...")
    test_dataset = load_and_format_dataset(TEST_DATA_DIR, TEST_LABELS_FILE)

    # Create DatasetDict for easy access
    dataset_dict = DatasetDict({
        "train": train_dataset,
        "test": test_dataset
    })

    # --- Inspection ---
    print("\n--- Dataset Summary ---")
    print(f"Train samples: {len(train_dataset)}")
    print(f"Test samples: {len(test_dataset)}")

    print("\n--- Example of Train Data ---")
    sample = train_dataset[0]
    print(f"Keys in sample: {sample.keys()}")
    print(f"Number of messages: {len(sample['messages'])}")
    print(f"User role: {sample['messages'][0]['role']}")
    print(f"User content: {sample['messages'][0]['content'][:50]}...")
    print(f"Assistant content: {sample['messages'][1]['content'][:50]}...")
    print(f"Image type: {type(sample['messages'][0]['images'][0])}")

    print("\n--- Example of Test Data ---")
    sample = test_dataset[0]
    print(f"Keys in sample: {sample.keys()}")
    print(f"Number of messages: {len(sample['messages'])}")
    print(f"User role: {sample['messages'][0]['role']}")

    return dataset_dict

# Alternative: If you want to keep the original function signature for compatibility
def load_and_format_all_datasets():
    """Wrapper function that returns both datasets"""
    print("1. Loading train dataset...")
    train_dataset = load_and_format_dataset(TRAIN_DATA_DIR, TRAIN_LABELS_FILE)

    print("\n2. Loading test dataset...")
    test_dataset = load_and_format_dataset(TEST_DATA_DIR, TEST_LABELS_FILE)

    return train_dataset, test_dataset

In [8]:
# @title Load train and test dataset
train_dataset, test_dataset = load_and_format_all_datasets()

1. Loading train dataset...
Loading from /content/UIT_HWDB_line/train_data...


FileNotFoundError: [Errno 2] No such file or directory: '/content/UIT_HWDB_line/train_data/labels.json'

### Let's Evaluate Deepseek-OCR Baseline Performance on UIT Handwritten Dataset

In [None]:
# Save an image that will not be used during training for evaluation purposes
train_dataset[1415]['image'].save("your_image.jpg")

In [None]:
train_dataset[1415]['image']

# Let's finetune Deepseek-OCR !

We now add LoRA adapters for parameter efficient finetuning - this allows us to only efficiently train 1% of all parameters.

**[NEW]** We also support finetuning ONLY the vision part of the model, or ONLY the language part. Or you can select both! You can also select to finetune the attention or the MLP layers!

In [None]:
# @title Get finetuned model initialization with PEFT
finetuned_model = FastVisionModel.get_peft_model(
    original_model,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],

    r = 16,           # The larger, the higher the accuracy, but might overfit
    lora_alpha = 16,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
    # target_modules = "all-linear", # Optional now! Can specify a list if needed
)

In [None]:
train_dataset[0]

In [None]:
# @title Create datacollator

import torch
import math
from dataclasses import dataclass
from typing import Dict, List, Any, Tuple
from PIL import Image, ImageOps
from torch.nn.utils.rnn import pad_sequence
import io

from deepseek_ocr.modeling_deepseekocr import (
    format_messages,
    text_encode,
    BasicImageTransform,
    dynamic_preprocess,
)

@dataclass
class DeepSeekOCRDataCollator:
    """
    Args:
        tokenizer: Tokenizer
        model: Model
        image_size: Size for image patches (default: 640)
        base_size: Size for global view (default: 1024)
        crop_mode: Whether to use dynamic cropping for large images
        train_on_responses_only: If True, only train on assistant responses (mask user prompts)
    """
    tokenizer: Any
    model: Any
    image_size: int = 640
    base_size: int = 1024
    crop_mode: bool = True
    image_token_id: int = 128815
    train_on_responses_only: bool = True

    def __init__(
        self,
        tokenizer,
        model,
        image_size: int = 640,
        base_size: int = 1024,
        crop_mode: bool = True,
        train_on_responses_only: bool = True,
    ):
        self.tokenizer = tokenizer
        self.model = model
        self.image_size = image_size
        self.base_size = base_size
        self.crop_mode = crop_mode
        self.image_token_id = 128815
        self.dtype = model.dtype  # Get dtype from model
        self.train_on_responses_only = train_on_responses_only

        self.image_transform = BasicImageTransform(
            mean=(0.5, 0.5, 0.5),
            std=(0.5, 0.5, 0.5),
            normalize=True
        )
        self.patch_size = 16
        self.downsample_ratio = 4

        # Get BOS token ID from tokenizer
        if hasattr(tokenizer, 'bos_token_id') and tokenizer.bos_token_id is not None:
            self.bos_id = tokenizer.bos_token_id
        else:
            self.bos_id = 0
            print(f"Warning: tokenizer has no bos_token_id, using default: {self.bos_id}")

    def deserialize_image(self, image_data) -> Image.Image:
        """Convert image data (bytes dict or PIL Image) to PIL Image in RGB mode"""
        if isinstance(image_data, Image.Image):
            return image_data.convert("RGB")
        elif isinstance(image_data, dict) and 'bytes' in image_data:
            image_bytes = image_data['bytes']
            image = Image.open(io.BytesIO(image_bytes))
            return image.convert("RGB")
        else:
            raise ValueError(f"Unsupported image format: {type(image_data)}")

    def calculate_image_token_count(self, image: Image.Image, crop_ratio: Tuple[int, int]) -> int:
        """Calculate the number of tokens this image will generate"""
        num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio)
        num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)

        width_crop_num, height_crop_num = crop_ratio

        if self.crop_mode:
            img_tokens = num_queries_base * num_queries_base + 1
            if width_crop_num > 1 or height_crop_num > 1:
                img_tokens += (num_queries * width_crop_num + 1) * (num_queries * height_crop_num)
        else:
            img_tokens = num_queries * num_queries + 1

        return img_tokens

    def process_image(self, image: Image.Image) -> Tuple[List, List, List, List, Tuple[int, int]]:
        """
        Process a single image based on crop_mode and size thresholds

        Returns:
            Tuple of (images_list, images_crop_list, images_spatial_crop, tokenized_image, crop_ratio)
        """
        images_list = []
        images_crop_list = []
        images_spatial_crop = []

        if self.crop_mode:
            # Determine crop ratio based on image size
            if image.size[0] <= 640 and image.size[1] <= 640:
                crop_ratio = (1, 1)
                images_crop_raw = []
            else:
                images_crop_raw, crop_ratio = dynamic_preprocess(
                    image, min_num=2, max_num=9,
                    image_size=self.image_size, use_thumbnail=False
                )

            # Process global view with padding
            global_view = ImageOps.pad(
                image, (self.base_size, self.base_size),
                color=tuple(int(x * 255) for x in self.image_transform.mean)
            )
            images_list.append(self.image_transform(global_view).to(self.dtype))

            width_crop_num, height_crop_num = crop_ratio
            images_spatial_crop.append([width_crop_num, height_crop_num])

            # Process local views (crops) if applicable
            if width_crop_num > 1 or height_crop_num > 1:
                for crop_img in images_crop_raw:
                    images_crop_list.append(
                        self.image_transform(crop_img).to(self.dtype)
                    )

            # Calculate image tokens
            num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio)
            num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)

            tokenized_image = ([self.image_token_id] * num_queries_base + [self.image_token_id]) * num_queries_base
            tokenized_image += [self.image_token_id]

            if width_crop_num > 1 or height_crop_num > 1:
                tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * (
                    num_queries * height_crop_num)

        else:  # crop_mode = False
            crop_ratio = (1, 1)
            images_spatial_crop.append([1, 1])

            # For smaller base sizes, resize; for larger, pad
            if self.base_size <= 640:
                resized_image = image.resize((self.base_size, self.base_size), Image.LANCZOS)
                images_list.append(self.image_transform(resized_image).to(self.dtype))
            else:
                global_view = ImageOps.pad(
                    image, (self.base_size, self.base_size),
                    color=tuple(int(x * 255) for x in self.image_transform.mean)
                )
                images_list.append(self.image_transform(global_view).to(self.dtype))

            num_queries = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)
            tokenized_image = ([self.image_token_id] * num_queries + [self.image_token_id]) * num_queries
            tokenized_image += [self.image_token_id]

        return images_list, images_crop_list, images_spatial_crop, tokenized_image, crop_ratio

    def process_single_sample(self, messages: List[Dict]) -> Dict[str, Any]:
            """
            Process a single conversation into model inputs.
            """

            # --- 1. Setup ---
            images = []
            for message in messages:
                if "images" in message and message["images"]:
                    for img_data in message["images"]:
                        if img_data is not None:
                            pil_image = self.deserialize_image(img_data)
                            images.append(pil_image)

            if not images:
                raise ValueError("No images found in sample. Please ensure all samples contain images.")

            tokenized_str = []
            images_seq_mask = []
            images_list, images_crop_list, images_spatial_crop = [], [], []

            prompt_token_count = -1 # Index to start training
            assistant_started = False
            image_idx = 0

            # Add BOS token at the very beginning
            tokenized_str.append(self.bos_id)
            images_seq_mask.append(False)

            for message in messages:
                role = message["role"]
                content = message["content"]

                # Check if this is the assistant's turn
                if role == "<|Assistant|>":
                    if not assistant_started:
                        # This is the split point. All tokens added *so far*
                        # are part of the prompt.
                        prompt_token_count = len(tokenized_str)
                        assistant_started = True

                    # Append the EOS token string to the *end* of assistant content
                    content = f"{content.strip()} {self.tokenizer.eos_token}"

                # Split this message's content by the image token
                text_splits = content.split('<image>')

                for i, text_sep in enumerate(text_splits):
                    # Tokenize the text part
                    tokenized_sep = text_encode(self.tokenizer, text_sep, bos=False, eos=False)
                    tokenized_str.extend(tokenized_sep)
                    images_seq_mask.extend([False] * len(tokenized_sep))

                    # If this text is followed by an <image> tag
                    if i < len(text_splits) - 1:
                        if image_idx >= len(images):
                            raise ValueError(
                                f"Data mismatch: Found '<image>' token but no corresponding image."
                            )

                        # Process the image
                        image = images[image_idx]
                        img_list, crop_list, spatial_crop, tok_img, _ = self.process_image(image)

                        images_list.extend(img_list)
                        images_crop_list.extend(crop_list)
                        images_spatial_crop.extend(spatial_crop)

                        # Add image placeholder tokens
                        tokenized_str.extend(tok_img)
                        images_seq_mask.extend([True] * len(tok_img))

                        image_idx += 1 # Move to the next image

            # --- 3. Validation and Final Prep ---
            if image_idx != len(images):
                raise ValueError(
                    f"Data mismatch: Found {len(images)} images but only {image_idx} '<image>' tokens were used."
                )

            # If we never found an assistant message, we're in a weird state
            # (e.g., user-only prompt). We mask everything.
            if not assistant_started:
                print("Warning: No assistant message found in sample. Masking all tokens.")
                prompt_token_count = len(tokenized_str)

            # Prepare image tensors
            images_ori = torch.stack(images_list, dim=0)
            images_spatial_crop_tensor = torch.tensor(images_spatial_crop, dtype=torch.long)

            if images_crop_list:
                images_crop = torch.stack(images_crop_list, dim=0)
            else:
                images_crop = torch.zeros((1, 3, self.base_size, self.base_size), dtype=self.dtype)

            return {
                "input_ids": torch.tensor(tokenized_str, dtype=torch.long),
                "images_seq_mask": torch.tensor(images_seq_mask, dtype=torch.bool),
                "images_ori": images_ori,
                "images_crop": images_crop,
                "images_spatial_crop": images_spatial_crop_tensor,
                "prompt_token_count": prompt_token_count, # This is now accurate
            }

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        """Collate batch of samples"""
        batch_data = []

        # Process each sample
        for feature in features:
            try:
                processed = self.process_single_sample(feature['messages'])
                batch_data.append(processed)
            except Exception as e:
                print(f"Error processing sample: {e}")
                continue

        if not batch_data:
            raise ValueError("No valid samples in batch")

        # Extract lists
        input_ids_list = [item['input_ids'] for item in batch_data]
        images_seq_mask_list = [item['images_seq_mask'] for item in batch_data]
        prompt_token_counts = [item['prompt_token_count'] for item in batch_data]

        # Pad sequences
        input_ids = pad_sequence(input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        images_seq_mask = pad_sequence(images_seq_mask_list, batch_first=True, padding_value=False)

        # Create labels
        labels = input_ids.clone()

        # Mask padding tokens
        labels[labels == self.tokenizer.pad_token_id] = -100

        # Mask image tokens (model shouldn't predict these)
        labels[images_seq_mask] = -100

        # Mask user prompt tokens when train_on_responses_only=True (only train on assistant responses)
        if self.train_on_responses_only:
            for idx, prompt_count in enumerate(prompt_token_counts):
                if prompt_count > 0:
                    labels[idx, :prompt_count] = -100

        # Create attention mask
        attention_mask = (input_ids != self.tokenizer.pad_token_id).long()

        # Prepare images batch (list of tuples)
        images_batch = []
        for item in batch_data:
            images_batch.append((item['images_crop'], item['images_ori']))

        # Stack spatial crop info
        images_spatial_crop = torch.cat([item['images_spatial_crop'] for item in batch_data], dim=0)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "images": images_batch,
            "images_seq_mask": images_seq_mask,
            "images_spatial_crop": images_spatial_crop,
        }

## Define Character Error Rate (CER) Metric

### Subtask:
Implement a function to calculate the Character Error Rate (CER), which will quantify the difference between the ground truth text and the model's predicted text. This metric will be crucial for evaluating the transcription accuracy of the models.


In [None]:
from jiwer import cer

def calculate_cer(ground_truth, hypothesis):
    """Calculates the Character Error Rate (CER) between two strings."""
    return cer(ground_truth, hypothesis)

print("CER calculation function defined.")

<a name="Train"></a>
### Train the model
Now let's train our model. We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support TRL's `DPOTrainer`!

We use our new `DeepSeekOCRDataCollator` which will help in our vision finetuning setup.

In [None]:
from transformers import Trainer, TrainingArguments
from unsloth import is_bf16_supported
FastVisionModel.for_training(finetuned_model) # Enable for training!
data_collator = DeepSeekOCRDataCollator(
    tokenizer = original_tokenizer,
    model = finetuned_model,
    image_size = 640,
    base_size = 1024,
    crop_mode = True,
    train_on_responses_only = True,
)
trainer = Trainer(
    model = finetuned_model,
    tokenizer = original_tokenizer,
    data_collator = data_collator, # Must use!
    train_dataset = train_dataset,
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 5,
        warmup_steps = 5,
        # max_steps = 60,
        num_train_epochs = 1, # Set this instead of max_steps for full training runs
        learning_rate = 2e-4,
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.001,
        lr_scheduler_type = "linear",
        seed = 3407,
        fp16 = not is_bf16_supported(),  # Use fp16 if bf16 is not supported
        bf16 = is_bf16_supported(),  # Use bf16 if supported
        output_dir = "outputs",
        report_to = "none",     # For Weights and Biases
        dataloader_num_workers=2,
        # You MUST put the below items for vision finetuning:
        remove_unused_columns = False,
    ),
)

In [None]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

In [None]:
trainer_stats = trainer.train()

In [None]:
# @title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
    f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

In [None]:
# @title Assign original tokenizer -> finetuned tokenizer after training
finetuned_tokenizer = original_tokenizer

<a name="Inference"></a>
### Inference
Let's run the model!

In [None]:
prompt = DEEPSEEK_PROMPT
image_file = 'your_image.jpg'
output_path = 'your/output/dir'
# Tiny: base_size = 512, image_size = 512, crop_mode = False
# Small: base_size = 640, image_size = 640, crop_mode = False
# Base: base_size = 1024, image_size = 1024, crop_mode = False
# Large: base_size = 1280, image_size = 1280, crop_mode = False

# Gundam: base_size = 1024, image_size = 640, crop_mode = True

res = finetuned_model.infer(original_tokenizer, prompt=prompt, image_file=image_file,
    output_path = output_path,
    image_size=640,
    base_size=1024,
    crop_mode=True,
    save_results = True,
    test_compress = False,
    eval_mode=False) # no need to return anything at this place


<a name="Save"></a>
### Saving, loading finetuned models
To save the final model as LoRA adapters, either use Huggingface's `push_to_hub` for an online save or `save_pretrained` for a local save.

**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!

In [None]:
finetuned_model.save_pretrained("lora_model")  # Local saving
finetuned_tokenizer.save_pretrained("lora_model")
# model.push_to_hub("your_name/lora_model", token = "...") # Online saving
# tokenizer.push_to_hub("your_name/lora_model", token = "...") # Online saving

Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:

In [None]:
if False:
    from unsloth import FastVisionModel
    model, tokenizer = FastVisionModel.from_pretrained(
        model_name = "lora_model", # YOUR MODEL YOU USED FOR TRAINING
        load_in_4bit = False, # Use 4bit to reduce memory use. False for 16bit LoRA.
        auto_model = AutoModel,
        trust_remote_code=True,
        unsloth_force_compile=True,
        use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
    )
    FastVisionModel.for_inference(model) # Enable for inference!

prompt = DEEPSEEK_PROMPT
image_file = 'your_image.jpg'
output_path = 'your/output/dir'

# Tiny: base_size = 512, image_size = 512, crop_mode = False
# Small: base_size = 640, image_size = 640, crop_mode = False
# Base: base_size = 1024, image_size = 1024, crop_mode = False
# Large: base_size = 1280, image_size = 1280, crop_mode = False

# Gundam: base_size = 1024, image_size = 640, crop_mode = True

res = finetuned_model.infer(original_tokenizer, prompt=prompt, image_file=image_file,
    output_path = output_path,
    image_size=640,
    base_size=1024,
    crop_mode=True,
    save_results = True,
    test_compress = False,
    eval_mode=False)


### Saving to float16 for VLLM

We also support saving to `float16` directly. Select `merged_16bit` for float16. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens.

In [None]:
# Select ONLY 1 to save! (Both not needed!)

# Save locally to 16bit
if False: model.save_pretrained_merged("unsloth_finetune", tokenizer,)

# To export and save to your Hugging Face account
if False: model.push_to_hub_merged("YOUR_USERNAME/unsloth_finetune", tokenizer, token = "PUT_HERE")

In [None]:
import os
# Define the folder to zip and the name of the zip file
folder_to_zip = "lora_model"
output_zip_file = f"{folder_to_zip}.zip"

# Use the !zip command to compress the folder
!zip -r {output_zip_file} {folder_to_zip}

print(f"Successfully created {output_zip_file}")
print(f"You can download it from the files tab on the left (if using Colab) or locate it in the current directory.")

In [None]:
!unzip lora_model.zip
print("Unzipping successfully!")

In [None]:
from unsloth import FastVisionModel
from transformers import AutoModel

# Reload the fine-tuned DeepSeek-OCR model with 4-bit quantization
finetuned_model, finetuned_tokenizer = FastVisionModel.from_pretrained(
    model_name = "lora_model",
    load_in_4bit = True, # Use 4bit quantization
    auto_model = AutoModel,
    trust_remote_code=True,
    unsloth_force_compile=True,
    use_gradient_checkpointing = False, # Not needed for inference
)

print("Fine-tuned DeepSeek-OCR model and tokenizer reloaded with 4-bit quantization.")

In [None]:
import os

test_samples_for_inference = []

for i, sample in enumerate(test_dataset):
    ground_truth_text = sample['messages'][1]['content']
    image_pil = sample['messages'][0]['images'][0]

    # Create a unique temporary filename
    temp_filename = f'temp_image_{i}.png'

    # Save the PIL image to the temporary file
    image_pil.save(temp_filename)

    test_samples_for_inference.append({
        'image_path': temp_filename,
        'ground_truth': ground_truth_text,
        'image_pil': image_pil # Keep the PIL object for potential future use or display
    })

print(f"Prepared {len(test_samples_for_inference)} test samples for inference, saving images to temporary files.")

### Evaluate Baseline and Finetuned Model

In [None]:
import numpy as np
import os
import collections
# The following variables are assumed to be defined in your notebook environment:
# original_model, original_tokenizer
# finetuned_model, finetuned_tokenizer
# calculate_cer(ground_truth, prediction)
# test_samples_for_inference (list of dicts, each having image_path, ground_truth, and data_type)


# --- 1. Configuration and Setup ---
# Create the output directory for illustrative results (input/output)
EVALUATION_DIR = os.path.join(OUTPUT_DIR, 'evaluation_results')
os.makedirs(EVALUATION_DIR, exist_ok=True)

# Define the data types for stratified evaluation (assumed)
DATA_TYPES = ["printed_text", "handwriting", "tables", "forms"]

# Initialize structure to store results and CERs for both models
results = {
    'original': {'predictions': [], 'cers': [], 'metrics_by_type': collections.defaultdict(lambda: {'cers': [], 'samples': []})},
    'finetuned': {'predictions': [], 'cers': [], 'metrics_by_type': collections.defaultdict(lambda: {'cers': [], 'samples': []})}
}


# --- 2. Utility Function: Run Inference ---
def run_inference(model, tokenizer, sample, prompt=DEEPSEEK_PROMPT, output_path=EVALUATION_DIR):
    """Performs inference using the custom model.infer() method."""
    image_file = sample['image_path']
    try:
        prediction = model.infer(
            tokenizer,
            prompt=prompt,
            image_file=image_file,
            output_path=output_path,
            eval_mode=True # Assumes infer() is modified to return the output when eval_mode=True
        )
    except Exception as e:
        print(f"Inference error: {e}")
        prediction = "" # Fallback

    # Normalize output
    if not isinstance(prediction, str):
        prediction = ""
    if hasattr(tokenizer, 'eos_token'):
        prediction = prediction.replace(tokenizer.eos_token, "").strip()
    else:
        prediction = prediction.strip()

    return prediction


# --- 3. Main Loop: Run Evaluation ---
print("Starting comprehensive evaluation...")

for i, sample in enumerate(test_samples_for_inference):

    # 3.1. Sample Classification
    data_type = sample.get('data_type', 'unspecified')
    if data_type not in DATA_TYPES:
        data_type = 'unspecified' # Handle case of unknown data type

    # 3.2. Run Original Model
    orig_pred = run_inference(original_model, original_tokenizer, sample)
    orig_cer = calculate_cer(sample['ground_truth'], orig_pred)

    results['original']['predictions'].append(orig_pred)
    results['original']['cers'].append(orig_cer)
    results['original']['metrics_by_type'][data_type]['cers'].append(orig_cer)

    # 3.3. Run Fine-tuned Model
    ft_pred = run_inference(finetuned_model, finetuned_tokenizer, sample)
    ft_cer = calculate_cer(sample['ground_truth'], ft_pred)

    results['finetuned']['predictions'].append(ft_pred)
    results['finetuned']['cers'].append(ft_cer)
    results['finetuned']['metrics_by_type'][data_type]['cers'].append(ft_cer)

    # 3.4. Store Sample for Error Analysis and Illustration
    sample_result = {
        'ground_truth': sample['ground_truth'],
        'original_prediction': orig_pred,
        'finetuned_prediction': ft_pred,
        'original_cer': orig_cer,
        'finetuned_cer': ft_cer,
        'image_path': sample['image_path'],
        'index': i
    }

    # Store sample for analysis (error analysis)
    # We store the same result object under both models for comparison
    results['original']['metrics_by_type'][data_type]['samples'].append(sample_result)
    results['finetuned']['metrics_by_type'][data_type]['samples'].append(sample_result)

    if (i + 1) % 50 == 0:
        print(f"Processed {i + 1}/{len(test_samples_for_inference)} samples.")


# --- 4. Calculate and Print Overall Results ---
print("\n" + "="*50)
print("             OVERALL EVALUATION SUMMARY")
print("="*50)

# Overall CER
avg_orig_cer = np.mean(results['original']['cers']) if results['original']['cers'] else 0
avg_ft_cer = np.mean(results['finetuned']['cers']) if results['finetuned']['cers'] else 0

print(f"Total evaluated samples: {len(test_samples_for_inference)}")
print(f"Average CER (Original Model): {avg_orig_cer:.4f}")
print(f"Average CER (Fine-tuned Model): {avg_ft_cer:.4f}")

# Comparison
improvement = avg_orig_cer - avg_ft_cer
print(f"Absolute CER Improvement: {improvement:+.4f}")
print(f"The Fine-tuned model {'Improved' if improvement > 0 else 'Degraded'} performance." if improvement != 0 else "Performance is similar.")


# --- 5. Evaluation by Data Type (Stratified Analysis) ---
print("\n" + "="*50)
print("             EVALUATION BY DATA TYPE")
print("="*50)

for dt in DATA_TYPES:
    orig_cers = results['original']['metrics_by_type'][dt]['cers']
    ft_cers = results['finetuned']['metrics_by_type'][dt]['cers']

    if orig_cers:
        avg_orig_dt_cer = np.mean(orig_cers)
        avg_ft_dt_cer = np.mean(ft_cers)
        dt_improvement = avg_orig_dt_cer - avg_ft_dt_cer

        print(f"\nData Type: {dt.upper()} (Samples: {len(orig_cers)})")
        print(f"  - Original CER: {avg_orig_dt_cer:.4f}")
        print(f"  - Fine-tuned CER: {avg_ft_dt_cer:.4f}")
        print(f"  - Change: {dt_improvement:+.4f} ({'Improved' if dt_improvement > 0 else 'Degraded'})")


# --- 6. Error Analysis and Illustration ---
print("\n" + "="*50)
print("         ERROR ANALYSIS AND EXAMPLES")
print("="*50)

# Find examples of significant improvement and degradation
significant_improvement_samples = []
significant_degradation_samples = []

for dt in DATA_TYPES:
    for sample_result in results['finetuned']['metrics_by_type'][dt]['samples']:
        # Significant improvement (FT CER is at least 0.5 lower than Original CER)
        if sample_result['original_cer'] > 0.5 and sample_result['original_cer'] - sample_result['finetuned_cer'] >= 0.5:
            significant_improvement_samples.append((dt, sample_result))

        # Significant degradation (FT CER is at least 0.5 higher than Original CER)
        if sample_result['finetuned_cer'] > 0.5 and sample_result['finetuned_cer'] - sample_result['original_cer'] >= 0.5:
            significant_degradation_samples.append((dt, sample_result))

# Print illustrative examples
def print_sample_details(title, samples):
    print(f"\n--- {title} (Total: {len(samples)} samples) ---")
    if not samples:
        print("No notable examples found.")
        return

    # Select the first 3 examples (or randomly)
    selected_samples = samples[:3]

    for dt, sample in selected_samples:
        print(f"\nSample #{sample['index']} ({dt.upper()}) | Image: {os.path.basename(sample['image_path'])}")
        print(f"  - Ground Truth (GT): {sample['ground_truth']}")
        print(f"  - Original (CER {sample['original_cer']:.4f}): {sample['original_prediction']}")
        print(f"  - Fine-tuned (CER {sample['finetuned_cer']:.4f}): {sample['finetuned_prediction']}")

        # Error analysis notes
        if sample['original_cer'] > 0 and sample['finetuned_cer'] == 0:
            print("  >>> Analysis: FT completely resolved the error.")
        elif sample['original_cer'] == 0 and sample['finetuned_cer'] > 0:
            print("  >>> Analysis: FT introduced an error (regression).")
        elif sample['original_cer'] > sample['finetuned_cer']:
            print("  >>> Analysis: FT improved accuracy, reducing errors.")
        elif sample['finetuned_cer'] > sample['original_cer']:
            print("  >>> Analysis: FT degraded accuracy compared to Original.")


print_sample_details("Significant Improvement Examples (FT > Original)", significant_improvement_samples)
print_sample_details("Significant Degradation Examples (FT < Original)", significant_degradation_samples)


# --- 7. Cleanup ---
# Note: If you want to keep the image files to view the error analysis, skip this section
# for sample in test_samples_for_inference:
#     if os.path.exists(sample['image_path']):
#         os.remove(sample['image_path'])
# print("\nCleaned up temporary image files.")

print("\nComprehensive evaluation complete.")