## 1. Check GPU and Install Dependencies

In [None]:
# Check GPU
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install required packages
!pip install -q nltk gensim pillow tqdm

In [None]:
# Download NLTK data
import nltk
nltk.download('punkt', quiet=True)
nltk.download('wordnet', quiet=True)
nltk.download('averaged_perceptron_tagger', quiet=True)

## 2. Mount Google Drive and Setup Data

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Set up paths - MODIFY THIS to match your Drive folder
import os
from pathlib import Path

# Path to your uploaded data in Google Drive
DRIVE_DATA_PATH = '/content/drive/MyDrive/artemis-captioning'

# Check if data exists
if os.path.exists(DRIVE_DATA_PATH):
    print(f"✓ Data folder found: {DRIVE_DATA_PATH}")
    print("Contents:")
    for item in os.listdir(DRIVE_DATA_PATH):
        item_path = os.path.join(DRIVE_DATA_PATH, item)
        if os.path.isdir(item_path):
            print(f"  - {item}/ ({len(os.listdir(item_path))} items)")
        else:
            print(f"  - {item}")
else:
    print(f"✗ Data folder not found: {DRIVE_DATA_PATH}")
    print("Please upload your data to Google Drive first.")
    print("\nRequired structure:")
    print("  artemis-captioning/")
    print("    ├── data/")
    print("    │   ├── processed/")
    print("    │   │   ├── images/      # Pre-resized 128x128 images (~57 MB)")
    print("    │   │   ├── splits/")
    print("    │   │   └── vocabulary.json")
    print("    │   └── raw/wikiart/     # Original images (if not using preprocessed)")
    print("    ├── models/")
    print("    ├── utils/")
    print("    ├── scripts/")
    print("    └── train.py")

In [None]:
# Copy data to local storage for faster access
!mkdir -p /content/artemis
!cp -r "{DRIVE_DATA_PATH}/data" /content/artemis/
!cp -r "{DRIVE_DATA_PATH}/utils" /content/artemis/
!cp -r "{DRIVE_DATA_PATH}/models" /content/artemis/
!cp -r "{DRIVE_DATA_PATH}/scripts" /content/artemis/
!cp "{DRIVE_DATA_PATH}/train.py" /content/artemis/

# Check if preprocessed images exist
import os
processed_images_path = "/content/artemis/data/processed/images"
if os.path.exists(processed_images_path) and os.listdir(processed_images_path):
    num_images = sum(1 for _ in Path(processed_images_path).rglob("*.jpg"))
    print(f"✓ Found {num_images} preprocessed images (128x128)")
else:
    print("⚠ Preprocessed images not found. Will use raw images with on-the-fly resizing.")
    print("  For faster training, run preprocess_images.py locally first.")

print("✓ Data copied to /content/artemis/")

In [None]:
# Change to project directory
import sys
os.chdir('/content/artemis')
sys.path.insert(0, '/content/artemis')
print(f"Working directory: {os.getcwd()}")

## 3. Import Modules and Setup

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import json
import time
from datetime import datetime
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

# Import project modules
from utils.text_preprocessing import TextPreprocessor
from utils.data_loader import create_data_loaders
from models.cnn_lstm import create_model as create_cnn_lstm
from models.vision_transformer import create_vit_model
from train import Trainer

print("✓ All modules imported successfully")

In [None]:
# Configuration
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 32  # Larger batch size for GPU
NUM_WORKERS = 2
MAX_BATCHES = 469  # ~15000 images (469 * 32 = 15008)
EPOCHS = 50

print(f"Device: {DEVICE}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Max batches: {MAX_BATCHES} (~{MAX_BATCHES * BATCH_SIZE} images)")
print(f"Epochs: {EPOCHS}")

In [None]:
# Load vocabulary
text_proc = TextPreprocessor()
text_proc.load_vocabulary('data/processed/vocabulary.json')
print(f"Vocabulary size: {text_proc.vocab_size}")

In [None]:
# Create data loaders
loaders = create_data_loaders(
    text_preprocessor=text_proc,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    splits=['train', 'val']
)

print(f"Train samples: {len(loaders['train'].dataset)}")
print(f"Val samples: {len(loaders['val'].dataset)}")

In [None]:
# Limited loader for controlled training
class LimitedLoader:
    def __init__(self, loader, max_batches):
        self.loader = loader
        self.max_batches = max_batches
        self.batch_size = loader.batch_size
        self.dataset = loader.dataset
    
    def __iter__(self):
        for i, batch in enumerate(self.loader):
            if i >= self.max_batches:
                break
            yield batch
    
    def __len__(self):
        return min(len(self.loader), self.max_batches)

