VAD Part

In [1]:
from utils import create_dataloaders
from evaluate import reconstruction_loss, evaluate_model
from VAD_Expo import VariationalAutoDecoder_Expo
from trainer import VADTrainer_Expo

In [2]:
train_ds, train_dl, test_ds, test_dl = create_dataloaders(data_path='', batch_size=64)


In [3]:
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

latent_dim = 128

model = VariationalAutoDecoder_Expo(latent_dim, device=device)

trainer = VADTrainer_Expo(model=model, dl=train_dl, latent_dim=model.latent_dim, device=device)

trainer.train(num_epochs=800)

Using device: cuda
Epoch [1/800], Loss: 71752.6287
Epoch [2/800], Loss: 28058.2604
Epoch [3/800], Loss: 24743.1710
Epoch [4/800], Loss: 22843.4323
Epoch [5/800], Loss: 21441.1523
Epoch [6/800], Loss: 20299.2112
Epoch [7/800], Loss: 19327.9678
Epoch [8/800], Loss: 18482.1697
Epoch [9/800], Loss: 17733.9772
Epoch [10/800], Loss: 17064.2755
Epoch [11/800], Loss: 16459.1434
Epoch [12/800], Loss: 15908.0663
Epoch [13/800], Loss: 15402.8909
Epoch [14/800], Loss: 14937.1625
Epoch [15/800], Loss: 14505.6815
Epoch [16/800], Loss: 14104.2043
Epoch [17/800], Loss: 13729.2151
Epoch [18/800], Loss: 13377.7745
Epoch [19/800], Loss: 13047.3892
Epoch [20/800], Loss: 12735.9493
Epoch [21/800], Loss: 12441.6302
Epoch [22/800], Loss: 12162.8517
Epoch [23/800], Loss: 11898.2449
Epoch [24/800], Loss: 11646.6062
Epoch [25/800], Loss: 11406.8722
Epoch [26/800], Loss: 11178.1040
Epoch [27/800], Loss: 10959.4708
Epoch [28/800], Loss: 10750.2231
Epoch [29/800], Loss: 10549.6880
Epoch [30/800], Loss: 10357.2714


[71752.62866210938,
 28058.260375976562,
 24743.171020507812,
 22843.432250976562,
 21441.152282714844,
 20299.211181640625,
 19327.967834472656,
 18482.16973876953,
 17733.97723388672,
 17064.275451660156,
 16459.143432617188,
 15908.066345214844,
 15402.890930175781,
 14937.162475585938,
 14505.681457519531,
 14104.204345703125,
 13729.215148925781,
 13377.774475097656,
 13047.389221191406,
 12735.949340820312,
 12441.630249023438,
 12162.85171508789,
 11898.244903564453,
 11646.606201171875,
 11406.87222290039,
 11178.104034423828,
 10959.470825195312,
 10750.22314453125,
 10549.688018798828,
 10357.271423339844,
 10172.42495727539,
 9994.653411865234,
 9823.514556884766,
 9658.594635009766,
 9499.523406982422,
 9345.961181640625,
 9197.589965820312,
 9054.121795654297,
 8915.293579101562,
 8780.853118896484,
 8650.577392578125,
 8524.258239746094,
 8401.690185546875,
 8282.699127197266,
 8167.106872558594,
 8054.763854980469,
 7945.511993408203,
 7839.220458984375,
 7735.7520751953

In [4]:
# Evaluation on the train dataset
latents = torch.randn((len(test_dl.dataset), model.latent_dim)).to(device)
train_latents = torch.nn.parameter.Parameter(latents).to(device)
train_latents = train_latents.detach().requires_grad_()
opt = torch.optim.Adam([train_latents], lr=1e-3)
train_loss = evaluate_model(model=model, test_dl=train_dl, opt=opt, latents=train_latents, epochs=800, device=device)
print(f"AD has finished test evaluation with a train loss of {train_loss}.")


# Evaluation on the test dataset
latents = torch.randn((len(test_dl.dataset), model.latent_dim)).to(device)

# mu_test = torch.randn(len(test_dl.dataset), latent_dim, device=device, requires_grad=True)
# sigma_test = torch.randn(len(test_dl.dataset), latent_dim, device=device, requires_grad=True)
test_latents = torch.nn.parameter.Parameter(latents).to(device)
# test_latents = test_latents.detach().requires_grad_()

opt = torch.optim.Adam([test_latents], lr=1e-3)
test_loss = evaluate_model(model=model, test_dl=test_dl, opt=opt, latents=test_latents, epochs=800, device=device)
print(f"AD has finished test evaluation with a test loss of {test_loss}.")

AD has finished test evaluation with a train loss of 0.33976610004901886.
AD has finished test evaluation with a test loss of 0.3378761559724808.


In [5]:
# Randomly sample 5 indices from the test dataset
import random
import utils

random.seed(6)
sampled_indices = random.sample(range(len(test_latents)), 5)

# Extract the corresponding vectors (input data) and their labels
sampled_latents = [test_latents[i] for i in sampled_indices]  # Only selecting input data, not labels

# Convert to a single tensor (optional)
sampled_latents_tensor = torch.stack(sampled_latents)
random_latents_tensor = torch.randn([5,128], device=device)

print("Sampled Vectors Shape:", sampled_latents_tensor.shape)  # Should be (5, *) depending on your data shape
print("Random Vectors Shape:", random_latents_tensor.shape)  # Should be (5, *) depending on your data shape

sampled_test_images = model(sampled_latents_tensor).view(-1, 1, 28, 28)
random_test_images = model.david_forward(random_latents_tensor).view(-1, 1, 28, 28)

print("Sampled Images Shape:", sampled_test_images.shape)  # Should be (5, *) depending on your data shape
utils.save_images(sampled_test_images, "sampled_test_images.png")
utils.save_images(random_test_images, "random_test_images.png")
# import matplotlib.pyplot as plt
# plt.imshow(sampled_test_images.detach().cpu().numpy().reshape(28, 28), cmap='gray')
# plt.show()

Sampled Vectors Shape: torch.Size([5, 128])
Random Vectors Shape: torch.Size([5, 128])
Sampled Images Shape: torch.Size([5, 1, 28, 28])


In [6]:
from utils import plot_tsne

result_test_latents = model.reparameterize(test_latents)
utils.plot_tsne(test_ds, result_test_latents, f"tsne_test")

<Figure size 800x600 with 0 Axes>

In [7]:
# #INTERPOLATION

#INTERPOLATION CODE !!

# need to check beforehand that the picked pictures will be of diffrenet classes
sampled_latents = [result_test_latents[18],  result_test_latents[69]]

#they asked for 5 differenet equally distrubed values, can probably do this using np but whatever
weights = [0, 0.25, 0.5, 0.75, 1]

interpolated_latents = [w * sampled_latents[0] + (1 - w) * sampled_latents[1] for w in weights]
interpolated_latents_tensor = torch.stack(interpolated_latents)
interpolated_images = model.decoder(interpolated_latents_tensor).view(-1, 1, 28, 28)


utils.save_images(interpolated_images, "interpolted_image_normal_dist.png")