<a target="_blank" href="https://colab.research.google.com/github/nirban/pytorch_tutorial/blob/main/VAE/Variational_Auto_Encoder.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Variational Auto Encoder From Scratch

In [None]:
import os
import cv2
import sys
import random
import shutil
import argparse
import glob
from tqdm import tqdm
import numpy as np
import _csv as csv
import torch
import torch.nn as nn


In [None]:
r"""
A very simple VAE which has the following architecture
Encoder
    For Conditional model we stack num_classes empty channels onto the image
        We make the gt_label index channel as `1` (See figure in README)
    N * Conv BN Activation Blocks
    FC layers for mean
    FC layers for variance

Decoder
    For Conditional model we also concat the one hot label feature onto the z input
        (See figure in README)
    FC Layers taking z to higher dimensional feature
    N * ConvTranspose BN Activation Blocks
"""

In [None]:
class VAE(nn.Module):
    def __init__(self,
                 config
                 ):
        super(VAE, self).__init__()
        activation_map = {
            'relu': nn.ReLU(),
            'leaky': nn.LeakyReLU(),
            'tanh': nn.Tanh(),
            'gelu': nn.GELU(),
            'silu': nn.SiLU()
        }
        
        self.config = config
        ##### Validate the configuration for the model is correctly setup #######
        assert config['transpose_activation_fn'] is None or config['transpose_activation_fn'] in activation_map
        assert config['dec_fc_activation_fn'] is None or config['dec_fc_activation_fn'] in activation_map
        assert config['conv_activation_fn'] is None or config['conv_activation_fn'] in activation_map
        assert config['enc_fc_activation_fn'] is None or config['enc_fc_activation_fn'] in activation_map
        assert config['enc_fc_layers'][-1] == config['dec_fc_layers'][0] == config['latent_dim'], \
            "Latent dimension must be same as fc layers number"
        
        self.num_classes = config['num_classes']
        self.transposebn_channels = config['transposebn_channels']
        self.latent_dim = config['latent_dim']
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Number of input channels will change if its a conditional model
        if config['concat_channel'] and config['conditional']:
            config['convbn_channels'][0] += self.num_classes
        
        # Encoder is just Conv bn blocks followed by fc for mean and variance
        self.encoder_layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(config['convbn_channels'][i], config['convbn_channels'][i + 1],
                          kernel_size=config['conv_kernel_size'][i], stride=config['conv_kernel_strides'][i]),
                nn.BatchNorm2d(config['convbn_channels'][i + 1]),
                activation_map[config['conv_activation_fn']]
            )
            for i in range(config['convbn_blocks'])
        ])
        
        encoder_mu_activation = nn.Identity() if config['enc_fc_mu_activation'] is None else activation_map[
            config['enc_fc_mu_activation']]
        self.encoder_mu_fc = nn.ModuleList([
            nn.Sequential(
                nn.Linear(config['enc_fc_layers'][i], config['enc_fc_layers'][i + 1]),
                encoder_mu_activation
            )
            for i in range(len(config['enc_fc_layers']) - 1)
        ])
        encoder_var_activation = nn.Identity() if config['enc_fc_var_activation'] is None else activation_map[
            config['enc_fc_var_activation']]
        self.encoder_var_fc = nn.ModuleList([
            nn.Sequential(
                nn.Linear(config['enc_fc_layers'][i], config['enc_fc_layers'][i + 1]),
                encoder_var_activation
            )
            for i in range(len(config['enc_fc_layers']) - 1)
        ])
        
        # Number of features will change if it's a conditional model
        if config['decoder_fc_condition'] and config['conditional']:
            config['dec_fc_layers'][0] += self.num_classes

        self.decoder_layers = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose2d(config['transposebn_channels'][i], config['transposebn_channels'][i + 1],
                                   kernel_size=config['transpose_kernel_size'][i],
                                   stride=config['transpose_kernel_strides'][i]),
                nn.BatchNorm2d(config['transposebn_channels'][i + 1]),
                activation_map[config['transpose_activation_fn']]
            )
            for i in range(config['transpose_bn_blocks'])
        ])
        
        self.decoder_fc = nn.ModuleList([
            nn.Sequential(
                nn.Linear(config['dec_fc_layers'][i], config['dec_fc_layers'][i + 1]),
                activation_map[config['dec_fc_activation_fn']]
            )
            for i in range(len(config['dec_fc_layers']) - 1)
        
        ])

    def forward(self, x, label=None):
        out = x
        if self.config['concat_channel'] and self.config['conditional']:
            # Stack the label feature maps onto the input if its a conditional model
            # And config asks to do so
            label_ch_map = torch.zeros((x.size(0), self.num_classes, *x.shape[2:])).to(self.device)
            batch_idx, label_idx = (torch.arange(0, x.size(0), device=self.device),
                                    label[torch.arange(0, x.size(0), device=self.device)])
            label_ch_map[batch_idx, label_idx, :, :] = 1
            out = torch.cat([x, label_ch_map], dim=1)
            
        for layer in self.encoder_layers:
            out = layer(out)
        out = out.reshape((x.size(0), -1))
        mu = out
        for layer in self.encoder_mu_fc:
            mu = layer(mu)
        std = out
        for layer in self.encoder_var_fc:
            std = layer(std)

        z = self.reparameterize(mu, std)
        generated_out = self.generate(z, label)
        if self.config['log_variance']:
            return {
                'mean': mu,
                'log_variance': std,
                'image': generated_out,
            }
        else:
            return {
                'mean': mu,
                'std': std,
                'image': generated_out,
            }
    
    def generate(self, z, label=None):
        out = z
        if self.config['decoder_fc_condition'] and self.config['conditional']:
            assert label is not None, "Label cannot be none for conditional generation"
            # Concat the num_classes dimensional one hot feature vector onto z
            # For label 9 this will be [0,0,0,0,0,0,0,0,0,1]
            label_fc_input = torch.zeros((z.size(0), self.num_classes)).to(self.device)
            batch_idx, label_idx = (torch.arange(0, z.size(0), device=self.device),
                                    label[torch.arange(0, z.size(0), device=self.device)])
            label_fc_input[batch_idx, label_idx] = 1
            out = torch.cat([out, label_fc_input], dim=-1)
        for layer in self.decoder_fc:
            out = layer(out)
        # Figure out how to reshape based on desired number of channels in transpose convolution
        hw = torch.as_tensor(out.size(-1) / self.transposebn_channels[0]).to(self.device)
        spatial = int(torch.sqrt(hw))
        assert spatial * spatial == hw
        out = out.reshape((z.size(0), -1, spatial, spatial))
        for layer in self.decoder_layers:
            out = layer(out)
        return out
    
    def sample(self, label=None, num_images=1, z=None):
        if z is None:
            z = torch.randn((num_images, self.latent_dim))
        if self.config['conditional']:
            assert label is not None, "Label cannot be none for conditional sampling"
            assert label.size(0) == num_images
        assert z.size(0) == num_images
        out = self.generate(z, label)
        return out
    
    def reparameterize(self, mu, std_or_logvariance):
        if self.config['log_variance']:
            std = torch.exp(0.5 * std_or_logvariance)
        else:
            std = std_or_logvariance
        z = torch.randn_like(std)
        return z * std + mu

