In [None]:
def get_nll(self, data, n_samples=1, batch_size=100):
        """
        Function computed the estimate negative log-likelihood of the model. It uses importance
        sampling method with the approximate posterior distribution. This may take a while.

        Args:
            data (torch.Tensor): The input data from which the log-likelihood should be estimated.
                Data must be of shape [Batch x n_channels x ...]
            n_samples (int): The number of importance samples to use for estimation
            batch_size (int): The batchsize to use to avoid memory issues
        """

        if n_samples <= batch_size:
            n_full_batch = 1
        else:
            n_full_batch = n_samples // batch_size
            n_samples = batch_size

        log_p = []

        for i in range(len(data)):
            x = data[i].unsqueeze(0)
            encoder_output = self.encoder(x)
            mu, log_var = encoder_output.embedding, encoder_output.log_covariance

            log_p_x = []

            for j in range(n_full_batch):
                x_rep = torch.cat(batch_size * [x])

                encoder_output = self.encoder(x_rep)
                mu, log_var = encoder_output.embedding, encoder_output.log_covariance

                std = torch.exp(0.5 * log_var)
                z, eps = self._sample_gauss(mu, std)

                log_q_z_given_x = -0.5 * (
                    log_var + (z - mu) ** 2 / torch.exp(log_var)
                ).sum(dim=-1)
                log_p_z = self._log_p_z(z)

                recon_x = self.decoder(z)["reconstruction"]

                #if self.model_config.reconstruction_loss == "mse":
                log_p_x_given_z = -0.5 * F.mse_loss(
                    recon_x.reshape(x_rep.shape[0], -1),
                    x_rep.reshape(x_rep.shape[0], -1),
                    reduction="none",
                ).sum(dim=-1) - torch.tensor(
                    [np.prod(self.input_dim) / 2 * np.log(np.pi * 2)]
                ).to(
                    data.device)  # decoding distribution is assumed unit variance  N(mu, I)

                #elif self.model_config.reconstruction_loss == "bce":
                #    log_p_x_given_z = -F.binary_cross_entropy(
                #        recon_x.reshape(x_rep.shape[0], -1),
                #        x_rep.reshape(x_rep.shape[0], -1),
                #        reduction="none",
                #    ).sum(dim=-1)

                log_p_x.append(
                    log_p_x_given_z + log_p_z - log_q_z_given_x
                )  # log(2*pi) simplifies

            log_p_x = torch.cat(log_p_x)

            log_p.append((torch.logsumexp(log_p_x, 0) - np.log(len(log_p_x))).item())

            if i % 1000 == 0:
                print(f"Current nll at {i}: {np.mean(log_p)}")
        return np.mean(log_p)

In [None]:
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from torchvision.transforms import v2


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as Datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.models as models
import torchvision.utils as vutils
from torch.hub import load_state_dict_from_url
from torchvision import datasets, transforms
from torch.utils.data import random_split

import os
import random
import numpy as np
import math
from IPython.display import clear_output
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import trange, tqdm
from torchvision.transforms import functional as TF
import PIL

latent_channels = 64
batch_size = 64
image_size = 128
lr = 5e-4
nepoch = 100
start_epoch = 0
dataset_root = ""
model_name = "oxford_pets_VAE_VAMP" 
torch.manual_seed(7)  # For reproducibility

save_dir = os.getcwd()
load_checkpoint  = True


use_cuda = torch.cuda.is_available()
gpu_indx  = 0
device = torch.device(gpu_indx if use_cuda else "cpu")



In [None]:
import torch
from PIL import Image
from torch.utils.data import Dataset
import os

class CustomImageDataset(Dataset):
    def __init__(self, data_dir, file_names, transform=None, num_black_images=350, num_white_images=350, image_size=(3, 128, 128) ):
        self.data_dir = data_dir
        self.transform = transform
        self.file_names = file_names
        self.image_size = image_size

        # Adad placeholders for black images
        self.black_image_placeholder = "<black_image>"
        self.file_names.extend([self.black_image_placeholder] * num_black_images)
        self.white_image_placeholder = "<white_image>"
        self.file_names.extend([self.white_image_placeholder] * num_white_images)

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        if self.file_names[idx] == self.black_image_placeholder:
            # Create a black image
            #black_image = torch.zeros(3, 256, 256)  # Adjust size as needed
            black_image = torch.zeros(self.image_size)
            black_image = TF.to_pil_image(black_image)
            return transforms.ToTensor()(black_image) #self.transform(black_image) if self.transform else black_image
        
        if self.file_names[idx] == self.white_image_placeholder:
            # Create a black image
            #black_image = torch.zeros(3, 256, 256)  # Adjust size as needed
            white_image = torch.ones(self.image_size)
            white_image = TF.to_pil_image(white_image)
            return transforms.ToTensor()(white_image) #self.transform(white_image) if self.transform else black_image

        
        img_name = os.path.join(self.data_dir, self.file_names[idx])
        image = Image.open(img_name).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image

