In [1]:
# import os, re
# if "COLAB_" not in "".join(os.environ.keys()):
#     !pip install unsloth
# else:
#     # Do this only in Colab notebooks! Otherwise use pip install unsloth
#     import torch; v = re.match(r"[0-9]{1,}\.[0-9]{1,}", str(torch.__version__)).group(0)
#     xformers = "xformers==" + ("0.0.33.post1" if v=="2.9" else "0.0.32.post2" if v=="2.8" else "0.0.29.post3")
#     !pip install --no-deps bitsandbytes accelerate {xformers} peft trl triton cut_cross_entropy unsloth_zoo
#     !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer
#     !pip install --no-deps unsloth
# !pip install transformers==4.56.2
# !pip install --no-deps trl==0.22.2
# !pip install jiwer
# !pip install einops addict easydict
# !pip install verovio

In [2]:
# from huggingface_hub import snapshot_download, login
# snapshot_download("unsloth/DeepSeek-OCR", local_dir = "deepseek_ocr")

In [3]:
# !pip install faker

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, IterableDataset
from transformers import (
    AutoModel,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    AutoConfig
)
from transformers.modeling_outputs import CausalLMOutputWithPast
import torchvision.transforms as T
from datasets import load_dataset
import warnings
import gc
from PIL import Image, ImageDraw, ImageFont
import os
import json
import random
from faker import Faker

# Must be set before torch is imported to take effect
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# ==============================================================================
# 0. HOTFIX FOR DEEPSEEK-OCR (Transformers >= 4.46 compatibility)
# ==============================================================================
import transformers.models.llama.modeling_llama
if not hasattr(transformers.models.llama.modeling_llama, "LlamaFlashAttention2"):
    print(">>> Monkeypatching LlamaFlashAttention2 for DeepSeek-OCR compatibility...")
    transformers.models.llama.modeling_llama.LlamaFlashAttention2 = transformers.models.llama.modeling_llama.LlamaAttention

# ==============================================================================
# 1. THE PROJECTOR (The "Bridge")
# ==============================================================================
class DeepSeekOCRToGOTProjector(nn.Module):
    def __init__(self, encoder_dim, decoder_dim=1024):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(encoder_dim, decoder_dim * 2),
            nn.GELU(),
            nn.Linear(decoder_dim * 2, decoder_dim),
            nn.LayerNorm(decoder_dim)
        )

    def forward(self, x):
        return self.net(x)

