In [None]:
# SSL : Ladder Network

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader,Subset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# -------------------------
# Step 1: Data Preparation
# -------------------------

# Transform: Normalize MNIST images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load full MNIST dataset
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

# Create labeled subset (e.g., 100 examples evenly distributed across classes)
num_labels = 100
labels = np.array(train_dataset.targets)
labeled_idx = []

for i in range(10):
    idx = np.where(labels == i)[0][:num_labels // 10]
    labeled_idx.extend(idx)

unlabeled_idx = list(set(range(len(train_dataset))) - set(labeled_idx))

labeled_dataset = Subset(train_dataset, labeled_idx)
unlabeled_dataset = Subset(train_dataset, unlabeled_idx)

labeled_loader = DataLoader(labeled_dataset, batch_size=64, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

In [None]:
# add noise

class GaussianNoise(nn.Module):
    def __init__(self, stddev):
        super().__init__()
        self.stddev = stddev
    def forward(self, x):
        if self.training:
            noise = torch.randn_like(x) * self.stddev
            return x + noise
        else:
            return x

class Encoder(nn.Module):
    def __init__(self, noise_std):
        super().__init__()
        self.noise = GaussianNoise(noise_std)
        self.fc1 = nn.Linear(784, 1000)
        self.fc2 = nn.Linear(1000, 500)
        self.fc3 = nn.Linear(500, 250)
        self.fc4 = nn.Linear(250, 250)
        self.fc5 = nn.Linear(250, 10)

    def forward(self, x):
        z = []
        x = x.view(-1, 784)
        x = self.noise(x)
        z1 = self.fc1(x)
        z.append(z1)
        z2 = self.fc2(F.relu(z1))
        z.append(z2)
        z3 = self.fc3(F.relu(z2))
        z.append(z3)
        z4 = self.fc4(F.relu(z3))
        z.append(z4)
        z5 = self.fc5(F.relu(z4))
        z.append(z5)
        return z

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 250)    # reconstruct z4 (250)
        self.fc2 = nn.Linear(250, 250)   # reconstruct z3 (250)
        self.fc3 = nn.Linear(250, 500)   # reconstruct z2 (500)
        self.fc4 = nn.Linear(500, 1000)  # reconstruct z1 (1000)

    def forward(self, z_corr):
        d1 = self.fc1(z_corr[-1])              # input: z5 (10)
        d2 = self.fc2(F.relu(d1))              # matches z3
        d3 = self.fc3(F.relu(d2))              # matches z2
        d4 = self.fc4(F.relu(d3))              # matches z1
        return [d4, d3, d2, d1]  # decoder outputs for z1 to z4

In [None]:
# -------------------------
# Step 3: Training Functions
# -------------------------

def supervised_loss(output, target):
    return F.cross_entropy(output, target)

def reconstruction_loss(z_clean, z_recon):
    loss = 0
    # Match z1 to z4 with d4 to d1
    for zc, zr in zip(z_clean[:4], z_recon):  # z_clean[:4] = z1 to z4
        loss += F.mse_loss(zr, zc.detach())
    return loss

def train_epoch(encoder, decoder, optimizer, labeled_loader, unlabeled_loader, alpha):
    encoder.train()
    decoder.train()

    for (x_l, y_l), (x_u, _) in zip(labeled_loader, unlabeled_loader):
        x_l, y_l = x_l.to(device), y_l.to(device)
        x_u = x_u.to(device)

        z_corr_l = encoder(x_l)
        z_clean_l = encoder(x_l)
        z_corr_u = encoder(x_u)
        z_clean_u = encoder(x_u)

        output = z_corr_l[-1]
        loss_sup = supervised_loss(output, y_l)

        recon_l = decoder(z_corr_l)
        recon_u = decoder(z_corr_u)

        loss_unsup = reconstruction_loss(z_clean_l, recon_l) + reconstruction_loss(z_clean_u, recon_u)

        loss = loss_sup + alpha * loss_unsup

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def evaluate(encoder, loader):
    encoder.eval()
    correct = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            output = encoder(x)[-1]
            pred = output.argmax(dim=1)
            correct += (pred == y).sum().item()
    return correct / len(loader.dataset)

In [None]:
# -------------------------
# Step 4: Training Loop
# -------------------------

encoder = Encoder(noise_std=0.3).to(device)
decoder = Decoder().to(device)
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-3)

epochs = 30
alpha = 0.5

for epoch in range(1, epochs + 1):
    train_epoch(encoder, decoder, optimizer, labeled_loader, unlabeled_loader, alpha)
    test_acc = evaluate(encoder, test_loader)
    labeled_acc = evaluate(encoder, labeled_loader)
    unlabeled_acc = evaluate(encoder, DataLoader(unlabeled_dataset, batch_size=256))
    print(f"Epoch {epoch:02d} | Test Acc: {test_acc:.4f} | Labeled Acc: {labeled_acc:.4f} | Unlabeled Pseudo Acc: {unlabeled_acc:.4f}")