In [1]:
# ============================================================================
# CELL 1: Install Dependencies
# ============================================================================
!pip install -q transformers==4.57.0 datasets accelerate sentencepiece safetensors \
    einops ftfy regex pillow torch torchvision peft bitsandbytes

# ============================================================================
# CELL 2: Setup and Imports
# ============================================================================
import torch
import os
from PIL import Image
from transformers import (
    AutoProcessor,
    Kosmos2ForConditionalGeneration,
    TrainingArguments,
    Trainer,
    default_data_collator
)
from datasets import load_dataset
import numpy as np
from torch.utils.data import Dataset
import gc

# Check GPU
print(f"GPU Available: {torch.cuda.is_available()}")
print(f"GPU Name: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}")
print(f"CUDA Version: {torch.version.cuda}")


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.4/41.4 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
Reason for being yanked: Error in the setup causing installation issues[0m[33m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.0/12.0 MB[0m [31m73.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[?25hGPU Available: True
GPU Name: Tesla T4
CUDA Version: 12.6


In [2]:
# ============================================================================
# CELL 3: Load HF Token (Optional but recommended)
# ============================================================================
from google.colab import userdata
try:
    HF_TOKEN = userdata.get('HF_TOKEN')
    print("HF token loaded successfully")
except:
    HF_TOKEN = None
    print("No HF token found - proceeding without authentication")


HF token loaded successfully


In [3]:
# ============================================================================
# CELL 4: Load Model and Processor
# ============================================================================
model_id = "microsoft/kosmos-2-patch14-224"
print(f"Loading processor from {model_id}...")
processor = AutoProcessor.from_pretrained(
    model_id,
    use_auth_token=HF_TOKEN if HF_TOKEN else None
)

print(f"Loading model from {model_id}...")
model = Kosmos2ForConditionalGeneration.from_pretrained(
    model_id,
    use_auth_token=HF_TOKEN if HF_TOKEN else None,
    torch_dtype=torch.float16,  # Use FP16 to save memory
    device_map="auto"
)

print(f"Model loaded on device: {model.device}")
print(f"Model dtype: {model.dtype}")


Loading processor from microsoft/kosmos-2-patch14-224...




preprocessor_config.json:   0%|          | 0.00/534 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json: 0.00B [00:00, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


Loading model from microsoft/kosmos-2-patch14-224...




config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/6.66G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

Model loaded on device: cuda:0
Model dtype: torch.float16


In [None]:
# ============================================================================
# CELL 5: Load and Explore Flickr30k Dataset
# ============================================================================
print("Loading Flickr30k dataset from raw files...")

import pandas as pd
from huggingface_hub import hf_hub_download
import zipfile
import os
from PIL import Image
import ast

# Download the CSV file with annotations
print("Downloading annotations CSV...")
csv_path = hf_hub_download(
    repo_id="nlphuji/flickr30k",
    filename="flickr_annotations_30k.csv",
    repo_type="dataset"
)

# Load annotations
print("Loading annotations...")
df = pd.read_csv(csv_path)
print(f"Loaded {len(df)} rows")
print(f"Columns: {df.columns.tolist()}")

# Parse the data - 'raw' column contains list of captions as string
dataset_list = []
for idx, row in df.iterrows():
    try:
        # The 'raw' column contains a string representation of a list
        captions_str = row['raw']
        captions = ast.literal_eval(captions_str)  # Convert string to list

        filename = row['filename']

        if captions and filename and isinstance(captions, list):
            dataset_list.append({
                'filename': filename,
                'captions': captions
            })

        if idx % 5000 == 0 and idx > 0:
            print(f"Processed {idx}/{len(df)} rows, found {len(dataset_list)} valid images")

    except Exception as e:
        if idx < 5:  # Show errors for first few rows for debugging
            print(f"Error at row {idx}: {e}")
        continue

print(f"\n✅ Processed: {len(dataset_list)} images with captions")
if dataset_list:
    print(f"Sample: {dataset_list[0]['filename']} has {len(dataset_list[0]['captions'])} captions")
    print(f"First caption: {dataset_list[0]['captions'][0][:80]}...")

# Download and extract images
print("\nDownloading images zip file (4.4GB - this may take several minutes)...")
zip_path = hf_hub_download(
    repo_id="nlphuji/flickr30k",
    filename="flickr30k-images.zip",
    repo_type="dataset"
)

# Extract images
images_dir = "./flickr30k_images"
if not os.path.exists(images_dir):
    print("Extracting images...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(images_dir)
    print(f"Images extracted to {images_dir}")
else:
    print(f"Images already extracted in {images_dir}")

# Find the actual images directory
actual_images_dir = "./flickr30k_images/flickr30k-images"
if not os.path.exists(actual_images_dir):
    actual_images_dir = images_dir

print(f"Images directory: {actual_images_dir}")

# Check image files
image_files = [f for f in os.listdir(actual_images_dir) if f.endswith('.jpg')]
print(f"Found {len(image_files)} .jpg files")
if image_files:
    print(f"Sample files: {image_files[:3]}")

# Create dataset dictionary
from datasets import Dataset

dataset_dict = {
    'image_path': [],
    'caption': []
}

print("\nMatching images with captions...")
matched = 0
for item in dataset_list:
    image_filename = item['filename']
    image_path = os.path.join(actual_images_dir, image_filename)

    if os.path.exists(image_path):
        dataset_dict['image_path'].append(image_path)
        dataset_dict['caption'].append(item['captions'])
        matched += 1

    if matched % 5000 == 0 and matched > 0:
        print(f"Matched {matched} images...")

print(f"\n✅ Matched {len(dataset_dict['image_path'])} images with captions")

# Create HuggingFace dataset
dataset = Dataset.from_dict(dataset_dict)

# Load images
def load_image(example):
    example['image'] = Image.open(example['image_path']).convert('RGB')
    return example

print("Loading images into dataset...")
dataset = dataset.map(load_image, num_proc=1)

print(f"\n✅ Dataset ready: {len(dataset)} examples")
print(f"Features: {dataset.features}")

# Show first example
print("\n📸 First example:")
example = dataset[0]
print(f"  Filename: {example['image_path'].split('/')[-1]}")
print(f"  Captions: {len(example['caption'])}")
print(f"  Caption 1: {example['caption'][0]}")
print(f"  Image size: {example['image'].size}")

# Display first image
from IPython.display import display
print("\nDisplaying first image:")
display(example['image'].resize((400, 400)))

Loading Flickr30k dataset from raw files...
Downloading annotations CSV...
Loading annotations...
Loaded 31014 rows
Columns: ['raw', 'sentids', 'split', 'filename', 'img_id']
Processed 5000/31014 rows, found 5001 valid images
Processed 10000/31014 rows, found 10001 valid images
Processed 15000/31014 rows, found 15001 valid images
Processed 20000/31014 rows, found 20001 valid images
Processed 25000/31014 rows, found 25001 valid images
Processed 30000/31014 rows, found 30001 valid images

✅ Processed: 31014 images with captions
Sample: 1000092795.jpg has 5 captions
First caption: Two young guys with shaggy hair look at their hands while hanging out in the yar...

Downloading images zip file (4.4GB - this may take several minutes)...
Images already extracted in ./flickr30k_images
Images directory: ./flickr30k_images/flickr30k-images
Found 31783 .jpg files
Sample files: ['4713532955.jpg', '4559349547.jpg', '3827180184.jpg']

Matching images with captions...
Matched 5000 images...
Matched 1

Map:   0%|          | 0/31014 [00:00<?, ? examples/s]

In [None]:
# ============================================================================
# CELL 6: Create Custom Dataset Class for Flickr30k
# ============================================================================
class Flickr30kKosmos2Dataset(Dataset):
    def __init__(self, hf_dataset, processor, max_length=128):
        self.dataset = hf_dataset
        self.processor = processor
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]

        # Get image
        image = item['image']
        if image.mode != 'RGB':
            image = image.convert('RGB')

        # Get caption - Flickr30k has a list of 5 captions per image
        # Randomly select one caption for training diversity
        captions = item['caption']
        if isinstance(captions, list):
            import random
            caption = random.choice(captions).strip()
        else:
            caption = str(captions).strip()

        # Format prompt for KOSMOS-2
        prompt = "<grounding> An image of"
        full_text = f"{prompt} {caption}"

        # Process inputs
        encoding = self.processor(
            text=full_text,
            images=image,
            return_tensors="pt",
            padding="max_length",
            max_length=self.max_length,
            truncation=True
        )

        # Remove batch dimension and prepare labels
        encoding = {k: v.squeeze(0) for k, v in encoding.items()}

        # Create labels (same as input_ids for causal LM)
        encoding['labels'] = encoding['input_ids'].clone()

        return encoding


In [None]:
# ============================================================================
# CELL 7: Prepare Train/Validation Split - FAST TRAINING MODE
# ============================================================================
print("Preparing dataset splits for FAST TRAINING...")

# REDUCED for 2-hour training on T4
train_size = 1000   # Reduced from 5000 to 1000
val_size = 100      # Reduced from 500 to 100

total_size = len(dataset)
print(f"Total dataset size: {total_size}")
print(f"⚡ FAST MODE: Using only {train_size} training images")

# Create train/val split
dataset_split = dataset.train_test_split(test_size=val_size, seed=42)

# Select subset for training
if len(dataset_split['train']) > train_size:
    train_data = dataset_split['train'].select(range(train_size))
else:
    train_data = dataset_split['train']

val_data = dataset_split['test']

# Create wrapped datasets
train_dataset = Flickr30kKosmos2Dataset(
    train_data,
    processor,
    max_length=128
)

val_dataset = Flickr30kKosmos2Dataset(
    val_data,
    processor,
    max_length=128
)

print(f"✅ Train dataset size: {len(train_dataset)}")
print(f"✅ Validation dataset size: {len(val_dataset)}")
print(f"⏱️ Expected training time: ~1.5-2 hours")

In [None]:
# ============================================================================
# CELL 8: Test Dataset Loading
# ============================================================================
print("\nTesting dataset loading...")
sample = train_dataset[0]
print(f"Sample keys: {sample.keys()}")
print(f"Input IDs shape: {sample['input_ids'].shape}")
print(f"Pixel values shape: {sample['pixel_values'].shape}")
print(f"Labels shape: {sample['labels'].shape}")


In [None]:
# ============================================================================
# CELL 9: Setup Training Arguments - FAST TRAINING MODE (2 HOURS)
# ============================================================================
output_dir = "./kosmos2-flickr30k-finetuned"

training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=2,                    # Changed: 3 → 2
    per_device_train_batch_size=2,         # Changed: 1 → 2
    per_device_eval_batch_size=2,          # Changed: 1 → 2
    gradient_accumulation_steps=4,         # Changed: 8 → 4
    learning_rate=8e-5,                    # Changed: 5e-5 → 8e-5
    weight_decay=0.01,
    warmup_steps=50,                       # Changed: 100 → 50
    logging_steps=25,                      # Changed: 50 → 25
    eval_strategy="steps",
    eval_steps=125,                        # Changed: 200 → 125
    save_steps=125,                        # Changed: 200 → 125
    save_total_limit=2,
    fp16=True,
    dataloader_num_workers=2,
    remove_unused_columns=False,
    report_to="none",
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    greater_is_better=False,
    gradient_checkpointing=True,
    optim="adamw_torch",
)

print("⚡ FAST TRAINING MODE - 2 HOUR TARGET")
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"Total steps: ~{(1000 * 2) // 8}")
print(f"Validations: 2 times (100 samples each)")
print(f"Expected time: ~1.5 hours")

In [None]:
# ============================================================================
# CELL 10: Enable Gradient Checkpointing
# ============================================================================
if training_args.gradient_checkpointing:
    model.gradient_checkpointing_enable()
    print("Gradient checkpointing enabled")


In [None]:
# ============================================================================
# CELL 11: Define Data Collator
# ============================================================================
def kosmos2_data_collator(features):
    """Custom collator for KOSMOS-2 that handles all required inputs"""
    batch = {}

    # Collect all keys from features
    keys = features[0].keys()

    for key in keys:
        batch[key] = torch.stack([f[key] for f in features])

    return batch


In [None]:
# ============================================================================
# CELL 12: Initialize Trainer
# ============================================================================
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=kosmos2_data_collator,
)

print("Trainer initialized successfully")


In [None]:
# ============================================================================
# CELL 13: Start Training
# ============================================================================
print("\n" + "="*50)
print("STARTING TRAINING")
print("="*50 + "\n")

# Clear cache before training
torch.cuda.empty_cache()
gc.collect()

# Train the model
train_result = trainer.train()

print("\n" + "="*50)
print("TRAINING COMPLETED")
print("="*50)
print(f"Training loss: {train_result.training_loss:.4f}")
print(f"Training time: {train_result.metrics['train_runtime']:.2f} seconds")


In [None]:
# ============================================================================
# CELL 14: Save the Fine-tuned Model
# ============================================================================
print("\nSaving fine-tuned model...")
trainer.save_model(output_dir)
processor.save_pretrained(output_dir)
print(f"Model saved to {output_dir}")


In [None]:
# ============================================================================
# CELL 15: Evaluate the Model
# ============================================================================
print("\nEvaluating model on validation set...")
eval_results = trainer.evaluate()
print("\nEvaluation Results:")
for key, value in eval_results.items():
    print(f"{key}: {value:.4f}")


In [None]:
# ============================================================================
# CELL 16: Test Inference on Sample Images
# ============================================================================
print("\n" + "="*50)
print("TESTING INFERENCE")
print("="*50 + "\n")

# Load a test image from dataset
test_idx = 0
test_sample = dataset[test_idx]
test_image = test_sample['image']
true_captions = test_sample['caption']

print(f"True captions: {true_captions[:2]}")  # Show first 2 captions

# Display image
from IPython.display import display
display(test_image.resize((400, 400)))

# Generate caption
prompt = "<grounding> An image of"
inputs = processor(text=prompt, images=test_image, return_tensors="pt")

# Move to device
device = model.device
inputs = {k: v.to(device) for k, v in inputs.items()}

# Generate
with torch.no_grad():
    generated_ids = model.generate(
        pixel_values=inputs["pixel_values"],
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        image_embeds_position_mask=inputs.get("image_embeds_position_mask"),
        max_new_tokens=64,
        num_beams=3,
    )

# Decode
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(f"\nGenerated text: {generated_text}")

try:
    caption, entities = processor.post_process_generation(generated_text)
    print(f"\nExtracted caption: {caption}")
    print(f"Entities: {entities}")
except Exception as e:
    print(f"Post-processing note: {e}")


In [None]:

# ============================================================================
# CELL 17: Test on Your Own Image (Optional)
# ============================================================================
print("\n" + "="*50)
print("TEST ON YOUR OWN IMAGE")
print("="*50 + "\n")

# Uncomment to upload your own image
"""
from google.colab import files
uploaded = files.upload()
fname = next(iter(uploaded))
user_image = Image.open(fname).convert('RGB')
display(user_image.resize((400, 400)))

# Generate caption
prompt = "<grounding> An image of"
inputs = processor(text=prompt, images=user_image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

with torch.no_grad():
    generated_ids = model.generate(
        pixel_values=inputs["pixel_values"],
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        image_embeds_position_mask=inputs.get("image_embeds_position_mask"),
        max_new_tokens=64,
        num_beams=3,
    )

generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(f"Generated: {generated_text}")
"""

# ============================================================================
# CELL 18: Save to Google Drive (Optional)
# ============================================================================
print("\n" + "="*50)
print("SAVE TO GOOGLE DRIVE (OPTIONAL)")
print("="*50 + "\n")

# Uncomment to save to Google Drive
"""
from google.colab import drive
drive.mount('/content/drive')

import shutil
drive_output = '/content/drive/MyDrive/kosmos2-flickr30k-finetuned'
shutil.copytree(output_dir, drive_output, dirs_exist_ok=True)
print(f"Model saved to Google Drive: {drive_output}")
"""

print("\n✅ Fine-tuning pipeline completed successfully!")