In [None]:
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
import numpy as np

#import F
from torch.nn import functional as F

In [None]:
# Set device to cuda 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

In [None]:
image_transform = transforms.Compose([
    transforms.Resize((128,128)),transforms.ToTensor(),transforms.Normalize([0.5], [0.5]),transforms.Lambda(lambda x : x.to(device))
])

In [None]:
from torchvision import datasets

In [None]:
trainFolder = 'train'
trainDataset = datasets.ImageFolder(trainFolder, transform=image_transform)

# Select 1000 images only to train
#trainDataset = torch.utils.data.Subset(trainDataset, np.random.choice(len(trainDataset), 1000, replace=False))

validFolder = 'val'
validDataset = datasets.ImageFolder(validFolder, transform=image_transform)

# Select 100 images only to validate
#validDataset = torch.utils.data.Subset(validDataset, np.random.choice(len(validDataset), 100, replace=False))

In [None]:
print('Train dataset:', len(trainDataset))
print('Valid dataset:', len(validDataset))

In [None]:
#Import DataLoader
from torch.utils.data import DataLoader
batchSize = 16
trainLoader = DataLoader(trainDataset, batch_size=batchSize, shuffle=True)
validLoader = DataLoader(validDataset, batch_size=batchSize, shuffle=True) 

In [None]:
# print some images from the trainLoader
import matplotlib.pyplot as plt
import numpy as np
for images, labels in trainLoader:
    fig, axes = plt.subplots(figsize=(10,4), ncols=4)
    for i in range(4):
        ax = axes[i]
        ax.imshow(images[i].cpu().numpy().transpose(1,2,0).squeeze())
    break

In [None]:
class VariationalAutoEncoder(nn.Module):
    
    def __init__(self):
        super(VariationalAutoEncoder, self).__init__()
        
        # Encoder accpting 128,128 images across 3 channels
        
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=6,kernel_size=5,stride=1,padding=0), # 128 -> 124
            nn.Conv2d(in_channels=6,out_channels=9,kernel_size=5,stride=1,padding=0), # 124 -> 120
            nn.BatchNorm2d(9),
            nn.ReLU(),
            
            nn.Conv2d(in_channels=9,out_channels=12,kernel_size=10,stride=1,padding=0), # 120 -> 111
            nn.Conv2d(in_channels=12,out_channels=15,kernel_size=10,stride=1,padding=0), # 111 -> 102
            nn.BatchNorm2d(15),
            nn.ReLU(),
            
            nn.Conv2d(in_channels=15,out_channels=18,kernel_size=15,stride=1,padding=0), # 102 -> 88
            nn.Conv2d(in_channels=18,out_channels=21,kernel_size=15,stride=1,padding=0), # 88 -> 74
            nn.BatchNorm2d(21),
            nn.ReLU(),
            
            nn.Conv2d(in_channels=21,out_channels=24,kernel_size=20,stride=1,padding=0), # 74 -> 55
            nn.Conv2d(in_channels=24,out_channels=27,kernel_size=20,stride=1,padding=0), # 55 -> 36
            nn.BatchNorm2d(27),
            nn.ReLU(),
            
            nn.Conv2d(in_channels=27,out_channels=30,kernel_size=25,stride=1,padding=0), # 36 -> 12
            nn.Flatten(),
            nn.Linear(30*12*12, 100),
        )
        
        # mean and log variance
        self.mean = nn.Linear(100, 128)
        self.logvar = nn.Linear(100, 128)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(128, 512),
            nn.ReLU(),
            nn.Linear(512, 30*12*12),
            nn.Unflatten(1, (30, 12, 12)),
            
            nn.ConvTranspose2d(in_channels=30,out_channels=27,kernel_size=25,stride=1,padding=0), # 12 -> 36
            nn.ConvTranspose2d(in_channels=27,out_channels=24,kernel_size=20,stride=1,padding=0), # 36 -> 55
            nn.BatchNorm2d(24),
            nn.ReLU(),
            
            nn.ConvTranspose2d(in_channels=24,out_channels=21,kernel_size=20,stride=1,padding=0), # 55 -> 74
            nn.ConvTranspose2d(in_channels=21,out_channels=18,kernel_size=15,stride=1,padding=0), # 74 -> 88
            nn.BatchNorm2d(18),
            nn.ReLU(),
            
            nn.ConvTranspose2d(in_channels=18,out_channels=15,kernel_size=15,stride=1,padding=0), # 88 -> 102
            nn.ConvTranspose2d(in_channels=15,out_channels=12,kernel_size=10,stride=1,padding=0), # 102 -> 111
            nn.BatchNorm2d(12),
            nn.ReLU(),
            
            nn.ConvTranspose2d(in_channels=12,out_channels=9,kernel_size=10,stride=1,padding=0), # 111 -> 120
            nn.ConvTranspose2d(in_channels=9,out_channels=6,kernel_size=5,stride=1,padding=0), # 120 -> 124
            nn.BatchNorm2d(6),
            nn.ReLU(),
            
            nn.ConvTranspose2d(in_channels=6,out_channels=3,kernel_size=5,stride=1,padding=0), # 124 -> 128
            nn.Tanh()
        )
        
        
    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std
    
    def forward(self, x):
        x = self.encoder(x)
        mean = self.mean(x)
        logvar = self.logvar(x)
        z = self.reparameterize(mean, logvar)
        x = self.decoder(z)
        return x, mean, logvar

