# Experiment 1: Oracle vs Reference

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from typing import Tuple, List
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
from torch.nn.functional import cross_entropy
from torch.utils.data import random_split, DataLoader
from tqdm import tqdm

from src.reference import RotatedMNISTClassifier
from src.data import RotatedMNISTDataset, FixedSizeWrapper

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Helper Functions

In [8]:
def do_train_epoch(model, loader, optimizer, epoch) -> float:
    model.train()
    epoch_losses = []
    for idx, (images, rotation_labels, digit_labels) in tqdm(enumerate(loader)):
        images = images.to(device)
        rotation_labels = rotation_labels.to(device)
        digit_labels = digit_labels.to(device)

        optimizer.zero_grad()
        logits = reference_model(images) # (bs, num_digit_classes)
        loss = cross_entropy(logits, digit_labels)
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())
    epoch_loss = torch.mean(torch.Tensor(epoch_losses)).item()
    print(f'Train epoch {epoch} loss: {epoch_loss:.4f}')
    return epoch_loss

def do_val_epoch(model, loader, epoch) -> float:
    model.eval()
    epoch_losses = []
    for idx, (images, rotation_labels, digit_labels) in tqdm(enumerate(loader)):
        images = images.to(device)
        rotation_labels = rotation_labels.to(device)
        digit_labels = digit_labels.to(device)

        logits = reference_model(images) # (bs, num_digit_classes)
        loss = cross_entropy(logits, digit_labels)
        epoch_losses.append(loss.item())
    epoch_loss = torch.mean(torch.Tensor(epoch_losses)).item()
    print(f'Val epoch {epoch} loss: {epoch_loss:.4f}') 
    return epoch_loss

def get_accuracy(model, loader):
    total_samples = 0
    total_correct = 0
    for idx, (images, rotation_labels, digit_labels) in tqdm(enumerate(loader)):
        images = images.to(device)
        rotation_labels = rotation_labels.to(device)
        digit_labels = digit_labels.to(device) # (bs,)

        logits = reference_model(images) # (bs, num_digit_classes)
        prediction = torch.argmax(logits, dim=1) # (bs,)
        total_samples += prediction.shape[0]
        total_correct += torch.sum(prediction == digit_labels)
    accuracy = total_correct / total_samples * 100
    return accuracy

