## Imports and Hyperparameters

In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import torch.optim as optim
import plotly.graph_objects as go
import plotly.express as px
import matplotlib.pyplot as plt
import torchvision
from torchvision import datasets, models, transforms
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from tqdm.notebook import tqdm
import pandas as pd
import numpy as np
from typing import Tuple
from tqdm import tqdm
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import random
torch.manual_seed(42)
np.random.seed(42)
random.seed(0)
hidden_width = 64
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [2]:
num_classes = 10
latent_dims = 32
hidden_channels = 12
hidden_width = 64

## Utils

In [3]:
def bayes_classifier(x, enc, dec, dimY, lowerbound, K = 1, beta=1.0):
    N = x.shape[0]
    logpxy = []
    for i in range(dimY):
        y = torch.zeros([N, dimY]).to(device)
        y[:, i] = 1
        bound = lowerbound(x, y,enc, dec, K,
                           IS=True, beta=beta)
        logpxy.append(torch.unsqueeze(bound, 1))
    logpxy = torch.concat(logpxy, 1)
    pyx = F.softmax(logpxy,dim=1)
    return pyx
def log_gaussian_prob(x, mu, log_sig):
    logprob = -(0.5 * np.log(2 * np.pi) + log_sig) \
                - 0.5 * ((x - mu) / torch.exp(log_sig)) ** 2
    ind = list(range(2, len(x.shape)))
    return torch.sum(logprob, ind)


def encoding(enc, x, y, K):
    mu_qz, log_sig_qz = enc(x, y)
    ph = torch.zeros([K]+list(mu_qz.shape))
    norm_sample = torch.normal(ph).to(device)
    samples = mu_qz+torch.unsqueeze(log_sig_qz.exp(), dim=0)*norm_sample
    logq = log_gaussian_prob(samples, mu_qz.unsqueeze(0), log_sig_qz.unsqueeze(0))
    return samples, logq

