In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
sys.path.insert(0, '/home/jovyan/braindecode/')
sys.path.insert(0, '/home/jovyan/mne-python/')
import pickle

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error

from braindecode.datasets.tuh import TUHAbnormal

from decode_tueg import (
    DataScaler, TargetScaler, trial_age_mae, Augmenter, ChannelsSymmetryFB,
    get_datasets, test_name, 
    create_windows, _create_windows,
    create_final_scores, _create_final_scores,
    plot_chronological_vs_predicted_age, plot_age_gap_hist,
    plot_thresh_to_acc,
)

In [None]:
exps = {
    3: '220824144139015381',
    1: '220824191628058941',
    4: '220824144141880751',
    2: '220824144128724203',
    0: '220824144055632350',
}
base_dir = '/home/jovyan/new_full_runs/'
model = 'valid_best'  # 'train_end'

In [None]:
only_valid = True

In [None]:
fig, ax_arr = plt.subplots(1 if only_valid else 2, 3, figsize=(20, 6), squeeze=False, sharex=True, sharey=True)
fig2, ax_arr2 = plt.subplots(3, 1 if only_valid else 2, figsize=(20, 6), squeeze=False, sharex=True, sharey=True)
fig3, ax_arr3 = plt.subplots(3, 1 if only_valid else 2, figsize=(20, 6), squeeze=False, sharex=True)

In [None]:
for exp_i, exp in exps.items():
    break

In [None]:
with open(os.path.join(base_dir, exp, f'checkpoint/{model}_model.pkl'), 'rb') as f:
    clf = pickle.load(f)
with open(os.path.join(base_dir, exp, 'data_scaler.pkl'), 'rb') as f:
    data_scaler = pickle.load(f)
with open(os.path.join(base_dir, exp, 'target_scaler.pkl'), 'rb') as f:
    target_scaler = pickle.load(f)
config = pd.read_csv(os.path.join(base_dir, exp, 'config.csv'), index_col=0).squeeze()

In [None]:
subsets = ['normal', 'mixed', 'abnormal']
for subset_i, subset in enumerate(subsets):
    train, valid, mapping = get_datasets(
        config.data_path,
        config.target_name,
        subset,  # maybe skip the subset used for training. could serve as sanity check though
        int(config.n_train_recordings),
        int(config.tmin),
        int(config.tmax),
        int(config.n_jobs),
        int(config.final_eval),
        float(config.valid_set_i),
        int(config.seed),
    )

    n_channels = train[0][0].shape[0]
    t = torch.ones(1, n_channels, int(config.window_size_samples), 1).cuda()
    n_preds_per_input = clf.module(t).size()[2]

    # order due to indexing ax_arr for plotting
    for ds_i, (ds_name, ds) in enumerate([(test_name(int(config.final_eval)), valid), ('train', train)]):
        mean_train_age = train.description['age'].mean()
        if only_valid and ds_name == 'train':
            print("skipping train")
            continue
        ds = _create_windows(
            mapping,
            ds,
            int(config.window_size_samples),
            n_channels,
            int(config.n_jobs),
            int(config.preload),
            n_preds_per_input,
        )

        ds.target_transform = target_scaler
        ds.transform = data_scaler

        preds, targets = clf.predict_trials(ds)
        preds = np.array([p.mean(-1) for p in preds])
        preds = target_scaler.invert(preds)
        targets = target_scaler.invert(targets)
        score = mean_absolute_error(targets, preds)

        df = pd.DataFrame({
            'y_pred': preds.ravel(),
            'y_true': targets.ravel(),
            'pathological': ds.description['pathological'].to_numpy(),
        })
        
        title = f"{subset}, {ds_name}, {config.target_name}, mae, {score:.2f}"
        ax = plot_chronological_vs_predicted_age(
            df,
            dummy=None,  # mean_train_age,
            ax=ax_arr[ds_i, subset_i],
        )
        ax.set_title(title)
        
        ax2 = plot_age_gap_hist(
            df,
            ax=ax_arr2[subset_i, ds_i],
        )
        ax2.set_title(title)
        
        ax3 = plot_thresh_to_acc(
            df,
            ax=ax_arr3[subset_i, ds_i],
        )
        ax3.set_title(title)

In [None]:
fig

In [None]:
fig2

In [None]:
fig3