def get_rotated_mnist_loaders(downsample_factor: int = 20) -> Tuple[DataLoader, DataLoader, DataLoader]:
    # Initialize dataset
    dataset = RotatedMNISTDataset()
    
    # Assuming `dataset` is your PyTorch Dataset
    dataset_size = len(dataset)
    train_size = int(0.7 * dataset_size)
    val_size = int(0.2 * dataset_size)
    test_size = dataset_size - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = random_split(
        dataset, [train_size, val_size, test_size], 
        generator=torch.Generator().manual_seed(40)
    )
    
    train_dataset = FixedSizeWrapper(dataset = train_dataset, size = train_size // downsample_factor)
    val_dataset = FixedSizeWrapper(dataset = val_dataset, size = val_size // downsample_factor)
    test_dataset = FixedSizeWrapper(dataset = test_dataset, size = test_size // downsample_factor)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    return train_loader, val_loader, test_loader

def training_loop(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, test_loader: DataLoader, num_epochs: int = 100, lr: float = 0.005, epochs_per_accuracy: int = 5) -> List[float]:
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    val_losses = []
    for epoch in range(num_epochs):
        do_train_epoch(model, train_loader, optimizer, epoch)
        val_loss = do_val_epoch(model, val_loader, epoch)
        val_losses.append(val_loss)
        if epoch % epochs_per_accuracy == 0:
            accuracy = get_accuracy(model, test_loader)
            print(f'Test accuracy: {accuracy:.3f}%')
    return val_losses

In [None]:
train_loader, val_loader, test_loader = get_rotated_mnist_loaders()

# Training Loop for Reference Model

In [9]:
# Training loop
reference_model = RotatedMNISTClassifier().to(device)
val_losses = training_loop(
    model=reference_model, 
    train_loader=train_loader, 
    val_loader=val_loader, 
    test_loader=test_loader, 
    num_epochs=50, 
    lr = 0.005, 
    epochs_per_accuracy=5)

263it [00:02, 122.53it/s]


Train epoch 0 loss: 1.1081


75it [00:00, 177.44it/s]


Val epoch 0 loss: 0.6816


38it [00:00, 177.33it/s]


Test accuracy: 77.750%


263it [00:02, 126.77it/s]


Train epoch 1 loss: 0.5383


75it [00:00, 175.72it/s]


Val epoch 1 loss: 0.4768


263it [00:02, 125.34it/s]


Train epoch 2 loss: 0.4238


75it [00:00, 175.07it/s]


Val epoch 2 loss: 0.3677


263it [00:02, 124.81it/s]


Train epoch 3 loss: 0.3794


75it [00:00, 176.57it/s]


Val epoch 3 loss: 0.3947


263it [00:02, 124.77it/s]


Train epoch 4 loss: 0.3516


75it [00:00, 173.93it/s]


Val epoch 4 loss: 0.3289


263it [00:02, 124.04it/s]


Train epoch 5 loss: 0.2949


75it [00:00, 176.59it/s]


Val epoch 5 loss: 0.3225


38it [00:00, 176.27it/s]


Test accuracy: 91.833%


263it [00:02, 126.34it/s]


Train epoch 6 loss: 0.2923


75it [00:00, 175.74it/s]


Val epoch 6 loss: 0.2725


263it [00:02, 125.96it/s]


Train epoch 7 loss: 0.2725


75it [00:00, 176.21it/s]


Val epoch 7 loss: 0.2996


263it [00:02, 125.67it/s]


Train epoch 8 loss: 0.2734


75it [00:00, 178.18it/s]


Val epoch 8 loss: 0.3048


263it [00:02, 125.98it/s]


Train epoch 9 loss: 0.2761


75it [00:00, 176.12it/s]


Val epoch 9 loss: 0.2422


263it [00:02, 125.43it/s]


Train epoch 10 loss: 0.2548


75it [00:00, 175.85it/s]


Val epoch 10 loss: 0.2723


38it [00:00, 176.76it/s]


Test accuracy: 93.250%


263it [00:02, 126.12it/s]


Train epoch 11 loss: 0.2295


75it [00:00, 177.22it/s]


Val epoch 11 loss: 0.2173


263it [00:02, 125.93it/s]


Train epoch 12 loss: 0.2320


75it [00:00, 177.09it/s]


Val epoch 12 loss: 0.2266


263it [00:02, 116.75it/s]


Train epoch 13 loss: 0.2276


75it [00:00, 157.50it/s]


Val epoch 13 loss: 0.2224


263it [00:02, 115.11it/s]


Train epoch 14 loss: 0.2323


75it [00:00, 156.78it/s]


Val epoch 14 loss: 0.2589


263it [00:02, 115.16it/s]


Train epoch 15 loss: 0.2233


75it [00:00, 156.55it/s]


Val epoch 15 loss: 0.2548


38it [00:00, 157.88it/s]


Test accuracy: 94.417%


263it [00:02, 115.62it/s]


Train epoch 16 loss: 0.2129


75it [00:00, 157.81it/s]


Val epoch 16 loss: 0.2277


263it [00:02, 115.06it/s]


Train epoch 17 loss: 0.2112


75it [00:00, 156.45it/s]


Val epoch 17 loss: 0.2217


263it [00:02, 113.47it/s]


Train epoch 18 loss: 0.2038


75it [00:00, 156.36it/s]


Val epoch 18 loss: 0.2502


263it [00:02, 114.90it/s]


Train epoch 19 loss: 0.2076


75it [00:00, 156.73it/s]


Val epoch 19 loss: 0.2259


263it [00:02, 120.14it/s]


Train epoch 20 loss: 0.2017


75it [00:00, 177.94it/s]


Val epoch 20 loss: 0.2118


38it [00:00, 177.52it/s]


Test accuracy: 94.000%


263it [00:02, 126.67it/s]


Train epoch 21 loss: 0.2055


75it [00:00, 177.91it/s]


Val epoch 21 loss: 0.2122


263it [00:02, 126.27it/s]


Train epoch 22 loss: 0.1951


75it [00:00, 179.23it/s]


Val epoch 22 loss: 0.1830


263it [00:02, 126.56it/s]


Train epoch 23 loss: 0.1869


75it [00:00, 177.05it/s]


Val epoch 23 loss: 0.2209


263it [00:02, 126.87it/s]


Train epoch 24 loss: 0.1703


75it [00:00, 173.73it/s]


Val epoch 24 loss: 0.1884


263it [00:02, 125.34it/s]


Train epoch 25 loss: 0.1854


75it [00:00, 176.76it/s]


Val epoch 25 loss: 0.2198


38it [00:00, 177.36it/s]


Test accuracy: 94.000%


263it [00:02, 119.53it/s]


Train epoch 26 loss: 0.1849


75it [00:00, 156.27it/s]


Val epoch 26 loss: 0.2330


263it [00:02, 114.98it/s]


Train epoch 27 loss: 0.1849


75it [00:00, 155.86it/s]


Val epoch 27 loss: 0.2047


263it [00:02, 115.36it/s]


Train epoch 28 loss: 0.1867


75it [00:00, 157.24it/s]


Val epoch 28 loss: 0.1919


263it [00:02, 114.86it/s]


Train epoch 29 loss: 0.1666


75it [00:00, 156.14it/s]


Val epoch 29 loss: 0.2029


263it [00:02, 114.83it/s]


Train epoch 30 loss: 0.1892


75it [00:00, 156.09it/s]


Val epoch 30 loss: 0.2200


38it [00:00, 156.98it/s]


Test accuracy: 94.500%


263it [00:02, 114.96it/s]


Train epoch 31 loss: 0.1651


75it [00:00, 156.77it/s]


Val epoch 31 loss: 0.1599


263it [00:02, 114.84it/s]


Train epoch 32 loss: 0.1783


75it [00:00, 156.39it/s]


Val epoch 32 loss: 0.1777


263it [00:02, 115.67it/s]


Train epoch 33 loss: 0.1701


75it [00:00, 157.89it/s]


Val epoch 33 loss: 0.2063


263it [00:02, 115.35it/s]


Train epoch 34 loss: 0.1837


75it [00:00, 158.14it/s]


Val epoch 34 loss: 0.1757


263it [00:02, 115.56it/s]


Train epoch 35 loss: 0.1711


75it [00:00, 157.34it/s]


Val epoch 35 loss: 0.2180


38it [00:00, 158.83it/s]


Test accuracy: 92.583%


263it [00:02, 115.43it/s]


Train epoch 36 loss: 0.1654


75it [00:00, 157.54it/s]


Val epoch 36 loss: 0.1777


263it [00:02, 114.71it/s]


Train epoch 37 loss: 0.1720


75it [00:00, 157.35it/s]


Val epoch 37 loss: 0.1662


263it [00:02, 114.69it/s]


Train epoch 38 loss: 0.1834


75it [00:00, 156.81it/s]


Val epoch 38 loss: 0.1727


263it [00:02, 114.46it/s]


Train epoch 39 loss: 0.1518


75it [00:00, 157.36it/s]


Val epoch 39 loss: 0.1931


263it [00:02, 114.75it/s]


Train epoch 40 loss: 0.1583


75it [00:00, 157.68it/s]


Val epoch 40 loss: 0.2542


38it [00:00, 158.60it/s]


Test accuracy: 93.583%


263it [00:02, 114.51it/s]


Train epoch 41 loss: 0.1659


75it [00:00, 153.70it/s]


Val epoch 41 loss: 0.1983


263it [00:02, 115.16it/s]


Train epoch 42 loss: 0.1607


75it [00:00, 177.79it/s]


Val epoch 42 loss: 0.1748


263it [00:02, 126.88it/s]


Train epoch 43 loss: 0.1778


75it [00:00, 179.19it/s]


Val epoch 43 loss: 0.1985


263it [00:02, 126.97it/s]


Train epoch 44 loss: 0.1703


75it [00:00, 177.32it/s]


Val epoch 44 loss: 0.1668


263it [00:02, 126.47it/s]


Train epoch 45 loss: 0.1670


75it [00:00, 177.70it/s]


Val epoch 45 loss: 0.1809


38it [00:00, 177.99it/s]


Test accuracy: 95.333%


263it [00:02, 126.86it/s]


Train epoch 46 loss: 0.1658


75it [00:00, 177.61it/s]


Val epoch 46 loss: 0.1981


263it [00:02, 126.76it/s]


Train epoch 47 loss: 0.1600


75it [00:00, 177.61it/s]


Val epoch 47 loss: 0.2079


263it [00:02, 127.33it/s]


Train epoch 48 loss: 0.1654


75it [00:00, 177.94it/s]


Val epoch 48 loss: 0.2231


263it [00:02, 126.62it/s]


Train epoch 49 loss: 0.1475


75it [00:00, 178.66it/s]

Val epoch 49 loss: 0.1893





In [13]:
get_accuracy(reference_model, test_loader)

38it [00:00, 144.18it/s]


tensor(94.4167, device='cuda:0')

# Training Loop for GoE with Oracle Router

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from src.binary_tree import BinaryTreeGoE, MNISTOracleRouter

In [15]:
# Initialize 
class LayerOne(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 32, kernel_size=3, padding=1)

    def forward(self, x: torch.Tensor):
        """
        Args:
        - x: (batch_size, 32, 28, 28)
        """
        return F.relu(F.max_pool2d(self.conv(x), 2))  # (batch_size, 32, 14, 14)

class LayerTwo(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(32, 64, kernel_size=3, padding=1)

    def forward(self, x: torch.Tensor):
        """
        Args:
        - x: (batch_size, 32, 14, 14)
        """
        return F.relu(F.max_pool2d(self.conv(x), 2))  # (batch_size, 64, 7, 7)

class LayerThree(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x: torch.Tensor):
        """
        Args:
        - x: (batch_size, 32, 7, 7)
        """
        x = x.view(-1, 64 * 7 * 7)  # (batch_size, 64 * 7 * 7)
        x = F.relu(self.fc1(x))  # (batch_size, 128)
        logits = self.fc2(x)  # (batch_size, 10)
        return logits

In [16]:
# Training loop
router = MNISTOracleRouter()
goe_model = BinaryTreeGoE(modules_by_depth = [LayerOne, LayerTwo, LayerThree], router=router)
val_losses = training_loop(
    model=goe_model, 
    train_loader=train_loader, 
    val_loader=val_loader, 
    test_loader=test_loader, 
    num_epochs=50, 
    lr = 0.005, 
    epochs_per_accuracy=5)

263it [00:02, 115.48it/s]


Train epoch 0 loss: 0.1537


75it [00:00, 176.35it/s]


Val epoch 0 loss: 0.1424


38it [00:00, 179.00it/s]


Test accuracy: 93.833%


263it [00:02, 126.74it/s]


Train epoch 1 loss: 0.1499


75it [00:00, 178.72it/s]


Val epoch 1 loss: 0.1765


263it [00:02, 126.74it/s]


Train epoch 2 loss: 0.1629


75it [00:00, 177.42it/s]


Val epoch 2 loss: 0.1664


263it [00:02, 128.16it/s]


Train epoch 3 loss: 0.1528


75it [00:00, 176.15it/s]


Val epoch 3 loss: 0.1742


263it [00:02, 127.36it/s]


Train epoch 4 loss: 0.1532


75it [00:00, 177.14it/s]


Val epoch 4 loss: 0.1574


263it [00:02, 122.30it/s]


Train epoch 5 loss: 0.1521


75it [00:00, 177.68it/s]


Val epoch 5 loss: 0.1843


38it [00:00, 177.84it/s]


Test accuracy: 94.083%


263it [00:02, 127.70it/s]


Train epoch 6 loss: 0.1692


75it [00:00, 180.85it/s]


Val epoch 6 loss: 0.1858


263it [00:02, 127.65it/s]


Train epoch 7 loss: 0.1416


75it [00:00, 177.87it/s]


Val epoch 7 loss: 0.1479


263it [00:02, 126.33it/s]


Train epoch 8 loss: 0.1556


75it [00:00, 175.26it/s]


Val epoch 8 loss: 0.2004


263it [00:02, 128.24it/s]


Train epoch 9 loss: 0.1721


75it [00:00, 182.31it/s]


Val epoch 9 loss: 0.1704


263it [00:02, 127.70it/s]


Train epoch 10 loss: 0.1579


75it [00:00, 171.08it/s]


Val epoch 10 loss: 0.1666


38it [00:00, 177.26it/s]


Test accuracy: 94.833%


263it [00:02, 128.15it/s]


Train epoch 11 loss: 0.1466


75it [00:00, 182.62it/s]


Val epoch 11 loss: 0.2060


263it [00:02, 129.20it/s]


Train epoch 12 loss: 0.1570


75it [00:00, 183.18it/s]


Val epoch 12 loss: 0.1852


263it [00:02, 129.54it/s]


Train epoch 13 loss: 0.1592


75it [00:00, 183.75it/s]


Val epoch 13 loss: 0.1825


263it [00:02, 127.54it/s]


Train epoch 14 loss: 0.1442


75it [00:00, 180.86it/s]


Val epoch 14 loss: 0.1813


263it [00:02, 126.96it/s]


Train epoch 15 loss: 0.1516


75it [00:00, 181.83it/s]


Val epoch 15 loss: 0.1633


38it [00:00, 182.74it/s]


Test accuracy: 95.583%


263it [00:02, 123.37it/s]


Train epoch 16 loss: 0.1536


75it [00:00, 182.61it/s]


Val epoch 16 loss: 0.1593


263it [00:02, 129.77it/s]


Train epoch 17 loss: 0.1517


75it [00:00, 182.29it/s]


Val epoch 17 loss: 0.1825


263it [00:01, 137.13it/s]


Train epoch 18 loss: 0.1601


75it [00:00, 182.21it/s]


Val epoch 18 loss: 0.1535


263it [00:01, 146.02it/s]


Train epoch 19 loss: 0.1432


75it [00:00, 179.24it/s]


Val epoch 19 loss: 0.1879


263it [00:01, 147.63it/s]


Train epoch 20 loss: 0.1534


75it [00:00, 181.92it/s]


Val epoch 20 loss: 0.1602


38it [00:00, 182.81it/s]


Test accuracy: 95.917%


263it [00:01, 148.39it/s]


Train epoch 21 loss: 0.1473


75it [00:00, 183.35it/s]


Val epoch 21 loss: 0.1538


263it [00:01, 149.07it/s]


Train epoch 22 loss: 0.1539


75it [00:00, 182.11it/s]


Val epoch 22 loss: 0.1701


263it [00:01, 148.55it/s]


Train epoch 23 loss: 0.1479


75it [00:00, 183.68it/s]


Val epoch 23 loss: 0.1681


263it [00:01, 148.98it/s]


Train epoch 24 loss: 0.1503


75it [00:00, 183.37it/s]


Val epoch 24 loss: 0.1953


263it [00:02, 128.44it/s]


Train epoch 25 loss: 0.1714


75it [00:00, 180.27it/s]


Val epoch 25 loss: 0.1570


38it [00:00, 178.95it/s]


Test accuracy: 94.333%


263it [00:02, 128.76it/s]


Train epoch 26 loss: 0.1468


75it [00:00, 181.78it/s]


Val epoch 26 loss: 0.1983


263it [00:02, 129.36it/s]


Train epoch 27 loss: 0.1532


75it [00:00, 182.18it/s]


Val epoch 27 loss: 0.1572


263it [00:02, 129.45it/s]


Train epoch 28 loss: 0.1567


75it [00:00, 182.22it/s]


Val epoch 28 loss: 0.1810


263it [00:02, 129.48it/s]


Train epoch 29 loss: 0.1528


75it [00:00, 183.54it/s]


Val epoch 29 loss: 0.1655


263it [00:02, 129.22it/s]


Train epoch 30 loss: 0.1548


75it [00:00, 183.35it/s]


Val epoch 30 loss: 0.1769


38it [00:00, 183.35it/s]


Test accuracy: 94.667%


263it [00:02, 129.51it/s]


Train epoch 31 loss: 0.1574


75it [00:00, 184.16it/s]


Val epoch 31 loss: 0.1849


263it [00:02, 130.18it/s]


Train epoch 32 loss: 0.1487


75it [00:00, 183.73it/s]


Val epoch 32 loss: 0.1566


263it [00:02, 128.86it/s]


Train epoch 33 loss: 0.1517


75it [00:00, 179.63it/s]


Val epoch 33 loss: 0.1587


263it [00:02, 128.08it/s]


Train epoch 34 loss: 0.1517


75it [00:00, 183.64it/s]


Val epoch 34 loss: 0.1458


263it [00:02, 128.91it/s]


Train epoch 35 loss: 0.1480


75it [00:00, 181.83it/s]


Val epoch 35 loss: 0.1379


38it [00:00, 184.67it/s]


Test accuracy: 94.167%


263it [00:02, 129.35it/s]


Train epoch 36 loss: 0.1548


75it [00:00, 182.16it/s]


Val epoch 36 loss: 0.1901


263it [00:02, 129.08it/s]


Train epoch 37 loss: 0.1679


75it [00:00, 183.01it/s]


Val epoch 37 loss: 0.1655


263it [00:02, 129.41it/s]


Train epoch 38 loss: 0.1411


75it [00:00, 182.66it/s]


Val epoch 38 loss: 0.1609


263it [00:02, 127.45it/s]


Train epoch 39 loss: 0.1477


75it [00:00, 159.11it/s]


Val epoch 39 loss: 0.1825


263it [00:02, 117.31it/s]


Train epoch 40 loss: 0.1538


75it [00:00, 159.47it/s]


Val epoch 40 loss: 0.1668


38it [00:00, 159.54it/s]


Test accuracy: 93.417%


263it [00:02, 122.16it/s]


Train epoch 41 loss: 0.1515


75it [00:00, 181.75it/s]


Val epoch 41 loss: 0.1869


263it [00:02, 129.49it/s]


Train epoch 42 loss: 0.1614


75it [00:00, 184.01it/s]


Val epoch 42 loss: 0.1602


263it [00:02, 129.46it/s]


Train epoch 43 loss: 0.1391


75it [00:00, 182.19it/s]


Val epoch 43 loss: 0.1730


263it [00:02, 129.88it/s]


Train epoch 44 loss: 0.1567


75it [00:00, 181.29it/s]


Val epoch 44 loss: 0.1600


263it [00:02, 127.69it/s]


Train epoch 45 loss: 0.1257


75it [00:00, 179.72it/s]


Val epoch 45 loss: 0.1621


38it [00:00, 182.44it/s]


Test accuracy: 95.083%


263it [00:02, 128.76it/s]


Train epoch 46 loss: 0.1537


75it [00:00, 181.27it/s]


Val epoch 46 loss: 0.1684


263it [00:02, 129.26it/s]


Train epoch 47 loss: 0.1612


75it [00:00, 181.67it/s]


Val epoch 47 loss: 0.2116


263it [00:02, 129.52it/s]


Train epoch 48 loss: 0.1614


75it [00:00, 181.53it/s]


Val epoch 48 loss: 0.1788


263it [00:02, 129.55it/s]


Train epoch 49 loss: 0.1603


75it [00:00, 183.13it/s]

Val epoch 49 loss: 0.1580





In [17]:
get_accuracy(goe_model, test_loader)

38it [00:00, 107.82it/s]


tensor(95.1667, device='cuda:0')

In [22]:
from src.metrics import get_model_flops

get_model_flops(goe_model, next(iter(train_loader))[0])

TypeError: get_path() missing 1 required positional argument: 'rotation_class'

In [None]:
from matplotlib import pyplot as plt

# Sample images from each set
num_samples_per_set = 10
_, axes = plt.subplots(3, 10, figsize=(10 * 10, 10 * 3))
axes = axes.reshape((3,10))

for curr_idx, curr_dataset in enumerate((train_dataset, val_dataset, test_dataset)):
    for sample_idx in range(num_samples_per_set):
        img, rotation_label, digit_label = curr_dataset[sample_idx]
        ax = axes[curr_idx, sample_idx]
        img = img.squeeze().numpy()
        ax.imshow(img, cmap='gray')
        ax.set_title(f'Rotation: {rotation_label * 90}°, Digit: {digit_label}')
        ax.axis('off')

# Adjust layout and display
# plt.tight_layout()
plt.show()