class Encoder(nn.Module):
    def __init__(self, hidden_channels: int, latent_dim: int, num_labels: int) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels=1,
            out_channels=hidden_channels,
            kernel_size=4,
            stride=2,
            padding=1,
        )

        self.conv2 = nn.Conv2d(
            in_channels=hidden_channels,
            out_channels=hidden_channels * 2,
            kernel_size=4,
            stride=2,
            padding=1,
        )

        self.fc = nn.Linear(hidden_channels * 2 * 7 * 7 + num_labels, hidden_width)

        self.fc_mu = nn.Linear(
            in_features=hidden_width,
            out_features=latent_dim,
        )
        self.fc_logvar = nn.Linear(
            in_features=hidden_width,
            out_features=latent_dim,
        )

        self.activation = nn.ReLU()

    def forward(
        self, x: torch.Tensor, y: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        h = self.activation(self.conv1(x))
        h = self.activation(self.conv2(h))
        h = h.view(h.size(0), -1)  # Flatten the tensor
        h = torch.cat(
            (h, y), dim=1
        )
        h = self.activation(self.fc(h))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    

## GFZ

In [4]:
def lowerbound_gbz(x, y, enc, dec, K=1, IS=False,  beta=1.0):

    z, logq = encoding(enc,x, y, K)
    log_prior_z = log_gaussian_prob(z, torch.zeros_like(z), torch.zeros_like(z))

    pxz, pyz = dec(torch.flatten(z,end_dim=1))
    pxz = torch.stack(torch.tensor_split(pxz,K,dim=0),dim=0)
    pyz = torch.stack(torch.tensor_split(pyz,K,dim=0),dim=0)
    # print(pyz.shape)

    ind = list(range(2, len(x.shape)+1))
    logp = -torch.sum((x.unsqueeze(0) - pxz)**2, dim=ind)
    logit_y = F.softmax(pyz,dim=2)

    y_rep = torch.stack([y  for i in range(K)],dim=0)
    log_pyz = -F.cross_entropy(logit_y.flatten(end_dim=1), y_rep.flatten(end_dim=1),
                               reduction='none').reshape(y_rep.shape[:-1])
    bound = logp * beta + log_pyz + (log_prior_z - logq)
    if IS and K > 1:
        bound = torch.logsumexp(bound,dim=0) - np.log(float(K))
    return bound.squeeze()


class Decoder_gbz(nn.Module):
    def __init__(self, hidden_channels: int, latent_dim: int, num_labels: int) -> None:
        super().__init__()
        self.hidden_channels = hidden_channels

        # MLP for p(y|z)
        self.fc_py_z = nn.Linear(latent_dim, 500)
        self.fc_py_z1 = nn.Linear(500, num_labels)

        # MLP for p(x|z)
        self.fc_px_z = nn.Linear(latent_dim, hidden_channels * 2 * 7 * 7)

        self.conv2 = nn.ConvTranspose2d(
            in_channels=hidden_channels * 2,
            out_channels=hidden_channels,
            kernel_size=4,
            stride=2,
            padding=1,
        )
        self.conv1 = nn.ConvTranspose2d(
            in_channels=hidden_channels,
            out_channels=1,
            kernel_size=4,
            stride=2,
            padding=1,
        )

        self.activation = nn.ReLU()

    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Use the MLP to get the distribution over labels
        y = self.activation(self.fc_py_z(z))
        y = self.fc_py_z1(y)

        h = self.activation(self.fc_px_z(z))
        h = h.view(h.size(0), self.hidden_channels * 2, 7, 7)  # Reshape the tensor

        # Use the rest of the decoder to get the reconstructed image
        h = self.activation(self.conv2(h))
        x_recon = torch.sigmoid(self.conv1(h))

        return x_recon, y


## DBX

In [5]:

class Decoder_dbx(nn.Module):
    def __init__(self,latent_dim: int, num_labels: int, hidden_channels: int):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels=1,
            out_channels=hidden_channels,
            kernel_size=4,
            stride=2,
            padding=1,
        )
        self.conv2 = nn.Conv2d(
            in_channels=hidden_channels,
            out_channels=hidden_channels * 2,
            kernel_size=4,
            stride=2,
            padding=1,
        )
        self.fc = nn.Linear(hidden_channels * 2 * 7 * 7, hidden_width)
        self.fc_mu = nn.Linear(
            in_features=hidden_width,
            out_features=latent_dim,
        )
        self.fc_logvar = nn.Linear(
            in_features=hidden_width,
            out_features=latent_dim,
        )
        self.fc_y = nn.Sequential(nn.Linear(latent_dim, hidden_width),
                                  nn.ReLU(),
                                  nn.Linear(hidden_width,num_labels))
        

        self.activation = nn.ReLU()

    def forward(
        self, x: torch.Tensor, z: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        h = self.activation(self.conv1(x))
        h = self.activation(self.conv2(h))
        h = h.view(h.size(0), -1)  # Flatten the tensor
        h = self.activation(self.fc(h))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar, self.fc_y(z)


def lowerbound_dbx(x, y, enc, dec, K=1, IS=False, beta=1.0):
 
    z, logq = encoding(enc,x, y,K)
    mu_pz, logsig_pz, logit_y = dec(x, torch.flatten(z, end_dim=1)) 
    logit_y = F.softmax(logit_y,dim=1)
    logit_y = torch.stack(torch.tensor_split(logit_y,K,dim=0),dim=0)
    
    log_pzx = log_gaussian_prob(z, mu_pz, logsig_pz)

    y_rep = torch.stack([y  for i in range(K)],dim=0)
    log_pyz = -F.cross_entropy(logit_y.flatten(end_dim=1), y_rep.flatten(end_dim=1),
                               reduction='none').reshape(y_rep.shape[:-1])    # mu_pz = tf.tile(mu_pz, [K, 1])
    # print(log_pzx.shape, log_pyz.shape, logq.shape)
    bound = log_pzx + log_pyz - beta * logq
    if IS and K > 1:
        bound = torch.logsumexp(bound,dim=0) - np.log(float(K))
    return bound.squeeze()


## Main

In [6]:

batch_size = 128
learning_rate = 1e-5
beta = 1
img_transform = transforms.Compose([
    transforms.ToTensor()
])
#to be switched
lowerbound = lowerbound_dbx
Decoder = Decoder_dbx


encoder = Encoder(hidden_channels=hidden_channels,latent_dim=latent_dims,num_labels=num_classes).to(device)
decoder = Decoder(hidden_channels=hidden_channels,latent_dim=latent_dims,num_labels=num_classes).to(device)
optimizer1 = torch.optim.Adam(params=encoder.parameters(), lr=learning_rate)
optimizer2 = torch.optim.Adam(params=decoder.parameters(), lr=learning_rate)




train_dataset = MNIST(root='./data/MNIST', download=True, train=True, transform=img_transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = MNIST(root='./data/MNIST', download=True, train=False, transform=img_transform)
test_dataloader = DataLoader(test_dataset, batch_size=max(10000, batch_size), shuffle=True)

for epoch in tqdm(range(1, 50+1)):
    # train_loss_averager = make_averager()
    loss = []
    Acc = []
    for image_batch, labels in train_dataloader:
        # print(image_batch.shape, labels.shape)
        image_batch, labels = image_batch.to(device), labels.to(device)
        bound = torch.mean(-lowerbound(image_batch,F.one_hot(labels,10).to(device).float(),encoder,decoder,K=2 ))

        optimizer1.zero_grad()
        optimizer2.zero_grad()

        bound.backward()

        # one step of the optmizer
        optimizer1.step()
        optimizer2.step()

        preds = bayes_classifier(image_batch,encoder, decoder,10, lowerbound,K=10)
        loss.append(bound.detach().item())
        Acc.append(torch.mean((torch.argmax(preds, 1)==labels)*1.0).item())
    print(sum(loss)/len(loss))
    print(sum(Acc)/len(Acc))

        

  0%|          | 0/50 [00:00<?, ?it/s]