## A minimal example showing the code used train the VAE+ model in the deep learning project:
Project 3: *Generative modelling for phenotypic profiling using Variational Autoencoders*

In [1]:
# Load necessary packages
import os
from plotting import make_vae_plots
import re
import random
import time
from collections import defaultdict
from typing import *
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
#import seaborn as sns
import pandas as pd
import math 
import torch
import torchvision.utils as vutils
from torch import nn, Tensor, sigmoid
from torch.nn.functional import softplus
from torch.distributions import Distribution, Bernoulli
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from functools import reduce
from model_VAE_plus import PrintSize, Flatten, UnFlatten
from model_VAE_plus import ReparameterizedDiagonalGaussian
from model_VAE_plus import ReparameterizedDiagonalGaussianWithSigmoid
from model_VAE_plus import VariationalAutoencoder, VariationalInference, Discriminator

In [2]:
print(torch.__version__)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f">> Using device: {device}")

if torch.cuda.is_available():
    print(torch.cuda.is_available())
    print(torch.cuda.current_device())
    print(torch.cuda.device(0))
    print(torch.cuda.device_count())
    print(torch.cuda.get_device_name(0))

print("PyTorch Version {}" .format(torch.__version__))
print("Cuda Version {}" .format(torch.version.cuda))
print("CUDNN Version {}" .format(torch.backends.cudnn.version()))

1.12.1
>> Using device: cpu
PyTorch Version 1.12.1
Cuda Version None
CUDNN Version None


In [3]:
name = 'vae_plus'
discrim_name = "discriminator"
result_dir = 'results_plus_VAE/'
if not(os.path.exists(result_dir)):
    os.mkdir(result_dir)

# Set random seed for reproducibility
#manualSeed = 1570
manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
f = open(result_dir + 'random_seed.txt', "w")
f.write(str(manualSeed))
f.close()
random.seed(manualSeed)
torch.manual_seed(manualSeed)

Random Seed:  9747


<torch._C.Generator at 0x112b3e490>

### Evalutation and plotting functions

In [4]:
### Evalutation and plotting functions
def evaluation(test_loader, name=None, model_best=None, epoch=None,
               device='cpu'):
    # EVALUATION
    if model_best is None:
        # load best performing model
        model_best = torch.load(name + '.model')
        model_best = model_best.to(device)

    model_best.eval()
    loss = 0.
    N = 0.
    for indx_batch, (test_batch, test_target) in enumerate(test_loader):
        test_batch = test_batch.to(device)
        
        #loss_t = model_best.forward(test_batch, reduction='sum')
        loss_t, xhat, diagnostics, outputs = vi(model_best, test_batch)
        loss = loss + loss_t.item()
        N = N + test_batch.shape[0]
    loss = loss / N

    if epoch is None:
        print(f'FINAL LOSS: nll={loss}')
    else:
        print(f'Epoch: {epoch}, val nll={loss}')

    return loss

def samples_real(name, test_loader):
    # REAL-------
    num_x = 3
    num_y = 4
    x = next(iter(test_loader))[0].detach().numpy()

    fig, ax = plt.subplots(num_x, num_y)
    for i, ax in enumerate(ax.flatten()):
        plottable_image = np.transpose(x[i].reshape((3, 68, 68)), (1, 2, 0))
        ax.imshow(plottable_image)
        ax.axis('off')

    plt.savefig(name + '_real_images.png', bbox_inches='tight')
    plt.close()
    

def samples_generated(name, data_loader, extra_name=''):
    x = next(iter(data_loader))[0].detach().numpy()

    # GENERATIONS-------
    model_best = torch.load(name + '.model')
    model_best.cpu()
    model_best.eval()

    num_x = 4
    num_y = 4
    
    px = model_best.sample_from_prior(batch_size=num_x * num_y)['px']
    x_samples = px.sample()

    fig, ax = plt.subplots(num_x, num_y)
    for i, ax in enumerate(ax.flatten()):
        plottable_image = np.reshape(x_samples[i], (3, 68, 68)).permute(1, 2, 0)
        ax.imshow(plottable_image)
        ax.axis('off')

    plt.savefig(name + '_generated_images' + extra_name + '.png', bbox_inches='tight')
    plt.close()
    

