In [1]:
from func_utils.plot_utils import show_image
import matplotlib.pyplot as plt 
from glob import glob
import pandas as pd 
import numpy as np 
import json
import os 

import torch 
from func_utils.pydataloader import SynthDogDataset
from func_utils.trainer_utils import *
from encoder_decoder_model import (
    init_dit_mbert_models_fixed, init_dit_dbart_models, 
    print_model_layer_sizes, load_pretrained_enc_dec_model, load_pretrained_iprocessor_tokenizer
    )

from func_utils.trainer_utils import unfreeze_last_n_encoder
from transformers import TrainingArguments, Seq2SeqTrainingArguments
from trl import SFTTrainer

import wandb
import gc

torch.cuda.empty_cache()
gc.collect()
wandb.login()

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


  from .autonotebook import tqdm as notebook_tqdm
[34m[1mwandb[0m: Currently logged in as: [33mbeasted90[0m ([33mbeasted90-comudel[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model

# Load quantized model
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype="bfloat16"
)

model_name = "facebook/Perception-LM-1B"
processor = AutoProcessor.from_pretrained(model_name, use_fast=True)
model = AutoModelForImageTextToText.from_pretrained(model_name, quantization_config=bnb_config, device_map="auto")

In [3]:
model = prepare_model_for_kbit_training(model)

In [4]:
text = 'ÁGUA É ESSENCIAL PARA A COMPREENSÃO E AÇÃO; CÂNCER, ÓRGÃOS, EMOÇÃO, TÊM INFLUÊNCIA, E ÍNDICES MOSTRAM EVOLUÇÃO.'
processor.decode(processor(text=text).input_ids[0])

'<|begin_of_text|>ÁGUA É ESSENCIAL PARA A COMPREENSÃO E AÇÃO; CÂNCER, ÓRGÃOS, EMOÇÃO, TÊM INFLUÊNCIA, E ÍNDICES MOSTRAM EVOLUÇÃO.'

In [5]:
processor.tokenizer.special_tokens_map

{'bos_token': '<|begin_of_text|>',
 'eos_token': '<|eot_id|>',
 'pad_token': '<|end_of_text|>',
 'image_token': '<|image|>',
 'video_token': '<|video|>'}

In [6]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",  # self-attention
        # "gate_proj", "up_proj", "down_proj"      # MLP
    ],
    lora_dropout=0.05,
    bias="none",
    task_type="SEQ_2_SEQ_LM"
)

model = get_peft_model(model, lora_config)
model.gradient_checkpointing_enable()

In [7]:
print_trainable_prams(model)

⛔ Frozen: base_model.model.model.vision_tower.timm_model.cls_token
⛔ Frozen: base_model.model.model.vision_tower.timm_model.pos_embed
⛔ Frozen: base_model.model.model.vision_tower.timm_model.patch_embed.proj.weight
⛔ Frozen: base_model.model.model.vision_tower.timm_model.norm_pre.weight
⛔ Frozen: base_model.model.model.vision_tower.timm_model.norm_pre.bias
⛔ Frozen: base_model.model.model.vision_tower.timm_model.blocks.0.gamma_1
⛔ Frozen: base_model.model.model.vision_tower.timm_model.blocks.0.gamma_2
⛔ Frozen: base_model.model.model.vision_tower.timm_model.blocks.0.norm1.weight
⛔ Frozen: base_model.model.model.vision_tower.timm_model.blocks.0.norm1.bias
⛔ Frozen: base_model.model.model.vision_tower.timm_model.blocks.0.attn.qkv.weight
⛔ Frozen: base_model.model.model.vision_tower.timm_model.blocks.0.attn.qkv.bias
⛔ Frozen: base_model.model.model.vision_tower.timm_model.blocks.0.attn.proj.weight
⛔ Frozen: base_model.model.model.vision_tower.timm_model.blocks.0.attn.proj.bias
⛔ Frozen: b

In [8]:
run_name = "overfit_testing_plm"
wandb.init(project="ocr model", name=run_name)

In [9]:
def get_synth_images_json_path(data_root= os.path.join('synthdog','outputs'), split='train'):
    ipath = os.path.join(data_root, '*', split, '*.jpg')
    json_path = os.path.join(data_root, '*', split, 'metadata.jsonl')

    return glob(ipath), glob(json_path)


root_path = os.path.join('synthdog', 'outputs')
train_ipath, train_json_metadata = get_synth_images_json_path(data_root=root_path, split='train')
val_ipath, val_json_metadata = get_synth_images_json_path(data_root=root_path, split='validation')

peak_mem = torch.cuda.max_memory_allocated()
print(f"The model as is is holding: {peak_mem / 1024**3:.2f} of GPU RAM")

max_token_size = 1056
sample_size = 32
train_synthdataset = SynthDogDataset(train_ipath, train_json_metadata, image_feature_extractor=processor, 
                                     text_tokenizer=None, max_token_size=max_token_size, sample_size=sample_size)
val_synthdataset = SynthDogDataset(val_ipath, val_json_metadata, image_feature_extractor=processor, 
                                   text_tokenizer=None, max_token_size=max_token_size, sample_size=4)

The model as is is holding: 2.10 of GPU RAM
['synthdog\\outputs\\SynthDoG-en\\train\\image_0.jpg', 'synthdog\\outputs\\SynthDoG-en\\train\\image_1.jpg']
Sampled lang counter: {'SynthDoG-en': 16, 'SynthDoG-pt': 16}
Length of _.images: 32 | Length of _.json_metadata: 91
['synthdog\\outputs\\SynthDoG-en\\validation\\image_43.jpg']
Sampled lang counter: {'SynthDoG-en': 1}
Length of _.images: 1 | Length of _.json_metadata: 1


In [10]:
model.config

PerceptionLMConfig {
  "architectures": [
    "PerceptionLMForConditionalGeneration"
  ],
  "dtype": "float16",
  "image_token_id": 128002,
  "model_type": "perception_lm",
  "projector_pooling_ratio": 2,
  "quantization_config": {
    "_load_in_4bit": true,
    "_load_in_8bit": false,
    "bnb_4bit_compute_dtype": "bfloat16",
    "bnb_4bit_quant_storage": "uint8",
    "bnb_4bit_quant_type": "nf4",
    "bnb_4bit_use_double_quant": true,
    "llm_int8_enable_fp32_cpu_offload": false,
    "llm_int8_has_fp16_weight": false,
    "llm_int8_skip_modules": null,
    "llm_int8_threshold": 6.0,
    "load_in_4bit": true,
    "load_in_8bit": false,
    "quant_method": "bitsandbytes"
  },
  "text_config": {
    "attention_bias": false,
    "attention_dropout": 0.0,
    "bos_token_id": 128000,
    "dtype": "float16",
    "eos_token_id": [
      128001,
      128009
    ],
    "head_dim": 64,
    "hidden_act": "silu",
    "hidden_size": 2048,
    "initializer_range": 0.02,
    "intermediate_size": 8

In [None]:
num_epochs = 1
training_args = Seq2SeqTrainingArguments(
        output_dir=f"./{run_name}",
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=8,
        learning_rate=5e-5,  
        lr_scheduler_type="cosine",
        num_train_epochs=num_epochs,
        warmup_ratio=0.1,  
        logging_steps=10,
        logging_strategy="steps",
        save_total_limit=3,
        bf16=True,
        max_grad_norm=1.0,  
        weight_decay=0.01,
        optim='adamw_torch',
        dataloader_pin_memory=True,
        # dataloader_num_workers=2,
        prediction_loss_only=True,
        
        report_to=["wandb"],
        run_name=run_name,

        eval_strategy="epoch",
        save_strategy="epoch",
        metric_for_best_model="eval_loss",
        load_best_model_at_end=True,  
        greater_is_better=False,
        )

In [18]:
def plm_collate_fn(batch, processor):
    """
    Collate function for Perception-LM fine-tuning.
    
    Args:
        batch: List of dicts with 'pixel_values' (PIL Image or tensor) and 'labels' (str)
        processor: AutoProcessor for Perception-LM
    
    Returns:
        Dict with processed inputs ready for the model
    """
    # Extract images and text labels
    images = [item['pixel_values'] for item in batch]
    labels = [item['labels'] for item in batch]
    
    # Process images - processor handles PIL Images or tensors
    # Returns pixel_values tensor
    inputs = processor(
        images=images,
        return_tensors="pt",
        padding=True
    )
    
    # Process text labels
    # For Perception-LM, we need to tokenize the target text
    text_inputs = processor.tokenizer(
        labels,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512  # Adjust based on your label lengths
    )
    
    # The model expects 'labels' for the decoder during training
    # Replace padding token ids with -100 so they're ignored in loss
    labels_tensor = text_inputs['input_ids'].clone()
    labels_tensor[labels_tensor == processor.tokenizer.pad_token_id] = -100
    
    # Combine everything
    inputs['labels'] = labels_tensor
    
    return inputs

def collate_fn_with_prompt(batch, max_length =1056, pad_token_id=processor.tokenizer.pad_token_id):
    """
    Collates already-processed outputs with instruction, [IMG], image tensors, and label token ids.
    Does not re-tokenize prompt—just pads and batches.
    """
    # Each item: {'pixel_values', 'input_ids', 'attention_mask', 'labels', ...}
    pixel_values = []
    input_ids = []
    attention_mask = []
    labels = []

    for item in batch:
        pixel_values.append(item['pixel_values'])
        input_ids.append(item['input_ids'])
        attention_mask.append(item['attention_mask'])
        labels.append(item['labels'])

    # Pad sequences (input_ids, attention_mask, labels) to max length in batch
    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id)
    attention_mask_padded = pad_sequence(attention_mask, batch_first=True, padding_value=0)
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=pad_token_id)

    # Mask labels padding token for loss computation
    labels_for_loss = labels_padded.clone()
    labels_for_loss[labels_padded == pad_token_id] = -100

    # Stack pixel_values if tensor shape matches (else use torch.cat or list as needed)
    pixel_values_batch = torch.stack(pixel_values)

    return {
        'input_ids': input_ids_padded,
        'attention_mask': attention_mask_padded,
        'pixel_values': pixel_values_batch,
        'labels': labels_for_loss,
    }

In [19]:
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=5, 
)
# trainer = setup_dit_bart_training(
#         train_synthdataset, val_synthdataset, training_args=training_args, model=model, text_tokenizer=processor.tokenizer,
#         run_name = run_name, 
#         callbacks=[early_stopping_callback],
#         max_length=max_token_size,
#         custom_collate_fn=collate_fn_with_prompt
#     )

In [20]:
peak_mem = torch.cuda.max_memory_allocated()
print(f"The model is holding: {peak_mem / 1024**3:.2f} of GPU RAM")

The model is holding: 2.10 of GPU RAM


In [21]:
# collate_fn = lambda batch: collate_fn_with_prompt(batch, processor.tokenizer)

trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_synthdataset,
        eval_dataset=train_synthdataset,
        data_collator=collate_fn_with_prompt,
        tokenizer=processor.tokenizer,
        # compute_metrics=compute_metrics,
        callbacks=[early_stopping_callback]
    )

  trainer = Seq2SeqTrainer(


In [22]:
trainer.train()

RuntimeError: DataLoader worker (pid(s) 18144, 8988) exited unexpectedly

In [None]:
processor.decode(processor(images=train_synthdataset[0]['image'], text = "Extract all the texts from the document.").input_ids[0])

'<|begin_of_text|>Extract all the texts from the document.'