In [1]:
from VariationalAutoDecoder import VariationalAutoDecoder as VAD
from VAD_Trainer import VAD_Trainer
import utils
from evaluate import evaluate_model
import torch
import torch.optim as optim
import torch.nn as nn
import csv
import time
import random
import itertools

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


## Create DataLoaders

In [2]:
train_ds, train_dl, test_ds, test_dl = utils.create_dataloaders(data_path="dataset" ,batch_size=64)

## Train Auto Decoder

In [3]:
latent_dims = [16, 32, 64, 128]
betas = [1e5, 5e5, 1e6, 5e6]
VADs = [VAD(latent_dim=dim, device=device) for (dim,_) in list(itertools.product(latent_dims, betas))]
trainers = [VAD_Trainer(var_decoder=VADs[i], dataloader=train_dl, latent_dim=dim, beta=beta, device=device, lr=1e-3)
            for i,(dim,beta) in enumerate(list(itertools.product(latent_dims, betas)))]

In [None]:
num_test_samples = len(train_dl.dataset)
csv_file_path = 'results_VAD.csv'

with open(csv_file_path, mode='w', newline='') as file:
    writer = csv.writer(file)
    header = ['Index'] + [f'Epoch {i+1} Loss' for i in range(500)] + ['Final Train Loss']
    writer.writerow(header)

for index, trainer in enumerate(trainers):
    optimizer = optim.Adam([trainer.latents], lr=1e-3)
    
    start_time = time.time()
    train_loss = trainer.train(num_epochs=500)
    end_time = time.time()
    
    elapsed_time = end_time - start_time
    print(f"Trainer {index} has finished training in {elapsed_time:.2f} seconds.")

    start_time = time.time()
    train_eval_loss = evaluate_model(model=VADs[index], test_dl=train_dl, opt=optimizer, latents=trainer.latents, epochs=500, device=device) 
    end_time = time.time()
    
    elapsed_time = end_time - start_time
    print(f"AD {index} has finished train evaluation in {elapsed_time:.2f} seconds.")

    row = [index] + train_loss + [train_eval_loss]

    with open(csv_file_path, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(row)

print(f"Results saved to {csv_file_path}.")

In [5]:
for i in range(len(trainers)):
    latents = VADs[i].reparameterize(trainers[i].latents)
    utils.plot_tsne(train_ds, latents, f"tsne_{i}")

<Figure size 800x600 with 0 Axes>

<Figure size 800x600 with 0 Axes>

<Figure size 800x600 with 0 Axes>

<Figure size 800x600 with 0 Axes>

<Figure size 800x600 with 0 Axes>

<Figure size 800x600 with 0 Axes>

<Figure size 800x600 with 0 Axes>

<Figure size 800x600 with 0 Axes>

<Figure size 800x600 with 0 Axes>

<Figure size 800x600 with 0 Axes>

<Figure size 800x600 with 0 Axes>

<Figure size 800x600 with 0 Axes>

<Figure size 800x600 with 0 Axes>

<Figure size 800x600 with 0 Axes>

<Figure size 800x600 with 0 Axes>

<Figure size 800x600 with 0 Axes>

## Best model

In [3]:
VAD_best = VAD(latent_dim=128, device=device)
trainer_best = VAD_Trainer(var_decoder=VAD_best, dataloader=train_dl, latent_dim=128, beta=1e6, device=device, lr=0.001)
_ = trainer_best.train(num_epochs=1000)

Epoch [1/1000], Loss: 154753928.0000
Epoch [2/1000], Loss: 150863846.0000
Epoch [3/1000], Loss: 148002060.0000
Epoch [4/1000], Loss: 145408609.0000
Epoch [5/1000], Loss: 142970976.0000
Epoch [6/1000], Loss: 140679386.0000
Epoch [7/1000], Loss: 138511162.0000
Epoch [8/1000], Loss: 136502866.5000
Epoch [9/1000], Loss: 134545313.0000
Epoch [10/1000], Loss: 132728968.5000
Epoch [11/1000], Loss: 131003507.5000
Epoch [12/1000], Loss: 129384099.5000
Epoch [13/1000], Loss: 127802410.0000
Epoch [14/1000], Loss: 126322324.0000
Epoch [15/1000], Loss: 124925971.0000
Epoch [16/1000], Loss: 123510945.0000
Epoch [17/1000], Loss: 122186918.0000
Epoch [18/1000], Loss: 120894334.5000
Epoch [19/1000], Loss: 119693608.5000
Epoch [20/1000], Loss: 118491207.5000
Epoch [21/1000], Loss: 117327397.5000
Epoch [22/1000], Loss: 116145090.5000
Epoch [23/1000], Loss: 114978266.0000
Epoch [24/1000], Loss: 113842454.5000
Epoch [25/1000], Loss: 112697620.0000
Epoch [26/1000], Loss: 111688411.0000
Epoch [27/1000], Loss

In [4]:
num_test_samples = len(train_dl.dataset)
opt = optim.Adam([trainer_best.latents], lr=1e-3)
evaluate_loss = evaluate_model(model=VAD_best, test_dl=train_dl, opt=opt, latents=trainer_best.latents, epochs=1000, device=device)
print(evaluate_loss)

0.21835661493241787


In [5]:
latents = VAD_best.reparameterize(trainer_best.latents)
utils.plot_tsne(train_ds, latents, f"tsne_best")

<Figure size 800x600 with 0 Axes>

## Sample specific vectors

In [6]:
num_test_samples = len(test_dl.dataset)
mu_test = torch.randn(num_test_samples, VAD_best.latent_dim, device=device, requires_grad=True)
sigma_test = torch.randn(num_test_samples, VAD_best.latent_dim, device=device, requires_grad=True)
test_latents = torch.nn.parameter.Parameter(torch.stack([mu_test, sigma_test], dim=1)).to(device)
opt = optim.Adam([test_latents], lr=1e-3)

In [7]:
test_loss = evaluate_model(model=VAD_best, test_dl=test_dl, opt=opt, latents=test_latents, epochs=1000, device=device)
print(f"AD has finished test evaluation with a test loss of {test_loss}.")

AD has finished test evaluation with a test loss of 0.2232640078291297.


In [8]:
final_test_latents = VAD_best.reparameterize(test_latents)
utils.plot_tsne(test_ds, final_test_latents, f"tsne_test")

<Figure size 800x600 with 0 Axes>

In [9]:
random.seed(49)
sampled_indices = random.sample(range(1000), 5)
random_latents_tensor = torch.randn(5,VAD_best.latent_dim, device=device)


sampled_test_images = VAD_best(test_latents[sampled_indices]).view(-1, 1, 28, 28)
random_test_images = VAD_best.decoder(random_latents_tensor).view(-1, 1, 28, 28)

utils.save_images(sampled_test_images, "sampled_test_images_VAD.png")
utils.save_images(random_test_images, "random_test_images_VAD.png")

## Interpolation

In [10]:
import numpy as np
sampled_indices = [1, 25]
sampled_latents = [final_test_latents[i] for i in sampled_indices]
weights = np.linspace(0, 1, 7)
interpolated_latents = [w * sampled_latents[0] + (1 - w) * sampled_latents[1] for w in weights]
interpolated_latents_tensor = torch.stack(interpolated_latents)
interpolated_images = VAD_best.decoder(interpolated_latents_tensor).view(-1, 1, 28, 28)
utils.save_images(interpolated_images, "interpolated_images_VAD.png")