In [None]:
from torchsummary import summary
model = VariationalAutoEncoder().to(device)

summary(model, (3, 128, 128))

In [None]:
def loss_function(recon_x, x, mu, log_var):
    MAE = F.mse_loss(recon_x, x.view(-1, 3, 128, 128), reduction='sum')
    KLD = torch.mean(-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()))
    return MAE + 0.1*KLD


In [None]:
# initialize the model 

model = VariationalAutoEncoder().to(device)

# initialize the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [None]:
from torch_snippets import *
from torchvision.utils import make_grid

n_epochs = 200
log = Report(n_epochs)
for epoch in range(n_epochs):
    N = len(trainLoader)
    for batchIndex, (imageSet, labelSet) in enumerate(trainLoader):
        imageSet = imageSet.to(device)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.zero_grad()
        recon_imageSet, mu, log_var = model(imageSet)
        loss = loss_function(recon_imageSet, imageSet, mu, log_var)
        loss.backward()
        optimizer.step()
        
        log.record(epoch+(batchIndex+1)/N, train_loss=loss.item(), end='\r')
        
    N = len(validLoader)
    for batchIndex, (imageSet, labelSet) in enumerate(validLoader):
        imageSet = imageSet.to(device)
        
        recon_imageSet, mu, log_var = model(imageSet)
        loss = loss_function(recon_imageSet, imageSet, mu, log_var)
        
        log.record(epoch+(batchIndex+1)/N, valid_loss=loss.item(), end='\r')
    
    log.report_avgs(epoch+1)
    
    #plot newly generated images64
    with torch.no_grad():
        z = torch.randn(32, 128).to(device)
        samples = model.decoder(z).cpu()
        grid = make_grid(samples, nrow=8, normalize=True)
        show(grid, title=f'Epoch {epoch+1}')
        
    with torch.no_grad():
        for images, labels in validLoader:
            recon_images, _, _ = model(images.to(device))
            break

        N = 6
        fig, axes = plt.subplots(2,N, figsize=(20,5))
        for i in range(N):
            axes[0,i].imshow(images[i].cpu().numpy().transpose(1,2,0).squeeze())
            axes[1,i].imshow(recon_images[i].cpu().numpy().transpose(1,2,0).squeeze())
            
        plt.show()
            

In [None]:
log.plot_epochs(['train_loss', 'valid_loss'])

In [None]:
# See reconstructed images

with torch.no_grad():
    for images, labels in validLoader:
        recon_images, _, _ = model(images.to(device))
        break

    N = 6
    fig, axes = plt.subplots(2,N, figsize=(20,5))
    for i in range(N):
        axes[0,i].imshow(images[i].cpu().numpy().transpose(1,2,0).squeeze())
        axes[1,i].imshow(recon_images[i].cpu().numpy().transpose(1,2,0).squeeze())
        
    plt.show()

In [None]:
#plot newly generated images
with torch.no_grad():
    z = torch.randn(32, 50).to(device)
    samples = model.decoder(z).cpu()
    grid = make_grid(samples, nrow=8, normalize=True)
    show(grid, title=f'Epoch {epoch+1}')