## VAE

In [None]:
import os
import random
import cv2
import numpy as np
import collections

import torch
import torchvision
import torch.nn as nn

import albumentations as A
from albumentations.pytorch import ToTensorV2

import matplotlib.pyplot as plt

from utils import (
    Collector,
    VAELoss,
    trainer,
    plot_latent_tsne,
    generate_samples_between_centers,
    visualize_prediction,
)

from models import (
    deconv_resnet18,
    resnet18,
)

%matplotlib inline

## Set Seeds

In [None]:
seed = 0

torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

g = torch.Generator()
g.manual_seed(seed)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

## Basic Params

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 50
bs = 256
num_workers = 8
mean = [0, 0, 0]
std = [1, 1, 1]
image_size = [128, 128]

## Data loader

In [None]:
class Transforms:
    
    def __init__(
        self,
        transforms: A.Compose,
    ) -> None:
        
        self.transforms = transforms

    def __call__(
        self,
        img,
        *args,
        **kwargs,
    ) -> torch.Tensor:
        
        return self.transforms(
            image=np.array(img),
            test_image=np.array(img),
        )

class MyDataset(torch.utils.data.Dataset):
    
    def __init__(
        self,
        dataset: torch.utils.data.Dataset,
        transforms: Transforms,
        kwargs,
    ) -> None:
        
        super().__init__()
        self.dataset = dataset(**kwargs)
        self.transforms = transforms
            
        
    def __getitem__(
        self,
        idx: int,
    ) -> dict:
        
        images, classes = self.dataset.__getitem__(idx)
        
        if not isinstance(images, np.ndarray):
            images = np.array(images)
        
        if len(images.shape) == 2:
            images = images[:, :, None]
        
        if images.shape[2] == 1:
            images = np.tile(images, (1, 1, 3))
            
        return {
            'images': self.transforms(images),
            'class': classes,
        }
    
    def __len__(
        self,
    ) -> int:
        
        return len(self.dataset)
    
in_tf = A.Compose([
    A.LongestMaxSize(max_size=max(image_size)),
    A.PadIfNeeded(
        position=A.PadIfNeeded.PositionType.TOP_LEFT,
        min_height=image_size[0],
        min_width=image_size[1],
        value=0,
        border_mode=cv2.BORDER_CONSTANT,
    ),
], additional_targets = {
    'test_image': 'image',
})

middle_tf = A.Compose([
    A.CoarseDropout(
        max_holes=4,
        min_holes=1,
        max_height=0.2,
        min_height=0.05,
        max_width=0.2,
        min_width=0.05,
        fill_value=[0, 0.5, 1],
        p=0.5,
    ),
    A.OneOf([
        A.RandomBrightnessContrast(
            brightness_limit=(-0.5, 0.5),
            contrast_limit=(-0.5, 0.5),
            p=1,
        ),
        A.Blur(
            p=1,
        ),
        A.GaussNoise(
            var_limit=5.0 / 255.0,
            p=1,
        ),
    ], p=1)
])

out_tf = A.Compose([
    A.Normalize(
        mean=mean,
        std=std,
    ),
    ToTensorV2(),
], additional_targets = {
    'test_image': 'image',
})
    
transformations = {
    'train': A.Compose([
        in_tf,
        middle_tf,
        out_tf,
    ]),
    'test': A.Compose([
        in_tf,
        out_tf,
    ]),
}

# STL10
stl10 = {
    'train': {
        'root': './data/',
        'split': 'unlabeled',
        'download': True,
        'transform': Transforms(transformations['train']),
    },
    'test': {
        'root': './data/',
        'split': 'test',
        'download': True,
        'transform': Transforms(transformations['test']),
    },
}

# MNIST
mnist = {
    'train': {
        'root': './data/',
        'train': True,
        'download': True,
    },
    'test': {
        'root': './data/',
        'train': False,
        'download': True,
    },
}

dataset = {
    'mnist': {
        'phase': mnist,
        'class': torchvision.datasets.MNIST,
    },
    'stl10': {
        'phase': stl10,
        'class': torchvision.datasets.STL10,
    },
}

