In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torch.utils.data import random_split

from src.data import RotatedMNISTDataset, CelebADataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
celeba_dataset = CelebADataset()

In [None]:
# Create dataset and dataloader
dataset = RotatedMNISTDataset()

# Assuming `dataset` is your PyTorch Dataset
dataset_size = len(dataset)
train_size = int(0.7 * dataset_size)
val_size = int(0.2 * dataset_size)
test_size = dataset_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size], 
    generator=torch.Generator().manual_seed(40)
)

In [None]:
from matplotlib import pyplot as plt

# Sample images from each set
num_samples_per_set = 10
_, axes = plt.subplots(3, 10, figsize=(10 * 10, 10 * 3))
axes = axes.reshape((3,10))

for curr_idx, curr_dataset in enumerate((train_dataset, val_dataset, test_dataset)):
    for sample_idx in range(num_samples_per_set):
        img, rotation_label, digit_label = curr_dataset[sample_idx]
        ax = axes[curr_idx, sample_idx]
        img = img.squeeze().numpy()
        ax.imshow(img, cmap='gray')
        ax.set_title(f'Rotation: {rotation_label * 90}°, Digit: {digit_label}')
        ax.axis('off')

# Adjust layout and display
# plt.tight_layout()
plt.show()