For CelebA dataset (both VAE and beta-VAE)

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from typing import List, Callable, Union, Any, TypeVar, Tuple
import pytorch_lightning as pl
from torchvision import datasets, transforms
from torch.autograd import Variable

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np

In [None]:
def plottingGrid(model, test_loader, type):

    count = 0
    model = model.to('cuda:0')

    for batch in test_loader:
    # for every batch in test_loader:
        count += 1
        
        if count>1:
            break
        x, y = batch
        x = x.cuda(0)
        y = y.cuda(0)
        
        z = model(x)
        x = x.cpu()

    fig = plt.figure(figsize=(12., 12.))
    grid = ImageGrid(fig, 111,  # similar to subplot(111)
                    nrows_ncols=(10, 10),  # creates 2x2 grid of axes
                    axes_pad=0.1,  # pad between axes in inch.
                    )
    latte = []
    # print(z[0].shape)
    for i in range(100):
        f = z[0][i].permute(1,2,0)
        latte.append(f.cpu().detach().numpy())

    for ax, im in zip(grid, latte):
        # Iterating over the grid returns the Axes.
        ax.imshow(im)

    fig.savefig(f'grid_{type}_thirty_epochs.png')
    plt.show()

    return

In [None]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class UnFlatten(nn.Module):
    def forward(self, input, size=256):
        return input.view(input.size(0), size, 19, 14)

class VAE(pl.LightningModule):
    def __init__(self, image_channels=3, h_dim=19*14*256, z_dim=32, lr = 1e-3):
        self.lr = lr
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 32, kernel_size=5, stride=2, padding = 2),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(),
            nn.Conv2d(32, 64, kernel_size=5, stride=2, padding = 2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding = 2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(128, 256, kernel_size=10, stride=1, padding = 0),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            Flatten()
        )
        
        self.fc1 = nn.Linear(h_dim, z_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(z_dim, h_dim)
        
        self.decoder = nn.Sequential(
            UnFlatten(),
            nn.ConvTranspose2d(256, 128, kernel_size=10, stride=1, padding = 0),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding = 2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=5, stride=2, padding = 2),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(32, image_channels, kernel_size=5, stride=2, padding = 2, output_padding=1),
            nn.BatchNorm2d(3),
            nn.Sigmoid(),
        )
        
    def reparameterize(self, mu, logvar):
        mu = mu.cuda(0)
        logvar = logvar.cuda(0)
        std = logvar.mul(0.5).exp_()
        std = std.cuda(0)
        esp = torch.randn(*mu.size())
        esp = esp.cuda(0)
        z = mu + std * esp
        z = z.cuda(0)
        return z
    
    def bottleneck(self, h):
        mu, logvar = self.fc1(h), self.fc2(h)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar
        
    def representation(self, x):
        return self.bottleneck(self.encoder(x))[0]

    def forward(self, x):
        x = x.cuda(0)
        h = self.encoder(x)
        z, mu, logvar = self.bottleneck(h)
        z = z.cuda(0)
        mu = mu.cuda(0)
        logvar = logvar.cuda(0)
        
        z = self.fc3(z)
        return [self.decoder(z), mu, logvar]
    

    def loss_fn(self, recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
        kl_divergence = -0.5 * torch.sum(1 + logvar - mu**2 -  logvar.exp())
        return BCE + kl_divergence
    
    
    def loss_function(self,recons,x,mu,logvar):
        # Account for the minibatch samples from the dataset; M_N = self.params['batch_size']/ self.num_train_imgs
        kld_weight = 0.5
        recons_loss =F.mse_loss(recons, x,reduction="sum")
        kld_loss = torch.sum(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim = 1), dim = 0)
        loss = recons_loss + kld_weight * kld_loss
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
    
    counter=0
    
    def training_step(self, train_batch, batch_idx):
        self.counter+=1
        x, y = train_batch
        x = x.float()
        z, mu, logvar = self(x)
        loss = self.loss_function(z, x, mu, logvar)
        if self.counter % 100 == 0:
            print(loss)
        self.log('train_loss', loss)
        return loss

