In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm import tqdm
import time
import os

# --- 1. Configuration ---
# Use a dictionary for easy management of hyperparameters
config = {
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    "BATCH_SIZE": 64,        # Use a larger batch size for caching and training
    "NUM_EPOCHS": 10,
    "LEARNING_RATE": 0.001,
    "DATA_DIR": "/home/kami/Documents/datasets/",    # Directory to store Food101 dataset
    "NUM_CLASSES": 101,      # Food101 has 101 classes
    "VGG16_FEATURE_SIZE": 25088 # VGG16 output size after features and avgpool
}

print(f"Using device: {config['DEVICE']}")
print("-" * 30)


# --- 2. The Custom Caching Model ---
class CachedFineTuneModel(nn.Module):
    """
    A model that caches activations from frozen layers to accelerate fine-tuning.
    """
    def __init__(self, original_model, num_classes, num_records):
        super().__init__()
        # Separate the frozen part of the model
        self.features = original_model.features
        self.avgpool = original_model.avgpool

        # Create a new trainable classifier
        self.classifier = nn.Sequential(
            nn.Linear(config["VGG16_FEATURE_SIZE"], 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )

        # Initialize the cache buffer on the CPU first
        initial_cache = torch.zeros(num_records, config["VGG16_FEATURE_SIZE"])
        # Register as a buffer: PyTorch will manage its device and state
        self.register_buffer('frozen_data', initial_cache)
        self.register_buffer('is_cached', torch.tensor(False))

    @torch.no_grad()
    def cache_activations(self, dataloader: DataLoader):
        """
        Performs the one-time forward pass to populate the activation cache.
        MUST be run in eval mode to ensure deterministic outputs from BatchNorm.
        """
        print("--- Phase 1: Caching Activations ---")
        self.eval() # CRITICAL: Use eval mode for deterministic caching

        device = next(self.parameters()).device
        start_time = time.time()

        for i, (data, _) in enumerate(tqdm(dataloader, desc="Caching Progress")):
            data = data.to(device)
            batch_size = data.shape[0]
            start_index = i * dataloader.batch_size
            end_index = start_index + batch_size

            activations = self.avgpool(self.features(data))
            activations = torch.flatten(activations, 1)
            self.frozen_data[start_index:end_index] = activations.cpu() # Store on CPU

        self.is_cached.fill_(True)
        caching_time = time.time() - start_time
        print(f"Caching complete in {caching_time:.2f} seconds.")
        # Move the entire cache to the target device after population
        self.frozen_data = self.frozen_data.to(device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        - If cached, `x` is a tensor of indices.
        - If not cached (e.g., during validation), `x` is an image tensor.
        """
        if self.training and self.is_cached:
            # Training phase: x is a batch of indices
            # Retrieve cached data directly from the buffer
            cached_batch = self.frozen_data[x]
            return self.classifier(cached_batch)
        else:
            # Validation phase or pre-caching: x is a batch of images
            # Perform the full forward pass
            x = self.avgpool(self.features(x))
            x = torch.flatten(x, 1)
            return self.classifier(x)

# --- 3. Data Loading ---
# Define transforms. No random augmentations for caching/validation.
data_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

print("Loading Food101 dataset. This may take a while on the first run...")
# Training set for caching (split='train')
train_dataset_caching = datasets.Food101(
    root=config["DATA_DIR"], split='train', download=True, transform=data_transforms
)
caching_loader = DataLoader(
    train_dataset_caching, batch_size=config["BATCH_SIZE"], shuffle=False
)

# Validation set
val_dataset = datasets.Food101(
    root=config["DATA_DIR"], split='test', download=True, transform=data_transforms
)
val_loader = DataLoader(val_dataset, batch_size=config["BATCH_SIZE"], shuffle=False)
print("Dataset loaded successfully.")

# --- 4. Model Initialization ---
# Load pretrained VGG16 with Batch Normalization
vgg16_bn = models.vgg16_bn(weights=models.VGG16_BN_Weights.DEFAULT)

# Freeze the feature extraction layers
for param in vgg16_bn.features.parameters():
    param.requires_grad = False

# Instantiate our custom model
num_train_records = len(train_dataset_caching)
model = CachedFineTuneModel(
    original_model=vgg16_bn,
    num_classes=config["NUM_CLASSES"],
    num_records=num_train_records
).to(config["DEVICE"])

# --- 5. Caching and Training ---

# Run the caching phase first
model.cache_activations(caching_loader)

# --- Prepare for the accelerated training phase ---
print("\n--- Phase 2: Accelerated Fine-Tuning ---")
# Create a new dataset that provides indices and original labels
train_labels = [label for _, label in train_dataset_caching]
training_dataset_indexed = TensorDataset(
    torch.arange(num_train_records),
    torch.tensor(train_labels)
)
training_loader = DataLoader(
    training_dataset_indexed, batch_size=config["BATCH_SIZE"], shuffle=True
)

# Optimizer should only see the parameters of the trainable classifier
optimizer = optim.Adam(model.classifier.parameters(), lr=config["LEARNING_RATE"])
criterion = nn.CrossEntropyLoss()

total_training_time = 0

for epoch in range(config["NUM_EPOCHS"]):
    epoch_start_time = time.time()
    # --- Training Loop ---
    model.train() # Set to training mode
    running_loss = 0.0

    pbar = tqdm(training_loader, desc=f"Epoch {epoch+1}/{config['NUM_EPOCHS']} [Training]")
    for batch_indices, batch_labels in pbar:
        batch_indices = batch_indices.to(config["DEVICE"])
        batch_labels = batch_labels.to(config["DEVICE"])

        optimizer.zero_grad()
        outputs = model(batch_indices) # Pass indices to the model
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    train_loss = running_loss / len(training_loader)

    # --- Validation Loop ---
    model.eval() # Set to evaluation mode
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        pbar_val = tqdm(val_loader, desc=f"Epoch {epoch+1}/{config['NUM_EPOCHS']} [Validation]")
        for images, labels in pbar_val:
            images = images.to(config["DEVICE"])
            labels = labels.to(config["DEVICE"])

            outputs = model(images) # Pass images to the model
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss /= len(val_loader)
    val_accuracy = 100 * correct / total

    epoch_time = time.time() - epoch_start_time
    total_training_time += epoch_time

    print(
        f"Epoch {epoch+1}/{config['NUM_EPOCHS']} | "
        f"Train Loss: {train_loss:.4f} | "
        f"Val Loss: {val_loss:.4f} | "
        f"Val Accuracy: {val_accuracy:.2f}% | "
        f"Time: {epoch_time:.2f}s"
    )

print("-" * 30)
print(f"Finished Accelerated Training in {total_training_time:.2f} seconds.")

Using device: cuda
------------------------------
Loading Food101 dataset. This may take a while on the first run...
Dataset loaded successfully.
--- Phase 1: Caching Activations ---


Caching Progress: 100%|██████████| 1184/1184 [04:40<00:00,  4.22it/s]


Caching complete in 280.33 seconds.

--- Phase 2: Accelerated Fine-Tuning ---


Epoch 1/10 [Training]: 100%|██████████| 1184/1184 [00:14<00:00, 79.59it/s, loss=2.7841]
Epoch 1/10 [Validation]: 100%|██████████| 395/395 [01:50<00:00,  3.56it/s]


Epoch 1/10 | Train Loss: 2.9777 | Val Loss: 1.9860 | Val Accuracy: 50.61% | Time: 125.76s


Epoch 2/10 [Training]: 100%|██████████| 1184/1184 [00:15<00:00, 76.75it/s, loss=2.8392]
Epoch 2/10 [Validation]: 100%|██████████| 395/395 [01:30<00:00,  4.36it/s]


Epoch 2/10 | Train Loss: 2.4156 | Val Loss: 1.9788 | Val Accuracy: 53.44% | Time: 105.93s


Epoch 3/10 [Training]: 100%|██████████| 1184/1184 [00:15<00:00, 77.69it/s, loss=1.8287]
Epoch 3/10 [Validation]: 100%|██████████| 395/395 [01:30<00:00,  4.35it/s]


Epoch 3/10 | Train Loss: 2.1997 | Val Loss: 1.8834 | Val Accuracy: 55.48% | Time: 106.15s


Epoch 4/10 [Training]: 100%|██████████| 1184/1184 [00:14<00:00, 81.66it/s, loss=2.1120]
Epoch 4/10 [Validation]: 100%|██████████| 395/395 [01:28<00:00,  4.47it/s]


Epoch 4/10 | Train Loss: 2.0653 | Val Loss: 1.8434 | Val Accuracy: 56.45% | Time: 102.91s


Epoch 5/10 [Training]: 100%|██████████| 1184/1184 [00:15<00:00, 77.77it/s, loss=1.6309]
Epoch 5/10 [Validation]: 100%|██████████| 395/395 [01:32<00:00,  4.29it/s]


Epoch 5/10 | Train Loss: 1.9370 | Val Loss: 1.7803 | Val Accuracy: 57.36% | Time: 107.26s


Epoch 6/10 [Training]: 100%|██████████| 1184/1184 [00:15<00:00, 76.93it/s, loss=1.9928]
Epoch 6/10 [Validation]: 100%|██████████| 395/395 [01:35<00:00,  4.13it/s]


Epoch 6/10 | Train Loss: 1.8728 | Val Loss: 1.7852 | Val Accuracy: 56.84% | Time: 110.96s


Epoch 7/10 [Training]: 100%|██████████| 1184/1184 [00:15<00:00, 78.89it/s, loss=1.9704]
Epoch 7/10 [Validation]: 100%|██████████| 395/395 [01:31<00:00,  4.33it/s]


Epoch 7/10 | Train Loss: 1.7967 | Val Loss: 1.7354 | Val Accuracy: 57.74% | Time: 106.19s


Epoch 8/10 [Training]: 100%|██████████| 1184/1184 [00:14<00:00, 79.32it/s, loss=1.4379]
Epoch 8/10 [Validation]: 100%|██████████| 395/395 [01:29<00:00,  4.40it/s]


Epoch 8/10 | Train Loss: 1.7395 | Val Loss: 1.7713 | Val Accuracy: 57.30% | Time: 104.66s


Epoch 9/10 [Training]: 100%|██████████| 1184/1184 [00:14<00:00, 79.49it/s, loss=2.4841]
Epoch 9/10 [Validation]: 100%|██████████| 395/395 [01:27<00:00,  4.49it/s]


Epoch 9/10 | Train Loss: 1.6883 | Val Loss: 1.7741 | Val Accuracy: 57.55% | Time: 102.84s


Epoch 10/10 [Training]: 100%|██████████| 1184/1184 [00:15<00:00, 78.90it/s, loss=1.8750]
Epoch 10/10 [Validation]: 100%|██████████| 395/395 [01:31<00:00,  4.32it/s]

Epoch 10/10 | Train Loss: 1.6471 | Val Loss: 1.7608 | Val Accuracy: 57.28% | Time: 106.45s
------------------------------
Finished Accelerated Training in 1079.10 seconds.



