In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

In [3]:
def one_hot_encode(labels, output_size=10):
    return torch.nn.functional.one_hot(labels, num_classes=output_size)

# Binarize the MNIST dataset by thresholding at 0.5
transform = transforms.Compose([
    transforms.ToTensor(),transforms.Lambda(lambda x: (x > 0.5).float().view(-1))])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

labeled_size = int(0.1 * len(train_dataset))  # use 10% of the dataset for labels
unlabeled_size = len(train_dataset) - labeled_size
labeled_dataset, unlabeled_dataset = random_split(train_dataset, [labeled_size, unlabeled_size])

labeled_loader = DataLoader(labeled_dataset, batch_size=128, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)


In [4]:
class Encoder_z(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim * 2)
        )
        
    def forward(self, x):
        return self.fc(x)

class Encoder_y(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes),
            nn.Softmax(dim=1) 
        )
        
    def forward(self, x):
        return self.fc(x)

class Decoder(nn.Module):
    def __init__(self, latent_dim, num_classes, hidden_dim, output_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_dim + num_classes, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.fc(x)

In [5]:
class SSVAE(nn.Module):
    def __init__(self, input_size=784, hidden_size=256, z_dim=20, output_size=10):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.z_dim = z_dim
        self.output_size = output_size
        
        self.encoder_z = Encoder_z(input_dim=input_size + output_size,
                                   hidden_dim=hidden_size,
                                   latent_dim=z_dim)
        
        self.encoder_y = Encoder_y(input_dim=input_size,
                                   hidden_dim=hidden_size,
                                   num_classes=output_size)
        
        self.decoder = Decoder(latent_dim=z_dim,
                               num_classes=output_size,
                               hidden_dim=hidden_size,
                               output_dim=input_size)
        
        self.optimizer = Adam({"lr": 0.0001})
        self.svi = SVI(self.model, self.guide, self.optimizer, loss=Trace_ELBO())

    def model(self, xs, ys=None):
        pyro.module("ss_vae", self)
        batch_size = xs.size(0)
        with pyro.plate("data", batch_size):
            #Prior for z 
            prior_loc = xs.new_zeros([batch_size, self.z_dim])
            prior_scale = xs.new_ones([batch_size, self.z_dim])
            zs = pyro.sample("z", dist.Normal(prior_loc, prior_scale).to_event(1))
            
            #Prior for y
            alpha_prior = xs.new_ones([batch_size, self.output_size]) / self.output_size
            ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys)
            
            #Decode: concatenate z and y to get x
            x_input = torch.cat([zs, ys], dim=1)
            loc = self.decoder(x_input)
            pyro.sample("x", dist.Bernoulli(loc).to_event(1), obs=xs)

    def guide(self, xs, ys=None):
        pyro.module("ss_vae", self)
        with pyro.plate("data", xs.size(0)):
            if ys is None:
                # Infer y from x using encoder_y
                alpha = self.encoder_y(xs)
                ys = pyro.sample("y", dist.OneHotCategorical(alpha))
            # Concatenate x and y as input to encoder_z
            combined_input = torch.cat([xs, ys], dim=1)
            z_params = self.encoder_z(combined_input)
            z_loc = z_params[:, :self.z_dim]
            z_scale = F.softplus(z_params[:, self.z_dim:]) + 1e-6
            pyro.sample("z", dist.Normal(z_loc, z_scale).to_event(1))

In [6]:
# Initialize the model
ssvae = SSVAE(input_size=784, hidden_size=128, z_dim=10, output_size=10)
num_epochs = 5

for epoch in range(num_epochs):
    ssvae.train()
    total_loss = 0

    # Training on labeled data
    for x, y in labeled_loader:
        x = x.view(-1, 784)
        y = one_hot_encode(y, output_size=10)
        loss = ssvae.svi.step(x, y)
        total_loss += loss

    # Training on unlabeled data
    for x, _ in unlabeled_loader:
        x = x.view(-1, 784)
        loss = ssvae.svi.step(x)
        total_loss += loss

    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(labeled_loader.dataset):.4f}")


Epoch 1, Loss: 3125.6688
Epoch 2, Loss: 2142.7943
Epoch 3, Loss: 1878.8889
Epoch 4, Loss: 1694.8022
Epoch 5, Loss: 1587.9650