In [None]:
class betaVAE(pl.LightningModule):
    def __init__(self, image_channels=3, h_dim=19*14*256, z_dim=32, lr = 1e-3, beta = 150):
        self.lr = lr
        self.beta = beta
        super(betaVAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 32, kernel_size=5, stride=2, padding = 2),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(),
            nn.Conv2d(32, 64, kernel_size=5, stride=2, padding = 2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding = 2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(128, 256, kernel_size=10, stride=1, padding = 0),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            Flatten()
        )
        
        self.fc1 = nn.Linear(h_dim, z_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(z_dim, h_dim)
        
        self.decoder = nn.Sequential(
            UnFlatten(),
            nn.ConvTranspose2d(256, 128, kernel_size=10, stride=1, padding = 0),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding = 2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=5, stride=2, padding = 2),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(32, image_channels, kernel_size=5, stride=2, padding = 2, output_padding=1),
            nn.BatchNorm2d(3),
            nn.Sigmoid(),
        )
        
    def reparameterize(self, mu, logvar):
        mu = mu.cuda(0)
        logvar = logvar.cuda(0)
        std = logvar.mul(0.5).exp_()
        std = std.cuda(0)
        esp = torch.randn(*mu.size())
        esp = esp.cuda(0)
        z = mu + std * esp
        z = z.cuda(0)
        return z
    
    def bottleneck(self, h):
        mu, logvar = self.fc1(h), self.fc2(h)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar
        
    def representation(self, x):
        return self.bottleneck(self.encoder(x))[0]

    def forward(self, x):
        x = x.cuda(0)
        h = self.encoder(x)
        z, mu, logvar = self.bottleneck(h)
        z = z.cuda(0)
        mu = mu.cuda(0)
        logvar = logvar.cuda(0)
        
        z = self.fc3(z)
        return [self.decoder(z), mu, logvar]    
    
    def loss_function(self, recons, x, mu, logvar, beta):
        # Account for the minibatch samples from the dataset; M_N = self.params['batch_size']/ self.num_train_imgs
        kld_weight = 0.5
        recons_loss =F.mse_loss(recons, x,reduction="sum")
        kld_loss = torch.sum(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim = 1), dim = 0)
        loss = recons_loss + beta * kld_weight * kld_loss
        loss = loss.cuda(0)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
    
    counter=0
    
    def training_step(self, train_batch, batch_idx):
        self.counter+=1
        x,y= train_batch
        x=x.float()
        x = x.cuda(0)
        y = y.cuda(0)
        z, mu, logvar = self(x)
        z = z.cuda(0)
        mu = mu.cuda(0)
        logvar = logvar.cuda(0)
        loss = self.loss_function(z, x, mu, logvar, self.beta)
        loss = loss.cuda(0)
        if self.counter%50 ==0:
            print(loss)
        self.log('train_loss', loss)
        return loss

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np

count = 0
model = model.to('cuda:0')

for batch in test_loader:
# for batch in test_loader:
    count+=1;
    # batch = batch.cuda()
    
    if count>1:
        break
    print("*****************************************************************************************************************")
    x,y = batch
    x = x.cuda(0)
    y = y.cuda(0)
    # print(x.shape)
    
    z = model(x)
#     print(z[0].shape)
#     print(type(z))
#     print(type(z[0]))
    x = x.cpu()
    # z = z.cpu()

fig = plt.figure(figsize=(12., 12.))
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(10, 10),  # creates 2x2 grid of axes
                 axes_pad=0.1,  # pad between axes in inch.
                 )
latte = []
print(z[0].shape)
for i in range(100):
    f = z[0][i].permute(1,2,0)
    latte.append(f.cpu().detach().numpy())

for ax, im in zip(grid, latte):
    # Iterating over the grid returns the Axes.
    ax.imshow(im)
# plt.save('grid_vae_ten_epochs.png')
fig.savefig('grid_beta_1_vae_thirty_epochs.png')
plt.show()

In [None]:
if __name__ == '__main__':

    # for celeba dataset
    data_path = '../data/celeba/img_align_celeba'

    dataset = datasets.ImageFolder(
        root=data_path,
        transform=transforms.ToTensor()
    )

    sets=torch.utils.data.random_split(dataset, [200000, 2599], generator=torch.Generator().manual_seed(2147483647))

    train_loader = torch.utils.data.DataLoader(
        sets[0],
        batch_size=128,
        num_workers=6,
        shuffle=True
    )
    train_test_loader = torch.utils.data.DataLoader(
        sets[0],
        batch_size=128,
        num_workers=6,
    )
    test_loader = torch.utils.data.DataLoader(
        sets[1],
        batch_size=128,
        num_workers=6,
    )

    for batch in train_loader:
        x, y = batch
        x = x.cuda(0)
        y = y.cuda(0)
        break

    # Vanilla VAE
    model = VAE(image_channels=3, z_dim=32, lr =1e-5 )
    trainer = pl.Trainer(auto_scale_batch_size=True , max_epochs = 30, devices = 1, accelerator='gpu')
    trainer.fit(model, train_loader) 
    plottingGrid(model, test_loader, 'vanillaVAE')

    # beta VAE
    # beta = 1
    model1 = betaVAE(image_channels=3, z_dim=32, lr =1e-5, beta = 1)
    trainer = pl.Trainer(auto_scale_batch_size=True , max_epochs = 30, devices = 1, accelerator='gpu')
    trainer.fit(model1, train_loader) 
    plottingGrid(model1, test_loader, 'betaVAE_1')


    # beta = 4
    model4 = betaVAE(image_channels=3, z_dim=32, lr =1e-5, beta = 4)
    trainer = pl.Trainer(auto_scale_batch_size=True , max_epochs = 30, devices = 1, accelerator='gpu')
    trainer.fit(model4, train_loader) 
    plottingGrid(model4, test_loader, 'betaVAE_4')

    # beta = 150
    model150 = betaVAE(image_channels=3, z_dim=32, lr =1e-5, beta = 150)
    trainer = pl.Trainer(auto_scale_batch_size=True , max_epochs = 30, devices = 1, accelerator='gpu')
    trainer.fit(model150, train_loader) 
    plottingGrid(model150, test_loader, 'betaVAE_150')

