In [1]:
!pip install -q pycocoevalcap
!pip install -q wandb

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.3/104.3 MB[0m [31m15.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
%%writefile data_loading.py
from torch.utils.data import Dataset, DataLoader, random_split
from datasets import load_dataset
import torch

class HnMDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        encoding = self.processor(images=item["image"], text=item["text"], padding="max_length", return_tensors="pt")
        # remove batch dimension
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        return encoding

def prepare_data(dataset, data_processor, train_size, batch_size):
    # create dataset
    ds = load_dataset(dataset, split='train')
    dataset = HnMDataset(ds, data_processor)
    
    # train/val split
    train_set, val_set = random_split(dataset, lengths=[0.8, 0.2], generator=torch.Generator().manual_seed(42))

    # data loader
    train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size)
    val_loader = DataLoader(val_set, shuffle=False, batch_size=2 * batch_size)

    return train_loader, val_loader

Writing data_loading.py


In [3]:
%%writefile metrics.py
from pycocoevalcap.cider.cider import Cider

def cider_score(ground_truth, generated):
    cider_score, _ = Cider().compute_score({i: [gt] for i, gt in enumerate(ground_truth)},
                                          {i: [pred] for i, pred in enumerate(generated)})
    return cider_score

Writing metrics.py


In [4]:
%%writefile model.py
from transformers import AutoProcessor, BlipForConditionalGeneration

def prepare_model(freeze_vit=False, freeze_bert=False):
    # load model and processor
    processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

    # freeze parameters of ViT
    if freeze_vit:
        for parameter in model._modules['vision_model'].parameters():
            parameter.requires_grad = False

    if freeze_bert:
        for parameter in model._modules['text_decoder'].parameters():
            parameter.requires_grad = False

    return model, processor

Writing model.py


In [5]:
%%writefile train.py
import torch
from data_loading import prepare_data
from model import prepare_model
from tqdm import tqdm
from accelerate import Accelerator
from transformers import AdamW
from transformers import get_scheduler
import wandb
from metrics import cider_score
from kaggle_secrets import UserSecretsClient
import os
        
def save_checkpoint(save_path, model, epoch, best_cider, optimizer=None, lr_scheduler=None):
    checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': lr_scheduler.state_dict(),
            'best_metric': best_cider
        }
    torch.save(checkpoint, save_path)

