# MNIST Dataset


In [None]:
import torch
import numpy as np
import torchvision
from sklearn.model_selection import train_test_split

print('torch:', torch.__version__)

In [None]:
# Load full dataset
train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=torchvision.transforms.ToTensor())

# Combine and split
X = torch.cat([train_set.data, test_set.data], dim=0).float()
y = torch.cat([train_set.targets, test_set.targets], dim=0).float()

# Reshape for linear model
X = X.view(X.shape[0], -1)

# Train/val/test split: 70/15/15
X_train, X_tmp, y_train, y_tmp = train_test_split(
    X.numpy(), y.numpy(), test_size=0.30, random_state=42
)
X_val, X_test, y_val, y_test = train_test_split(
    X_tmp, y_tmp, test_size=0.50, random_state=42
)

# Print dimensions
print("Train X:", X_train.shape, "Train y:", y_train.shape)
print("Val   X:", X_val.shape,   "Val   y:", y_val.shape)
print("Test  X:", X_test.shape,  "Test  y:", y_test.shape)

## Train Withour Regularization

In [None]:
X_train_t = torch.tensor(X_train, dtype=torch.float32)
y_train_t = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)

# MNIST has 10 classes
n_input = X_train.shape[1]
n_output = 10

w = torch.zeros((n_input, n_output), requires_grad=True)
b = torch.zeros(n_output, requires_grad=True)

lr = 1e-4
epochs = 100
loss_fn = torch.nn.CrossEntropyLoss()

for i in range(epochs):
    y_pred = X_train_t @ w + b
    loss = loss_fn(y_pred, y_train_t.long().squeeze())
    loss.backward()
    with torch.no_grad():
        w -= lr * w.grad
        b -= lr * b.grad
        w.grad.zero_()
        b.grad.zero_()
    if (i+1) % 10 == 0:
        print(f"Epoch {i+1}, Loss: {loss.item():.4f}")

print("Final loss:", loss.item())

## Test Loss

In [None]:
X_test_t = torch.tensor(X_test, dtype=torch.float32)
y_test_t = torch.tensor(y_test, dtype=torch.float32).view(-1, 1)

with torch.no_grad():
    y_pred_test = X_test_t @ w + b
    test_loss = loss_fn(y_pred_test, y_test_t.long().squeeze())

print("Test loss:", test_loss.item())