In [None]:
%load_ext autoreload
%autoreload 2

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

In [None]:
from pathlib import Path

import torch
import torchvision
from torchvision.transforms import Compose
from tqdm import tqdm

from uncertify.data.dataloaders import dataloader_factory, DatasetType
from uncertify.models.vae import VariationalAutoEncoder
from uncertify.models.encoder_decoder_baur2020 import BaurEncoder, BaurDecoder
from uncertify.evaluation.latent_space_analysis import sample_from_gauss_prior
from uncertify.evaluation.inference import infer_latent_space_samples
from uncertify.data.default_dataloaders import default_dataloader_dict_factory
from uncertify.io.models import load_vae_baur_model

from uncertify.visualization.reconstruction import plot_vae_output
from uncertify.common import DATA_DIR_PATH

# Load model and Dataloaders

In [None]:
masked_model = load_vae_baur_model(Path('/mnt/2TB_internal_HD/lightning_logs/schedule_mask/version_1/checkpoints/last.ckpt'))
non_masked_model = load_vae_baur_model(Path('/mnt/2TB_internal_HD/lightning_logs/schedule_mask/version_6/checkpoints/last.ckpt'))

In [None]:
model = masked_model

In [None]:
dataloader_dict = default_dataloader_dict_factory(batch_size=155,
                                                  num_workers=0,
                                                  shuffle_val=True)

# Plot latent space behaviour

## Plot variance captured over one latent space dimension

In [None]:
from uncertify.visualization.latent_space_analysis import plot_latent_reconstructions_one_dim_changing

In [None]:
change_dim_indices = [1, 80, 108]

for dim in change_dim_indices:
    plot_latent_reconstructions_one_dim_changing(trained_model=model, change_dim_idx=dim, n_samples=32, 
                                          save_path=DATA_DIR_PATH / 'plots' / f'latent_sample_one_dim_{dim}.png');


## Plot variance captured over all latent space dimensions having all others fixed

In [None]:
from uncertify.visualization.latent_space_analysis import plot_latent_reconstructions_multiple_dims

In [None]:
plot_latent_reconstructions_multiple_dims(model, latent_space_dims=128, n_samples_per_dim=32, save_path=DATA_DIR_PATH / 'plots' / 'latent_reconstruct_all_dims.png',)

## Plot 2D grid varying values from -3 to 3 std's

In [None]:
from uncertify.visualization.latent_space_analysis import plot_latent_reconstructions_2d_grid

In [None]:
plot_latent_reconstructions_2d_grid(model, dim1=18, dim2=50, save_path=DATA_DIR_PATH / 'plots' / 'latent_space_2d_grid.png')

# Plot latent space sample reconstructions from gaussian random samples

In [None]:
from uncertify.visualization.latent_space_analysis import plot_random_latent_space_samples

In [None]:
plot_random_latent_space_samples(model, n_samples=16, nrow=16, cmap='gray')

# Plot Gaussian samples annulus distribution

In [None]:
from uncertify.visualization.latent_space_analysis import plot_gaussian_annulus_distribution

In [None]:
plot_gaussian_annulus_distribution(latent_space_dims=128, n_samples=1000)

# Plot latent space sample reconstructions from different locations in latent space

In [None]:
from uncertify.visualization.latent_space_analysis import plot_latent_samples_from_ring

radii = [(0, 1), (2, 3), (4, 5), (7, 9), (10, 12), (15, 17), (20, 30), (50, 60), (200, 210)]

for sample in radii:
    inner_radius, outer_radius = sample
    fig = plot_latent_samples_from_ring(model, n_samples=16, inner_radius=inner_radius, outer_radius=outer_radius, cmap='gray')

# Plot latent space embeddings UMAP

In [None]:
from uncertify.visualization.latent_space_analysis import plot_umap_latent_embedding
from uncertify.evaluation.inference import yield_inference_batches

In [None]:
print(dataloader_dict.keys())

In [None]:
max_n_batches = 6
redisual_threshold = 0.67

select_dataloaders = ['CamCAN train', 'MNIST', 'Gaussian noise']

output_generators = []
for dataloader_name in select_dataloaders:
    dataloader = dataloader_dict[dataloader_name]
    output_generators.append(yield_inference_batches(dataloader, model, max_n_batches, redisual_threshold, progress_bar_suffix=f'{dataloader_name}'))

umap_fig = plot_umap_latent_embedding(output_generators, select_dataloaders, figsize=(14, 10))


In [None]:
umap_fig.savefig(DATA_DIR_PATH / 'plots' / f'umap_latent_embedding_masked.png')