# ==============================================================================
# 2. THE FUSION MODEL (IDENTICAL TO STAGE 1 ROBUST LOADING)
# ==============================================================================
class DeepSeekGOTFusion(nn.Module):
    def __init__(self, deepseek_ocr_path, got_path, tokenizer, projector_path=None):
        super().__init__()
        warnings.filterwarnings("ignore")
        frozen_dtype = torch.float16
        self.vision_dtype = frozen_dtype

        # --- A. LOAD DEEPSEEK-OCR ENCODER ---
        print(f">>> Loading DeepSeek-OCR Encoder...")
        tmp_ds = AutoModel.from_pretrained(
            deepseek_ocr_path,
            trust_remote_code=True,
            device_map="cpu",
            torch_dtype=frozen_dtype,
            low_cpu_mem_usage=True
        )

        if hasattr(tmp_ds, "model"):
            base_model = tmp_ds.model
        else:
            base_model = tmp_ds

        if hasattr(base_model, "deep_encoder"):
            self.vision_tower = base_model.deep_encoder
        elif hasattr(base_model, "vision_model"):
            self.vision_tower = base_model.vision_model
        else:
            self.vision_tower = base_model

        # Determine Dimension (STRICTLY 1024 for Stage 1 Compatibility)
        self.vision_dim = 1024
        if hasattr(self.vision_tower, "config"):
            self.vision_dim = getattr(self.vision_tower.config, "hidden_size", 1024)
        
        # Correct for DeepSeek-OCR specific logic if it erroneously picks up LLM dim
        if self.vision_dim == 1280:
            print(">>> Correcting detected dimension 1280 -> 1024 for DeepEncoder compatibility.")
            self.vision_dim = 1024

        print(f">>> Vision Dimension: {self.vision_dim}")
        
        self.vision_tower = self.vision_tower.to("cuda")
        del tmp_ds
        gc.collect()

        # --- B. LOAD GOT-OCR DECODER ---
        print(f">>> Loading GOT-OCR Decoder...")
        tmp_got = AutoModel.from_pretrained(
            got_path,
            trust_remote_code=True,
            device_map="cpu",
            torch_dtype=frozen_dtype,
            low_cpu_mem_usage=True
        )

        if hasattr(tmp_got, "language_model"):
            self.decoder = tmp_got.language_model
        else:
            self.decoder = tmp_got.model

        self.decoder_dim = self.decoder.config.hidden_size
        self.decoder = self.decoder.to("cuda")

        # LM Head Logic
        self.lm_head = None
        if hasattr(tmp_got, "lm_head"):
            self.lm_head = tmp_got.lm_head
        elif hasattr(self.decoder, "lm_head"):
            self.lm_head = self.decoder.lm_head
        
        if self.lm_head is None:
            vocab_size = self.decoder.config.vocab_size
            self.lm_head = nn.Linear(self.decoder_dim, vocab_size, bias=False)
            if hasattr(self.decoder, "embed_tokens"):
                self.lm_head.weight = self.decoder.embed_tokens.weight
            elif hasattr(self.decoder, "wte"):
                self.lm_head.weight = self.decoder.wte.weight
        
        self.lm_head = self.lm_head.to(dtype=frozen_dtype, device="cuda")
        del tmp_got
        gc.collect()
        torch.cuda.empty_cache()

        # --- C. INIT PROJECTOR ---
        self.projector = DeepSeekOCRToGOTProjector(self.vision_dim, self.decoder_dim)
        
        if projector_path is not None:
            weights_path = os.path.join(projector_path, "pytorch_model.bin")
            if os.path.exists(weights_path):
                print(f">>> SUCCESS: Loading Stage 1 Projector weights from {weights_path}")
                self.projector.load_state_dict(torch.load(weights_path, map_location="cpu"))
            else:
                print(f">>> WARNING: Weights not found at {weights_path}, initializing randomly.")
                self.projector.apply(self._init_weights)
        else:
            self.projector.apply(self._init_weights)

        self.img_start_id = tokenizer.convert_tokens_to_ids("<img>")
        self.img_end_id = tokenizer.convert_tokens_to_ids("</img>")

        self.vision_tower.requires_grad_(False)
        self.decoder.requires_grad_(False)
        self.lm_head.requires_grad_(False)
        self.projector.requires_grad_(True)
        self.projector = self.projector.to(device="cuda", dtype=torch.float32)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.01)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
        if hasattr(self.decoder, "gradient_checkpointing_enable"):
            self.decoder.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)

    def save_pretrained(self, save_directory):
        if not os.path.exists(save_directory): os.makedirs(save_directory)
        torch.save(self.projector.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
        with open(os.path.join(save_directory, "config.json"), "w") as f:
            json.dump({
                "vision_dim": self.vision_dim, 
                "decoder_dim": self.decoder_dim,
                "architecture": "DeepSeekOCRToGOTProjector"
            }, f, indent=4)

    def forward(self, pixel_values, input_ids, attention_mask=None, labels=None):
        with torch.no_grad():
            try:
                vision_out = self.vision_tower(pixel_values.to(self.vision_dtype), patch_embeds=None)
            except TypeError:
                vision_out = self.vision_tower(pixel_values.to(self.vision_dtype))
            
            if isinstance(vision_out, (tuple, list)):
                features = vision_out[0]
            elif hasattr(vision_out, "last_hidden_state"):
                features = vision_out.last_hidden_state
            else:
                features = vision_out
            features = features.detach()

        vision_embeds = self.projector(features.to(torch.float32)).to(self.vision_dtype)
        B, N, _ = vision_embeds.shape
        device = vision_embeds.device

        input_emb_fn = self.decoder.get_input_embeddings()
        start_embeds = input_emb_fn(torch.tensor([self.img_start_id], device=device)).expand(B, 1, -1)
        end_embeds = input_emb_fn(torch.tensor([self.img_end_id], device=device)).expand(B, 1, -1)
        text_embeds = input_emb_fn(input_ids)

        inputs_embeds = torch.cat([start_embeds, vision_embeds, end_embeds, text_embeds], dim=1)
        
        vision_len = N + 2
        full_mask = None
        if attention_mask is not None:
            v_mask = torch.ones((B, vision_len), device=device, dtype=attention_mask.dtype)
            full_mask = torch.cat([v_mask, attention_mask], dim=1)

        outputs = self.decoder(
            input_ids=torch.zeros((B, inputs_embeds.shape[1]), device=device, dtype=torch.long),
            inputs_embeds=inputs_embeds,
            attention_mask=full_mask,
            return_dict=True
        )
        
        hidden_states = outputs.last_hidden_state
        relevant_hidden = hidden_states[:, vision_len-1:-1, :].contiguous()
        logits = self.lm_head(relevant_hidden)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.decoder.config.vocab_size).to(torch.float32), labels.view(-1))

        return CausalLMOutputWithPast(loss=loss, logits=logits if not self.training else None)

