In [1]:
# If no torchvision locally, uncomment the next line and run it once:
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

import os, random, math, time 
from pathlib import Path 

import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torch.utils.data import DataLoader 

import torchvision 
from torchvision import datasets, transforms 

import numpy as np 
import matplotlib.pyplot as plt 

SEED = 42 
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = torch.device(device)
DEVICE

device(type='cpu')

## Data: download transform and inspect
We download MNIST (70k images of handwritten digits, 28*28 grayscale) using `torchvision.datasets.MNIST`

**Transforms used** :
* `ToTensor()` : converts PIL images to PyTorch tensors in [0,1]
* `Normalize((0.1307,), (0.3081,))` standardizes pixel values using the dataset's mean and std for faster, stabler training

In [2]:
DATA_DIR = Path("./data")

transform = transforms.Compose([
    transforms.ToTensor(),                    # [H,W] -> [1,H,W], values in [0,1]
    transforms.Normalize((0.1307,), (0.3081,))# standardize using MNIST stats
])

train_ds = torchvision.datasets.MNIST(
    root=DATA_DIR,
    train=True, 
    download=True,
    transform=transform
)

test_ds = torchvision.datasets.MNIST(
    root=DATA_DIR, 
    train=False, 
    download=True, 
    transform=transform
)

print(train_ds)
print(test_ds)

100%|██████████| 9.91M/9.91M [00:00<00:00, 11.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 336kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.16MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.55MB/s]

Dataset MNIST
    Number of datapoints: 60000
    Root location: data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.1307,), std=(0.3081,))
           )
Dataset MNIST
    Number of datapoints: 10000
    Root location: data
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.1307,), std=(0.3081,))
           )



