# MNIST Neural Network from Scratch (PyTorch)

Simple implementation of a neural network from scratch using PyTorch's autograd for MNIST digit classification. Built as a proof of concept before implementing it in C++. The accuracy is around 98% for validation, if you make the network deeper you can probably get it higher.

## Architecture
- Input (784) → Hidden (256) → Output (10)
- ReLU + Softmax activations
- Cross-entropy loss

## Setup
1. Place MNIST CSV files in `data/`
2. Update PROJECT_ROOT path
3. Run notebook

In [1]:
import torch
import csv
import os


PROJECT_ROOT = "/home/nlin/workspace/code/projects/autograd_cpp"

In [2]:
def load_mnist_from_file(csvfile):
    reader = csv.reader(csvfile)

    data = []
    labels = []
    next(reader)
    for row in reader:
        row = [float(i) for i in row]
        data.append([i * 1.0 / 255 for i in row[1:]])
        labels.append([1 if row[0] == i else 0 for i in range(10)])

    n = len(data)
    xs = torch.tensor(data)
    ys = torch.tensor(labels)
    csvfile.close()

    return xs, ys, n


csvfile = open(os.path.join(PROJECT_ROOT, "data/mnist_train.csv"))
xs, ys, n = load_mnist_from_file(csvfile)
xs = xs.cuda()
ys = ys.cuda()

In [3]:
VISUALIZE = False

import matplotlib.pyplot as plt

if VISUALIZE:
    plt.figure(figsize=(28, 28))
    for idx, i in enumerate(xs[:5]):
        image = torch.reshape(i, (28, 28))
        plt.subplot(1, 5, idx + 1)
        plt.imshow(image)
    plt.show()

In [4]:
def softmax(X):
    eX = torch.exp(X)
    sum_eX = torch.sum(eX, dim=1)
    t_sum_eX = sum_eX[None]
    pt_sum_eX = torch.permute(t_sum_eX, (1, 0))
    return eX / pt_sum_eX

In [5]:
W1 = (torch.rand((784, 256), requires_grad=True) - 0.5).cuda()
W2 = (torch.rand((256, 10), requires_grad=True) - 0.5).cuda()
W1 = W1.detach().requires_grad_()
W2 = W2.detach().requires_grad_()
assert W1.is_leaf
assert W2.is_leaf

In [6]:
tcsvfile = open(os.path.join(PROJECT_ROOT, "data/mnist_test_short.csv"))
txs, tys, tn = load_mnist_from_file(tcsvfile)
txs = txs.cuda()
tys = tys.cuda()

In [None]:
BATCH_SIZE = 128
NUM_ITER = 100000
learning_rate = 1e-3

for it in range(NUM_ITER):
    batch_indices = torch.randperm(n)[:BATCH_SIZE]
    batch_xs = xs[batch_indices]
    batch_ys = ys[batch_indices]

    A1 = torch.mm(batch_xs, W1)
    Z1 = torch.where(A1 > 0, A1, 0)
    A2 = torch.mm(Z1, W2)
    Z2 = softmax(A2)

    loss = -torch.sum(
        torch.log(Z2 + 1e-10) * batch_ys
    )  # Added small constant for numerical stability
    loss.backward()

    with torch.no_grad():
        W1 -= learning_rate * W1.grad
        W2 -= learning_rate * W2.grad
        W1.grad.zero_()
        W2.grad.zero_()

        assert W1.is_leaf
        assert W2.is_leaf

    y_pred = torch.argmax(Z2, dim=1)
    y_true = torch.argmax(batch_ys, dim=1)
    if it % 1000 == 0:
        print(
            f"Iteration {it}; Loss: {loss.cpu().float()}, Accuracy: {(torch.sum(y_pred.cpu() == y_true.cpu()) / BATCH_SIZE).float()}"
        )

    if it % 10000 == 0:
        t_A1 = torch.mm(txs, W1)
        t_Z1 = torch.where(t_A1 > 0, t_A1, 0)
        t_A2 = torch.mm(t_Z1, W2)
        t_Z2 = softmax(t_A2)

        t_y_pred = torch.argmax(t_Z2, dim=1)
        t_y_true = torch.argmax(tys, dim=1)
        print(
            "Validation Accuracy: ",
            (torch.sum(t_y_pred.cpu() == t_y_true.cpu()) / tn).float().item(),
        )

In [9]:
torch.save(
    (W1.cpu(), W2.cpu()),
    os.path.join(PROJECT_ROOT, f"checkpoint/mnist_torch_ckpt_final.pt"),
)

In [None]:
ltcsvfile = open(os.path.join(PROJECT_ROOT, "data/mnist_test.csv"))
ltxs, ltys, ltn = load_mnist_from_file(ltcsvfile)
ltxs = ltxs.cuda()
ltys = ltys.cuda()

lt_A1 = torch.mm(txs, W1)
lt_Z1 = torch.where(lt_A1 > 0, lt_A1, 0)
lt_A2 = torch.mm(lt_Z1, W2)
lt_Z2 = softmax(lt_A2)

lt_y_pred = torch.argmax(lt_Z2, dim=1)
lt_y_true = torch.argmax(tys, dim=1)
print(
    "Validation Accuracy: ",
    (torch.sum(lt_y_pred.cpu() == lt_y_true.cpu()) / tn).float().item(),
)