In [None]:
def get_model(config):
    model = VAE(
        config=config['model_params']
    )
    return model

In [None]:
! mkdir config
! cd config
! pwd
! wget https://raw.githubusercontent.com/explainingai-code/VAE-Pytorch/refs/heads/main/config/vae_kl_latent4.yaml -P config/
! wget https://raw.githubusercontent.com/explainingai-code/VAE-Pytorch/refs/heads/main/config/vae_kl.yaml -P config/
! wget https://raw.githubusercontent.com/explainingai-code/VAE-Pytorch/refs/heads/main/config/vae_kl_latent4_enc_channel_dec_fc_condition.yaml -P config/
! wget https://raw.githubusercontent.com/explainingai-code/VAE-Pytorch/refs/heads/main/config/vae_nokl.yaml -P config/

In [None]:
import sys
sys.append('../')
import yaml

config_path = '../config/vae_kl_latent4.yaml'
with open(config_path, 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)
    
model = get_model(config)
labels = torch.zeros((3)).long()
labels[0] = 0
labels[1] = 2
out = model(torch.rand((3,1,28,28)), labels)
print(out['mean'].shape)
print(out['image'].shape)

# Get Dataset

In [4]:
! mkdir -p data/train/images
! mkdir -p data/test/images

In [5]:
#!/bin/bash
! curl -L -o archive.zip https://www.kaggle.com/api/v1/datasets/download/oddrationale/mnist-in-csv

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 15.2M  100 15.2M    0     0  3988k      0  0:00:03  0:00:03 --:--:-- 6773k


