In [1]:
# import os, re
# if "COLAB_" not in "".join(os.environ.keys()):
#     !pip install unsloth
# else:
#     # Do this only in Colab notebooks! Otherwise use pip install unsloth
#     import torch; v = re.match(r"[0-9]{1,}\.[0-9]{1,}", str(torch.__version__)).group(0)
#     xformers = "xformers==" + ("0.0.33.post1" if v=="2.9" else "0.0.32.post2" if v=="2.8" else "0.0.29.post3")
#     !pip install --no-deps bitsandbytes accelerate {xformers} peft trl triton cut_cross_entropy unsloth_zoo
#     !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer
#     !pip install --no-deps unsloth
# !pip install transformers==4.56.2
# !pip install --no-deps trl==0.22.2
# !pip install jiwer
# !pip install einops addict easydict
# !pip install verovio

In [2]:
# from huggingface_hub import snapshot_download, login
# login(token='thanks_secret_scanning')
# snapshot_download("unsloth/DeepSeek-OCR", local_dir = "deepseek_ocr")

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from transformers import (
    AutoModel,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    AutoConfig
)
from transformers.modeling_outputs import CausalLMOutputWithPast
import torchvision.transforms as T
from datasets import load_dataset
import warnings
import gc
from PIL import Image
import os

# Must be set before torch is imported to take effect
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# ==============================================================================
# 0. HOTFIX FOR DEEPSEEK-OCR (Transformers >= 4.46 compatibility)
#    DeepSeek-OCR's remote code tries to import 'LlamaFlashAttention2' which
#    was removed in recent transformers versions. We monkeypatch it.
# ==============================================================================
import transformers.models.llama.modeling_llama

# Check if the class is missing and inject a dummy alias if needed
if not hasattr(transformers.models.llama.modeling_llama, "LlamaFlashAttention2"):
    print(">>> Monkeypatching LlamaFlashAttention2 for DeepSeek-OCR compatibility...")
    # Map it to the standard LlamaAttention class as a fallback
    transformers.models.llama.modeling_llama.LlamaFlashAttention2 = transformers.models.llama.modeling_llama.LlamaAttention

# ==============================================================================
# 1. THE PROJECTOR (The "Bridge")
#    Connects DeepSeek's Compressed Vision Tokens to GOT's Decoder
# ==============================================================================
class DeepSeekOCRToGOTProjector(nn.Module):
    def __init__(self, encoder_dim, decoder_dim=1024):
        super().__init__()
        # DeepSeek-OCR tokens are highly compressed (dense information).
        # We use a bottleneck MLP to map them to Qwen's embedding space.
        self.net = nn.Sequential(
            nn.Linear(encoder_dim, decoder_dim * 2),
            nn.GELU(),
            nn.Linear(decoder_dim * 2, decoder_dim),
            nn.LayerNorm(decoder_dim)
        )

    def forward(self, x):
        return self.net(x)