def plot_curve(name, nll_val, x_label="epochs", y_label="nll"):
    plt.plot(np.arange(len(nll_val)), nll_val, linewidth='3')
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.savefig(name + '.png', bbox_inches='tight')
    plt.close()

### Parameters

In [5]:
# Initialize parameters
# Number of workers for dataloader
workers = 1

# Batch size during training
batch_size = 12

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 68

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector
latent_features = 100

# Size of feature maps in VAE encoder and decoder
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 10

# Max patience for early stopping
max_patience = 100

# Learning rate for optimizers
lr_VAE = 1e-4
lr_D = 1e-5

# SDG momentum
momentum_weight = 0.9

# Regularization
reg_weight = 1e-4

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Beta2 hyperparam for VAE loss
beta2 = 1.0

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

# The value the DMSO category is downsampled to
downsample_value = 16000

# Amount of data used for training, validation and testing
data_prct = 1
train_prct = 0.95

# Slope for delayed, linear, saturated schedulling of representation loss
slope  = 250
scale_repr = 100

# Smoothing of the real labels
smooth = False

### Dataset
The dataset contain 68x68 images of single cells treated with different compounds. For each of the utilized compounds there is an associated mechanism of action (MOA), which describes how the compound it affecting the cell. There are 12 different MOA classes and a control class called DSMO. Here only a small percentage of the full dataset (1000 images) will be utilized.

In [6]:
# Get current working directory
DIR = os.getcwd()

# Load metadata table
start_time = time.time()
metadata = pd.read_csv("../data/metadata_mini.csv")
print("pd.read_csv wiht pyarrow took %s seconds" % (time.time() - start_time))

pd.read_csv wiht pyarrow took 0.2244250774383545 seconds


In [7]:
metadata.groupby("moa").size().reset_index(name='counts').sort_values(by="counts", ascending=False)

Unnamed: 0,moa,counts
10,Microtubule stabilizers,387
4,DNA damage,81
1,Aurora kinase inhibitors,79
3,DMSO,65
6,Eg5 inhibitors,60
8,Kinase inhibitors,60
7,Epithelial,59
9,Microtubule destabilizers,54
12,Protein synthesis,44
0,Actin disruptors,38


In [8]:
# Map from class name to class index
classes = {index: name for name, index in enumerate(metadata["moa"].unique())}
classes_inv = {v: k for k, v in classes.items()}

In [9]:
# Map from class name to class index
classes = {index: name for name, index in enumerate(metadata["moa"].unique())}
classes_inv = {v: k for k, v in classes.items()}
classes

{'Kinase inhibitors': 0,
 'Eg5 inhibitors': 1,
 'Microtubule stabilizers': 2,
 'DMSO': 3,
 'Microtubule destabilizers': 4,
 'Protein synthesis': 5,
 'Actin disruptors': 6,
 'Epithelial': 7,
 'Aurora kinase inhibitors': 8,
 'DNA damage': 9,
 'Protein degradation': 10,
 'Cholesterol-lowering': 11,
 'DNA replication': 12}

In [10]:
metadata