# ==============================================================================
# 3. THE DATASET (ON-THE-FLY SYNTHETIC GENERATION)
# ==============================================================================
class RealTextOCRDataset(Dataset):
    def __init__(self, tokenizer, num_samples=20000):
        self.tokenizer = tokenizer
        self.num_samples = num_samples
        self.fake = Faker('en_US')  # English generator
        
        self.transform = T.Compose([
            T.Resize((1024, 1024), interpolation=T.InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                        std=[0.26862954, 0.26130258, 0.27577711])
        ])

    def __len__(self):
        return self.num_samples

    def generate_content(self):
        # Generate varied content types to make the model robust
        r = random.random()
        if r < 0.4:
            # Type 1: Standard Sentences (The easiest for LLMs)
            return self.fake.sentence(nb_words=10)
        elif r < 0.7:
            # Type 2: Addresses (Structured data)
            return self.fake.address().replace('\n', ', ')
        else:
            # Type 3: Names and Phone numbers
            return f"{self.fake.name()} - {self.fake.phone_number()}"

    def generate_image(self, text):
        # 1. Random Background (White-ish)
        bg_color = random.randint(230, 255)
        img = Image.new('RGB', (1024, 1024), color=(bg_color, bg_color, bg_color))
        draw = ImageDraw.Draw(img)

        # 2. Font Management
        try:
            # Attempt to use a larger, clearer font size
            font_size = random.randint(40, 80) 
            font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf", font_size)
        except IOError:
            font = ImageFont.load_default()

        # 3. Draw Text (Centered-ish)
        x = random.randint(50, 100)
        y = random.randint(200, 500)
        
        # Simple text wrapping logic
        words = text.split()
        current_line = ""
        for word in words:
            if (len(current_line) + len(word)) * (font_size * 0.5) > 800:
                draw.text((x, y), current_line, fill=(0, 0, 0), font=font)
                y += font_size + 10
                current_line = word + " "
            else:
                current_line += word + " "
        
        # Draw the last line
        draw.text((x, y), current_line, fill=(0, 0, 0), font=font)

        return img

    def __getitem__(self, idx):
        text = self.generate_content()
        image = self.generate_image(text)

        pixel_values = self.transform(image)
        prompt = f"OCR: {text}{self.tokenizer.eos_token}"
        
        # Consistent prefix masking logic from Stage 1 setup
        prefix_enc = self.tokenizer("OCR: ", add_special_tokens=False)
        prefix_len = len(prefix_enc.input_ids)

        encodings = self.tokenizer(
            prompt,
            max_length=512,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )
        
        input_ids = encodings.input_ids.squeeze(0)
        attention_mask = encodings.attention_mask.squeeze(0)
        labels = input_ids.clone()
        
        # Mask prefix and padding
        starts_with_bos = (input_ids[0] == self.tokenizer.bos_token_id)
        offset = 1 if starts_with_bos else 0
        labels[:prefix_len + offset] = -100
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            "pixel_values": pixel_values,
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }


2026-01-19 16:15:26.691634: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1768839326.713017    5075 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1768839326.719668    5075 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1768839326.736779    5075 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768839326.736801    5075 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768839326.736804    5075 computation_placer.cc:177] computation placer alr

>>> Monkeypatching LlamaFlashAttention2 for DeepSeek-OCR compatibility...


In [5]:
# 0. CLEANUP FOR RERUNS (Critical for Kaggle/Colab)
# This prevents the "Process has 14.72 GiB memory in use" error by clearing previous model refs.
import gc
import torch

if 'trainer' in locals(): del trainer
if 'model' in locals(): del model
if 'training_args' in locals(): del training_args
if 'train_dataset' in locals(): del train_dataset

gc.collect()
torch.cuda.empty_cache()
gc.collect()

# Final check
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

Allocated: 0.00 GB
Reserved: 0.00 GB


