In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import os
import matplotlib.image as img
import tqdm
from vae import Basic_VAE
from encoder import Encoder

#### Create Dataset


In [2]:
cats = []
directory = "data/cats"
count = 0
for catpic in os.listdir(directory):
    if count < 1000:
        # read from image and convert to tensor
        im = torch.tensor(img.imread(os.path.join(directory, catpic))).float()
        # permute to (channels, height, width) for conv2d layer
        im = torch.permute(im, (2, 0, 1))
        # normalize to range between -1 and 1
        im = im / 128 - 1
        cats.append(im)
        count+=1
    else:
        break
cats = torch.stack(cats)
print(cats.shape)

torch.Size([1000, 3, 64, 64])


#### Test/Training/Validation Split


In [3]:
# Split the data into what we use for testing and not testing
traing_test_split = 0.75
training_test_cutoff = int(cats.shape[0] * traing_test_split + 1)
random_perm = torch.randperm(cats.shape[0])
not_test_tensor = cats[:training_test_cutoff]
testing_tensor = cats[training_test_cutoff:]
# Split the data into what we use for training and cross validation
training_cv_split = 0.8
training_cv_cutoff = int(not_test_tensor.shape[0] * training_cv_split)
training_tensor = not_test_tensor[:training_cv_cutoff]
cv_tensor = not_test_tensor[training_cv_cutoff:]
print(training_tensor.shape)
print(cv_tensor.shape)
print(testing_tensor.shape)

torch.Size([600, 3, 64, 64])
torch.Size([151, 3, 64, 64])
torch.Size([249, 3, 64, 64])


#### Choose Hyperparameters and Build Model


In [4]:
from vae import Basic_VAE
from encoder import Encoder
hidden_dims = [16, 32, 64, 128]
latent_dim = 64
in_dim = 3
model = Basic_VAE(in_dim, hidden_dims, latent_dim)
encoder = Encoder(in_dim, hidden_dims, latent_dim)

[128, 64, 32, 16]


In [5]:
# Testing the model
# model.forward(training_tensor)
encoder.forward(training_tensor)

encoded.shape: torch.Size([600, 128, 3, 3])
encoded.shape after flatten: torch.Size([600, 1152])


(tensor([[-1.2788e-01,  3.0896e-01, -9.9229e-02,  ...,  4.6898e-01,
           3.9784e-04, -1.1673e-01],
         [-1.4930e-01,  7.1223e-01, -2.2759e-01,  ..., -6.2941e-03,
          -6.0940e-02, -2.9898e-01],
         [-3.9452e-01,  4.5149e-02, -3.1154e-02,  ...,  9.3408e-01,
           4.0722e-01,  8.8972e-01],
         ...,
         [ 2.1063e-02,  1.0916e-01, -2.0865e-01,  ...,  4.0393e-01,
           3.8704e-01,  1.6069e-01],
         [ 1.6254e-01,  3.4640e-01,  2.8482e-03,  ...,  8.3765e-02,
           4.7057e-01,  4.0400e-01],
         [-3.2551e-01,  5.2595e-02, -1.6702e-01,  ..., -6.3182e-02,
          -7.7393e-02,  5.2902e-01]], grad_fn=<AddmmBackward0>),
 tensor([[-0.0331, -0.1212,  0.6238,  ..., -0.4601, -0.0491, -0.0831],
         [-0.3641, -0.2528,  0.0996,  ..., -0.7813,  0.1304, -0.4069],
         [ 0.0641, -0.6217, -0.2455,  ...,  0.0525,  0.0592, -0.9320],
         ...,
         [-0.3655,  0.0939,  0.0384,  ..., -0.2729,  0.1941, -0.2859],
         [-0.5290,  0.1035,  0

#### Implement Loss Function

In [None]:
def loss_function(self, reconstructed_img, input_img, mu, log_var, kld_weight=2):
    img_loss = F.mse_loss(reconstructed_img, input_img)
    # article on calculating kl divergence between 2 gaussians:
    # https://medium.com/@outerrencedl/variational-autoencoder-and-a-bit-kl-divergence-with-pytorch-ce04fd55d0d7
    kld_loss = torch.mean(
        torch.sum(-log_var + (log_var.exp() ** 2 + mu**2) / 2 - 1 / 2)
    )
    kld_loss *= kld_weight

    return img_loss + kld_loss

#### Implement Gradient Descent


In [None]:
def gradient_descent(model, loss_func, x, y, xvalid, yvalid, lr=0.1, steps=5000):
    optimizer = optim.AdamW(model.parameters(), lr)

    losses = []
    valid_losses = []
    for _ in tqdm.trange(steps):
        model.train()
        loss = loss_func(model(x), y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        model.eval()
        valid_loss = loss_func(model(xvalid), yvalid)
        losses.append(loss.detach().numpy())
        valid_losses.append(valid_loss.detach().numpy())

    print(f"Final training loss: {losses[-1]}")

    return losses, valid_losses