# Installation

In [None]:
%%capture
import os, re
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    import torch; v = re.match(r"[0-9]{1,}\.[0-9]{1,}", str(torch.__version__)).group(0)
    xformers = "xformers==" + ("0.0.33.post1" if v=="2.9" else "0.0.32.post2" if v=="2.8" else "0.0.29.post3")
    !pip install --no-deps bitsandbytes accelerate {xformers} peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth
    
!pip install transformers==4.56.2
!pip install --no-deps trl==0.22.2
!pip install jiwer
!pip install einops addict easydict

In [None]:
import os
import json
import math
import io

from huggingface_hub import snapshot_download
from unsloth import FastVisionModel, is_bf16_supported
import torch
from transformers import AutoModel, Trainer, TrainingArguments
import jiwer
from dataclasses import dataclass
from typing import Dict, List, Any, Tuple
from PIL import Image, ImageOps
from torch.nn.utils.rnn import pad_sequence

# Unsloth
Prepare OCR model

In [None]:
snapshot_download("unsloth/DeepSeek-OCR", local_dir = "deepseek_ocr")

In [None]:
os.environ["UNSLOTH_WARN_UNINITIALIZED"] = '0'
# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
    "unsloth/Qwen3-VL-8B-Instruct-bnb-4bit", # Qwen 3 vision support
    "unsloth/Qwen3-VL-8B-Thinking-bnb-4bit",
    "unsloth/Qwen3-VL-32B-Instruct-bnb-4bit",
    "unsloth/Qwen3-VL-32B-Thinking-bnb-4bit",
]

model, tokenizer = FastVisionModel.from_pretrained(
    "./deepseek_ocr",
    load_in_4bit = False, # Use 4bit to reduce memory use. False for 16bit LoRA.
    auto_model = AutoModel,
    trust_remote_code=True,
    unsloth_force_compile=True,
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
)

# Using Baseline Model

In [None]:
import os
import json
import shutil
from tqdm import tqdm

DATASETS_CONFIG_TEST = {
    "word" : "/kaggle/input/uit-hwdb/UIT_HWDB_word/UIT_HWDB_word/test_data",
    "line" : "/kaggle/input/uit-hwdb/UIT_HWDB_line/UIT_HWDB_line/test_data",
    "paragraph" : "/kaggle/input/uit-hwdb/UIT_HWDB_paragraph/UIT_HWDB_paragraph/test_data"
}

DATASET_CONFIG_TRAIN = {
"word": {
        "root_path": '/kaggle/input/uit-hwdb/UIT_HWDB_word/UIT_HWDB_word/train_data',
        "prompt": "<image>\nOCR Word."
    },
    "line": {
        "root_path": '/kaggle/input/uit-hwdb/UIT_HWDB_line/UIT_HWDB_line/train_data',
        "prompt": "<image>\nOCR Line."
    },
    "paragraph": {
        "root_path": '/kaggle/input/uit-hwdb/UIT_HWDB_paragraph/UIT_HWDB_paragraph/train_data',
        "prompt": "<image>\nFree OCR."
    }
}

In [None]:
output_working_dir = '/kaggle/working/temp_ocr_process/'
final_json_path = '/kaggle/working/merged_labels.json'

all_results = {} 
prompt = "<image>\nFree OCR."

