In [21]:
## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()


import pandas as pd
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torch.optim as optim
import os
import random
import pickle
import torchvision
import shutil
from torchvision import transforms
from tqdm import tqdm

  set_matplotlib_formats('svg', 'pdf') # For export


In [22]:
torch.cuda.set_device(0)
device = torch.device("cuda")

def set_all_seeds(seed):
    os.environ["PL_GLOBAL_SEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_all_seeds(42)

# Read data

## Oversampling

In [23]:
def over_sample(path):
    classes = os.listdir(path)
    num_class = {}

    for i in range(len(classes)):
        num_of_class = len(os.listdir(f"{path}/{classes[i]}"))
        num_class[classes[i]] = num_of_class

    max_num_class = max(num_class.values())

    for key in num_class:
        copy_all = max_num_class // num_class[key] - 1
        copy_sample = max_num_class % num_class[key]

        images = os.listdir(f"{path}/{key}")
        to_copy_sample = random.sample(images, k=copy_sample)
        ind = 0
        for i in range(copy_all):
            for image in images:

                shutil.copytree(f"{path}/{key}/{image}", f"{path}/{key}/{image}_{i}")
            ind += 1
        for image in to_copy_sample:
            shutil.copytree(f"{path}/{key}/{image}", f"{path}/{key}/{image}_{ind}")


## Create val folder

In [24]:
def create_val_folder():
    sub_folders = os.listdir("./trafic_32")
    os.mkdir("./val_trafic_32")
    for sub_folder in sub_folders:
        os.mkdir(f"./val_trafic_32/{sub_folder}")

    for sub_folder in sub_folders:
        images_list = os.listdir(f"./trafic_32/{sub_folder}")
        to_move = random.sample(images_list, k=(len(images_list) // 5))
        for image in to_move:
            shutil.move(f"./trafic_32/{sub_folder}/{image}", f"./val_trafic_32/{sub_folder}/{image}")


In [25]:
def check_classes_amounts(path):
    classes = os.listdir(path)
    num_all = 0
    num_class = []
    for i in range(len(classes)):
        num_of_class = len(os.listdir(f"{path}/{classes[i]}"))
        num_class.append(num_of_class)
        num_all += num_of_class
    print(f"Number of all examples: {num_all}\nNumbers of every class: {num_class}")
    # return num_all, num_class

In [26]:
check_classes_amounts("./trafic_32")

Number of all examples: 77400
Numbers of every class: [1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800]


In [27]:
# create_val_folder()

In [28]:
check_classes_amounts("./val_trafic_32")

Number of all examples: 7841
Numbers of every class: [42, 444, 450, 282, 396, 372, 84, 288, 282, 294, 402, 264, 420, 432, 156, 126, 84, 222, 240, 42, 72, 66, 78, 102, 54, 300, 120, 48, 108, 54, 90, 156, 48, 137, 84, 240, 78, 42, 414, 60, 72, 48, 48]


In [29]:
check_classes_amounts("./trafic_32")

Number of all examples: 77400
Numbers of every class: [1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800]


In [30]:
# over_sample("./trafic_32")

In [31]:
check_classes_amounts("./trafic_32")

Number of all examples: 77400
Numbers of every class: [1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800]


## Data preparation

In [32]:
transform = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.ImageFolder("trafic_32/", transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)


val_dataset = datasets.ImageFolder("val_trafic_32/", transform=transform)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, drop_last=False, num_workers=4)

In [33]:
# does not work after transfation from image format to tensor
# from IPython.display import display
#
# image_num = 31367
# scale = 4
# display(train_dataset[image_num][0].resize(( int(train_dataset[image_num][0].width * scale), int(train_dataset[image_num][0].height * scale))))

# Net preparation

In [34]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()

        self.fc_1 = nn.Linear(input_dim, hidden_dim)
        self.fc_2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_mean  = nn.Linear(hidden_dim, latent_dim)
        self.fc_var   = nn.Linear (hidden_dim, latent_dim)

        self.LeakyReLU = nn.LeakyReLU(0.2)

        self.training = True

    def forward(self, x):
        x = torch.flatten(x, 1)
        x       = self.LeakyReLU(self.fc_1(x))
        x       = self.LeakyReLU(self.fc_2(x))
        mean     = self.fc_mean(x)
        log_var  = self.fc_var(x)

        return mean, log_var

In [35]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.fc_1 = nn.Linear(latent_dim, hidden_dim)
        self.fc_2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_3 = nn.Linear(hidden_dim, output_dim)

        self.LeakyReLU = nn.LeakyReLU(0.2)

    def forward(self, x):
        h     = self.LeakyReLU(self.fc_1(x))
        h     = self.LeakyReLU(self.fc_2(h))

        x_hat = torch.sigmoid(self.fc_3(h))
        # x_hat = x_hat.view([-1, 1, 28, 28])
        x_hat = x_hat.view([-1, 3, 32, 32])
        return x_hat


In [36]:
class VAE(nn.Module):
    def __init__(self, x_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = Encoder(input_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim)
        self.decoder = Decoder(latent_dim=latent_dim, hidden_dim = hidden_dim, output_dim = x_dim)


    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(device)        # sampling epsilon
        z = mean + var * epsilon
        return z


    def forward(self, x):
        mean, log_var = self.encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # takes exponential function (log var -> var)
        x_hat = self.decoder(z)
        return x_hat, mean, log_var

In [37]:
def vae_loss_function(x, x_hat, mean, log_var):


    reproduction_loss = nn.functional.mse_loss(x_hat, x, reduction='sum')
    if reproduction_loss.item() < 99999:
        pass
    else:
        reproduction_loss.item = 99999
    # print(f"rep_loss -> {reproduction_loss}")
    KLD      = min(99999, -0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp()))
    # print(f"KLD -> {KLD}")

    return reproduction_loss + KLD

In [38]:
vae = VAE(latent_dim=256, hidden_dim=2048, x_dim=3072).to(device)
# vae.load_state_dict(torch.load("SSNE-lab-5/vae_40"))

In [39]:
optimizer = optim.Adam(vae.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.99)

In [40]:
num_epochs = 40
for n in range(num_epochs):
    losses_epoch = []
    for x, _ in tqdm(iter(train_loader)):
        x = x.to(device)
        out, means, log_var = vae(x)
        loss = vae_loss_function(x, out, means, log_var)
        losses_epoch.append(loss.item())
        loss.backward()               # backward pass (compute parameter updates)
        optimizer.step()              # make the updates for each parameter
        optimizer.zero_grad()
    L1_list = []
    # if n % 10 == 0:
    for x, _ in iter(val_loader):
        x  = x.to(device)
        out, _, _ = vae(x)
        L1_list.append(torch.mean(torch.abs(out-x)).item())
    print(f"Epoch {n} loss {np.mean(np.array(losses_epoch))}, test L1 = {np.mean(L1_list)}")
    scheduler.step()



100%|██████████| 604/604 [00:42<00:00, 14.37it/s]


Epoch 0 loss 12536.639407530525, test L1 = 0.0957288850219019


100%|██████████| 604/604 [01:03<00:00,  9.51it/s]


Epoch 1 loss 8758.861055689931, test L1 = 0.08753058644792726


100%|██████████| 604/604 [02:22<00:00,  4.24it/s]


Epoch 2 loss 7970.396509435793, test L1 = 0.08333800402620146


100%|██████████| 604/604 [02:24<00:00,  4.18it/s]


Epoch 3 loss 7624.981972397558, test L1 = 0.08375255595291814


100%|██████████| 604/604 [01:08<00:00,  8.80it/s]


Epoch 4 loss 7439.638868319278, test L1 = 0.08114912108548226


100%|██████████| 604/604 [01:00<00:00,  9.99it/s]


Epoch 5 loss 7288.908440798324, test L1 = 0.08032411907709414


100%|██████████| 604/604 [00:51<00:00, 11.63it/s]


Epoch 6 loss 7066.786050354408, test L1 = 0.07862043080310668


100%|██████████| 604/604 [00:44<00:00, 13.61it/s]


Epoch 7 loss 6970.71653899136, test L1 = 0.07850980848794983


100%|██████████| 604/604 [00:43<00:00, 13.99it/s]


Epoch 8 loss 6848.880058238048, test L1 = 0.0780489795630978


100%|██████████| 604/604 [00:42<00:00, 14.07it/s]


Epoch 9 loss 6742.996506848872, test L1 = 0.07671683419856333


100%|██████████| 604/604 [00:42<00:00, 14.36it/s]


Epoch 10 loss 6664.314410279128, test L1 = 0.07831259907012986


100%|██████████| 604/604 [00:42<00:00, 14.34it/s]


Epoch 11 loss 6601.4814251021835, test L1 = 0.07655440270900726


100%|██████████| 604/604 [00:41<00:00, 14.47it/s]


Epoch 12 loss 6511.604538267022, test L1 = 0.07557087271444259


100%|██████████| 604/604 [00:41<00:00, 14.58it/s]


Epoch 13 loss 6424.886035641298, test L1 = 0.07465192699624647


100%|██████████| 604/604 [00:41<00:00, 14.62it/s]


Epoch 14 loss 6359.080216363566, test L1 = 0.07491350636607216


100%|██████████| 604/604 [00:41<00:00, 14.40it/s]


Epoch 15 loss 6285.944228418615, test L1 = 0.07460497908534543


100%|██████████| 604/604 [00:41<00:00, 14.58it/s]


Epoch 16 loss 6248.840140437448, test L1 = 0.07344555632481652


100%|██████████| 604/604 [00:41<00:00, 14.69it/s]


Epoch 17 loss 6198.424942117653, test L1 = 0.07341398004322283


100%|██████████| 604/604 [00:43<00:00, 14.04it/s]


Epoch 18 loss 6144.714996539994, test L1 = 0.07319206648295926


100%|██████████| 604/604 [00:48<00:00, 12.57it/s]


Epoch 19 loss 6106.54782731012, test L1 = 0.07297324713680052


100%|██████████| 604/604 [00:40<00:00, 15.07it/s]


Epoch 20 loss 6064.559413480443, test L1 = 0.07363467339065767


100%|██████████| 604/604 [00:40<00:00, 14.91it/s]


Epoch 21 loss 6021.598134021885, test L1 = 0.07233879420786135


100%|██████████| 604/604 [00:42<00:00, 14.26it/s]


Epoch 22 loss 5987.927432028663, test L1 = 0.07255481033315582


100%|██████████| 604/604 [00:47<00:00, 12.80it/s]


Epoch 23 loss 5936.750140663804, test L1 = 0.07243777372904363


100%|██████████| 604/604 [00:57<00:00, 10.53it/s]


Epoch 24 loss 5903.566806414269, test L1 = 0.07171318746141848


100%|██████████| 604/604 [01:48<00:00,  5.58it/s]


Epoch 25 loss 5880.92331235772, test L1 = 0.07250361841532492


100%|██████████| 604/604 [01:07<00:00,  8.92it/s]


Epoch 26 loss 5845.090495330608, test L1 = 0.0722417731679255


100%|██████████| 604/604 [00:46<00:00, 12.92it/s]


Epoch 27 loss 5812.223879378363, test L1 = 0.07247098385085983


100%|██████████| 604/604 [00:47<00:00, 12.69it/s]


Epoch 28 loss 5805.252916752897, test L1 = 0.07126161708466468


100%|██████████| 604/604 [00:43<00:00, 13.96it/s]


Epoch 29 loss 5761.280879746999, test L1 = 0.07093090449850406


100%|██████████| 604/604 [00:42<00:00, 14.16it/s]


Epoch 30 loss 5732.611616728322, test L1 = 0.07153999889569898


100%|██████████| 604/604 [00:42<00:00, 14.33it/s]


Epoch 31 loss 5706.929414256519, test L1 = 0.07191630741280894


100%|██████████| 604/604 [00:41<00:00, 14.43it/s]


Epoch 32 loss 5695.382758336352, test L1 = 0.07212682392808699


100%|██████████| 604/604 [00:41<00:00, 14.49it/s]


Epoch 33 loss 5671.5975762171465, test L1 = 0.0714337445435024


100%|██████████| 604/604 [00:41<00:00, 14.58it/s]


Epoch 34 loss 5633.138523935482, test L1 = 0.07028157047687038


100%|██████████| 604/604 [00:42<00:00, 14.33it/s]


Epoch 35 loss 5626.849047528198, test L1 = 0.07032159505592238


100%|██████████| 604/604 [00:40<00:00, 14.81it/s]


Epoch 36 loss 5598.308038370499, test L1 = 0.07127598318601808


100%|██████████| 604/604 [00:48<00:00, 12.37it/s]


Epoch 37 loss 5577.022880503673, test L1 = 0.07002479288606875


100%|██████████| 604/604 [01:06<00:00,  9.05it/s]


Epoch 38 loss 5563.856310307585, test L1 = 0.07081371059100475


100%|██████████| 604/604 [02:01<00:00,  4.96it/s]


Epoch 39 loss 5533.431806349597, test L1 = 0.0706334749777471


In [41]:
torch.save(vae.state_dict(), "vae_40_os_256_2048_bs128_40e")

In [None]:
# vae.load_state_dict(torch.load("vae_40_after_os"))
# vae.state_dict()

In [None]:
def get_train_images(start, num):
    return torch.stack([val_dataset[i][0] for i in range(start,start+num)], dim=0)

In [None]:
def visualize_reconstructions(model, input_imgs, device):
    # Reconstruct images
    model.eval()
    with torch.no_grad():
        reconst_imgs, means, log_var = model(input_imgs.to(device))
    reconst_imgs = reconst_imgs.cpu()

    # Plotting
    imgs = torch.stack([input_imgs, reconst_imgs], dim=1).flatten(0,1)
    grid = torchvision.utils.make_grid(imgs, nrow=4, normalize=False, range=(-1,1))
    grid = grid.permute(1, 2, 0)
    if len(input_imgs) == 4:
        plt.figure(figsize=(10,10))
    else:
        plt.figure(figsize=(15,10))
    plt.title(f"Reconstructions")
    plt.imshow(grid)
    plt.axis('off')
    plt.show()

In [None]:
for i in range(13):
    # input_imgs = get_train_images(6765, 8)
    input_imgs = get_train_images(i * 500, 8)
    visualize_reconstructions(vae, input_imgs, device)

In [None]:
def generate_images(model, n_imgs, device):
    # Generate images
    model.eval()
    with torch.no_grad():
        generated_imgs = model.decoder(torch.randn([n_imgs, model.latent_dim]).to(device))
    generated_imgs = generated_imgs.cpu()

    grid = torchvision.utils.make_grid(generated_imgs, nrow=4, normalize=False, range=(-1,1))
    grid = grid.permute(1, 2, 0)
    if len(input_imgs) == 4:
        plt.figure(figsize=(10,10))
    else:
        plt.figure(figsize=(15,10))
    plt.title(f"Generations")
    plt.imshow(grid)
    plt.axis('off')
    plt.show()

In [None]:
generate_images(vae, 16 , device)