In [None]:
import os
import glob
from PIL import Image
from tqdm import tqdm
import random

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torchvision as tv

In [None]:
ROOT_DIR = ""
IMG_SIZE = 64
BATCH_SIZE = 128
LATENT_DIMS = 16
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
train_csv = pd.read_csv(ROOT_DIR + "Train.csv")
test_csv = pd.read_csv(ROOT_DIR + "Test.csv")

train_files = train_csv[["Path", "ClassId"]]
test_files = test_csv[["Path", "ClassId"]]

In [None]:
tfms = tv.transforms.Compose([tv.transforms.Resize((IMG_SIZE, IMG_SIZE)), tv.transforms.ToTensor()])
filenames = [os.path.join(dirpath,filename) for dirpath, _, filenames in os.walk(ROOT_DIR + "Train/") for filename in filenames if filename.endswith('.png')]

In [None]:
# Load data into memory

file_arr = []
for i in tqdm(range(len(filenames))):
    image = Image.open(filenames[i])
    tens = tfms(image)
    conv_filename = filenames[i].split("gtsrb/")[-1]
    class_id = int(train_files[train_files["Path"] == conv_filename]["ClassId"].astype(int))
    tens_id_arr = [tens, class_id]
    file_arr.append(tens_id_arr)

In [None]:
# make sure that classes are mixed before splitting array into train and validation set

random.shuffle(file_arr)

train_files = file_arr[:-1000]
valid_files = file_arr[-1000:]

In [None]:
class TSDataset(Dataset):
    def __init__(self, files, transform=None):
        self.files = files
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        x = self.files[idx][0]
        label = self.files[idx][1]
            
        return x, label

In [None]:
training_data = TSDataset(train_files, tfms)
valid_data = TSDataset(valid_files, tfms) 

In [None]:
train_dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True)
# shuffle = false to be able to compare output(-improvements) during training
valid_dataloader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# Sanity check

data, labels = next(iter(valid_dataloader))
n_cols = 8
n_rows = 4

fig = plt.figure(figsize=(25, 16))
for i, img in enumerate(data):
    
    if (n_cols*n_rows) >= (i + 1):
        ax = fig.add_subplot(n_rows, n_cols, i + 1)
        img = img.numpy().transpose(1, 2, 0)
        plt.axis('off')
        plt.imshow((img))

In [None]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)


class UnFlatten(nn.Module):
    def forward(self, input, size=1024):
        return input.view(input.size(0), size, 1, 1)

In [None]:
# https://www.kaggle.com/code/muhammad4hmed/anime-vae/notebook

