In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import numpy as np
from torchvision import datasets, transforms

In [None]:
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)

# ---------------------------------------------
# Step 1: Prepare Data (Reuses existing loaders)
# ---------------------------------------------
transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

# Split into labeled and unlabeled sets
def split_labeled_unlabeled(dataset, num_labels=100):
    targets = np.array(dataset.targets)
    labeled_idx = []
    for i in range(10):
        idx = np.where(targets == i)[0][:num_labels // 10]
        labeled_idx.extend(idx)
    unlabeled_idx = list(set(range(len(dataset))) - set(labeled_idx))
    return Subset(dataset, labeled_idx), Subset(dataset, unlabeled_idx)

labeled_dataset, unlabeled_dataset = split_labeled_unlabeled(train_dataset)
labeled_loader = DataLoader(labeled_dataset, batch_size=64, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

In [None]:
# ---------------------------------------------
# Step 2: VAE Model with Classifier
# ---------------------------------------------
def one_hot(y, num_classes=10):
    return F.one_hot(y, num_classes=num_classes).float()

class Encoder(nn.Module):
    def __init__(self, latent_dim=20, n_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(784 + n_classes, 400)
        self.fc_mu = nn.Linear(400, latent_dim)
        self.fc_logvar = nn.Linear(400, latent_dim)

    def forward(self, x, y):
        xy = torch.cat([x, y], dim=1)
        h = F.relu(self.fc1(xy))
        return self.fc_mu(h), self.fc_logvar(h)

class Decoder(nn.Module):
    def __init__(self, latent_dim=20, n_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(latent_dim + n_classes, 400)
        self.fc2 = nn.Linear(400, 784)

    def forward(self, z, y):
        zy = torch.cat([z, y], dim=1)
        h = F.relu(self.fc1(zy))
        return torch.sigmoid(self.fc2(h))  # Ensure output in [0, 1]

class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(784, 10)

    def forward(self, x):
        return F.log_softmax(self.fc(x), dim=1)

def kl_divergence(mu, logvar):
    return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)

def reconstruction_loss(x_recon, x_true):
    x_recon = torch.clamp(x_recon, 1e-6, 1 - 1e-6)  # avoid log(0)
    x_true = torch.clamp(x_true, 0, 1)
    assert x_recon.shape == x_true.shape, f"Shape mismatch: {x_recon.shape} vs {x_true.shape}"
    return F.binary_cross_entropy(x_recon, x_true, reduction='none').sum(dim=1)

In [None]:
# ---------------------------------------------
# Step 3: Full VAE-SSL Model
# ---------------------------------------------
class VAESSL(nn.Module):
    def __init__(self, latent_dim=20, n_classes=10):
        super().__init__()
        self.encoder = Encoder(latent_dim, n_classes)
        self.decoder = Decoder(latent_dim, n_classes)
        self.classifier = Classifier()
        self.n_classes = n_classes
        self.latent_dim = latent_dim

    def forward_labeled(self, x, y):
        mu, logvar = self.encoder(x, y)
        std = torch.exp(0.5 * logvar)
        z = mu + std * torch.randn_like(std)
        x_recon = self.decoder(z, y)
        return x_recon, mu, logvar

    def forward_unlabeled(self, x):
        log_probs = self.classifier(x)
        probs = log_probs.exp()
        total_loss = 0
        for i in range(self.n_classes):
            y_i = torch.eye(self.n_classes)[i].repeat(x.size(0), 1).to(device)
            mu, logvar = self.encoder(x, y_i)
            std = torch.exp(0.5 * logvar)
            z = mu + std * torch.randn_like(std)
            x_recon = self.decoder(z, y_i)

            recon = reconstruction_loss(x_recon, x)
            kl = kl_divergence(mu, logvar)
            elbo = recon + kl - log_probs[:, i]
            total_loss += probs[:, i] * elbo
        return total_loss.mean()

In [None]:
# ---------------------------------------------
# Step 4: Training & Evaluation
# ---------------------------------------------
model = VAESSL().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)


def train_epoch(model, labeled_loader, unlabeled_loader, alpha=0.1, beta=1.0):
    model.train()
    total_labeled_loss = 0
    total_unlabeled_loss = 0

    for (x_l, y_l), (x_u, _) in zip(labeled_loader, unlabeled_loader):
        x_l = x_l.view(-1, 784).to(device)
        x_u = x_u.view(-1, 784).to(device)
        y_l = y_l.to(device)
        y_l_1h = one_hot(y_l).to(device)

        # Labeled VAE loss
        x_recon, mu, logvar = model.forward_labeled(x_l, y_l_1h)
        recon_loss = reconstruction_loss(x_recon, x_l)
        kl_loss = kl_divergence(mu, logvar)
        labeled_vae_loss = (recon_loss + kl_loss).mean()

        # ➕ Supervised classifier loss
        class_pred = model.classifier(x_l)
        class_loss = F.nll_loss(class_pred, y_l)

        # Unlabeled VAE loss
        unlabeled_loss = model.forward_unlabeled(x_u)

        # Total loss
        loss = labeled_vae_loss + alpha * unlabeled_loss + beta * class_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_labeled_loss += (labeled_vae_loss + beta * class_loss).item()
        total_unlabeled_loss += unlabeled_loss.item()

    return total_labeled_loss, total_unlabeled_loss


def evaluate_classifier(model, loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.view(-1, 784).to(device)
            y = y.to(device)
            pred = model.classifier(x).argmax(1)
            correct += (pred == y).sum().item()
    return correct / len(loader.dataset)

In [None]:
# ---------------------------------------------
# Step 5: Training Loop
# ---------------------------------------------
for epoch in range(1, 21):
    l_loss, u_loss = train_epoch(model, labeled_loader, unlabeled_loader, alpha=0.1)
    acc = evaluate_classifier(model, test_loader)
    print(f"Epoch {epoch:02d} | Test Accuracy: {acc:.4f} | Labeled Loss: {l_loss:.2f} | Unlabeled Loss: {u_loss:.2f}")