# Coursework 2: Generative Models

## Instructions

### Submission
Please submit one zip file on cate - *CW2.zip* containing the following:
1. A version of this notebook containing your answers. Write your answers in the cells below each question. **Please deliver the notebook including the outputs of the cells**
2. Your trained VAE model as *VAE_model.pth*
3. Your trained Generator and Discriminator: *DCGAN_model_D.pth and DCGAN_model_G.pth*


### Training
Training the GAN will take quite a long time (multiple hours), please refer to the 4 GPU options detailed in the logistics lecture. Some additional useful pointers:
* PaperSpace [guide if you need more compute](https://hackmd.io/@afspies/S1stL8Qnt)
* Lab GPUs via SSH.  The VSCode Remote Develop extension is recommended for this. For general Imperial remote working instructions see [this post](https://www.doc.ic.ac.uk/~nuric/teaching/remote-working-for-imperial-computing-students.html). You'll also want to [setup your environment as outlined here](https://hackmd.io/@afspies/Bkd7Zq60K).
* Use Colab and add checkpointing to the model training code; this is to handle the case where colab stops a free-GPU kernel after a certain number of hours (~4).
* Use Colab Pro - If you do not wish to use PaperSpace then you can pay for Colab Pro. We cannot pay for this on your behalf (this is Google's fault).


### Testing
TAs will run a testing cell (at the end of this notebook), so you are required to copy your data ```transform``` and ```denorm``` functions to a cell near the bottom of the document (it is demarkated). You are advised to check that your implementations pass these tests (in particular, the jit saving and loading may not work for certain niche functions)

### General
You can feel free to add architectural alterations / custom functions outside of pre-defined code blocks, but if you manipulate the model's inputs in some way, please include the same code in the TA test cell, so our tests will run easily.

<font color="orange">**The deadline for submission is Monday, 26 Feb by 6 pm** </font>

## Setting up working environment
You will need to install pytorch and import some utilities by running the following cell:

In [23]:
!pip install -q torch torchvision altair==4.2.2 seaborn tqdm
!git clone -q https://github.com/afspies/icl_dl_cw2_utils
from icl_dl_cw2_utils.utils.plotting import plot_tsne
from pathlib import Path
import tqdm

fatal: destination path 'icl_dl_cw2_utils' already exists and is not an empty directory.


Here we have some default pathing options which vary depending on the environment you are using. You can of course change these as you please.

In [24]:
# Initialization Cell
WORKING_ENV = 'COLAB' # Can be LABS, COLAB, PAPERSPACE, SAGEMAKER
USERNAME = '' # If working on Lab Machines - Your college username
assert WORKING_ENV in ['LABS', 'COLAB', 'PAPERSPACE', 'SAGEMAKER']

if WORKING_ENV == 'COLAB':
    from google.colab import drive
    %load_ext google.colab.data_table
    dl_cw2_repo_path = 'dl_cw2/' # path in your gdrive to the repo
    content_path = f'/content/drive/MyDrive/{dl_cw2_repo_path}' # path to gitrepo in gdrive after mounting
    data_path = './data/' # save the data locally
    drive.mount('/content/drive/') # Outputs will be saved in your google drive

elif WORKING_ENV == 'LABS':
    content_path = f'/vol/bitbucket/{USERNAME}/dl/dl_cw2/' # You may want to change this
    data_path = f'/vol/bitbucket/{USERNAME}/dl/'
    # Your python env and training data should be on bitbucket
    if 'vol/bitbucket' not in content_path or 'vol/bitbucket' not in data_path:
        import warnings
        warnings.warn(
           'It is best to create a dir in /vol/bitbucket/ otherwise you will quickly run into memory issues'
           )
elif WORKING_ENV == 'PAPERSPACE': # Using Paperspace
    # Paperspace does not properly render animated progress bars
    # Strongly recommend using the JupyterLab UI instead of theirs
    !pip install ipywidgets
    content_path = '/notebooks/'
    data_path = './data/'

elif WORKING_ENV == 'SAGEMAKER':
    content_path = '/home/studio-lab-user/sagemaker-studiolab-notebooks/dl/'
    data_path = f'{content_path}data/'

else:
  raise NotImplementedError()

content_path = Path(content_path)

The google.colab.data_table extension is already loaded. To reload it, use:
  %reload_ext google.colab.data_table
Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


## Introduction

For this coursework, you are asked to implement two commonly used generative models:
1. A **Variational Autoencoder (VAE)**
2. A **Deep Convolutional Generative Adversarial Network (DCGAN)**

For the first part you will the MNIST dataset https://en.wikipedia.org/wiki/MNIST_database and for the second the CIFAR-10 (https://www.cs.toronto.edu/~kriz/cifar.html).

Each part is worth 50 points.

The emphasis of both parts lies in understanding how the models behave and learn, however, some points will be available for getting good results with your GAN (though you should not spend too long on this).

# Part 1 - Variational Autoencoder

## Part 1.1 (25 points)
**Your Task:**

a. Implement the VAE architecture with accompanying hyperparameters. More marks are awarded for using a Convolutional Encoder and Decoder.

b. Design an appropriate loss function and train the model.


In [25]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, sampler
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import torch.nn.functional as F
import matplotlib.pyplot as plt

def show(img):
    npimg = img.cpu().numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)))

if not os.path.exists(content_path/'CW_VAE/'):
    os.makedirs(content_path/'CW_VAE/')

if not os.path.exists(data_path):
    os.makedirs(data_path)

# We set a random seed to ensure that your results are reproducible.
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
torch.manual_seed(0)

GPU = True # Choose whether to use GPU
if GPU:
    device = torch.device("cuda"  if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")
print(f'Using {device}')

Using cuda


---
## Part 1.1a: Implement VAE (25 Points)
### Hyper-parameter selection


In [26]:
# Necessary Hyperparameters
num_epochs = 30
learning_rate = 0.0001
batch_size = 128
latent_dim = 50 # Choose a value for the size of the latent space

# Additional Hyperparameters
beta = 1

# (Optionally) Modify transformations on input
transform = transforms.Compose([
    transforms.ToTensor(),
])

# (Optionally) Modify the network's output for visualizing your images
def denorm(x):
    return x

### Data loading


In [27]:
train_dat = datasets.MNIST(
    data_path, train=True, download=True, transform=transform
)
test_dat = datasets.MNIST(data_path, train=False, transform=transform)

loader_train = DataLoader(train_dat, batch_size, shuffle=True)
loader_test = DataLoader(test_dat, batch_size, shuffle=False)

# Don't change
sample_inputs, _ = next(iter(loader_test))
fixed_input = sample_inputs[:32, :, :, :]
save_image(fixed_input, content_path/'CW_VAE/image_original.png')

### Model Definition

<figure>
  <img src="https://blog.bayeslabs.co/assets/img/vae-gaussian.png" style="width:60%">
  <figcaption>
    Fig.1 - VAE Diagram (with a Guassian prior), taken from <a href="https://blog.bayeslabs.co/2019/06/04/All-you-need-to-know-about-Vae.html">1</a>.
  </figcaption>
</figure>


You will need to define:
* The hyperparameters
* The constructor
* encode
* reparametrize
* decode
* forward



Hints:
- It is common practice to encode the log of the variance, rather than the variance
- You might try using BatchNorm

In [28]:
# *CODE FOR PART 1.1a IN THIS CELL*

class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        #######################################################################
        #                       ** START OF YOUR CODE **
        #######################################################################
        self.image_channel = 1 # MNIST
        self.encoder = nn.Sequential(
            nn.Conv2d(self.image_channel, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )
        # 64*7*7 is number of channels times intermediate dimension of image
        self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
        self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)
        self.fc_decoder = nn.Linear(latent_dim, 64 * 7 * 7)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, self.image_channel, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
        )
        #######################################################################
        #                       ** END OF YOUR CODE **
        #######################################################################

    def encode(self, x):
        #######################################################################
        #                       ** START OF YOUR CODE **
        #######################################################################
        x_encoded = self.encoder(x)
        mu, logvar = self.fc_mu(x_encoded), self.fc_logvar(x_encoded)
        z = self.reparametrize(mu, logvar)
        return z, mu, logvar
        #######################################################################
        #                       ** END OF YOUR CODE **
        #######################################################################

    def reparametrize(self, mu, logvar):
        #######################################################################
        #                       ** START OF YOUR CODE **
        #######################################################################
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
        #######################################################################
        #                       ** END OF YOUR CODE **
        #######################################################################

    def decode(self, z):
        #######################################################################
        #                       ** START OF YOUR CODE **
        #######################################################################
        z = self.fc_decoder(z)
        z_reshaped = z.view(-1, 64, 7, 7)
        z_decoded = self.decoder(z_reshaped)
        return z_decoded
        #######################################################################
        #                       ** END OF YOUR CODE **
        #######################################################################

    def forward(self, x):
        #######################################################################
        #                       ** START OF YOUR CODE **
        #######################################################################
        z, mu, logvar = self.encode(x)
        x_reconstructed = self.decode(z)
        return x_reconstructed, mu, logvar
        #######################################################################
        #                       ** END OF YOUR CODE **
        #######################################################################

model = VAE(latent_dim).to(device)
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total number of parameters is: {}".format(params))
print(model)
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