train_transform = transforms.Compose([
    transforms.Resize(size=(128, 128)),
    v2.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_transform = transforms.Compose([transforms.Resize(size=(128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
#!pip install scikit-learn

from sklearn.model_selection import train_test_split

# Assuming all your images are in 'data_dir'
all_images = [img for img in os.listdir(dataset_root) if img.endswith('.jpg')]  # Adjust for your file type
train_images, test_images = train_test_split(all_images, test_size=0.05, random_state=7)

train_dataset = CustomImageDataset(data_dir=dataset_root, file_names=train_images, transform=train_transform, num_black_images=500)
test_dataset = CustomImageDataset(data_dir=dataset_root, file_names=test_images, transform=test_transform, num_black_images=500)




In [None]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

In [None]:
dataiter = iter(train_loader)
train_images = next(dataiter)
train_images.shape

test_dataiter = iter(test_loader)
test_images = next(test_dataiter)
test_images.shape

In [None]:
plt.figure(figsize = (20,10))
out = vutils.make_grid(test_images[0:64], normalize=True)
plt.imshow(out.numpy().transpose((1, 2, 0)))

In [None]:
import torch
import torch.nn as nn
import torch.utils.data


class ResDown(nn.Module):
    """
    Residual down sampling block for the encoder
    """

    def __init__(self, channel_in, channel_out, kernel_size=3):
        super(ResDown, self).__init__()
        self.conv1 = nn.Conv2d(channel_in, channel_out // 2, kernel_size, 2, kernel_size // 2)
        self.bn1 = nn.BatchNorm2d(channel_out // 2, eps=1e-4)
        self.conv2 = nn.Conv2d(channel_out // 2, channel_out, kernel_size, 1, kernel_size // 2)
        self.bn2 = nn.BatchNorm2d(channel_out, eps=1e-4)

        self.conv3 = nn.Conv2d(channel_in, channel_out, kernel_size, 2, kernel_size // 2)

        self.act_fnc = nn.SiLU()

    def forward(self, x):
        skip = self.conv3(x)
        x = self.act_fnc(self.bn1(self.conv1(x)))
        x = self.conv2(x)
        return self.act_fnc(self.bn2(x+ skip))


class ResUp(nn.Module):
    """
    Residual up sampling block for the decoder
    """

    def __init__(self, channel_in, channel_out, kernel_size=3, scale_factor=2):
        super(ResUp, self).__init__()

        self.conv1 = nn.Conv2d(channel_in, channel_in // 2, kernel_size, 1, kernel_size // 2)
        self.bn1 = nn.BatchNorm2d(channel_in // 2, eps=1e-4)
        self.conv2 = nn.Conv2d(channel_in // 2, channel_out, kernel_size, 1, kernel_size // 2)
        self.bn2 = nn.BatchNorm2d(channel_out, eps=1e-4)

        self.conv3 = nn.Conv2d(channel_in, channel_out, kernel_size, 1, kernel_size // 2)

        self.up_nn = nn.Upsample(scale_factor=scale_factor, mode="nearest")

        self.act_fnc = nn.ELU()

    def forward(self, x):
        x = self.up_nn(x)
        skip = self.conv3(x)
        x = self.act_fnc(self.bn1(self.conv1(x)))
        x = self.conv2(x)

        return self.act_fnc(self.bn2(x + skip))


class ResBlock(nn.Module):
    """
    Residual block
    """

    def __init__(self, channel_in, channel_out, kernel_size=3, act=nn.ELU()):
        super(ResBlock, self).__init__()

        self.conv1 = nn.Conv2d(channel_in, channel_in // 2, kernel_size, 1, kernel_size // 2)
        self.bn1 = nn.BatchNorm2d(channel_in // 2, eps=1e-4)
        self.conv2 = nn.Conv2d(channel_in // 2, channel_out, kernel_size, 1, kernel_size // 2)
        self.bn2 = nn.BatchNorm2d(channel_out, eps=1e-4)

        if not channel_in == channel_out:
            self.conv3 = nn.Conv2d(channel_in, channel_out, kernel_size, 1, kernel_size // 2)
        else:
            self.conv3 = nn.Identity()

        self.act_fnc = act

    def forward(self, x):
        skip = self.conv3(x)
        x = self.act_fnc(self.bn1(self.conv1(x)))
        x = self.conv2(x)

        return self.act_fnc(self.bn2(x+skip))


class Encoder(nn.Module):
    """
    Encoder block
    """

    def __init__(self, channels, ch=64, blocks=(1, 2, 4, 8), latent_channels=512):
        super(Encoder, self).__init__()
        self.conv_in = nn.Conv2d(channels, blocks[0] * ch, 3, 1, 1)

        widths_in = list(blocks)
        widths_out = list(blocks[1:]) + [blocks[-1]]

        layer_blocks = []

        for w_in, w_out in zip(widths_in, widths_out):
            layer_blocks.append(ResDown(w_in * ch, w_out * ch))

        layer_blocks.append(ResBlock(blocks[-1] * ch, blocks[-1] * ch, kernel_size=3, act=nn.SiLU()))
        layer_blocks.append(ResBlock(blocks[-1] * ch, blocks[-1] * ch, kernel_size=3, act=nn.SiLU()))

        self.res_blocks = nn.Sequential(*layer_blocks)

        self.conv_mu = nn.Conv2d(blocks[-1] * ch, latent_channels, 1, 1)
        self.conv_log_var = nn.Conv2d(blocks[-1] * ch, latent_channels, 1, 1)
        self.act_fnc = nn.SiLU()

    def sample(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, sample=False):
        x = self.act_fnc(self.conv_in(x))
        x = self.res_blocks(x)

        mu = self.conv_mu(x)
        log_var = self.conv_log_var(x)


        return mu, log_var


class Decoder(nn.Module):
    """
    Decoder block
    Built to be a mirror of the encoder block
    """

    def __init__(self, channels, ch=64, blocks=(1, 2, 4, 8), latent_channels=512):
        super(Decoder, self).__init__()
        self.conv_in = nn.Conv2d(latent_channels, ch * blocks[-1], 1, 1)

        widths_out = list(blocks)[::-1]
        widths_in = (list(blocks[1:]) + [blocks[-1]])[::-1]

        layer_blocks = [ResBlock(blocks[-1] * ch, blocks[-1] * ch, kernel_size=3, act=nn.ELU()),
                        ResBlock(blocks[-1] * ch, blocks[-1] * ch, kernel_size=3, act=nn.ELU())]

        for w_in, w_out in zip(widths_in, widths_out):
            layer_blocks.append(ResUp(w_in * ch, w_out * ch))

        self.res_blocks = nn.Sequential(*layer_blocks)

        self.conv_out = nn.Conv2d(blocks[0] * ch, channels, 3, 1, 1)
        self.act_fnc = nn.ELU()

    def forward(self, x):
        x = self.act_fnc(self.conv_in(x))
        x = self.res_blocks(x)
        mu = torch.tanh(self.conv_out(x))
        return mu

In [None]:
class VAMPVAE(nn.Module):
    # Constructor remains the same
    def __init__(self, input_channels, latent_channels=512, ch=64, blocks=(1, 2, 4, 8), number_components=10):
        super(VAMPVAE, self).__init__()
        
        self.input_channels = input_channels
        self.latent_channels = latent_channels
        # Assuming the input size for the encoder is 192x192
        self.expected_height = 128
        self.expected_width = 128
        
        self.encoder = Encoder(channels=input_channels, ch=ch, blocks=blocks, latent_channels=latent_channels)
        self.decoder = Decoder(channels=input_channels, ch=ch, blocks=blocks, latent_channels=latent_channels)
        self.number_components = number_components

        # Pseudo-inputs network
        self.latent_height = 8  # Example value, adjust as needed
        self.latent_width = 8   # Example value, adjust as needed
        self.pseudo_input_height = 128  # Adjust based on your model architecture
        self.pseudo_input_width = 128   # Adjust based on your model architecture

        self.pseudo_inputs = nn.Sequential(
            nn.Linear(number_components, 3*self.pseudo_input_height*self.pseudo_input_width),
            nn.Hardtanh(0.0, 1.0)
        )

        self.idle_input = torch.eye(number_components, requires_grad=False)



    def forward(self, x, epoch=100):
        mu, log_var = self.encoder(x)
        std = torch.exp(0.5 * log_var)
        z, _ = self._sample_gauss(mu, std)

        recon_x = self.decoder(z)

        loss, recon_loss, kld = self.loss_function(recon_x, x, mu, log_var, z, epoch)

        return {
            'recon_loss': recon_loss,
            'reg_loss': kld,
            'loss': loss,
            'recon_x': recon_x,
            'z': z}

    def loss_function(self, recon_x, x, mu, log_var, z, epoch):
        # Assuming the reconstruction loss is MSE, modify as needed
        recon_loss = F.mse_loss(recon_x, x, reduction='sum') / x.size(0)

        log_p_z = self._log_p_z(z)
        log_q_z = (-0.5 * (log_var + torch.pow(z - mu, 2) / log_var.exp()+ 1e-4)).sum(dim=[1, 2, 3])
    
        KLD = -(log_p_z - log_q_z)
        

        # Linear scheduling for beta
        beta = 1.0

        return recon_loss + beta * KLD, recon_loss, KLD

    def _log_p_z(self, z):
        
        C = self.number_components
        pseudo_inputs = self.pseudo_inputs(self.idle_input.to(z.device))
        
        pseudo_inputs = pseudo_inputs.view(C, 3, self.pseudo_input_height, self.pseudo_input_height)
        

        # Pass the pseudo inputs through the encoder
        pseudo_mu, pseudo_log_var = self.encoder(pseudo_inputs)

        # Calculate the actual latent dimension size per component
        total_elements = pseudo_mu.numel()  # Total number of elements in pseudo_mu
        latent_dim_per_component = total_elements // C  # Dividing by the number of components

        # Flatten pseudo_mu and pseudo_log_var to match this size
        pseudo_mu_flat = pseudo_mu.view(C, latent_dim_per_component)
        pseudo_log_var_flat = pseudo_log_var.view(C, latent_dim_per_component)

        # Flatten z and expand for broadcasting
        z_flat = z.view(z.size(0), -1)
        z_expand = z_flat.unsqueeze(1)  # Shape: (batch_size, 1, latent_dim_per_component)

        # Calculate log probabilities
        log_p_z = -0.5 * (pseudo_log_var_flat + (z_expand - pseudo_mu_flat) ** 2 / torch.exp(pseudo_log_var_flat))
        log_p_z = log_p_z.sum(dim=2) - torch.log(torch.tensor(C, dtype=torch.float, device=z.device))
        log_p_z = torch.logsumexp(log_p_z, dim=1)


        return log_p_z

    def _sample_gauss(self, mu, std):
        eps = torch.randn_like(std)
        return mu + eps * std, eps
    
    

In [None]:
vae_net = VAMPVAE(input_channels=3, ch=64, blocks=(1, 2, 4, 8), latent_channels=64, number_components=50).to(device)
# setup optimizer
optimizer = optim.Adam(vae_net.parameters(), lr=0.0001, betas=(0.5, 0.999))
#Loss function
loss_log = []
from torchsummary import summary
summary(vae_net, (3,128,128))

In [None]:
enc = Encoder(channels=3, ch=64, blocks=(1, 2, 4, 8), latent_channels=64).to(device)
#summary(enc, (3,128,128))

In [None]:
#sample_input = torch.randn(4, 3, 128, 128).cuda()  # Replace with appropriate dimensions
reconstructed_output = vae_net(sample_input)['loss']
print(reconstructed_output.shape)

In [None]:
loss = vae_net(sample_input)['loss']
loss.shape

In [None]:
nepoch = 300

                
for epoch in trange(start_epoch, nepoch, leave=False):
    train_loss = 0
    train_recon_loss = 0
    train_kld_loss = 0
    
    vae_net.train()
    for i, images in enumerate(tqdm(train_loader, leave=False)):
        images = images.to(device)

        recon_img = vae_net(images)['recon_x']
        #VAE loss
        kl_loss_ = vae_net(images)['reg_loss'].mean()
        mse_loss = vae_net(images)['recon_loss']

        loss = vae_net(images)['loss'].mean()
        
        train_loss += loss.item()
        train_recon_loss += mse_loss.item()
        train_kld_loss += kl_loss_.item()
        #train_perceptual_loss += feature_loss.item()
        
        loss_log.append(loss.item())
        vae_net.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(vae_net.parameters(), max_norm=1.0)

        optimizer.step()
        
    avg_loss = train_loss / len(train_loader)  # Calculate average loss for the epoch
    avg_recon_loss = train_recon_loss / len(train_loader) 
    avg_kld_loss = train_kld_loss / len(train_loader) 
    #avg_perceptual_loss = train_perceptual_loss / len(train_loader)

    #In eval mode the model will use mu as the encoding instead of sampling from the distribution
    vae_net.eval()
    with torch.no_grad():
        recon_img = vae_net(test_images.to(device))['recon_x']
        img_cat = torch.cat((recon_img.cpu(), test_images), 2)

        vutils.save_image(img_cat,
                          "%s/%s/%s_%d.png" % (save_dir, "Results" , model_name, image_size),
                          normalize=True)

        #Save a checkpoint
        torch.save({
                    'epoch'                         : epoch,
                    'loss_log'                      : loss_log,
                    'model_state_dict'              : vae_net.state_dict(),
                    'optimizer_state_dict'          : optimizer.state_dict()

                     }, save_dir + "/Models/" + model_name + "_" + str(image_size) + ".pt")
    print(f'Epoch {epoch}/{nepoch} - Avg Total Loss: {avg_loss} - Avg Recon Loss: {avg_recon_loss}\
    - Avg KLD Loss: {avg_kld_loss}')