for dataset_type, root_dir in DATASETS_CONFIG_TEST.items():
    print(f"\nPROCESSING: {dataset_type.upper()}")
    
    # Lấy file trong dataset hiện tại
    file_list = []
    for root, dirs, files in os.walk(root_dir):
        for file in files:
            if file.endswith(('.png', '.jpg', '.jpeg')):
                full_path = os.path.join(root, file)
                relative_path = os.path.relpath(full_path, root_dir) 
                file_list.append((full_path, relative_path))
    
    print(f"Found {len(file_list)} pícs")
    
    for i, (image_path, relative_name) in enumerate(tqdm(file_list)):
        if i == 5: break
        
        # Tạo key -> phân biệt ảnh này thuộc bộ nào
        unique_key = f"{dataset_type}_{relative_name}"
        
        # Tạo folder tạm
        spec_output_path = os.path.join(output_working_dir, f"{dataset_type}_{i}")
        os.makedirs(spec_output_path, exist_ok=True)

        try:
            # Model Inference
            model.infer(
                tokenizer,
                prompt = prompt,
                image_file = image_path,
                output_path = spec_output_path,
                base_size = 1024,
                image_size = 640,
                crop_mode = True,
                save_results = True,
                test_compress = False,
            )
            
            # Đọc kết quả
            content = ""
            generated_files = os.listdir(spec_output_path)
            for filename in generated_files:
                if filename.endswith(('.mmd', '.txt')):
                    with open(os.path.join(spec_output_path, filename), 'r', encoding='utf-8') as f:
                        content = f.read().strip()
                    break
            
            # Lưu vào dict tổng
            if content:
                all_results[unique_key] = content
            
        except Exception as e:
            print(f"Lỗi: {e}")
        finally:
            # Dọn folder temp
            if os.path.exists(spec_output_path):
                shutil.rmtree(spec_output_path)

In [None]:
# Lưu kết quả vào file json
with open(final_json_path, 'w', encoding='utf-8') as f:
    json.dump(all_results, f, ensure_ascii=False, indent=2)

### **CER Metric**

In [None]:
def evaluate_CER(merged_result_path):
    with open(merged_result_path, 'r', encoding='utf-8') as f:
        all_predictions = json.load(f)
    
    # Biến để báo cáo từng phần
    report = {} 
    
    # Duyệt từng word, line, paragraph
    for dataset_type, root_dir in DATASETS_CONFIG_TEST.items():
        print(f"\nScoring dataset: {dataset_type.upper()}")
        
        total_cer = 0
        count = 0
        
        # Duyệt qua các folder con 
        for root, dirs, files in os.walk(root_dir):
            if 'label.json' in files:
                label_path = os.path.join(root, 'label.json')
                
                # Load nhãn gốc
                try:
                    with open(label_path, 'r', encoding='utf-8') as f:
                        local_labels = json.load(f)
                except Exception as e:
                    print(f"Lỗi đọc file label tại {label_path}: {e}")
                    continue
    
                # Lấy tên folder hiện tại để tạo key
                current_folder_name = os.path.basename(root)
    
                # Duyệt từng ảnh trong file label này
                for img_name, ground_truth in local_labels.items():
                    
                    relative_path = os.path.join(current_folder_name, img_name)
                    unique_key = f"{dataset_type}_{relative_path}"
                    
                    # Lấy kết quả dự đoán từ file tổng
                    prediction = all_predictions.get(unique_key)
    
                    if prediction is None:
                        continue 
    
                    # Tính CER
                    gt_norm = str(ground_truth).strip()
                    pred_norm = str(prediction).strip()
                    
                    if not gt_norm: 
                        continue
    
                    cer = jiwer.cer(gt_norm, pred_norm)
                    total_cer += cer
                    count += 1
                    
                    # In mẫu sai nhiều
                    if cer > 0.5 and count % 100 == 0:
                         print(f" {unique_key} | CER: {cer:.2f}")
                         print(f" GT  : {gt_norm}")
                         print(f" Pred: {pred_norm}")
    
        # Tổng kết cho từng loại dataset
        if count > 0:
            avg_cer = total_cer / count
            report[dataset_type] = avg_cer
            #print(f"Result {dataset_type}: {count} samples | Avg CER: {avg_cer:.4f}")
    


    return report

In [None]:
merged_result_path = '/kaggle/working/merged_labels.json'

report = evaluate_CER(merged_result_path)

print("\n" + "-"*30)
print("FINAL CER REPORT")
print("-"*30)
for dtype, score in report.items():
    print(f"{dtype:<10}: {score:.4f} ({score*100:.2f}%)")

# Finetune Deepseek-OCR