class CVAE(nn.Module):
    def __init__(self, image_channels=3, h_dim=1024, z_dim=16):
        super().__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 32, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2),
            nn.ReLU(),
            Flatten()
        )
        
        self.signclass_embedding = nn.Embedding(43, 10)
        
        self.h2mu = nn.Linear(h_dim, z_dim)
        self.h2sigma = nn.Linear(h_dim, z_dim)
        self.z2h = nn.Linear(z_dim + 10, h_dim)
        
        self.decoder = nn.Sequential(
            UnFlatten(),
            nn.ConvTranspose2d(h_dim, 128, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(32, image_channels, kernel_size=6, stride=2),
            nn.Sigmoid(),
        )
        
    # Enforce latent space well-formedness by jinecting random gaussian noise    
    def reparameterize(self, mu, logvar):          
        std = logvar.mul(0.5).exp_()
        eps = torch.randn(*mu.size()).to(DEVICE)
        z = mu + std * eps
        return z
    
    def bottleneck(self, h, label):
        mu = self.h2mu(h)
        logvar = self.h2sigma(h)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar
        
    def encode(self, x, label):
        return self.bottleneck(self.encoder(x), label)[0]

    def decode(self, z):
        return self.decoder(self.z2h(z))
    
    def forward(self, x, label):
        h = self.encoder(x)
        z_small, mu, logvar = self.bottleneck(h, label)     
        signclass = self.signclass_embedding(label.long())
        signclass = signclass.squeeze(dim=1)
        z_small_cat = torch.cat([z_small, signclass], dim=1)
        z = self.z2h(z_small_cat)
        return self.decoder(z), mu, logvar, z_small, z

In [None]:
model = CVAE()
model.to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

In [None]:
def vae_loss(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x.view(-1, IMG_SIZE*IMG_SIZE*3),
                                 x.view(-1, IMG_SIZE*IMG_SIZE*3), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [None]:
def num_params(model):
    return sum(p.numel() for p in model.parameters())

num_params(model)

In [None]:
def plot_interm_results():
    x, label = next(iter(valid_dataloader))

    if torch.cuda.is_available():
        x, label = x.cuda(), label.type(torch.FloatTensor).unsqueeze(dim=1).cuda()

    imgs, mu, logvar, _, _ = model(x, label)
    imgs = imgs.detach().cpu()          

    n_cols = 8
    n_rows = 4

    fig = plt.figure(figsize=(25, 16))
    for i in range(n_cols*n_rows):

        if (n_cols*n_rows) >= (i + 1):
            ax = fig.add_subplot(n_rows, n_cols, i + 1)
            img = imgs[i]
            img = img.permute(1, 2, 0)
            plt.axis('off')
            plt.imshow((img))

In [None]:
epochs = 2000
epoch_train_losses = []
epoch_valid_losses = []

for epoch in tqdm(range(epochs)):
    batch_train_losses = []
    batch_valid_losses = []
    
    model.train()
    for data, label in train_dataloader:
        optimizer.zero_grad()
        if torch.cuda.is_available():
            data, label = data.cuda(), label.type(torch.FloatTensor).unsqueeze(dim=1).cuda()

        recon_batch, mu, logvar, _, _ = model(data, label)  
        loss = vae_loss(recon_batch, data, mu, logvar)
        loss.backward()
        optimizer.step()
        
        batch_train_losses.append(loss.item()/data.shape[0])
    epoch_train_losses.append(np.mean(batch_train_losses))
    
    model.eval()  
    for data, label in valid_dataloader:
        if torch.cuda.is_available():
            data, label = data.cuda(), label.type(torch.FloatTensor).unsqueeze(dim=1).cuda()
        
        recon_x, mu, logvar, _, _ = model(data, label)
        loss = vae_loss(recon_x, data, mu, logvar)
        
        batch_valid_losses.append(loss.item()/data.shape[0])
    epoch_valid_losses.append(np.mean(batch_valid_losses))
    
    
    if epoch % 100 == 0 :
        torch.save(model.state_dict(), "cvae_epoch_{}.pth".format(epoch))
        
    print(f'Epoch {epoch+1} \t\t Training Loss: {np.mean(epoch_train_losses)} \t\t Validation Loss: {np.mean(epoch_valid_losses)}')
    
    plt.plot(epoch_train_losses, label = "train_loss")
    plt.plot(epoch_valid_losses, label = "valid_loss")

    plot_interm_results()
    plt.show()

In [None]:
torch.save(model.state_dict(), "...")

# Validation loss comparison

In [None]:
# Save the valid loss array and start training with new latent space value from scratch

plt.plot(dim64 label = "64")
plt.plot(dim32, label = "32")
plt.plot(dim16, label = "16")
plt.plot(dim8, label = "8")
plt.plot(dim4, label = "4")

plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(title="Dimensions")
plt.show()

# Test error comparison (Benchmark)

In [None]:
# https://github.com/poojahira/gtsrb-pytorch

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        # CNN layers
        self.conv1 = nn.Conv2d(3, 100, kernel_size=5)
        self.bn1 = nn.BatchNorm2d(100)
        self.conv2 = nn.Conv2d(100, 150, kernel_size=3)
        self.bn2 = nn.BatchNorm2d(150)
        self.conv3 = nn.Conv2d(150, 250, kernel_size=3)
        self.bn3 = nn.BatchNorm2d(250)
        self.conv_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(250*2*2, 350)
        self.fc2 = nn.Linear(350, nclasses)

        self.localization = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
            )

        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 4 * 4, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
            )
   
        # Initialize the weights/bias with identity transformation
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))


    # Spatial transformer network forward function
    def stn(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, 10 * 4 * 4)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)
        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)
        return x

    def forward(self, x):
        # transform the input
        x = F.interpolate(x, size=(32,32), mode='bilinear')
        x = self.stn(x)

        # Perform forward pass
        x = self.bn1(F.max_pool2d(F.leaky_relu(self.conv1(x)),2))
        x = self.conv_drop(x)
        x = self.bn2(F.max_pool2d(F.leaky_relu(self.conv2(x)),2))
        x = self.conv_drop(x)
        x = self.bn3(F.max_pool2d(F.leaky_relu(self.conv3(x)),2))
        x = self.conv_drop(x)
        x = x.view(-1, 250*2*2)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [None]:
classifier = Net()
classifier.load_state_dict(torch.load("..."))
classifier.to(DEVICE)
classifier.eval()

In [None]:
correct = 0.
nr = 0.01
model.eval()  
for data, label in test_dataloader:
    if torch.cuda.is_available():
        data, label = data.cuda(), label.type(torch.FloatTensor).unsqueeze(dim=1).cuda()
    recon_x, mu, logvar, _, _ = model(data, label)
    recon_x = data_transforms(recon_x)
    output = classifier(recon_x)
    output = torch.argmax(output, dim=1)
    label = label.squeeze(dim=1)
    correct += (output == label).float().sum()
    nr += data.shape[0]

accuracy = 100 * correct / nr
print(accuracy)

# Plot ranges

In [None]:
def plot_images(X, y, yp, M, N):
    f, ax = plt.subplots(M, N, sharex=True, sharey=True, figsize=(N,M*1.3))
    prob = F.softmax(yp)
    print(prob.shape)
    prob = prob.gather(1, y)
    for i in range(M):
        for j in range(N):
            ax[i][j].imshow(X[i*N+j])
            title = ax[i][j].set_title("{:.2f}".format(prob[i*N+j].item()))
            plt.setp(title, color=('g' if yp[i*N+j].max(dim=0)[1] == y[i*N+j] else 'r'))
            ax[i][j].set_axis_off()
    plt.tight_layout()

In [None]:
# Ensemble architecture (combining cvae and classifier)

class Ensemble(nn.Module):
    def __init__(self, embeddings, upscaler, decoder, classifier):
        super(Ensemble, self).__init__()
        self.embeddings = embeddings
        self.upscaler = upscaler
        self.decoder = decoder
        self.classifier = classifier
        
    def forward(self, z, label):
        enc_label = self.embeddings(label.long())
        enc_label = enc_label.squeeze(dim=1)
        x = torch.cat((z, enc_label), dim=1)
        x = self.upscaler(x)
        x = self.decoder(x)
        x = self.classifier(x)
        return x
    
    def get_img(self, z, label):
        enc_label = self.embeddings(label.long())
        x = torch.cat((z, enc_label), dim=1)
        x = self.upscaler(x)
        x = self.decoder(x)
        return x

In [None]:
# Load cvae and classifier into ensemble

embeddings, upscaler, decoder = cvae.extract_model()
ensemble = Ensemble(embeddings, upscaler, decoder, classifier)
ensemble.to(DEVICE);

In [None]:
cvae_data, cvae_labels = next(iter(train_dataloader))
cvae_data, cvae_labels = cvae_data.to(DEVICE), cvae_labels.to(DEVICE)
cvae_labels = cvae_labels.unsqueeze(dim=1)
cvae_data.shape, cvae_labels.shape

In [None]:
mu_range = torch.zeros((160, 16)).to(DEVICE)

# for every dimension, insert the range -X SDs to X SD
for i in range(16):
    tens = torch.range(-3.25, 3.5, 0.75).to(DEVICE)
    mu_range[i*10:(i+1)*10, i] = tens

In [None]:
# Plot all dimensions (here: 16 along specified range)

yp = ensemble(mu_range, labels)

imgs = ensemble.get_img((mu_range), labels.squeeze(dim=1))
imgs = imgs.detach().cpu().numpy()
imgs = imgs.transpose(0, 2, 3, 1)
plot_images(imgs, labels, yp, 16, 10)