Unnamed: 0.2,Unnamed: 0.1,level_0,index,Unnamed: 0,Multi_Cell_Image_Id,Multi_Cell_Image_Name,Single_Cell_Image_Id,Single_Cell_Image_Name,TableNumber,ImageNumber,...,Image_FileName_Tubulin,Image_PathName_Tubulin,Image_FileName_Actin,Image_PathName_Actin,Image_Metadata_Plate_DAPI,Image_Metadata_Well_DAPI,Replicate,Image_Metadata_Compound,Image_Metadata_Concentration,moa
0,0,191846,406091,406091,3080,Week7_7__F04_s2_w1C7A8F9CA-F54B-40DE-9EEE-7E71...,125,Week7_7__F04_s2_w1C7A8F9CA-F54B-40DE-9EEE-7E71...,7,3050,...,Week7_7__F04_s2_w2CAF44A0C-1EDB-41CE-8480-91F8...,Week7_34661,Week7_7__F04_s2_w4F9D05EDC-B012-4F3F-B558-5C56...,Week7_34661,Week7_34661,F04,2,PD-169316,10.00,Kinase inhibitors
1,1,2787,5545,5545,525,Week10_200907_C09_s3_w19437640F-29D0-42B8-9C85...,83,Week10_200907_C09_s3_w19437640F-29D0-42B8-9C85...,0,71,...,Week10_200907_C09_s3_w29B6DE609-DB82-47FD-A103...,Week10_40111,Week10_200907_C09_s3_w44DE7F152-E698-48C6-87F2...,Week10_40111,Week10_40111,C09,1,AZ138,0.03,Eg5 inhibitors
2,2,159666,309483,309483,2387,Week5_130707_E02_s4_w1CD0139A5-C58F-4E4E-BE44-...,36,Week5_130707_E02_s4_w1CD0139A5-C58F-4E4E-BE44-...,5,3004,...,Week5_130707_E02_s4_w208EA0654-74FC-40AB-800F-...,Week5_29301,Week5_130707_E02_s4_w495146015-FBA6-406D-BD4F-...,Week5_29301,Week5_29301,E02,1,taxol,0.30,Microtubule stabilizers
3,3,153731,292273,292273,2348,Week5_130707_D11_s2_w16C0AA106-F223-4CA3-976B-...,48,Week5_130707_D11_s2_w16C0AA106-F223-4CA3-976B-...,5,358,...,Week5_130707_D11_s2_w21FCEC64A-77E4-4257-8E5D-...,Week5_28921,Week5_130707_D11_s2_w46E781881-AAA6-4F65-9D1F-...,Week5_28921,Week5_28921,D11,2,taxol,0.30,Microtubule stabilizers
4,4,169436,343661,343661,2634,Week6_200607_D02_s4_w1021438BB-B36F-48B2-98DE-...,26,Week6_200607_D02_s4_w1021438BB-B36F-48B2-98DE-...,6,564,...,Week6_200607_D02_s4_w2E5B36F5D-7732-4B5B-86B1-...,Week6_31681,Week6_200607_D02_s4_w47065B016-72C5-4F2B-BDE2-...,Week6_31681,Week6_31681,D02,3,DMSO,0.00,DMSO
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,995,97113,183560,183560,2092,Week3_290607_F05_s2_w16BE8E365-50DA-470B-8D9B-...,150,Week3_290607_F05_s2_w16BE8E365-50DA-470B-8D9B-...,3,654,...,Week3_290607_F05_s2_w2185B8544-3831-4245-A00A-...,Week3_25461,Week3_290607_F05_s2_w47FBB3922-9DBE-450F-B484-...,Week3_25461,Week3_25461,F05,3,etoposide,3.00,DNA damage
996,996,171410,353345,353345,2593,Week6_200607_C11_s1_w10463464A-3744-4A1F-9083-...,38,Week6_200607_C11_s1_w10463464A-3744-4A1F-9083-...,6,2957,...,Week6_200607_C11_s1_w21D4F2DF2-F959-4696-B1DC-...,Week6_32121,Week6_200607_C11_s1_w4EBBCB968-9FAE-4424-97C4-...,Week6_32121,Week6_32121,C11,2,taxol,0.30,Microtubule stabilizers
997,997,27835,54679,54679,745,Week1_150607_B11_s1_w129C9B1A2-75C6-44AE-9E1D-...,8,Week1_150607_B11_s1_w129C9B1A2-75C6-44AE-9E1D-...,1,277,...,Week1_150607_B11_s1_w284256C54-8558-4EDF-8C81-...,Week1_22141,Week1_150607_B11_s1_w49853504B-D04D-45BE-945A-...,Week1_22141,Week1_22141,B11,2,taxol,0.30,Microtubule stabilizers
998,998,40876,75823,75823,736,Week1_150607_B06_s2_w12134F829-2C5E-4ED7-BA2B-...,149,Week1_150607_B06_s2_w12134F829-2C5E-4ED7-BA2B-...,1,2898,...,Week1_150607_B06_s2_w23FEAC940-8D8C-4735-A14E-...,Week1_22361,Week1_150607_B06_s2_w4899D2801-D3BC-45B8-8D35-...,Week1_22361,Week1_22361,B06,1,latrunculin B,1.00,Actin disruptors


In [11]:
# Dataloader class. Using the metadata table, images are sampled and 
# passed to VAE duing training
class SingleCellDataset(torch.utils.data.Dataset):
    def __init__(self, annotation_file, images_folder, class_map, 
                 mode='train', transform = None):
        self.df = annotation_file
        self.images_folder = images_folder
        self.transform = transform
        self.class2index = class_map
            

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        filename = self.df.loc[index, "Single_Cell_Image_Name"]
        label = self.class2index[self.df.loc[index, "moa"]]
        #subfolder = re.search("(.*)_", filename).group(1)
        image = np.load(os.path.join(self.images_folder, filename))
        if self.transform is not None:
            image = self.transform(image.astype(np.float32))
        return image, label

