In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import os
import matplotlib.image as img
import tqdm
from vae import Basic_VAE
from encoder import Encoder
from decoder import Decoder
from torch.utils.data import DataLoader, TensorDataset

#### Create Dataset


In [None]:
cats = []
directory = "data/cats"
count = 0

for catpic in os.listdir(directory):
    if count < 1000:
        # read from image and convert to tensor
        im = torch.tensor(img.imread(os.path.join(directory, catpic))).float()
        # permute to (channels, height, width) for conv2d layer
        im = torch.permute(im, (2, 0, 1))
        # normalize to range between -1 and 1
        im = im / 128 - 1
        cats.append(im)
        count += 1
    else:
        break
cats = torch.stack(cats)
print(cats.shape)

#### Test/Training/Validation Split


In [None]:
# Split the data into what we use for testing and not testing
training_test_split = 0.75
training_test_cutoff = int(cats.shape[0] * training_test_split + 1)
random_perm = torch.randperm(cats.shape[0])
not_test_tensor = cats[:training_test_cutoff]
testing_tensor = cats[training_test_cutoff:]
# Split the data into what we use for training and cross validation
training_cv_split = 0.8
training_cv_cutoff = int(not_test_tensor.shape[0] * training_cv_split)
training_tensor = not_test_tensor[:training_cv_cutoff]
cv_tensor = not_test_tensor[training_cv_cutoff:]
print(training_tensor.shape)
print(cv_tensor.shape)
print(testing_tensor.shape)

#### Choose Hyperparameters and Build Model


In [None]:
from vae import Basic_VAE

hidden_dims = [16, 32, 64, 128]
latent_dim = 64
in_dim = 3
model = Basic_VAE(in_dim, hidden_dims, latent_dim)
# encoder = Encoder(in_dim, hidden_dims, latent_dim) for testing
# decoder = Decoder(latent_dim, hidden_dims) for testing

In [None]:
# # Testing the model
model.forward(training_tensor)

# # Code to test out the encoder & decoder
# mu, log_var = encoder.forward(training_tensor)
# print("mu: ", mu.shape)
# print("log_var: ", log_var.shape)

# reconstructed_img = decoder.forward(mu, log_var)

[tensor([[[[ 1.0545e-01, -2.7190e-01, -7.4478e-02,  ...,  2.4903e-01,
             7.9888e-02, -4.2045e-02],
           [ 4.4215e-01, -4.0652e-01, -4.1028e-01,  ...,  5.8318e-01,
             1.5132e-01, -1.7852e-01],
           [ 5.8461e-01, -2.4170e-01,  1.7816e-01,  ...,  2.9364e-01,
             8.4248e-02, -2.6859e-02],
           ...,
           [-2.7778e-02, -5.0840e-02,  4.1793e-01,  ..., -1.4100e-01,
             3.3263e-01, -1.6155e-02],
           [ 6.7371e-03, -4.0904e-02, -2.9629e-01,  ..., -1.2480e-01,
             7.1517e-02,  1.1990e-02],
           [-7.6130e-02, -1.0520e-01,  5.1723e-01,  ...,  2.4237e-01,
             1.1980e-02,  1.1654e-01]],
 
          [[ 1.5131e-01, -2.1218e-01,  1.7692e-01,  ...,  2.0670e-02,
             9.7760e-02,  2.1136e-02],
           [ 3.6103e-01, -3.7454e-01,  5.6006e-01,  ...,  4.8172e-01,
             3.7732e-02,  7.5365e-02],
           [-2.3069e-01,  6.8395e-02, -2.3192e-01,  ..., -1.0070e-01,
            -2.5704e-01, -1.0155e-01],


#### Implement Loss Function


In [9]:
def loss_function(reconstructed_img, input_img, mu, log_var, kld_weight=2):
    # print("reconstructed_img: ", reconstructed_img.shape)
    # print("input_img: ", input_img.shape)
    img_loss = F.mse_loss(reconstructed_img, input_img)
    # article on calculating kl divergence between 2 gaussians:
    # https://medium.com/@outerrencedl/variational-autoencoder-and-a-bit-kl-divergence-with-pytorch-ce04fd55d0d7
    # kld_loss = torch.mean(
    #     torch.sum(-log_var + (log_var.exp() ** 2 + mu**2) / 2 - 1 / 2)
        
    # )
    kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()))
    kld_loss *= kld_weight

    return img_loss + kld_loss

#### Implement Gradient Descent


In [10]:
def gradient_descent(model, loss_func, x, y, xvalid, yvalid, lr=0.1, steps=5000, batch_size=32):
    # only really need x or y, they are the same thing
    optimizer = optim.AdamW(model.parameters(), lr)

    losses = []
    valid_losses = []

    train_dataset = TensorDataset(x, y)
    train_batches = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    valid_dataset = TensorDataset(xvalid, yvalid)
    valid_batches = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

    for _ in tqdm.trange(steps):
        model.train()
        total_loss = 0

        for input_batch, label_batch in train_batches:
            reconstructed_img, mu, log_var = model(input_batch)
            # y is the original image I think? yeah
            loss = loss_func(reconstructed_img, label_batch, mu, log_var)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            optimizer.zero_grad()
        mean_loss = total_loss/len(train_batches)
        losses.append(mean_loss)

        model.eval()
        total_valid_loss = 0
        for input_valid_batch, label_valid_batch in valid_batches:
            reconstructed_img, mu, log_var = model(input_valid_batch)
            valid_loss = loss_func(reconstructed_img, label_valid_batch, mu, log_var)
            total_valid_loss += valid_loss.detach()
        mean_valid_loss = total_valid_loss/len(valid_batches)
        valid_losses.append(mean_valid_loss)

    print(f"Final training loss: {losses[-1]}")

    return losses, valid_losses

In [13]:
losses, valid_losses = gradient_descent(
    model, loss_function, training_tensor, training_tensor, cv_tensor, cv_tensor
)

100%|██████████| 5000/5000 [2:50:00<00:00,  2.04s/it]  

Final training loss: nan





In [14]:
torch.save(model.state_dict(), "saved_model")

In [None]:
# after we've trained, we use the decoder to generate new images
# we assume a normal distribution over our latent space
# so we sample from that distribution and feed it into the decoder