# Install packages
*inspired from Unsloth's Deepseek-OCR fine-tuning notebook*

In [None]:
!pip install transformers==4.56.2
!pip install --no-deps trl==0.22.2
!pip install jiwer
!pip install einops addict easydict
!pip install verovio
!pip install faker
!pip install peft

### HF Login for Kaggle

In [None]:
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
token = user_secrets.get_secret("HF_TOKEN")
login(token=token)

### Hotfix (Transformers >= 4.46 compatibility)

In [None]:
import transformers.models.llama.modeling_llama
if not hasattr(transformers.models.llama.modeling_llama, "LlamaFlashAttention2"):
    print(">>> Monkeypatching LlamaFlashAttention2 for DeepSeek-OCR compatibility...")
    transformers.models.llama.modeling_llama.LlamaFlashAttention2 = transformers.models.llama.modeling_llama.LlamaAttention

# Imports

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from transformers import (
    AutoModel,
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    PreTrainedModel
)
from transformers.modeling_outputs import CausalLMOutputWithPast
import torchvision.transforms as T
import warnings
import gc
from PIL import Image, ImageDraw, ImageFont
import os
import json
import random
from faker import Faker
from peft import get_peft_model, LoraConfig
import types

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Projector MLP

In [None]:
class DeepSeekOCRToGOTProjector(nn.Module):
    def __init__(self, encoder_dim, decoder_dim=1024):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(encoder_dim, decoder_dim * 2),
            nn.GELU(),
            nn.Linear(decoder_dim * 2, decoder_dim),
            nn.LayerNorm(decoder_dim)
        )

    def forward(self, x):
        return self.net(x)

# Fusion Model

In [None]:
class DeepSeekOCRToGOTProjector(nn.Module):
    def __init__(self, encoder_dim, decoder_dim=1024):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(encoder_dim, decoder_dim * 2),
            nn.GELU(),
            nn.Linear(decoder_dim * 2, decoder_dim),
            nn.LayerNorm(decoder_dim)
        )

    def forward(self, x):
        return self.net(x)

