In [5]:
import torch
from torchvision import transforms, models
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torch.optim import Adam
import random
from tqdm.auto import tqdm
from vcs2425 import ApplyColormap, ImageNetDepth


# 1. Reproducibility
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)

# 2. Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

color_map_transform = ApplyColormap(cmap='viridis')

# 3. Transforms
train_transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.CenterCrop((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    color_map_transform,
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.CenterCrop((224, 224)),
    color_map_transform,
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [6]:
# 4. Load full dataset
data_dir = '../ILSVRC2012_depth'
full_dataset = ImageNetDepth(root_dir=data_dir)
class_to_idx = full_dataset.class_to_idx

# 5. Manual split
indices = list(range(len(full_dataset)))
random.shuffle(indices)
split = int(0.8 * len(full_dataset))
train_indices, val_indices = indices[:split], indices[split:]

# 6. Create separate train and val datasets with transforms
train_base = ImageNetDepth(root_dir=data_dir, transform=train_transform)
val_base = ImageNetDepth(root_dir=data_dir, transform=val_transform)
train_base.class_to_idx = class_to_idx
val_base.class_to_idx = class_to_idx

train_dataset = Subset(train_base, train_indices)
val_dataset = Subset(val_base, val_indices)

BATCH_SIZE = 256

# 7. DataLoaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=4,
    pin_memory=True,  # Better memory management
    persistent_workers=True  # Keep workers alive between iterations
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE, 
    shuffle=False,
    # num_workers=4,
    # pin_memory=True,  # Better memory management
    # persistent_workers=True  # Keep workers alive between iterations
)

FileNotFoundError: [Errno 2] No such file or directory: '../ILSVRC2012_depth'

In [None]:
def calculate_accuracy(prediction, ground_truth):
    prediction = prediction.argmax(dim=1, keepdim=True)
    correct = prediction.eq(ground_truth.view_as(prediction)).sum()
    accuracy = correct.float() / ground_truth.shape[0]
    return accuracy

def train(model, loader, criterion, optimizer):
    model.train()
    epoch_loss = 0.0
    epoch_accuracy = 0.0

    for images, labels in tqdm(loader, desc="Training", leave=False):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        
        predictions = model(images)
        
        loss = criterion(predictions, labels)
        loss.backward()
        optimizer.step()

        acc = calculate_accuracy(predictions, labels)

        epoch_loss += loss.item()
        epoch_accuracy += acc.item()

    train_accuracy = epoch_accuracy / len(loader)
    train_loss = epoch_loss / len(loader)
    
    return train_loss, train_accuracy

def evaluate(model, loader, criterion):
    model.eval()
    epoch_accuracy = 0.0
    epoch_loss = 0.0

    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Evaluation", leave=False):
            images, labels = images.to(device), labels.to(device)
            predictions = model(images)
            loss = criterion(predictions, labels)
            acc = calculate_accuracy(predictions, labels)
            
            epoch_loss += loss.item()
            epoch_accuracy += acc.item()


    return epoch_loss / len(loader), epoch_accuracy / len(loader)

def evaluate_topk(model, loader, k=5):
    model.eval()
    topk_accuracy = 0.0

    with torch.no_grad():
        for images, labels in tqdm(loader, desc=f"Top-{k} Evaluation", leave=False):
            images, labels = images.to(device), labels.to(device)
            predictions = model(images)
            _, topk_preds = predictions.topk(k, dim=1)
            topk_correct = topk_preds.eq(labels.view(-1, 1).expand_as(topk_preds)).sum().item()
            topk_accuracy += topk_correct / labels.size(0)

    return topk_accuracy / len(loader)


# Evaluating "as-is"

In [None]:
model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1)

model.to(device)

In [None]:
from matplotlib.pylab import f

top_1s = []
top_5s = []

colormaps = ['stacked', 'gray', 'viridis', 'plasma', 'magma', 'Spectral']
criterion = nn.CrossEntropyLoss()

for cmap in tqdm(colormaps, desc="Evaluating on different colormaps"):
    color_map_transform.cmap = cmap

    top_1 = evaluate_topk(model, val_loader, k=1)
    top_5 = evaluate_topk(model, val_loader, k=5)
    
    top_1s.append(top_1)
    top_5s.append(top_5)
    
    print(f"Colormap: {cmap}, Top-1 Accuracy: {top_1:.4f}, Top-5 Accuracy: {top_5:.4f}")



In [None]:
import numpy as np

import matplotlib.pyplot as plt

# Create a figure for the grid
fig, ax = plt.subplots(figsize=(12, 6))

# Set up bar width and positions
x = np.arange(len(colormaps))
width = 0.35

# Create bars
rects1 = ax.bar(x - width/2, top_1s, width, label='Top-1 Accuracy')
rects2 = ax.bar(x + width/2, top_5s, width, label='Top-5 Accuracy')

# Add labels, title and legend
ax.set_xlabel('Colormap')
ax.set_ylabel('Accuracy')
ax.set_title('Top-1 and Top-5 Accuracies by Colormap')
ax.set_xticks(x)
ax.set_xticklabels(colormaps)
ax.legend()

# Add text labels on top of bars
def autolabel(rects):
    for rect in rects:
        height = rect.get_height()
        ax.annotate(f'{height:.3f}',
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0, 3),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom')

autolabel(rects1)
autolabel(rects2)

fig.tight_layout()
plt.show()