In [None]:
%load_ext autoreload
%autoreload 2

from context import uncertify

In [None]:
import logging
from uncertify.log import setup_logging
setup_logging()
LOG = logging.getLogger(__name__)

# Matplotlib DEBUG logging spits out a whole bunch of crap
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

In [None]:
from pathlib import Path

import torch
import torchvision
from torchvision.transforms.transforms import Compose
from tqdm import tqdm
try:
    tqdm._instances.clear()
except:
    pass
import seaborn as sns
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
plt.rc('font', family='serif')
matplotlib.rcParams.update({'font.size': 18})
from sklearn.metrics import roc_curve, precision_recall_curve, precision_recall_fscore_support
from sklearn.metrics import average_precision_score, roc_auc_score, confusion_matrix

from uncertify.models.vae import load_vae_baur_model
from uncertify.models.vae import VariationalAutoEncoder
from uncertify.models.encoder_decoder_baur2020 import BaurDecoder, BaurEncoder
from uncertify.data.dataloaders import dataloader_factory, DatasetType
from uncertify.evaluation.inference import yield_inference_batches, yield_y_true_y_pred
from uncertify.evaluation.ood_metrics import sample_wise_waic_scores
from uncertify.visualization.grid import imshow_grid
from uncertify.visualization.model_performance import plot_roc_curve, plot_precision_recall_curve, plot_confusion_matrix
from uncertify.visualization.plotting import setup_plt_figure
from uncertify.visualization.histograms import plot_multi_histogram
from uncertify.data.datasets import GaussianNoiseDataset

from uncertify.common import DATA_DIR_PATH

In [None]:
# Define some paths and high level parameters
RUN_DIR_PATH = Path('/media/juniors/2TB_internal_HD/lightning_logs/train_vae/')
RUN_VERSIONS = [0, 1, 2]
CHECKPOINT_PATHS = [RUN_DIR_PATH / f'version_{version}/checkpoints/last.ckpt' for version in RUN_VERSIONS]
USE_N_BATCHES = 3

In [None]:
ensemble_models = [load_vae_baur_model(path) for path in CHECKPOINT_PATHS]

In [None]:
batch_size = 64

HDD_PROCESSED_DIR_PATH = Path('/media/juniors/2TB_internal_HD/datasets/processed/')
SSD_PROCESSED_DIR_PATH = DATA_DIR_PATH / 'processed'

_, brats_val_t2_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, 
                                             val_set_path=SSD_PROCESSED_DIR_PATH / 'brats17_t2_hm_bc_std_bv-3.5.hdf5', shuffle_val=False)
_, brats_val_t1_dataloader = dataloader_factory(DatasetType.BRATS17, batch_size=batch_size, 
                                             val_set_path=SSD_PROCESSED_DIR_PATH / 'brats17_t1_hm_bc_std_bv-3.5.hdf5', shuffle_val=True)
camcan_train_dataloader, camcan_val_dataloader = dataloader_factory(DatasetType.CAMCAN, batch_size=batch_size, 
                                                                    val_set_path=DATA_DIR_PATH / 'processed/camcan_val_t2_hm_std_bv3.5_xe.hdf5', 
                                                                    train_set_path=DATA_DIR_PATH / 'processed/camcan_train_t2_hm_std_bv3.5_xe.hdf5', 
                                                                    shuffle_val=False, shuffle_train=True)
noise_set = GaussianNoiseDataset()
noise_loader = DataLoader(noise_set, batch_size=batch_size)

_, mnist_val_dataloader = dataloader_factory(DatasetType.MNIST, batch_size=batch_size, transform=torchvision.transforms.Compose([
                                                                        torchvision.transforms.Resize((128, 128)),
                                                                        torchvision.transforms.ToTensor()
                                                                    ])
                         )
dataloader_dict = {
    'BraTS T2 val': brats_val_t2_dataloader,
    'BraTS T1 val': brats_val_t1_dataloader,
    'CamCAN train': camcan_train_dataloader,
    'Gaussian noise': noise_loader,
    'MNIST': mnist_val_dataloader
}

for name, dataloader in dataloader_dict.items():
    print(f'{name:15} dataloader: {len(dataloader)} batches (batch_size: {dataloader.batch_size}) -> {len(dataloader) * dataloader.batch_size} samples.')

In [None]:
NUM_BACTHES = 10

waic_dict = {}
for name, data_loader in dataloader_dict.items():
    LOG.info(f'WAIC score calculation for {name} ({NUM_BACTHES * data_loader.batch_size} patients)...')
    waic_scores = sample_wise_waic_scores(models=ensemble_models, data_loader=data_loader, max_n_batches=NUM_BACTHES)
    waic_dict[name] = waic_scores

In [None]:
fig, _ = plot_multi_histogram(waic_dict.values(), list(waic_dict.keys()), plot_density=False, 
                     figsize=(12, 6), xlabel='WAIC', ylabel='Slice-wise frequency',
                     hist_kwargs={'bins': 15});
fig.savefig(DATA_DIR_PATH / 'plots' / 'waic_scores.png')

In [None]:
from dataclasses import dataclass, field
from typing import List

In [None]:
@dataclass
class Test:
    a: List[float] = field(default_factory=list)

In [None]:
t = Test()

In [None]:
t.a.extend([2, 3])

In [None]:
t.a

In [None]:
from uncertify.evaluation.inference import AnomalyInferenceScores, AnomalyScores, SliceWiseAnomalyScores, SliceWiseCriteria

In [None]:
    anomaly_scores = AnomalyInferenceScores(AnomalyScores(), 
                                            [SliceWiseAnomalyScores(criteria) for criteria in SliceWiseCriteria])

In [None]:
from pprint import pprint

pprint(anomaly_scores.__dict__)

In [None]:
import torch

In [None]:

l = [
        [
            [[0, 1, 1],
             [0, 0, 0],
             [1, 0, 1]]
        ],
        [
            [[0, 0, 1],
             [0, 0, 0],
             [1, 0, 1]]
        ]
]

t = torch.tensor(l)

In [None]:
t.size()

In [None]:
n_abnormal_pixels = torch.sum(t > 0, axis=(1, 2, 3))

In [None]:
n_abnormal_pixels

In [None]:
from uncertify.evaluation.inference import yield_inference_batches

In [None]:
for batch in yield_inference_batches(brats_val_t2_dataloader, ensemble_models[0], max_batches=1):
    s = torch.sum(batch.segmentation > 0, axis=(1, 2, 3))
    y = batch.segmentation[batch.mask].flatten().numpy()
    r = batch.segmentation[batch.mask].flatten()
    print(r)

In [None]:
print(list(s.numpy()))

In [None]:
print(y)