For Dsprites dataset

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from typing import List, Callable, Union, Any, TypeVar, Tuple
import pytorch_lightning as pl
from torchvision import datasets, transforms
from torch.autograd import Variable
import numpy as np
import abc

class decoder(nn.Module):

    def __init__(self, img_size, z_dim = 10):
        """
        Model Architecture: 
        - 4 convolutional layers - each with 32 channels and a kernel size of 4*4
        - 2 fc layers - each of 256 unit len
        - 1 fc layer for latent distribution of 20 units, that is, mean and log variance for 10 gaussians
        """
        super(decoder, self).__init__()

        #params
        hidden_channels = 1
        kernel_size = 4
        h_dim = 256
        self.img_size = img_size
        self.reshape = (hidden_channels, kernel_size, kernel_size)
        n_channels = self.img_size[0]
        self.z_dim = z_dim

        #fc layers
        self.linear1 = nn.Linear(z_dim, h_dim)
        self.linear2 = nn.Linear(h_dim, h_dim)
        self.linear3 = nn.Linear(h_dim, 64)

        # conv layers
        kwargs = dict(stride = 2, padding = 1)
        self.conv1 = nn.ConvTranspose2d(hidden_channels, hidden_channels, kernel_size, **kwargs)
        self.conv2 = nn.ConvTranspose2d(hidden_channels, hidden_channels, kernel_size, **kwargs)
        self.conv3 = nn.ConvTranspose2d(hidden_channels, n_channels, kernel_size, **kwargs)
        self.conv4 = nn.ConvTranspose2d(hidden_channels, hidden_channels, kernel_size, **kwargs)

    def forward(self, z):
        batch_size = z.size(0)

        #fc with relu
        x = torch.relu(self.linear1(z))
        x = torch.relu(self.linear2(x))
        x = torch.relu(self.linear3(x))
        x = x.view(batch_size, *(1, 8, 8))

        #conv layers with relu
        x = torch.relu(self.conv4(x))
        x = torch.relu(self.conv1(x))
        # x = torch.relu(self.conv2(x))
        # for final layer activation should be sigmoid
        x = torch.sigmoid(self.conv3(x))

        return x

class encoder(nn.Module):
    def __init__(self, img_size, z_dim = 10):
        """
        Model Architecture: 
        - 4 convolutional layers - each with 32 channels and a kernel size of 4*4
        - 2 fc layers - each of 256 unit len
        - 1 fc layer for latent distribution of 20 units, that is, mean and log variance for 10 gaussians
        """
        super(encoder, self).__init__()

        #params
        hidden_channels = 1
        kernel_size = 4
        h_dim = 256
        self.img_size = img_size
        self.reshape = (hidden_channels, kernel_size, kernel_size)
        n_channels = self.img_size[0]
        self.z_dim = z_dim

        # conv layers
        kwargs = dict(stride = 2, padding = 1)
        self.conv1 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size, **kwargs)
        self.conv2 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size, **kwargs)
        self.conv3 = nn.Conv2d(hidden_channels, n_channels, kernel_size, **kwargs)
        self.conv4 = nn.Conv2d(hidden_channels, hidden_channels, kernel_size, **kwargs)

        #fc layers
        self.linear1 = nn.Linear(64, h_dim)
        self.linear2 = nn.Linear(h_dim, h_dim)

        self.mu_logvar = nn.Linear(h_dim, self.z_dim * 2)

    def forward(self, x):
        batch_size = x.size(0)
        #conv layers with relu

        x = torch.relu(self.conv1(x))
        # x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = torch.relu(self.conv4(x))

        x = x.view((batch_size, -1))

        #fc layers with relu
        x = torch.relu(self.linear1(x))
        x = torch.relu(self.linear2(x))
        
        # for calculating mean and log variance
        mu_logvar = self.mu_logvar(x)
        mu, logvar = mu_logvar.view(-1, self.z_dim, 2).unbind(-1)

        return mu, logvar