In [None]:
def extract_images(save_dir, csv_fname):
    assert os.path.exists(save_dir), "Directory {} to save images does not exist".format(save_dir)
    assert os.path.exists(csv_fname), "Csv file {} does not exist".format(csv_fname)
    with open(csv_fname) as f:
        reader = csv.reader(f)
        for idx, row in enumerate(reader):
            if idx == 0:
                continue
            im = np.zeros((784))
            im[:] = list(map(int, row[1:]))
            im = im.reshape((28,28))
            if not os.path.exists(os.path.join(save_dir, row[0])):
                os.mkdir(os.path.join(save_dir, row[0]))
            cv2.imwrite(os.path.join(save_dir, row[0], '{}.png'.format(idx)), im)
            if idx % 1000 == 0:
                print('Finished creating {} images in {}'.format(idx+1, save_dir))

In [None]:
extract_images('data/train/images', 'data/mnist_train.csv')
extract_images('data/test/images', 'data/mnist_test.csv')

# Dataset Loader

In [None]:
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader

In [None]:
class MnistDataset(Dataset):
    def __init__(self, split, im_path, im_ext='png'):
        self.split = split
        self.im_ext = im_ext
        self.images, self.labels = self.load_images(im_path)
        
    def load_images(self, im_path):
        assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
        ims = []
        labels = []
        for d_name in tqdm(os.listdir(im_path)):
            for fname in glob.glob(os.path.join(im_path, d_name, '*.{}'.format(self.im_ext))):
                ims.append(fname)
                labels.append(int(d_name))
        print('Found {} images for split {}'.format(len(ims), self.split))
        return ims, labels

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        im = cv2.imread(self.images[index], 0)
        label = self.labels[index]
        # Convert to 0 to 255 into -1 to 1
        im = 2*(im / 255) - 1
        # Convert H,W,C into 1,C,H,W
        im_tensor = torch.from_numpy(im)[None,:]
        return im_tensor, torch.as_tensor(label)

In [None]:
mnist = MnistDataset('test', 'data/test/images')
mnist_loader = DataLoader(mnist, batch_size=16, shuffle=True, num_workers=0)
for im, label in mnist_loader:
    print('Image dimension', im.shape)
    print('Label dimension: {}'.format(label.shape))
    break

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Utilities

In [None]:
import pickle
import torchvision
from einops import rearrange
from torchvision.utils import make_grid
from matplotlib import pyplot as plt

In [None]:
def visualize_latent_space(config, model, data_loader, save_fig_path):
    r"""
    Method to visualize the latent dimension by simply plotting the means for each of the images
    :param config: Config file used to create the model
    :param model:
    :param data_loader:
    :param save_fig_path: Path where the latent space image will be saved
    :return:
    """
    labels = []
    means = []
    
    for im, label in tqdm(data_loader):
        im = im.float().to(device)
        label = label.long().to(device)
        output = model(im, label)
        labels.append(label)
        mean = output['mean']
        means.append(mean)
    
    labels = torch.cat(labels, dim=0).reshape(-1)
    means = torch.cat(means, dim=0)
    if model.latent_dim != 2:
        print('Latent dimension > 2 and hence projecting')
        U, _, V = torch.pca_lowrank(means, center=True, niter=2)
        proj_means = torch.matmul(means, V[:, :2])
        if not os.path.exists(config['train_params']['task_name']):
            os.mkdir(config['train_params']['task_name'])
        pickle.dump(V, open('{}/pca_matrix.pkl'.format(config['train_params']['task_name']), 'wb'))
        means = proj_means
    
    fig, ax = plt.subplots()
    for num in range(10):
        idxs = torch.where(labels == num)[0]
        ax.scatter(means[idxs, 0].cpu().numpy(), means[idxs, 1].cpu().numpy(), s=10, label=str(num),
                   alpha=1.0, edgecolors='none')
    ax.legend()
    ax.grid(True)
    plt.savefig(save_fig_path)


