<a href="https://colab.research.google.com/github/data4class/handwrittendigits/blob/main/Toy_NN_with_visualization_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""

Minimal 3×2×1 fully-connected network in PyTorch

"""

import torch
import torch.nn as nn
import pandas as pd

# 1. Hyper-parameters
LR   = 0.5         # big learning rate so weights visibly move
EPOCHS = 200       # enough to converge on toy data

# 2. Load toy dataset ---------------------------------------------------------
df = pd.read_csv('data.csv')
X = torch.tensor(df[['x1','x2']].values, dtype=torch.float32)
y = torch.tensor(df['y'].values, dtype=torch.float32).view(-1,1)

# 3. Define the network architecture ------------------------------------------
class TinyNet(nn.Module):
    def __init__(self):
        super().__init__()
        # layer sizes: 2 → 2 → 1
        self.hidden = nn.Linear(2, 2)   # W1 shape (2×2), b1 (2)
        self.out    = nn.Linear(2, 1)   # W2 shape (1×2), b2 (1)
    def forward(self, x):
        x = torch.sigmoid(self.hidden(x))   # h = σ(W1·x + b1)
        x = torch.sigmoid(self.out(x))      # ŷ = σ(W2·h + b2)
        return x
net = TinyNet()

# 4. Loss and optimizer -------------------------------------------------------
criterion  = nn.MSELoss()
optimizer  = torch.optim.SGD(net.parameters(), lr=LR)

# 5. Training loop ------------------------------------------------------------
loss_log   = []          # for the visualizer
W1_log, W2_log = [], []  # store weights every epoch

for epoch in range(EPOCHS):
    # forward
    y_hat = net(X)
    loss  = criterion(y_hat, y)

    # backward + update
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # bookkeeping
    loss_log.append(loss.item())
    with torch.no_grad():
        W1_log.append(net.hidden.weight.clone().detach().numpy())
        W2_log.append(net.out.weight.clone().detach().numpy())

    if epoch % 20 == 0:
        print(f'epoch {epoch:3d}  loss {loss.item():.4f}')
