# Install packages

In [None]:
# !pip install transformers==4.56.2
# !pip install --no-deps trl==0.22.2
# !pip install jiwer
# !pip install einops addict easydict
# !pip install verovio
# !pip install faker
# !pip install peft

Collecting transformers==4.56.2
  Downloading transformers-4.56.2-py3-none-any.whl.metadata (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.1/40.1 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
Downloading transformers-4.56.2-py3-none-any.whl (11.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.6/11.6 MB[0m [31m136.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.57.3
    Uninstalling transformers-4.57.3:
      Successfully uninstalled transformers-4.57.3
Successfully installed transformers-4.56.2


Collecting trl==0.22.2
  Downloading trl-0.22.2-py3-none-any.whl.metadata (11 kB)
Downloading trl-0.22.2-py3-none-any.whl (544 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/544.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m542.7/544.8 kB[0m [31m18.2 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m544.8/544.8 kB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: trl
Successfully installed trl-0.22.2
Collecting jiwer
  Downloading jiwer-4.0.0-py3-none-any.whl.metadata (3.3 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.14.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (12 kB)
Downloading jiwer-4.0.0-py3-none-any.whl (23 kB)
Downloading rapidfuzz-3.14.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m

### HF Login for Kaggle

In [2]:
# from huggingface_hub import login
# from kaggle_secrets import UserSecretsClient
# user_secrets = UserSecretsClient()
# token = user_secrets.get_secret("HF_TOKEN")
# login(token=token)

### Hotfix (Transformers >= 4.46 compatibility)

In [1]:
import transformers.models.llama.modeling_llama
if not hasattr(transformers.models.llama.modeling_llama, "LlamaFlashAttention2"):
    print(">>> Monkeypatching LlamaFlashAttention2 for DeepSeek-OCR compatibility...")
    transformers.models.llama.modeling_llama.LlamaFlashAttention2 = transformers.models.llama.modeling_llama.LlamaAttention

>>> Monkeypatching LlamaFlashAttention2 for DeepSeek-OCR compatibility...


# Imports

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from transformers import (
    AutoModel,
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    PreTrainedModel
)
from transformers.modeling_outputs import CausalLMOutputWithPast
import torchvision.transforms as T
import warnings
import gc
from PIL import Image, ImageDraw, ImageFont
import os
import json
import random
from faker import Faker
from peft import get_peft_model, LoraConfig
import types

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Projector MLP

In [21]:
class DeepSeekOCRToGOTProjector(nn.Module):
    def __init__(self, encoder_dim, decoder_dim=1024):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(encoder_dim, decoder_dim * 2),
            nn.GELU(),
            nn.Linear(decoder_dim * 2, decoder_dim),
            nn.LayerNorm(decoder_dim)
        )

    def forward(self, x):
        return self.net(x)

# Fusion Model

In [None]:
class DeepSeekFeatureExtractor(nn.Module):
    """
    Wraps the internal DeepSeek-OCR modules to perform the full
    SAM + CLIP + Projection forward pass, treating the output as a 1D sequence.
    """
    def __init__(self, base_model):
        super().__init__()
        # Extract sub-modules from the base DeepSeek-OCR model
        self.sam_model = base_model.sam_model
        self.vision_model = base_model.vision_model
        self.projector = base_model.projector

    def forward(self, images):
        # 1. Run SAM Backbone
        sam_features = self.sam_model(images)

        # 2. Run Vision Backbone (CLIP) with SAM injection
        vision_features = self.vision_model(images, sam_features)

        # 3. Fuse Features
        # SAM: [B, C_sam, H, W] -> [B, H*W, C_sam]
        sam_flat = sam_features.flatten(2).permute(0, 2, 1)
        # CLIP: [B, 1+L, C_clip] -> [B, L, C_clip] (skipping CLS token)
        clip_flat = vision_features[:, 1:]

        fused_features = torch.cat((clip_flat, sam_flat), dim=-1)

        # 4. Run Internal Projector
        # Projects fused features to the cross-modal dimension [B, L, C]
        final_features = self.projector(fused_features)

        # We return the flattened 1D sequence without inserting image newline tokens
        return final_features


In [None]:
class DeepSeekGOTFusion(nn.Module):
    def __init__(self, deepseek_ocr_path, got_path, tokenizer, use_lora=True):
        super().__init__()

        self.tokenizer = tokenizer 

        warnings.filterwarnings("ignore")
        frozen_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        self.vision_dtype = frozen_dtype

        # --- 1. Load DeepSeek Encoder ---
        print(f">>> Loading DeepSeek-OCR Encoder...")
        deepseek_model = AutoModel.from_pretrained(
            deepseek_ocr_path,
            trust_remote_code=True,
            torch_dtype=frozen_dtype,
        )
        base_vision = deepseek_model.model if hasattr(deepseek_model, "model") else deepseek_model
        self.vision_tower = DeepSeekFeatureExtractor(base_vision)
        self.vision_dim = 1280

        for param in self.vision_tower.parameters():
            param.requires_grad = False

        # --- 2. Load GOT-OCR Decoder (Use ForCausalLM to enable internal loss/head) ---
        print(f">>> Loading GOT-OCR Decoder (AutoModelForCausalLM)...")
        self.decoder = AutoModelForCausalLM.from_pretrained(
            got_path,
            trust_remote_code=True,
            torch_dtype=frozen_dtype,
        )
        self.decoder_dim = self.decoder.config.hidden_size

        # --- 3. LoRA Injection ---
        self.decoder.requires_grad_(False)
        if use_lora:
            print(">>> Injecting LoRA Adapters...")
            lora_config = LoraConfig(
                r=16,
                lora_alpha=32,
                target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type="CAUSAL_LM"
            )
            self.decoder = get_peft_model(self.decoder, lora_config)
            self.decoder.print_trainable_parameters()

        # --- 4. Projector ---
        self.projector = DeepSeekOCRToGOTProjector(self.vision_dim, self.decoder_dim)
        self.projector.apply(self._init_weights)

        self.img_start_id = tokenizer.convert_tokens_to_ids("<img>")
        self.img_end_id = tokenizer.convert_tokens_to_ids("</img>")

        self.vision_tower.requires_grad_(False)
        self.projector.requires_grad_(True)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.01)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
        if hasattr(self.decoder, "gradient_checkpointing_enable"):
            self.decoder.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)

    def save_pretrained(self, save_directory):
        if not os.path.exists(save_directory): os.makedirs(save_directory)
        torch.save(self.projector.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
        if hasattr(self.decoder, "save_pretrained"):
             self.decoder.save_pretrained(save_directory)
        with open(os.path.join(save_directory, "config.json"), "w") as f:
            json.dump({
                "vision_dim": self.vision_dim,
                "decoder_dim": self.decoder_dim,
                "architecture": "DeepSeekOCRToGOTProjector"
            }, f, indent=4)

    def forward(self, pixel_values, input_ids, attention_mask=None, labels=None, **kwargs):
        # 1. Vision Encode
        with torch.no_grad():
            vision_out = self.vision_tower(pixel_values.to(self.vision_dtype))
            if isinstance(vision_out, (tuple, list)):
                features = vision_out[0]
            elif hasattr(vision_out, "last_hidden_state"):
                features = vision_out.last_hidden_state
            else:
                features = vision_out
            features = features.detach()

        # 2. Project
        vision_embeds = self.projector(features.to(torch.float32)).to(self.vision_dtype)
        B, N, _ = vision_embeds.shape
        device = vision_embeds.device

        # 3. Text Embeds
        base_decoder = self.decoder.get_base_model() if hasattr(self.decoder, "get_base_model") else self.decoder
        input_emb_fn = base_decoder.get_input_embeddings()

        start_embeds = input_emb_fn(torch.tensor([self.img_start_id], device=device)).expand(B, 1, -1)
        end_embeds = input_emb_fn(torch.tensor([self.img_end_id], device=device)).expand(B, 1, -1)
        text_embeds = input_emb_fn(input_ids)

        inputs_embeds = torch.cat([start_embeds, vision_embeds, end_embeds, text_embeds], dim=1)

        # 4. Alignment of Labels and Attention Mask
        vision_token_len = inputs_embeds.shape[1] - text_embeds.shape[1]
        
        full_mask = None
        if attention_mask is not None:
            v_mask = torch.ones((B, vision_token_len), device=device, dtype=attention_mask.dtype)
            full_mask = torch.cat([v_mask, attention_mask], dim=1)

        full_labels = None
        if labels is not None:
            vision_ignore = torch.full((B, vision_token_len), -100, device=device)
            full_labels = torch.cat([vision_ignore, labels], dim=1)

        # 5. Decoder Call (Loss is calculated internally by CausalLM)
        return self.decoder(
            inputs_embeds=inputs_embeds,
            attention_mask=full_mask,
            labels=full_labels,
            return_dict=True,
            use_cache=False
        )


# Dataset

In [None]:
class RealTextOCRDataset(Dataset):
    def __init__(self, tokenizer, num_samples=20000):
        self.tokenizer = tokenizer
        self.num_samples = num_samples
        self.fake = Faker('en_US')  # English generator

        self.transform = T.Compose([
            T.Resize((1024, 1024), interpolation=T.InterpolationMode.BICUBIC),
            T.ToTensor()
        ]) # do not normalize, DeepSeek-OCR expects [0,1] range

    def __len__(self):
        return self.num_samples

    def generate_content(self):
        # Generate varied content types to make the model robust
        r = random.random()
        if r < 0.4:
            # Type 1: Standard Sentences (The easiest for LLMs)
            return self.fake.sentence(nb_words=10)
        elif r < 0.7:
            # Type 2: Addresses (Structured data)
            return self.fake.address().replace('\n', ', ')
        else:
            # Type 3: Names and Phone numbers
            return f"{self.fake.name()} - {self.fake.phone_number()}"

    def generate_image(self, text):
        # 1. Random Background (White-ish)
        bg_color = random.randint(230, 255)
        img = Image.new('RGB', (1024, 1024), color=(bg_color, bg_color, bg_color))
        draw = ImageDraw.Draw(img)

        # 2. Font Management
        try:
            # Attempt to use a larger, clearer font size
            font_size = random.randint(40, 80)
            font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf", font_size)
        except IOError:
            font = ImageFont.load_default()

        # 3. Draw Text (Centered-ish)
        x = random.randint(50, 100)
        y = random.randint(200, 500)

        # Simple text wrapping logic
        words = text.split()
        current_line = ""
        for word in words:
            if (len(current_line) + len(word)) * (font_size * 0.5) > 800:
                draw.text((x, y), current_line, fill=(0, 0, 0), font=font)
                y += font_size + 10
                current_line = word + " "
            else:
                current_line += word + " "

        # Draw the last line
        draw.text((x, y), current_line, fill=(0, 0, 0), font=font)

        return img

    def __getitem__(self, idx):
        text = self.generate_content()
        image = self.generate_image(text)

        pixel_values = self.transform(image)
        # prompt = f"OCR: {text}{self.tokenizer.eos_token}"
        eos = self.tokenizer.eos_token if self.tokenizer.eos_token else "<|endoftext|>"
        prompt = f"OCR: {text}{eos}"

        # Consistent prefix masking logic from Stage 1 setup
        prefix_enc = self.tokenizer("OCR: ", add_special_tokens=False)
        prefix_len = len(prefix_enc.input_ids)

        encodings = self.tokenizer(
            prompt,
            max_length=512,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )

        input_ids = encodings.input_ids.squeeze(0)
        attention_mask = encodings.attention_mask.squeeze(0)
        labels = input_ids.clone()

        # Mask prefix and padding
        starts_with_bos = (input_ids[0] == self.tokenizer.bos_token_id)
        offset = 1 if starts_with_bos else 0
        labels[:prefix_len + offset] = -100
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            "pixel_values": pixel_values,
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

# Training configuration

In [27]:
DEEPSEEK_OCR_PATH = "deepseek-ai/DeepSeek-OCR"
GOT_PATH = "stepfun-ai/GOT-OCR2_0"
OUTPUT_DIR = "./deepseek_ocr_got_final"

In [28]:
print(">>> 1. Loading Tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(GOT_PATH, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(">>> 2. Initializing RealTextOCRDataset...")
train_dataset = RealTextOCRDataset(tokenizer, num_samples=20000)

def collate_fn(batch):
    return {
        "pixel_values": torch.stack([x['pixel_values'] for x in batch]),
        "input_ids": torch.stack([x['input_ids'] for x in batch]),
        "labels": torch.stack([x['labels'] for x in batch]),
        "attention_mask": torch.stack([x['attention_mask'] for x in batch])
    }

print(">>> 3. Initializing Fusion Model...")
model = DeepSeekGOTFusion(
    DEEPSEEK_OCR_PATH,
    GOT_PATH,
    tokenizer,
)

>>> 1. Loading Tokenizer...
>>> 2. Initializing RealTextOCRDataset...
>>> 3. Initializing Fusion Model...
>>> Loading DeepSeek-OCR Encoder...


You are using a model of type deepseek_vl_v2 to instantiate a model of type DeepseekOCR. This is not supported for all configurations of models and can yield errors.
Some weights of DeepseekOCRForCausalLM were not initialized from the model checkpoint at deepseek-ai/DeepSeek-OCR and are newly initialized: ['model.vision_model.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


>>> Loading GOT-OCR Decoder (AutoModel)...
>>> Injecting LoRA Adapters...
trainable params: 7,569,408 || all params: 568,098,048 || trainable%: 1.3324


In [33]:
MAX_STEPS = 500

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    learning_rate=1e-4,
    max_steps=MAX_STEPS,   # Use max_steps for streaming
    fp16=True,
    gradient_checkpointing=True,
    logging_steps=10,
    save_strategy="steps",
    save_steps=50,
    remove_unused_columns=False,
    report_to="none",
    save_safetensors=False,
    dataloader_pin_memory=False,
    prediction_loss_only=True,
    max_grad_norm=0.5,
    lr_scheduler_type="cosine",
    warmup_steps=5,
    ddp_find_unused_parameters=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=collate_fn,
)

# [DEBUG] Sanity Check

In [42]:
# Set the debug flag
model.debug_one_step = True
model.eval() # Optional: set to eval to disable dropout noise for cleaner check

print(">>> Running Sanity Check...")
# Grab one batch from the dataloader
dataloader = trainer.get_train_dataloader()
batch = next(iter(dataloader))

# Move batch to device
device = model.decoder.device # or "cuda"
batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}

# Forward pass
with torch.no_grad():
    model(**batch)

model.train() # Set back to train mode
print(">>> Sanity Check Complete.")

>>> Running Sanity Check...

[DEBUG] === Alignment Check ===
[DEBUG] Vision Tokens (N): 272
[DEBUG] Vision Total (N+2): 274
[DEBUG] Text Input Length: 512
[DEBUG] Full Hidden Shape: torch.Size([1, 786, 1024])
[DEBUG] Relevant Hidden (Sliced): torch.Size([1, 512, 1024])
[DEBUG] Labels Shape: torch.Size([1, 512])
[DEBUG] --- Decoding First Sample ---
[DEBUG] Label Text: ' serious once staff improve break unit way something pattern that husband.'
[DEBUG] Pred  Text: '.. - - - - - - - - -.'

>>> Sanity Check Complete.


# Train

In [34]:
print(">>> 4. Starting Training...")

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f">>> Trainable Parameters: {trainable_params:,}")

trainer.train()

print(f">>> 5. Saving Projector to {OUTPUT_DIR}...")
model.save_pretrained(OUTPUT_DIR)
print(">>> Done.")

>>> 4. Starting Training...
>>> Trainable Parameters: 12,293,120


Step,Training Loss
10,10.5145
20,8.3379
30,7.6773
40,7.8643
50,9.9592
60,8.5185
70,6.8777
80,8.5693
90,7.503
100,7.9878


>>> 5. Saving Projector to ./deepseek_ocr_got_final...
>>> Done.


In [18]:
# import torch
# from transformers import AutoModel

# def list_deepseek_components():
#     model_path = "deepseek-ai/DeepSeek-OCR"

#     print(f">>> Loading {model_path} (this may take a moment)...")
#     try:
#         # trust_remote_code=True is required for DeepSeek-OCR
#         model = AutoModel.from_pretrained(
#             model_path,
#             trust_remote_code=True,
#             torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
#         )
#     except Exception as e:
#         print(f"Error loading model: {e}")
#         return

#     print(f"\n>>> Component List for {model_path}:")
#     print("=" * 60)

#     # Iterate through all named modules (layers/components)
#     for name, module in model.named_modules():
#         # Indent based on depth for readability
#         depth = name.count('.')
#         indent = "  " * depth
#         print(f"{indent}{name} ({module.__class__.__name__})")

#     print("=" * 60)

#     # Specific check for the encoder you were debugging
#     print("\n>>> Checking for Vision Encoder attributes:")
#     if hasattr(model, "deep_encoder"):
#         print("  - Found 'deep_encoder' attribute")
#     elif hasattr(model, "vision_model"):
#         print("  - Found 'vision_model' attribute")
#     else:
#         print("  - Could not find standard 'deep_encoder' or 'vision_model' attributes at top level.")

# if __name__ == "__main__":
#     list_deepseek_components()

>>> Loading deepseek-ai/DeepSeek-OCR (this may take a moment)...


You are using a model of type deepseek_vl_v2 to instantiate a model of type DeepseekOCR. This is not supported for all configurations of models and can yield errors.
Some weights of DeepseekOCRForCausalLM were not initialized from the model checkpoint at deepseek-ai/DeepSeek-OCR and are newly initialized: ['model.vision_model.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



>>> Component List for deepseek-ai/DeepSeek-OCR:
 (DeepseekOCRForCausalLM)
model (DeepseekOCRModel)
  model.embed_tokens (Embedding)
  model.layers (ModuleList)
    model.layers.0 (DeepseekV2DecoderLayer)
      model.layers.0.self_attn (LlamaAttention)
        model.layers.0.self_attn.q_proj (Linear)
        model.layers.0.self_attn.k_proj (Linear)
        model.layers.0.self_attn.v_proj (Linear)
        model.layers.0.self_attn.o_proj (Linear)
      model.layers.0.mlp (DeepseekV2MLP)
        model.layers.0.mlp.gate_proj (Linear)
        model.layers.0.mlp.up_proj (Linear)
        model.layers.0.mlp.down_proj (Linear)
        model.layers.0.mlp.act_fn (SiLU)
      model.layers.0.input_layernorm (DeepseekV2RMSNorm)
      model.layers.0.post_attention_layernorm (DeepseekV2RMSNorm)
    model.layers.1 (DeepseekV2DecoderLayer)
      model.layers.1.self_attn (LlamaAttention)
        model.layers.1.self_attn.q_proj (Linear)
        model.layers.1.self_attn.k_proj (Linear)
        model.layer

In [22]:
# import torch
# from transformers import AutoModel
# import inspect

# def inspect_deepseek_components():
#     model_path = "deepseek-ai/DeepSeek-OCR"

#     print(f">>> Loading {model_path} (this may take a moment)...")
#     try:
#         # trust_remote_code=True is required for DeepSeek-OCR
#         model = AutoModel.from_pretrained(
#             model_path,
#             trust_remote_code=True,
#             torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
#         )
#     except Exception as e:
#         print(f"Error loading model: {e}")
#         return

#     inner_model = model.model

#     print(f">>> Inspecting {inner_model.__class__.__name__}...")

#     # 2. We are looking for how 'sam_model' and 'vision_model' are connected.
#     # It is likely in 'forward' or a method called 'encode_images' bound to this class.
#     if hasattr(inner_model, "forward"):
#         source = inspect.getsource(inner_model.forward)
#         print(source)
#     else:
#       print("Could not find forward method on inner_model.")

# if __name__ == "__main__":
#     inspect_deepseek_components()

>>> Loading deepseek-ai/DeepSeek-OCR (this may take a moment)...


You are using a model of type deepseek_vl_v2 to instantiate a model of type DeepseekOCR. This is not supported for all configurations of models and can yield errors.
Some weights of DeepseekOCRForCausalLM were not initialized from the model checkpoint at deepseek-ai/DeepSeek-OCR and are newly initialized: ['model.vision_model.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


>>> Inspecting DeepseekOCRModel...
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        images_seq_mask: Optional[torch.FloatTensor] = None,
        images_spatial_crop: Optional[torch.FloatTensor] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:




        if inputs_embeds is None:
            # inputs_embeds = self.embed_tokens(input_ids)
            inputs_embeds = self.get_input_embeddings()(input_ids)



        sam_model = getattr(self, 'sam_model', None)
        # sam_mod