class DeepSeekGOTFusion(nn.Module):
    def __init__(self, deepseek_ocr_path, got_path, tokenizer, use_lora=True):
        super().__init__()
        warnings.filterwarnings("ignore")
        frozen_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        self.vision_dtype = frozen_dtype

        # --- 1. Load DeepSeek Encoder ---
        print(f">>> Loading DeepSeek-OCR Encoder...")
        tmp_ds = AutoModel.from_pretrained(
            deepseek_ocr_path,
            trust_remote_code=True,
            # device_map="cpu",
            torch_dtype=frozen_dtype,
            # low_cpu_mem_usage=True
        )
        if hasattr(tmp_ds, "model"):
            base_model = tmp_ds.model
        else:
            base_model = tmp_ds

        if hasattr(base_model, "deep_encoder"):
            self.vision_tower = base_model.deep_encoder
        elif hasattr(base_model, "vision_model"):
            self.vision_tower = base_model.vision_model
        else:
            self.vision_tower = base_model

        self.vision_dim = 1024
        if hasattr(self.vision_tower, "config"):
            self.vision_dim = getattr(self.vision_tower.config, "hidden_size", 1024)
        if self.vision_dim == 1280:
            self.vision_dim = 1024
        
        # self.vision_tower = self.vision_tower.to("cuda")
        del tmp_ds
        gc.collect()

        # --- 2. Load GOT-OCR Decoder ---
        print(f">>> Loading GOT-OCR Decoder (AutoModel)...")
        tmp_got = AutoModel.from_pretrained(
            got_path,
            trust_remote_code=True,
            # device_map="cpu",
            torch_dtype=frozen_dtype,
            # low_cpu_mem_usage=True
        )

        if hasattr(tmp_got, "language_model"):
            self.decoder = tmp_got.language_model
        else:
            self.decoder = tmp_got.model

        self.decoder_dim = self.decoder.config.hidden_size
        # self.decoder = self.decoder.to("cuda")

        # --- 3. Extract LM Head ---
        self.lm_head = None
        if hasattr(tmp_got, "lm_head"):
            self.lm_head = tmp_got.lm_head
        elif hasattr(self.decoder, "lm_head"):
            self.lm_head = self.decoder.lm_head
        
        if self.lm_head is None:
            vocab_size = self.decoder.config.vocab_size
            self.lm_head = nn.Linear(self.decoder_dim, vocab_size, bias=False)
            if hasattr(self.decoder, "embed_tokens"):
                self.lm_head.weight = self.decoder.embed_tokens.weight
            elif hasattr(self.decoder, "wte"):
                self.lm_head.weight = self.decoder.wte.weight
        
        self.lm_head = self.lm_head.to(dtype=frozen_dtype) #, device="cuda")
        del tmp_got
        gc.collect()
        torch.cuda.empty_cache()

        # --- 4. ROBUST MONKEY PATCH ---
        original_forward = self.decoder.forward
        
        def patched_forward(input_ids=None, past_key_values=None, attention_mask=None, 
                            token_type_ids=None, position_ids=None, head_mask=None, 
                            inputs_embeds=None, encoder_hidden_states=None, 
                            encoder_attention_mask=None, use_cache=None, 
                            output_attentions=None, output_hidden_states=None, 
                            return_dict=None, labels=None, **kwargs):
            
            # STRICT FILTERING: Only pass arguments that GOTQwenModel definitely accepts.
            # We explicitly DROP 'token_type_ids', 'labels', 'head_mask', etc.
            
            return original_forward(
                input_ids=input_ids,
                past_key_values=past_key_values,
                attention_mask=attention_mask,
                # token_type_ids=token_type_ids,  <-- REMOVED (The cause of your error)
                position_ids=position_ids,
                # head_mask=head_mask,            <-- REMOVED
                inputs_embeds=inputs_embeds,
                # encoder_hidden_states=...       <-- REMOVED
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict
            )
        
        self.decoder.forward = patched_forward

        # Patch prepare_inputs to also exclude token_type_ids
        if not hasattr(self.decoder, "prepare_inputs_for_generation"):
            def _prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
                return {
                    "input_ids": input_ids,
                    "past_key_values": past_key_values,
                    "attention_mask": attention_mask,
                    "inputs_embeds": kwargs.get("inputs_embeds", None)
                    # No token_type_ids here either
                }
            self.decoder.prepare_inputs_for_generation = types.MethodType(_prepare_inputs_for_generation, self.decoder)

        # --- 5. LoRA Injection ---
        self.decoder.requires_grad_(False) 
        if use_lora:
            print(">>> Injecting LoRA Adapters...")
            lora_config = LoraConfig(
                r=16, 
                lora_alpha=32, 
                target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], 
                lora_dropout=0.05, 
                bias="none", 
                task_type="CAUSAL_LM"
            )
            self.decoder = get_peft_model(self.decoder, lora_config)
            self.decoder.print_trainable_parameters() 

        # --- 6. Projector ---
        self.projector = DeepSeekOCRToGOTProjector(self.vision_dim, self.decoder_dim)
        self.projector.apply(self._init_weights)
        
        self.img_start_id = tokenizer.convert_tokens_to_ids("<img>") 
        self.img_end_id = tokenizer.convert_tokens_to_ids("</img>") 

        self.vision_tower.requires_grad_(False)
        self.lm_head.requires_grad_(False)
        self.projector.requires_grad_(True) 
        # self.projector = self.projector.to(device="cuda", dtype=torch.float32)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.01)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
        if hasattr(self.decoder, "gradient_checkpointing_enable"):
            self.decoder.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)

    def save_pretrained(self, save_directory):
        if not os.path.exists(save_directory): os.makedirs(save_directory)
        torch.save(self.projector.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
        if hasattr(self.decoder, "save_pretrained"):
             self.decoder.save_pretrained(save_directory)
        with open(os.path.join(save_directory, "config.json"), "w") as f:
            json.dump({
                "vision_dim": self.vision_dim, 
                "decoder_dim": self.decoder_dim,
                "architecture": "DeepSeekOCRToGOTProjector"
            }, f, indent=4)

    def forward(self, pixel_values, input_ids, attention_mask=None, labels=None, **kwargs):
        # 1. Vision Encode
        with torch.no_grad():
            try:
                vision_out = self.vision_tower(pixel_values.to(self.vision_dtype), patch_embeds=None)
            except TypeError:
                vision_out = self.vision_tower(pixel_values.to(self.vision_dtype))
            
            if isinstance(vision_out, (tuple, list)):
                features = vision_out[0]
            elif hasattr(vision_out, "last_hidden_state"):
                features = vision_out.last_hidden_state
            else:
                features = vision_out
            features = features.detach()

        # 2. Project
        vision_embeds = self.projector(features.to(torch.float32)).to(self.vision_dtype)
        B, N, _ = vision_embeds.shape
        device = vision_embeds.device

        # 3. Text Embeds
        # Get embeddings from the underlying model (bypass LoRA if needed to find the embedding layer)
        base_decoder = self.decoder.get_base_model() if hasattr(self.decoder, "get_base_model") else self.decoder
        if hasattr(base_decoder, "get_input_embeddings"):
            input_emb_fn = base_decoder.get_input_embeddings()
        elif hasattr(base_decoder, "model") and hasattr(base_decoder.model, "embed_tokens"):
             input_emb_fn = base_decoder.model.embed_tokens
        else:
             input_emb_fn = base_decoder.embed_tokens

        start_embeds = input_emb_fn(torch.tensor([self.img_start_id], device=device)).expand(B, 1, -1)
        end_embeds = input_emb_fn(torch.tensor([self.img_end_id], device=device)).expand(B, 1, -1)
        text_embeds = input_emb_fn(input_ids)

        inputs_embeds = torch.cat([start_embeds, vision_embeds, end_embeds, text_embeds], dim=1)
        
        vision_len = N + 2
        full_mask = None
        if attention_mask is not None:
            v_mask = torch.ones((B, vision_len), device=device, dtype=attention_mask.dtype)
            full_mask = torch.cat([v_mask, attention_mask], dim=1)

        # 4. Decoder Call
        # We pass 'labels' to self.decoder. The monkey patch above will capture it (satisfying LoRA/Trainer)
        # but NOT pass it to the actual GOTQwenModel (preventing the crash).
        outputs = self.decoder(
            input_ids=torch.zeros((B, inputs_embeds.shape[1]), device=device, dtype=torch.long),
            inputs_embeds=inputs_embeds,
            attention_mask=full_mask,
            labels=labels, 
            return_dict=True,
            use_cache=False 
        )
        
        hidden_states = outputs.last_hidden_state
        relevant_hidden = hidden_states[:, vision_len-1:-1, :].contiguous()
        
        if self.lm_head is not None:
            logits = self.lm_head(relevant_hidden)
        else:
            raise ValueError("LM Head is not defined.")

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            labels = labels.to(logits.device)
            loss = loss_fct(logits.view(-1, self.decoder.config.vocab_size).to(torch.float32), labels.view(-1))

        return CausalLMOutputWithPast(loss=loss, logits=logits if not self.training else None)

# Dataset

In [None]:
class RealTextOCRDataset(Dataset):
    def __init__(self, tokenizer, num_samples=20000):
        self.tokenizer = tokenizer
        self.num_samples = num_samples
        self.fake = Faker('en_US')  # English generator
        
        self.transform = T.Compose([
            T.Resize((1024, 1024), interpolation=T.InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                        std=[0.26862954, 0.26130258, 0.27577711])
        ])

    def __len__(self):
        return self.num_samples

    def generate_content(self):
        # Generate varied content types to make the model robust
        r = random.random()
        if r < 0.4:
            # Type 1: Standard Sentences (The easiest for LLMs)
            return self.fake.sentence(nb_words=10)
        elif r < 0.7:
            # Type 2: Addresses (Structured data)
            return self.fake.address().replace('\n', ', ')
        else:
            # Type 3: Names and Phone numbers
            return f"{self.fake.name()} - {self.fake.phone_number()}"

    def generate_image(self, text):
        # 1. Random Background (White-ish)
        bg_color = random.randint(230, 255)
        img = Image.new('RGB', (1024, 1024), color=(bg_color, bg_color, bg_color))
        draw = ImageDraw.Draw(img)

        # 2. Font Management
        try:
            # Attempt to use a larger, clearer font size
            font_size = random.randint(40, 80) 
            font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf", font_size)
        except IOError:
            font = ImageFont.load_default()

        # 3. Draw Text (Centered-ish)
        x = random.randint(50, 100)
        y = random.randint(200, 500)
        
        # Simple text wrapping logic
        words = text.split()
        current_line = ""
        for word in words:
            if (len(current_line) + len(word)) * (font_size * 0.5) > 800:
                draw.text((x, y), current_line, fill=(0, 0, 0), font=font)
                y += font_size + 10
                current_line = word + " "
            else:
                current_line += word + " "
        
        # Draw the last line
        draw.text((x, y), current_line, fill=(0, 0, 0), font=font)

        return img

    def __getitem__(self, idx):
        text = self.generate_content()
        image = self.generate_image(text)

        pixel_values = self.transform(image)
        prompt = f"OCR: {text}{self.tokenizer.eos_token}"
        
        # Consistent prefix masking logic from Stage 1 setup
        prefix_enc = self.tokenizer("OCR: ", add_special_tokens=False)
        prefix_len = len(prefix_enc.input_ids)

        encodings = self.tokenizer(
            prompt,
            max_length=512,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )
        
        input_ids = encodings.input_ids.squeeze(0)
        attention_mask = encodings.attention_mask.squeeze(0)
        labels = input_ids.clone()
        
        # Mask prefix and padding
        starts_with_bos = (input_ids[0] == self.tokenizer.bos_token_id)
        offset = 1 if starts_with_bos else 0
        labels[:prefix_len + offset] = -100
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            "pixel_values": pixel_values,
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

# Training configuration

In [None]:
DEEPSEEK_OCR_PATH = "deepseek-ai/DeepSeek-OCR"
GOT_PATH = "stepfun-ai/GOT-OCR2_0"
OUTPUT_DIR = "./deepseek_ocr_got_final"

In [None]:
print(">>> 1. Loading Tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(GOT_PATH, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(">>> 2. Initializing RealTextOCRDataset...")
train_dataset = RealTextOCRDataset(tokenizer, num_samples=20000)

def collate_fn(batch):
    return {
        "pixel_values": torch.stack([x['pixel_values'] for x in batch]),
        "input_ids": torch.stack([x['input_ids'] for x in batch]),
        "labels": torch.stack([x['labels'] for x in batch]),
        "attention_mask": torch.stack([x['attention_mask'] for x in batch])
    }
    
print(">>> 3. Initializing Fusion Model...")
model = DeepSeekGOTFusion(
    DEEPSEEK_OCR_PATH, 
    GOT_PATH, 
    tokenizer, 
)

In [None]:
MAX_STEPS = 63

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    learning_rate=1e-5,
    max_steps=MAX_STEPS,   # Use max_steps for streaming
    fp16=True,
    gradient_checkpointing=True,
    logging_steps=1,
    save_strategy="steps",
    save_steps=50,      
    remove_unused_columns=False,
    report_to="none",
    save_safetensors=False,
    dataloader_pin_memory=False,
    prediction_loss_only=True,
    max_grad_norm=0.5,
    lr_scheduler_type="cosine",
    warmup_steps=5,
    ddp_find_unused_parameters=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=collate_fn,
)

# Train

In [None]:
print(">>> 4. Starting Training...")

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f">>> Trainable Parameters: {trainable_params:,}")

trainer.train()

print(f">>> 5. Saving Projector to {OUTPUT_DIR}...")
model.save_pretrained(OUTPUT_DIR)
print(">>> Done.")