In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

# -- Reproducibility -----------------------------------------------------------
torch.manual_seed(42)
np.random.seed(42)

# -- 1. Dataset ----------------------------------------------------------------
# A deliberately small dataset to make overfitting unavoidable
N_TRAIN = 30
N_VAL   = 200

X_train = torch.linspace(-3, 3, N_TRAIN).unsqueeze(1)
y_train = torch.sin(X_train) + 0.3 * torch.randn_like(X_train)  # sine + noise

X_val = torch.linspace(-3, 3, N_VAL).unsqueeze(1)
y_val = torch.sin(X_val)  # clean ground truth

# -- 2. Deliberately Oversized Model -------------------------------------------
# 5 hidden layers x 512 units trained on 30 data points
# This extreme capacity mismatch guarantees overfitting
class OverparameterizedNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 512),   nn.ReLU(),
            nn.Linear(512, 512), nn.ReLU(),
            nn.Linear(512, 512), nn.ReLU(),
            nn.Linear(512, 512), nn.ReLU(),
            nn.Linear(512, 512), nn.ReLU(),
            nn.Linear(512, 1),
        )

    def forward(self, x):
        return self.net(x)

model     = OverparameterizedNetwork()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# -- 3. Training Loop ----------------------------------------------------------
EPOCHS = 3000
train_losses, val_losses = [], []

for epoch in range(EPOCHS):
    model.train()
    optimizer.zero_grad()
    loss = criterion(model(X_train), y_train)
    loss.backward()
    optimizer.step()

    model.eval()
    with torch.no_grad():
        val_loss = criterion(model(X_val), y_val)

    train_losses.append(loss.item())
    val_losses.append(val_loss.item())

# -- 4. Visualization ----------------------------------------------------------
fig = plt.figure(figsize=(14, 6), facecolor="#0d0d0d")
gs  = gridspec.GridSpec(1, 2, figure=fig, wspace=0.35)

# --- Left panel: loss curves ---
ax1 = fig.add_subplot(gs[0])
ax1.set_facecolor("#0d0d0d")

epochs_range = range(1, EPOCHS + 1)
ax1.plot(epochs_range, train_losses, color="#00e5ff", linewidth=2, label="Training Loss")
ax1.plot(epochs_range, val_losses,   color="#ff1744", linewidth=2, label="Validation Loss")
ax1.set_yscale("log")
ax1.set_xlabel("Epochs", color="white", fontsize=12)
ax1.set_ylabel("MSE Loss (log scale)", color="white", fontsize=12)
ax1.set_title("Loss Curves\nClassic Overfitting Signature", color="white", fontsize=13, pad=12)
ax1.tick_params(colors="white")
ax1.spines[:].set_color("#444")

best_val_epoch = int(np.argmin(val_losses))
ax1.axvline(best_val_epoch, color="#ffea00", linestyle="--", linewidth=1.5,
            label=f"Validation minimum (epoch {best_val_epoch})")
ax1.annotate("Training loss continues to decrease\nValidation loss diverges upward",
             xy=(EPOCHS * 0.7, val_losses[int(EPOCHS * 0.7)]),
             xytext=(EPOCHS * 0.35, max(val_losses) * 0.6),
             color="#ff1744", fontsize=10,
             arrowprops=dict(arrowstyle="->", color="#ff1744"))
ax1.legend(fontsize=10, facecolor="#1a1a1a", labelcolor="white")

# --- Right panel: prediction vs ground truth ---
ax2 = fig.add_subplot(gs[1])
ax2.set_facecolor("#0d0d0d")

X_plot = torch.linspace(-3, 3, 400).unsqueeze(1)
model.eval()
with torch.no_grad():
    y_pred = model(X_plot).squeeze().numpy()

ax2.plot(X_plot.squeeze(), np.sin(X_plot.squeeze()), color="#00e5ff",
         linewidth=2.5, label="Ground Truth (sine function)", zorder=3)
ax2.plot(X_plot.squeeze(), y_pred, color="#ff1744",
         linewidth=2.5, linestyle="--", label="Model Prediction", zorder=3)
ax2.scatter(X_train.squeeze(), y_train.squeeze(), color="#ffea00",
            s=60, zorder=5, label=f"Training Data (N={N_TRAIN})")

ax2.set_xlabel("x", color="white", fontsize=12)
ax2.set_ylabel("y", color="white", fontsize=12)
ax2.set_title("Model Fit vs Ground Truth\nThe model memorizes rather than generalizes",
              color="white", fontsize=13, pad=12)
ax2.tick_params(colors="white")
ax2.spines[:].set_color("#444")
ax2.legend(fontsize=10, facecolor="#1a1a1a", labelcolor="white")

# -- Global title --------------------------------------------------------------
fig.suptitle("SEVERE OVERFITTING\n5-layer x 512-unit network trained on 30 data points",
             color="#ffea00", fontsize=15, fontweight="bold", y=1.02)

plt.savefig("overfitting_demo.png", dpi=150, bbox_inches="tight", facecolor="#0d0d0d")
print("Figure saved as overfitting_demo.png")
plt.show()