In [None]:
def reconstruct(config, model, dataset, num_images=100):
    r"""
    Randomly sample points from the dataset and visualize image and its reconstruction
    :param config: Config file used to create the model
    :param model: Trained model
    :param dataset: Mnist dataset(not the data loader)
    :param num_images: NUmber of images to visualize
    :return:
    """
    print('Generating reconstructions')
    if not os.path.exists(config['train_params']['task_name']):
        os.mkdir(config['train_params']['task_name'])
    if not os.path.exists(
            os.path.join(config['train_params']['task_name'], config['train_params']['output_train_dir'])):
        os.mkdir(os.path.join(config['train_params']['task_name'], config['train_params']['output_train_dir']))
    
    idxs = torch.randint(0, len(dataset) - 1, (num_images,))
    ims = torch.cat([dataset[idx][0][None, :] for idx in idxs]).float()
    labels = torch.cat([dataset[idx][1][None] for idx in idxs]).long()
    
    output = model(ims, labels)
    generated_im = output['image']
    
    # Dataset generates -1 to 1 we convert it to 0-1
    ims = (ims + 1) / 2
    # For reconstruction, we specifically flip it(white digit on black background -> black digit on white background)
    # for easier visualization
    generated_im = 1 - (generated_im + 1) / 2
    out = torch.hstack([ims, generated_im])
    output = rearrange(out, 'b c h w -> b () h (c w)')
    grid = make_grid(output, nrow=10)
    img = torchvision.transforms.ToPILImage()(grid)
    img.save(os.path.join(config['train_params']['task_name'],
                          config['train_params']['output_train_dir'],
                          'reconstruction.png'))

In [None]:
def visualize_interpolation(config, model, dataset, interpolation_steps=500, save_dir='interp'):
    r"""
        We randomly fetch two points and linearly interpolate between them.
        We only use the mean values for interpolation
    :param config:
    :param model:
    :param dataset:
    :param interpolation_steps: We will interpolate these many points between start and end
    :param save_dir:
    :return:
    """
    # if model.config['conditional']:
    #     print('Interpolation is only for non conditional model. Check README for details. Skipping...')
    #     return
    print('Interpolating between images')
    if not os.path.exists(config['train_params']['task_name']):
        os.mkdir(config['train_params']['task_name'])
    if not os.path.exists(
            os.path.join(config['train_params']['task_name'], config['train_params']['output_train_dir'])):
        os.mkdir(os.path.join(config['train_params']['task_name'], config['train_params']['output_train_dir']))
    
    if os.path.exists(os.path.join(config['train_params']['task_name'],
                                   config['train_params']['output_train_dir'],
                                   save_dir)):
        shutil.rmtree(os.path.join(config['train_params']['task_name'],
                                   config['train_params']['output_train_dir'],
                                   save_dir))
    os.mkdir(os.path.join(config['train_params']['task_name'],
                          config['train_params']['output_train_dir'],
                          save_dir))
    
    idxs = torch.randint(0, len(dataset)-1, (2,))
    if model.config['conditional']:
        label_val = torch.randint(0, 9, (1,))
        labels = (torch.ones((1,)).long().to(device) * label_val).repeat((2))
    else:
        labels = None
    ims = torch.cat([dataset[idx][0][None, :] for idx in idxs]).float()
    means = model(ims, labels)['mean']
    factors = torch.linspace(0, 1.0, steps=interpolation_steps)
    means_start = means[0]
    means_end = means[1]
    if model.config['conditional']:
        label_val = torch.randint(0, 9, (1,))
        labels = (torch.ones((1,)).long().to(device) * label_val).repeat((interpolation_steps))
    else:
        labels = None
    means = factors[:, None] * means_end[None, :] + (1 - factors[:, None]) * means_start[None, :]
    out = model.generate(means, labels)
    for idx in tqdm(range(out.shape[0])):
        # Convert generated output from -1 to 1 range to 0-255
        im = 255 * (out[idx, 0] + 1) / 2
        cv2.imwrite('{}/{}.png'.format(os.path.join(config['train_params']['task_name'],
                                                    config['train_params']['output_train_dir'],
                                                    save_dir), idx), im.cpu().numpy())

