In [None]:
import torchvision
from torchvision import transforms
#from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt # imshow might be faster?
import numpy as np
import math


# Load the EMNIST dataset with the correct transform
transform = transforms.Compose([
    transforms.ToTensor(),
])

original_dataset_EMNIST = torchvision.datasets.EMNIST(
    root='./data', 
    split='balanced', 
    train=True, 
    download=True, 
    transform=transform
)

balanced_dict = {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: 'A', 11: 'B', 12: 'C', 13: 'D', 14: 'E', 15: 'F', 16: 
                 'G', 17: 'H', 18: 'I', 19: 'J', 20: 'K', 21: 'L', 22: 'M', 23: 'N', 24: 'O', 25: 'P', 26: 'Q', 27: 'R', 28: 'S', 29: 'T', 30: 'U', 31: 'V', 
                 32: 'W', 33: 'X', 34: 'Y', 35: 'Z', 36: 'a', 37: 'b', 38: 'd', 39: 'e', 40: 'f', 41: 'g', 42: 'h', 43: 'n', 44: 'q', 45: 'r', 46: 't'}

def visualize_emnist_samples(dataset):
    num_classes = 47  # EMNIST balanced has 47 classes
    fig, axes = plt.subplots(7, 7, figsize=(14, 14))
    fig.suptitle('EMNIST - One Sample per Class')

    class_samples = {}
    
    # First pass: collect one sample per class
    for idx in range(int(len(dataset)/2),len(dataset)):
        image, label = dataset[idx]
        if label not in class_samples:
            class_samples[label] = image
            if len(class_samples) == num_classes:
                break

    # Plot the samples in class order
    for label in range(num_classes):
        row = label // 7
        col = label % 7
        ax = axes[row, col]
        
        # Rotate the image 90 degrees counterclockwise
        rotated_image = torch.rot90(class_samples[label], k=-1, dims=[1, 2])
        flipped_image = torch.flip(rotated_image, dims=[2])
        
        ax.imshow(flipped_image.squeeze(), cmap='gray')
        ax.set_title(f'Label: {balanced_dict[label]}')
        ax.axis('off')

    # Remove empty subplots
    for i in range(num_classes, 49):
        row = i // 7
        col = i % 7
        fig.delaxes(axes[row, col])

    plt.tight_layout()
    plt.subplots_adjust(wspace=0.1, hspace=0.3)
    plt.show()

# Visualize one sample per class
visualize_emnist_samples(original_dataset_EMNIST)