In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import os
import matplotlib.image as img
import tqdm
from vae import Basic_VAE
from encoder import Encoder
from decoder import Decoder

#### Create Dataset


In [2]:
cats = []
directory = "data/cats"
count = 0
for catpic in os.listdir(directory):
    if count < 1000:
        # read from image and convert to tensor
        im = torch.tensor(img.imread(os.path.join(directory, catpic))).float()
        # permute to (channels, height, width) for conv2d layer
        im = torch.permute(im, (2, 0, 1))
        # normalize to range between -1 and 1
        im = im / 128 - 1
        cats.append(im)
        count += 1
    else:
        break
cats = torch.stack(cats)
print(cats.shape)

torch.Size([1000, 3, 64, 64])


#### Test/Training/Validation Split


In [3]:
# Split the data into what we use for testing and not testing
training_test_split = 0.75
training_test_cutoff = int(cats.shape[0] * training_test_split + 1)
random_perm = torch.randperm(cats.shape[0])
not_test_tensor = cats[:training_test_cutoff]
testing_tensor = cats[training_test_cutoff:]
# Split the data into what we use for training and cross validation
training_cv_split = 0.8
training_cv_cutoff = int(not_test_tensor.shape[0] * training_cv_split)
training_tensor = not_test_tensor[:training_cv_cutoff]
cv_tensor = not_test_tensor[training_cv_cutoff:]
print(training_tensor.shape)
print(cv_tensor.shape)
print(testing_tensor.shape)

torch.Size([600, 3, 64, 64])
torch.Size([151, 3, 64, 64])
torch.Size([249, 3, 64, 64])


#### Choose Hyperparameters and Build Model


In [4]:
from vae import Basic_VAE

hidden_dims = [16, 32, 64, 128]
latent_dim = 64
in_dim = 3
model = Basic_VAE(in_dim, hidden_dims, latent_dim)
# encoder = Encoder(in_dim, hidden_dims, latent_dim) for testing
# decoder = Decoder(latent_dim, hidden_dims) for testing

[128, 64, 32, 16]


In [5]:
# # Testing the model
model.forward(training_tensor)

# # Code to test out the encoder & decoder
# mu, log_var = encoder.forward(training_tensor)
# print("mu: ", mu.shape)
# print("log_var: ", log_var.shape)

# reconstructed_img = decoder.forward(mu, log_var)

encoded.shape: torch.Size([600, 128, 3, 3])
encoded.shape after flatten: torch.Size([600, 1152])