dataset_name = 'mnist'

datasets = {
    phase: MyDataset(
        dataset[dataset_name]['class'],
        Transforms(transformations[phase]),
        dataset[dataset_name]['phase'][phase]
    )
    for phase in dataset[dataset_name]['phase']
}

dataloaders = {
    phase: torch.utils.data.DataLoader(
        dataset=datasets[phase],
        batch_size=bs,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True,
        worker_init_fn=seed_worker,
        generator=g,
    )
    for phase in dataset[dataset_name]['phase']
}

## Let's visualize the batch

In [None]:
visualize_prediction(
    model=None,
    dataset=datasets['train'],
)

## Define model
![alt text](pics/vae.png "Title")

In [None]:
class Flatten(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return x.view(x.size(0), -1)

class UnFlatten(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        b, c = x.shape
        return x.view(b, c, 1, 1)

class VariationalAE(nn.Module):
    
    def __init__(
        self,
        downlayers: nn.Module,
        uplayers: nn.Module,
        hidden_dim: int=512,
        latent_dim: int=512,
    ) -> None:
        
        super().__init__()
        
        self.down = downlayers()
        self.up = uplayers()
        
        self.agg = nn.AdaptiveAvgPool2d((1, 1))
        self.middle = nn.ConvTranspose2d(
            in_channels=latent_dim,
            out_channels=hidden_dim,
            kernel_size=(image_size[0] // 32, image_size[1] // 32),
            stride=2,
        )
        
        self.mu_repr = nn.Linear(hidden_dim, latent_dim)
        self.log_sigma_repr = nn.Linear(hidden_dim, latent_dim)
        
        self.flatten = Flatten()
        self.unflatten = UnFlatten()
        self.sigmoid = nn.Sigmoid()
        
        # hack to get sampling on the GPU
        self.samling = torch.distributions.Normal(0, 1)
        self.samling.loc = self.samling.loc.to(device)
        self.samling.scale = self.samling.scale.to(device)
        
        self.init_weights()
        
    def init_weights(
        self,
    ) -> None:
        
        for m in self.modules():
            
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(
                    tensor=m.weight,
                )
                nn.init.zeros_(
                    tensor=m.bias,
                )
        
    def _encode(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        
        x = self.down(x)
        x = self.agg(x)
        x = self.flatten(x)
        
        return x
    
    def _decode(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        
        x = self.unflatten(x)
        x = self.middle(x)
        x = self.up(x)
        
        return x
    
    def _reparametrize(
        self,
        mu: torch.Tensor,
        log_sigma: torch.Tensor,
    ) -> torch.Tensor:
        
        eps = self.samling.sample(mu.shape)
        x = eps * log_sigma.exp() + mu
        
        return x
        
    def forward(
        self,
        x: torch.Tensor,
        embedding: bool=True,
    ) -> torch.Tensor:
        
        latent_repr = self._encode(x)
        
        if embedding:
            return latent_repr
        
        latent_mu = self.mu_repr(latent_repr)
        latent_log_sigma = self.log_sigma_repr(latent_repr)
        
        sample = self._reparametrize(latent_mu, latent_log_sigma)
        
        image = self._decode(sample)
        
        return {
            'pred_image': image,
            'mu': latent_mu,
            'log_sigma': latent_log_sigma,
        }

In [None]:
latent_dim=1024

model = VariationalAE(
    downlayers=resnet18,
    uplayers=deconv_resnet18,
    latent_dim=latent_dim,
).to(device)

## Define trainer params

In [None]:
lr=5e-4
weights=[1,100]

optimizer = torch.optim.AdamW(
    params=model.parameters(),
    lr=lr,
    weight_decay=1e-3,
)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer,
    max_lr=lr,
    steps_per_epoch=len(dataloaders['train']),
    epochs=epochs,
)
loss = VAELoss(weights=weights)

save_path=f'metrics/VAE/dataset={dataset_name}_epoch={epochs}_bs={bs}_lr={lr}_loss={loss.__class__.__name__}_weights={weights}_latent={latent_dim}_norm=max'
visualiser = Collector(
    root_graphics=save_path,
    root_desc=save_path,
    phases=list(transformations.keys()),
)

In [None]:
trainer(model, optimizer, scheduler, dataloaders, epochs, device, loss, visualiser, save_path)

## Visualize predictions

In [None]:
plot_latent_tsne(model, dataloaders['test'], save_path)

In [None]:
@torch.no_grad()
def visualize_prediction(
    model: nn.Module,
    dataset: torch.utils.data.Dataset,
    save_path: str=None,
    n_samples: int=10,
    device: torch.device='cuda',
) -> None:
    
    fig, axes = plt.subplots(n_samples, 3, figsize=(10, 40))

    for i in range(n_samples):
        index = random.randint(0, len(dataset))
            
        sample = dataset[index]
        
        inputs = sample['images']['image']
        outputs = sample['images']['test_image']
        
        if model is not None:
            model.eval()
            preds = model(
                x=inputs.to(next(model.parameters()).device).unsqueeze(0),
                embedding=False,
            )['pred_image'].squeeze(0).sigmoid()
        else:
            preds = torch.zeros(inputs.shape)
            
        for j, (data, title) in enumerate(zip([inputs, outputs, preds], ['input', 'output', 'pred'])):            
            axes[i][j].imshow(data.permute(1, 2, 0).cpu().numpy())
            axes[i][j].set_title(title)
    
    if save_path is not None:
        plt.savefig(os.path.join(save_path, 'predictions.jpg'), format='jpg')
        
    plt.show()

In [None]:
visualize_prediction(model, datasets['test'], save_path)

In [None]:
def generate_samples_between_centers(
    model: nn.Module,
    dataloader: torch.utils.data.DataLoader,
    save_path: str=None,
    from_class: int=9,
    to_class: int=7,
    rows_cols: int=2,
    device: torch.device='cuda',
) -> None:
    
    centroids = []
    classes = []
    
    for gt_batch in dataloader:
        preds_batch = model(
            x=gt_batch['images']['test_image'].to(next(model.parameters()).device),
            embedding=False,
        )
        
        centroids.append(preds_batch['mu'].detach().cpu().numpy())
        classes.append(gt_batch['class'].numpy())
        
    centroids = np.concatenate(centroids)
    classes = np.concatenate(classes)
    
    center_from = centroids[classes == from_class].mean(axis=0)
    center_to = centroids[classes == to_class].mean(axis=0)
    
    z = torch.stack([
        torch.from_numpy(t*center_from + (1 - t) * center_to)
        for t in np.linspace(0, 1, rows_cols*rows_cols)
    ])
    
    images = model._decode(z.to(device)).permute(0, 2, 3, 1).cpu().sigmoid().detach().numpy()
    
    fig = plt.figure(figsize=(10, 10))

    for i, img in enumerate(images):
        fig.add_subplot(rows_cols, rows_cols, i + 1)
        plt.imshow(img)
    
    if save_path is not None:
        plt.savefig(os.path.join(save_path, 'samples_between_centers.jpg'), format='jpg')
    plt.show()

In [None]:
generate_samples_between_centers(model, dataloaders['test'], save_path, from_class=9, to_class=1, rows_cols=4)

In [None]:
def convert_embeddings(
    path: str,
    transforms,
) -> torch.Tensor:
    
    batch = torch.load(path).mul_(255).permute(0, 2, 3, 1).numpy().astype(np.uint8)
    new_batch = []
    
    for image in batch:
        new_batch.append(
            transforms(image=image)['image'],
        )
        
    new_batch = torch.stack(new_batch)
    
    return new_batch

In [None]:
traced_module = torch.jit.trace(model.cpu(), torch.randn(1, 3, image_size[0], image_size[1]))
torch.jit.save(traced_module, os.path.join(save_path, 'model_traced.pt'))

In [None]:
loaded = torch.jit.load(os.path.join(save_path, 'model_traced.pt'))
batch = convert_embeddings('health_dataset.pth', transformations['test'])

with torch.no_grad():
    embeddings = loaded(batch.cpu()).cpu()
    
torch.save(embeddings, os.path.join(save_path, 'embeddings.pth'))