In [1]:
import torch 
import torchvision 

train_mnist = torchvision.datasets.MNIST(
    "./data",
    train=True,
    download=True,
    transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0,1397,), (0,3081))
    ])
)

test_mnist = torchvision.datasets.MNIST(
    "./data",
    train=False,
    download=True,
    transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0,1397,), (0,3081))
    ])
)


100%|██████████| 9.91M/9.91M [00:00<00:00, 12.8MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 340kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.12MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.97MB/s]


## Class: `OHMNISTGenerator`

This network is not a classifier but a **generator**: it learns to produce MNIST-like digit images from one-hot class codes.

### Structure
- Inherits from `torch.nn.Module`, which allows it to define layers and be trained with PyTorch.
- The model uses a small multilayer perceptron (MLP):
  - Linear layer: input size 10 → hidden size 300  
    (expects a 10-dimensional one-hot vector where each position corresponds to a digit 0–9).
  - LeakyReLU activation: like ReLU but with a small negative slope, which prevents neurons from "dying" at zero output.
  - Linear layer: hidden size 300 → output size 28×28 = 784  
    (the flattened number of pixels in an MNIST image).

### Forward pass
- Input `x`: shape (batch_size, 10), each row is a one-hot digit vector.
- The MLP produces an output of shape (batch_size, 784).
- This is reshaped into (batch_size, 28, 28), so each output is a 28×28 image.

### Intuition
- The network learns a mapping from **digit labels → images**.
- Example: input one-hot for digit "3" → the output should be an image resembling the digit 3.
- This setup is essentially the reverse of a classifier: instead of predicting labels from images, it generates images from labels.

### Key idea
- Input: one-hot encoded label (10 values).
- Output: generated MNIST-style image (28×28).
- Purpose: demonstrate how to build a simple generator network that produces images given a class code.


In [2]:
class OHMNISTGenerator(torch.nn.Module):
    def __init__(self):
        super(OHMNISTGenerator, self).__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(10, 300),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(300, 28*28)
        )

    def forward(self, x):
        flat_output = self.mlp(x)
        return flat_output.view(x.shape[0], 28, 28)