# Experiment 1: Oracle vs Reference on RotatedMNIST

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from typing import Tuple, List, Type
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
from torch.nn.functional import cross_entropy
from torch.utils.data import random_split, DataLoader
from tqdm import tqdm
from torch.functional import F

from src.reference import ReferenceModel
from src.data import RotatedMNISTDataset, FixedSizeWrapper
from src.binary_tree import MNISTOracleRouter, BinaryTreeGoE, LatentVariableRouter, RandomBinaryTreeRouter

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Helper Functions

In [None]:
def build_modules_sqA(param_factor: int) -> Tuple[Type[nn.Module], Type[nn.Module], Type[nn.Module]]:
    """
    Architecture factory for Status Quo A Model for RotatedMNIST task which has the same number of *total* parameters as the Graph of Experts. Takes in hyperparameter param_factor, outputs list of module architectures which scale linearly in param_factor
    """
    class LayerOne(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(1, param_factor, kernel_size=3, padding=1)
    
        def forward(self, x: torch.Tensor):
            """
            Args:
            - x: (batch_size, 1, 28, 28)
            """
            return F.relu(F.max_pool2d(self.conv(x), 2))  # (batch_size, param_factor, 14, 14)
    
    class LayerTwo(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(param_factor, 2 * 2 * param_factor, kernel_size=3, padding=1)
    
        def forward(self, x: torch.Tensor):
            """
            Args:
            - x: (batch_size, 32, 14, 14)
            """
            return F.relu(F.max_pool2d(self.conv(x), 2))  # (batch_size, 2 * 2 * param_factor, 7, 7)
    
    class LayerThree(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc1 = nn.Linear(2 * 2 * param_factor * 7 * 7, 2 * 4 * param_factor)
            self.fc2 = nn.Linear(2 * 4 * param_factor, 10)
    
        def forward(self, x: torch.Tensor):
            """
            Args:
            - x: (batch_size, 128, 7, 7)
            """
            x = x.view(-1, 2 * 2 * param_factor * 7 * 7)  # (batch_size, 2 * 2 * param_factor * 7 * 7)
            x = F.relu(self.fc1(x))  # (batch_size, 2 * 4 * param_factor)
            logits = self.fc2(x)  # (batch_size, 10)
            return logits
    return LayerOne, LayerTwo, LayerThree

In [None]:
def build_modules(param_factor: int) -> Tuple[Type[nn.Module], Type[nn.Module], Type[nn.Module]]:
    """
    Architecture factory for RotatedMNIST task. Takes in hyperparameter param_factor, outputs list of module architectures which scale linearly in param_factor
    """
    class LayerOne(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(1, param_factor, kernel_size=3, padding=1)
    
        def forward(self, x: torch.Tensor):
            """
            Args:
            - x: (batch_size, 1, 28, 28)
            """
            return F.relu(F.max_pool2d(self.conv(x), 2))  # (batch_size, param_factor, 14, 14)
    
    class LayerTwo(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(param_factor, 2 * param_factor, kernel_size=3, padding=1)
    
        def forward(self, x: torch.Tensor):
            """
            Args:
            - x: (batch_size, 32, 14, 14)
            """
            return F.relu(F.max_pool2d(self.conv(x), 2))  # (batch_size, 2 * param_factor, 7, 7)
    
    class LayerThree(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc1 = nn.Linear(2 * param_factor * 7 * 7, 4 * param_factor)
            self.fc2 = nn.Linear(4 * param_factor, 10)
    
        def forward(self, x: torch.Tensor):
            """
            Args:
            - x: (batch_size, 64, 7, 7)
            """
            x = x.view(-1, 2 * param_factor * 7 * 7)  # (batch_size, 2 * param_factor * 7 * 7)
            x = F.relu(self.fc1(x))  # (batch_size, 4 * param_factor)
            logits = self.fc2(x)  # (batch_size, 10)
            return logits
    return LayerOne, LayerTwo, LayerThree

In [None]:
def do_train_epoch(model, loader, optimizer, epoch) -> float:
    model.train()
    epoch_losses = []
    for idx, (images, rotation_labels, digit_labels) in tqdm(enumerate(loader)):
        images = images.to(device)
        rotation_labels = rotation_labels.to(device)
        digit_labels = digit_labels.to(device)

        optimizer.zero_grad()
        logits = model(images, rotation_labels=rotation_labels) # (bs, num_digit_classes)
        loss = cross_entropy(logits, digit_labels)
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())
    epoch_loss = torch.mean(torch.Tensor(epoch_losses)).item()
    print(f'Train epoch {epoch} loss: {epoch_loss:.4f}')
    return epoch_loss

def do_val_epoch(model, loader, epoch) -> float:
    model.eval()
    epoch_losses = []
    for idx, (images, rotation_labels, digit_labels) in tqdm(enumerate(loader)):
        images = images.to(device)
        rotation_labels = rotation_labels.to(device)
        digit_labels = digit_labels.to(device)

        logits = model(images, rotation_labels=rotation_labels) # (bs, num_digit_classes)
        loss = cross_entropy(logits, digit_labels)
        epoch_losses.append(loss.item())
    epoch_loss = torch.mean(torch.Tensor(epoch_losses)).item()
    print(f'Val epoch {epoch} loss: {epoch_loss:.4f}') 
    return epoch_loss

def get_accuracy(model, loader):
    model.eval()
    total_samples = 0
    total_correct = 0
    for idx, (images, rotation_labels, digit_labels) in tqdm(enumerate(loader)):
        images = images.to(device)
        rotation_labels = rotation_labels.to(device)
        digit_labels = digit_labels.to(device) # (bs,)

        logits = model(images, rotation_labels=rotation_labels) # (bs, num_digit_classes)
        prediction = torch.argmax(logits, dim=1) # (bs,)
        total_samples += prediction.shape[0]
        total_correct += torch.sum(prediction == digit_labels).item()
    accuracy = total_correct / total_samples * 100
    return accuracy

def get_rotated_mnist_loaders(downsample_factor: int = 20) -> Tuple[DataLoader, DataLoader, DataLoader]:
    # Initialize dataset
    dataset = RotatedMNISTDataset()
    
    # Assuming `dataset` is your PyTorch Dataset
    dataset_size = len(dataset)
    train_size = int(0.7 * dataset_size)
    val_size = int(0.2 * dataset_size)
    test_size = dataset_size - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = random_split(
        dataset, [train_size, val_size, test_size], 
        generator=torch.Generator().manual_seed(40)
    )
    
    train_dataset = FixedSizeWrapper(dataset = train_dataset, size = train_size // downsample_factor)
    val_dataset = FixedSizeWrapper(dataset = val_dataset, size = val_size // downsample_factor)
    test_dataset = FixedSizeWrapper(dataset = test_dataset, size = test_size // downsample_factor)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    return train_dataset, val_dataset, test_dataset, train_loader, val_loader, test_loader

def training_loop(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, test_loader: DataLoader, num_epochs: int = 100, lr: float = 0.005, epochs_per_accuracy: int = 5) -> List[float]:
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    val_losses = []
    for epoch in range(num_epochs):
        do_train_epoch(model, train_loader, optimizer, epoch)
        val_loss = do_val_epoch(model, val_loader, epoch)
        val_losses.append(val_loss)
        if epoch % epochs_per_accuracy == 0:
            accuracy = get_accuracy(model, test_loader)
            print(f'Test accuracy: {accuracy:.3f}%')
    return val_losses

def save_model(model: nn.Module, save_path: str):
    save_dir = os.path.dirname(save_path)
    os.makedirs(save_dir, exist_ok=True)
    torch.save(model.state_dict(), save_path)

def load_model(model: nn.Module, save_path: str):
    return model.load_state_dict(torch.load(save_path))

# Experiment

In [None]:
# Build loaders
(
    train_dataset, val_dataset, test_dataset,
    train_loader, val_loader, test_loader
) = get_rotated_mnist_loaders()
# param_factors = [2, 4, 8, 12, 16, 24] # The LOST numbers :D
param_factors = [2]
num_epochs = 1
reference_models = {}
goe_oracle_models = {}

print(f"Number of training samples: {len(train_dataset)}")
# print(train_dataset[-1])
# from tqdm import tqdm
for el in tqdm(train_dataset):
    pass

In [None]:
# Build loaders
(
    train_dataset, val_dataset, test_dataset,
    train_loader, val_loader, test_loader
) = get_rotated_mnist_loaders()
param_factors = [2, 4, 8, 12, 16, 24] # The LOST numbers :D
# param_factors = [2]
num_epochs = 100
reference_models = {}
reference_sqA_models = {}
goe_oracle_models = {}
goe_random_models = {}
goe_latent_models = {}

# routers
oracle_router = MNISTOracleRouter()

random_router = RandomBinaryTreeRouter(depth=3)
random_router.compute_codebook(train_dataset)

latent_router = LatentVariableRouter(depth=3)
latent_router.compute_codebook(train_dataset)

# train models
for param_factor in param_factors:
    # Build architectures
    modules_by_depth = build_modules(param_factor)
    modules_by_depth_sqA = build_modules_sqA(param_factor)
    # TODO: Compute FLOPS, etc. 
    
    # Train reference model + save    
    print(f'Training reference at param_factor={param_factor}')
    reference_model = ReferenceModel(modules_by_depth=modules_by_depth).to(device)
    _ = training_loop(
        model=reference_model, 
        train_loader=train_loader, 
        val_loader=val_loader, 
        test_loader=test_loader, 
        num_epochs=num_epochs, 
        lr = 0.005, 
        epochs_per_accuracy=5
    )
    reference_models[param_factor] = reference_model
    reference_save_path = f'checkpoints/rotated_mnist_oracle/param_{param_factor}_reference.pt'
    save_model(reference_model, reference_save_path)

    print(f'Training SQ-A Reference Model at param_factor={param_factor}')
    reference_sqA_model = ReferenceModel(modules_by_depth=modules_by_depth_sqA).to(device)
    _ = training_loop(
        model=reference_model, 
        train_loader=train_loader, 
        val_loader=val_loader, 
        test_loader=test_loader, 
        num_epochs=num_epochs, 
        lr = 0.005, 
        epochs_per_accuracy=5
    )
    reference_sqA_models[param_factor] = reference_sqA_model
    reference_sqA_save_path = f'checkpoints/rotated_mnist_oracle/param_{param_factor}_reference_sqA.pt'
    save_model(reference_sqA_model, reference_sqA_save_path)
    
    # Train GoE model + save
    print(f'Training GoE at param_factor={param_factor}')
    goe_oracle_model = BinaryTreeGoE(modules_by_depth = modules_by_depth, router=oracle_router).to(device)
    val_losses = training_loop(
        model=goe_oracle_model, 
        train_loader=train_loader, 
        val_loader=val_loader, 
        test_loader=test_loader, 
        num_epochs=num_epochs, 
        lr = 0.005, 
        epochs_per_accuracy=5
    )
    goe_oracle_models[param_factor] = goe_oracle_model
    goe_save_path = f'checkpoints/rotated_mnist_oracle/param_{param_factor}_goe_oracle.pt'
    save_model(goe_oracle_model, goe_save_path)
    
    # Train random GoE model + save
    print(f'Training Random GoE at param_factor={param_factor}')
    goe_random_model = BinaryTreeGoE(modules_by_depth = modules_by_depth, router=random_router).to(device)
    val_losses = training_loop(
        model=goe_random_model, 
        train_loader=train_loader, 
        val_loader=val_loader, 
        test_loader=test_loader, 
        num_epochs=num_epochs, 
        lr = 0.005, 
        epochs_per_accuracy=5
    )
    goe_random_models[param_factor] = goe_random_model
    goe_save_path = f'checkpoints/rotated_mnist_oracle/param_{param_factor}_goe_random.pt'
    save_model(goe_random_model, goe_save_path)
    
    # Train latent GoE model + save
    print(f'Training Latent GoE at param_factor={param_factor}')
    goe_latent_model = BinaryTreeGoE(modules_by_depth = modules_by_depth, router=latent_router).to(device)
    val_losses = training_loop(
        model=goe_latent_model, 
        train_loader=train_loader, 
        val_loader=val_loader, 
        test_loader=test_loader, 
        num_epochs=num_epochs, 
        lr = 0.005, 
        epochs_per_accuracy=5
    )
    goe_latent_models[param_factor] = goe_latent_model
    goe_save_path = f'checkpoints/rotated_mnist_oracle/param_{param_factor}_goe_latent.pt'
    save_model(goe_latent_model, goe_save_path)
    

In [None]:
# Graph results
reference_accuracies = [get_accuracy(model, test_loader) for (_, model) in reference_models.items()]
goe_oracle_accuracies = [get_accuracy(model, test_loader) for (_, model) in goe_oracle_models.items()]
goe_random_accuracies = [get_accuracy(model, test_loader) for (_, model) in goe_random_models.items()]
goe_latent_accuracies = [get_accuracy(model, test_loader) for (_, model) in goe_latent_models.items()]


In [None]:
print(reference_accuracies)

In [None]:
plt.plot(param_factors, reference_accuracies, label="Reference")
plt.plot(param_factors, goe_oracle_accuracies, label="GoE Oracle")
plt.plot(param_factors, goe_random_accuracies, label="GoE Random")
plt.plot(param_factors, goe_latent_accuracies, label="GoE Latent")
plt.title("Reference vs GoE # Parameters vs Accuracy")
plt.legend()
# plt.show()
plt.savefig('rotated_mnist_oracle.png')

In [None]:
plt.plot(reference_accuracies, param_factors, label="Reference")
plt.plot(goe_oracle_accuracies, param_factors, label="GoE")
plt.plot(goe_random_accuracies, param_factors, label="GoE Random")
plt.plot(goe_latent_accuracies, param_factors, label="GoE Latent")
plt.title("Reference vs GoE # Parameters vs Accuracy")
plt.legend()
# plt.show()
plt.savefig('rotated_mnist_oracle_T.png')