We now add LoRA adapters for parameter efficient finetuning - this allows us to only efficiently train 1% of all parameters.

**[NEW]** We also support finetuning ONLY the vision part of the model, or ONLY the language part. Or you can select both! You can also select to finetune the attention or the MLP layers!

In [None]:
model = FastVisionModel.get_peft_model(
    model,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],

    r = 16,           # The larger, the higher the accuracy, but might overfit
    lora_alpha = 16,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
    # target_modules = "all-linear", # Optional now! Can specify a list if needed
)

## Data Prep

To format the dataset, all vision finetuning tasks should be formatted as follows:

```python
[
{ "role": "<|User|>",
  "content": "",
  "images": []
},
{ "role": "<|Assistant|>",
  "content": ""
},
]
```

### Creating formatted dataset

In [None]:
def create_conversation(image_path, text_label, instruction):
    return {
        "messages": [
            {
                "role": "<|User|>",
                "content": instruction,
                "images": [image_path]
            },
            {
                "role": "<|Assistant|>",
                "content": text_label
            }
        ]
    }

In [None]:
OUTPUT_FILE = '/kaggle/working/train_dataset.json'

final_dataset = []
stats = {k: 0 for k in DATASET_CONFIG_TRAIN.keys()}

for data_type, config in DATASET_CONFIG_TRAIN.items():
    root_dir = config["root_path"]
    prompt = config["prompt"]

    for root, dirs, files in os.walk(root_dir):
        if 'label.json' in files:
            label_path = os.path.join(root, 'label.json')

            try:
                with open(label_path, 'r', encoding = 'utf-8') as f:
                    label_data = json.load(f)
            except Exception as e:
                print(f"Error: Can not open file {label_path}")
                continue

            valid_imgs = [f for f in files if f.endswith(('.png', '.jpg', '.jpeg'))]
            for img_name in valid_imgs:
                if img_name in label_data:
                    full_img_path = os.path.join(root, img_name)
                    text_context = label_data[img_name]

                    sample = create_conversation(full_img_path, text_context, prompt)
                    final_dataset.append(sample)

                    stats[data_type] += 1

In [None]:
print("\n" + "="*30)
print("REPORT")
print("="*30)
total_count = 0
for dtype, count in stats.items():
    print(f"- {dtype.upper():<10}: {count} mẫu")
    total_count += count
print("-" * 30)
print(f"TỔNG CỘNG : {total_count} mẫu")

In [None]:
with open(OUTPUT_FILE, 'w', encoding='utf-8') as f:
        json.dump(final_dataset, f, ensure_ascii=False, indent=2)

In [None]:
# Chuyển file json sang Dataset, chia tỉ lệ word, line, para trong train data
TARGET_TOTAL_SAMPLES = 10000

RATIOS = {
    "word": 0.2,      
    "line": 0.7,      
    "paragraph": 0.1
}

with open(OUTPUT_FILE, 'r', encoding='utf-8') as f:
    all_data = json.load(f)

buckets = {
    "word": [],
    "line": [],
    "paragraph": []
}

for item in all_data:
    try:
        # Lấy đường dẫn
        image_path = item['messages'][0]['images'][0]
        path_lower = image_path.lower()

        # Check keyword trong đường dẫn và thêm vào buckets
        if "word" in path_lower:
            buckets["word"].append(item)
            
        elif "line" in path_lower:
            buckets["line"].append(item)

        elif "paragraph" in path_lower: 
            buckets["paragraph"].append(item)
            
        else:
            buckets["paragraph"].append(item) # Tạm gán vào paragraph
            
    except Exception as e:
        print(f"Lỗi mẫu dữ liệu: {e}")
        continue

for k, v in buckets.items():
    print(f"- {k.upper()}: {len(v)} mẫu")

# --- 2. LẤY MẪU THEO TỶ LỆ (SAMPLING) ---
final_train_list = []
random.seed(SEED)

