In [1]:
import torch 
import torchvision 

train_mnist = torchvision.datasets.MNIST(
    "./data",
    train=True,
    download=True,
    transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0,1397,), (0,3081))
    ])
)

test_mnist = torchvision.datasets.MNIST(
    "./data",
    train=False,
    download=True,
    transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0,1397,), (0,3081))
    ])
)


100%|██████████| 9.91M/9.91M [00:00<00:00, 51.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.75MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.3MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.31MB/s]


## Class: `OHMNISTGenerator`

This network is not a classifier but a **generator**: it learns to produce MNIST-like digit images from one-hot class codes.

### Structure
- Inherits from `torch.nn.Module`, which allows it to define layers and be trained with PyTorch.
- The model uses a small multilayer perceptron (MLP):
  - Linear layer: input size 10 → hidden size 300  
    (expects a 10-dimensional one-hot vector where each position corresponds to a digit 0–9).
  - LeakyReLU activation: like ReLU but with a small negative slope, which prevents neurons from "dying" at zero output.
  - Linear layer: hidden size 300 → output size 28×28 = 784  
    (the flattened number of pixels in an MNIST image).

### Forward pass
- Input `x`: shape (batch_size, 10), each row is a one-hot digit vector.
- The MLP produces an output of shape (batch_size, 784).
- This is reshaped into (batch_size, 28, 28), so each output is a 28×28 image.

### Intuition
- The network learns a mapping from **digit labels → images**.
- Example: input one-hot for digit "3" → the output should be an image resembling the digit 3.
- This setup is essentially the reverse of a classifier: instead of predicting labels from images, it generates images from labels.

### Key idea
- Input: one-hot encoded label (10 values).
- Output: generated MNIST-style image (28×28).
- Purpose: demonstrate how to build a simple generator network that produces images given a class code.


## Custom Dataset: OneHotMNIST

This wrapper class takes the standard MNIST dataset and modifies how each sample is returned.

### Purpose
- In the original MNIST dataset, a sample is `(image, label)`, where:
  - `image` has shape (1, 28, 28) (channel × height × width).
  - `label` is an integer in the range 0–9.
- `OneHotMNIST` transforms the label into a **one-hot vector** and adjusts the image shape.

### How it works
- `__len__`: returns the number of samples in the dataset (same as the base MNIST).
- `__getitem__(idx)`:
  - Retrieves `(img, cls)` from the original dataset.
  - Creates a zero vector of length 10.
  - Sets the position corresponding to the digit class to 1 → this is the one-hot encoding.
    - Example: if `cls = 7`, the one-hot vector is [0, 0, 0, 0, 0, 0, 0, 1, 0, 0].
  - Returns `(one_hot_label, image)`.

### Why `img[0, :, :]`?
- Original MNIST images are shaped (1, 28, 28).
  - The first dimension is the **channel** (MNIST is grayscale → 1 channel).
- Taking `img[0, :, :]` removes this redundant channel dimension, giving a simpler (28, 28) array.
- This makes sense here because we know MNIST always has one channel.

### Important note
- If you were working with RGB images (3 channels), you would **not** drop the channel dimension.
- Keeping or removing the channel depends on how you plan to feed the images into your model.


In [2]:
class OneHotMNIST(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset 
        
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, cls = self.dataset[idx] 
        oh = torch.zeros(10) 
        oh[cls] = 1 
        return oh, img[0,:,:]
        