# ==============================================================================
# 2. THE FUSION MODEL (FIXED FOR AMP/FP16 TRAINING)
# ==============================================================================
class DeepSeekGOTFusion(nn.Module):
    def __init__(self, deepseek_ocr_path, got_path, tokenizer):
        super().__init__()
        warnings.filterwarnings("ignore")

        # Frozen backbone dtype (Save VRAM)
        frozen_dtype = torch.float16
        self.vision_dtype = frozen_dtype

        # --- A. LOAD DEEPSEEK-OCR ENCODER ---
        print(f">>> Loading DeepSeek-OCR Encoder from: {deepseek_ocr_path}...")
        try:
            # Load to CPU first to avoid OOM when both models are loaded
            tmp_ds = AutoModel.from_pretrained(
                deepseek_ocr_path,
                trust_remote_code=True,
                device_map="cpu",
                torch_dtype=frozen_dtype,
                low_cpu_mem_usage=True
            )

            if hasattr(tmp_ds, "model"):
                base_model = tmp_ds.model
            else:
                base_model = tmp_ds

            if hasattr(base_model, "deep_encoder"):
                self.vision_tower = base_model.deep_encoder
            elif hasattr(base_model, "vision_model"):
                self.vision_tower = base_model.vision_model
            else:
                raise ValueError("Could not find 'deep_encoder'.")

            # Determine Dimension
            self.vision_dim = 1024
            if hasattr(self.vision_tower, "config"):
                self.vision_dim = getattr(self.vision_tower.config, "hidden_size", 1024)

            print(f">>> Vision Dimension: {self.vision_dim}")
            
            # Move only the needed component to GPU
            self.vision_tower = self.vision_tower.to("cuda")
            
            del tmp_ds
            if 'base_model' in locals(): del base_model
            gc.collect()
            torch.cuda.empty_cache()

        except Exception as e:
            print(f"CRITICAL ERROR Loading DeepSeek-OCR: {e}")
            raise e

        # --- B. LOAD GOT-OCR DECODER ---
        print(f">>> Loading GOT-OCR Decoder from: {got_path}...")
        # Load to CPU first
        tmp_got = AutoModel.from_pretrained(
            got_path,
            trust_remote_code=True,
            device_map="cpu",
            torch_dtype=frozen_dtype,
            low_cpu_mem_usage=True
        )

        if hasattr(tmp_got, "language_model"):
            self.decoder = tmp_got.language_model
        else:
            self.decoder = tmp_got.model

        self.decoder_dim = self.decoder.config.hidden_size
        
        # Move decoder to GPU
        self.decoder = self.decoder.to("cuda")

        # --- C. RECREATE LM HEAD ---
        self.lm_head = None
        if hasattr(tmp_got, "lm_head"):
            self.lm_head = tmp_got.lm_head
        elif hasattr(self.decoder, "lm_head"):
            self.lm_head = self.decoder.lm_head

        if self.lm_head is None:
            vocab_size = self.decoder.config.vocab_size
            self.lm_head = nn.Linear(self.decoder_dim, vocab_size, bias=False)
            if hasattr(self.decoder, "embed_tokens"):
                self.lm_head.weight = self.decoder.embed_tokens.weight
            elif hasattr(self.decoder, "wte"):
                self.lm_head.weight = self.decoder.wte.weight
            
        # Move LM Head to GPU
        self.lm_head = self.lm_head.to(dtype=frozen_dtype, device="cuda")

        del tmp_got
        gc.collect()
        torch.cuda.empty_cache()

        # --- D. INIT PROJECTOR ---
        print(f">>> Initializing Projector: {self.vision_dim} -> {self.decoder_dim}")
        self.projector = DeepSeekOCRToGOTProjector(self.vision_dim, self.decoder_dim)

        # --- E. CONFIG ---
        self.img_start_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
        self.img_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")

        # Freezing
        self.vision_tower.requires_grad_(False)
        self.decoder.requires_grad_(False)
        self.lm_head.requires_grad_(False)
        self.projector.requires_grad_(True)

        # --- FIX: TRAINABLE PARAMS MUST BE FLOAT32 FOR AMP ---
        # The backbones stay fp16 (frozen), but the projector is fp32.
        # The Trainer will autocast ops to fp16, but weights stay stable in fp32.
        self.projector = self.projector.to(device="cuda", dtype=torch.float32)

        # Initialize Projector Weights
        self.projector.apply(self._init_weights)

    def save_pretrained(self, save_directory):
        """
        Custom save method because this is a hybrid nn.Module, 
        not a standard transformers PreTrainedModel.
        Focuses on saving the trained projector.
        """
        if not os.path.exists(save_directory):
            os.makedirs(save_directory)
        
        # 1. Save Projector Weights (The only trained part)
        torch.save(self.projector.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
        
        # 2. Save a minimal config for reloading
        import json
        config = {
            "vision_dim": self.vision_dim,
            "decoder_dim": self.decoder_dim,
            "architecture": "DeepSeekOCRToGOTProjector"
        }
        with open(os.path.join(save_directory, "config.json"), "w") as f:
            json.dump(config, f, indent=4)
            
        print(f">>> Projector weights and config saved to {save_directory}")

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # Use smaller standard deviation for the "bridge" to prevent initial divergence
            nn.init.trunc_normal_(m.weight, std=0.01)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
        if hasattr(self.decoder, "gradient_checkpointing_enable"):
            self.decoder.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
            self.decoder.config.use_cache = False

    def forward(self, pixel_values, input_ids, attention_mask=None, labels=None):
        # 1. Vision Forward (DeepEncoder)
        with torch.no_grad():
            try:
                pixel_values = pixel_values.to(dtype=self.vision_dtype)
                vision_out = self.vision_tower(pixel_values, patch_embeds=None)
            except TypeError:
                pixel_values = pixel_values.to(dtype=self.vision_dtype)
                vision_out = self.vision_tower(pixel_values)
            
            if isinstance(vision_out, (tuple, list)):
                features = vision_out[0]
            elif hasattr(vision_out, "last_hidden_state"):
                features = vision_out.last_hidden_state
            else:
                features = vision_out

            features = features.detach()

        # 2. Project Features
        # Scale down vision features initially to match text embedding magnitude
        # and prevent attention spikes that lead to NaNs
        vision_embeds = self.projector(features.to(dtype=torch.float32))
        vision_embeds = vision_embeds.to(dtype=self.vision_dtype)

        # 3. Prepare Embeddings
        B, N, _ = vision_embeds.shape
        device = vision_embeds.device

        input_emb_fn = self.decoder.get_input_embeddings()
        start_embeds = input_emb_fn(torch.tensor([self.img_start_id], device=device)).expand(B, 1, -1)
        end_embeds = input_emb_fn(torch.tensor([self.img_end_id], device=device)).expand(B, 1, -1)
        text_embeds = input_emb_fn(input_ids)

        # Sequence: [Start, Vision(N), End, Text(L)]
        inputs_embeds = torch.cat([start_embeds, vision_embeds, end_embeds, text_embeds], dim=1)

        # 4. Attention Mask
        vision_len = N + 2
        full_attention_mask = None
        if attention_mask is not None:
            vision_mask = torch.ones((B, vision_len), device=device, dtype=attention_mask.dtype)
            full_attention_mask = torch.cat([vision_mask, attention_mask], dim=1)

        # 5. Decoder Forward
        # Create dummy input_ids to satisfy the GOT model's internal checks
        # which access input_ids.shape even when inputs_embeds is provided.
        # Since images=None, it won't actually try to use these ids for vision processing.
        dummy_input_ids = torch.zeros(
            (B, inputs_embeds.shape[1]), 
            device=device, 
            dtype=torch.long
        )

        outputs = self.decoder(
            input_ids=dummy_input_ids,
            inputs_embeds=inputs_embeds,
            attention_mask=full_attention_mask,
            return_dict=True
        )
        
        # MEMORY OPTIMIZATION: Only compute logits for the text tokens (loss calculation).
        # vision_len = N + 2. We need logits from vision_len-1 to end-1.
        hidden_states = outputs.last_hidden_state
        relevant_hidden = hidden_states[:, vision_len-1:-1, :].contiguous()
        logits = self.lm_head(relevant_hidden)

        # 6. Loss Calculation (ROBUST VERSION)
        loss = None
        if labels is not None:
            relevant_labels = labels.contiguous()

            loss_fct = nn.CrossEntropyLoss()
            # Cast to float32 for stable loss calculation
            loss = loss_fct(
                logits.view(-1, self.decoder.config.vocab_size).to(torch.float32), 
                relevant_labels.view(-1)
            )

        # To prevent DataParallel from gathering a massive logits tensor on GPU 0
        # (which causes OOM), we only return logits when not training.
        return CausalLMOutputWithPast(loss=loss, logits=logits if not self.training else None)

# ==============================================================================
# 3. THE DATASET (Full Implementation)
# ==============================================================================
class DeepSeekOCRDataset(Dataset):
    def __init__(self, hf_dataset, tokenizer):
        self.data = hf_dataset
        self.tokenizer = tokenizer

        # DeepSeek-OCR / CLIP-Large specific normalization
        # Note: DeepSeek-OCR natively supports 1024x1024
        self.transform = T.Compose([
            T.Resize((1024, 1024), interpolation=T.InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                        std=[0.26862954, 0.26130258, 0.27577711])
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # 1. Load Image
        image = item['image']
        if image.mode != "RGB":
            image = image.convert("RGB")

        pixel_values = self.transform(image)

        # 2. Prepare Text
        text = item['text']
        # DeepSeek/GOT style prompt
        prompt = f"OCR: {text}{self.tokenizer.eos_token}"

        # 3. Tokenize
        encodings = self.tokenizer(
            prompt,
            max_length=512,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )

        input_ids = encodings.input_ids.squeeze(0)
        attention_mask = encodings.attention_mask.squeeze(0)

        # 4. Labels (Mask padding)
        labels = input_ids.clone()
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            "pixel_values": pixel_values,
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

2026-01-18 12:08:58.919285: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1768738138.941274    2674 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1768738138.948009    2674 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1768738138.965181    2674 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768738138.965200    2674 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768738138.965203    2674 computation_placer.cc:177] computation placer alr

>>> Monkeypatching LlamaFlashAttention2 for DeepSeek-OCR compatibility...


In [4]:
# 0. CLEANUP FOR RERUNS (Critical for Kaggle/Colab)
# This prevents the "Process has 14.72 GiB memory in use" error by clearing previous model refs.
import gc
import torch

if 'trainer' in locals(): del trainer
if 'model' in locals(): del model
if 'training_args' in locals(): del training_args
if 'train_dataset' in locals(): del train_dataset

gc.collect()
torch.cuda.empty_cache()
gc.collect()

# Final check
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

Allocated: 0.00 GB
Reserved: 0.00 GB


In [5]:
# CONFIGURATION
DEEPSEEK_OCR_PATH = "deepseek-ai/DeepSeek-OCR"
GOT_PATH = "stepfun-ai/GOT-OCR2_0"
OUTPUT_DIR = "./deepseek_ocr_got_final"

print(">>> 1. Loading Tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(GOT_PATH, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(">>> 2. Preparing Dataset...")
dataset = load_dataset("hezarai/parsynth-ocr-200k", split="train[:5000]")
dataset = dataset.rename_column("image_path", "image")
train_dataset = DeepSeekOCRDataset(dataset, tokenizer)

def collate_fn(batch):
    return {
        "pixel_values": torch.stack([x['pixel_values'] for x in batch]),
        "input_ids": torch.stack([x['input_ids'] for x in batch]),
        "labels": torch.stack([x['labels'] for x in batch]),
        "attention_mask": torch.stack([x['attention_mask'] for x in batch])
    }
    
print(">>> 3. Initializing Fusion Model...")
model = DeepSeekGOTFusion(DEEPSEEK_OCR_PATH, GOT_PATH, tokenizer)

>>> 1. Loading Tokenizer...
>>> 2. Preparing Dataset...


You are using a model of type deepseek_vl_v2 to instantiate a model of type DeepseekOCR. This is not supported for all configurations of models and can yield errors.
`torch_dtype` is deprecated! Use `dtype` instead!


>>> 3. Initializing Fusion Model...
>>> Loading DeepSeek-OCR Encoder from: deepseek-ai/DeepSeek-OCR...


Some weights of DeepseekOCRForCausalLM were not initialized from the model checkpoint at deepseek-ai/DeepSeek-OCR and are newly initialized: ['model.vision_model.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


>>> Vision Dimension: 1024
>>> Loading GOT-OCR Decoder from: stepfun-ai/GOT-OCR2_0...
>>> Initializing Projector: 1024 -> 1024


In [6]:
# TRAINING CONFIG
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=1,  # REDUCED to 1 to fit in T4
    gradient_accumulation_steps=8,  # INCREASED to maintain effective batch size 8
    learning_rate=5e-5,    # Reduced LR for stability
    num_train_epochs=1,
    fp16=True,  # ENABLED for T4
    bf16=False, # DISABLED for T4
    gradient_checkpointing=True, 
    logging_steps=5,
    save_strategy="steps",
    save_steps=50,
    remove_unused_columns=False,
    report_to="none",
    save_safetensors=False,
    dataloader_pin_memory=False,
    prediction_loss_only=True,  # CRITICAL: Do not gather logits (saves HUGE memory)
    max_grad_norm=1.0      # CLIPPING to prevent NaNs
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=collate_fn,
)

print(">>> 4. Starting Training...")

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f">>> Trainable Parameters: {trainable_params:,}")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"    - {name}: {param.numel():,} params")

trainer.train()

print(f">>> 5. Saving Projector to {OUTPUT_DIR}...")
model.save_pretrained(OUTPUT_DIR)
print(">>> Done.")

>>> 4. Starting Training...
>>> Trainable Parameters: 4,199,424
    - projector.net.0.weight: 2,097,152 params
    - projector.net.0.bias: 2,048 params
    - projector.net.2.weight: 2,097,152 params
    - projector.net.2.bias: 1,024 params
    - projector.net.3.weight: 1,024 params
    - projector.net.3.bias: 1,024 params


Step,Training Loss
5,13.3444
10,11.4285
15,13.3893
20,12.2692
25,11.0591
30,9.9829
35,9.1216
40,8.2996
45,7.7367
50,7.4479


>>> 5. Saving Projector to ./deepseek_ocr_got_final...
>>> Projector weights and config saved to ./deepseek_ocr_got_final
>>> Done.


In [15]:
!ls ./deepseek_ocr_got_final/checkpoint-313

optimizer.pt	   rng_state.pth  scheduler.pt	      training_args.bin
pytorch_model.bin  scaler.pt	  trainer_state.json


In [17]:
# CLEANUP AND EXPORT ONLY NECESSARY FILES
import shutil
import os

# 1. Clear the checkpoints IMMEDIATELY to free space
# Each checkpoint has optimizer states and massive redundancy.
print(">>> Freeing disk space by removing checkpoints...")
!rm -rf {OUTPUT_DIR}/checkpoint-*
!rm -rf ./checkpoint-313.zip

# 2. Check disk space again
print(">>> Current disk usage:")
!df -h .

# 3. Create a clean export directory for just the weights and config
EXPORT_DIR = "./projector_export"
if not os.path.exists(EXPORT_DIR):
    os.makedirs(EXPORT_DIR)

# 4. Copy just the weights and config
# Note: These were saved in the root of OUTPUT_DIR at the end of training
print(">>> Copying weights and config...")
shutil.copy(os.path.join(OUTPUT_DIR, "pytorch_model.bin"), os.path.join(EXPORT_DIR, "pytorch_model.bin"))
shutil.copy(os.path.join(OUTPUT_DIR, "config.json"), os.path.join(EXPORT_DIR, "config.json"))

# 5. Zip ONLY the export directory (very small, only ~17MB)
print(">>> Zipping export files...")
!zip -r DeepSeek_GOT_Projector.zip {EXPORT_DIR}

print("\n>>> DONE! You can now download DeepSeek_GOT_Projector.zip")


>>> Freeing disk space by removing checkpoints...
>>> Current disk usage:
Filesystem      Size  Used Avail Use% Mounted on
/dev/loop1       20G  6.3G   14G  33% /kaggle/working
>>> Copying weights and config...
>>> Zipping export files...
  adding: projector_export/ (stored 0%)
  adding: projector_export/config.json (deflated 22%)
  adding: projector_export/pytorch_model.bin (deflated 7%)

>>> DONE! You can now download DeepSeek_GOT_Projector.zip