for dtype, ratio in RATIOS.items():
    # Tính số lượng cần lấy
    n_needed = int(TARGET_TOTAL_SAMPLES * ratio)
    
    # Số lượng thực tế đang có
    n_available = len(buckets[dtype])
    
    if n_available == 0:
        print(f"Không có dữ liệu loại {dtype}!")
        continue
        
    # Nếu dữ liệu có ít hơn số cần lấy -> Lấy hết những gì đang có
    if n_available < n_needed:
        print(f"{dtype}: Cần {n_needed}, chỉ có {n_available}")
        selected_items = buckets[dtype]
    else:
        # Nếu dư -> random
        print(f"{dtype}: Lấy ngẫu nhiên {n_needed} mẫu từ {n_available}.")
        selected_items = random.sample(buckets[dtype], n_needed)
        
    final_train_list.extend(selected_items)

random.shuffle(final_train_list) # Trộn dataset

train_dataset = Dataset.from_list(final_train_list)

## Create Datacollator

In [None]:
from deepseek_ocr.modeling_deepseekocr import (
    format_messages,
    text_encode,
    BasicImageTransform,
    dynamic_preprocess,
)

@dataclass
class DeepSeekOCRDataCollator:
    """
    Args:
        tokenizer: Tokenizer
        model: Model
        image_size: Size for image patches (default: 640)
        base_size: Size for global view (default: 1024)
        crop_mode: Whether to use dynamic cropping for large images
        train_on_responses_only: If True, only train on assistant responses (mask user prompts)
    """
    tokenizer: Any
    model: Any
    image_size: int = 640
    base_size: int = 1024
    crop_mode: bool = True
    image_token_id: int = 128815
    train_on_responses_only: bool = True

    def __init__(
        self,
        tokenizer,
        model,
        image_size: int = 640,
        base_size: int = 1024,
        crop_mode: bool = True,
        train_on_responses_only: bool = True,
    ):
        self.tokenizer = tokenizer
        self.model = model
        self.image_size = image_size
        self.base_size = base_size
        self.crop_mode = crop_mode
        self.image_token_id = 128815
        self.dtype = model.dtype  # Get dtype from model
        self.train_on_responses_only = train_on_responses_only

        self.image_transform = BasicImageTransform(
            mean=(0.5, 0.5, 0.5),
            std=(0.5, 0.5, 0.5),
            normalize=True
        )
        self.patch_size = 16
        self.downsample_ratio = 4

        # Get BOS token ID from tokenizer
        if hasattr(tokenizer, 'bos_token_id') and tokenizer.bos_token_id is not None:
            self.bos_id = tokenizer.bos_token_id
        else:
            self.bos_id = 0
            print(f"Warning: tokenizer has no bos_token_id, using default: {self.bos_id}")

    def deserialize_image(self, image_data) -> Image.Image:
        """Convert image data (bytes dict or PIL Image) to PIL Image in RGB mode"""
        if isinstance(image_data, str):
            return Image.open(image_data).convert("RGB")
        
        if isinstance(image_data, Image.Image):
            return image_data.convert("RGB")
        elif isinstance(image_data, dict) and 'bytes' in image_data:
            image_bytes = image_data['bytes']
            image = Image.open(io.BytesIO(image_bytes))
            return image.convert("RGB")
        else:
            raise ValueError(f"Unsupported image format: {type(image_data)}")

    def calculate_image_token_count(self, image: Image.Image, crop_ratio: Tuple[int, int]) -> int:
        """Calculate the number of tokens this image will generate"""
        num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio)
        num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)

        width_crop_num, height_crop_num = crop_ratio

        if self.crop_mode:
            img_tokens = num_queries_base * num_queries_base + 1
            if width_crop_num > 1 or height_crop_num > 1:
                img_tokens += (num_queries * width_crop_num + 1) * (num_queries * height_crop_num)
        else:
            img_tokens = num_queries * num_queries + 1

        return img_tokens

    def process_image(self, image: Image.Image) -> Tuple[List, List, List, List, Tuple[int, int]]:
        """
        Process a single image based on crop_mode and size thresholds

        Returns:
            Tuple of (images_list, images_crop_list, images_spatial_crop, tokenized_image, crop_ratio)
        """
        images_list = []
        images_crop_list = []
        images_spatial_crop = []

        if self.crop_mode:
            # Determine crop ratio based on image size
            if image.size[0] <= 640 and image.size[1] <= 640:
                crop_ratio = (1, 1)
                images_crop_raw = []
            else:
                images_crop_raw, crop_ratio = dynamic_preprocess(
                    image, min_num=2, max_num=9,
                    image_size=self.image_size, use_thumbnail=False
                )

            # Process global view with padding
            global_view = ImageOps.pad(
                image, (self.base_size, self.base_size),
                color=tuple(int(x * 255) for x in self.image_transform.mean)
            )
            images_list.append(self.image_transform(global_view).to(self.dtype))

            width_crop_num, height_crop_num = crop_ratio
            images_spatial_crop.append([width_crop_num, height_crop_num])

            # Process local views (crops) if applicable
            if width_crop_num > 1 or height_crop_num > 1:
                for crop_img in images_crop_raw:
                    images_crop_list.append(
                        self.image_transform(crop_img).to(self.dtype)
                    )

            # Calculate image tokens
            num_queries = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio)
            num_queries_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)

            tokenized_image = ([self.image_token_id] * num_queries_base + [self.image_token_id]) * num_queries_base
            tokenized_image += [self.image_token_id]

            if width_crop_num > 1 or height_crop_num > 1:
                tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * (
                    num_queries * height_crop_num)

        else:  # crop_mode = False
            crop_ratio = (1, 1)
            images_spatial_crop.append([1, 1])

            # For smaller base sizes, resize; for larger, pad
            if self.base_size <= 640:
                resized_image = image.resize((self.base_size, self.base_size), Image.LANCZOS)
                images_list.append(self.image_transform(resized_image).to(self.dtype))
            else:
                global_view = ImageOps.pad(
                    image, (self.base_size, self.base_size),
                    color=tuple(int(x * 255) for x in self.image_transform.mean)
                )
                images_list.append(self.image_transform(global_view).to(self.dtype))

            num_queries = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio)
            tokenized_image = ([self.image_token_id] * num_queries + [self.image_token_id]) * num_queries
            tokenized_image += [self.image_token_id]

        return images_list, images_crop_list, images_spatial_crop, tokenized_image, crop_ratio

    def process_single_sample(self, messages: List[Dict]) -> Dict[str, Any]:
            """
            Process a single conversation into model inputs.
            """

            # --- 1. Setup ---
            images = []
            for message in messages:
                if "images" in message and message["images"]:
                    for img_data in message["images"]:
                        if img_data is not None:
                            pil_image = self.deserialize_image(img_data)
                            images.append(pil_image)

            if not images:
                raise ValueError("No images found in sample. Please ensure all samples contain images.")

            tokenized_str = []
            images_seq_mask = []
            images_list, images_crop_list, images_spatial_crop = [], [], []

            prompt_token_count = -1 # Index to start training
            assistant_started = False
            image_idx = 0

            # Add BOS token at the very beginning
            tokenized_str.append(self.bos_id)
            images_seq_mask.append(False)

            for message in messages:
                role = message["role"]
                content = message["content"]

                # Check if this is the assistant's turn
                if role == "<|Assistant|>":
                    if not assistant_started:
                        # This is the split point. All tokens added *so far*
                        # are part of the prompt.
                        prompt_token_count = len(tokenized_str)
                        assistant_started = True

                    # Append the EOS token string to the *end* of assistant content
                    content = f"{content.strip()} {self.tokenizer.eos_token}"

                # Split this message's content by the image token
                text_splits = content.split('<image>')

                for i, text_sep in enumerate(text_splits):
                    # Tokenize the text part
                    tokenized_sep = text_encode(self.tokenizer, text_sep, bos=False, eos=False)
                    tokenized_str.extend(tokenized_sep)
                    images_seq_mask.extend([False] * len(tokenized_sep))

                    # If this text is followed by an <image> tag
                    if i < len(text_splits) - 1:
                        if image_idx >= len(images):
                            raise ValueError(
                                f"Data mismatch: Found '<image>' token but no corresponding image."
                            )

                        # Process the image
                        image = images[image_idx]
                        img_list, crop_list, spatial_crop, tok_img, _ = self.process_image(image)

                        images_list.extend(img_list)
                        images_crop_list.extend(crop_list)
                        images_spatial_crop.extend(spatial_crop)

                        # Add image placeholder tokens
                        tokenized_str.extend(tok_img)
                        images_seq_mask.extend([True] * len(tok_img))

                        image_idx += 1 # Move to the next image

            # --- 3. Validation and Final Prep ---
            if image_idx != len(images):
                raise ValueError(
                    f"Data mismatch: Found {len(images)} images but only {image_idx} '<image>' tokens were used."
                )

            # If we never found an assistant message, we're in a weird state
            # (e.g., user-only prompt). We mask everything.
            if not assistant_started:
                print("Warning: No assistant message found in sample. Masking all tokens.")
                prompt_token_count = len(tokenized_str)

            # Prepare image tensors
            images_ori = torch.stack(images_list, dim=0)
            images_spatial_crop_tensor = torch.tensor(images_spatial_crop, dtype=torch.long)

            if images_crop_list:
                images_crop = torch.stack(images_crop_list, dim=0)
            else:
                images_crop = torch.zeros((1, 3, self.base_size, self.base_size), dtype=self.dtype)

            return {
                "input_ids": torch.tensor(tokenized_str, dtype=torch.long),
                "images_seq_mask": torch.tensor(images_seq_mask, dtype=torch.bool),
                "images_ori": images_ori,
                "images_crop": images_crop,
                "images_spatial_crop": images_spatial_crop_tensor,
                "prompt_token_count": prompt_token_count, # This is now accurate
            }

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        """Collate batch of samples"""
        batch_data = []

        # Process each sample
        for feature in features:
            try:
                processed = self.process_single_sample(feature['messages'])
                batch_data.append(processed)
            except Exception as e:
                print(f"Error processing sample: {e}")
                continue

        if not batch_data:
            raise ValueError("No valid samples in batch")

        # Extract lists
        input_ids_list = [item['input_ids'] for item in batch_data]
        images_seq_mask_list = [item['images_seq_mask'] for item in batch_data]
        prompt_token_counts = [item['prompt_token_count'] for item in batch_data]

        # Pad sequences
        input_ids = pad_sequence(input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        images_seq_mask = pad_sequence(images_seq_mask_list, batch_first=True, padding_value=False)

        # Create labels
        labels = input_ids.clone()

        # Mask padding tokens
        labels[labels == self.tokenizer.pad_token_id] = -100

        # Mask image tokens (model shouldn't predict these)
        labels[images_seq_mask] = -100

        # Mask user prompt tokens when train_on_responses_only=True (only train on assistant responses)
        if self.train_on_responses_only:
            for idx, prompt_count in enumerate(prompt_token_counts):
                if prompt_count > 0:
                    labels[idx, :prompt_count] = -100

        # Create attention mask
        attention_mask = (input_ids != self.tokenizer.pad_token_id).long()

        # Prepare images batch (list of tuples)
        images_batch = []
        for item in batch_data:
            images_batch.append((item['images_crop'], item['images_ori']))

        # Stack spatial crop info
        images_spatial_crop = torch.cat([item['images_spatial_crop'] for item in batch_data], dim=0)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "images": images_batch,
            "images_seq_mask": images_seq_mask,
            "images_spatial_crop": images_spatial_crop,
        }

