In [2]:
# ==============================================================================
# 1. SETUP: INSTALL LIBRARIES AND IMPORTS
# ==============================================================================
# Install required libraries in the Colab environment
!pip install -q kagglehub transformers accelerate bitsandbytes peft tqdm

import os
import torch
import pandas as pd
import kagglehub
import PIL.Image
from tqdm.notebook import tqdm

from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.cuda.amp import autocast, GradScaler

from transformers import (
    ViTModel,
    GPT2LMHeadModel,
    GPT2TokenizerFast,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, TaskType

# ==============================================================================
# 2. CONFIGURATION: CHOOSE DEVICE AND SET PARAMETERS
# ==============================================================================
# <<<<<<<<<<<<<<< CHOOSE YOUR DEVICE HERE >>>>>>>>>>>>>>>>>
DEVICE_CHOICE = "cuda" if torch.cuda.is_available() else "cpu"
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<

device = torch.device(DEVICE_CHOICE)
print(f"Using device: {device}")

# --- Model & Training Parameters ---
GPT2_MODEL_NAME = "gpt2"
VIT_MODEL_NAME = "google/vit-base-patch16-224"
BATCH_SIZE = 20
NUM_EPOCHS = 3
LEARNING_RATE = 5e-5
MAX_TEXT_LENGTH = 128

# ==============================================================================
# 3. MODEL AND TOKENIZER LOADING (WITH CONDITIONAL QUANTIZATION)
# ==============================================================================
print("Loading models and tokenizer...")

tokenizer = GPT2TokenizerFast.from_pretrained(GPT2_MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    print("Added [PAD] token to tokenizer.")

quantization_config = None
if DEVICE_CHOICE == "cuda":
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )
    print("4-bit quantization is enabled for GPU.")

base_model = GPT2LMHeadModel.from_pretrained(
    GPT2_MODEL_NAME,
    quantization_config=quantization_config,
)
base_model.resize_token_embeddings(len(tokenizer))
vit = ViTModel.from_pretrained(VIT_MODEL_NAME)

# ==============================================================================
# 4. PEFT & LORA CONFIGURATION
# ==============================================================================
print("Configuring LoRA for both models...")
lora_config_gpt2 = LoraConfig(r=8, lora_alpha=16, target_modules=["c_attn", "c_proj"], lora_dropout=0.05, bias="none", task_type=TaskType.CAUSAL_LM)
lora_gpt2_model = get_peft_model(base_model, lora_config_gpt2)
print("--- GPT-2 with LoRA ---")
lora_gpt2_model.print_trainable_parameters()

lora_config_vit = LoraConfig(r=8, lora_alpha=16, target_modules=["query", "key", "value", "dense"], lora_dropout=0.05, bias="none")
lora_vit_model = get_peft_model(vit, lora_config_vit)
print("\n--- ViT with LoRA ---")
lora_vit_model.print_trainable_parameters()

# ==============================================================================
# 5. DATASET AND DATALOADER (CORRECTED)
# ==============================================================================
print("\nDownloading and preparing dataset...")
path = kagglehub.dataset_download("hsankesara/flickr-image-dataset")
csv_path = os.path.join(path, "flickr30k_images/results.csv")
img_dir = os.path.join(path, "flickr30k_images/flickr30k_images")
print(f"Dataset downloaded to: {path}")

df = pd.read_csv(csv_path, delimiter='|')
df.columns = [col.strip() for col in df.columns]

# <<<<<<<<<<<<<<<<<<<< THE FIX IS HERE >>>>>>>>>>>>>>>>>>>>
# 1. Remove rows with missing values in the 'comment' column.
df.dropna(subset=['comment'], inplace=True)

# 2. Ensure all comments are strings, stripping any extra whitespace.
df['comment'] = df['comment'].astype(str).str.strip()

# 3. Reset the index of the DataFrame after dropping rows.
df.reset_index(drop=True, inplace=True)
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<


# --- Image Transformations ---
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# --- Custom PyTorch Dataset (No changes needed here) ---
class FlickrDataset(Dataset):
    def __init__(self, dataframe, root_dir, tokenizer, transform=None, max_length=128):
        self.dataframe = dataframe
        self.root_dir = root_dir
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_length = max_length
    def __len__(self):
        return len(self.dataframe)
    def __getitem__(self, idx):
        try:
            img_name = self.dataframe.loc[idx, 'image_name']
            img_path = os.path.join(self.root_dir, img_name)
            caption = self.dataframe.loc[idx, 'comment'] # This will now always be a string
            image = PIL.Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            tokenized_caption = self.tokenizer(caption, return_tensors='pt', padding='max_length', truncation=True, max_length=self.max_length)
            caption_ids = tokenized_caption['input_ids'].squeeze(0)
            return image, caption_ids
        except (PIL.UnidentifiedImageError, FileNotFoundError) as e:
            print(f"Warning: Skipping corrupted or missing image at index {idx}: {e}")
            return self.__getitem__((idx + 1) % len(self))

dataset = FlickrDataset(df, img_dir, tokenizer, transform, MAX_TEXT_LENGTH)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# ==============================================================================
# 6. COMBINED MODEL ARCHITECTURE (Img2GPT)
# ==============================================================================
class Img2GPT(nn.Module):
    def __init__(self, vit_model, gpt2_model):
        super(Img2GPT, self).__init__()
        self.vit = vit_model
        self.gpt2 = gpt2_model
        vit_dim = self.vit.config.hidden_size
        gpt_dim = self.gpt2.config.hidden_size
        self.proj = nn.Linear(vit_dim, gpt_dim)
    def forward(self, pixel_values, labels=None):
        vit_outputs = self.vit(pixel_values=pixel_values)
        image_embeds = vit_outputs.last_hidden_state
        projected_image_embeds = self.proj(image_embeds)
        if labels is not None:
            text_embeds = self.gpt2.get_input_embeddings()(labels.long())
            inputs_embeds = torch.cat([projected_image_embeds, text_embeds], dim=1)
        else:
            inputs_embeds = projected_image_embeds
        output_labels = None
        if labels is not None:
            num_image_tokens = projected_image_embeds.shape[1]
            mask_labels = torch.full(projected_image_embeds.shape[:-1], -100, device=labels.device)
            output_labels = torch.cat([mask_labels, labels], dim=1)
        outputs = self.gpt2(inputs_embeds=inputs_embeds, labels=output_labels)
        return outputs

    @torch.no_grad()
    def generate_caption(self, pixel_values, max_length=50, num_beams=5):
        self.eval()
        pixel_values = pixel_values.to(next(self.parameters()).device)
        vit_outputs = self.vit(pixel_values=pixel_values)
        image_embeds = vit_outputs.last_hidden_state
        projected_image_embeds = self.proj(image_embeds)
        generated_ids = self.gpt2.generate(
            inputs_embeds=projected_image_embeds,
            max_length=max_length + projected_image_embeds.shape[1],
            num_beams=num_beams,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
        captions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        return [caption.strip() for caption in captions]

model = Img2GPT(lora_vit_model, lora_gpt2_model)
model.to(device)

total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total:,}")
print(f"Trainable parameters: {trainable:,} ({100 * trainable / total:.2f}%)")

# ==============================================================================
# 7. TRAINING LOOP WITH PERIODIC EVALUATION (CORRECTED)
# ==============================================================================
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
use_amp = (DEVICE_CHOICE == "cuda")
scaler = GradScaler(enabled=use_amp)

# <<<<<<<<<<<<<<<<<<<< PREPARE FIXED BATCH FOR EVALUATION >>>>>>>>>>>>>>>>>>
print("\nFetching a fixed batch for periodic evaluation...")
# Set a temporary dataloader without shuffling to get a consistent batch
temp_loader = DataLoader(dataset, batch_size=4, shuffle=False)
fixed_eval_batch = next(iter(temp_loader))
eval_images, eval_captions_ids = fixed_eval_batch

eval_images = eval_images.to(device)
true_eval_captions = tokenizer.batch_decode(eval_captions_ids, skip_special_tokens=True)
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<

print(f"\nStarting training for {NUM_EPOCHS} epochs...")
print(f"Automatic Mixed Precision (AMP) enabled: {use_amp}")

EVAL_INTERVAL = 1000 # Evaluate every 50 batches

for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0
    progress_bar = tqdm(enumerate(dataloader), desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", total=len(dataloader))

    for batch_idx, (images, captions) in progress_bar:
        images = images.to(device)
        captions = captions.to(device)

        # --- Training Step ---
        optimizer.zero_grad()
        with autocast(enabled=use_amp):
            outputs = model(pixel_values=images, labels=captions)
            loss = outputs.loss

        if use_amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})

        # --- Periodic Evaluation Step ---
        if (batch_idx + 1) % EVAL_INTERVAL == 0:
            print(f"\n--- Running evaluation at Epoch {epoch+1}, Batch {batch_idx+1} ---")
            torch.save(model.state_dict(), f"img2gpt_epoch_{epoch+1}.pth")
            # <<<<<<<<<<<<<<<<<<<< THE FIX IS HERE >>>>>>>>>>>>>>>>>>>>
            # The generation must also be inside an autocast context to handle float16/half dtypes
            with autocast(enabled=use_amp):
                generated_captions = model.generate_caption(eval_images, max_length=50)
            # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<

            # Display results
            for i in range(len(generated_captions)):
                print(f"  Image {i+1}:")
                print(f"    -> True Caption: {true_eval_captions[i].strip()}")
                print(f"    -> Generated Caption: {generated_captions[i]}")

            print("--- Evaluation finished, resuming training ---")

            # IMPORTANT: Switch back to training mode
            model.train()

    avg_loss = total_loss / len(dataloader)
    print(f"\nEpoch {epoch+1} finished. Average Loss: {avg_loss:.4f}")



