In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import random
from tqdm import tqdm
import torchvision
import torchvision.datasets as datasets

#for consistency, all seeds are set to 69420
seed = 69420
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

## Decoder

In [2]:
class Decoder(nn.Module):
    
    def __init__(self, latent_size, img_channel):
        
        super(Decoder, self).__init__()

        self.conv_transpose_block_1 = nn.Sequential(
            nn.ConvTranspose2d(latent_size, 256, 4, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True))

        self.conv_transpose_block_2 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True))

        self.conv_transpose_block_3 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True))

        self.conv_transpose_block_4 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True))

        self.conv_transpose_block_5 = nn.Sequential(
            nn.ConvTranspose2d(32, img_channel, 1, 1, 0, bias=False),
            nn.Sigmoid())

    def forward(self, x):
        x = self.conv_transpose_block_1(x)
        x = self.conv_transpose_block_2(x)
        x = self.conv_transpose_block_3(x)
        x = self.conv_transpose_block_4(x)
        x = self.conv_transpose_block_5(x)
        return x

## Encoder

In [14]:
class Encoder(nn.Module):
    def __init__(self, latent_size, img_channel, components_num):
        super(Encoder, self).__init__()
        self.latent_size = latent_size
        self.components_num = components_num
        
        self.encoder = nn.Sequential(
            nn.Conv2d(img_channel, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(32, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
        )

        #compute the mean and covariance for all components
        self.mu_fc = nn.Conv2d(256, latent_size, 2, 1, 0, bias=False)
        self.sigma_fc = nn.Conv2d(256, latent_size, 2, 1, 0, bias=False) #output also * number of components?
        self.weights_fc = nn.Conv2d(256, 1, 2, 1, 0, bias=False)
        self.kl = 0
        
    def forward(self, x, latent_size):
        encoded = self.encoder(x)
        batch_size = encoded.size(0)
        mu = self.mu_fc(encoded).view(batch_size, -1, latent_size) #mean
        sigma = torch.exp(self.sigma_fc(encoded).view(batch_size, -1, latent_size)) #covariance


        weights_map = self.weights_fc(encoded)
        flat_weights = weights_map.view(batch_size, -1)
        weights = torch.softmax(flat_weights, dim=1)

        z = self.reparameterize(mu, sigma)
        self.kl = self.kl_loss(mu, sigma, weights) # kl loss term

        return z, mu, sigma, weights
    
    #reparameterization trick
    def reparameterize(self, mu, sigma):
        sd = torch.sqrt(sigma + 1e-8)
        noise = torch.randn_like(sd)
        z = mu + sd * noise
        return z
    

    # KL divergence loss
    def kl_loss(self, mu, sigma, weights):
        kl_component = 0.5 * torch.sum(sigma**2 + mu**2 - torch.log(sigma) - 1, dim=2)
        kl = torch.mean(torch.sum(weights * kl_component, dim=1))
        return kl

## Train

In [15]:
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms

In [16]:
transforms = torchvision.transforms.Compose([
     torchvision.transforms.Resize((32, 32)),
     torchvision.transforms.ToTensor()
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)

test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms)
data_loader = torch.utils.data.DataLoader(mnist_trainset,batch_size=128,shuffle=True,num_workers=1)

In [17]:
latent_size = 128
img_channel = 1
components_num = 10
epcochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = Encoder(latent_size, img_channel, components_num).to(device)
decoder = Decoder(latent_size, img_channel).to(device)
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=0.0002, betas=(0.5, 0.999))

In [18]:
def recon_loss(x, x_recon):
    return nn.BCELoss()(x_recon, x)


def kl_loss(kl):
    return kl


def loss_function(x, x_recon, kl):
    recon = recon_loss(x, x_recon)
    kl = kl_loss(kl)
    return recon + kl


for epoch in range(epcochs):
    encoder.train()
    decoder.train()
    curr_recon , curr_kl = 0, 0

for x,y in train_loader:
    x = x.to(device)
    
    z, mu, sigma, weights = encoder(x, latent_size)
    x_recon = decoder(z)
    loss = loss_function(x, x_recon, encoder.kl)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    curr_recon += recon_loss(x, x_recon).item()
    curr_kl += encoder.kl.item()

    print(f"Epoch [{epoch+1}/{epcochs}], Loss: {loss.item():.4f}, Recon Loss: {curr_recon/len(train_loader):.4f}, KL Loss: {curr_kl/len(train_loader):.4f}")

    

ValueError: expected 4D input (got 3D input)