## Train model

In [None]:
FastVisionModel.for_training(model) # Bật model để train

data_collator = DeepSeekOCRDataCollator(
    tokenizer=tokenizer,
    model = model,
    image_size=640,
    base_size=1024,
    crop_mode=True,
    train_on_responses_only=True,
)

trainer = Trainer(
    model = model,
    processing_class = tokenizer, 
    data_collator = data_collator,
    train_dataset = converted_dataset,
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        #max_steps = 60,
        num_train_epochs = 3, # Chạy dữ liệu 3 lần
        learning_rate = 2e-4,
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.001,
        lr_scheduler_type = "linear",
        seed = 3407,
        fp16 = not is_bf16_supported(),  # Use fp16 if bf16 is not supported
        bf16 = is_bf16_supported(),  # Use bf16 if supported
        output_dir = "outputs",
        report_to = "none",     # For Weights and Biases
        dataloader_num_workers=2, # Dùng 2 nhân CPU để chạy
        # You MUST put the below items for vision finetuning:
        remove_unused_columns = False,
    ),
)

In [None]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

In [None]:
trainer_stats = trainer.train()

In [None]:
# @title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
    f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

<a name="Inference"></a>
# Inference
Let's run the model!

In [None]:
prompt = "<image>\nFree OCR. "
image_file = '/kaggle/input/uit-hwdb-word/UIT_HWDB_word/test_data/250/1.jpg'
output_path = '/kaggle/working/results/'

