In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm import tqdm
import time

# --- 1. Configuration ---
config = {
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    "BATCH_SIZE": 64,
    "NUM_EPOCHS": 10,
    "LEARNING_RATE": 0.001,
    "DATA_DIR": "/home/kami/Documents/datasets/",
    "NUM_CLASSES": 101,
    "VGG16_FEATURE_SIZE": 25088
}

print(f"--- SCRIPT 1: CACHED FINE-TUNING ---")
print(f"Using device: {config['DEVICE']}")
print("-" * 40)


# --- 2. The Custom Caching Model ---
class CachedFineTuneModel(nn.Module):
    def __init__(self, original_model, num_classes, num_records):
        super().__init__()
        self.features = original_model.features
        self.avgpool = original_model.avgpool
        self.classifier = nn.Sequential(
            nn.Linear(config["VGG16_FEATURE_SIZE"], 4096), nn.ReLU(True), nn.Dropout(0.5),
            nn.Linear(4096, 1024), nn.ReLU(True), nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )
        initial_cache = torch.zeros(num_records, config["VGG16_FEATURE_SIZE"])
        self.register_buffer('frozen_data', initial_cache)
        self.register_buffer('is_cached', torch.tensor(False))

    @torch.no_grad()
    def cache_activations(self, dataloader: DataLoader):
        print("--- Phase 1: Caching Activations (One-Time Cost) ---")
        self.eval()
        device = next(self.parameters()).device
        start_time = time.time()
        for i, (data, _) in enumerate(tqdm(dataloader, desc="Caching Progress")):
            data = data.to(device)
            batch_size = data.shape[0]
            start_index, end_index = i * dataloader.batch_size, i * dataloader.batch_size + batch_size
            activations = torch.flatten(self.avgpool(self.features(data)), 1)
            self.frozen_data[start_index:end_index] = activations.cpu()
        self.is_cached.fill_(True)
        print(f"Caching complete in {time.time() - start_time:.2f} seconds.")
        self.frozen_data = self.frozen_data.to(device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.training and self.is_cached:
            return self.classifier(self.frozen_data[x])
        else: # Used for a single forward pass if needed, not in training loop
            return self.classifier(torch.flatten(self.avgpool(self.features(x)), 1))

# --- 3. Data and Model Setup ---
data_transforms = transforms.Compose([
    transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
print("Loading Food101 dataset for caching...")
train_dataset_caching = datasets.Food101(root=config["DATA_DIR"], split='train', download=True, transform=data_transforms)
caching_loader = DataLoader(train_dataset_caching, batch_size=config["BATCH_SIZE"], shuffle=False)
vgg16_bn = models.vgg16_bn(weights=models.VGG16_BN_Weights.DEFAULT)
for param in vgg16_bn.features.parameters(): param.requires_grad = False
model = CachedFineTuneModel(vgg16_bn, config["NUM_CLASSES"], len(train_dataset_caching)).to(config["DEVICE"])

# --- 4. Caching Phase ---
model.cache_activations(caching_loader)

# --- 5. Accelerated Training Phase ---
print("\n--- Phase 2: Accelerated Fine-Tuning ---")
train_labels = [label for _, label in train_dataset_caching]
training_dataset_indexed = TensorDataset(torch.arange(len(train_dataset_caching)), torch.tensor(train_labels))
training_loader = DataLoader(training_dataset_indexed, batch_size=config["BATCH_SIZE"], shuffle=True)
optimizer = optim.Adam(model.classifier.parameters(), lr=config["LEARNING_RATE"])
criterion = nn.CrossEntropyLoss()

total_training_time = 0
model.train()

for epoch in range(config["NUM_EPOCHS"]):
    epoch_start_time = time.time()
    running_loss = 0.0
    correct = 0
    total = 0
    pbar = tqdm(training_loader, desc=f"Epoch {epoch+1}/{config['NUM_EPOCHS']}")
    for batch_indices, batch_labels in pbar:
        batch_indices, batch_labels = batch_indices.to(config["DEVICE"]), batch_labels.to(config["DEVICE"])
        optimizer.zero_grad()
        outputs = model(batch_indices)
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += batch_labels.size(0)
        correct += (predicted == batch_labels).sum().item()
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    epoch_time = time.time() - epoch_start_time
    total_training_time += epoch_time

# --- 6. Final Report ---
final_train_loss = running_loss / len(training_loader)
final_train_accuracy = 100 * correct / total
print("-" * 40)
print(f"Finished Accelerated Training.")
print(f"Total Time for {config['NUM_EPOCHS']} epochs: {total_training_time:.2f} seconds")
print(f"Final Training Loss: {final_train_loss:.4f}")
print(f"Final Training Accuracy: {final_train_accuracy:.2f}%")
print("-" * 40)

--- SCRIPT 1: CACHED FINE-TUNING ---
Using device: cuda
----------------------------------------
Loading Food101 dataset for caching...
--- Phase 1: Caching Activations (One-Time Cost) ---


Caching Progress: 100%|██████████| 1184/1184 [04:36<00:00,  4.28it/s]


Caching complete in 276.60 seconds.

--- Phase 2: Accelerated Fine-Tuning ---


Epoch 1/10: 100%|██████████| 1184/1184 [00:15<00:00, 77.06it/s, loss=2.5550]
Epoch 2/10: 100%|██████████| 1184/1184 [00:15<00:00, 78.36it/s, loss=2.7821]
Epoch 3/10: 100%|██████████| 1184/1184 [00:15<00:00, 78.18it/s, loss=1.9598]
Epoch 4/10: 100%|██████████| 1184/1184 [00:14<00:00, 78.98it/s, loss=2.1403]
Epoch 5/10: 100%|██████████| 1184/1184 [00:14<00:00, 82.42it/s, loss=1.5186]
Epoch 6/10: 100%|██████████| 1184/1184 [00:15<00:00, 77.66it/s, loss=2.2002]
Epoch 7/10: 100%|██████████| 1184/1184 [00:14<00:00, 80.10it/s, loss=1.3070]
Epoch 8/10: 100%|██████████| 1184/1184 [00:14<00:00, 79.67it/s, loss=1.6209]
Epoch 9/10: 100%|██████████| 1184/1184 [00:14<00:00, 81.40it/s, loss=2.0822]
Epoch 10/10: 100%|██████████| 1184/1184 [00:14<00:00, 80.27it/s, loss=2.5154]

----------------------------------------
Finished Accelerated Training.
Total Time for 10 epochs: 149.17 seconds
Final Training Loss: 1.6591
Final Training Accuracy: 59.02%
----------------------------------------