def train_blip(
    num_epochs, batch_size, train_size, lr=5e-5, weight_decay=0.01,
    caption_max_length=35,
    freeze_vit=False, freeze_bert=False, 
    checkpoint_path=None, logger=None
):
    # metrics
    best_cider = 0.0   # can be changed if load checkpoint

    # prepare model
    model, processor = prepare_model(freeze_vit=freeze_vit, freeze_bert=freeze_bert)
    
    # prepare data
    train_loader, val_loader = prepare_data(
        dataset='tomytjandra/h-and-m-fashion-caption', data_processor=processor,
        train_size=train_size, batch_size=batch_size
    )
    
    # optimizer
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    # lr scheduler
    num_training_steps = len(train_loader) * num_epochs
    lr_scheduler = get_scheduler(
        "cosine", optimizer=optimizer, num_warmup_steps=0.1 * num_training_steps, 
        num_training_steps=num_training_steps
    )
    
    # continue from checkpoint if specify
    if checkpoint_path is not None:
        checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        epoch = checkpoint['epoch']
        best_cider = checkpoint['best_metric']
        print(f"Checkpoint loaded: Resuming from Epoch {epoch}, best_cider={best_cider}")


    # TRAINING LOOPS
    
    # accelerator to use dual gpu
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    accelerator = Accelerator()
    train_loader, val_loader, model, optimizer, lr_scheduler = accelerator.prepare(
            train_loader, val_loader, model, optimizer, lr_scheduler 
    )
    
    for epoch in range(num_epochs):
        model.train()
        train_loops = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}:')
        epoch_loss = 0.0
        for batch in train_loops:
            input_ids = batch.pop("input_ids").to(device)
            pixel_values = batch.pop("pixel_values").to(device)
    
            outputs = model(input_ids=input_ids,
                            pixel_values=pixel_values,
                            labels=input_ids)
    
            loss = outputs.loss
            accelerator.backward(loss)
            optimizer.step()
            optimizer.zero_grad()

            # update scheduler
            lr_scheduler.step()
    
            # visualize batch loss
            all_gpu_loss = accelerator.gather(loss).mean().item()
            epoch_loss += all_gpu_loss
        epoch_loss /= len(train_loader)
    
        # evaluate
        model.eval()
        with torch.no_grad():
            ## val set                                                                          
            eval_loops = tqdm(val_loader, desc=f'Validating epoch {epoch+1}/{num_epochs}:')
            val_loss = 0.0
            ground_truth = []
            generated = []
            for batch in eval_loops:
                input_ids = batch.pop("input_ids").to(device)
                pixel_values = batch.pop("pixel_values").to(device)
    
                outputs = model(input_ids=input_ids,
                            pixel_values=pixel_values,
                            labels=input_ids)
                
                # ground truth
                gt_captions = processor.batch_decode(input_ids, skip_special_tokens=True)
                ground_truth.extend(gt_captions)
                # generated
                generated_ids = model.module.generate(pixel_values=pixel_values, max_length=caption_max_length)
                generated_captions = processor.batch_decode(generated_ids, skip_special_tokens=True)
                generated.extend(generated_captions)
            
                # loss
                loss = outputs.loss
                all_gpu_loss = accelerator.gather(loss).mean().item()
                val_loss += all_gpu_loss
            val_loss /= len(val_loader)
            val_cider = cider_score(ground_truth, generated)
        
            # log
            print(f'Epoch {epoch+1}/{num_epochs}: train_loss={epoch_loss:.6f}, val_loss={val_loss:.6f}, val_cider={val_cider:.6f}')
            logger.log({'train_loss': epoch_loss, 'val_loss': val_loss, 'val_cider': val_cider})
    
            # save check point
            if val_cider > best_cider:
                unwrapped_model = accelerator.unwrap_model(model)
                best_cider = val_cider
                save_checkpoint('/kaggle/working/best.pth', unwrapped_model,
                                   epoch+1, best_cider, optimizer, lr_scheduler)
                print(f"Best checkpoint saved at epoch {epoch+1}")
    
    else:
        unwrapped_model = accelerator.unwrap_model(model)
        save_checkpoint('/kaggle/working/last.pth', unwrapped_model, 
                            epoch+1, best_cider, optimizer, lr_scheduler)
        print(f"Last checkpoint saved.")
    

if __name__ == '__main__':
    # login to wandb
    os.environ["WANDB_MODE"] = "offline"
    user_secrets = UserSecretsClient()
    api_key = user_secrets.get_secret("wandb api key")
    wandb.login(anonymous='never', key=api_key)
    
    # uncomment next line to resume from checkpoint
    checkpoint_path='/kaggle/input/blip_8_epochs/pytorch/default/1/best.pth'

    # initialize hyper-parameters
    num_epochs = 9
    lr = 5e-5
    weight_decay = 0.01
    caption_max_length = 50
    freeze_vit = False
    freeze_bert = False
    
    run = wandb.init(
        id='unfreeze_vit',
        project="train_blip_on_h&m",
        config={
            "dataset": "H&M captions",
            "epochs": num_epochs,
            "learning_rate": lr,
            'weight_decay': weight_decay,
            'caption_max_length': caption_max_length,
            'freeze_vit': freeze_vit,
            'freeze_bert': freeze_bert,
            'continue_from_checkpoint': checkpoint_path,
            'machine': 'offline cluster'
        },
        resume='allow'
    )
    train_blip(num_epochs=num_epochs, batch_size=4, train_size=0.8, lr=lr, weight_decay=weight_decay,
                caption_max_length=caption_max_length,
                freeze_vit=freeze_vit, freeze_bert=freeze_bert, 
                checkpoint_path=checkpoint_path, logger=run
    )
    run.finish()

Writing train.py


In [6]:
!accelerate launch train.py

2025-04-06 00:27:36.762761: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-06 00:27:36.762792: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-06 00:27:37.243115: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-06 00:27:37.243132: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-06 00:27:37.374611: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register fac