Total number of parameters is: 511205
VAE(
  (encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Flatten(start_dim=1, end_dim=-1)
  )
  (fc_mu): Linear(in_features=3136, out_features=50, bias=True)
  (fc_logvar): Linear(in_features=3136, out_features=50, bias=True)
  (fc_decoder): Linear(in_features=50, out_features=3136, bias=True)
  (decoder): Sequential(
    (0): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (1): ReLU()
    (2): ConvTranspose2d(32, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (3): ReLU()
  )
)


---

## Part 1.1b: Training the Model (5 Points)

### Defining a Loss
Recall the Beta VAE loss, with an encoder $q$ and decoder $p$:
$$ \mathcal{L}=\mathbb{E}_{q_\phi(z \mid X)}[\log p_\theta(X \mid z)]-\beta D_{K L}[q_\phi(z \mid X) \| p_\theta(z)]$$

In order to implement this loss you will need to think carefully about your model's outputs and the choice of prior.

There are multiple accepted solutions. Explain your design choices based on the assumptions you make regarding the distribution of your data.

* Hint: this refers to the log likelihood as mentioned in the tutorial. Make sure these assumptions reflect on the values of your input data, i.e. depending on your choice you might need to do a simple preprocessing step.

* You are encouraged to experiment with the weighting coefficient $\beta$ and observe how it affects your training

In [None]:
# *CODE FOR PART 1.1b IN THIS CELL*

def loss_function_VAE(recon_x, x, mu, logvar, beta):
    #######################################################################
    #                       ** START OF YOUR CODE **
    #######################################################################
    # Find loss of reconstructed vs input
    mse = F.mse_loss(recon_x, x, reduction='sum')

    # KL component. Used reduced analytical answer from Appendix B in the paper
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    KL_loss = - 0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return mse + beta * KL_loss, mse, KL_loss
    #######################################################################
    #                       ** END OF YOUR CODE **
    #######################################################################


def train(model, beta_values):

    def train_model(model, beta):
        model.train()
        total_loss_list = []
        mse_loss_list = []
        kl_loss_list = []
        test_total_loss_list = []
        test_mse_loss_list = []
        test_kl_loss_list = []
        for epoch in range(num_epochs):
            running_loss = running_mse = running_kl = 0.0
            with tqdm.tqdm(loader_train, unit="batch") as tepoch:
                for batch_idx, (data, _) in enumerate(tepoch):
                    #######################################################################
                    #                       ** START OF YOUR CODE **
                    #######################################################################
                    # Need at least one batch/random data with right shape -
                    # This is required to initialize to model properly below
                    # when we save the computational graph for testing (jit.save)
                    data = data.to(device)
                    optimizer.zero_grad()
                    recon_x, mu, logvar = model(data)

                    # Compute loss
                    loss, mse, KL_loss = loss_function_VAE(recon_x, data, mu, logvar, beta=beta)
                    running_loss += loss.item()
                    running_mse += mse.item()
                    running_kl += KL_loss.item()

                    # Compute gradient
                    loss = loss / len(data)
                    loss.backward()
                    loss = loss * len(data)
                    optimizer.step()

                    #######################################################################
                    #                       ** END OF YOUR CODE **
                    #######################################################################
                    if batch_idx % 20 == 0:
                        tepoch.set_description(f"Epoch {epoch}")
                        tepoch.set_postfix(loss=loss.item()/len(data))
            # Logging for the end of each epoch
            avg_loss, avg_mse, avg_kl = running_loss / len(loader_train.dataset), running_mse / len(loader_train.dataset), running_kl / len(loader_train.dataset)
            test_loss, test_mse, test_kl = evaluate_model(model)
            # Save training statistics for plotting
            total_loss_list.append(avg_loss)
            mse_loss_list.append(avg_mse)
            kl_loss_list.append(avg_kl)
            test_total_loss_list.append(test_loss)
            test_mse_loss_list.append(test_mse)
            test_kl_loss_list.append(test_kl)
            print(f"Train Loss: {avg_loss}, MSE: {avg_mse}, KL: {avg_kl}")
            print(f"Test Loss: {test_loss}, MSE: {test_mse}, KL: {test_kl}")
            # save the model
            if epoch == num_epochs - 1:
                with torch.no_grad():
                    torch.jit.save(torch.jit.trace(model, (data), check_trace=False),
                        content_path/'CW_VAE/VAE_model.pth')

        training_loss = {'total': total_loss_list, 'mse': mse_loss_list, 'kl': kl_loss_list}
        test_loss = {'total': test_total_loss_list, 'mse': test_mse_loss_list, 'kl': test_kl_loss_list}
        return model, training_loss, test_loss

    def evaluate_model(model):
        model.eval()  # Set the model to evaluation mode

        running_loss = running_mse = running_kl = 0.0
        with torch.no_grad():
            for batch_idx, (data, _) in enumerate(loader_test):
                # Move data to the device if needed
                data = data.to(device)

                # Perform forward pass
                recon_x, mu, logvar = model(data)

                # Compute loss
                loss, mse, KL_loss = loss_function_VAE(recon_x, data, mu, logvar, beta=beta)
                running_loss += loss.item()
                running_mse += mse.item()
                running_kl += KL_loss.item()
            avg_loss, avg_mse, avg_kl = running_loss / len(loader_test.dataset), running_mse / len(loader_test.dataset), running_kl / len(loader_test.dataset)
        return avg_loss, avg_mse, avg_kl

    train_loss_dict, test_loss_dict = {}, {}
    for beta in beta_values:
        model, train_loss_dict[beta], test_loss_dict[beta]  = train_model(model, beta)
    print(train_loss_dict, test_loss_dict)
    return model, train_loss_dict,  test_loss_dict

beta_values = [0, 0.5, 1, 2, 4, 10]
model, training_loss_dict, test_loss_dict = train(model, beta_values)




Epoch 0: 100%|██████████| 469/469 [00:17<00:00, 26.55batch/s, loss=16.4]


Train Loss: 37.39318949381511, MSE: 37.39318949381511, KL: 184.41281998581886
Test Loss: 15.532876293945312, MSE: 15.532876293945312, KL: 279.0158973144531


Epoch 1: 100%|██████████| 469/469 [00:08<00:00, 55.86batch/s, loss=8.72]


Train Loss: 11.578221759033203, MSE: 11.578221759033203, KL: 298.2704716796875
Test Loss: 8.7460679145813, MSE: 8.7460679145813, KL: 311.60205043945314


Epoch 2: 100%|██████████| 469/469 [00:08<00:00, 52.90batch/s, loss=6.51]


Train Loss: 7.83524477335612, MSE: 7.83524477335612, KL: 312.08559095052084
Test Loss: 6.706895062255859, MSE: 6.706895062255859, KL: 319.76863212890623


Epoch 3: 100%|██████████| 469/469 [00:09<00:00, 50.54batch/s, loss=5.92]


Train Loss: 6.355860698445638, MSE: 6.355860698445638, KL: 320.3513564453125
Test Loss: 5.7350517227172855, MSE: 5.7350517227172855, KL: 327.6733946777344


Epoch 4: 100%|██████████| 469/469 [00:09<00:00, 51.06batch/s, loss=5.28]


Train Loss: 5.566883193969726, MSE: 5.566883193969726, KL: 328.9190053385417
Test Loss: 5.149116046142578, MSE: 5.149116046142578, KL: 335.78845576171875


Epoch 5: 100%|██████████| 469/469 [00:09<00:00, 50.45batch/s, loss=4.92]


Train Loss: 5.064958971659342, MSE: 5.064958971659342, KL: 337.0481318359375
Test Loss: 4.763085287475586, MSE: 4.763085287475586, KL: 343.86013876953126


Epoch 6: 100%|██████████| 469/469 [00:08<00:00, 55.56batch/s, loss=4.78]


Train Loss: 4.71522216389974, MSE: 4.71522216389974, KL: 344.5316986979167
Test Loss: 4.452337252044678, MSE: 4.452337252044678, KL: 349.4258470214844


Epoch 7: 100%|██████████| 469/469 [00:08<00:00, 53.71batch/s, loss=4.35]


Train Loss: 4.448969413757324, MSE: 4.448969413757324, KL: 350.89785748697915
Test Loss: 4.229877867126465, MSE: 4.229877867126465, KL: 355.49791494140624


Epoch 8: 100%|██████████| 469/469 [00:09<00:00, 50.82batch/s, loss=4.21]


Train Loss: 4.2338560129801435, MSE: 4.2338560129801435, KL: 355.4787164713542
Test Loss: 4.055838513946533, MSE: 4.055838513946533, KL: 360.9049118652344


Epoch 9: 100%|██████████| 469/469 [00:10<00:00, 45.42batch/s, loss=3.84]


Train Loss: 4.05469320119222, MSE: 4.05469320119222, KL: 361.2536486328125
Test Loss: 3.888896356201172, MSE: 3.888896356201172, KL: 366.1276408203125


Epoch 10: 100%|██████████| 469/469 [00:09<00:00, 51.08batch/s, loss=3.67]


Train Loss: 3.8989595138549804, MSE: 3.8989595138549804, KL: 366.60111165364583
Test Loss: 3.7569026531219483, MSE: 3.7569026531219483, KL: 371.12548046875


Epoch 11: 100%|██████████| 469/469 [00:08<00:00, 52.46batch/s, loss=3.45]


Train Loss: 3.7618163584391278, MSE: 3.7618163584391278, KL: 371.886037890625
Test Loss: 3.610596829223633, MSE: 3.610596829223633, KL: 376.2583831542969


Epoch 12: 100%|██████████| 469/469 [00:08<00:00, 55.62batch/s, loss=3.74]


Train Loss: 3.6367635528564453, MSE: 3.6367635528564453, KL: 376.6535381510417
Test Loss: 3.511346549987793, MSE: 3.511346549987793, KL: 381.07271162109373


Epoch 13: 100%|██████████| 469/469 [00:09<00:00, 51.26batch/s, loss=3.34]


Train Loss: 3.5256523646036784, MSE: 3.5256523646036784, KL: 381.4419689453125
Test Loss: 3.411068501281738, MSE: 3.411068501281738, KL: 385.5632251464844


Epoch 14: 100%|██████████| 469/469 [00:09<00:00, 50.51batch/s, loss=3.22]


Train Loss: 3.4271933385213216, MSE: 3.4271933385213216, KL: 386.3200701171875
Test Loss: 3.311630223083496, MSE: 3.311630223083496, KL: 390.1316354492188


Epoch 15: 100%|██████████| 469/469 [00:09<00:00, 50.99batch/s, loss=3.34]


Train Loss: 3.339833248392741, MSE: 3.339833248392741, KL: 389.73886477864585
Test Loss: 3.2257138931274416, MSE: 3.2257138931274416, KL: 395.80040927734376


Epoch 16: 100%|██████████| 469/469 [00:08<00:00, 52.83batch/s, loss=3.11]


Train Loss: 3.261981652832031, MSE: 3.261981652832031, KL: 393.5458130859375
Test Loss: 3.1703421211242677, MSE: 3.1703421211242677, KL: 396.04423525390627


Epoch 17: 100%|██████████| 469/469 [00:08<00:00, 55.19batch/s, loss=3.01]


Train Loss: 3.1904916997273762, MSE: 3.1904916997273762, KL: 397.4095408854167
Test Loss: 3.0910390785217285, MSE: 3.0910390785217285, KL: 401.4847748535156


Epoch 18: 100%|██████████| 469/469 [00:09<00:00, 51.58batch/s, loss=3.11]


Train Loss: 3.1259870412190756, MSE: 3.1259870412190756, KL: 402.592102734375
Test Loss: 3.0373538623809813, MSE: 3.0373538623809813, KL: 408.4692088867188


Epoch 19: 100%|██████████| 469/469 [00:09<00:00, 50.74batch/s, loss=3.04]


Train Loss: 3.0682131830851236, MSE: 3.0682131830851236, KL: 408.00616399739584
Test Loss: 2.9843935066223146, MSE: 2.9843935066223146, KL: 411.59914643554686


Epoch 20: 100%|██████████| 469/469 [00:09<00:00, 50.95batch/s, loss=2.78]


Train Loss: 3.0166355829874676, MSE: 3.0166355829874676, KL: 411.06753658854166
Test Loss: 2.9384383953094484, MSE: 2.9384383953094484, KL: 415.15706142578125


Epoch 21: 100%|██████████| 469/469 [00:08<00:00, 52.52batch/s, loss=3.18]


Train Loss: 2.971884627787272, MSE: 2.971884627787272, KL: 415.7669619140625
Test Loss: 2.8924538719177244, MSE: 2.8924538719177244, KL: 419.74763515625


Epoch 22: 100%|██████████| 469/469 [00:08<00:00, 55.00batch/s, loss=3.15]


Train Loss: 2.926272543334961, MSE: 2.926272543334961, KL: 419.8358033854167
Test Loss: 2.857215622329712, MSE: 2.857215622329712, KL: 424.4498045898437


Epoch 23: 100%|██████████| 469/469 [00:09<00:00, 51.26batch/s, loss=2.82]


Train Loss: 2.8855845123291015, MSE: 2.8855845123291015, KL: 423.5783944661458
Test Loss: 2.8131628795623778, MSE: 2.8131628795623778, KL: 428.0173146484375


Epoch 24: 100%|██████████| 469/469 [00:09<00:00, 50.15batch/s, loss=2.75]


Train Loss: 2.846974147542318, MSE: 2.846974147542318, KL: 427.7475251953125
Test Loss: 2.795458045578003, MSE: 2.795458045578003, KL: 432.64064428710935


Epoch 25: 100%|██████████| 469/469 [00:09<00:00, 50.23batch/s, loss=2.77]


Train Loss: 2.81449575398763, MSE: 2.81449575398763, KL: 430.45081614583336
Test Loss: 2.749939838409424, MSE: 2.749939838409424, KL: 434.40512465820314


Epoch 26: 100%|██████████| 469/469 [00:09<00:00, 49.77batch/s, loss=2.78]


Train Loss: 2.778476714070638, MSE: 2.778476714070638, KL: 432.6942651041667
Test Loss: 2.7182415477752686, MSE: 2.7182415477752686, KL: 435.93045834960935


Epoch 27: 100%|██████████| 469/469 [00:08<00:00, 55.82batch/s, loss=2.58]


Train Loss: 2.746215449523926, MSE: 2.746215449523926, KL: 434.5833347005208
Test Loss: 2.69857794342041, MSE: 2.69857794342041, KL: 436.84547177734373


Epoch 28: 100%|██████████| 469/469 [00:09<00:00, 51.22batch/s, loss=2.99]


Train Loss: 2.7178102966308595, MSE: 2.7178102966308595, KL: 437.02032141927083
Test Loss: 2.6659300140380857, MSE: 2.6659300140380857, KL: 442.9739553710937


Epoch 29: 100%|██████████| 469/469 [00:09<00:00, 50.10batch/s, loss=2.65]


Train Loss: 2.6902562149047853, MSE: 2.6902562149047853, KL: 443.12846647135416
Test Loss: 2.6509976833343507, MSE: 2.6509976833343507, KL: 447.0042458984375


Epoch 0: 100%|██████████| 469/469 [00:09<00:00, 49.77batch/s, loss=31]


Train Loss: 39.58796359863281, MSE: 16.83450423787435, KL: 45.50691877034505
Test Loss: 31.494943322753905, MSE: 15.666110340881348, KL: 31.657665942382813


Epoch 1: 100%|██████████| 469/469 [00:09<00:00, 50.18batch/s, loss=30.5]


Train Loss: 30.757734745279947, MSE: 15.37591327311198, KL: 30.76364295654297
Test Loss: 29.890558868408203, MSE: 15.14562924194336, KL: 29.489859033203125


Epoch 2: 100%|██████████| 469/469 [00:08<00:00, 53.52batch/s, loss=29.6]


Train Loss: 29.45259880777995, MSE: 14.885003208414714, KL: 29.135191198730467
Test Loss: 28.72218952636719, MSE: 14.373838481140137, KL: 28.696702044677735


Epoch 3: 100%|██████████| 469/469 [00:08<00:00, 54.99batch/s, loss=27.6]


Train Loss: 28.4096529296875, MSE: 14.559909653727214, KL: 27.699486568196615
Test Loss: 27.816363903808593, MSE: 14.408765045166016, KL: 26.815197622680664


Epoch 4: 100%|██████████| 469/469 [00:09<00:00, 50.45batch/s, loss=27.4]


Train Loss: 27.570826123046874, MSE: 14.284520985921224, KL: 26.572610331217447
Test Loss: 27.049202911376952, MSE: 13.943089724731445, KL: 26.212226348876953


Epoch 5: 100%|██████████| 469/469 [00:09<00:00, 50.16batch/s, loss=27.5]


Train Loss: 26.925824865722657, MSE: 14.041927996826171, KL: 25.767793815104167
Test Loss: 26.536169091796875, MSE: 13.871551159667968, KL: 25.329235961914062


Epoch 6: 100%|██████████| 469/469 [00:09<00:00, 50.67batch/s, loss=26]


Train Loss: 26.454814225260417, MSE: 13.85311724650065, KL: 25.203394079589845
Test Loss: 26.040576568603516, MSE: 13.550616549682617, KL: 24.97991983947754


Epoch 7: 100%|██████████| 469/469 [00:08<00:00, 53.91batch/s, loss=26]


Train Loss: 26.041469868977863, MSE: 13.671297908528645, KL: 24.74034394938151
Test Loss: 25.697098428344727, MSE: 13.537455264282226, KL: 24.319286328125


Epoch 8: 100%|██████████| 469/469 [00:08<00:00, 55.52batch/s, loss=25.7]


Train Loss: 25.716650158691408, MSE: 13.538484788004558, KL: 24.356330692545573
Test Loss: 25.44479109802246, MSE: 13.486040454101563, KL: 23.91750140991211


Epoch 9: 100%|██████████| 469/469 [00:09<00:00, 50.98batch/s, loss=25.5]


Train Loss: 25.45474969889323, MSE: 13.415652917480468, KL: 24.078193416341147
Test Loss: 25.183693005371094, MSE: 13.461565447998046, KL: 23.44425511779785


Epoch 10: 100%|██████████| 469/469 [00:09<00:00, 50.64batch/s, loss=25.4]


Train Loss: 25.222040063476562, MSE: 13.30700225016276, KL: 23.830075520833333
Test Loss: 24.9596362701416, MSE: 13.279096606445313, KL: 23.361079400634765


Epoch 11: 100%|██████████| 469/469 [00:09<00:00, 50.36batch/s, loss=25.2]


Train Loss: 25.020058459472658, MSE: 13.210249869791667, KL: 23.61961720377604
Test Loss: 24.74826508178711, MSE: 12.935175828552246, KL: 23.626178509521484


Epoch 12: 100%|██████████| 469/469 [00:08<00:00, 52.65batch/s, loss=25]


Train Loss: 24.851326021321615, MSE: 13.122391866048178, KL: 23.45786835530599
Test Loss: 24.588415374755858, MSE: 12.725506176757813, KL: 23.725818518066408


Epoch 13: 100%|██████████| 469/469 [00:08<00:00, 56.00batch/s, loss=25.4]


Train Loss: 24.683144278971355, MSE: 13.040542329915365, KL: 23.28520390218099
Test Loss: 24.478740145874024, MSE: 12.885851593017579, KL: 23.18577700805664


Epoch 14: 100%|██████████| 469/469 [00:09<00:00, 50.35batch/s, loss=24.6]


Train Loss: 24.513212927246094, MSE: 12.967382826741536, KL: 23.091660205078124
Test Loss: 24.261954583740234, MSE: 13.053722590637207, KL: 22.41646413574219


Epoch 15: 100%|██████████| 469/469 [00:09<00:00, 50.75batch/s, loss=23.9]


Train Loss: 24.37763465983073, MSE: 12.897056760660806, KL: 22.961155851236978
Test Loss: 24.11801439819336, MSE: 12.55837621459961, KL: 23.1192763671875


Epoch 16: 100%|██████████| 469/469 [00:09<00:00, 50.41batch/s, loss=25.2]


Train Loss: 24.28959617919922, MSE: 12.82937852376302, KL: 22.920435270182292
Test Loss: 24.016111791992188, MSE: 12.74570234375, KL: 22.54081882019043


Epoch 17: 100%|██████████| 469/469 [00:08<00:00, 52.25batch/s, loss=23.9]


Train Loss: 24.18734130452474, MSE: 12.779763413492839, KL: 22.815155887858072
Test Loss: 23.961761096191406, MSE: 12.71820732421875, KL: 22.487107470703126


Epoch 18: 100%|██████████| 469/469 [00:08<00:00, 56.12batch/s, loss=24.1]


Train Loss: 24.080496175130207, MSE: 12.71192032063802, KL: 22.737151700846354
Test Loss: 23.869071539306642, MSE: 12.595038667297363, KL: 22.54806589355469


Epoch 19: 100%|██████████| 469/469 [00:09<00:00, 51.12batch/s, loss=24.4]


Train Loss: 23.984478002929688, MSE: 12.66524252319336, KL: 22.638470939127604
Test Loss: 23.717202008056642, MSE: 12.568722183227539, KL: 22.29695984802246


Epoch 20: 100%|██████████| 469/469 [00:09<00:00, 49.74batch/s, loss=23]


Train Loss: 23.88629665120443, MSE: 12.614499881998698, KL: 22.543593497721353
Test Loss: 23.64220224914551, MSE: 12.262526391601563, KL: 22.759351861572267


Epoch 21: 100%|██████████| 469/469 [00:09<00:00, 49.96batch/s, loss=23.1]


Train Loss: 23.82541319173177, MSE: 12.570116288248698, KL: 22.51059385579427
Test Loss: 23.569684173583983, MSE: 12.246130339050293, KL: 22.64710771789551


Epoch 22: 100%|██████████| 469/469 [00:09<00:00, 51.22batch/s, loss=23.5]


Train Loss: 23.770869750976562, MSE: 12.531486505126953, KL: 22.478766459147135
Test Loss: 23.516445031738282, MSE: 12.30428250427246, KL: 22.424324856567385


Epoch 23: 100%|██████████| 469/469 [00:08<00:00, 55.06batch/s, loss=22.9]


Train Loss: 23.684712064615887, MSE: 12.490494685872395, KL: 22.38843475748698
Test Loss: 23.47795545654297, MSE: 12.200113046264649, KL: 22.55568494567871


Epoch 24: 100%|██████████| 469/469 [00:09<00:00, 51.73batch/s, loss=23]


Train Loss: 23.61569999186198, MSE: 12.443212225341798, KL: 22.3449755086263
Test Loss: 23.416773590087892, MSE: 12.379571929931641, KL: 22.07440329589844


Epoch 25: 100%|██████████| 469/469 [00:09<00:00, 50.53batch/s, loss=23.5]


Train Loss: 23.55879814046224, MSE: 12.410618041992187, KL: 22.296360319010418
Test Loss: 23.356112747192384, MSE: 12.151917913818359, KL: 22.408389666748047


Epoch 26: 100%|██████████| 469/469 [00:09<00:00, 51.02batch/s, loss=22.6]


Train Loss: 23.500770141601564, MSE: 12.388326405843099, KL: 22.224887471516926
Test Loss: 23.231247833251953, MSE: 12.263984983825683, KL: 21.934525579833984


Epoch 27: 100%|██████████| 469/469 [00:09<00:00, 49.39batch/s, loss=23]


Train Loss: 23.47669735107422, MSE: 12.348443589274089, KL: 22.25650762125651
Test Loss: 23.227639822387694, MSE: 12.014565335083008, KL: 22.42614921875


Epoch 28: 100%|██████████| 469/469 [00:08<00:00, 53.39batch/s, loss=23.5]


Train Loss: 23.388370572916667, MSE: 12.309642934163412, KL: 22.15745526529948
Test Loss: 23.250928125, MSE: 12.110572494506837, KL: 22.28071128540039


Epoch 29: 100%|██████████| 469/469 [00:08<00:00, 53.20batch/s, loss=23.3]


Train Loss: 23.34925548095703, MSE: 12.281399279785155, KL: 22.13571240641276
Test Loss: 23.111776538085937, MSE: 12.073372299194336, KL: 22.076808380126952


Epoch 0: 100%|██████████| 469/469 [00:09<00:00, 49.74batch/s, loss=32.4]


Train Loss: 32.24961182047526, MSE: 17.33269488728841, KL: 14.91691689453125
Test Loss: 31.904610342407228, MSE: 17.142559074401856, KL: 14.762051402282715


Epoch 1: 100%|██████████| 469/469 [00:09<00:00, 50.17batch/s, loss=32.5]


Train Loss: 31.99304090169271, MSE: 17.743564634195963, KL: 14.249476204427083
Test Loss: 31.746206719970704, MSE: 17.873185861206053, KL: 13.873020919799805


Epoch 2: 100%|██████████| 469/469 [00:09<00:00, 50.74batch/s, loss=31.7]


Train Loss: 31.875671130371092, MSE: 17.874499934895834, KL: 14.00117123006185
Test Loss: 31.61842283325195, MSE: 17.77085405578613, KL: 13.847568519592285


Epoch 3: 100%|██████████| 469/469 [00:08<00:00, 52.85batch/s, loss=31.7]


Train Loss: 31.78014728190104, MSE: 17.92522428588867, KL: 13.85492299601237
Test Loss: 31.531048440551757, MSE: 17.807333850097656, KL: 13.723714663696288


Epoch 4: 100%|██████████| 469/469 [00:08<00:00, 53.41batch/s, loss=31.3]


Train Loss: 31.700006408691408, MSE: 17.949893064371746, KL: 13.750113297526042
Test Loss: 31.42521866455078, MSE: 17.970184204101564, KL: 13.455034390258788


Epoch 5: 100%|██████████| 469/469 [00:09<00:00, 50.81batch/s, loss=32]


Train Loss: 31.604534354654948, MSE: 17.946452779134116, KL: 13.658081540934244
Test Loss: 31.346937576293946, MSE: 17.65589701538086, KL: 13.691040646362305


Epoch 6: 100%|██████████| 469/469 [00:09<00:00, 50.87batch/s, loss=31.6]


Train Loss: 31.543873262532554, MSE: 17.95312919108073, KL: 13.590744091796875
Test Loss: 31.242897631835937, MSE: 17.848927020263673, KL: 13.393970623779296


Epoch 7: 100%|██████████| 469/469 [00:09<00:00, 50.85batch/s, loss=30.9]


Train Loss: 31.482578926595053, MSE: 17.926252176920574, KL: 13.556326770019531
Test Loss: 31.17500473022461, MSE: 17.85877096862793, KL: 13.31623367614746


Epoch 8: 100%|██████████| 469/469 [00:08<00:00, 53.05batch/s, loss=31.2]


Train Loss: 31.44003652750651, MSE: 17.946825783284506, KL: 13.49321075032552
Test Loss: 31.163227294921874, MSE: 17.406120809936525, KL: 13.757106497192384


Epoch 9: 100%|██████████| 469/469 [00:08<00:00, 56.18batch/s, loss=31.2]


Train Loss: 31.36894766438802, MSE: 17.913494521077475, KL: 13.45545313313802
Test Loss: 31.027810955810548, MSE: 17.742687310791016, KL: 13.285123683166503


Epoch 10: 100%|██████████| 469/469 [00:09<00:00, 50.68batch/s, loss=30.5]


Train Loss: 31.324917028808592, MSE: 17.901180454508463, KL: 13.423736610921225
Test Loss: 31.09332130126953, MSE: 17.870608422851564, KL: 13.222712913513183


Epoch 11: 100%|██████████| 469/469 [00:09<00:00, 50.54batch/s, loss=32.5]


Train Loss: 31.281222277832033, MSE: 17.889143843587238, KL: 13.392078440348307
Test Loss: 31.030167388916016, MSE: 17.440809768676758, KL: 13.58935754699707


Epoch 12: 100%|██████████| 469/469 [00:08<00:00, 52.39batch/s, loss=30.5]


Train Loss: 31.203329895019532, MSE: 17.873472330729168, KL: 13.329857568359374
Test Loss: 30.94891717529297, MSE: 17.636774426269533, KL: 13.312142578125


Epoch 13: 100%|██████████| 469/469 [00:08<00:00, 52.87batch/s, loss=31.1]


Train Loss: 31.18979678141276, MSE: 17.860781125895183, KL: 13.32901562906901
Test Loss: 30.94989303588867, MSE: 17.568782397460936, KL: 13.381110676574707


Epoch 14: 100%|██████████| 469/469 [00:08<00:00, 56.73batch/s, loss=31.7]


Train Loss: 31.12446571451823, MSE: 17.842755818684896, KL: 13.28170987141927
Test Loss: 30.827269470214844, MSE: 17.599919482421875, KL: 13.227350012207031


Epoch 15: 100%|██████████| 469/469 [00:08<00:00, 52.71batch/s, loss=31.2]


Train Loss: 31.10468505045573, MSE: 17.82420611979167, KL: 13.28047891438802
Test Loss: 30.86561697692871, MSE: 17.525436236572265, KL: 13.340180776977538


Epoch 16: 100%|██████████| 469/469 [00:09<00:00, 52.08batch/s, loss=30.9]


Train Loss: 31.046917810058595, MSE: 17.78493088175456, KL: 13.26198686319987
Test Loss: 30.84536459350586, MSE: 17.850880966186523, KL: 12.994483714294434


Epoch 17: 100%|██████████| 469/469 [00:08<00:00, 52.22batch/s, loss=31.1]


Train Loss: 30.989971028645833, MSE: 17.764473120117188, KL: 13.225497875976563
Test Loss: 30.780661199951172, MSE: 17.591163700866698, KL: 13.189497570800782


Epoch 18: 100%|██████████| 469/469 [00:08<00:00, 54.89batch/s, loss=31.8]


Train Loss: 30.961630623372397, MSE: 17.75515947265625, KL: 13.206471130371094
Test Loss: 30.70225884399414, MSE: 17.693064492797852, KL: 13.00919426574707


Epoch 19: 100%|██████████| 469/469 [00:08<00:00, 57.98batch/s, loss=31]


Train Loss: 30.922310213216146, MSE: 17.724948685709634, KL: 13.1973615234375
Test Loss: 30.66401030883789, MSE: 17.54648977050781, KL: 13.117520573425294


Epoch 20: 100%|██████████| 469/469 [00:08<00:00, 52.84batch/s, loss=32.2]


Train Loss: 30.88926016031901, MSE: 17.714651985677083, KL: 13.174608221435546
Test Loss: 30.584843383789064, MSE: 17.491799206542968, KL: 13.093044299316407


Epoch 21: 100%|██████████| 469/469 [00:08<00:00, 52.44batch/s, loss=30.4]


Train Loss: 30.825701538085937, MSE: 17.681007535807293, KL: 13.144693977864584
Test Loss: 30.62773142089844, MSE: 17.58594861755371, KL: 13.041782876586915


Epoch 22: 100%|██████████| 469/469 [00:09<00:00, 51.83batch/s, loss=32]


Train Loss: 30.838011250813803, MSE: 17.67924876098633, KL: 13.15876250406901
Test Loss: 30.59886078491211, MSE: 17.505080648803713, KL: 13.093780111694336


Epoch 23: 100%|██████████| 469/469 [00:08<00:00, 55.71batch/s, loss=31.7]


Train Loss: 30.764909216308595, MSE: 17.641761315917968, KL: 13.1231478902181
Test Loss: 30.514152523803713, MSE: 17.485748822021485, KL: 13.028403701782226


Epoch 24: 100%|██████████| 469/469 [00:08<00:00, 56.86batch/s, loss=30.9]


Train Loss: 30.764111848958333, MSE: 17.649200567626952, KL: 13.114911305745443
Test Loss: 30.53463414001465, MSE: 17.52740015258789, KL: 13.007233914184571


Epoch 25: 100%|██████████| 469/469 [00:08<00:00, 52.47batch/s, loss=30.3]


Train Loss: 30.72399999186198, MSE: 17.62749034830729, KL: 13.096509641520182
Test Loss: 30.447591320800782, MSE: 17.372365913391114, KL: 13.075225430297852


Epoch 26: 100%|██████████| 469/469 [00:08<00:00, 52.60batch/s, loss=31.6]


Train Loss: 30.673220174153645, MSE: 17.593584737141928, KL: 13.079635406494141
Test Loss: 30.502516317749024, MSE: 17.36363534088135, KL: 13.138880964660645


Epoch 27: 100%|██████████| 469/469 [00:08<00:00, 52.82batch/s, loss=30.6]


Train Loss: 30.666431750488282, MSE: 17.58229940999349, KL: 13.084132342529298
Test Loss: 30.39754451904297, MSE: 17.461715371704102, KL: 12.935829109191895


Epoch 28: 100%|██████████| 469/469 [00:08<00:00, 57.49batch/s, loss=29.5]


Train Loss: 30.649811263020833, MSE: 17.564152380371095, KL: 13.085658878580729
Test Loss: 30.395695245361328, MSE: 17.348431036376954, KL: 13.047264135742187


Epoch 29: 100%|██████████| 469/469 [00:08<00:00, 55.29batch/s, loss=29.3]


Train Loss: 30.60164706624349, MSE: 17.546951403808595, KL: 13.054695658365885
Test Loss: 30.370185006713868, MSE: 17.365131042480467, KL: 13.005054013061523


Epoch 0: 100%|██████████| 469/469 [00:08<00:00, 52.90batch/s, loss=39.8]


Train Loss: 40.76912349853516, MSE: 24.376459029134114, KL: 8.19633224995931
Test Loss: 40.35489344482422, MSE: 25.032086529541015, KL: 7.6614035125732425


Epoch 1: 100%|██████████| 469/469 [00:08<00:00, 52.41batch/s, loss=40.3]


Train Loss: 40.392353039550784, MSE: 25.108919104003906, KL: 7.64171694946289
Test Loss: 40.01686075439453, MSE: 24.66950705871582, KL: 7.673676870727539


Epoch 2: 100%|██████████| 469/469 [00:08<00:00, 52.46batch/s, loss=40.9]


Train Loss: 40.245947086588544, MSE: 25.35485329996745, KL: 7.445546930948893
Test Loss: 40.064394073486326, MSE: 25.485282229614256, KL: 7.289555903625488


Epoch 3: 100%|██████████| 469/469 [00:08<00:00, 57.85batch/s, loss=40.7]


Train Loss: 40.16713915608724, MSE: 25.462703771972656, KL: 7.352217729695638
Test Loss: 39.95761512451172, MSE: 25.554897882080077, KL: 7.201358641052246


Epoch 4: 100%|██████████| 469/469 [00:08<00:00, 53.80batch/s, loss=40.8]


Train Loss: 40.0552900024414, MSE: 25.55032127685547, KL: 7.25248437093099
Test Loss: 39.79473189086914, MSE: 25.188465795898438, KL: 7.303133139038086


Epoch 5: 100%|██████████| 469/469 [00:08<00:00, 52.51batch/s, loss=38.5]


Train Loss: 40.06604835611979, MSE: 25.63731084798177, KL: 7.214368756103515
Test Loss: 39.80888272705078, MSE: 25.533383959960936, KL: 7.13774951171875


Epoch 6: 100%|██████████| 469/469 [00:08<00:00, 52.98batch/s, loss=40.7]


Train Loss: 39.9534876953125, MSE: 25.61121151529948, KL: 7.1711381195068356
Test Loss: 39.66963979492188, MSE: 24.91445714111328, KL: 7.377591264343262


Epoch 7: 100%|██████████| 469/469 [00:08<00:00, 54.79batch/s, loss=40.3]


Train Loss: 39.91929084472656, MSE: 25.65845987955729, KL: 7.1304154917399085
Test Loss: 39.70016665039063, MSE: 25.599889776611327, KL: 7.050138271331787


Epoch 8: 100%|██████████| 469/469 [00:08<00:00, 56.99batch/s, loss=38.2]


Train Loss: 39.818135689290365, MSE: 25.66050205078125, KL: 7.0788168009440104
Test Loss: 39.61157675170899, MSE: 25.568439544677734, KL: 7.021568542480469


Epoch 9: 100%|██████████| 469/469 [00:08<00:00, 53.03batch/s, loss=39.7]


Train Loss: 39.792804467773436, MSE: 25.65983127034505, KL: 7.066486577351888
Test Loss: 39.581052825927735, MSE: 25.333165585327148, KL: 7.123943698883057


Epoch 10: 100%|██████████| 469/469 [00:08<00:00, 52.92batch/s, loss=39.2]


Train Loss: 39.78304881591797, MSE: 25.62713047281901, KL: 7.077959190877278
Test Loss: 39.53293757324219, MSE: 25.173229608154298, KL: 7.179854000854492


Epoch 11: 100%|██████████| 469/469 [00:08<00:00, 52.29batch/s, loss=39.8]


Train Loss: 39.71915702311198, MSE: 25.619697806803387, KL: 7.0497296518961585
Test Loss: 39.54514415893555, MSE: 25.203432614135743, KL: 7.17085560836792


Epoch 12: 100%|██████████| 469/469 [00:08<00:00, 55.97batch/s, loss=39.7]


Train Loss: 39.68151991373698, MSE: 25.635180480957033, KL: 7.023169724527995
Test Loss: 39.51583131103516, MSE: 25.404188165283202, KL: 7.055821629333496


Epoch 13: 100%|██████████| 469/469 [00:08<00:00, 57.26batch/s, loss=39.9]


Train Loss: 39.703326440429684, MSE: 25.645933386230467, KL: 7.028696538289388
Test Loss: 39.471426654052735, MSE: 25.12069016418457, KL: 7.175368281555175


Epoch 14: 100%|██████████| 469/469 [00:08<00:00, 52.33batch/s, loss=39.6]


Train Loss: 39.60228563232422, MSE: 25.577834838867187, KL: 7.012225356038411
Test Loss: 39.41673536376953, MSE: 25.37540212097168, KL: 7.020666528320312


Epoch 15: 100%|██████████| 469/469 [00:08<00:00, 52.51batch/s, loss=39.3]


Train Loss: 39.59336182454427, MSE: 25.560224938964843, KL: 7.016568404134115
Test Loss: 39.369629986572264, MSE: 25.35753741455078, KL: 7.006046224975586


Epoch 16: 100%|██████████| 469/469 [00:08<00:00, 52.53batch/s, loss=40.5]


Train Loss: 39.58536083170573, MSE: 25.56550488688151, KL: 7.00992798461914
Test Loss: 39.34230503540039, MSE: 25.35430302734375, KL: 6.994000985717774


Epoch 17: 100%|██████████| 469/469 [00:08<00:00, 56.59batch/s, loss=40]


Train Loss: 39.528254052734376, MSE: 25.550715193684894, KL: 6.988769417317708
Test Loss: 39.284391662597656, MSE: 25.372125875854493, KL: 6.9561328643798825


Epoch 18: 100%|██████████| 469/469 [00:08<00:00, 54.96batch/s, loss=38.2]


Train Loss: 39.481899198404946, MSE: 25.53205975748698, KL: 6.974919713338216
Test Loss: 39.25091356811524, MSE: 25.225579370117188, KL: 7.012667074584961


Epoch 19: 100%|██████████| 469/469 [00:08<00:00, 52.82batch/s, loss=41.2]


Train Loss: 39.48975603841146, MSE: 25.515740889485677, KL: 6.987007582600912
Test Loss: 39.255635888671875, MSE: 25.61976330871582, KL: 6.8179362167358395


Epoch 20: 100%|██████████| 469/469 [00:08<00:00, 52.62batch/s, loss=40.4]


Train Loss: 39.449635465494794, MSE: 25.50425196126302, KL: 6.972691757202148
Test Loss: 39.27186352539063, MSE: 25.193032412719727, KL: 7.039415496826172


Epoch 21: 100%|██████████| 469/469 [00:08<00:00, 52.79batch/s, loss=39.8]


Train Loss: 39.459067325846355, MSE: 25.50254307454427, KL: 6.978262111409506
Test Loss: 39.19869075317383, MSE: 25.005695932006837, KL: 7.0964973854064946


Epoch 22: 100%|██████████| 469/469 [00:08<00:00, 57.51batch/s, loss=39.7]


Train Loss: 39.412641068522134, MSE: 25.48336119791667, KL: 6.964639935302734
Test Loss: 39.32020014038086, MSE: 25.037191983032226, KL: 7.1415040542602535


Epoch 23: 100%|██████████| 469/469 [00:08<00:00, 54.20batch/s, loss=38.3]


Train Loss: 39.38570125732422, MSE: 25.451875899251302, KL: 6.966912650553385
Test Loss: 39.28016213378906, MSE: 25.255195806884764, KL: 7.012483113098145


Epoch 24: 100%|██████████| 469/469 [00:09<00:00, 51.32batch/s, loss=39.4]


Train Loss: 39.38186635335286, MSE: 25.451315104166667, KL: 6.965275633748372
Test Loss: 39.1837485534668, MSE: 25.271585174560546, KL: 6.956081658172607


Epoch 25: 100%|██████████| 469/469 [00:08<00:00, 53.00batch/s, loss=39.6]


Train Loss: 39.3148146891276, MSE: 25.413550435384114, KL: 6.95063215230306
Test Loss: 39.173536627197265, MSE: 25.2834443359375, KL: 6.9450461891174315


Epoch 26: 100%|██████████| 469/469 [00:08<00:00, 52.42batch/s, loss=39.4]


Train Loss: 39.36002167561849, MSE: 25.416264831542968, KL: 6.971878415934245
Test Loss: 39.205634741210936, MSE: 25.01838171386719, KL: 7.093626507568359


Epoch 27: 100%|██████████| 469/469 [00:08<00:00, 57.82batch/s, loss=37.9]


Train Loss: 39.335030716959636, MSE: 25.393145621744793, KL: 6.970942514038086
Test Loss: 39.17552883300781, MSE: 25.370436526489257, KL: 6.902546182250976


Epoch 28: 100%|██████████| 469/469 [00:08<00:00, 53.74batch/s, loss=38.9]


Train Loss: 39.279489803059896, MSE: 25.35505928141276, KL: 6.962215256754558
Test Loss: 39.03417277832031, MSE: 25.152099328613282, KL: 6.941036683654785


Epoch 29: 100%|██████████| 469/469 [00:08<00:00, 52.15batch/s, loss=39.6]


Train Loss: 39.26953229573568, MSE: 25.36277520751953, KL: 6.9533784993489585
Test Loss: 39.05296807250976, MSE: 25.062585369873048, KL: 6.995191315460205


Epoch 0: 100%|██████████| 469/469 [00:08<00:00, 52.39batch/s, loss=48.6]


Train Loss: 49.307912972005205, MSE: 34.41334130859375, KL: 3.723642923482259
Test Loss: 48.72554296875, MSE: 34.98532455444336, KL: 3.4350546165466307


Epoch 1: 100%|██████████| 469/469 [00:08<00:00, 54.82batch/s, loss=49]


Train Loss: 48.763460546875, MSE: 35.57455182291667, KL: 3.2972271896362306
Test Loss: 48.5691493347168, MSE: 35.87483209228515, KL: 3.1735793224334716


Epoch 2: 100%|██████████| 469/469 [00:08<00:00, 58.27batch/s, loss=50.3]


Train Loss: 48.5740281656901, MSE: 35.96499178873698, KL: 3.1522590911865236
Test Loss: 48.39095819091797, MSE: 35.7787453125, KL: 3.1530532257080077


Epoch 3: 100%|██████████| 469/469 [00:08<00:00, 52.67batch/s, loss=48.5]


Train Loss: 48.50039903971354, MSE: 36.21165215657552, KL: 3.072186708577474
Test Loss: 48.36319068603515, MSE: 36.76634200439453, KL: 2.899212102508545


Epoch 4: 100%|██████████| 469/469 [00:08<00:00, 53.02batch/s, loss=49.4]


Train Loss: 48.40378494466146, MSE: 36.344069982910156, KL: 3.0149287638346354
Test Loss: 48.20559784545898, MSE: 36.133595642089844, KL: 3.0180005752563477


Epoch 5: 100%|██████████| 469/469 [00:08<00:00, 52.59batch/s, loss=45.7]


Train Loss: 48.323890071614585, MSE: 36.45168983968099, KL: 2.968050040181478
Test Loss: 48.209878125, MSE: 36.70311365966797, KL: 2.876691125488281


Epoch 6: 100%|██████████| 469/469 [00:08<00:00, 56.53batch/s, loss=47.5]


Train Loss: 48.283277490234376, MSE: 36.56546649169922, KL: 2.9294527572631837
Test Loss: 48.12105595703125, MSE: 36.05425049438477, KL: 3.0167013103485107


Epoch 7: 100%|██████████| 469/469 [00:08<00:00, 56.92batch/s, loss=50]


Train Loss: 48.255189534505206, MSE: 36.63672685546875, KL: 2.9046156733194985
Test Loss: 48.101666400146485, MSE: 36.663531762695314, KL: 2.859533631896973


Epoch 8: 100%|██████████| 469/469 [00:08<00:00, 52.83batch/s, loss=49.7]


Train Loss: 48.18934097493489, MSE: 36.65698412679036, KL: 2.883089215596517
Test Loss: 48.165552734375, MSE: 37.02455983886719, KL: 2.785248239135742


Epoch 9: 100%|██████████| 469/469 [00:08<00:00, 52.69batch/s, loss=48.4]


Train Loss: 48.1616601155599, MSE: 36.66818472086589, KL: 2.873368835449219
Test Loss: 48.12314625244141, MSE: 36.45274066772461, KL: 2.9176013679504393


Epoch 10: 100%|██████████| 469/469 [00:08<00:00, 53.01batch/s, loss=48.6]


Train Loss: 48.18982923990885, MSE: 36.70490319010417, KL: 2.871231526184082
Test Loss: 47.99813339233398, MSE: 36.25348001708984, KL: 2.9361633438110353


Epoch 11: 100%|██████████| 469/469 [00:08<00:00, 57.56batch/s, loss=49.5]


Train Loss: 48.127463321940105, MSE: 36.723139200846354, KL: 2.8510810287475588
Test Loss: 47.95446378173828, MSE: 36.594381665039066, KL: 2.8400204654693604


Epoch 12: 100%|██████████| 469/469 [00:08<00:00, 54.51batch/s, loss=47.8]


Train Loss: 48.0778615641276, MSE: 36.732800447591146, KL: 2.8362652893066405
Test Loss: 48.10379887695313, MSE: 36.66128679199219, KL: 2.860628057098389


Epoch 13: 100%|██████████| 469/469 [00:08<00:00, 53.28batch/s, loss=47.6]


Train Loss: 48.118647534179686, MSE: 36.77249473470052, KL: 2.836538182067871
Test Loss: 48.03187406616211, MSE: 37.15503551635742, KL: 2.7192096862792967


Epoch 14: 100%|██████████| 469/469 [00:08<00:00, 52.74batch/s, loss=49.4]


Train Loss: 48.09778484700521, MSE: 36.73569637858073, KL: 2.840522095743815
Test Loss: 48.04034587402344, MSE: 36.7228921081543, KL: 2.829363481903076


Epoch 15: 100%|██████████| 469/469 [00:08<00:00, 53.76batch/s, loss=48.5]


Train Loss: 48.064946459960936, MSE: 36.79216183268229, KL: 2.8181961558024087
Test Loss: 47.91089328613281, MSE: 36.49584663696289, KL: 2.853761612701416


Epoch 16: 100%|██████████| 469/469 [00:08<00:00, 57.90batch/s, loss=48]


Train Loss: 48.02283763834635, MSE: 36.75022296549479, KL: 2.818153667195638
Test Loss: 47.89414760131836, MSE: 36.57295337524414, KL: 2.830298565673828


Epoch 17: 100%|██████████| 469/469 [00:08<00:00, 53.24batch/s, loss=47.7]


Train Loss: 48.0593888671875, MSE: 36.77631189778646, KL: 2.8207692774454753
Test Loss: 47.90153023071289, MSE: 36.61885177001953, KL: 2.820669618225098


Epoch 18: 100%|██████████| 469/469 [00:08<00:00, 52.47batch/s, loss=46.6]


Train Loss: 47.985808032226565, MSE: 36.73456687825521, KL: 2.8128102844238283
Test Loss: 48.01518960571289, MSE: 36.894150939941404, KL: 2.780259672164917


Epoch 19: 100%|██████████| 469/469 [00:08<00:00, 52.45batch/s, loss=49.2]


Train Loss: 47.98545489095052, MSE: 36.75778094889323, KL: 2.8069184778849285
Test Loss: 47.931939508056644, MSE: 36.451908410644535, KL: 2.8700077194213867


Epoch 20: 100%|██████████| 469/469 [00:08<00:00, 55.81batch/s, loss=49.1]


Train Loss: 47.96131904296875, MSE: 36.73708907063802, KL: 2.8060574918111167
Test Loss: 47.95092199707031, MSE: 36.83297637939453, KL: 2.7794864685058593


Epoch 21: 100%|██████████| 469/469 [00:08<00:00, 57.44batch/s, loss=48.5]


Train Loss: 47.927236686197915, MSE: 36.74027830810547, KL: 2.7967396006266276
Test Loss: 47.87885272827148, MSE: 36.835097912597654, KL: 2.7609386852264404


Epoch 22: 100%|██████████| 469/469 [00:08<00:00, 53.01batch/s, loss=47.1]


Train Loss: 47.94532442220052, MSE: 36.74948284098307, KL: 2.798960371398926
Test Loss: 47.88908765258789, MSE: 36.619441455078125, KL: 2.817411570739746


Epoch 23: 100%|██████████| 469/469 [00:08<00:00, 52.57batch/s, loss=48.3]


Train Loss: 47.936272517903646, MSE: 36.724899080403645, KL: 2.802843366495768
Test Loss: 47.882257299804685, MSE: 36.44989864501953, KL: 2.858089663696289


Epoch 24: 100%|██████████| 469/469 [00:09<00:00, 51.59batch/s, loss=46.1]


Train Loss: 47.94118627929687, MSE: 36.73691830240885, KL: 2.8010669743855794
Test Loss: 47.78366616210938, MSE: 36.496870092773435, KL: 2.8216990089416503


Epoch 25: 100%|██████████| 469/469 [00:08<00:00, 56.46batch/s, loss=46.8]


Train Loss: 47.88565962727865, MSE: 36.73085710042318, KL: 2.788700626118978
Test Loss: 47.886470703125, MSE: 36.57354844970703, KL: 2.8282305702209474


Epoch 26: 100%|██████████| 469/469 [00:08<00:00, 55.85batch/s, loss=48.7]


Train Loss: 47.87411522623698, MSE: 36.698224979654945, KL: 2.7939725463867187
Test Loss: 47.67808137207031, MSE: 36.187736669921875, KL: 2.872586173248291


Epoch 27: 100%|██████████| 469/469 [00:08<00:00, 52.64batch/s, loss=47.7]


Train Loss: 47.875248518880205, MSE: 36.62784338785807, KL: 2.811851271057129
Test Loss: 47.787535113525394, MSE: 36.54942791748047, KL: 2.80952674369812


Epoch 28: 100%|██████████| 469/469 [00:08<00:00, 53.03batch/s, loss=45.9]


Train Loss: 47.8721638264974, MSE: 36.619896626790364, KL: 2.8130668019612632
Test Loss: 47.8451363647461, MSE: 36.50385394897461, KL: 2.8353205848693848


Epoch 29: 100%|██████████| 469/469 [00:08<00:00, 52.98batch/s, loss=48.6]


Train Loss: 47.83685437825521, MSE: 36.63748267008464, KL: 2.7998429489135743
Test Loss: 47.708402526855465, MSE: 36.33524600830078, KL: 2.8432891082763674


Epoch 0: 100%|██████████| 469/469 [00:08<00:00, 57.24batch/s, loss=52.5]


Train Loss: 54.89463288574219, MSE: 51.73260876464844, KL: 0.3162024162610372
Test Loss: 54.034765856933596, MSE: 52.97729559326172, KL: 0.1057470081448555


Epoch 1: 100%|██████████| 469/469 [00:08<00:00, 53.82batch/s, loss=53.9]


Train Loss: 53.565237556966146, MSE: 52.793949275716145, KL: 0.07712883319854737
Test Loss: 53.62551640625, MSE: 53.077147888183596, KL: 0.05483684155941009


Epoch 2: 100%|██████████| 469/469 [00:08<00:00, 52.55batch/s, loss=52]


Train Loss: 53.291609724934894, MSE: 52.836714412434894, KL: 0.045489521396160124
Test Loss: 53.339071472167966, MSE: 52.95894644775391, KL: 0.03801249703168869


Epoch 3: 100%|██████████| 469/469 [00:08<00:00, 53.00batch/s, loss=52.6]


Train Loss: 53.139110083007814, MSE: 52.772983642578126, KL: 0.036612644966443376
Test Loss: 53.26529778442383, MSE: 52.930685705566404, KL: 0.03346120125055313


Epoch 4: 100%|██████████| 469/469 [00:08<00:00, 54.13batch/s, loss=52.7]


Train Loss: 53.05197044270833, MSE: 52.71048727213542, KL: 0.034148315024375916
Test Loss: 53.20487866821289, MSE: 52.86723950805664, KL: 0.03376389615535736


Epoch 5: 100%|██████████| 469/469 [00:08<00:00, 57.24batch/s, loss=54.8]


Train Loss: 53.01131364746094, MSE: 52.6875333984375, KL: 0.03237802092234294
Test Loss: 53.1616746887207, MSE: 52.858297314453125, KL: 0.03033773738145828


Epoch 6: 100%|██████████| 469/469 [00:08<00:00, 52.36batch/s, loss=51.3]


Train Loss: 52.97177676595052, MSE: 52.655877872721355, KL: 0.03158988466262817
Test Loss: 53.11126590576172, MSE: 52.77132731933594, KL: 0.033993864107131955


Epoch 7: 100%|██████████| 469/469 [00:08<00:00, 52.37batch/s, loss=52.2]


Train Loss: 52.961009236653645, MSE: 52.66910383300781, KL: 0.029190543528397877
Test Loss: 53.08378814086914, MSE: 52.81605072021485, KL: 0.02677373217344284


Epoch 8: 100%|██████████| 469/469 [00:08<00:00, 52.81batch/s, loss=52.7]


Train Loss: 52.933383268229164, MSE: 52.64682098795573, KL: 0.028656223185857137
Test Loss: 53.07566451416016, MSE: 52.762107446289065, KL: 0.0313557128071785


Epoch 9: 100%|██████████| 469/469 [00:08<00:00, 55.73batch/s, loss=51.4]


Train Loss: 52.90934765625, MSE: 52.58947888997396, KL: 0.0319868758559227
Test Loss: 53.06326168823242, MSE: 52.797245135498045, KL: 0.02660165295600891


Epoch 10: 100%|██████████| 469/469 [00:08<00:00, 57.82batch/s, loss=52.7]


Train Loss: 52.91377552083333, MSE: 52.62400760091146, KL: 0.028976788139343262
Test Loss: 53.05133720703125, MSE: 52.78280718994141, KL: 0.026853039073944093


Epoch 11: 100%|██████████| 469/469 [00:08<00:00, 52.75batch/s, loss=53.7]


Train Loss: 52.891749780273436, MSE: 52.615599348958334, KL: 0.027615036396185556
Test Loss: 53.01918726196289, MSE: 52.67859649658203, KL: 0.0340590980052948


Epoch 12: 100%|██████████| 469/469 [00:08<00:00, 52.43batch/s, loss=52.3]


Train Loss: 52.88594315592448, MSE: 52.62029823404948, KL: 0.026564489090442657
Test Loss: 53.00935514526367, MSE: 52.71913276367187, KL: 0.029022247970104217


Epoch 13: 100%|██████████| 469/469 [00:08<00:00, 52.26batch/s, loss=54.3]


Train Loss: 52.866977587890624, MSE: 52.55543455403646, KL: 0.0311543026248614
Test Loss: 53.01741689453125, MSE: 52.71415314941406, KL: 0.03032637597322464


Epoch 14: 100%|██████████| 469/469 [00:08<00:00, 57.45batch/s, loss=52.4]


Train Loss: 52.863254345703126, MSE: 52.53779534505208, KL: 0.03254590129852295
Test Loss: 53.017187475585935, MSE: 52.720182263183595, KL: 0.02970053334236145


Epoch 15: 100%|██████████| 469/469 [00:08<00:00, 56.01batch/s, loss=53.1]


Train Loss: 52.84490580240885, MSE: 52.51904600423177, KL: 0.032585980212688444
Test Loss: 52.999325, MSE: 52.65580155029297, KL: 0.03435232458710671


Epoch 16: 100%|██████████| 469/469 [00:08<00:00, 52.12batch/s, loss=53.8]


Train Loss: 52.855169213867185, MSE: 52.5182269124349, KL: 0.033694233600298565
Test Loss: 52.97054895629883, MSE: 52.63291617431641, KL: 0.033763288223743436


Epoch 17: 100%|██████████| 469/469 [00:08<00:00, 52.80batch/s, loss=52.8]


Train Loss: 52.846455615234376, MSE: 52.5284582438151, KL: 0.031799724380175275
Test Loss: 53.00669678955078, MSE: 52.683984210205075, KL: 0.03227128247022629


Epoch 18: 100%|██████████| 469/469 [00:09<00:00, 51.94batch/s, loss=52.6]


Train Loss: 52.830813614908855, MSE: 52.497670963541665, KL: 0.03331426432132721
Test Loss: 53.01683491821289, MSE: 52.581277947998046, KL: 0.04355567327737808


Epoch 19: 100%|██████████| 469/469 [00:08<00:00, 56.95batch/s, loss=53.3]


Train Loss: 52.84242753092448, MSE: 52.50444381510417, KL: 0.03379837466478348
Test Loss: 52.995996771240236, MSE: 52.60768493041992, KL: 0.03883119099736214


Epoch 20: 100%|██████████| 469/469 [00:08<00:00, 53.93batch/s, loss=52.8]


Train Loss: 52.818324259440104, MSE: 52.45620725097656, KL: 0.03621170253753662
Test Loss: 53.01282297973633, MSE: 52.593835998535155, KL: 0.041898703593015674


Epoch 21: 100%|██████████| 469/469 [00:08<00:00, 52.69batch/s, loss=51.6]


Train Loss: 52.809476928710936, MSE: 52.4052185139974, KL: 0.040425843433539076
Test Loss: 52.991166827392576, MSE: 52.51441789550781, KL: 0.0476749022603035


Epoch 22: 100%|██████████| 469/469 [00:09<00:00, 52.05batch/s, loss=53]


Train Loss: 52.82233488769531, MSE: 52.422893969726566, KL: 0.03994409365653992
Test Loss: 52.992002459716794, MSE: 52.54985165405274, KL: 0.04421508212685585


Epoch 23: 100%|██████████| 469/469 [00:08<00:00, 53.82batch/s, loss=52.5]


Train Loss: 52.81315309244792, MSE: 52.4064890625, KL: 0.04066639450391134
Test Loss: 52.94325035400391, MSE: 52.49280100097656, KL: 0.04504491262435913


Epoch 24: 100%|██████████| 469/469 [00:08<00:00, 57.70batch/s, loss=52.4]


Train Loss: 52.82043920084635, MSE: 52.406809464518226, KL: 0.041362977107365924
Test Loss: 52.939683239746095, MSE: 52.49905670166016, KL: 0.04406266952753067


Epoch 25: 100%|██████████| 469/469 [00:08<00:00, 52.33batch/s, loss=54.3]


Train Loss: 52.78478673502604, MSE: 52.31490764973958, KL: 0.04698790261348089
Test Loss: 52.97579633789063, MSE: 52.480028295898435, KL: 0.04957681760787964


Epoch 26: 100%|██████████| 469/469 [00:08<00:00, 52.58batch/s, loss=51.4]


Train Loss: 52.795563932291664, MSE: 52.31948157552083, KL: 0.04760823881228765
Test Loss: 52.990991284179685, MSE: 52.50873500366211, KL: 0.048225608348846434


Epoch 27: 100%|██████████| 469/469 [00:08<00:00, 52.39batch/s, loss=51.1]


Train Loss: 52.82491772460938, MSE: 52.350183821614586, KL: 0.047473388675848646
Test Loss: 52.98148928222656, MSE: 52.535504711914065, KL: 0.044598477125167846


Epoch 28: 100%|██████████| 469/469 [00:08<00:00, 54.69batch/s, loss=55]


Train Loss: 52.80657236328125, MSE: 52.35801435546875, KL: 0.044855806545416516
Test Loss: 52.97677858886719, MSE: 52.448650146484376, KL: 0.052812851595878604


Epoch 29: 100%|██████████| 469/469 [00:08<00:00, 57.46batch/s, loss=49.6]


Train Loss: 52.782580151367185, MSE: 52.26455440266927, KL: 0.051802574666341146
Test Loss: 52.949864819335936, MSE: 52.42713148193359, KL: 0.05227335637807846
{0: {'total': [37.39318949381511, 11.578221759033203, 7.83524477335612, 6.355860698445638, 5.566883193969726, 5.064958971659342, 4.71522216389974, 4.448969413757324, 4.2338560129801435, 4.05469320119222, 3.8989595138549804, 3.7618163584391278, 3.6367635528564453, 3.5256523646036784, 3.4271933385213216, 3.339833248392741, 3.261981652832031, 3.1904916997273762, 3.1259870412190756, 3.0682131830851236, 3.0166355829874676, 2.971884627787272, 2.926272543334961, 2.8855845123291015, 2.846974147542318, 2.81449575398763, 2.778476714070638, 2.746215449523926, 2.7178102966308595, 2.6902562149047853], 'mse': [37.39318949381511, 11.578221759033203, 7.83524477335612, 6.355860698445638, 5.566883193969726, 5.064958971659342, 4.71522216389974, 4.448969413757324, 4.2338560129801435, 4.05469320119222, 3.8989595138549804, 3.7618163584391278, 3.63676

### Loss Explanation
Explain your choice of loss and how this relates to:

* The VAE Prior
* The output data domain
* Disentanglement in the latent space


In [None]:
# Any code for your explanation here (you may not need to use this cell)
print(training_loss_dict)

**YOUR ANSWER**



<h2>Part 1.2 (9 points)</h2>

a. Plot your loss curves

b. Show reconstructions and samples

c. Discuss your results from parts (a) and (b)

## Part 1.2a: Loss Curves (3 Points)
Plot your loss curves (6 in total, 3 for the training set and 3 for the test set): total loss, reconstruction log likelihood loss, KL loss (x-axis: epochs, y-axis: loss). If you experimented with different values of $\beta$, you may wish to display multiple plots (worth 1 point).

In [None]:
# *CODE FOR PART 1.2a IN THIS CELL*
import matplotlib.pyplot as plt


for beta, training_loss in training_loss_dict.items():
    test_loss = test_loss_dict[beta]

    # Create subplots with 1 row and 3 columns for each combination of beta
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(20, 6))

    # Plot total loss for training and test
    axes[0].plot(range(1, num_epochs + 1), training_loss['total'], label=f'Training Total Loss (Beta={beta})', marker='o')
    axes[0].plot(range(1, num_epochs + 1), test_loss['total'], label=f'Test Total Loss (Beta={beta})', marker='o')
    axes[0].set_title('Total Loss Over Epochs')
    axes[0].set_xlabel('Epochs')
    axes[0].set_ylabel('Total Loss')
    axes[0].legend()

    # Plot reconstruction log likelihood loss for training and test
    axes[1].plot(range(1, num_epochs + 1), training_loss['mse'], label=f'Training MSE Loss (Beta={beta})', marker='o')
    axes[1].plot(range(1, num_epochs + 1), test_loss['mse'], label=f'Test MSE Loss (Beta={beta})', marker='o')
    axes[1].set_title('Reconstruction Log Likelihood Loss Over Epochs')
    axes[1].set_xlabel('Epochs')
    axes[1].set_ylabel('MSE Loss')
    axes[1].legend()

    # Plot KL loss for training and test
    axes[2].plot(range(1, num_epochs + 1), training_loss['kl'], label=f'Training KL Loss (Beta={beta})', marker='o')
    axes[2].plot(range(1, num_epochs + 1), test_loss['kl'], label=f'Test KL Loss (Beta={beta})', marker='o')
    axes[2].set_title('KL Loss Over Epochs')
    axes[2].set_xlabel('Epochs')
    axes[2].set_ylabel('KL Loss')
    axes[2].legend()

plt.show()


## Part 1.2b: Samples and Reconstructions (6 Points)
Visualize a subset of the images of the test set and their reconstructions **as well as** a few generated samples. Most of the code for this part is provided. You only need to call the forward pass of the model for the given inputs (might vary depending on your implementation).

For reference, here's [some samples from our VAE](https://imgur.com/NwNMuG3).


In [22]:
# *CODE FOR PART 1.2b IN THIS CELL*

# load the model
model,_ ,_ = train(model, [0.5])
print('Input images')
print('-'*50)

sample_inputs, _ = next(iter(loader_test))
fixed_input = sample_inputs[0:32, :, :, :]
# visualize the original images of the last batch of the test set
img = make_grid(denorm(fixed_input), nrow=8, padding=2, normalize=False,
                value_range=None, scale_each=False, pad_value=0)
plt.figure()
show(img)

print('Reconstructed images')
print('-'*50)
with torch.no_grad():
    # visualize the reconstructed images of the last batch of test set

    #######################################################################
    #                       ** START OF YOUR CODE **
    #######################################################################
    recon_batch, _, _ = model(fixed_input.to(device))
    #######################################################################
    #                       ** END OF YOUR CODE **
    #######################################################################

    recon_batch = recon_batch.cpu()
    recon_batch = make_grid(denorm(recon_batch), nrow=8, padding=2, normalize=False,
                            value_range=None, scale_each=False, pad_value=0)
    plt.figure()
    show(recon_batch)

print('Generated Images')
print('-'*50)
model.eval()
n_samples = 256
z = torch.randn(n_samples,latent_dim).to(device)
with torch.no_grad():
    #######################################################################
    #                       ** START OF YOUR CODE **
    #######################################################################
    samples = model.decode(z)
    #######################################################################
    #                       ** END OF YOUR CODE **
    #######################################################################

    samples = samples.cpu()
    samples = make_grid(denorm(samples), nrow=16, padding=2, normalize=False,
                            value_range=None, scale_each=False, pad_value=0)
    plt.figure(figsize = (8,8))
    show(samples)



Epoch 0: 100%|██████████| 469/469 [00:10<00:00, 46.59batch/s, loss=38.5]


Train Loss: 55.350765250651044, MSE: 48.86738577067057, KL: 12.966759188818932
Test Loss: 36.9231267578125, MSE: 27.72489732055664, KL: 18.396458654785157


Epoch 1: 100%|██████████| 469/469 [00:08<00:00, 55.75batch/s, loss=31.5]


Train Loss: 34.034542411295575, MSE: 24.373987890625, KL: 19.321109014892578
Test Loss: 31.69088910522461, MSE: 21.67261285095215, KL: 20.036552389526367


Epoch 2: 100%|██████████| 469/469 [00:08<00:00, 53.16batch/s, loss=29.6]


Train Loss: 30.870833923339845, MSE: 20.851866479492188, KL: 20.037934822591147
Test Loss: 29.65763132019043, MSE: 19.511975399780273, KL: 20.291312060546876


Epoch 3: 100%|██████████| 469/469 [00:09<00:00, 49.91batch/s, loss=28.9]


Train Loss: 29.299053678385416, MSE: 19.248919915771484, KL: 20.100267533365887
Test Loss: 28.475511279296875, MSE: 18.303525099182128, KL: 20.343972534179688


Epoch 4: 100%|██████████| 469/469 [00:10<00:00, 45.96batch/s, loss=27.3]


Train Loss: 28.276862154134115, MSE: 18.254730924479166, KL: 20.044262420654295
Test Loss: 27.59487717285156, MSE: 17.46202449645996, KL: 20.265705303955077


Epoch 5: 100%|██████████| 469/469 [00:09<00:00, 48.40batch/s, loss=27.6]


Train Loss: 27.529283825683592, MSE: 17.509138564046225, KL: 20.0402904683431
Test Loss: 26.965926837158204, MSE: 17.00142375488281, KL: 19.929006115722657


Epoch 6: 100%|██████████| 469/469 [00:08<00:00, 52.13batch/s, loss=27.3]


Train Loss: 26.948568868001303, MSE: 16.91156735229492, KL: 20.07400303548177
Test Loss: 26.436574963378906, MSE: 16.33914562225342, KL: 20.194858679199218


Epoch 7: 100%|██████████| 469/469 [00:08<00:00, 55.76batch/s, loss=26.5]


Train Loss: 26.49717268066406, MSE: 16.409248388671877, KL: 20.17584853922526
Test Loss: 26.068557092285157, MSE: 15.942263815307617, KL: 20.252586602783204


Epoch 8: 100%|██████████| 469/469 [00:09<00:00, 48.01batch/s, loss=25.9]


Train Loss: 26.107050903320314, MSE: 16.005530399576823, KL: 20.20304098510742
Test Loss: 25.68888666381836, MSE: 15.54913717956543, KL: 20.279499017333983


Epoch 9:  58%|█████▊    | 273/469 [00:05<00:04, 47.63batch/s, loss=25.8]


KeyboardInterrupt: 

### Discussion
Provide a brief analysis of your loss curves and reconstructions:
* What do you observe in the behaviour of the log-likelihood loss and the KL loss (increasing/decreasing)?
* Can you intuitively explain if this behaviour is desirable?
* What is posterior collapse and did you observe it during training (i.e. when the KL is too small during the early stages of training)?
    * If yes, how did you mitigate it? How did this phenomenon reflect on your output samples?
    * If no, why do you think that is?

**YOUR ANSWER**

---
<h2> Part 1.3 (11 points) <h2/>

Qualitative analysis of the learned representations

In this question you are asked to qualitatively assess the representations that your model has learned. In particular:

a. Dimensionality Reduction of learned embeddings

b. Interpolating in the latent space

## Part 1.3a: T-SNE on Embeddings (7 Points)
Extract the latent representations of the test set and visualize them using [T-SNE](https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding)  [(see implementation)](https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html). You can use a T-SNE implementation from a library such as scikit-learn.

We've provided a function to visualize a subset of the data, but you are encouraged to also produce a matplotlib plot (please use different colours for each digit class).

In [10]:
# *CODE FOR PART 1.3a IN THIS CELL
model.eval()

z_embedded = []

with torch.no_grad():
    for batch_idx, (data, _) in enumerate(loader_test):
        data = data.to(device)

        # Perform forward pass
        z, mu, logvar = model.encode(data)

        z_embedded.append(z.cpu().numpy())

z_embedded = np.concatenate(z_embedded, axis=0)


In [11]:
# Interactive Visualization - Code Provided
test_dataloader = DataLoader(test_dat, 10000, shuffle=False)
""" Inputs to the function are
        z_embedded - Embedded X, Y positions for every point in test_dataloader
        test_dataloader - dataloader with batchsize set to 10000
        num_points - number of points plotted (will slow down with >1k)
"""
plot_tsne(z_embedded, test_dataloader, num_points=1000, darkmode=False)


Output hidden; open in https://colab.research.google.com to view.

In [None]:
# Custom Visualizations

### Discussion
What do you observe? Discuss the structure of the visualized representations.
* What do you observe? What role do the KL loss term and $\beta$ have, if any, in what you observe (multiple matplotlib plots may be desirable here)?
    * Consider Outliers
    * Counsider Boundaries
    * Consider Clusters
* Is T-SNE reliable? What happens if you change the parameters (don't worry about being particularly thorough). [This link](https://distill.pub/2016/misread-tsne/) may be helpful.

Note - If you created multiple plots and want to include them in your discussion, the best option is to upload them to (e.g.) google drive and then embed them via a **public** share link. If you reference local files, please include these in your submission zip, and use relative pathing if you are embedding them (with the notebook in the base directory)

**YOUR ANSWER**

## Part 1.3b: Interpolating in $z$ (4 Points)
Perform a linear interpolation in the latent space of the autoencoder by choosing any two digits from the test set. What do you observe regarding the transition from on digit to the other?

_hint: Locate the positions in latent space of 2 data points (maybe a one and an eight). Then sample multiple latent space vectors along the line which joins the 2 points and pass them through the decoder._


In [None]:
# CODE FOR PART 1.3b IN THIS CELL


### Discussion
What did you observe in the interpolation? Is this what you expected?
* Can you relate the interpolation to your T-SNE visualization

**YOUR ANSWER**

# Part 2 - Deep Convolutional GAN

In this task, your main objective is to train a DCGAN (https://arxiv.org/abs/1511.06434) on the CIFAR-10 dataset. You should experiment with different architectures and tricks for stability in training (such as using different activation functions, batch normalization, different values for the hyper-parameters, etc.). In the end, you should provide us with:

- your best trained model (which we will be able to load and run),
- some generations for the fixed latent vectors $\mathbf{z}\sim \mathcal{N}\left(\mathbf{0}, \mathbf{I}\right)$ we have provided you with (train for a number of epochs and make sure there is no mode collapse),
- plots with the losses for the discriminator $D$ and the generator $G$ as the training progresses and explain whether your produced plots are theoretically sensible and why this is (or not) the case.
- a discussion on whether you noticed any mode collapse, where this behaviour may be attributed to, and explanations of what you did in order to cope with mode collapse.

## Part 2.1 (30 points)
**Your Task**:

a. Implement the DCGAN architecture.

b. Define a loss and implement the Training Loop

c. Visualize images sampled from your best model's generator ("Extension" Assessed on quality)

d. Discuss the experimentations which led to your final architecture. You can plot losses or generated results by other architectures that you tested to back your arguments (but this is not necessary to get full marks).


_Clarification: You should not be worrying too much about getting an "optimal" performance on your trained GAN. We want you to demonstrate to us that you experimented with different types of DCGAN variations, report what difficulties transpired throughout the training process, etc. In other words, if we see that you provided us with a running implementation, that you detail different experimentations that you did before providing us with your best one, and that you have grapsed the concepts, you can still get good marks. The attached model does not have to be perfect, and the extension marks for performance are only worth 10 points._

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import sampler
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from torch.optim.lr_scheduler import StepLR, MultiStepLR
import torch.nn.functional as F
import matplotlib.pyplot as plt

mean = torch.Tensor([0.4914, 0.4822, 0.4465])
std = torch.Tensor([0.247, 0.243, 0.261])
unnormalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())

def denorm(x, channels=None, w=None ,h=None, resize = False):

    x = unnormalize(x)
    if resize:
        if channels is None or w is None or h is None:
            print('Number of channels, width and height must be provided for resize.')
        x = x.view(x.size(0), channels, w, h)
    return x

def show(img):
    npimg = img.cpu().numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)))

if not os.path.exists(content_path/'CW_GAN'):
    os.makedirs(content_path/'CW_GAN')

GPU = True # Choose whether to use GPU
if GPU:
    device = torch.device("cuda"  if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")
print(f'Using {device}')

# We set a random seed to ensure that your results are reproducible.
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
torch.manual_seed(0)

### Part 2.1a: Implement DCGAN (8 Points)
Fill in the missing parts in the cells below in order to complete the Generator and Discriminator classes. You will need to define:

- The hyperparameters
- The constructors
- `decode`
- `discriminator`

Recomendations for experimentation:
- use the architecture that you implemented for the Autoencoder of Part 1 (encoder as discriminator, decoder as generator).
- use the architecture desribed in the DCGAN paper (https://arxiv.org/abs/1511.06434).

Some general reccomendations:
- add several convolutional layers (3-4).
- accelerate training with batch normalization after every convolutional layer.
- use the appropriate activation functions.
- Generator module: the upsampling can be done with various methods, such as nearest neighbor upsampling (`torch.nn.Upsample`) or transposed convolutions(`torch.nn.ConvTranspose2d`).
- Discriminator module: Experiment with batch normalization (`torch.nn.BatchNorm2d`) and leaky relu (`torch.nn.LeakyReLu`) units after each convolutional layer.

Try to follow the common practices for CNNs (e.g small kernels, max pooling, RELU activations), in order to narrow down your possible choices.

<font color="red">**Your model should not have more than 25 Million Parameters**</font>

The number of epochs that will be needed in order to train the network will vary depending on your choices. As an advice, we recommend that while experimenting you should allow around 20 epochs and if the loss doesn't sufficiently drop, restart the training with a more powerful architecture. You don't need to train the network to an extreme if you don't have the time.

#### Data loading

In [None]:
batch_size =   # change that

transform = transforms.Compose([
     transforms.ToTensor(),
     transforms.Normalize(mean=mean, std=std),
])
# note - data_path was initialized at the top of the notebook
cifar10_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform)
cifar10_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transform)
loader_train = DataLoader(cifar10_train, batch_size=batch_size)
loader_test = DataLoader(cifar10_test, batch_size=batch_size)

We'll visualize a subset of the test set:

In [None]:
samples, _ = next(iter(loader_test))

samples = samples.cpu()
samples = make_grid(denorm(samples), nrow=8, padding=2, normalize=False,
                        value_range=None, scale_each=False, pad_value=0)
plt.figure(figsize = (15,15))
plt.axis('off')
show(samples)

#### Model Definition
Define hyperparameters and the model

In [None]:
# *CODE FOR PART 2.1 IN THIS CELL*

# Choose the number of epochs, the learning rate
# and the size of the Generator's input noise vetor.

num_epochs =
learning_rate =
latent_vector_size =

# Other hyperparams


In [None]:
# *CODE FOR PART 2.1 IN THIS CELL*


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        #######################################################################
        #                       ** START OF YOUR CODE **
        #######################################################################

        #######################################################################
        #                       ** END OF YOUR CODE **
        #######################################################################

    # You can modify the arguments of this function if needed
    def forward(self, z):
        #######################################################################
        #                       ** START OF YOUR CODE **
        #######################################################################

        #######################################################################
        #                       ** END OF YOUR CODE **
        #######################################################################
        return out


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        #######################################################################
        #                       ** START OF YOUR CODE **
        #######################################################################

        #######################################################################
        #                       ** END OF YOUR CODE **
        #######################################################################

    # You can modify the arguments of this function if needed
    def forward(self, x):
        #######################################################################
        #                       ** START OF YOUR CODE **
        #######################################################################

        #######################################################################
        #                       ** END OF YOUR CODE **
        #######################################################################

        return out


<h2> Initialize Model and print number of parameters </h2>

You can use method `weights_init` to initialize the weights of the Generator and Discriminator networks. Otherwise, implement your own initialization, or do not use at all. You will not be penalized for not using initialization.

In [None]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [None]:
use_weights_init = True

model_G = Generator().to(device)
if use_weights_init:
    model_G.apply(weights_init)
params_G = sum(p.numel() for p in model_G.parameters() if p.requires_grad)
print("Total number of parameters in Generator is: {}".format(params_G))
print(model_G)
print('\n')

model_D = Discriminator().to(device)
if use_weights_init:
    model_D.apply(weights_init)
params_D = sum(p.numel() for p in model_D.parameters() if p.requires_grad)
print("Total number of parameters in Discriminator is: {}".format(params_D))
print(model_D)
print('\n')

print("Total number of parameters is: {}".format(params_G + params_D))

### Part 2.1b: Training the Model (12 Points)

#### Defining a Loss

In [None]:
# You can modify the arguments of this function if needed
def loss_function(out):
    loss =
    return loss

<h3>Choose and initialize optimizers</h3>

In [None]:
# setup optimizer
# You are free to add a scheduler or change the optimizer if you want. We chose one for you for simplicity.
beta1 = 0.5
optimizerD = torch.optim.Adam(model_D.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = torch.optim.Adam(model_G.parameters(), lr=learning_rate, betas=(beta1, 0.999))


<h3> Define fixed input vectors to monitor training and mode collapse. </h3>

In [None]:
fixed_noise = torch.randn(batch_size, latent_vector_size, 1, 1, device=device)
# Additional input variables should be defined here

#### Training Loop

Complete the training loop below. We've defined some variables to keep track of things during training:
* errD: Loss of Discriminator after being trained on real and fake instances
* errG: Loss of Generator
* D_x: Output of Discriminator for real images
* D_G_z1: Output of Discriminator for fake images (When Generator is not being trained)
* D_G_z2: Output of Discriminator for fake images (When Generator is being trained)

In [None]:
train_losses_G = []
train_losses_D = []

# <- You may wish to add logging info here
for epoch in range(num_epochs):
    # <- You may wish to add logging info here
    with tqdm.tqdm(loader_train, unit="batch") as tepoch:
        for i, data in enumerate(tepoch):
            train_loss_D = 0
            train_loss_G = 0

            #######################################################################
            #                       ** START OF YOUR CODE **
            #######################################################################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))

            # train with real

            # train with fake

            # (2) Update G network: maximize log(D(G(z)))

            #######################################################################
            #                       ** END OF YOUR CODE **
            #######################################################################
            # Logging
            if i % 50 == 0:
                tepoch.set_description(f"Epoch {epoch}")
                tepoch.set_postfix(D_G_z=f"{D_G_z1:.3f}/{D_G_z2:.3f}", D_x=D_x,
                                  Loss_D=errD.item(), Loss_G=errG.item())

    if epoch == 0:
        save_image(denorm(real_cpu.cpu()).float(), content_path/'CW_GAN/real_samples.png')
    with torch.no_grad():
        fake = model_G(fixed_noise)
        save_image(denorm(fake.cpu()).float(), str(content_path/'CW_GAN/fake_samples_epoch_%03d.png') % epoch)
    train_losses_D.append(train_loss_D / len(loader_train))
    train_losses_G.append(train_loss_G / len(loader_train))

# save  models
# if your discriminator/generator are conditional you'll want to change the inputs here
torch.jit.save(torch.jit.trace(model_G, (fixed_noise)), content_path/'CW_GAN/GAN_G_model.pth')
torch.jit.save(torch.jit.trace(model_D, (fake)), content_path/'CW_GAN/GAN_D_model.pth')

## Part 2.1c: Results (10 Points)
This part is fairly open-ended, but not worth too much so do not go crazy. The table below shows examples of what are considered good samples. Level 3 and above will get you 10/10 points, level 2 will roughly get you 5/10 points and level 1 and below will get you 0/10 points.

<table><tr>
<td>
  <p align="center">
    <img alt="Routing" src="https://drive.google.com/uc?id=18aWqRAnAVTRDY52y1yHSCdqSxUFRKOS9" width="%30">
    <br>
    <em style="color: grey">Level 1</em>
  </p>
</td>
<td>
  <p align="center">
    <img alt="Routing" src="https://drive.google.com/uc?id=1ymO2-jGAvWeUR2kaj_LxQcGYF1RWNRnw" width="%30">
    <br>
    <em style="color: grey">Level 2</em>
  </p>
</td>
<td>
  <p align="center">
    <img alt="Routing" src="https://drive.google.com/uc?id=13SW62ekW32NMYtfcdm_dCJJ3ZMOZEZAJ" width="%30">
    <br>
    <em style="color: grey">Level 3</em>
  </p>
</td>
</tr></table>

### Generator samples

In [None]:
input_noise = torch.randn(100, latent_vector_size, 1, 1, device=device)
with torch.no_grad():
    # visualize the generated images
    generated = model_G(input_noise).cpu()
    generated = make_grid(denorm(generated)[:100], nrow=10, padding=2, normalize=False,
                        value_range=None, scale_each=False, pad_value=0)
    plt.figure(figsize=(15,15))
    save_image(generated, content_path/'CW_GAN/Teaching_final.png')
    show(generated) # note these are now class conditional images columns rep classes 1-10

it = iter(loader_test)
sample_inputs, _ = next(it)
fixed_input = sample_inputs[0:64, :, :, :]
# visualize the original images of the last batch of the test set for comparison
img = make_grid(denorm(fixed_input), nrow=8, padding=2, normalize=False,
                value_range=None, scale_each=False, pad_value=0)
plt.figure(figsize=(15,15))
show(img)

## Part 2.1d: Engineering Choices (10 Points)

Discuss the process you took to arrive at your final architecture. This should include:

* Which empirically useful methods did you utilize
* What didn't work, what worked and what mattered most
* Are there any tricks you came across in the literature etc. which you suspect would be helpful here

**Your Answer**

## Part 2.2: Understanding GAN Training (5 points)


### Loss Curves
**Your task:**


Plot the losses curves for the discriminator $D$ and the generator $G$ as the training progresses and explain whether the produced curves are theoretically sensible and why this is (or not) the case (x-axis: epochs, y-axis: loss).

Make sure that the version of the notebook you deliver includes these results.

In [None]:
# ANSWER FOR PART 2.2 IN THIS CELL*

### Discussion

Do your loss curves look sensible? What would you expect to see and why?

**YOUR ANSWER**

## Part 2.3: Understanding Mode Collapse (5 points)
**Your task:**

Describe the what causes the phenomenon of Mode Collapse and how it may manifest in the samples from a GAN.

Based on the images created by your generator using the `fixed_noise` vector during training, did you notice any mode collapse? what this behaviour may be attributed to, and what did you try to eliminate / reduce it?

In [None]:
# Any additional code

### Discussion


**YOUR ANSWER**



# TA Test Cell
TAs will run this cell to ensure that your results are reproducible, and that your models have been defined suitably.

<font color="orange"> <b> Please provide the input and output transformations required to make your VAE and GANs work. If your GAN generator requires more than just noise as input, also specify this below (there are two marked cells for you to inspect) </b></font>


In [None]:
# If you want to run these tests yourself, change directory:
# %cd '.../dl_cw2/'
ta_data_path = "../data" # You can change this to = data_path when testing

In [None]:
!pip install -q torch torchvision

In [None]:
# Do not remove anything here
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, sampler
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import torch.nn.functional as F
import matplotlib.pyplot as plt

show = lambda img: plt.imshow(np.transpose(img.cpu().numpy(), (1,2,0)))

device = torch.device("cuda"  if torch.cuda.is_available() else "cpu")

# Do not change this cell!
torch.backends.cudnn.deterministic = True
torch.manual_seed(0)

In [None]:
############# CHANGE THESE (COPY AND PASTE FROM YOUR OWN CODE) #############
vae_transform = transforms.Compose([
    transforms.ToTensor(),
])

def vae_denorm(x):
    return x

def gan_denorm(x):
    return x

gan_latent_size =

# If your generator requires something other than noise as input, please specify
# two cells down from here

In [None]:
# Load VAE Dataset
test_dat = datasets.MNIST(ta_data_path, train=False, transform=vae_transform,
                          download=True)
vae_loader_test = DataLoader(test_dat, batch_size=32, shuffle=False)

In [None]:
############# MODIFY IF NEEDED #############
vae_input, _ = next(iter(vae_loader_test))

# If your generator is conditional, then please modify this input suitably
input_noise = torch.randn(100, gan_latent_size, 1, 1, device=device)
gan_input = [input_noise] # In case you want to provide a tuple, we wrap ours

In [None]:
# VAE Tests
# TAs will change these paths as you will have provided the model files manually
"""To TAs, you should have been creating a folder with the student uid
   And the .ipynb + models in the root. Then that path is './VAE_model.pth' etc.
"""
vae = model_G = torch.jit.load('./CW_VAE/VAE_model.pth')
vae.eval()

# Check if VAE is convolutional
def recurse_cnn_check(parent, flag):
    if flag:
        return flag
    children = list(parent.children())
    if len(children) > 0:
        for child in children:
            flag = flag or recurse_cnn_check(child, flag)
    else:
        params = parent._parameters
        if 'weight' in params.keys():
            flag = params['weight'].ndim == 4
    return flag

has_cnn = recurse_cnn_check(vae, False)
print("Used CNN" if has_cnn else "Didn't Use CNN")

vae_in = make_grid(vae_denorm(vae_input), nrow=8, padding=2, normalize=False,
                value_range=None, scale_each=False, pad_value=0)
plt.figure()
plt.axis('off')
show(vae_in)

vae_test = vae(vae_input.to(device))[0].detach()
vae_reco = make_grid(vae_denorm(vae_test), nrow=8, padding=2, normalize=False,
                value_range=None, scale_each=False, pad_value=0)
plt.figure()
plt.axis('off')
show(vae_reco)

In [None]:
# GAN Tests
model_G = torch.jit.load('./CW_GAN/GAN_G_model.pth')
model_D = torch.jit.load('./CW_GAN/GAN_D_model.pth')
[model.eval() for model in (model_G, model_D)]

# Check that GAN doesn't have too many parameters
num_param = sum(p.numel() for p in [*model_G.parameters(),*model_D.parameters()])

print(f"Number of Parameters is {num_param} which is", "ok" if num_param<25E+6 else "not ok")

# visualize the generated images
generated = model_G(*gan_input).cpu()
generated = make_grid(gan_denorm(generated)[:100].detach(), nrow=10, padding=2, normalize=False,
                    value_range=None, scale_each=False, pad_value=0)
plt.figure(figsize=(15,15))
plt.axis('off')
show(generated)