In [None]:
%load_ext autoreload
%autoreload 2

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

In [None]:
from pathlib import Path

import torch
import torchvision
from torchvision.transforms import Compose
from tqdm import tqdm

from uncertify.data.dataloaders import dataloader_factory, DatasetType
from uncertify.models.vae import VariationalAutoEncoder
from uncertify.models.encoder_decoder_baur2020 import BaurEncoder, BaurDecoder
from uncertify.deploy import sample_from_gauss_prior
from uncertify.deploy import infer_latent_space_samples

from uncertify.visualization.reconstruction import plot_vae_output
from uncertify.common import DATA_DIR_PATH

# Load model and Dataloaders

In [None]:
model = VariationalAutoEncoder(BaurEncoder(), BaurDecoder(), get_batch_fn=lambda batch: batch['scan'])
# model_mnist = VariationalAutoEncoder(BaurEncoder(), BaurDecoder(), get_batch_fn=lambda batch: batch[0])

In [None]:
CHECKPOINT_PATH = DATA_DIR_PATH / 'lightning_logs/train_vae/version_1/checkpoints/epoch=261.ckpt'
CHECKPOINT_PATH = Path('/home/matt/polybox/Masterarbeit/Training/epoch=261.ckpt')
assert CHECKPOINT_PATH.exists(), f'Model checkpoint does not exist!'

checkpoint = torch.load(CHECKPOINT_PATH)
model.load_state_dict(checkpoint['state_dict'])

In [None]:
BATCH_SIZE = 8

_, brats_val_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=BATCH_SIZE, shuffle_val=True)
camcan_train_dataloader, camcan_val_dataloader = dataloader_factory(DatasetType.CAMCAN, batch_size=BATCH_SIZE, shuffle_train=True, shuffle_val=True)
mnist_train_dataloader, mnist_val_dataloader = dataloader_factory(DatasetType.MNIST, batch_size=BATCH_SIZE, shuffle_train=True, shuffle_val=True,
                                                                  transform=Compose([torchvision.transforms.Resize((128, 128)),
                                                                                     torchvision.transforms.ToTensor()]))

# Plot latent space behaviour

## Plot variance captured over one latent space dimension

In [None]:
from uncertify.visualization.latent_space_analysis import plot_reconstructions_one_dim_changing

In [None]:
change_dim_indices = [1, 80, 108]

for dim in change_dim_indices:
    plot_reconstructions_one_dim_changing(trained_model=model, change_dim_idx=dim, n_samples=32, 
                                          save_path=DATA_DIR_PATH / 'plots' / f'latent_sample_one_dim_{dim}.png');


## Plot variance captured over all latent space dimensions having all others fixed

In [None]:
from uncertify.visualization.latent_space_analysis import plot_latent_reconstruction_multiple_dims

In [None]:
plot_latent_reconstruction_multiple_dims(model, latent_space_dims=128, n_samples_per_dim=32, save_path=DATA_DIR_PATH / 'plots' / 'latent_reconstruct_all_dims.png',)

## Plot 2D grid varying values from -3 to 3 std's

In [None]:
from uncertify.visualization.latent_space_analysis import plot_latent_reconstructions_2d_grid

In [None]:
plot_latent_reconstructions_2d_grid(model, dim1=80, dim2=108, save_path=DATA_DIR_PATH / 'plots' / 'latent_space_2d_grid.png')