print("\nTraining finished.")




# ==============================================================================
# 8. FINAL INFERENCE EXAMPLE
# ==============================================================================
print("\n--- Running Final Inference Example on the evaluation batch ---")
# Generate final captions
final_captions = model.generate_caption(eval_images, max_length=50)

# Display results
for i in range(len(final_captions)):
    print(f"\n--- Image {i+1} ---")
    print(f"  -> True Caption: {true_eval_captions[i].strip()}")
    print(f"  -> Final Generated Caption: {final_captions[i]}")

Using device: cuda
Loading models and tokenizer...
Added [PAD] token to tokenizer.
4-bit quantization is enabled for GPU.


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Configuring LoRA for both models...
--- GPT-2 with LoRA ---
trainable params: 811,008 || all params: 125,251,584 || trainable%: 0.6475

--- ViT with LoRA ---
trainable params: 1,339,392 || all params: 87,728,640 || trainable%: 1.5267

Downloading and preparing dataset...
Dataset downloaded to: /kaggle/input/flickr-image-dataset

Total parameters: 171,103,488
Trainable parameters: 2,740,992 (1.60%)

Fetching a fixed batch for periodic evaluation...

Starting training for 3 epochs...
Automatic Mixed Precision (AMP) enabled: True


  scaler = GradScaler(enabled=use_amp)