In [12]:
# The loaders perform the actual work
#images_folder = '/zhome/70/5/14854/nobackup/deeplearningf22/bbbc021/singlecell/singh_cp_pipeline_singlecell_images/'
images_folder = "../data/"
train_transforms = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Lambda(lambda x: x/x.max()),
    ]
)

train_set = SingleCellDataset(images_folder=images_folder, 
                              annotation_file=metadata, 
                              transform=train_transforms,
                              class_map=classes)

# Define the size of the train, validation and test datasets
data_amount = int(len(metadata) * data_prct)
train_size = int(train_prct * data_amount)
val_size = (data_amount - train_size) // 2
test_size = (data_amount - train_size) // 2

indicies = torch.randperm(len(metadata))
train_indices = indicies[:train_size]
val_indicies = indicies[train_size:train_size+val_size]
test_indicies = indicies[train_size+val_size:train_size+val_size+test_size]

training_set = torch.utils.data.Subset(train_set, train_indices.tolist())
val_set = torch.utils.data.Subset(train_set, val_indicies.tolist())
testing_set = torch.utils.data.Subset(train_set, test_indicies.tolist())

training_loader = torch.utils.data.DataLoader(training_set, batch_size=batch_size, 
                                             shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(testing_set, batch_size=batch_size, shuffle=True)

print(len(training_loader.dataset))
print(len(val_loader.dataset))
print(len(test_loader.dataset))

# Load a batch of images into memory
images, labels = next(iter(training_loader))

950
25
25


In [13]:
vae = VariationalAutoencoder(latent_features)
mse_loss = nn.MSELoss(reduction='none')
print(vae)

VariationalAutoencoder(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(6, 6), stride=(2, 2), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): PrintSize()
    (4): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): PrintSize()
    (8): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): PrintSize()
    (12): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): ReLU(inplace=True)
    (15): PrintSize()
    (16): Conv2d(512, 200, kernel_size=(4, 4), stride=

In [14]:
def reduce(x:Tensor) -> Tensor:
    """for each datapoint: sum over all dimensions"""
    return x.view(x.size(0), -1).sum(dim=1)

In [15]:
vi_test = VariationalInference(beta=1)
loss, xhat, diagnostics, outputs = vi_test(vae, images)
print(f"{'xhat':6} | shape: {list(xhat.shape)}")
print(f"{'loss':6} | mean = {loss:10.3f}, shape: {list(loss.shape)}")
for key, tensor in diagnostics.items():
    print(f"{key:6} | mean = {tensor.mean():10.3f}, shape: {list(tensor.shape)}")

Size: torch.Size([12, 64, 32, 32])
Size: torch.Size([12, 128, 16, 16])
Size: torch.Size([12, 256, 8, 8])
Size: torch.Size([12, 512, 4, 4])
Size: torch.Size([12, 200, 1, 1])
Size: torch.Size([12, 200])
Size: torch.Size([12, 512, 4, 4])
Size: torch.Size([12, 256, 8, 8])
Size: torch.Size([12, 128, 16, 16])
Size: torch.Size([12, 64, 32, 32])
Size: torch.Size([12, 6, 68, 68])
xhat   | shape: [12, 3, 68, 68]
loss   | mean =   2498.102, shape: []
elbo   | mean =  -2498.102, shape: [12]
log_px | mean =  -2492.245, shape: [12]
kl     | mean =      5.856, shape: [12]


In [16]:
discrim_test = Discriminator()
print(discrim_test)

Discriminator(
  (activation): LeakyReLU(negative_slope=0.2, inplace=True)
  (sigmoid): Sigmoid()
  (batchnorm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (batchnorm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (batchnorm3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_1): Conv2d(3, 64, kernel_size=(6, 6), stride=(2, 2), bias=False)
  (conv_2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv_3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv_4): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv_out): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
  (max_pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)


In [17]:
output, intermediate_rep = discrim_test(images)
print(f"{'output':6} | shape: {list(output.shape)}")
for i, tensor in enumerate(intermediate_rep):
    print(f"x_img_{i+1} | shape: {list(tensor.shape)}")

output | shape: [12, 1, 1, 1]
x_img_1 | shape: [12, 64, 32, 32]
x_img_2 | shape: [12, 128, 16, 16]
x_img_3 | shape: [12, 256, 8, 8]
x_img_4 | shape: [12, 512, 4, 4]


In [18]:
# VAE
vae = VariationalAutoencoder(latent_features)

# Evaluator: Variational Inference
vi = VariationalInference(beta=beta2)

# Discriminator
discrim = Discriminator()

In [19]:
# Setup Adam optimizers for both VAE and D
discriminator_optim = torch.optim.SGD(discrim.parameters(), lr=lr_D, 
                                      weight_decay=reg_weight, momentum=momentum_weight)
vae_optimizer = torch.optim.Adam(vae.parameters(), lr=lr_VAE, weight_decay=reg_weight)

In [20]:
# define dictionary to store the training curves
training_data = defaultdict(list)
validation_data = defaultdict(list)

def schedule_weight(delay):
    """ Defines a delayed, linear, saturated schedulling function.
    """
    step_norm = max(0.0, step - delay)
    w = step_norm / slope
    w = max(0.0, min(1.0, w)) #-- Bounded weight
    return w

epoch = 0
best_nll = 1000000.
patience = 0
nll_val = []
loss_repr_func = nn.MSELoss()
bce_loss = nn.BCELoss()

# Lists to keep track of progress
train_img_list = []
val_img_list = []
D_losses = []
D_losses_real = []
D_losses_fake = []
VAE_losses = []
elbo_losses = []
repr_losses = []
representation_loss = []
discriminator_real_loss = []
discriminator_fake_loss = []
discriminator_avg_loss = []
discriminator_sum_loss = []
total_loss_list = []
real_label = 1.0
fake_label = 0.0

# move the model to the device
vae = vae.to(device)
discrim = discrim.to(device)

step = 0

train_fixed, train_label_fixed = next(iter(val_loader))
train_fixed_b1, train_fixed_b2 = torch.split(train_fixed, split_size_or_sections=batch_size//2)
train_fixed_b1 = train_fixed_b1.reshape(batch_size//2, 3, 68, 68)
train_fixed_b1 = train_fixed_b1.to(device)

val_fixed, val_label_fixed = next(iter(val_loader))
val_fixed_b1, val_fixed_b2 = torch.split(val_fixed, split_size_or_sections=batch_size//2)
val_fixed_b1 = val_fixed_b1.reshape(batch_size//2, 3, 68, 68)
val_fixed_b1 = val_fixed_b1.to(device)

### Training procedure

In [21]:
while epoch < num_epochs:
    
    epoch += 1
    training_epoch_data = defaultdict(list)
    batch_discrim_avg_loss = []
    batch_discrim_sum_loss = []
    batch_discrim_real_loss = []
    batch_discrim_fake_loss = []
    batch_total_loss = []
    batch_repr_loss = []
    vae.train()
    discrim.train()

    # Go through each batch in the training dataset using the loader
    # Note that y is not necessarily known as it is here
    for i, (x, y) in enumerate(training_loader):
            
            step += 1

            ############################
            # (Step 0) Prepare data:
            ###########################
            #tmp_batch_size = x.size(0)
            b1, b2 = torch.split(x, split_size_or_sections=batch_size//2)
            batch_size_half = b1.size(0)
            
            b1 = b1.to(device)
            b2 = b2.to(device)

            # Reshape to in order to be used as input to conv layers
            #b1_reshaped = b1.reshape(batch_size_half, 3, 68, 68)
            #b2_reshaped = b2.reshape(batch_size_half, 3, 68, 68)

            
            ############################
            # (Step 1) Update Discriminator network:
            ###########################
            ## Train with all-real batch
            discriminator_optim.zero_grad()
            
            # Format labels
            label = torch.full((batch_size_half,), real_label,
                                dtype=torch.float, device=device) 
            if smooth:
                label = label*0.9
            
            # Forward pass real batch through D
            #b2_reshaped = b2.reshape(batch_size_half, 3, 68, 68)
            output_real_b2, inter_repr_real_b2 = discrim(b2)
            output_real_b2 = output_real_b2.view(-1)
            
            # Calculate loss on all-real batch
            errD_real = bce_loss(output_real_b2, label)
            
            # Calculate gradients for D in backward pass
            errD_real.backward()

            ## Train with all-fake batch
            ### Pass real images from batch 1 (b1) through VAE
            loss_elbo, xhat, diagnostics, outputs = vi(vae, b1)
            #xhat_reshaped = xhat.reshape(batch_size_half, 3, 68, 68).to(device)

            # Get the output from the discriminator
            output_hat, inter_repr_fake = discrim(xhat.detach())
            output_hat = output_hat.view(-1)

            # Calculate D's loss on the all-fake batch
            label.fill_(fake_label)
            errD_fake = bce_loss(output_hat, label)

            # Calculate the gradients for this batch, accumulated (summed) with previous gradients
            errD_fake.backward()

            # Compute error of D as sum over the fake and the real batches
            errD = errD_real + errD_fake
            #print((errD_real/(errD_real + errD_fake)).item())
            #print(f"errD: {errD.item()}")

            # Update weights
            discriminator_optim.step()
            
            ############################
            # (Step 2) Update VAE network:
            ###########################

            vae_optimizer.zero_grad()

            # Pass real images from batch 1 (b1) through VAE:
            #loss_elbo, xhat, diagnostics, outputs = vi(vae, b1)
            #xhat_reshaped = xhat.reshape(batch_size_half, 3, 68, 68).to(device)
            
            # Pass x-hat (reconstructions) through Discriminator
            # without calculating gradients
            output_hat, inter_repr_fake = discrim(xhat)
            
            # Pass real images from batch 1 (b1) through Discriminator
            # without calculating gradients
            output_real_b1, inter_repr_real_b1 = discrim(b1)
            
            # Calculate the loss between the representations
            # of the real images and the reconstructions 
            loss_repr = 0
            loss_repr_list = []

            delays = [slope * (k+1) for k in range(len(inter_repr_fake))]

            for j, (repr_fake, repr_real) in enumerate(zip(inter_repr_fake, inter_repr_real_b1)):
                loss_batch = loss_repr_func(repr_fake, repr_real)
                loss_repr_list.append(loss_batch)

                loss_weight = schedule_weight(delays[j])
                loss_repr += loss_weight * loss_batch #-- Schedule-based weighted average
            
            loss_total = loss_elbo + scale_repr * loss_repr
            
            # Backpropagate the gradients for the VAE
            loss_total.backward()
            vae_optimizer.step()
                        
            # gather data for the current bach
            for k, v in diagnostics.items():
                training_epoch_data[k] += [v.mean().item()]
            
            batch_discrim_real_loss.append(errD_real.item())
            batch_discrim_fake_loss.append(errD_fake.item())
            batch_discrim_avg_loss.append((errD_real/(errD_real + errD_fake)).item())
            batch_discrim_sum_loss.append(errD.item())   
            batch_repr_loss.append(loss_repr.item())
            batch_total_loss.append(loss_total.item()) 

            # Output training stats
            if step % 50 == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_VAE: %.4f'
                  % (epoch, num_epochs, i, len(training_loader),
                     errD.item(), loss_total.item()))
                #print(f"Training loss: {loss_total}")

            # Save Losses for plotting later
            D_losses.append(errD.item())
            D_losses_real.append(errD_real.item())
            D_losses_fake.append(errD_fake.item())
            VAE_losses.append(loss_total.item())
            elbo_losses.append(loss_elbo.item())
            repr_losses.append(loss_repr.item())

            # Check how the generator is doing by saving G's output on fixed_noise
            if (step % 100 == 0) or ((epoch == num_epochs-1) and (i == len(training_loader)-1)):
                with torch.no_grad():
                    loss_elbo, xhat, diagnostics, outputs = vi(vae, train_fixed_b1)
                    xhat = xhat.detach().cpu()
                train_img_list.append(vutils.make_grid(xhat, padding=2, normalize=True))

                with torch.no_grad():
                    loss_elbo, xhat, diagnostics, outputs = vi(vae, val_fixed_b1)
                    xhat = xhat.detach().cpu()
                val_img_list.append(vutils.make_grid(xhat, padding=2, normalize=True))


    # gather data for the full epoch
    for k, v in training_epoch_data.items():
        training_data[k] += [np.mean(training_epoch_data[k])]
    
    discriminator_real_loss.append(np.mean(batch_discrim_real_loss))
    discriminator_fake_loss.append(np.mean(batch_discrim_fake_loss))
    discriminator_avg_loss.append(np.mean(batch_discrim_avg_loss))
    discriminator_sum_loss.append(np.mean(batch_discrim_sum_loss))
    representation_loss.append(np.mean(batch_repr_loss))
    total_loss_list.append(np.mean(batch_total_loss))
        
    # Evaluate on a single batch, do not propagate gradients
    with torch.no_grad():
        vae.eval()
        
        # Just load a single batch from the validation loader
        x, y = next(iter(val_loader))
        x = x.to(device)

        b1, b2 = torch.split(x, split_size_or_sections=batch_size//2)
        batch_size_half = b1.size(0)
        b1 = b1.to(device)
        b2 = b2.to(device)

        # perform a forward pass through the model and compute the ELBO
        loss_val_elbo, xhat, diagnostics, outputs = vi(vae, b1)

        # Pass x-hat (reconstructions) through Discriminator
        # without calculating gradients
        output_hat, inter_repr_fake = discrim(xhat)
        
        # Pass real images from batch 1 (b1) through Discriminator
        # without calculating gradients
        output_real_b1, inter_repr_real_b1 = discrim(b1)

        # Calculate the loss between the representations
        # of the real images and the reconstructions 
        loss_repr = 0

        delays = [slope * (k+1) for k in range(len(inter_repr_fake))]

        for i, (repr_fake, repr_real) in enumerate(zip(inter_repr_fake, inter_repr_real_b1)):
            loss_batch = loss_repr_func(repr_fake, repr_real)

            loss_weight = schedule_weight(delays[i])
            loss_repr += loss_weight * loss_batch #-- Schedule-based weighted average
        
        loss_total_val = loss_val_elbo + scale_repr * loss_repr
        
        print(f"Loss_val: {loss_total_val}")
        nll_val.append(loss_total_val.cpu())  # save for plotting
        
        # gather data for the validation step
        for k, v in diagnostics.items():
            validation_data[k] += [v.mean().item()]
    
    # Reproduce the figure from the begining of the notebook, plot the training curves and show latent samples
    # make_vae_plots(vae, x, y, outputs, training_data, validation_data)

    if epoch == 1:
            print('saved!')
            torch.save(vae, result_dir + name + '.model')
            torch.save(discrim, result_dir + discrim_name + '.model')
            best_nll = loss_total_val
    else:
        samples_generated(result_dir + name, val_loader, extra_name="_epoch_" + str(epoch))
        if loss_total_val < best_nll:
            print('saved!')
            torch.save(vae, result_dir + name + '.model')
            torch.save(discrim, result_dir + discrim_name + '.model')
            best_nll = loss_total_val
            patience = 0
        else:
            patience = patience + 1
        
    if patience > max_patience:
        print("Max patience reached! Performing early stopping!")
        break

Size: torch.Size([6, 64, 32, 32])
Size: torch.Size([6, 128, 16, 16])
Size: torch.Size([6, 256, 8, 8])
Size: torch.Size([6, 512, 4, 4])
Size: torch.Size([6, 200, 1, 1])
Size: torch.Size([6, 200])
Size: torch.Size([6, 512, 4, 4])
Size: torch.Size([6, 256, 8, 8])
Size: torch.Size([6, 128, 16, 16])
Size: torch.Size([6, 64, 32, 32])
Size: torch.Size([6, 6, 68, 68])
[1/10][49/79]	Loss_D: 0.9076	Loss_VAE: 458.0801
Loss_val: 480.1374206542969
saved!
[2/10][20/79]	Loss_D: 0.7131	Loss_VAE: 445.7473
[2/10][70/79]	Loss_D: 0.3758	Loss_VAE: 331.9757
Loss_val: 354.2762145996094
saved!
[3/10][41/79]	Loss_D: 0.4797	Loss_VAE: 373.3381
Loss_val: 743.4060668945312
[4/10][12/79]	Loss_D: 0.4929	Loss_VAE: 281.9650
[4/10][62/79]	Loss_D: 0.4499	Loss_VAE: 326.4587
Loss_val: 286.6507263183594
saved!
[5/10][33/79]	Loss_D: 0.2638	Loss_VAE: 331.0581
Loss_val: 441.69635009765625
[6/10][4/79]	Loss_D: 0.3204	Loss_VAE: 298.1320
[6/10][54/79]	Loss_D: 0.2895	Loss_VAE: 233.0574
Loss_val: 288.4593505859375
[7/10][25/79]	Lo

In [22]:
print('saved final model!')
torch.save(vae, result_dir + name + '_final.model')
torch.save(discrim, result_dir + discrim_name + '_final.model')

make_vae_plots(vae, x, y, outputs, training_data, validation_data,
               save_img= result_dir + "vae_out.png", save=True)

saved final model!
Could not generate the plot of the latent sanples because of exception
name 'sns' is not defined


### Evaluation

In [23]:
# Evaluation
test_loss = evaluation(name=result_dir + name, test_loader=test_loader)
f = open(result_dir + name + '_test_loss.txt', "w")
f.write(str(test_loss))
f.close()

samples_real(result_dir + name, test_loader)

FINAL LOSS: nll=33.036627197265624


In [24]:
plt.figure(figsize=(10,5))
plt.title("VAE and Discriminator Loss During Training")
plt.plot(VAE_losses, label="VAE")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig(result_dir + 'VAE_training_loss.png', bbox_inches='tight')
plt.close()

plt.figure(figsize=(10,5))
plt.title("Discriminator Loss During Training")
plt.plot(D_losses, label="D")
plt.plot(D_losses_fake, label="D_fake")
plt.plot(D_losses_real, label="D_real")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig(result_dir + 'Discrim_training_loss.png', bbox_inches='tight')

plt.close()

plt.figure(figsize=(10,5))
plt.title("Elbo Loss During Training")
plt.plot(elbo_losses, label="Elbo")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig(result_dir + 'Elbo_training_loss.png', bbox_inches='tight')
plt.close()

plt.figure(figsize=(10,5))
plt.title("Representation Loss During Training")
plt.plot(repr_losses, label="Repr")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig(result_dir + 'repr_training_loss.png', bbox_inches='tight')
plt.close()

#fig = plt.figure(figsize=(8,8))
#plt.axis("off")
#ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in train_img_list]
#ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
# saving to m4 using ffmpeg writer
#writervideo = animation.FFMpegWriter(fps=60)
#ani.save(result_dir + 'train_VAE_progression.mp4', writer=writervideo)
#plt.close()

#fig = plt.figure(figsize=(8,8))
#plt.axis("off")
#ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in val_img_list]
#ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
# saving to m4 using ffmpeg writer
#writervideo = animation.FFMpegWriter(fps=60)
#ani.save(result_dir + 'val_VAE_progression.mp4', writer=writervideo)
#plt.close()


plot_curve(result_dir + name + "_nll_val_curve", nll_val)
plot_curve(result_dir + name + "_discriminator_real_loss", discriminator_real_loss, y_label="discriminator_real_loss")
plot_curve(result_dir + name + "_discriminator_fake_loss", discriminator_fake_loss, y_label="discriminator_fake_loss")
plot_curve(result_dir + name + "_discriminator_avg_loss", discriminator_avg_loss, y_label="discriminator_avg_loss")
plot_curve(result_dir + name + "_discriminator_sum_loss", discriminator_sum_loss, y_label="discriminator_sum_loss")
plot_curve(result_dir + name + "_representation_loss", representation_loss, y_label="representation_loss")
plot_curve(result_dir + name + "_total_loss", total_loss_list, y_label="total_loss")

np.save(result_dir + 'nll_val.npy', nll_val)
np.save(result_dir +'_discriminator_real_loss.npy',  discriminator_real_loss)
np.save(result_dir +'discriminator_fake_loss.npy',  discriminator_fake_loss)
np.save(result_dir +'_discriminator_avg_loss.npy',  discriminator_avg_loss)
np.save(result_dir +'_discriminator_sum_loss.npy',  discriminator_sum_loss)
np.save(result_dir +'_representation_loss.npy',  representation_loss)
np.save(result_dir +'_total_loss.npy',  total_loss_list)
np.save(result_dir +'_repr_losses.npy',  repr_losses)
np.save(result_dir +'_elbo_losses.npy',  elbo_losses)
np.save(result_dir +'_D_losses_fake.npy',  D_losses_fake)
np.save(result_dir +'_D_losses_real.npy',  D_losses_real)
np.save(result_dir +'_D_losses.npy',  D_losses)
np.save(result_dir +'_VAE_losses.npy',  VAE_losses)