# ASOS
## Imports

In [None]:
import random

import matplotlib.pyplot as plt

from tlib import tlearn
from asos import settings, utils

%load_ext autoreload
%autoreload 2

In [None]:
file_infos = settings.load_file_infos().df

## Setup ASOS

In [None]:
# setup asos performer
dims = utils.load_model().hparams['n_unet_maps']
use_hypercube = True  # hypercube method (True) or expectation maximization (False)

if use_hypercube:
    if dims == 1:
        asos = tlearn.interpret.asos.ASOSPerformer1d(ax_range=(-1, 1), output_folder=settings.working_folder)
    elif dims == 2:
        asos = tlearn.interpret.asos.ASOSPerformer2d(ax_range=(-1, 1), output_folder=settings.working_folder)
    elif dims == 3:
        asos = tlearn.interpret.asos.ASOSPerformer3d(ax_range=(-1, 1), output_folder=settings.working_folder)

else:
    if dims == 1:
        asos = tlearn.interpret.asos.ASOSPerformerEM1d(ax_range=(-1, 1), output_folder=settings.working_folder)
    elif dims == 2:
        asos = tlearn.interpret.asos.ASOSPerformerEM2d(ax_range=(-1, 1), output_folder=settings.working_folder)
    elif dims == 3:
        asos = tlearn.interpret.asos.ASOSPerformerEM3d(ax_range=(-1, 1), output_folder=settings.working_folder)

asos.save()  # save asos with pickle

## Vectorization

In [None]:
# get unet maps
files = file_infos[(file_infos['dataset'] == 'train') & (file_infos['true_pred'])].index.to_list()
print(len(files))

# get only a random fraction of unet maps
frac_unet_maps = 0.15

random_indices = random.sample(range(0, len(files)), int(len(files) * frac_unet_maps))
files = [files[index] for index in random_indices]
print(len(files))

unet_maps = utils.predict(*files)

# vectorize
random_frac = 1/1000
asos.vectorize(maps=unet_maps, map_ids=files, frame_size=10, random_frac=random_frac)
asos.save()  # save asos with pickle

del unet_maps

In [None]:
%matplotlib inline
if asos.dims in [1, 2]:
    asos.plot_chspace()
    plt.show()

In [None]:
%matplotlib widget
if asos.dims == 3:
    asos.plot_chspace(colors='rgb')  # colors=None to not color vectors in rgb
    plt.show()

## Groups

In [None]:
# define groups
if use_hypercube:
    
    edge_length = 2/10
    consider_factor = 2
    
    asos.fit_groups(edge_length=edge_length, consider_factor=consider_factor)
else:
    asos.fit_groups(n_groups=3)

asos.save()  # save asos with pickle

In [None]:
%matplotlib inline
if asos.dims in [1, 2]:
    asos.plot_chspace(colors='groups')
    plt.show()

In [None]:
%matplotlib widget
if asos.dims == 3:
    asos.plot_chspace(colors='groups')
    plt.show()

## Sensitivities

In [None]:
# we cannot predict all unet-maps as follows at this point, because this would cause a memory overflow for the many training data:
# files = file_infos[(file_infos['dataset'] == 'train') & (file_infos['true_pred'])].index.to_list()
# unet_maps = utils.predict(*files)
# instead we define an object, that behaves like a list using __getitem__:

class UNetMaps:
    def __init__(self):
        self.dataset = settings.load_datamodule(setup_stage='fit', cutmix=None).train_dataset
        self.unet = utils.load_model().unet
    
    def __getitem__(self, index):
        x = self.dataset[index]['x']
        unet_map = self.unet(x.unsqueeze(0)).detach().cpu()[0]
        return unet_map
    
    def __len__(self):
        return len(self.dataset)

unet_maps = UNetMaps()

In [None]:
%matplotlib inline

# get model
model = utils.load_model().classify_unet_map

# fit sensitivities
#asos.fit_sensitivities(maps=unet_maps, model=model, fill_value=0, move_data_to_gpu=True)
asos.fit_sensitivities(maps=unet_maps, model=model, fill_value=0, move_data_to_gpu=True)
asos.save()  # save asos with pickle

In [None]:
# adapt valid deviations

min_n_occluded_pixels = 100
q = 0.05
    
# only those deviations are taken for further calculations that were calculated from at least min_n_occluded_pixels when occluding a map
asos.adapt_valid_deviations(min_n_occluded_pixels=min_n_occluded_pixels)

asos.set_vlim(q=q)
asos.save()  # save asos with pickle

In [None]:
%matplotlib inline
asos.plot_histograms()

In [None]:
%matplotlib inline
if asos.dims in [1, 2]:
    asos.plot_chspace(colors='sensitivities')
    plt.show()

In [None]:
%matplotlib widget
if asos.dims == 3:
    asos.plot_chspace(colors='sensitivities')
    plt.show()

In [None]:
%matplotlib inline
# plot sample
index = 100
sensitivity_map = asos.predict_sensitivities(unet_maps[index].unsqueeze(0))[0]
asos.plot_sensitivity_map(sensitivity_map)
plt.show()