In [None]:
import os
import time
import random
import requests

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision.utils
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms

from typing import Tuple

import json
from PIL import Image

import io
from tqdm.notebook import tqdm


seed = 123
np.random.seed(seed)
_ = torch.manual_seed(seed)
_ = torch.cuda.manual_seed(seed)

# we select to work on GPU if it is available in the machine, otherwise will run on CPU
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

hparams = {
    'batch_size':64,
    'num_epochs':8,
    'channels':32,
    'latent_dims':64,
    'variational_beta':1,
    'learning_rate':1e-3,
    'weight_decay':1e-5
}

In [None]:
folder_path = "/kaggle/input/masked10k/masked_images/"
batch_size = hparams["batch_size"]

images = [folder_path + im for im in os.listdir(folder_path) if im.endswith(".jpg")]

n_samples = len(images)
print(f"{n_samples} images")

# indices for all time steps where the episode continues
indices = np.arange(n_samples, dtype="int64")
np.random.shuffle(indices)

# split indices into minibatches. minibatchlist is a list of lists; each
# list is the id of the observation preserved through the training
minibatchlist = [
    np.array(sorted(indices[start_idx : start_idx + batch_size]))
    for start_idx in range(0, len(indices) - batch_size + 1, batch_size)
]

In [None]:
# we define the Autoencoder class and all its components
class Encoder(nn.Module):
    def __init__(
            self,
            channels: int,
            latent_dims: int,
            ) -> None:

        super(Encoder, self).__init__()

        self.c = channels
        self.bnorm1 = nn.BatchNorm2d(3)
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=self.c, kernel_size=4, stride=2, padding=1)
        self.bnorm2 = nn.BatchNorm2d(self.c)
        self.conv2 = nn.Conv2d(in_channels=self.c, out_channels=2*self.c, kernel_size=4, stride=2, padding=1)
        self.bnorm3 = nn.BatchNorm2d(2*self.c)
        self.conv3 = nn.Conv2d(in_channels=2*self.c, out_channels=3*self.c, kernel_size=4, stride=2, padding=1)
        self.bnorm4 = nn.BatchNorm2d(3*self.c)
        self.conv4 = nn.Conv2d(in_channels=3*self.c, out_channels=4*self.c, kernel_size=4, stride=2, padding=1)
        self.fc_mu = nn.Linear(in_features=self.c*4*8*8, out_features=latent_dims)
        self.fc_logvar = nn.Linear(in_features=self.c*4*8*8, out_features=latent_dims)

    def forward(self, x: torch.Tensor)-> Tuple[torch.Tensor, torch.Tensor]:

        out = F.relu(self.conv1(self.bnorm1(x))) # Batch x channels x 64 x 64
        out = F.relu(self.conv2(self.bnorm2(out))) # Batch x 2*channels x 32 x 32
        out = F.relu(self.conv3(self.bnorm3(out))) # Batch x 3*channels x 16 x 16
        out = F.relu(self.conv4(self.bnorm4(out))) # Batch x 4*channels x 8 x 8
        out = out.view(out.size(0), -1) # flatten batch of multi-channel feature maps to a batch of feature vectors

        # We obtain the mean and covariance matrices from the output of the linear layers
        x_mu = self.fc_mu(out)
        x_logvar = self.fc_logvar(out)

        return x_mu, x_logvar

    
class Decoder(nn.Module):
    def __init__(
            self,
            channels: int,
            latent_dims: int
            ) -> None:

        super(Decoder, self).__init__()
        self.c = channels
        self.fc = nn.Linear(in_features=latent_dims, out_features=self.c*4*8*8)
        self.conv4 = nn.ConvTranspose2d(in_channels=4*self.c, out_channels=3*self.c, kernel_size=4, stride=2, padding=1)
        self.bnorm4 = nn.BatchNorm2d(4*self.c)
        self.conv3 = nn.ConvTranspose2d(in_channels=3*self.c, out_channels=2*self.c, kernel_size=4, stride=2, padding=1)
        self.bnorm3 = nn.BatchNorm2d(3*self.c)
        self.conv2 = nn.ConvTranspose2d(in_channels=self.c*2, out_channels=self.c, kernel_size=4, stride=2, padding=1)
        self.bnorm2 = nn.BatchNorm2d(self.c*2)
        self.conv1 = nn.ConvTranspose2d(in_channels=self.c, out_channels=3, kernel_size=4, stride=2, padding=1)
        self.bnorm1 = nn.BatchNorm2d(self.c)

    def forward(self, z: torch.Tensor) -> torch.Tensor:

        out = self.fc(z)
        out = out.view(out.size(0), self.c*4, 8, 8) # unflatten batch

        out = F.relu(self.conv4(self.bnorm4(out)))
        out = F.relu(self.conv3(self.bnorm3(out)))
        out = F.relu(self.conv2(self.bnorm2(out)))
        out = torch.sigmoid(self.conv1(self.bnorm1(out)))
        return out
    
    
