# Семинар № 8 - VAE

# Imports 

In [None]:
# !pip install -q -U albumentations
!pip install --upgrade -q albumentations==0.4.6
!echo "$(pip freeze | grep albumentations) is successfully installed"

In [None]:
import os
from pathlib import Path
import random
import typing as tp
from time import gmtime, strftime

import yaml
from tqdm import tqdm
from cv2 import erode
import numpy as np
import matplotlib.pyplot as plt
from skimage.io import imread
from skimage.transform import rotate, rescale, resize

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

In [None]:
def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

In [None]:
set_seed(42)

DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# DEVICE = torch.device('cpu')

# Набор данных


In [None]:
mnist_trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=None)
mnist_validset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=None)

In [None]:
len(mnist_trainset), len(mnist_validset)

In [None]:
img, label = mnist_trainset[4]

In [None]:
plt.imshow(np.array(img))

## Create dataset

In [None]:
class MNISTDataset(Dataset):
    def __init__(self, data,
                 transforms: tp.Optional[A.BasicTransform] = None):
        self.data = data
        self.transforms = transforms

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> dict:
        image, label = self.data[idx]
        image = np.array(image)
        
        # check for dimensions            
        if image.ndim == 2:
            image = image[..., np.newaxis]
            
        result = {"image": image}

        if self.transforms is not None:
            result = self.transforms(**result)
            
        image = result['image']
        label = torch.tensor(label).long()

        return image, label

## Create augs

In [None]:
def pre_transform() -> A.BasicTransform:
    result = []
    return A.Compose(result)


def augmentations() -> A.BasicTransform:
    result = [
        # A.GaussNoise(),
    ]
    return A.Compose(result)


def post_transform() -> A.BasicTransform:
    return A.Compose([
        A.Normalize(mean=0, 
                    std=1,
                    max_pixel_value=255),
        ToTensorV2()
    ])

In [None]:
# Create train pipeline
train_transformation = A.Compose([
    pre_transform(), augmentations(), post_transform()
])


# Create valid pipeline
valid_transformation = A.Compose([
    pre_transform(), post_transform()
])  

### Check augs

In [None]:
_, ax = plt.subplots(2, 2, figsize=(6.4 * 1.5, 4.8 * 1.5))

for i in range(4):
    sample_idx = np.random.randint(0, len(mnist_trainset), 1)[0]
    sample_image, _ = mnist_trainset[sample_idx]
    sample_image = np.array(sample_image)[..., np.newaxis]
    aug_image = train_transformation(image=sample_image)["image"].numpy()
    ax[i % 2][i // 2].imshow(aug_image[0])
    
plt.show()

## Make DataLoader 

In [None]:
batch_size = 128
num_workers = 0  # num_workers = 0 for local notebook on win 

train_dataset = MNISTDataset(mnist_trainset, train_transformation)
train_loader = DataLoader(
    train_dataset, 
    shuffle=True,
    batch_size=batch_size,
    num_workers=num_workers,
    drop_last=True,
)

valid_dataset = MNISTDataset(mnist_validset, valid_transformation) 
valid_loader = DataLoader(
    valid_dataset, 
    shuffle=False,
    batch_size=batch_size,
    num_workers=num_workers,
    drop_last=True,
)

In [None]:
image, label = next(iter(train_loader))

print('Image batch shape:', image.shape)

# Create model 

In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_size: int = 2):
        super().__init__()
        self.latent_size = latent_size

        # Create CNN encoder model
        self.feature_extractor = nn.Sequential(
        )
        self.fc_mean = nn.Linear(None, latent_size)
        self.fc_var = nn.Linear(None, latent_size)
        
    def forward(self, images: torch.Tensor):
        features = self.feature_extractor(images)
        mean = self.fc_mean(features)
        log_var = self.fc_var(features)

        return mean, log_var

In [None]:
class Decoder(nn.Module):
    def __init__(self, latent_size: int = 2):
        super().__init__()
        self.latent_size = latent_size

        # Create Decoder model!
        self.map_generator = nn.Sequential(
            nn.Linear(latent_size, 128 * 49),
        )
        # create up-sample layers
        self.deconv = nn.Sequential(
            None
        )
        # create final conv layer
        self.output = nn.Sequential(
            nn.Conv2d(None, 1, 3, padding=1),
            nn.Tanh()
        )
            
    def forward(self, points: torch.Tensor) -> torch.Tensor:
        feature_map = self.map_generator(points)        
        feature_map = self.deconv(feature_map)
        
        return self.output(feature_map)

In [None]:
LOG_SCALE_MAX = 2
LOG_SCALE_MIN = -10

def normal_sample(loc: torch.Tensor, log_scale: torch.Tensor) -> torch.Tensor:
    scale = torch.exp(0.5 * log_scale)
    return loc + scale * torch.randn_like(scale)


class VAE(nn.Module):
    def __init__(self, latent_size: int = 2):
        super().__init__()

        self.encoder = Encoder(latent_size)
        self.decoder = Decoder(latent_size)
        
    def forward(self, x: torch.Tensor):
        mean, log_var = self.encoder(x)  # get mean and log_var for sampling
        log_var = torch.clamp(log_var, LOG_SCALE_MIN, LOG_SCALE_MAX)

        z_ = normal_sample(mean, log_var) if self.training else mean
        x_ = self.decoder(z_)  # recreate object from z_

        return x_, mean, log_var

In [None]:
latent_size = 16
model = VAE(latent_size).to(DEVICE)

In [None]:
x = torch.ones((256, 1, 28, 28)).to(DEVICE)

x_hat, mean, log_var = model(x)

In [None]:
x_hat.shape, mean.shape, log_var.shape

## Define loss

In [None]:
BCE_loss = nn.BCELoss()

def KLD_loss(mean, log_var):
     return (-0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=1)).mean()

In [None]:
KLD_loss(mean, log_var)

In [None]:
BCE_loss(x_hat.sigmoid(), x)

In [None]:
lr = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

## Train loop

In [None]:
epochs = 2

In [None]:
print("Start training VAE...")
model.train()

for epoch in range(epochs):
    overall_loss = 0
    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    for batch_idx, (x, _) in pbar:
        x = x.to(DEVICE)

        optimizer.zero_grad()

        x_hat, mean, log_var = model(x)
        loss_bce = BCE_loss(x_hat.sigmoid(), x)
        loss_kld = KLD_loss(mean, log_var)
        loss = loss_bce + loss_kld
        
        overall_loss += loss.item()
        
        loss.backward()
        optimizer.step()

        log_line = f'Loss: {loss.detach().cpu().data:.4f}'
        pbar.set_description(log_line)

        
    print("\tEpoch", epoch + 1, "complete!", "\tAverage Loss: ", overall_loss / (batch_idx * batch_size))
    
print("Finish!!")

## Generate images

In [None]:
model.eval()

with torch.no_grad():
    for batch_idx, (x, _) in enumerate(valid_loader):
        x = x.to(DEVICE)
    
        x_hat, _, _ = model(x)

        break

In [None]:
def show_image(x, idx):

    fig = plt.figure()
    plt.imshow(x[idx][0].cpu().numpy())

In [None]:
show_image(x, idx=0)