In [122]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets, transforms 
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, JitTrace_ELBO, Trace_ELBO
from pyro.optim import Adam

## Classical VAE

In [90]:
# VAE encoder
class Encoder(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(784, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)

    def forward(self, x):
        hidden = torch.relu(self.fc1(x))
        z_loc = self.fc21(hidden) # mean vector
        z_scale = torch.exp(0.5*self.fc22(hidden)) # covariance vector
        return z_loc, z_scale

In [92]:
class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, 784)

    def forward(self, z):
        hidden = torch.relu(self.fc1(z))
        loc_img = torch.sigmoid(self.fc21(hidden))
        return loc_img

In [94]:
class VAE(nn.Module):
    def __init__(self, z_dim=10, hidden_dim=128):
        super().__init__()
        # create the encoder and decoder networks
        self.encoder = Encoder(z_dim, hidden_dim)
        self.decoder = Decoder(z_dim, hidden_dim)
        self.z_dim = z_dim

    # define the model p(x|z)p(z)
    def model(self, x):
        pyro.module("decoder", self.decoder)
        with pyro.plate("data", x.shape[0]):
            # prior p(z) as standard normal
            z_loc = torch.zeros(x.shape[0], self.z_dim)
            z_scale = torch.ones(x.shape[0], self.z_dim)
            # sample from prior
            z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
            # decode the latent code z
            loc_img = self.decoder.forward(z)
            
            pyro.sample(
                "obs",
                dist.Bernoulli(loc_img, validate_args=False).to_event(1),
                obs=x.reshape(-1, 784),
            )
            
            return loc_img

    # define the guide (i.e. variational distribution) q(z|x)
    def guide(self, x):
        pyro.module("encoder", self.encoder)
        with pyro.plate("data", x.shape[0]):
            # use the encoder to get the parameters used to define q(z|x)
            z_loc, z_scale = self.encoder.forward(x)
            # sample the latent code z
            pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))

    
    def reconstruct_img(self, x):
        # encode image x
        z_loc, z_scale = self.encoder(x)
        # sample in latent space
        z = dist.Normal(z_loc, z_scale).sample()
        # decode the image
        loc_img = self.decoder(z)
        return loc_img

In [96]:
#input_dim = 784
lr = 0.01
batch_size = 128
epochs = 10

train_data = datasets.MNIST(
    root='./data', 
    train=True, 
    transform=transforms.ToTensor(), 
    download=True)

train_dataloader = DataLoader(
    dataset=train_data,
    batch_size=batch_size,
    shuffle=True
)

vae = VAE()
optimizer = Adam({"lr": lr})
svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())


In [98]:
#training 
for epoch in range(epochs):
    train_loss = 0
    for x, _ in train_dataloader:
        x = x.view(-1, 28*28)
        train_loss += svi.step(x)
        
    train_loss /= len(train_dataloader.dataset)
    print(f"Epoch {epoch + 1}, Loss: {train_loss}")

Epoch 1, Loss: 546.0059057373047
Epoch 2, Loss: 545.9868314229329
Epoch 3, Loss: 545.990926147461
Epoch 4, Loss: 546.0077961425782
Epoch 5, Loss: 545.9942545776368
Epoch 6, Loss: 546.0172881734212
Epoch 7, Loss: 545.9797055480957
Epoch 8, Loss: 545.9897087402344
Epoch 9, Loss: 545.9915152526855
Epoch 10, Loss: 545.9966286946615


## SS VAE (M2 model)

In [171]:
# VAE encoder for z
class Encoder_Z(nn.Module):
    def __init__(self, z_dim=10, input_dim=784, hidden_dim=128, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(input_dim+num_classes, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)

    def forward(self, x):
        hidden = torch.relu(self.fc1(x))
        z_loc = self.fc21(hidden) # mean vector
        z_scale = torch.exp(0.5*self.fc22(hidden)) # covariance vector
        return z_loc, z_scale

In [173]:
# VAE encoder for y
class Encoder_Y(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=128, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        hidden = torch.relu(self.fc1(x))
        y_alpha = torch.softmax(self.fc21(hidden), dim=1)
        return y_alpha

In [175]:
class Decoder(nn.Module):
    def __init__(self, z_dim=10, input_dim=784, hidden_dim=128, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(z_dim + num_classes, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, input_dim)

    def forward(self, z):
        hidden = torch.relu(self.fc1(z))
        loc_img = torch.sigmoid(self.fc21(hidden))
        return loc_img

In [177]:
class SSVAE(nn.Module):
    def __init__(self, z_dim=10, hidden_dim=128, num_classes=10):
        super().__init__()
        # Encoders and decoder
        self.z_dim = z_dim
        self.num_classes = num_classes
        self.output_size = num_classes
        self.encoder_y = Encoder_Y()
        self.encoder_z = Encoder_Z()
        self.decoder = Decoder()

    def model(self, xs, ys=None):
        pyro.module("ss_vae", self)
        batch_size = xs.size(0)

        with pyro.plate("data"):
            # Sample latent variable z
            prior_loc = xs.new_zeros([batch_size, self.z_dim])
            prior_scale = xs.new_ones([batch_size, self.z_dim])
            zs = pyro.sample("z", dist.Normal(prior_loc, prior_scale).to_event(1))

            # Sample labels y (if unobserved)
            alpha_prior = xs.new_ones([batch_size, self.output_size]) / self.output_size
            ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys)

            # Decode (concatenate z and y)
            combined_input = torch.cat([zs, ys], dim=1)
            loc = self.decoder(combined_input)

            # Sample observations x
            pyro.sample("x", dist.Bernoulli(loc).to_event(1), obs=xs)

    def guide(self, xs, ys=None):
        with pyro.plate("data"):
            # If y is not observed, sample it
            if ys is None:
                alpha = self.encoder_y(xs)
                ys = pyro.sample("y", dist.OneHotCategorical(alpha))

            # Sample latent variable z
            combined_input = torch.cat([xs, ys], dim=1)
            loc, scale = self.encoder_z(combined_input)
            pyro.sample("z", dist.Normal(loc, scale).to_event(1))


In [179]:
# Transform: Flatten images and normalize to [0, 1]
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: torch.flatten(x)),
])