class VariationalAutoencoder(nn.Module):
    def __init__(
            self,
            z_dims: int,
            n_ch: int,
            ) -> None:

        super(VariationalAutoencoder, self).__init__()
        self.encoder = Encoder(channels=n_ch, latent_dims=z_dims)
        self.decoder = Decoder(channels=n_ch, latent_dims=z_dims)

    def reparametrize(
            self,
            mu:torch.Tensor,
            logvar:torch.Tensor,
            ) -> torch.Tensor:
        # Given mean and logvar returns z
        # reparameterization trick: instead of sampling from Q(z|X), sample epsilon = N(0,I)
        # mu, logvar: mean and log of variance of Q(z|X)

        # The factor 1/2 in the exponent ensures that the distribution has unit variance
        std = torch.exp(0.5 * logvar)
        # Random sample
        eps = torch.randn_like(std)
        return mu + std * eps

    def forward(self, x: torch.Tensor) -> Tuple[
            torch.Tensor, torch.Tensor, torch.Tensor]:
        latent_mu, latent_logvar = self.encoder(x)
        z = self.reparametrize(latent_mu, latent_logvar)
        x_recon = self.decoder(z)

        return x_recon, latent_mu, latent_logvar

    
def vae_loss(
        recon_x: torch.Tensor,
        x: torch.Tensor,
        mu: torch.Tensor,
        logvar: torch.Tensor,
        variational_beta: int=1,
        ) -> float:
    # recon_x is the probability of a multivariate Bernoulli distribution p.
    # -log(p(x)) is then the pixel-wise binary cross-entropy.

    recon_loss = F.binary_cross_entropy(recon_x.view(-1, 393216), x.view(-1, 393216), reduction='sum')
    kldivergence = variational_beta * (-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()))
    mean_batch_loss = (recon_loss +  kldivergence)/x.shape[0]

    return mean_batch_loss

In [None]:
def train_batch(
        image_batch: torch.Tensor,
        vae: torch.nn.Module,
        vae_loss: torch.nn.Module,
        optimizer: torch.optim,
        ) -> float:

    image_batch = image_batch.to(device)

    # Get vae reconstruction and loss
    image_batch_recon, latent_mu, latent_logvar = vae(image_batch)
    loss = vae_loss(image_batch_recon, image_batch, latent_mu, latent_logvar)

    # backpropagation
    optimizer.zero_grad()
    loss.backward()

    # one step of the optmizer (using the gradients from backpropagation)
    optimizer.step()

    return loss.item()

In [None]:
transf = transforms.Compose([
        transforms.Resize((128, 128), Image.BICUBIC),
        transforms.ToTensor(),
    ])

In [None]:
start_time=time.time()

# Instantiation of optimizer and model
vae_2z = VariationalAutoencoder(hparams['latent_dims'], hparams['channels']).to(device)
optimizer = torch.optim.Adam(params=vae_2z.parameters(), lr=hparams['learning_rate'], weight_decay=hparams['weight_decay'])

# Number of parameters used in the model
num_params = sum(p.numel() for p in vae_2z.parameters() if p.requires_grad)
print(f'Number of parameters: {num_params}')

# set to training mode
vae_2z.train()

train_loss_avg = []