[tensor([[[[-1.5987e-01,  3.4740e-01,  1.1799e-01,  ...,  3.5794e-01,
            -1.3262e-01,  2.5225e-01],
           [ 5.2806e-01,  9.0487e-01,  4.3923e-02,  ...,  1.9971e-01,
             1.6595e-01, -1.8641e-03],
           [ 2.3152e-01,  6.6951e-01, -6.6842e-01,  ...,  9.9181e-02,
            -3.0802e-01, -6.8947e-02],
           ...,
           [-1.8711e-01, -2.2090e-01,  1.9113e-01,  ...,  4.0890e-02,
             1.7868e-01, -3.3205e-02],
           [ 7.4919e-02, -1.2330e-01, -7.1301e-01,  ...,  1.1409e-01,
             2.3821e-01,  1.0046e-01],
           [-1.1839e-01, -1.1355e-01, -2.9124e-01,  ..., -6.8516e-02,
            -1.6792e-02,  6.4432e-02]],
 
          [[-1.1686e-01, -6.5236e-02,  8.2952e-01,  ...,  9.0968e-02,
            -2.6019e-01, -2.2322e-01],
           [ 2.2235e-01,  5.2316e-01, -7.5258e-02,  ..., -5.8584e-01,
            -7.6346e-02,  2.2641e-01],
           [ 6.9051e-01, -7.8847e-01,  9.0353e-02,  ...,  4.6588e-01,
             6.1398e-02, -3.0085e-01],


#### Implement Loss Function


In [6]:
def loss_function(reconstructed_img, input_img, mu, log_var, kld_weight=2):
    print("reconstructed_img: ", reconstructed_img.shape)
    print("input_img: ", input_img.shape)
    img_loss = F.mse_loss(reconstructed_img, input_img)
    # article on calculating kl divergence between 2 gaussians:
    # https://medium.com/@outerrencedl/variational-autoencoder-and-a-bit-kl-divergence-with-pytorch-ce04fd55d0d7
    kld_loss = torch.mean(
        torch.sum(-log_var + (log_var.exp() ** 2 + mu**2) / 2 - 1 / 2)
    )
    kld_loss *= kld_weight

    return img_loss + kld_loss

#### Implement Gradient Descent


In [7]:
def gradient_descent(model, loss_func, x, y, xvalid, yvalid, lr=0.1, steps=5000):
    # only really need x or y, they are the same thing
    optimizer = optim.AdamW(model.parameters(), lr)

    losses = []
    valid_losses = []
    for _ in tqdm.trange(steps):
        model.train()
        reconstructed_img, mu, log_var = model(x)
        # y is the original image I think?
        loss = loss_func(reconstructed_img, y, mu, log_var)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        model.eval()
        reconstructed_img, mu, log_var = model(xvalid)
        valid_loss = loss_func(reconstructed_img, yvalid, mu, log_var)
        losses.append(loss.detach().numpy())
        valid_losses.append(valid_loss.detach().numpy())

    print(f"Final training loss: {losses[-1]}")

    return losses, valid_losses

In [8]:
losses, valid_losses = gradient_descent(
    model, loss_function, training_tensor, training_tensor, cv_tensor, cv_tensor
)

  0%|          | 0/5000 [00:00<?, ?it/s]

encoded.shape: torch.Size([600, 128, 3, 3])
encoded.shape after flatten: torch.Size([600, 1152])
reconstructed_img:  torch.Size([600, 3, 64, 64])
input_img:  torch.Size([600, 3, 64, 64])
encoded.shape: torch.Size([151, 128, 3, 3])
encoded.shape after flatten: torch.Size([151, 1152])


  0%|          | 1/5000 [00:06<9:10:45,  6.61s/it]

reconstructed_img:  torch.Size([151, 3, 64, 64])
input_img:  torch.Size([151, 3, 64, 64])
encoded.shape: torch.Size([600, 128, 3, 3])
encoded.shape after flatten: torch.Size([600, 1152])
reconstructed_img:  torch.Size([600, 3, 64, 64])
input_img:  torch.Size([600, 3, 64, 64])
encoded.shape: torch.Size([151, 128, 3, 3])
encoded.shape after flatten: torch.Size([151, 1152])


  0%|          | 2/5000 [00:12<8:39:02,  6.23s/it]

reconstructed_img:  torch.Size([151, 3, 64, 64])
input_img:  torch.Size([151, 3, 64, 64])
encoded.shape: torch.Size([600, 128, 3, 3])
encoded.shape after flatten: torch.Size([600, 1152])
reconstructed_img:  torch.Size([600, 3, 64, 64])
input_img:  torch.Size([600, 3, 64, 64])
encoded.shape: torch.Size([151, 128, 3, 3])
encoded.shape after flatten: torch.Size([151, 1152])
reconstructed_img:  torch.Size([151, 3, 64, 64])
input_img:  torch.Size([151, 3, 64, 64])


  0%|          | 3/5000 [00:15<6:43:33,  4.85s/it]

encoded.shape: torch.Size([600, 128, 3, 3])
encoded.shape after flatten: torch.Size([600, 1152])
reconstructed_img:  torch.Size([600, 3, 64, 64])
input_img:  torch.Size([600, 3, 64, 64])
encoded.shape: torch.Size([151, 128, 3, 3])
encoded.shape after flatten: torch.Size([151, 1152])


  0%|          | 4/5000 [00:18<5:42:26,  4.11s/it]

reconstructed_img:  torch.Size([151, 3, 64, 64])
input_img:  torch.Size([151, 3, 64, 64])
encoded.shape: torch.Size([600, 128, 3, 3])
encoded.shape after flatten: torch.Size([600, 1152])


  0%|          | 4/5000 [00:20<6:57:08,  5.01s/it]


KeyboardInterrupt: 

In [None]:
# after we've trained, we use the decoder to generate new images
# we assume a normal distribution over our latent space
# so we sample from that distribution and feed it into the decoder