# Tiny: base_size = 512, image_size = 512, crop_mode = False
# Small: base_size = 640, image_size = 640, crop_mode = False
# Base: base_size = 1024, image_size = 1024, crop_mode = False
# Large: base_size = 1280, image_size = 1280, crop_mode = False

# Gundam: base_size = 1024, image_size = 640, crop_mode = True

res = model.infer(tokenizer, prompt=prompt, image_file=image_file,
    output_path = output_path,
    image_size=640,
    base_size=1024,
    crop_mode=True,
    save_results = True,
    test_compress = False)


<a name="Save"></a>
### Saving, loading finetuned models
To save the final model as LoRA adapters, either use Huggingface's `push_to_hub` for an online save or `save_pretrained` for a local save.

**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!

In [None]:
model.save_pretrained("lora_model")  # Local saving
tokenizer.save_pretrained("lora_model")
# model.push_to_hub("your_name/lora_model", token = "...") # Online saving
# tokenizer.push_to_hub("your_name/lora_model", token = "...") # Online saving

Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:

In [None]:
if False:
    from unsloth import FastVisionModel
    model, tokenizer = FastVisionModel.from_pretrained(
        model_name = "lora_model", # YOUR MODEL YOU USED FOR TRAINING
        load_in_4bit = False, # Use 4bit to reduce memory use. False for 16bit LoRA.
        auto_model = AutoModel,
        trust_remote_code=True,
        unsloth_force_compile=True,
        use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
    )
    FastVisionModel.for_inference(model) # Enable for inference!

prompt = "<image>\nFree OCR. "
image_file = 'your_image.jpg'
output_path = 'your/output/dir'

# Tiny: base_size = 512, image_size = 512, crop_mode = False
# Small: base_size = 640, image_size = 640, crop_mode = False
# Base: base_size = 1024, image_size = 1024, crop_mode = False
# Large: base_size = 1280, image_size = 1280, crop_mode = False

# Gundam: base_size = 1024, image_size = 640, crop_mode = True

res = model.infer(tokenizer, prompt=prompt, image_file=image_file,
    output_path = output_path,
    image_size=640,
    base_size=1024,
    crop_mode=True,
    save_results = True,
    test_compress = False)


### Saving to float16 for VLLM

We also support saving to `float16` directly. Select `merged_16bit` for float16. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens.

In [None]:
# Select ONLY 1 to save! (Both not needed!)

# Save locally to 16bit
if False: model.save_pretrained_merged("unsloth_finetune", tokenizer,)

# To export and save to your Hugging Face account
if False: model.push_to_hub_merged("YOUR_USERNAME/unsloth_finetune", tokenizer, token = "PUT_HERE")