In [None]:
def visualize_manifold(config, model):
    print('Generating the manifold')
    if not os.path.exists(config['train_params']['task_name']):
        os.mkdir(config['train_params']['task_name'])
    if not os.path.exists(
            os.path.join(config['train_params']['task_name'], config['train_params']['output_train_dir'])):
        os.mkdir(os.path.join(config['train_params']['task_name'], config['train_params']['output_train_dir']))
    
    # For conditional model we can generate all numbers for all points in the space.
    # This because the condition introduces the variance even if the point (z) is the same
    # But for non-conditional only one output is possible for one z hence progress bar range is 1
    if model.config['conditional']:
        pbar_range = model.num_classes
    else:
        pbar_range = 1
    for label_val in tqdm(range(pbar_range)):
        num_images = 900
        # For values below use the latent images to get a sense of what ranges we need to plot
        xs = torch.linspace(-10, 10, 30)
        ys = torch.linspace(-10, 10, 30)
        
        xs, ys = torch.meshgrid([xs, ys])
        xs = xs.reshape(-1, 1)
        ys = ys.reshape(-1, 1)
        zs = torch.cat([xs, ys], dim=-1)
        if model.latent_dim != 2:
            if not os.path.exists(os.path.join(config['train_params']['task_name'], 'pca_matrix.pkl')):
                print('Latent dimension > 2 but no pca info available. '
                      'Call visualize_latent_space first. Skipping visualize_manifold')
            else:
                V = pickle.load(open(os.path.join(config['train_params']['task_name'], 'pca_matrix.pkl'), 'rb'))
                reconstruct_means = torch.matmul(zs, V[:, :2].T)
                zs = reconstruct_means
        label = (torch.ones((1,)).long().to(device) * label_val).repeat((num_images))
        generated_ims = model.sample(label, num_images, z=zs)
        generated_ims = ((generated_ims + 1) / 2)
        grid = make_grid(generated_ims, nrow=30)
        img = torchvision.transforms.ToPILImage()(grid)
        img.save(os.path.join(config['train_params']['task_name'],
                              config['train_params']['output_train_dir'],
                              'manifold_{}.png'.format(label_val) if model.config['conditional'] else 'manifold.png'))

# Training

In [None]:
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [None]:
def train_for_one_epoch(epoch_idx, model, mnist_loader, optimizer, crtierion, config):
    r"""
    Method to run the training for one epoch.
    :param epoch_idx: iteration number of current epoch
    :param model: VAE model
    :param mnist_loader: Data loder for mnist
    :param optimizer: optimzier to be used taken from config
    :param crtierion: For computing the loss
    :param config: configuration for the current run
    :return:
    """
    recon_losses = []
    kl_losses = []
    losses = []
    # We ignore the label for VAE
    for im, label in tqdm(mnist_loader):
        im = im.float().to(device)
        label = label.long().to(device)
        optimizer.zero_grad()
        output = model(im, label)
        mean = output['mean']
        std, log_variance = None, None
        if config['model_params']['log_variance']:
            log_variance = output['log_variance']
        else:
            std = output['std']
        generated_im = output['image']
        if config['train_params']['save_training_image']:
            cv2.imwrite('input.jpeg', (255 * (im.detach() + 1) / 2).cpu().numpy()[0, 0])
            cv2.imwrite('output.jpeg', (255 * (generated_im.detach() + 1) / 2).cpu().numpy()[0, 0])
        
        if config['model_params']['log_variance']:
            kl_loss = torch.mean(0.5 * torch.sum(torch.exp(log_variance) + mean ** 2 - 1 - log_variance, dim=-1))
        else:
            kl_loss = torch.mean(0.5 * torch.sum(std ** 2 + mean ** 2 - 1 - torch.log(std ** 2), dim=-1))
        recon_loss = crtierion(generated_im, im)
        loss = recon_loss + config['train_params']['kl_weight'] * kl_loss
        recon_losses.append(recon_loss.item())
        losses.append(loss.item())
        kl_losses.append(kl_loss.item())
        loss.backward()
        optimizer.step()
    print('Finished epoch: {} | Recon Loss : {:.4f} | KL Loss : {:.4f}'.format(epoch_idx + 1,
                                                                               np.mean(recon_losses),
                                                                               np.mean(kl_losses)))
    return np.mean(losses)

