# Experiment 1: Oracle vs Reference

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
from torch.nn.functional import cross_entropy
from torch.utils.data import random_split
from tqdm import tqdm

from src.reference import RotatedMNISTClassifier
from src.data import RotatedMNISTDataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Initialize model
reference_model = RotatedMNISTClassifier().to(device)

# Initialize dataset
dataset = RotatedMNISTDataset()

# Assuming `dataset` is your PyTorch Dataset
dataset_size = len(dataset)
train_size = int(0.7 * dataset_size)
val_size = int(0.2 * dataset_size)
test_size = dataset_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size], 
    generator=torch.Generator().manual_seed(40)
)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
def do_train_epoch(model, loader, optimizer, epoch):
    model.train()
    epoch_loss = 0
    for idx, (images, rotation_labels, digit_labels) in tqdm(enumerate(loader)):
        images = images.to(device)
        rotation_labels = rotation_labels.to(device)
        digit_labels = digit_labels.to(device)

        optimizer.zero_grad()
        logits = reference_model(images) # (bs, num_digit_classes)
        loss = cross_entropy(logits, digit_labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f'Train epoch {epoch} loss: {epoch_loss:.4f}')

def do_val_epoch(model, loader, epoch):
    model.eval()
    epoch_loss = 0
    for idx, (images, rotation_labels, digit_labels) in tqdm(enumerate(loader)):
        images = images.to(device)
        rotation_labels = rotation_labels.to(device)
        digit_labels = digit_labels.to(device)

        logits = reference_model(images) # (bs, num_digit_classes)
        loss = cross_entropy(logits, digit_labels)
        epoch_loss += loss.item()
    print(f'Val epoch {epoch} loss: {epoch_loss:.4f}') 

In [None]:
# Training loop
num_epochs = 2
lr = 0.005
optimizer = torch.optim.Adam(reference_model.parameters(), lr=lr)

for epoch in range(num_epochs):
    do_train_epoch(reference_model, train_loader, optimizer, epoch)
    do_val_epoch(referenc_model, val_loader, epoch)