### colab setup

<a href="https://colab.research.google.com/github/e-cremente/physioex/blob/main/examples/freq_bands_importance.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# colab setup
from google.colab import drive
drive.mount("/content/drive")

import os
working_dir = "/content/drive/MyDrive/Thesis"
os.chdir( working_dir )

!git clone https://github.com/e-cremente/physioex.git
%cd physioex

!git pull origin main
!pip install -e .

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from physioex.explain.freq_bands_explainer import FreqBandsExplainer
from loguru import logger

ckp_path = "models/cel/chambon2018/seqlen=3/dreem/dodh/"

expl = FreqBandsExplainer(
            model_name = "chambon2018", 
            dataset_name  = "dreem", 
            version = "dodh", 
            use_cache = True, 
            sequence_lenght  = 3,
            ckp_path = ckp_path, 
            batch_size  = 32,
            #questo parametro da un errore. Non e' previsto dal costruttore della classe FreqBandsExplainer (e.c.)
            #n_jobs = 1
            )

In [None]:
#cambiato Gamma da [30, 50] a [30, 49.5] perche' con un sampling rate di 100hz, il filtro che tagliasse fuori frequenze di 50hz non era applicabile (e.c.)
sleep_bands = { 'Alpha' : [8, 12], 'Beta' : [12, 30], 'Delta' : [0.5, 4], 'Theta' : [4, 8], 'Gamma' : [30, 49.5] }

for band in sleep_bands:
    logger.info("Explaining band: {}".format(band))

    #il numero di jobs viene invece specificato qui, come ultimo parametro della chiamata del metodo (e.c.)
    expl.explain(sleep_bands[band], band_name = format(band), save_csv= True,  plot_pred = True, plot_true = True, n_jobs = 1)

    # plot the results of the 0 fold
    #cambiato str(band) in str(sleep_bands[band]) per prendere le immagini corrette
    img = mpimg.imread( ckp_path + 'fold=0_true_band=' + str(sleep_bands[band]) + '_importance.png')
    imgplot = plt.imshow(img)
    plt.show()
    
    img = mpimg.imread( ckp_path + 'fold=0_pred_band=' + str(sleep_bands[band]) + '_importance.png')
    imgplot = plt.imshow(img)
    plt.show()