Epoch 1/3:   0%|          | 0/7946 [00:00<?, ?it/s]

  with autocast(enabled=use_amp):



--- Running evaluation at Epoch 1, Batch 1000 ---


  with autocast(enabled=use_amp):


  Image 1:
    -> True Caption: Two young guys with shaggy hair look at their hands while hanging out in the yard .
    -> Generated Caption: A man and a woman are walking down a street in the middle of the night in a dark colored van , looking at each other with their eyes closed , while the other man is standing in the middle of the street with his hands in his pockets .
  Image 2:
    -> True Caption: Two young , White males are outside near many bushes .
    -> Generated Caption: A man and a woman are walking down a street in the middle of the night in a dark colored van , looking at each other with their eyes closed , while the other man is standing in the middle of the street with his hands in his pockets .
  Image 3:
    -> True Caption: Two men in green shirts are standing in a yard .
    -> Generated Caption: A man and a woman are walking down a street in the middle of the night in a dark colored van , looking at each other with their eyes closed , while the other man is stand

Epoch 2/3:   0%|          | 0/7946 [00:00<?, ?it/s]


--- Running evaluation at Epoch 2, Batch 1000 ---
  Image 1:
    -> True Caption: Two young guys with shaggy hair look at their hands while hanging out in the yard .
    -> Generated Caption: A man in a blue shirt and blue pants is playing a guitar in front of a tree in front of a house in front of a house in front of a house in front of a house in front of a house in front of a house in front
  Image 2:
    -> True Caption: Two young , White males are outside near many bushes .
    -> Generated Caption: A man in a blue shirt and blue pants is playing a guitar in front of a tree in front of a house in front of a house in front of a house in front of a house in front of a house in front of a house in front
  Image 3:
    -> True Caption: Two men in green shirts are standing in a yard .
    -> Generated Caption: A man in a blue shirt and blue pants is playing a guitar in front of a tree in front of a house in front of a house in front of a house in front of a house in front of a house i

Epoch 3/3:   0%|          | 0/7946 [00:00<?, ?it/s]


--- Running evaluation at Epoch 3, Batch 1000 ---
  Image 1:
    -> True Caption: Two young guys with shaggy hair look at their hands while hanging out in the yard .
    -> Generated Caption: A man in a blue shirt is playing a guitar on a grassy hillside with a man in a red shirt and a woman in a blue shirt . convolvents in front of a tree with a man in a blue shirt and a woman in
  Image 2:
    -> True Caption: Two young , White males are outside near many bushes .
    -> Generated Caption: A man in a blue shirt is playing a guitar on a grassy hillside with a man in a red shirt and a woman in a blue shirt . convolvents in front of a tree with a man in a blue shirt and a woman in
  Image 3:
    -> True Caption: Two men in green shirts are standing in a yard .
    -> Generated Caption: A man in a blue shirt is playing a guitar on a grassy hillside with a man in a red shirt and a woman in a blue shirt . convolvents in front of a tree with a man in a blue shirt and a woman in
  Image 4:


RuntimeError: expected scalar type Float but found Half