# Create limited loaders
train_loader = LimitedLoader(loaders['train'], MAX_BATCHES)
val_loader = LimitedLoader(loaders['val'], MAX_BATCHES // 4)  # Smaller validation

print(f"Train batches: {len(train_loader)} (~{len(train_loader) * BATCH_SIZE} images)")
print(f"Val batches: {len(val_loader)}")

## 4. Train CNN+LSTM Model

In [None]:
# Create CNN+LSTM model
cnn_lstm_model = create_cnn_lstm(
    vocab_size=text_proc.vocab_size,
    embed_dim=256
).to(DEVICE)

total_params = sum(p.numel() for p in cnn_lstm_model.parameters())
print(f"CNN+LSTM Parameters: {total_params:,}")

In [None]:
# Create checkpoint directory
!mkdir -p checkpoints/colab_cnn_lstm
!mkdir -p outputs/colab_cnn_lstm

# Train CNN+LSTM
print("="*70)
print("TRAINING CNN+LSTM")
print("="*70)

cnn_trainer = Trainer(
    model=cnn_lstm_model,
    train_loader=train_loader,
    val_loader=val_loader,
    text_preprocessor=text_proc,
    device=DEVICE,
    checkpoint_dir='checkpoints/colab_cnn_lstm',
    output_dir='outputs/colab_cnn_lstm'
)

start_time = time.time()
cnn_history = cnn_trainer.train(num_epochs=EPOCHS)
cnn_time = time.time() - start_time

print(f"\nCNN+LSTM Training complete in {cnn_time/60:.1f} minutes")
print(f"Best BLEU: {max(cnn_history['val_bleu']):.4f}")

In [None]:
# Plot CNN+LSTM training history
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(cnn_history['train_loss'], label='Train Loss')
axes[0].plot(cnn_history['val_loss'], label='Val Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('CNN+LSTM Loss')
axes[0].legend()

axes[1].plot(cnn_history['val_bleu'], label='Val BLEU', color='green')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('BLEU Score')
axes[1].set_title('CNN+LSTM BLEU Score')
axes[1].legend()

plt.tight_layout()
plt.savefig('outputs/colab_cnn_lstm/training_curves.png', dpi=150)
plt.show()

## 5. Train Vision Transformer Model

In [None]:
# Create Vision Transformer model
vit_model = create_vit_model(
    vocab_size=text_proc.vocab_size,
    embed_dim=256,
    encoder_layers=6,  # More layers for GPU
    decoder_layers=6
).to(DEVICE)

total_params = sum(p.numel() for p in vit_model.parameters())
print(f"Vision Transformer Parameters: {total_params:,}")

In [None]:
# Create checkpoint directory
!mkdir -p checkpoints/colab_vit
!mkdir -p outputs/colab_vit

# Train ViT
print("="*70)
print("TRAINING VISION TRANSFORMER")
print("="*70)

vit_trainer = Trainer(
    model=vit_model,
    train_loader=train_loader,
    val_loader=val_loader,
    text_preprocessor=text_proc,
    device=DEVICE,
    checkpoint_dir='checkpoints/colab_vit',
    output_dir='outputs/colab_vit'
)

start_time = time.time()
vit_history = vit_trainer.train(num_epochs=EPOCHS)
vit_time = time.time() - start_time

print(f"\nViT Training complete in {vit_time/60:.1f} minutes")
print(f"Best BLEU: {max(vit_history['val_bleu']):.4f}")

In [None]:
# Plot ViT training history
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(vit_history['train_loss'], label='Train Loss')
axes[0].plot(vit_history['val_loss'], label='Val Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Vision Transformer Loss')
axes[0].legend()

axes[1].plot(vit_history['val_bleu'], label='Val BLEU', color='green')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('BLEU Score')
axes[1].set_title('Vision Transformer BLEU Score')
axes[1].legend()

plt.tight_layout()
plt.savefig('outputs/colab_vit/training_curves.png', dpi=150)
plt.show()

## 6. Compare Models

In [None]:
# Comparison summary
print("="*70)
print("MODEL COMPARISON")
print("="*70)

comparison = {
    'CNN+LSTM': {
        'final_train_loss': cnn_history['train_loss'][-1],
        'final_val_loss': cnn_history['val_loss'][-1],
        'best_bleu': max(cnn_history['val_bleu']),
        'training_time_min': cnn_time / 60
    },
    'ViT': {
        'final_train_loss': vit_history['train_loss'][-1],
        'final_val_loss': vit_history['val_loss'][-1],
        'best_bleu': max(vit_history['val_bleu']),
        'training_time_min': vit_time / 60
    }
}

for model_name, metrics in comparison.items():
    print(f"\n{model_name}:")
    print(f"  Final Train Loss: {metrics['final_train_loss']:.4f}")
    print(f"  Final Val Loss: {metrics['final_val_loss']:.4f}")
    print(f"  Best BLEU: {metrics['best_bleu']:.4f}")
    print(f"  Training Time: {metrics['training_time_min']:.1f} minutes")

# Save comparison
with open('outputs/model_comparison.json', 'w') as f:
    json.dump(comparison, f, indent=2)
print("\n✓ Comparison saved to outputs/model_comparison.json")

In [None]:
# Plot comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss comparison
axes[0].plot(cnn_history['val_loss'], label='CNN+LSTM')
axes[0].plot(vit_history['val_loss'], label='ViT')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Validation Loss')
axes[0].set_title('Validation Loss Comparison')
axes[0].legend()

# BLEU comparison
axes[1].plot(cnn_history['val_bleu'], label='CNN+LSTM')
axes[1].plot(vit_history['val_bleu'], label='ViT')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('BLEU Score')
axes[1].set_title('BLEU Score Comparison')
axes[1].legend()

# Bar chart
models = ['CNN+LSTM', 'ViT']
bleu_scores = [max(cnn_history['val_bleu']), max(vit_history['val_bleu'])]
colors = ['steelblue', 'coral']
axes[2].bar(models, bleu_scores, color=colors)
axes[2].set_ylabel('Best BLEU Score')
axes[2].set_title('Best BLEU Score Comparison')
for i, v in enumerate(bleu_scores):
    axes[2].text(i, v + 0.001, f'{v:.4f}', ha='center')

plt.tight_layout()
plt.savefig('outputs/model_comparison.png', dpi=150)
plt.show()

## 7. Save Models to Google Drive

In [None]:
# Copy checkpoints and outputs to Google Drive
!mkdir -p "{DRIVE_DATA_PATH}/checkpoints"
!mkdir -p "{DRIVE_DATA_PATH}/outputs"

!cp -r checkpoints/colab_cnn_lstm "{DRIVE_DATA_PATH}/checkpoints/"
!cp -r checkpoints/colab_vit "{DRIVE_DATA_PATH}/checkpoints/"
!cp -r outputs/colab_cnn_lstm "{DRIVE_DATA_PATH}/outputs/"
!cp -r outputs/colab_vit "{DRIVE_DATA_PATH}/outputs/"
!cp outputs/model_comparison.json "{DRIVE_DATA_PATH}/outputs/"
!cp outputs/model_comparison.png "{DRIVE_DATA_PATH}/outputs/"

print("✓ Models and outputs saved to Google Drive!")
print(f"Location: {DRIVE_DATA_PATH}")

## 8. Test Caption Generation

In [None]:
# Load best CNN+LSTM model and generate sample captions
import random
from PIL import Image
from utils.image_preprocessing import ImagePreprocessor

# Load a sample image from validation set
with open('data/processed/splits/val.json', 'r') as f:
    val_data = json.load(f)

# Pick random samples
samples = random.sample(val_data, 5)

# Image preprocessor
img_proc = ImagePreprocessor()

# Generate captions
cnn_lstm_model.eval()

print("Sample Caption Generation (CNN+LSTM):\n")
for sample in samples:
    img_path = f"data/raw/wikiart/{sample['style']}/{sample['painting']}.jpg"
    if os.path.exists(img_path):
        # Load and preprocess image
        image = Image.open(img_path).convert('RGB')
        img_tensor = img_proc.val_transform(image).unsqueeze(0).to(DEVICE)
        
        # Generate caption
        with torch.no_grad():
            caption, _ = cnn_lstm_model.generate_caption(img_tensor)
        
        # Decode caption
        generated = text_proc.decode(caption.cpu().numpy().tolist(), skip_special_tokens=True)
        
        print(f"Image: {sample['painting']}")
        print(f"Ground Truth: {sample['utterance']}")
        print(f"Generated: {generated}")
        print("-" * 50)

---
## Training Complete!

Your trained models are saved to Google Drive. To use them:
1. Download the checkpoints from `{DRIVE_DATA_PATH}/checkpoints/`
2. Use `scripts/predict.py` on your local machine