class VAE(pl.LightningModule):
    
    def __init__(self, img_size, encoder, decoder, z_dim, beta = 1, lr = 1e-4):
        "model and forward pass"
        super(VAE, self).__init__()

        self.z_dim = z_dim
        self.img_size = img_size
        self.encoder = encoder(img_size, z_dim)
        self.decoder = decoder(img_size, z_dim)
        self.beta = beta
        self.lr = lr
        
    def reparameterize(self, mu, logvar):
        mu = mu.cuda(0)
        logvar = logvar.cuda(0)
        std = logvar.mul(0.5).exp_()
        std = std.cuda(0)
        esp = torch.randn(*mu.size())
        esp = esp.cuda(0)
        z = mu + std * esp
        z = z.cuda(0)
        return z

    def forward(self, x):
        x = x.cuda(0)
        mu, logvar = self.encoder(x)
        z_sample = self.reparameterize(mu, logvar)
        return [self.decoder(z_sample), mu, logvar]    
    
    def loss_function(self, recons, x, mu, logvar):
        # Account for the minibatch samples from the dataset; M_N = self.params['batch_size']/ self.num_train_imgs
        kld_weight = 0.5
        recons_loss =F.mse_loss(recons, x,reduction="sum")
        kld_loss = torch.sum(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim = 1), dim = 0)
        loss = recons_loss + kld_weight * kld_loss * self.beta
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
    
    counter=0
    
    def training_step(self, train_batch, batch_idx):
        self.counter+=1
        x = train_batch
        x=x.float()
        z, mu, logvar = self(x)
        loss = self.loss_function(z, x, mu, logvar)
        if self.counter % 100 ==0:
            print(loss)
        self.log('train_loss', loss)
        return loss

In [None]:
# data_path = '../data/dsprites'
import numpy as np
# Load dataset
dataset_zip = np.load('../data/dsprites/dsprites_ndarray.npz')

print('Keys in the dataset:', dataset_zip.keys())
imgs = dataset_zip['imgs']
latents_values = dataset_zip['latents_values']
latents_classes = dataset_zip['latents_classes']

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np

count = 0
# model = model.to('cuda:0')

for batch in train_data_loader:
# for batch in test_loader:
    count+=1    
    if count>1:
        break
    print("*****************************************************************************************************************")
    x = batch

fig = plt.figure(figsize=(12., 12.))
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(10, 10),  # creates 2x2 grid of axes
                 axes_pad=0.1,  # pad between axes in inch.
                 )
latte = []
for i in range(100):
    f = z[0][i].reshape(64,64)
    latte.append(f.cpu().detach().numpy())

for ax, im in zip(grid, latte):
    # Iterating over the grid returns the Axes.
    ax.imshow(im, cmap = 'gray')
# plt.save('grid_vae_ten_epochs.png')
fig.savefig('dsprites_dataset.png')
plt.show()

In [None]:
from torch.utils.data import Dataset, DataLoader

class dspritesDataset(Dataset):
    
    def __init__(self, X):
        # 'Initialization'
        self.X = X

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        image = self.X[index]
        X = self.transform(image)
        return X

    transform = transforms.Compose([transforms.ToTensor()])

dataset = dspritesDataset(imgs)
batch_size = 128
sets = torch.utils.data.random_split(dataset, [730000, 7280], generator=torch.Generator().manual_seed(2147483647))
train_data_loader = DataLoader(sets[0], batch_size, shuffle = True, num_workers = 3)
test_data_loader = DataLoader(sets[1], batch_size, shuffle = True, num_workers = 3)

In [None]:
img_size = (1, 64, 64)
model = VAE( img_size, encoder = encoder, decoder = decoder, z_dim = 10, lr = 1e-5 )
trainer = pl.Trainer(auto_scale_batch_size=True , max_epochs = 20, devices = 1, accelerator='gpu')
trainer.fit(model, train_data_loader) 

In [None]:
plottingGrid(model, train_data_loader, 'VAE')

In [None]:
model10 = VAE( img_size, encoder = encoder, decoder = decoder, z_dim = 10, lr = 1e-5, beta = 10)
trainer = pl.Trainer(auto_scale_batch_size=True , max_epochs = 10, devices = 1, accelerator='gpu')
trainer.fit(model10, train_data_loader) 

In [None]:
plottingGrid(model, test_data_loader, 'beta_VAE_10')

In [None]:
model150 = VAE( img_size, encoder = encoder, decoder = decoder, z_dim = 10, lr = 1e-5, beta = 150 )
trainer = pl.Trainer(auto_scale_batch_size=True , max_epochs = 10, devices = 1, accelerator='gpu')
trainer.fit(model150, train_data_loader) 

In [None]:
plottingGrid(model, test_data_loader, 'beta_VAE_150')