In [None]:
def train(args):
    ######## Read the config file #######
    with open(args.config_path, 'r') as file:
        try:
            config = yaml.safe_load(file)
        except yaml.YAMLError as exc:
            print(exc)
    print(config)
    #######################################
    
    ######## Set the desired seed value #######
    seed = config['train_params']['seed']
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if device == 'cuda':
        torch.cuda.manual_seed_all(args.seed)
    #######################################
    
    # Create the model and dataset
    model = get_model(config).to(device)
    mnist = MnistDataset('train', config['train_params']['train_path'])
    mnist_test = MnistDataset('test', config['train_params']['test_path'])
    mnist_loader = DataLoader(mnist, batch_size=config['train_params']['batch_size'], shuffle=True, num_workers=4)
    mnist_test_loader = DataLoader(mnist_test, batch_size=config['train_params']['batch_size'], shuffle=False,
                                   num_workers=0)
    num_epochs = config['train_params']['epochs']
    optimizer = Adam(model.parameters(), lr=config['train_params']['lr'])
    scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=1, verbose=True)
    criterion = {
        'l1': torch.nn.L1Loss(),
        'l2': torch.nn.MSELoss()
    }.get(config['train_params']['crit'])
    
    # Deleting old outputs for this task
    # Create output directories
    if os.path.exists(config['train_params']['task_name']):
        shutil.rmtree(config['train_params']['task_name'])
    os.mkdir(config['train_params']['task_name'])
    os.mkdir(os.path.join(config['train_params']['task_name'], config['train_params']['output_train_dir']))
    
    best_loss = np.inf
    latent_im_path = os.path.join(config['train_params']['task_name'],
                                  config['train_params']['output_train_dir'],
                                  'latent_epoch_{}.jpeg')
    with torch.no_grad():
        model.eval()
        visualize_latent_space(config, model, mnist_test_loader, save_fig_path=latent_im_path.format(0))
        model.train()
    for epoch_idx in range(num_epochs):
        mean_loss = train_for_one_epoch(epoch_idx, model, mnist_loader, optimizer, criterion, config)
        if config['train_params']['save_latent_plot']:
            model.eval()
            with torch.no_grad():
                print('Generating latent plot on test set')
                visualize_latent_space(config, model, mnist_test_loader,
                                       save_fig_path=latent_im_path.format(epoch_idx + 1))
            model.train()
        scheduler.step(mean_loss)
        # Simply update checkpoint if found better version
        if mean_loss < best_loss:
            print('Improved Loss to {:.4f} .... Saving Model'.format(mean_loss))
            torch.save(model.state_dict(), os.path.join(config['train_params']['task_name'],
                                                        config['train_params']['ckpt_name']))
            best_loss = mean_loss
        else:
            print('No Loss Improvement')

# Perform the training

In [None]:
parser = argparse.ArgumentParser(description='Arguments for conditional vae training')
parser.add_argument('--config', dest='config_path',
                        default='config/vae_kl.yaml', type=str)
args = parser.parse_args()
train(args)

# Inference

In [None]:
def inference(args):
    with open(args.config_path, 'r') as file:
        try:
            config = yaml.safe_load(file)
        except yaml.YAMLError as exc:
            print(exc)
    print(config)
    
    model = get_model(config).to(device)
    model.load_state_dict(torch.load(os.path.join(config['train_params']['task_name'],
                                                  config['train_params']['ckpt_name']), map_location='cpu'))
    model.eval()
    mnist = MnistDataset('test', 'data/test/images')
    mnist_loader = DataLoader(mnist, batch_size=config['train_params']['batch_size'], shuffle=True, num_workers=4)
    
    with torch.no_grad():
        latent_im_path = os.path.join(config['train_params']['task_name'],
                                      config['train_params']['output_train_dir'],
                                      'latent_inference.jpeg')
        visualize_latent_space(config, model, mnist_loader, latent_im_path)
        visualize_interpolation(config, model, mnist)
        reconstruct(config, model, mnist)
        visualize_manifold(config, model)

In [None]:
parser = argparse.ArgumentParser(description='Arguments for vae inference')
parser.add_argument('--config', dest='config_path',
                        default='config/vae_kl.yaml', type=str)
args = parser.parse_args()
inference(args)