print('Training ...')
for epoch in range(hparams['num_epochs']):
    print(f"Epoch {epoch}")
    train_loss_avg.append(0)
    num_batches = 0
    print(f"Minibatchlist size: {len(minibatchlist)}")
    for i,image_batch in tqdm(enumerate(minibatchlist)):
        img_batch = None
        for item in image_batch:
            try:
                if img_batch is None:
                    img_batch = transf(Image.open(images[item]).convert("RGB")).unsqueeze(0)
                else:
                    img_batch = torch.cat((img_batch,transf(Image.open(images[item]).convert("RGB")).unsqueeze(0)),0)
            except:
                done = False
                while not done:
                    try:
                        idx = random.randint(0,len(images)-1)
                        if img_batch is None:
                            img_batch = transf(Image.open(images[idx]).convert("RGB")).unsqueeze(0)
                        else:
                            img_batch = torch.cat((img_batch,transf(Image.open(images[idx]).convert("RGB")).unsqueeze(0)),0)
                        done = True
                    except:
                        pass
        loss_batch = train_batch(img_batch, vae_2z, vae_loss, optimizer)
        train_loss_avg[-1] += loss_batch

    train_loss_avg[-1] /= i
    print('Epoch [%d / %d] average reconstruction error: %f' % (epoch+1, hparams['num_epochs'], train_loss_avg[-1]))
    
    trans = torchvision.transforms.ToPILImage()
    inp = transf(Image.open(images[0])).unsqueeze(0).to(device)
    display(trans(inp.squeeze()))
    x,_,_ = vae_2z(inp)
    out = trans(x.squeeze())
    display(out)
    print("--- ELAPSED TIME: %s min ---" % (round((time.time() - start_time) / 60, 3)))

print("--- TOTAL TIME: %s min ---" % (round((time.time() - start_time) / 60, 3)))

### We save the trained model

In [None]:
m = torch.jit.script(vae_2z)
# Save to file
torch.jit.save(m, 'model.pt')

### We load the saved models

In [None]:
model_path1 = "/kaggle/input/mod-mask/model_mask.pt"
mod1 = torch.jit.load(model_path1).to(device)

model_path2 = "/kaggle/input/modelo/model.pt"
mod2 = torch.jit.load(model_path2).to(device)

### We check the distance between the embeddings of some images

In [None]:
folder_path = "/kaggle/input/hackupc/Imatges/"

images = [folder_path + im for im in os.listdir(folder_path) if im.endswith(".jpg")]

trans = torchvision.transforms.ToPILImage()
dist = nn.PairwiseDistance()

def display_vae_reconstr(mod, inp):
    display(trans(inp.squeeze()))
    x,_,_ = mod(inp)
    out = trans(x.squeeze())
    display(out)
    
def find_closer(mod,inp,chosen_i):
    latent_mu, latent_logvar = mod.encoder(inp)
    z = mod.reparametrize(latent_mu, latent_logvar)
    min_dist = 100000
    min_z = 0
    for i in range(1000):
        try:
            inp2 = transf(Image.open(images[i])).unsqueeze(0).to(device)
            latent_mu, latent_logvar = mod.encoder(inp2)
            z2 = mod.reparametrize(latent_mu, latent_logvar)
            if dist(z,z2) < min_dist and i != chosen_i:
                min_dist = dist(z,z2)
                min_index = i
        except:
            pass
    return min_index, min_dist

chosen_i = 244
inp = transf(Image.open(images[chosen_i])).unsqueeze(0).to(device)
display_vae_reconstr(mod2, inp)

i1,_ = find_closer(mod2,inp,chosen_i)

inp_mod1 = transf(Image.open(images[i1])).unsqueeze(0).to(device)

display_vae_reconstr(mod2, inp_mod1)

### Saving the embeddings of every image

In [None]:
for i,img in enumerate(images):
    try:
        inp = transf(Image.open(img)).unsqueeze(0).to(device)
        latent_mu, latent_logvar = mod2.encoder(inp)
        z = mod2.reparametrize(latent_mu, latent_logvar)
        col = torch.flatten(z.cpu()).detach().numpy()
        if i == 0:
            res = pd.DataFrame(col, columns=["embedding_1"])
        else:
            colname = "embedding_"+str(i+1)
            res[colname] = col
    except:
        pass

res.to_csv('emb_vae.csv', index=True)