In [None]:
# CSCI-485 — Assignment 3

import random, numpy as np
import torch, torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA


SEED = 0
EPOCHS = 15
BATCH = 128
LATS = (16, 32, 64)
LR = 1e-3
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1) data: [0,1], train/val/test
tf = transforms.Compose([transforms.ToTensor()])
def make_loaders():
    train_full = datasets.FashionMNIST("./data", train=True, download=True, transform=tf)
    test_ds    = datasets.FashionMNIST("./data", train=False, download=True, transform=tf)
    n = len(train_full); n_val = int(0.1*n)
    train_ds, val_ds = random_split(train_full, [n - n_val, n_val])
    return (
        DataLoader(train_ds, batch_size=BATCH, shuffle=True),
        DataLoader(val_ds,   batch_size=BATCH, shuffle=False),
        DataLoader(test_ds,  batch_size=BATCH, shuffle=False),
    )

# 2) model: FC AE, ReLU hidden, Sigmoid output
class AE(nn.Module):
    def __init__(self, latent):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Flatten(), nn.Linear(784,256), nn.ReLU(), nn.Linear(256, latent)
        )
        self.dec = nn.Sequential(
            nn.Linear(latent,256), nn.ReLU(), nn.Linear(256,784), nn.Sigmoid()
        )
    def forward(self, x):
        z = self.enc(x); y = self.dec(z)
        return y.view(-1,1,28,28), z

crit = nn.MSELoss()

def train_epoch(m, loader, opt):
    m.train(); L=N=0.0
    for x,_ in loader:
        x=x.to(device); opt.zero_grad()
        y,_=m(x); loss=crit(y,x); loss.backward(); opt.step()
        L+=loss.item(); N+=1
    return L/N

@torch.no_grad()
def eval_epoch(m, loader):
    m.eval(); L=N=0.0
    for x,_ in loader:
        x=x.to(device); y,_=m(x); loss=crit(y,x)
        L+=loss.item(); N+=1
    return L/N

@torch.no_grad()
def collect_latents(m, loader, max_batches=8):
    m.eval(); Z=[]; Y=[]; n=0
    for x,y in loader:
        x=x.to(device); _,z=m(x)
        Z.append(z.cpu().numpy()); Y.append(y.numpy()); n+=1
        if n>=max_batches: break
    return np.concatenate(Z,0), np.concatenate(Y,0)

# ---- visuals (SAVE + SHOW) ----
def recon_grid(x_true_np, x_pred_np, path):
    B = min(10, x_true_np.shape[0])
    fig = plt.figure(figsize=(B*1.1, 2.2))
    for i in range(B):
        ax1 = fig.add_subplot(2,B,i+1);   ax1.imshow(x_true_np[i,0], cmap="gray"); ax1.axis("off")
        ax2 = fig.add_subplot(2,B,B+i+1); ax2.imshow(x_pred_np[i,0], cmap="gray"); ax2.axis("off")
    plt.tight_layout(); fig.savefig(path, dpi=150); plt.show()

def loss_curve(tr, va, path, title):
    fig = plt.figure(figsize=(8,4))
    plt.plot(tr,'--',label="train"); plt.plot(va,label="val")
    plt.xlabel("epoch"); plt.ylabel("loss"); plt.title(title); plt.legend()
    plt.tight_layout(); fig.savefig(path, dpi=150); plt.show()

def latent_scatter(Z,y,path,title):
    Z2 = PCA(n_components=2, random_state=SEED).fit_transform(Z) if Z.shape[1]>2 else Z
    fig = plt.figure(figsize=(6,5))
    sc = plt.scatter(Z2[:,0], Z2[:,1], c=y, s=6, cmap="tab10")
    plt.colorbar(sc, ticks=range(10)); plt.title(title)
    plt.tight_layout(); fig.savefig(path, dpi=150); plt.show()

# 3–4) train / eval / visualize
def main():
    train_loader, val_loader, test_loader = make_loaders()
    summary=[]
    for lat in LATS:
        print("latent =", lat)
        m = AE(lat).to(device)
        opt = torch.optim.Adam(m.parameters(), lr=LR)
        trL, vaL = [], []
        for ep in range(EPOCHS):
            a = train_epoch(m, train_loader, opt)
            b = eval_epoch(m, val_loader)
            trL.append(a); vaL.append(b)
            print("epoch", ep+1, "train_loss=", round(a,4), "val_loss=", round(b,4))
        test_mse = eval_epoch(m, test_loader); summary.append((lat, test_mse))
        # recon grid
        x_true,_ = next(iter(test_loader)); x_true=x_true.to(device)
        with torch.no_grad(): x_pred,_ = m(x_true)
        recon_grid(x_true.cpu().numpy(), x_pred.detach().cpu().numpy(), f"fashion_recon_lat{lat}.png")
        # loss curve
        loss_curve(trL, vaL, f"fashion_loss_lat{lat}.png", f"Loss (latent={lat})")
        # latent scatter
        Z,y = collect_latents(m, test_loader, 8)
        latent_scatter(Z,y,f"fashion_latent_lat{lat}.png", f"Latent (latent={lat})")
        print("test_mse =", round(test_mse,4))
    print("Summary (test MSE):", [(l, round(m,4)) for l,m in summary])

if __name__ == "__main__":
    main()