mnist_train = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
mnist_test = datasets.MNIST(root="./data", train=False, transform=transform, download=True)

# Split train data into labeled and unlabeled subsets
num_labeled = 1000  # Number of labeled samples
num_unlabeled = len(mnist_train) - num_labeled
labeled_data, unlabeled_data = random_split(mnist_train, [num_labeled, num_unlabeled])

batch_size = 128
labeled_loader = DataLoader(labeled_data, batch_size=batch_size, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)

In [181]:
def one_hot_encode(labels, num_classes=10):
    return torch.eye(num_classes)[labels]

In [185]:
# Initialize the model
z_dim = 10
hidden_dim = 128
ssvae = SSVAE(z_dim=z_dim, hidden_dim=hidden_dim)

optimizer = optim.Adam({"lr": 0.01})

svi = SVI(ssvae.model, ssvae.guide, optimizer, loss=Trace_ELBO())

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    ssvae.train()
    total_loss = 0

    # Train on labeled data
    for x, y in labeled_loader:
        x = x.to(dtype=torch.float32)  
        y = one_hot_encode(y).to(dtype=torch.float32)  
        loss = svi.step(x, y)  
        total_loss += loss

    # Train on unlabeled data
    for x, _ in unlabeled_loader:
        x = x.to(dtype=torch.float32)
        loss = svi.step(x)  
        total_loss += loss

    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(labeled_loader.dataset):.4f}")


ValueError: Error while computing log_prob at site 'x':
Expected value argument (Tensor of shape (128, 784)) to be within the support (Boolean()) of the distribution Bernoulli(probs: torch.Size([128, 784])), but found invalid values:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
                 Trace Shapes:            
                  Param Sites:            
 ss_vae$$$encoder_y.fc1.weight 128 784    
   ss_vae$$$encoder_y.fc1.bias     128    
ss_vae$$$encoder_y.fc21.weight  10 128    
  ss_vae$$$encoder_y.fc21.bias      10    
 ss_vae$$$encoder_z.fc1.weight 128 794    
   ss_vae$$$encoder_z.fc1.bias     128    
ss_vae$$$encoder_z.fc21.weight  10 128    
  ss_vae$$$encoder_z.fc21.bias      10    
ss_vae$$$encoder_z.fc22.weight  10 128    
  ss_vae$$$encoder_z.fc22.bias      10    
   ss_vae$$$decoder.fc1.weight 128  20    
     ss_vae$$$decoder.fc1.bias     128    
  ss_vae$$$decoder.fc21.weight 784 128    
    ss_vae$$$decoder.fc21.bias     784    
                 Sample Sites:            
                        z dist 128   |  10
                         value 128   |  10
                      log_prob 128   |    
                        y dist 128   |  10
                         value 128   |  10
                      log_prob 128   |    
                        x dist 128   | 784
                         value 128   | 784

In [None]:
import matplotlib.pyplot as plt

ssvae.eval()
test_iter = iter(test_loader)
x, y = next(test_iter)
x = x.to(dtype=torch.float32)

# Reconstruct images
with torch.no_grad():
    loc_img = ssvae.decoder(torch.cat([ssvae.encoder_y(x), x], dim=1))

def visualize_reconstruction(original, reconstructed):
    plt.figure(figsize=(8, 4))
    for i in range(8):
        # Original images
        plt.subplot(2, 8, i + 1)
        plt.imshow(original[i].view(28, 28), cmap="gray")
        plt.axis("off")
        # Reconstructed images
        plt.subplot(2, 8, i + 9)
        plt.imshow(reconstructed[i].view(28, 28), cmap="gray")
        plt.axis("off")
    plt.show()

visualize_reconstruction(x[:8], loc_img[:8])

In [None]:
correct = 0
total = 0

with torch.no_grad():
    for x, y in test_loader:
        x = x.to(dtype=torch.float32)
        y_onehot = ssvae.encoder_y(x)
        predicted = torch.argmax(y_onehot, dim=1)
        correct += (predicted == y).sum().item()
        total += y.size(0)

print(f"Test Accuracy: {100 * correct / total:.2f}%")