In [1]:
import os
import json
import math
import numpy as np 
import random

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline 
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
from torch.nn.modules.activation import Sigmoid
from torch.optim import optimizer

# Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torchvision.datasets import ImageFolder
from sklearn.model_selection import train_test_split
from torchvision.utils import save_image

In [6]:
%%capture
!unzip /content/trafic_32.zip

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

# Dane

In [8]:
train_transform = transforms.Compose(
    [transforms.ToTensor()])

dataset = ImageFolder("/content/trafic_32", transform=train_transform)


In [9]:
batch_size = 64
validation_split = .2
shuffle_dataset = True
random_seed= 42

dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset:
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
valid_sampler = torch.utils.data.SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 
                                           sampler=train_sampler, num_workers=2)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                sampler=valid_sampler, num_workers=2)

In [10]:
%%capture
!mkdir /content/validation_data
!rm /content/validation_data/*

for batch, (images, labels) in enumerate(validation_loader):
    for cnt, image in enumerate(images):
        if batch*batch_size + cnt < 1000:
            save_image(image, '/content/validation_data/val_' + str(batch*batch_size + cnt) + '.png')
        else:
            break

In [11]:
classes = dataset.classes

In [12]:
classes_len = {}
for batch in validation_loader:
    for label in batch[1]:
        if label.item() not in classes_len:
            classes_len[label.item()] = 1
        else:
            classes_len[label.item()] += 1

# Model

In [13]:
class View(nn.Module):
    def __init__(self, shape):
        super(View, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

In [14]:
class HamsterConcurentLayer(nn.Module):
    """Simple residual conection module"""
    def __init__(self, channels):
        super(HamsterConcurentLayer, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(channels, channels),
            nn.ReLU(),
            nn.BatchNorm1d(channels),
            nn.Linear(channels, channels),
            nn.ReLU(),
            nn.BatchNorm1d(channels),
        )

    def forward(self, x):
        x_hat = self.layers(x)
        x_hat += x
        return x_hat

In [15]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, cond_dim):
        super(Encoder, self).__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=input_dim, out_channels=32, stride=(2,2), kernel_size=5),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=32, out_channels=64, stride=(2, 2), kernel_size=5),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2),

            nn.Flatten(),
            nn.Linear(64, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
        )

        self.res_layers = nn.Sequential(
            nn.Flatten(),
            HamsterConcurentLayer(input_dim),
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            HamsterConcurentLayer(hidden_dim),
        )

        self.layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
        )

        self.fc_mean  = nn.Linear(hidden_dim + cond_dim, latent_dim)
        self.fc_var   = nn.Linear(hidden_dim + cond_dim, latent_dim)
        
        self.training = True
        self.init_weights()

    def init_weights(self):
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                elif isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
  
    def forward(self, x, y=None):
        
        # x = self.conv_layers(x)
        # x = self.res_layers(x)
        x = self.layers(x)

        if y is not None:
            x = torch.cat((x, y), dim=1)
        mean     = self.fc_mean(x)
        log_var  = self.fc_var(x)  
        return mean, log_var

In [16]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, cond_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()

        self.conv_layers = nn.Sequential(
            nn.Linear(latent_dim + cond_dim, 192),
            nn.ReLU(),
            View([-1, 3, 8, 8]),
            nn.ConvTranspose2d(in_channels = 3, out_channels=128, stride=2, kernel_size=3, padding=1, output_padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(in_channels = 128, out_channels=64, stride=2, kernel_size=3, padding=1, output_padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(in_channels = 64, out_channels=3, stride=1, kernel_size=3, padding=1),
            nn.Flatten(),
            nn.Sigmoid(),
        )

        self.layers = nn.Sequential(
            nn.Linear(latent_dim + cond_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, output_dim),
            nn.Sigmoid(),
        )

        self.res_layers = nn.Sequential(
            HamsterConcurentLayer(latent_dim + cond_dim),
            nn.Linear(latent_dim + cond_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            HamsterConcurentLayer(hidden_dim),
            nn.Linear(hidden_dim, output_dim),
            nn.Sigmoid(),
        )
        
        self.init_weights()

    def init_weights(self):
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)

    def forward(self, x, y=None):

        if y is not None:
            x = torch.cat((x, y), dim=1)

        # x_hat = self.conv_layers(x)
        # x_hat = self.res_layers(x)
        x_hat = self.layers(x)
        
        x_hat = x_hat.view([-1, 3, 32, 32])
        return x_hat

In [17]:
class cVAE(nn.Module):
    """Conditional VAE"""
    def __init__(self, x_dim, nclass, hidden_dim, latent_dim, cond_dim):
        super(cVAE, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = Encoder(input_dim=x_dim*32*32, cond_dim=cond_dim, hidden_dim=hidden_dim, latent_dim=latent_dim) # if conv layers are used input_dim=x_dim else input_dim=x_dim*32*32
        self.decoder = Decoder(latent_dim=latent_dim, cond_dim=cond_dim, hidden_dim = hidden_dim, output_dim = x_dim*32*32)
        self.label_embedding = nn.Embedding(nclass, cond_dim)

        
    def sampling(self, mean, var):
        z = torch.randn_like(mean) * var + mean
        return z
        
                
    def forward(self, x, y=None):
        if y is not None:
            y = self.label_embedding(y)
        mean, log_var = self.encoder(x, y)
        z = self.sampling(mean, torch.exp(0.5 * log_var))
        x_hat = self.decoder(z, y)
        return x_hat, mean, log_var

    def generate(self, z, y=None):
        if y is not None:
            y = self.label_embedding(y)
        x_hat = self.decoder(z, y)
        return x_hat

In [18]:
model = cVAE(latent_dim=256, cond_dim=128, hidden_dim=1024, x_dim=3, nclass=len(classes)).to(device)

# Trenowanie

In [19]:
def vae_loss_function(x, x_hat, mean, log_var, b):
    MSE = nn.functional.mse_loss(x_hat, x, reduction='sum')
    KLD = -0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())
    return MSE + b*KLD

In [20]:
def frange_cycle_sigmoid(start, stop, n_epoch, n_cycle=4, ratio=0.5):
    L = np.ones(n_epoch)
    period = n_epoch/n_cycle
    step = (stop-start)/(period*ratio)

    for c in range(n_cycle):

        v , i = start , 0
        while v <= stop:
            L[int(i+c*period)] = 1.0/(1.0+ np.exp(- (v*12.-6.)))
            v += step
            i += 1
    return L    

In [21]:
# optimizer = optim.RMSprop(model.parameters(), lr=0.001, weight_decay=0.01)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.99)

In [22]:
num_epochs = 20
beta_np_cycle = frange_cycle_sigmoid(0.0, 1.0, num_epochs, 4)

In [23]:
for epoch in range(num_epochs):
    losses_epoch = []
    model.train()
    for x, label in iter(train_loader):
        x = x.to(device)
        label = label.to(device)

        out, means, log_var = model(x, label)
        # print(out.shape)
        loss = vae_loss_function(x, out, means, log_var, beta_np_cycle[epoch]) 
        losses_epoch.append(loss.item())

        loss.backward()     
          
        optimizer.step()             
        optimizer.zero_grad() 
 
    L1_list = []
#     if epoch % 10 == 0:
    model.eval()
    with torch.no_grad():
        for x, label in iter(validation_loader):
            x  = x.to(device)
            label = label.to(device)
            out, _, _ = model(x, label)
            L1_list.append(torch.mean(torch.abs(out-x)).item())
            
        print(f"Epoch [{epoch+1}/{num_epochs}] loss {np.mean(np.array(losses_epoch))}, validation L1 = {np.mean(L1_list)}")
    scheduler.step()

Epoch [1/20] loss 14123.838140483786, validation L1 = 0.09207791338364284
Epoch [2/20] loss 5510.24810306215, validation L1 = 0.10148066336788782
Epoch [3/20] loss 6345.050029933332, validation L1 = 0.09510272238554994
Epoch [4/20] loss 5492.674318395176, validation L1 = 0.10518599053223927
Epoch [5/20] loss 5473.298645143839, validation L1 = 0.09052932795470323
Epoch [6/20] loss 2911.526459175312, validation L1 = 0.07468495651231549
Epoch [7/20] loss 3610.434245123154, validation L1 = 0.07671164866627717
Epoch [8/20] loss 4757.163626428049, validation L1 = 0.0863970933648629
Epoch [9/20] loss 4498.943192926784, validation L1 = 0.08984702613537873
Epoch [10/20] loss 4509.377809357497, validation L1 = 0.08176777269539794
Epoch [11/20] loss 2431.1147518119114, validation L1 = 0.06767891189916347
Epoch [12/20] loss 3063.9861560829304, validation L1 = 0.07186688649339404
Epoch [13/20] loss 4239.726305804282, validation L1 = 0.09106506074104852
Epoch [14/20] loss 4307.136294487294, validati

# Ocena

In [24]:
def get_train_images(num):
    images = []
    labels = []
    for i in range(0, num):
        r = random.randint(0, len(dataset))
        images.append(dataset[r][0])
        labels.append(dataset[r][1])
    return torch.stack(images, dim=0), torch.tensor(labels)

In [25]:
def visualize_reconstructions(model, input_imgs, labels, device):
    # Reconstruct images
    model.eval()
    with torch.no_grad():
        reconst_imgs, means, log_var = model(input_imgs.to(device), labels.to(device))
    reconst_imgs = reconst_imgs.cpu()
    
    # Plotting
    imgs = torch.stack([input_imgs, reconst_imgs], dim=1).flatten(0,1)
    grid = torchvision.utils.make_grid(imgs, nrow=4, normalize=False, range=(-1,1))
    grid = grid.permute(1, 2, 0)
    if len(input_imgs) == 4:
        plt.figure(figsize=(10,10))
    else:
        plt.figure(figsize=(15,10))
    plt.title(f"Reconstructions")
    plt.imshow(grid)
    plt.axis('off')
    plt.show()

### Rekonstrukcja obrazów

In [None]:
input_imgs, labels = get_train_images(8)
visualize_reconstructions(model, input_imgs, labels, device)

# Generacje

In [27]:
def generate_images(model, n_imgs, device):
    # Generate images
    model.eval()
    with torch.no_grad():
        label = torch.tensor(random.choices(list(classes_len.keys()), weights = list(classes_len.values()), k = n_imgs))
        generated_imgs = model.generate(torch.randn([n_imgs, model.latent_dim]).to(device), label.to(device))
    return generated_imgs.cpu()


In [28]:
def show_images(images):
    grid = torchvision.utils.make_grid(images, nrow=4, normalize=False, range=(-1,1))
    grid = grid.permute(1, 2, 0)
    if len(images) == 4:
        plt.figure(figsize=(10,10))
    else:
        plt.figure(figsize=(15,10))
    plt.title(f"Generations")
    plt.imshow(grid)
    plt.axis('off')
    plt.show()

In [None]:
generated_images = generate_images(model, 16 , device)
show_images(generated_images)

In [30]:
generated_images = generate_images(model, 1000 , device)


In [31]:
%%capture
!mkdir /content/generated_data
!rm /content/generated_data/*

for cnt, image in enumerate(generated_images):
    save_image(image, '/content/generated_data/gen_' + str(cnt) + '.png')

### Liczenie wartości Fréchet Inception Distance (FID)

In [32]:
%%capture
!pip install pytorch-fid

In [33]:
!python -m pytorch_fid --device cuda:0 '/content/validation_data' '/content/generated_data' 


100% 20/20 [00:04<00:00,  4.82it/s]
100% 20/20 [00:04<00:00,  4.92it/s]
FID:  79.51691877536854
