In [None]:
import torch
import pandas as pd
import yaml
from pathlib import Path
from sklearn.model_selection import train_test_split
from transformers import (
    AutoModelForCausalLM, Blip2Model, BlipImageProcessor, 
    AutoTokenizer, BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader, Subset

# From src dir
from model import BLIP2ForPhi, setup_model, select_train_params
from dataset import ImageCaptioningDataset, get_datasets
from trainer import CustomTrainer


In [3]:
import os
os.environ["HF_TOKEN"] = "TOKEN ID"

In [6]:
config_path = '../configs/config.yaml' 

with open(config_path, 'r') as f:
        config = yaml.safe_load(f)

model, image_processor, tokenizer = setup_model(config)
num_trainable_params = select_train_params(model, language_model=False)

print(f"training {num_trainable_params} params...")
print(f"Preparing dataset...")

train_dataset, valid_dataset, train_debug, valid_debug = get_datasets(config['dataset']['image_captioning'], config, image_processor, tokenizer)

trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.AdamW(
    trainable_params, 
    lr=float(config['training']['learning_rate']), 
    weight_decay=config['training']['weight_decay']
)

trainer = CustomTrainer(
        model=model,
        optimizer=optimizer,
        tokenizer=tokenizer,
        train_dataset=train_debug,
        val_dataset=valid_debug,
        dataset_name=config['dataset']['image_captioning'],
        batch_size=config['training']['batch_size'],
        save_dir=config['path']['save_dir'],
        repo_id=config['hf']['repo_id']
    )

print("Starting training...")
trainer.train(
    num_epochs=config['training']['num_epochs'],
    resume_from_checkpoint=config['path'].get('resume_from_checkpoint'),
)
print("Training finished!")

2025-07-06 22:32:16.629491: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751841136.824421      88 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751841136.881616      88 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