In [6]:
# CONFIGURATION
DEEPSEEK_OCR_PATH = "deepseek-ai/DeepSeek-OCR"
GOT_PATH = "stepfun-ai/GOT-OCR2_0"
OUTPUT_DIR = "./deepseek_ocr_got_final"

# Corrected path to stage 1 weights
# STAGE1_PROJECTOR_PATH = "/kaggle/input/temp-projector/projector_export" 

STAGE1_PROJECTOR_PATH = ""

print(">>> 1. Loading Tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(GOT_PATH, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# 2. Initialize RealTextOCRDataset
print(">>> 2. Initializing RealTextOCRDataset...")
train_dataset = RealTextOCRDataset(tokenizer, num_samples=20000)

def collate_fn(batch):
    return {
        "pixel_values": torch.stack([x['pixel_values'] for x in batch]),
        "input_ids": torch.stack([x['input_ids'] for x in batch]),
        "labels": torch.stack([x['labels'] for x in batch]),
        "attention_mask": torch.stack([x['attention_mask'] for x in batch])
    }
    
print(">>> 3. Initializing Fusion Model (Checking Stage 1 Weights)...")
model = DeepSeekGOTFusion(
    DEEPSEEK_OCR_PATH, 
    GOT_PATH, 
    tokenizer, 
    projector_path=STAGE1_PROJECTOR_PATH
)


>>> 1. Loading Tokenizer...
>>> 2. Initializing RealTextOCRDataset...
>>> 3. Initializing Fusion Model (Checking Stage 1 Weights)...
>>> Loading DeepSeek-OCR Encoder...


You are using a model of type deepseek_vl_v2 to instantiate a model of type DeepseekOCR. This is not supported for all configurations of models and can yield errors.
`torch_dtype` is deprecated! Use `dtype` instead!
Some weights of DeepseekOCRForCausalLM were not initialized from the model checkpoint at deepseek-ai/DeepSeek-OCR and are newly initialized: ['model.vision_model.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


>>> Vision Dimension: 1024
>>> Loading GOT-OCR Decoder...


In [8]:
# TRAINING CONFIG (Optimized for Speed and 12h Limit)
# Total samples = MAX_STEPS * per_device_train_batch_size * gradient_accumulation_steps
# With 1000 steps and accum 8, we process 8000 samples.
# At 0.03 it/s (33s per step), 1000 steps = ~9 hours.
MAX_STEPS = 1000
   
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=4,
    # gradient_accumulation_steps=8,
    learning_rate=5e-5,
    max_steps=MAX_STEPS,   # Use max_steps for streaming
    fp16=True,
    gradient_checkpointing=True, 
    logging_steps=1,
    save_strategy="steps",
    save_steps=200,      
    remove_unused_columns=False,
    report_to="none",
    save_safetensors=False,
    dataloader_pin_memory=False,
    prediction_loss_only=True,
    max_grad_norm=1.0,
    lr_scheduler_type="cosine",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=collate_fn,
)

print(">>> 4. Starting Training...")

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f">>> Trainable Parameters: {trainable_params:,}")

trainer.train()

print(f">>> 5. Saving Projector to {OUTPUT_DIR}...")
model.save_pretrained(OUTPUT_DIR)
print(">>> Done.")


>>> 4. Starting Training...
>>> Trainable Parameters: 4,199,424


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
1,18.6644
2,8.203
3,23.8152
4,21.3906
5,12.1144
6,14.3214
7,12.1424
8,11.0335
9,9.2214
10,9.9134


KeyboardInterrupt: 

In [9]:
# Diagnostic: Check special tokens
test_tokens = ["<|im_start|>", "<|im_end|>", "<img>", "</img>", "<image>", "</image>", "<|img_start|>", "<|img_end|>"]
print(f"Checking tokens for: {GOT_PATH}")
for t in test_tokens:
    token_id = tokenizer.convert_tokens_to_ids(t)
    print(f"Token '{t}': {token_id}")

# Also check if they are in special_tokens_map
print(f"\nSpecial tokens map: {tokenizer.special_tokens_map}")


Checking tokens for: stepfun-ai/GOT-OCR2_0
Token '<|im_start|>': 151644
Token '<|im_end|>': 151645
Token '<img>': 151857
Token '</img>': 151858
Token '<image>': None
Token '</image>': None
Token '<|img_start|>': None
Token '<|img_end|>': None

Special tokens map: {'pad_token': '<|endoftext|>'}
