## Model training on spectral line Mg II h&k - Example notebook


Solar flares release their energy when twisted and twirled coronal magnetic field loops reconnect to form an energetically more favorable configuration. The twisting of these magnetic field loops is driven by the convective motion in the solar atmosphere below the $\beta<1$ surface. Due to the frozen in condition the magnetic field can only reconnect and reconfigure above the $\beta=1$ surface. However, it is only poorly understood how magnetic energy builds up over time in the emerging field loops without undergoing early reconnection, and what the necessary conditions are for reconnection to trigger a solar flare. Up to this day reliable data driven 3D magneto hydrodynamic models (3D-MHD) to model the evolution of an active region and predict flares from recorded initial conditions do not exist. At the same time, the computational costs of such a model would exceed current computational capabilities. Therefore, to investigate flares for early signs of an imminent eruption we investigate spectra in Mg II h&k, Si IV and CII collected by the small explorer satellite IRIS. Spectra contain height dependent informationa about the physical state of the atmospheric layer they form in. In case of Mg II h&k the formation height spans from the lower to upper chromosphere and in explosive events even into the transition region. The spectra from Si IV and C II form in the transition region, which lays between the Chromosphere and the Corona.

### Fully connected neural networks

<img src="NN_scheme.png" width="600" height="450">

Artificial Neural networks (ANN) for classification were first proposed in 1944 by Warren McCullough and Walter Pitts \citep{fitch_1944}, but only became popular after the invention of backpropagation \citep{Rumelhart_1986}, the advances in computation power in 2012, and creating ImageNet \citep{ImageNet_2012} which could reliably classify images. The classical deep neural network is a fully connected network \citep{bishop_1995}. It consists of an input layer containing the data to be classified or approximated, at least one hidden layer (deep) with an arbitrary number of neurons, and an output layer with a number of neurons appropriate to our classification problem. Fully connected means that the neurons between each layer are connected to each neuron in the previous and subsequent layer. For a binary classification task the output layer would consist of one or two neurons. Each neuron is equipped with an activation function, which is typically a rectified linear unit ReLU
\begin{equation}  \label{int_num_test}  
\phi(x)  = \begin{cases} 
0 & x \le 0\\
x & x > 0\\ 
\end{cases}
\end{equation}
for hidden layers and a sigmoid function 
\begin{equation}
   \phi(x) = \frac{1}{1+e^{-x}}
\end{equation}

for the output neurons. Based on the universal approximation theorem \citep{ZHOU_2020} we should be able to approximate any continuous function with a deep neural network containing at least one single hidden layer with an arbitrary number of neurons and non-linear activation functions on each neuron. However, in practice training a neural network is difficult and one has to carefully select the architecture, optimizer and loss function to balance convergence, training time as well as generalization, e. g. performance on unseen data of the network. This process is completely empirical and to our knowledge no solution to this ambiguity has been found yet. In case of classification what a neural network does is essentially warping and projecting the input space with non-linear transformations into lower dimensional representations, such that the data can be separated by a hyperplane or a set of hyperplanes into the desired amount of classes. The weights and biases are initialized randomly and get updated after each forward pass through backpropagation. Thereby, the output from the forward pass is compared to the desired output or label using a loss function. The loss function is dependent on the weights and biases, e. g. the parameters of the network and thus increases or decreases by adjusting them. Therefore, in order to improve our prediction model we update the weights and biases by computing the gradient of the loss function with respect to the weights and biases of our network. Since we cannot compute the gradient for our whole dataset on each forward pass, we choose a size of a representative random sample of our data called a minibatch to estimate the gradient of the loss function and update the parameters. The size of our minibatch can be estimated by the noise in the loss function and we chose sizes between 2400 to 6400 spectra per forward pass. To decide under which strategy our parameters should be updated to ensure convergence to the global minimum, we need to select an appropriate optimizer. In our case we used variants of Adam with and without weight decay, which is a learning rate adjusting algorithm that shows high performance in convergence speed and stability of the gradient \citep{Adam}. Weight decay is controlled via a parameter set apriori and should prevent the model from overfitting. Overfitting is a state where the neural network “memorizes” the training data and performs well on it but does not generalize well to new data. To monitor the training process of a neural network we compare their loss both on the training and validation set which we split at about a 0.25:0.75 ratio. As soon as the loss on the validation set relative to the training set starts increasing while the loss on the training set remains decreasing or constant with training time, the network starts overfitting. To compare the performance on the validation set in parallel for each minibatch in the training set, we take a minibatch from the validation set matching the number of minibatches in the training set and feed it to the model after each training step, and record the learning curve.  By monitoring the learning curve we can observe overfitting and create adequate criteria for early stopping, e. g. to fall back on the best model state before overfitting. The learning curve is drawn with the loss on each minibatch, the black line represents the loss on the training set and the orange line the loss on the validation set. The blue line depicts the accuracy measured over the current validation minibatch. If the size of the minibatches is not representative enough, the learning curve becomes noisy with outliers. The same happens in case of a too high learning rate, since the model jumps too far in the parameter space and always misses the minimum. One sequence during which the model has seen the entire dataset is called an epoch. After each epoch the minibatches get reshuffled and the training sequence restarts. 
Better though is to prevent the model completely from overfitting by applying several techniques such as batch normalization, dropout, and the previously mentioned weight decay. Batch normalization normalizes each minibatch according to: 
\begin{equation}
    \hat{x}_{i}^{(k)}=\frac{x_{i}^{(k)}-\mu_{B}^{(k)}}{\sqrt{\sigma_{B}^{(k)^{2}}+\epsilon}}, 
\end{equation}
where $k \in[1, d]$ and $ i \in[1, m] $ and 
\begin{equation}
    \mu_{B}=\frac{1}{m} \sum_{i=1}^{m} x_{i}, \text { and } \sigma_{B}^{2}=\frac{1}{m} \sum_{i=1}^{m}\left(x_{i}-\mu_{B}\right)^{2},
\end{equation}
and thereby prevents the gradient from exploding through outliers, ensuring smooth fast convergence to the global minimum. Dropout randomly sets channels to zero not contributing to the output during a pass, ensuring equal training of all hidden layers and units. Weight decay is introduced through a regularization term in the loss function $\frac{\lambda}{2}\|w_{i}\|^2$. We use binary cross entropy as our loss function where we average over each batch:
\begin{equation}
    L= \frac{1}{m} \sum_{i=1}^{m}-w_{i}\left[y_{i} \cdot \log x_{i}+\left(1-y_{i}\right) \cdot \log \left(1-x_{i}\right)\right],
\end{equation}
with $x_i$ our model output and $y_i$ for the correct label of each spectrum $i$. To keep track of the best version of our network during training, we measure the models performance after each training epoch on the testing set and store the model in case it scored higher than the previously stored one. For very large testing sets this can increase the training time by a significant amount and therefore we estimate the best score from maximum $200'000$ samples and only compute the score on the full testing set with the final model.

### Convolutional neural networks

With the advance of image classification a different type of neural networks became popular, namely convolutional neural networks. First introduced by \citep{hubel_1968}, the idea is to automatically derive features from the input, by using filters over a collection of pixels. Convolving the set of pixels to a single output and thereby making use of the correlation between adjacent pixels we get a sparsely connected network, easier to train and better generalizing to new data. The last part is usually still a fully connected network, taking the output of the convolutional layers. In this way, one can train filters to for instance recognize edges or round objects in images. Each type of filter creates a feature map highlighting the areas in the input image where certain patterns are more present. It also allows to reconstruct the type of information the network learned to base its decision on. 

In this notebook we will show how we trained and evaluated models on preprocessed and vae cleaned spectra from the Mg II h&k line. 


First let's import all the relevant libraries.

In [None]:
import sys
import os
import traceback

from IPython.core import display
from IPython.core.display import Image

import astropy.units as u
from astropy.time import Time, TimeDelta
from sunpy.net import Fido
from sunpy.net import attrs as a
from sunpy.timeseries import TimeSeries
from sunpy.time import parse_time
from tqdm import tqdm_notebook as tqdm
import torch
import pickle
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import animation
from matplotlib import rcParams
from matplotlib.ticker import MultipleLocator
from IPython.display import HTML
import matplotlib.gridspec as gridspec
from matplotlib.ticker import MultipleLocator
from sklearn.model_selection import train_test_split
from scipy.stats import norm

import torch.multiprocessing

import h5py

from utils_features import *
import utils as utils
import utils_models_MgIIk as mdls
from sklearn.model_selection import StratifiedKFold

from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix

## Collect the different obs ids and group them accordingly

In [None]:
obs_ids_Brandon_AR = ['20150518_143915_3860256971',
                    '20150518_161415_3860256971',
                    '20150521_185917_3800507454',
                    '20150703_165917_3620006130',
                    '20150704_100921_3860108354',
                    '20150704_165917_3620006130',
                    '20150728_151849_3660109122',
                    '20150807_221421_3860259180',
                    '20150809_061551_3860009180',
                    '20150916_181744_3600101141',
                    '20151017_003115_3660105403',
                    '20150724_053523_3620109103',
                    '20150130_112715_3860607366',
                    '20150408_045717_3860107054',
                    '20141128_210538_3860009154',
                    '20141201_154438_3800008053',
                    '20140329_201426_3820011652',
                    '20140313_093521_3820109554']


obs_ids_Brandon_PF = ['20141026_185250_3864111353',
                    '20141107_093726_3860602088',
                    '20140906_112339_3820259253',
                    '20150821_160115_3660104044',
                    '20140612_110933_3863605329',
                    '20150312_054519_3860107053',
                    '20140212_215458_3860257280',
                    '20150311_044603_3860259280',
                    '20141109_151704_3860258971',
                    '20141027_205655_3864111353',
                    '20140611_181927_3863605329',
                    '20150622_170003_3660100039',
                    '20141021_181052_3860261353',
                    '20141025_145828_3880106953',
                    '20140329_140938_3860258481',
                    '20140910_112825_3860259453',
                    '20141022_081850_3860261381',
                    '20141027_140420_3860354980',
                    '20150311_151947_3860107071']

## Parameters for each line
MgIIk = {'lambda_min' : 2794,
         'lambda_max' : 2806,
         'n_breaks' : 960,
         'line' : "Mg II k",
         'field' : "NUV",
         'threshold' : 10
        }


"""Create the different sets for the k-fold cross validation"""

num_of_repetitions = 5
K = 5

# Fix line here
line = 'MgIIk'
line_params = MgIIk

#Setup your own paths

path_cleaned_AR = f'~/cleaned/AR/'
path_cleaned = f'~/cleaned/'
path_cleaned_PF = f'~/cleaned/PF/'


## Create the different sets for the k-fold cross validation

In [None]:
# %%time # Careful! If you rerun this cell, you will create a new set of splits.

y_neg = np.zeros([len(obs_ids_Brandon_AR)])
y_pos_ = np.ones([len(obs_ids_Brandon_PF)])

y_ = np.concatenate((y_pos_, y_neg), axis = 0)

###########################################################################################################

#Processing of green + yellow flare observations

X_obs_ids = np.array(obs_ids_Brandon_PF + obs_ids_Brandon_AR)
y_obs_ids = y_ # careful! choose well, since balance of labels is generated from this one


"""Part to build the sets of labels for k-fold crossvalidation"""

Label_set_all = {}

obs_files = []
splits = {}

#Find all files with obs_ids in X_obs_ids

for obs_np_file in os.listdir(path_cleaned_PF): # only takes the flares that could be processed and stored
    obs_label = obs_np_file.split('.')[0][3:29] # IRIS obs ID
    if obs_label in X_obs_ids:
        obs_files.append(obs_np_file)


for obs_np_file in os.listdir(path_cleaned_AR): # only takes the times that could be processed and stored
    obs_label = obs_np_file.split('.')[0][3:29] # IRIS obs ID
    if obs_label in X_obs_ids:
        obs_files.append(obs_np_file)


#create k-fold cross validation sets
Label_set = {}

for itter in tqdm(range(num_of_repetitions)):

    skf = StratifiedKFold(n_splits=K, shuffle=True, random_state=None)

    k = 0

    for train_index, test_index in skf.split(X_obs_ids, y_obs_ids):

        # for a single split and rep make train and test sets (train cleaned, test uncleaned in general):

        print("TRAIN set:", X_obs_ids[train_index], "TEST set:", X_obs_ids[test_index])
        train_set_, test_set_ = X_obs_ids[train_index], X_obs_ids[test_index]

        files_train = []
        files_test = []

        for file_name in obs_files:

            if np.any(file_name[3:29] == np.array(test_set_)):

                files_test.append(file_name)

            else:

                files_train.append(file_name)

        Label_set[str(k) + '_' + str(itter) + '_testing_obs_list'] = test_set_
        Label_set[str(k) + '_' + str(itter) + '_training'] = files_train
        Label_set[str(k) + '_' + str(itter) + '_testing'] = files_test

        k += 1

np.savez(path_cleaned + 'Label_set_5_5.npz', Label_set) # Careful! If you rerun this cell, you will create a new set of splits.


In [None]:
"""Test, Training, Scoring"""

def train(decision_model, optimizer, X):

    local_batch = X[0]
    local_labels = X[1]

    # Transfer to GPU
#     local_batch = local_batch.view(-1,1,line_params['n_breaks'])

    local_batch, local_labels = local_batch.to(device, dtype= torch.float), local_labels.to(device, dtype= torch.float)

    # Model computations
    optimizer.zero_grad()
    # Forward pass
    y_hat = decision_model(local_batch)
    # Calculate loss
    loss = criterion(y_hat.squeeze(), local_labels)

    # Back pass
    loss.backward()
    optimizer.step()
    train_loss = loss.cpu().detach().numpy()

    return train_loss


def test(decision_model, V):

    decision_model.eval()

    local_batch = V[0]
    local_labels = V[1]

    local_batch, local_labels = local_batch.to(device, dtype= torch.float), local_labels.to(device, dtype= torch.float)


    # Model computations
    # Forward pass

    y_hat = decision_model(local_batch)
    # Calculate loss
    loss = criterion(y_hat.squeeze(), local_labels)
    valid_loss = loss.cpu().detach().numpy()

    acc = accuracy_score(local_labels.cpu().detach().numpy().squeeze(), torch.round(y_hat).cpu().detach().numpy().squeeze())

    return acc, valid_loss


def compute_score(decision_models, validation_generator):

    if not isinstance(decision_models, list):
        decision_models = [decision_models]

    [decision_model.eval() for decision_model in decision_models]

    yhat_classes_list = [[] for n in range(len(decision_models))]

    yhat_probs_list = [[] for n in range(len(decision_models))]

    y_clean_obs_test_labels = []

    for V in tqdm(validation_generator):

        with torch.no_grad():
            local_batch = V[0]
            local_labels = V[1]

            # Transfer to GPU
#             local_batch = local_batch.view(-1,1,line_params['n_breaks'])

            local_batch, local_labels = local_batch.to(device, dtype= torch.float), local_labels.to(device, dtype= torch.float)

            # predict probabilities for test set
            yhat_list = [decision_model(local_batch) for decision_model in decision_models]
            # predict classes for test set
            _ = [yhat_classes.extend(np.array([ torch.round(yhat_prob[i]).cpu().detach().numpy() for i in range(len(yhat_prob)) ]).squeeze()) for yhat_prob, yhat_classes in zip(yhat_list, yhat_classes_list)]

            _ = [yhat_probs.extend(yhat_prob.cpu().detach().numpy().squeeze()) for yhat_prob, yhat_probs in zip(yhat_list, yhat_probs_list)]

            y_clean_obs_test_labels.extend(local_labels.cpu().detach().numpy().squeeze())

    yhat_classes_list = [np.hstack(yhat_classes) for yhat_classes in yhat_classes_list]
    yhat_probs_list = [np.hstack(yhat_probs) for yhat_probs in yhat_probs_list]

    # accuracy: (tp + tn) / (p + n)
    accuracy = [accuracy_score(y_clean_obs_test_labels, yhat_classes) for yhat_classes in yhat_classes_list]
    print('Accuracy: ', accuracy)
    # precision tp / (tp + fp)
    precision = [precision_score(y_clean_obs_test_labels, yhat_classes) for yhat_classes in yhat_classes_list]
    print('Precision: ', precision)
    # recall: tp / (tp + fn)
    recall = [recall_score(y_clean_obs_test_labels, yhat_classes) for yhat_classes in yhat_classes_list]
    print('Recall: ', recall)
    # f1: 2 tp / (2 tp + fp + fn)
    f1 = [f1_score(y_clean_obs_test_labels, yhat_classes) for yhat_classes in yhat_classes_list]
    print('F1 score: ', f1)
    # kappa
    kappa = [cohen_kappa_score(y_clean_obs_test_labels, yhat_classes) for yhat_classes in yhat_classes_list]
    print('Cohens kappa: ', kappa)
    # ROC AUC
    auc = [roc_auc_score(y_clean_obs_test_labels, yhat_probs) for yhat_probs in yhat_classes_list]
    print('ROC AUC: ', auc)
    # confusion matrix
    matrices = [confusion_matrix(y_clean_obs_test_labels, yhat_classes) for yhat_classes in yhat_classes_list]
    print(matrices)
    # true skill score
    TN = [matrix[0,0] for matrix in matrices]
    FP = [matrix[0,1] for matrix in matrices]
    FN = [matrix[1,0] for matrix in matrices]
    TP = [matrix[1,1] for matrix in matrices]
    tss_eval = [TP[n]/(TP[n]+FN[n]) - FP[n]/(FP[n]+TN[n]) for n in range(len(decision_models))]
    print('TSS: ', tss_eval)

    return accuracy, precision, recall, f1, kappa, auc, tss_eval

    #trainer = Trainer(callbacks=[EarlyStopping(monitor="tss")])



def ROC_plot(decision_model, validation_generator, model_name):

    decision_model.eval()

    yhat_classes = []
    yhat_probs = []

    pos_class_dist = []
    neg_class_dist = []
    y_clean_obs_test_labels = []
    coutner = 0
    for V in tqdm(validation_generator):

        with torch.no_grad():
            local_batch =V[0]
            local_labels = V[1]

            # Transfer to GPU
#             local_batch = local_batch.view(-1,1,line_params['n_breaks'])

            local_batch, local_labels = local_batch.to(device, dtype= torch.float), local_labels.to(device, dtype= torch.float)

            # predict probabilities for test set
            yhat_prob = decision_model(local_batch)
            # predict classes for test set
            yhat_classes.extend(np.array([ torch.round(yhat_prob[i]).cpu().detach().numpy() for i in range(len(yhat_prob)) ]).squeeze())

            yhat_probs.extend(yhat_prob.cpu().detach().numpy().squeeze())
            y_clean_obs_test_labels.extend(local_labels.cpu().detach().numpy().squeeze())
            pos_class_dist.extend([ yhat_prob[j].cpu().detach().numpy().squeeze() for j in range(len(local_labels)) if local_labels[j]==1 ])
            neg_class_dist.extend([ yhat_prob[j].cpu().detach().numpy().squeeze() for j in range(len(local_labels)) if local_labels[j]==0 ])

    yhat_classes = np.hstack(yhat_classes)
    yhat_probs = np.hstack(yhat_probs)

    # accuracy: (tp + tn) / (p + n)
    accuracy = accuracy_score(y_clean_obs_test_labels, yhat_classes)
    print('Accuracy: %f' % accuracy)
    # precision tp / (tp + fp)
    precision = precision_score(y_clean_obs_test_labels, yhat_classes)
    print('Precision: %f' % precision)
    # recall: tp / (tp + fn)
    recall = recall_score(y_clean_obs_test_labels, yhat_classes)
    print('Recall: %f' % recall)
    # f1: 2 tp / (2 tp + fp + fn)
    f1 = f1_score(y_clean_obs_test_labels, yhat_classes)
    print('F1 score: %f' % f1)
    # kappa
    kappa = cohen_kappa_score(y_clean_obs_test_labels, yhat_classes)
    print('Cohens kappa: %f' % kappa)
    # ROC AUC
    auc = roc_auc_score(y_clean_obs_test_labels, yhat_probs)
    print('ROC AUC: %f' % auc)
    # confusion matrix
    matrix = confusion_matrix(y_clean_obs_test_labels, yhat_classes)
    print(matrix)
    # true skill score
    try:
        # true skill score
        TN = matrix[0,0]
        FP = matrix[0,1]
        FN = matrix[1,0]
        TP = matrix[1,1]

        if (FP == 0) and (TN == 0):
            tss_eval = TP/(TP+FN)
        elif (TP == 0) and (FN == 0):
            tss_eval = -FP/(FP+TN)
        else:
            tss_eval = TP/(TP+FN) - FP/(FP+TN)
    except Exception as exc:
        if np.all(yhat_classes==1):
            tss_eval = 1
        elif np.all(yhat_classes==0):
            tss_eval = 0
        else:
            print('something went wrong with the present observation', exc)
            tss_eval = 0
    print('TSS: %f' % tss_eval)

    pos_class_dist = np.asarray( pos_class_dist )
    neg_class_dist = np.asarray( neg_class_dist )

    pos_class_dist = np.squeeze( pos_class_dist )
    neg_class_dist = np.squeeze( neg_class_dist )

    sns.set_style("darkgrid")
    sns.set_style("white")

    plt.clf()
    plt.cla()
    fig = plt.figure(figsize=(15, 8))
    gs = gridspec.GridSpec(1, 2)
    clrs='orange'

    #------------------------------------------------------------------------------------------------------------------
    ax1 = plt.subplot(gs[0:, 0])
    ax = sns.kdeplot( pos_class_dist, color='coral', label='pos',shade='r', zorder=1 )
    ax = sns.kdeplot( neg_class_dist, color='powderblue', label='neg',shade='k', zorder=2 )
    plt.locator_params(axis='y', nbins=4)

    plt.axvline(x=.5, c='k', linestyle='--')
    plt.axvline(x=.2, c='k', linestyle='--')
    plt.axvline(x=.8, c='k', linestyle='--')

    plt.arrow(.2, 5.5, 0.2, 0, length_includes_head=True,
              head_width=0.3, head_length=0.02, color='k')
    plt.arrow(.5, 5.5, 0.2, 0, length_includes_head=True,
              head_width=0.3, head_length=0.02, color='k')
    plt.arrow(.8, 5.5, 0.2, 0, length_includes_head=True,
              head_width=0.3, head_length=0.02, color='k')

    ax.text(.25, 3, 'A', style='italic',
            bbox={'facecolor': 'b', 'alpha': 0.3, 'pad': 8})
    ax.text(.55, 3, 'B', style='italic',
            bbox={'facecolor': 'orange', 'alpha': 0.3, 'pad': 8})
    ax.text(.85, 3, 'C', style='italic',
            bbox={'facecolor': 'green', 'alpha': 0.3, 'pad': 8})

    plt.ylabel('density function')
    plt.xlabel('score')
    plt.legend(loc='upper left')
    plt.title('Class score distribution for the best model')

    ax1 = plt.subplot(gs[0:, 1])
    from sklearn.metrics import roc_curve
    from sklearn.metrics import auc

    fpr, tpr, thresholds = roc_curve(y_clean_obs_test_labels, yhat_probs)
    auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, label = 'AA' + ' (area = {:.3f})'.format(auc),
             c='k')

    plt.plot([0, 1], [0, 1], 'k--')

    p1 = abs(thresholds - .2).argmin()
    p2 = abs(thresholds - .5).argmin()
    p3 = abs(thresholds - .8).argmin()

    posA = [fpr[p1], tpr[p1]]
    posB = [fpr[p2], tpr[p2]]
    posC = [fpr[p3], tpr[p3]]

    plt.scatter(posA[0], posA[1], s=80)
    plt.scatter(posB[0], posB[1], s=80)
    plt.scatter(posC[0], posC[1], s=80)

    plt.locator_params(axis='x', nbins=2)
    plt.locator_params(axis='y', nbins=2)
    plt.xlabel('false positive rate')
    plt.ylabel('true positive rate')
    plt.title('ROC curve')
    plt.legend(loc='upper left')

    plt.tight_layout()
    plt.show()
#     fig.savefig(f'{save_path_stats}_{model_name}.png')

    return accuracy, precision, recall, f1, kappa, auc, tss_eval

# Functions for plotting and animation:
def anim_intensity_slit_movie(obs_cls, plot_arr, max_intensity = None, save_path=None):

    # get irisreader observation instance
    obs_id = obs_cls.obs_id
    year = obs_id[:4]
    month = obs_id[4:6]
    day = obs_id[6:8]
    pth = f'/sml/iris/{year}/{month}/{day}/{obs_id}'
    obs = observation( pth, keep_null=True )

    try:
        sji = obs.sji('Si IV')
        att = .25
        cm = 'binary_r'
    except Exception:
        try:
            sji = obs.sji('Mg II h/k')
            att = .25
            cm = 'binary_r'
        except Exception:
            sji = obs.sji('C II')
            att = .4
            cm = 'binary_r'

    sji_times = sji.get_timestamps()
    obs_time_0 = obs_cls.times[np.where(obs_cls.times>0)][0]
    obs_time_end = obs_cls.times[np.where(obs_cls.times>0)][-1]
    sj_0 = np.argmin(np.abs(obs_time_0 - sji_times))
    sj_end = np.argmin(np.abs(obs_time_end - sji_times))
    rast_ind = np.argmin(np.abs(obs_cls.times - sj_0), axis=1)

    plt.cla()
    plt.clf()
    plt.close()
    obs.close()

    exptimes = np.array([hdr['EXPTIME'] for hdr in sji.headers]).reshape(-1,1,1)

    # Generate initial figure for the animation with overplotted slit colors.
    sji_exped = (sji[:,:,:].clip(min=0)/exptimes[:,:,:])**att
    sji_exped[np.where(sji_exped == 0.)] = 5.0

    fig = plt.figure(figsize=(20,20))
    im = plt.imshow(sji_exped[sj_0,:,:], cmap=cm, vmax=10)
    slpos = raster_x_coord(obs_cls, sji, sj_0)
    slpos_grid = [ [sl]*plot_arr.shape[2] for sl in slpos ] # slit x-coords for each spectra for a single raster
    xcoords = np.asarray( [item for sublist in slpos_grid for item in sublist] ) # flatten nested list
    ycoords = np.asarray( list(range(plot_arr.shape[2]))*plot_arr.shape[0] ) # y-coord for each spectra for a single raster
    date_obs = parse_time(sji_times[sj_0], format='unix').to_datetime()
    im.axes.set_title( "Flare: {}  Frame: {}  Date-Time: {}".format( obs_cls.obs_id, 0, date_obs ), fontsize=25, alpha=.8)
    im.axes.set_xlabel( 'camera x', fontsize=25, alpha=.8 )
    im.axes.set_ylabel( 'camera y', fontsize=25, alpha=.8 )
    cmap = plt.cm.get_cmap('jet')


    if not max_intensity:
        max_intensity = np.max(plot_arr)*.2
    colors = cmap(np.asarray([plot_arr[m,r_ind] for m, r_ind in enumerate(rast_ind)]).squeeze()/max_intensity)
    if obs_cls.num_of_raster_pos != 1:
        scat = plt.scatter(xcoords, ycoords, marker='s', s=15, c=colors.reshape(colors.shape[0]*colors.shape[1], 4), alpha=.3)
    else:
        scat = plt.scatter(xcoords, ycoords, marker='s', s=15, c=colors, alpha=.3)

    plt.close()

    def init():
        return im,

    # animation function
    def update(i, sji_times):
        '''
        Update the data for each successive raster.
        '''
        sjiind = sj_0 + i

        sj_t = sji_times[sjiind]

        rast_ind = np.argmin(np.abs(obs_cls.times - sj_t), axis=1)

        slpos = raster_x_coord( obs_cls, sji, sjiind )
        im.set_data(sji_exped[sjiind,:,:])
        date_obs = parse_time(sj_t, format='unix').to_datetime()
        im.axes.set_title( "Flare: {}  Frame: {}  Date-Time: {}".format( obs_cls.obs_id, i, date_obs ), fontsize=25, alpha=.8)
        im.axes.set_xlabel( 'camera x', fontsize=25, alpha=.8 )
        im.axes.set_ylabel( 'camera y', fontsize=25, alpha=.8 )
        colors = cmap(np.asarray([plot_arr[m,r_ind] for m, r_ind in enumerate(rast_ind)]).squeeze()/max_intensity)
        scat.set_offsets(np.asarray([xcoords, ycoords]).T)
        if obs_cls.num_of_raster_pos != 1:
            scat.set_color(colors.reshape(colors.shape[0]*colors.shape[1], 4))
        else:
            scat.set_color(colors)

        return im


    anim = animation.FuncAnimation(fig, lambda i: update(i, sji_times), init_func=init, frames=sj_end-sj_0, interval=400)

    if save_path:
        anim.save(save_path)

    return HTML(anim.to_html5_video())


def raster_x_coord( obs_cls, sji, sjiind ):
    '''
    Returns n_raster slit positions ito pixels for a given SJI index.
    '''
    slitcoord_primary_x = sji.get_slit_pos(sjiind)-1
    slit_offset_secondary = sji.headers[sjiind]['PZTX']/sji.headers[sjiind]['CDELT2'] # in pix coords
    # This assumes that the PZT offsets remain constant for the entire observation
    raster_offsets_secondary = np.asarray( [ obs_cls.hdrs.loc[i]['PZTX']/obs_cls.hdrs.loc[i]['CDELT2'] for i in range(obs_cls.num_of_raster_pos) ] )
    # slit x position for each raster pos = position of primary - wedge tilt + fine scale secondary pztx
    slpos = slitcoord_primary_x - slit_offset_secondary + raster_offsets_secondary

    return slpos




## Functions to get the labels of the different splits and load the indices from subsamples to be used later by the Dataloader

In [None]:
def Load_data_labels(Label_set_filename, line):
    path_cleaned = f'~/cleaned/'

    Label_set_ = np.load(path_cleaned + Label_set_filename, allow_pickle=True)['arr_0'][()]

    return Label_set_


def Lazy_filehandler(files, line, max_samples=None):

    X_PF_dict = {}
    X_AR_dict = {}
    label_file_IDs = []
    label_indexs = []
    y_labels = []

    for file in files:
        if file[:2] == 'PF':
#             print(path_cleaned_PF + file)
            with h5py.File(path_cleaned_PF + file, 'r') as f:

                X_PF_dict[file] = f["im_arr_cleaned_QS"].shape[0]

        else:
#             print(path_cleaned_AR + file)
            with h5py.File(path_cleaned_AR + file, 'r') as f:

                X_AR_dict[file] = f["im_arr_cleaned_QS"].shape[0]

    PF_samples = np.sum(np.array(list((X_PF_dict.values()))))
    AR_samples = np.sum(np.array(list((X_AR_dict.values()))))

    print("number of PF samples in total: ", PF_samples, "number of AR samples in total: ", AR_samples)

    if PF_samples > AR_samples:
        upper_limit = AR_samples
        major_limit = PF_samples

        q = upper_limit/major_limit # factor between number of obs samples and sample size from each observation.

        for file, nsamples in X_PF_dict.items():

            sample_size = np.int(np.round(q*nsamples))
            sample_indices = np.random.randint(0, nsamples, sample_size)
            y_labels.extend(np.ones([sample_size]))
            label_file_IDs.extend(['PF/'+file for sample_index in sample_indices])
            label_indexs.extend([sample_index for sample_index in sample_indices])

        nsamples_PF = len(y_labels)
        print("number of samples PF: ", nsamples_PF)

        for file, nsamples in X_AR_dict.items():

            sample_indices = np.arange(0, nsamples)
            y_labels.extend(np.zeros([nsamples]))
            label_file_IDs.extend(['AR/'+file for sample_index in sample_indices])
            label_indexs.extend([sample_index for sample_index in sample_indices])

        nsamples_AR = len(y_labels) - nsamples_PF
        print("number of samples AR: ", nsamples_AR)

    else:
        upper_limit = PF_samples
        major_limit = AR_samples

        q = upper_limit/major_limit # factor between number of obs samples and sample size from each observation.

        for file, nsamples in X_AR_dict.items():

            sample_size = np.int(np.round(q*nsamples))
            sample_indices = np.random.randint(0, nsamples, sample_size)
            y_labels.extend(np.zeros([sample_size]))
            label_file_IDs.extend(['AR/'+file for sample_index in sample_indices])
            label_indexs.extend([sample_index for sample_index in sample_indices])

        nsamples_AR = len(y_labels)
        print("number of samples AR: ", nsamples_AR)

        for file, nsamples in X_PF_dict.items():

            sample_indices = np.arange(0, nsamples)
            y_labels.extend(np.ones([nsamples]))
            label_file_IDs.extend(['PF/'+file for sample_index in sample_indices])
            label_indexs.extend([sample_index for sample_index in sample_indices])

        nsamples_PF = len(y_labels) - nsamples_AR
        print("number of samples PF: ", nsamples_PF)

    y_labels = np.array(y_labels)
    label_file_IDs = np.array(label_file_IDs)
    label_indexs = np.array(label_indexs)

    if max_samples:
        if max_samples > len(y_labels):
            rand_inds = np.random.randint(0, len(y_labels), max_samples)
        else:
            rand_inds = np.random.randint(0, len(y_labels), max_samples)
        y_labels = y_labels[rand_inds]
        label_file_IDs = label_file_IDs[rand_inds]
        label_indexs = label_indexs[rand_inds]

    print("number of samples used (both PF + AR) label_IDs and y_labels: ", len(label_indexs), len(y_labels))

    return label_file_IDs, label_indexs, y_labels

def Lazy_filehandler_no_sampling(files, line):

    X_PF_dict = {}
    X_AR_dict = {}
    label_file_IDs = []
    label_indexs = []
    y_labels = []

    for file in files:
        if file[:2] == 'PF':
#             print(path_cleaned_PF + file)
            with h5py.File(path_cleaned_PF + file, 'r') as f:

                X_PF_dict[file] = f["im_arr_cleaned_QS"].shape[0]

        else:
#             print(path_cleaned_AR + file)
            with h5py.File(path_cleaned_AR + file, 'r') as f:

                X_AR_dict[file] = f["im_arr_cleaned_QS"].shape[0]


    PF_samples = np.sum(np.array(list((X_PF_dict.values()))))
    AR_samples = np.sum(np.array(list((X_AR_dict.values()))))

    print("number of PF samples in total: ", PF_samples, "number of AR samples in total: ", AR_samples)

    for file, nsamples in X_PF_dict.items():

        sample_indices = np.arange(0, nsamples)
        y_labels.extend(np.ones([nsamples]))
        label_file_IDs.extend(['PF/'+file for sample_index in sample_indices])
        label_indexs.extend(sample_indices)


    for file, nsamples in X_AR_dict.items():

        sample_indices = np.arange(0, nsamples)
        y_labels.extend(np.zeros([nsamples]))
        label_file_IDs.extend(['AR/'+file for sample_index in sample_indices])
        label_indexs.extend(sample_indices)

    y_labels = np.array(y_labels)
    label_file_IDs = np.array(label_file_IDs)
    label_indexs = np.array(label_indexs)

    print("number of samples used (both PF + AR) label_IDs and y_labels: ", len(label_indexs), len(y_labels))

    return label_file_IDs, label_indexs, y_labels



## Main body to train the models

I am using a workaround for the "Too many open files" problem with the Dataloader. If you run this cell for long time and many models, then check regularly if the memory leak is filling the memory. Notice that four different models are trained and tested at a time. The train, test, and compute_score functions don't care about how many models are submitted, but the ROC plot only accepts one at a time. The model architectures we used for this analysis are:


Model architectures single line experiments fully connected neural networks

| Single Layer |                  |                  |                  |
|--------------|------------------|------------------|------------------|
|              | 1st hidden Layer |                  |                  |
| Mg II h&k    | 10, Sigmoid      |                  |                  |
| Si IV        | 10, Sigmoid      |                  |                  |
| CII          | 10, Sigmoid      |                  |                  |
| Two Layers   |                  |                  |                  |
|              | 1st hidden Layer | 2nd hidden Layer |                  |
| Mg II h&k    | 12, ReLU         | 8, Sigmoid       |                  |
| Si IV        | 12, ReLU         | 8, Sigmoid       |                  |
| CII          | 12, ReLU         | 8, Sigmoid       |                  |
| Three Layers |                  |                  |                  |
|              | 1st hidden Layer | 2nd hidden Layer | 3rd hidden Layer |
| Mg II h&k    | 10, ReLU         | 12, ReLU         | 8, Sigmoid       |
| Si IV        | 12, ReLU         | 10, ReLU         | 8, Sigmoid       |
| CII          | 10, ReLU         | 12, ReLU         | 8, Sigmoid       |


Model architectures single line experiments convolutional neural networks

| ConvNet   |                                                              |                                                           |
|-----------|--------------------------------------------------------------|-----------------------------------------------------------|
|           | 1st convolutional layer                                      | 1st hidden layer                                          |
| Mg II h&k | in ch.: 1, out ch.: 10, kernel size: 32, stride: 16, ReLU    | 6, Sigmoid                                                |
| Si IV     | in ch.: 1, out ch.: 2, kernel size: 20, stride: 4, ReLU      | 12, Sigmoid                                               |
| CII       | in ch.: 1, out ch.: 2, kernel size: 20, stride: 4,   ReLU    | 6, Sigmoid                                                |
|           |                                                              |                                                           |
|           | 2nd convolutional layer                                      | 2nd hidden layer                                          |
|           | in ch.: 10, out ch.: 20,   kernel size: 10 , stride: 5, ReLU |                                                           |
|           | in ch.: 2, out ch.: 4, kernel size: 16 , stride: 3, ReLU     | 8, Sigmoid                                                |
|           | in ch.: 2, out ch.: 4, kernel size: 16, stride: 3,   ReLU    |                                                           |


Model architecture line combination:

| Line combinations |                          |
|-------------------|--------------------------|
|                   | 1st convolutional layer: | in ch.: 1, out ch.: 10, kernel size: 32, stride: 16, ReLU  |
|                   | 2nd convolutional layer: | in ch.: 10, out ch.: 20, kernel size: 16 , stride: 8, ReLU |
|                   | 3rd convolutional layer: | in ch.: 20, out ch.: 40, kernel size: 4 , stride:2, ReLU   |
|                   | 1st hidden layer:        | 10, ReLU or Sigmoid                                        |
|                   | 2nd hidden layer:        | 12, ReLU or Sigmoid                                        |
|                   | 3rd hidden layer:        | 8, Sigmoid                                                 |


In [None]:
%%time
#all PF

sharing_strategy = "file_system"
torch.multiprocessing.set_sharing_strategy(sharing_strategy)

def set_worker_sharing_strategy(worker_id: int) -> None:
    torch.multiprocessing.set_sharing_strategy(sharing_strategy)

importlib.reload(utils)
importlib.reload(mdls)


N_EPOCHS = 10

line = 'MgIIk'
line_params = MgIIk

# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True # only works if input shape does not vary, optimizes the use of your hardware to run the training

# keep track of the factor between train and validation batches

batch_size_train = 6400
batch_size_val = 6400


Label_set = Load_data_labels('Label_set_5_5.npz', line)
for itter in range(num_of_repetitions):
    for k in range(K):

        if ('decision_model_zbinden_test_uncleaned_non_sampled_Brandon_full_' + str(k) + '_' + str(itter)+'_TwoLayers.pt') in os.listdir(f'/sml/zbindenj/{line}/models/'):
            print(itter, k)
        else:
            label_train = Label_set[str(k) + '_' + str(itter) + '_training']
            label_test = Label_set[str(k) + '_' + str(itter) + '_testing']

            label_file_IDs_train, label_indexs_train, y_labels_train = Lazy_filehandler_no_sampling(label_train, line)
            training_set = mdls.Dataset_NN(label_file_IDs_train, label_indexs_train, y_labels_train, path_cleaned)

            label_file_IDs_test, label_indexs_test, y_labels_test = Lazy_filehandler_no_sampling(label_test, line)
            validation_set = mdls.Dataset_NN(label_file_IDs_test, label_indexs_test, y_labels_test, path_cleaned)

            label_file_IDs_test, label_indexs_test, y_labels_test = Lazy_filehandler_no_sampling(label_test, line)
            validation_set_fast = mdls.Dataset_NN(label_file_IDs_test, label_indexs_test, y_labels_test, path_cleaned)

            save_path_models = f'~/models/decision_model_' + str(k) + '_' + str(itter)
            save_path_stats = f'~/stats/decision_model_' + str(k) + '_' + str(itter)

            # Parameters and Generators
            params_train = {'batch_size': batch_size_train,
                      'shuffle': True,
                      'num_workers': 16,
                      'worker_init_fn': set_worker_sharing_strategy,
                      'persistent_workers': True}

            training_generator = torch.utils.data.DataLoader(training_set, **params_train)

            params_val = {'batch_size': np.int(np.floor(len(validation_set)/len(training_generator))),
                      'shuffle': True,
                      'num_workers': 16,
                      'worker_init_fn': set_worker_sharing_strategy,
                      'persistent_workers': True}

            params_fast_val = {'batch_size': batch_size_val,
                      'shuffle': True,
                      'num_workers': 16,
                      'worker_init_fn': set_worker_sharing_strategy,
                      'persistent_workers': True}

            validation_generator = torch.utils.data.DataLoader(validation_set, **params_val)

            validation_fast_generator = torch.utils.data.DataLoader(validation_set_fast, **params_fast_val)

            optimizer = []
            torch.cuda.empty_cache()
            decision_model_SingleLayer = mdls.SingleLayer(line_params['n_breaks'], 10).to(device) # Initiate network
            decision_model_TwoLayers = mdls.TwoLayers(line_params['n_breaks'], 12, 8).to(device) # Initiate network
            decision_model_ThreeLayers = mdls.ThreeLayers(line_params['n_breaks'], 10, 12, 8).to(device) # Initiate network
            decision_model_convnet = mdls.ConvNet(line_params['n_breaks'], 6).to(device) # Initiate network

            decision_models = [decision_model_SingleLayer, decision_model_TwoLayers, decision_model_ThreeLayers, decision_model_convnet]

            criterion = torch.nn.BCELoss(reduction='mean') # Use Binary cross-entropy as the loss function
            optimizer.append(optim.Adam(decision_models[0].parameters(), lr=0.001, weight_decay=.001))
            optimizer.append(optim.Adam(decision_models[1].parameters(), lr=0.001, weight_decay=.001))
            optimizer.append(optim.Adam(decision_models[2].parameters(), lr=0.001, weight_decay=.001))
            optimizer.append(optim.Adam(decision_models[3].parameters(), lr=0.001, weight_decay=.001))


            # Loop over epochs
            training_loss1 = []
            validation_loss1 = []
            accuracies_valid1 = []

            training_loss2 = []
            validation_loss2 = []
            accuracies_valid2 = []

            training_loss3 = []
            validation_loss3 = []
            accuracies_valid3 = []

            training_loss4 = []
            validation_loss4 = []
            accuracies_valid4 = []

            Mean_loss_train1 = []
            Mean_loss_valid1 = []
            Mean_loss_acc1 = []

            Mean_loss_train2 = []
            Mean_loss_valid2 = []
            Mean_loss_acc2 = []

            Mean_loss_train3 = []
            Mean_loss_valid3 = []
            Mean_loss_acc3 = []

            Mean_loss_train4 = []
            Mean_loss_valid4 = []
            Mean_loss_acc4 = []

            stats_current = {}
            best_TSS0 = -1
            best_TSS1 = -1
            best_TSS2 = -1
            best_TSS3 = -1


            _ = [decision_model.train() for decision_model in decision_models]

            _ = [decision_model.to(device) for decision_model in decision_models if (not next(decision_model.parameters()).is_cuda)]

            break_ = False
            n = 0
            catch_ = 0
            for epoch in tqdm(range(N_EPOCHS)):

                counter = 0

                gc.collect()

                # Training
                for X, V in tqdm(zip(training_generator, validation_generator), total=len(training_generator)):

                    train_loss1 = train(decision_model_SingleLayer, optimizer[0], X)
                    train_loss2 = train(decision_model_TwoLayers, optimizer[1], X)
                    train_loss3 = train(decision_model_ThreeLayers, optimizer[2], X)
                    train_loss4 = train(decision_model_convnet, optimizer[3], X)

                    training_loss1.append(train_loss1)
                    training_loss2.append(train_loss2)
                    training_loss3.append(train_loss3)
                    training_loss4.append(train_loss4)



                    acc_valid1, valid_loss1 = test(decision_model_SingleLayer, V)
                    acc_valid2, valid_loss2 = test(decision_model_TwoLayers, V)
                    acc_valid3, valid_loss3 = test(decision_model_ThreeLayers, V)
                    acc_valid4, valid_loss4 = test(decision_model_convnet, V)

                    validation_loss1.append(valid_loss1)
                    accuracies_valid1.append(acc_valid1)
                    validation_loss2.append(valid_loss2)
                    accuracies_valid2.append(acc_valid2)
                    validation_loss3.append(valid_loss3)
                    accuracies_valid3.append(acc_valid3)
                    validation_loss4.append(valid_loss4)
                    accuracies_valid4.append(acc_valid4)


                    counter += 1


                    if ((counter % 200 == 199) and ((accuracies_valid1[-1]<.6) and epoch>5)): # make dependent on best model architecture


                        accuracy, precision, recall, f1, kappa, auc, tss_eval = compute_score(decision_models, validation_fast_generator)

                        stats_current = {'best epoch':epoch, 'accuracy':accuracy[0], 'precision':precision[0], 'recall':recall[0], 'f1':f1[0], 'kappa':kappa[0], 'auc':auc[0], 'TSS':tss_eval[0]}
                        current_TSS = tss_eval[0]
                        if current_TSS > best_TSS0: # only best model is stored
                            best_TSS0 = current_TSS
                            best_epoch0 = epoch
                            torch.save(decision_model_SingleLayer.state_dict(), f'{save_path_models}_SingleLayer.pt')
                            with open(f'{save_path_stats}_SingleLayer.p', 'wb') as f: pickle.dump(stats_current,f)


                        stats_current = {'best epoch':epoch, 'accuracy':accuracy[1], 'precision':precision[1], 'recall':recall[1], 'f1':f1[1], 'kappa':kappa[1], 'auc':auc[1], 'TSS':tss_eval[1]}
                        current_TSS = tss_eval[1]
                        if current_TSS > best_TSS1: # only best model is stored
                            best_TSS1 = current_TSS
                            best_epoch1 = epoch
                            torch.save(decision_model_TwoLayers.state_dict(), f'{save_path_models}_TwoLayers.pt')
                            with open(f'{save_path_stats}_TwoLayers.p', 'wb') as f: pickle.dump(stats_current,f)


                        stats_current = {'best epoch':epoch, 'accuracy':accuracy[2], 'precision':precision[2], 'recall':recall[2], 'f1':f1[2], 'kappa':kappa[2], 'auc':auc[2], 'TSS':tss_eval[2]}
                        current_TSS = tss_eval[2]
                        if current_TSS > best_TSS2: # only best model is stored
                            best_TSS2 = current_TSS
                            best_epoch2 = epoch
                            torch.save(decision_model_ThreeLayers.state_dict(), f'{save_path_models}_ThreeLayers.pt')
                            with open(f'{save_path_stats}_ThreeLayers.p', 'wb') as f: pickle.dump(stats_current,f)


                        stats_current = {'best epoch':epoch, 'accuracy':accuracy[3], 'precision':precision[3], 'recall':recall[3], 'f1':f1[3], 'kappa':kappa[3], 'auc':auc[3], 'TSS':tss_eval[3]}
                        current_TSS = tss_eval[3]
                        if current_TSS > best_TSS3: # only best model is stored
                            best_TSS3 = current_TSS
                            best_epoch3 = epoch
                            torch.save(decision_model_convnet.state_dict(), f'{save_path_models}_convnet.pt')
                            with open(f'{save_path_stats}_convnet.p', 'wb') as f: pickle.dump(stats_current,f)

#                         if epoch >= 5:
#                             break_ = True
#                             pass


                #After each epoch evaluate the model stats
                accuracy, precision, recall, f1, kappa, auc, tss_eval = compute_score(decision_models, validation_fast_generator)

                stats_current = {'best epoch':epoch, 'accuracy':accuracy[0], 'precision':precision[0], 'recall':recall[0], 'f1':f1[0], 'kappa':kappa[0], 'auc':auc[0], 'TSS':tss_eval[0]}
                current_TSS = tss_eval[0]
                if current_TSS > best_TSS0: # only best model is stored
                    best_TSS0 = current_TSS
                    best_epoch0 = epoch
                    torch.save(decision_model_SingleLayer.state_dict(), f'{save_path_models}_SingleLayer.pt')
                    with open(f'{save_path_stats}_SingleLayer.p', 'wb') as f: pickle.dump(stats_current,f)


                stats_current = {'best epoch':epoch, 'accuracy':accuracy[1], 'precision':precision[1], 'recall':recall[1], 'f1':f1[1], 'kappa':kappa[1], 'auc':auc[1], 'TSS':tss_eval[1]}
                current_TSS = tss_eval[1]
                if current_TSS > best_TSS1: # only best model is stored
                    best_TSS1 = current_TSS
                    best_epoch1 = epoch
                    torch.save(decision_model_TwoLayers.state_dict(), f'{save_path_models}_TwoLayers.pt')
                    with open(f'{save_path_stats}_TwoLayers.p', 'wb') as f: pickle.dump(stats_current,f)


                stats_current = {'best epoch':epoch, 'accuracy':accuracy[2], 'precision':precision[2], 'recall':recall[2], 'f1':f1[2], 'kappa':kappa[2], 'auc':auc[2], 'TSS':tss_eval[2]}
                current_TSS = tss_eval[2]
                if current_TSS > best_TSS2: # only best model is stored
                    best_TSS2 = current_TSS
                    best_epoch2 = epoch
                    torch.save(decision_model_ThreeLayers.state_dict(), f'{save_path_models}_ThreeLayers.pt')
                    with open(f'{save_path_stats}_ThreeLayers.p', 'wb') as f: pickle.dump(stats_current,f)


                stats_current = {'best epoch':epoch, 'accuracy':accuracy[3], 'precision':precision[3], 'recall':recall[3], 'f1':f1[3], 'kappa':kappa[3], 'auc':auc[3], 'TSS':tss_eval[3]}
                current_TSS = tss_eval[3]
                if current_TSS > best_TSS3: # only best model is stored
                    best_TSS3 = current_TSS
                    best_epoch3 = epoch
                    torch.save(decision_model_convnet.state_dict(), f'{save_path_models}_convnet.pt')
                    with open(f'{save_path_stats}_convnet.p', 'wb') as f: pickle.dump(stats_current,f)

                # SingleLayer

                # summarize history for loss and accuracy
                fig1, ax1 = plt.subplots(figsize=(15,4))
                plt.title('model loss')
                plt.xlabel('training mini batches')
                plt.ylabel('loss')

                l1, = plt.plot(np.arange(0, len(training_loss1)), np.asarray(training_loss1), color='black', alpha=.5)
                l2, = plt.plot(np.arange(0, len(validation_loss1)), np.asarray(validation_loss1), color='coral', alpha=.5)
                ax2 = ax1.twinx()
                l3, = plt.plot(np.arange(0, len(accuracies_valid1)), np.asarray(accuracies_valid1), color='blue', alpha=.5)
                plt.legend([l1,l2,l3],['loss', 'valid', 'acc'])
    #             fig1.savefig(f'{save_path_stats}_SingleLayer.pdf')
                plt.show()


                # TwoLayers

                # summarize history for loss and accuracy
                fig2, ax2 = plt.subplots(figsize=(15,4))
                plt.title('model loss')
                plt.xlabel('training mini batches')
                plt.ylabel('loss')

                l1, = plt.plot(np.arange(0, len(training_loss2)), np.asarray(training_loss2), color='black', alpha=.5)
                l2, = plt.plot(np.arange(0, len(validation_loss2)), np.asarray(validation_loss2), color='coral', alpha=.5)
                ax2 = ax1.twinx()
                l3, = plt.plot(np.arange(0, len(accuracies_valid2)), np.asarray(accuracies_valid2), color='blue', alpha=.5)
                ax2.set_ylim([0,1])
                plt.legend([l1,l2,l3],['loss', 'valid', 'acc'])
    #             fig2.savefig(f'{save_path_stats}_TwoLayers.pdf')
                plt.show()


                # ThreeLayers

                # summarize history for loss and accuracy
                fig3, ax3 = plt.subplots(figsize=(15,4))
                plt.title('model loss')
                plt.xlabel('training mini batches')
                plt.ylabel('loss')

                l1, = plt.plot(np.arange(0, len(training_loss3)), np.asarray(training_loss3), color='black', alpha=.5)
                l2, = plt.plot(np.arange(0, len(validation_loss3)), np.asarray(validation_loss3), color='coral', alpha=.5)
                ax2 = ax1.twinx()
                l3, = plt.plot(np.arange(0, len(accuracies_valid3)), np.asarray(accuracies_valid3), color='blue', alpha=.5)
                plt.legend([l1,l2,l3],['loss', 'valid', 'acc'])
    #             fig3.savefig(f'{save_path_stats}_ThreeLayers.pdf')
                plt.show()

                # ConvNet

                # summarize history for loss and accuracy
                fig4, ax4 = plt.subplots(figsize=(15,4))
                plt.title('model loss')
                plt.xlabel('training mini batches')
                plt.ylabel('loss')

                l1, = plt.plot(np.arange(0, len(training_loss4)), np.asarray(training_loss4), color='black', alpha=.5)
                l2, = plt.plot(np.arange(0, len(validation_loss4)), np.asarray(validation_loss4), color='coral', alpha=.5)
                ax2 = ax1.twinx()
                l3, = plt.plot(np.arange(0, len(accuracies_valid4)), np.asarray(accuracies_valid4), color='blue', alpha=.5)
                plt.legend([l1,l2,l3],['loss', 'valid', 'acc'])
    #             fig4.savefig(f'{save_path_stats}_convnet.pdf')
                plt.show()

    ######################################################################################################################################################################################

            # Last figure of learning curves

            # SingleLayer

            # summarize history for loss and accuracy
            fig1, ax1 = plt.subplots(figsize=(15,4))
            plt.title('model loss')
            plt.xlabel('training mini batches')
            plt.ylabel('loss')

            l1, = plt.plot(np.arange(0, len(training_loss1)), np.asarray(training_loss1), color='black', alpha=.5)
            l2, = plt.plot(np.arange(0, len(validation_loss1)), np.asarray(validation_loss1), color='coral', alpha=.5)
            ax2 = ax1.twinx()
            l3, = plt.plot(np.arange(0, len(accuracies_valid1)), np.asarray(accuracies_valid1), color='blue', alpha=.5)
            plt.legend([l1,l2,l3],['loss', 'valid', 'acc'])
            fig1.savefig(f'{save_path_stats}_SingleLayer.pdf')
            plt.show()


            # TwoLayers

            # summarize history for loss and accuracy
            fig2, ax2 = plt.subplots(figsize=(15,4))
            plt.title('model loss')
            plt.xlabel('training mini batches')
            plt.ylabel('loss')

            l1, = plt.plot(np.arange(0, len(training_loss2)), np.asarray(training_loss2), color='black', alpha=.5)
            l2, = plt.plot(np.arange(0, len(validation_loss2)), np.asarray(validation_loss2), color='coral', alpha=.5)
            ax2 = ax1.twinx()
            l3, = plt.plot(np.arange(0, len(accuracies_valid2)), np.asarray(accuracies_valid2), color='blue', alpha=.5)
            ax2.set_ylim([0,1])
            plt.legend([l1,l2,l3],['loss', 'valid', 'acc'])
            fig2.savefig(f'{save_path_stats}_TwoLayers.pdf')
            plt.show()


            # ThreeLayers

            # summarize history for loss and accuracy
            fig3, ax3 = plt.subplots(figsize=(15,4))
            plt.title('model loss')
            plt.xlabel('training mini batches')
            plt.ylabel('loss')

            l1, = plt.plot(np.arange(0, len(training_loss3)), np.asarray(training_loss3), color='black', alpha=.5)
            l2, = plt.plot(np.arange(0, len(validation_loss3)), np.asarray(validation_loss3), color='coral', alpha=.5)
            ax2 = ax1.twinx()
            l3, = plt.plot(np.arange(0, len(accuracies_valid3)), np.asarray(accuracies_valid3), color='blue', alpha=.5)
            plt.legend([l1,l2,l3],['loss', 'valid', 'acc'])
            fig3.savefig(f'{save_path_stats}_ThreeLayers.pdf')
            plt.show()

            # ConvNet

            # summarize history for loss and accuracy
            fig4, ax4 = plt.subplots(figsize=(15,4))
            plt.title('model loss')
            plt.xlabel('training mini batches')
            plt.ylabel('loss')

            l1, = plt.plot(np.arange(0, len(training_loss4)), np.asarray(training_loss4), color='black', alpha=.5)
            l2, = plt.plot(np.arange(0, len(validation_loss4)), np.asarray(validation_loss4), color='coral', alpha=.5)
            ax2 = ax1.twinx()
            l3, = plt.plot(np.arange(0, len(accuracies_valid4)), np.asarray(accuracies_valid4), color='blue', alpha=.5)
            plt.legend([l1,l2,l3],['loss', 'valid', 'acc'])
            fig4.savefig(f'{save_path_stats}_convnet.pdf')
            plt.show()

            try:
    #             exc_info = sys.exc_info()

                decision_model_SingleLayer.load_state_dict(torch.load(f'{save_path_models}_SingleLayer.pt'))
                decision_model_TwoLayers.load_state_dict(torch.load(f'{save_path_models}_TwoLayers.pt'))
                decision_model_ThreeLayers.load_state_dict(torch.load(f'{save_path_models}_ThreeLayers.pt'))
                decision_model_convnet.load_state_dict(torch.load(f'{save_path_models}_convnet.pt'))

                ROC_plot(decision_model_SingleLayer, validation_fast_generator, 'SingleLayer')
                ROC_plot(decision_model_TwoLayers, validation_fast_generator, 'TwoLayers')
                ROC_plot(decision_model_ThreeLayers, validation_fast_generator, 'ThreeLayers')
                ROC_plot(decision_model_convnet, validation_fast_generator, 'convnet')

            except RuntimeError or FileNotFoundError:
                pass # models that have not reached a complete epoch before going into overfitting are not stored



## Validation of models over the whole set of splits
Here the statistics are computed. If there are more or less than four models, blocks have to be added or removed here and the subsequent cells.

In [None]:
# Make overall statistics
importlib.reload(mdls)

Label_set = Load_data_labels('Label_set_5_5.npz', line)

keys = ['k_itter',
        'best_epoch',
        'test_obs',
        'n_train',
        'n_test',
        'TSS',
        'ACC',
        'recall',
        'precision',
        'kappa',
        'f1',
        'train_test_ratio']

Failed = []

sharing_strategy = "file_system"
torch.multiprocessing.set_sharing_strategy(sharing_strategy)

def set_worker_sharing_strategy(worker_id: int) -> None:
    torch.multiprocessing.set_sharing_strategy(sharing_strategy)

# CUDA for PyTorch

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
torch.backends.cudnn.benchmark = True # only works if input shape does not vary, optimizes the use of your hardware to run the training

settype = 'test' # or 'test'

all_stats_dic1 = {}
all_stats_dic2 = {}
all_stats_dic3 = {}
all_stats_dic4 = {}



train_obs_list = []
test_obs_list = []
n_train_list = []
n_test_list = []
kitter_list = []
TSS_list1 = []
ACC_list1 = []
Recall_list1 = []
Precision_list1 = []
AUC_list1 = []
Kappa_list1 = []
F1_list1 = []
train_test_ratio_list = []
best_epoch_list1 = []


best_epoch_list2 = []
TSS_list2 = []
ACC_list2 = []
Recall_list2 = []
Precision_list2 = []
AUC_list2 = []
Kappa_list2 = []
F1_list2 = []


best_epoch_list3 = []
TSS_list3 = []
ACC_list3 = []
Recall_list3 = []
Precision_list3 = []
AUC_list3 = []
Kappa_list3 = []
F1_list3 = []


best_epoch_list4 = []
TSS_list4 = []
ACC_list4 = []
Recall_list4 = []
Precision_list4 = []
AUC_list4 = []
Kappa_list4 = []
F1_list4 = []


for itter in range(num_of_repetitions):
    for k in range(K):

        all_stats_list1 = []
        all_stats_list2 = []
        all_stats_list3 = []
        all_stats_list4 = []

        save_path_stats = f'~/stats/decision_model_{k}_{itter}'
        save_path_models = f'~/models/decision_model_{k}_{itter}'

        all_obs_ids = obs_ids_Brandon_PF + obs_ids_Brandon_AR

        label_train = Label_set[str(k) + '_' + str(itter) + '_training']
        label_test = []
        obs_ids_train = []
        obs_ids_test = []

        for obs_np_file in label_train:
            obs_ids_train.append(obs_np_file.split('.')[0][3:29])

        for obs_id in all_obs_ids:
            if obs_id in obs_ids_train:
                pass
            else:
                obs_ids_test.append(obs_id)

        obs_ids_test = list(set(obs_ids_test))
        label_train = []

        for obs_np_file in os.listdir(path_cleaned_PF):
            if obs_np_file.split('.')[0][3:29] in obs_ids_test:
                label_test.append(obs_np_file)
            elif obs_np_file.split('.')[0][3:29] in obs_ids_train:
                label_train.append(obs_np_file)
            else:
                pass

        for obs_np_file in os.listdir(path_cleaned_AR):
            if obs_np_file.split('.')[0][3:29] in obs_ids_test:
                label_test.append(obs_np_file)
            elif obs_np_file.split('.')[0][3:29] in obs_ids_train:
                label_train.append(obs_np_file)
            else:
                pass


        label_file_IDs_train, label_indexs_train, y_labels_train = Lazy_filehandler_no_sampling(label_train, line)
        label_file_IDs_test, label_indexs_test, y_labels_test = Lazy_filehandler_no_sampling(label_test, line)

        failed=False
        try:

            validation_set = mdls.Dataset_NN(label_file_IDs_test, label_indexs_test, y_labels_test, path_cleaned)

            # Parameters
            params_val = {'batch_size': 6400,
                      'shuffle': True,
                      'num_workers': 16}

            # Generators

            validation_generator = torch.utils.data.DataLoader(validation_set, **params_val)

            torch.cuda.empty_cache()
            decision_model_SingleLayer = mdls.SingleLayer(line_params['n_breaks'], 10).to(device) # Initiate network
            decision_model_TwoLayers = mdls.TwoLayers(line_params['n_breaks'], 12, 8).to(device) # Initiate network
            decision_model_ThreeLayers = mdls.ThreeLayers(line_params['n_breaks'], 10, 12, 8).to(device) # Initiate network
            decision_model_convnet = mdls.ConvNet(line_params['n_breaks'], 6).to(device) # Initiate network

            decision_model_SingleLayer.load_state_dict(torch.load(f'{save_path_models}_SingleLayer.pt', map_location=torch.device('cpu')))
            decision_model_TwoLayers.load_state_dict(torch.load(f'{save_path_models}_TwoLayers.pt', map_location=torch.device('cpu')))
            decision_model_ThreeLayers.load_state_dict(torch.load(f'{save_path_models}_ThreeLayers.pt', map_location=torch.device('cpu')))
            decision_model_convnet.load_state_dict(torch.load(f'{save_path_models}_convnet.pt', map_location=torch.device('cpu')))

            decision_models = [decision_model_SingleLayer, decision_model_TwoLayers, decision_model_ThreeLayers, decision_model_convnet]
            _ = [decision_model.to(device) for decision_model in decision_models if (not next(decision_model.parameters()).is_cuda)]


            try:
                with open(f'{save_path_stats}_SingleLayer_full_{settype}.p', 'rb') as f:
                    stats1 = pickle.load(f)

                with open(f'{save_path_stats}_TwoLayers_full_{settype}.p', 'rb') as f:
                    stats2 = pickle.load(f)

                with open(f'{save_path_stats}_ThreeLayers_full_{settype}.p', 'rb') as f:
                    stats3 = pickle.load(f)

                with open(f'{save_path_stats}_convnet_full_{settype}.p', 'rb') as f:
                    stats4 = pickle.load(f)
                print(mmm)

            except Exception as exc:

                print(exc)



                accuracy, precision, recall, f1, kappa, auc, tss_eval = compute_score(decision_models, validation_generator)

                with open(f'{save_path_stats}_SingleLayer.p', 'rb') as f:
                    stats_single1 = pickle.load(f)

                stats1 = {'k_itter': str(k)+'_'+str(itter),
                         'test_obs': Label_set[str(k) + '_' + str(itter) + '_testing'],
                         'best_epoch' : stats_single1['best epoch']+1,
                         'n_train': len(label_file_IDs_train),
                         'n_test': len(label_file_IDs_test),
                         'TSS': tss_eval[0],
                         'ACC': accuracy[0],
                         'recall' : recall[0],
                         'precision' : precision[0],
                         'auc' : auc[0],
                         'kappa' : kappa[0],
                         'f1' : f1[0],
                         'train_test_ratio': len(label_file_IDs_test)/(len(label_file_IDs_train)+len(label_file_IDs_test))}


                with open(f'{save_path_stats}_SingleLayer_full_{settype}.p', 'wb') as f: pickle.dump(stats1,f)

                with open(f'{save_path_stats}_TwoLayers.p', 'rb') as f:
                    stats_single2 = pickle.load(f)

    #             stats2 = {'best epoch':epoch, 'accuracy':accuracy[0], 'precision':precision[0], 'recall':recall[0], 'f1':f1[0], 'kappa':kappa[0], 'auc':auc[0], 'TSS':tss_eval[0]}

                stats2 = {'k_itter': str(k)+'_'+str(itter),
                         'test_obs': Label_set[str(k) + '_' + str(itter) + '_testing'],
                         'best_epoch' : stats_single2['best epoch']+1,
                         'n_train': len(label_file_IDs_train),
                         'n_test': len(label_file_IDs_test),
                         'TSS': tss_eval[1],
                         'ACC': accuracy[1],
                         'recall' : recall[1],
                         'precision' : precision[1],
                         'auc' : auc[1],
                         'kappa' : kappa[1],
                         'f1' : f1[1],
                         'train_test_ratio': len(label_file_IDs_test)/(len(label_file_IDs_train)+len(label_file_IDs_test))}

                with open(f'{save_path_stats}_TwoLayers_full_{settype}.p', 'wb') as f: pickle.dump(stats2,f)

    #             print('TSS 2 list : ', TSS_list2)

                with open(f'{save_path_stats}_ThreeLayers.p', 'rb') as f:
                    stats_single3 = pickle.load(f)

    #             stats3 = {'best epoch':epoch, 'accuracy':accuracy[0], 'precision':precision[0], 'recall':recall[0], 'f1':f1[0], 'kappa':kappa[0], 'auc':auc[0], 'TSS':tss_eval[0]}

                stats3 = {'k_itter': str(k)+'_'+str(itter),
                         'test_obs': Label_set[str(k) + '_' + str(itter) + '_testing'],
                         'best_epoch' : stats_single3['best epoch']+1,
                         'n_train': len(label_file_IDs_train),
                         'n_test': len(label_file_IDs_test),
                         'TSS': tss_eval[2],
                         'ACC': accuracy[2],
                         'recall' : recall[2],
                         'precision' : precision[2],
                         'auc' : auc[2],
                         'kappa' : kappa[2],
                         'f1' : f1[2],
                         'train_test_ratio': len(label_file_IDs_test)/(len(label_file_IDs_train)+len(label_file_IDs_test))}

                with open(f'{save_path_stats}_ThreeLayers_full_{settype}.p', 'wb') as f: pickle.dump(stats3,f)


                with open(f'{save_path_stats}_convnet.p', 'rb') as f:
                    stats_single4 = pickle.load(f)

    #             stats4 = {'best epoch':epoch, 'accuracy':accuracy[0], 'precision':precision[0], 'recall':recall[0], 'f1':f1[0], 'kappa':kappa[0], 'auc':auc[0], 'TSS':tss_eval[0]}

                stats4 = {'k_itter': str(k)+'_'+str(itter),
                         'test_obs': Label_set[str(k) + '_' + str(itter) + '_testing'],
                         'best_epoch' : stats_single4['best epoch']+1,
                         'n_train': len(label_file_IDs_train),
                         'n_test': len(label_file_IDs_test),
                         'TSS': tss_eval[3],
                         'ACC': accuracy[3],
                         'recall' : recall[3],
                         'precision' : precision[3],
                         'auc' : auc[3],
                         'kappa' : kappa[3],
                         'f1' : f1[3],
                         'train_test_ratio': len(label_file_IDs_test)/(len(label_file_IDs_train)+len(label_file_IDs_test))}

                with open(f'{save_path_stats}_convnet_full_{settype}.p', 'wb') as f: pickle.dump(stats4,f)


        except Exception as error:
            Failed.append((itter,k))

            print(error)
            failed=True

        try:
            if failed:
                pass
            else:

                kitter_list.append(stats1['k_itter'])
                best_epoch_list1.append(stats1['best_epoch'])
                test_obs_list.append(stats1['test_obs'])
                n_train_list.append(stats1['n_train'])
                n_test_list.append(stats1['n_test'])
                TSS_list1.append(stats1['TSS'])
                ACC_list1.append(stats1['ACC'])
                Recall_list1.append(stats1['recall'])
                Precision_list1.append(stats1['precision'])
                AUC_list1.append(stats1['auc'])
                Kappa_list1.append(stats1['kappa'])
                F1_list1.append(stats1['f1'])
                train_test_ratio_list.append(stats1['train_test_ratio'])

                best_epoch_list2.append(stats2['best_epoch'])
                TSS_list2.append(stats2['TSS'])
                ACC_list2.append(stats2['ACC'])
                Recall_list2.append(stats2['recall'])
                Precision_list2.append(stats2['precision'])
                AUC_list2.append(stats2['auc'])
                Kappa_list2.append(stats2['kappa'])
                F1_list2.append(stats2['f1'])

                best_epoch_list3.append(stats3['best_epoch'])
                TSS_list3.append(stats3['TSS'])
                ACC_list3.append(stats3['ACC'])
                Recall_list3.append(stats3['recall'])
                Precision_list3.append(stats3['precision'])
                AUC_list3.append(stats3['auc'])
                Kappa_list3.append(stats3['kappa'])
                F1_list3.append(stats3['f1'])

                best_epoch_list4.append(stats4['best_epoch'])
                TSS_list4.append(stats4['TSS'])
                ACC_list4.append(stats4['ACC'])
                Recall_list4.append(stats4['recall'])
                Precision_list4.append(stats4['precision'])
                AUC_list4.append(stats4['auc'])
                Kappa_list4.append(stats4['kappa'])
                F1_list4.append(stats4['f1'])

        except Exception as exc:
            print(error)


        for key in keys:
            all_stats_list1.append(stats1[key])
            all_stats_list2.append(stats2[key])
            all_stats_list3.append(stats3[key])
            all_stats_list4.append(stats4[key])

        all_stats_dic1[str(k) +'_'+ str(itter)] = all_stats_list1
        all_stats_dic2[str(k) +'_'+ str(itter)] = all_stats_list2
        all_stats_dic3[str(k) +'_'+ str(itter)] = all_stats_list3
        all_stats_dic4[str(k) +'_'+ str(itter)] = all_stats_list4


sorted_itters1 = sorted(all_stats_dic1.keys(), key = lambda x: all_stats_dic1[x][3], reverse = True)
sorted_itters2 = sorted(all_stats_dic2.keys(), key = lambda x: all_stats_dic2[x][3], reverse = True)
sorted_itters3 = sorted(all_stats_dic3.keys(), key = lambda x: all_stats_dic3[x][3], reverse = True)
sorted_itters4 = sorted(all_stats_dic4.keys(), key = lambda x: all_stats_dic4[x][3], reverse = True)


In [None]:
kitter = np.array(kitter_list)
best_epoch1 = np.array(best_epoch_list1)
test_obs = np.array(test_obs_list)
n_train = np.array(n_train_list)
n_test = np.array(n_test_list)
TSS_1 = np.array(TSS_list1)
ACC_1 = np.array(ACC_list1)
Recall_1 = np.array(Recall_list1)
Precision_1 = np.array(Precision_list1)
AUC_1 = np.array(AUC_list1)
Kappa_1 = np.array(Kappa_list1)
F1_1 = np.array(F1_list1)
train_test_ratio = np.array(train_test_ratio_list)

best_epoch2 = np.array(best_epoch_list2)
TSS_2 = np.array(TSS_list2)
ACC_2 = np.array(ACC_list2)
Recall_2 = np.array(Recall_list2)
Precision_2 = np.array(Precision_list2)
AUC_2 = np.array(AUC_list2)
Kappa_2 = np.array(Kappa_list2)
F1_2 = np.array(F1_list2)


best_epoch3 = np.array(best_epoch_list3)
TSS_3 = np.array(TSS_list3)
ACC_3 = np.array(ACC_list3)
Recall_3 = np.array(Recall_list3)
Precision_3 = np.array(Precision_list3)
AUC_3 = np.array(AUC_list3)
Kappa_3 = np.array(Kappa_list3)
F1_3 = np.array(F1_list3)


best_epoch4 = np.array(best_epoch_list4)
TSS_4 = np.array(TSS_list4)
ACC_4 = np.array(ACC_list4)
Recall_4 = np.array(Recall_list4)
Precision_4 = np.array(Precision_list4)
AUC_4 = np.array(AUC_list4)
Kappa_4 = np.array(Kappa_list4)
F1_4 = np.array(F1_list4)


restr_ratio = np.where((train_test_ratio >= 0.05) & (train_test_ratio <= .5)) #actual true numbers

kitter_reduced = kitter[restr_ratio]
best_epoch_reduced1 = best_epoch1[restr_ratio]
test_obs_reduced = test_obs[restr_ratio]
n_train_reduced = n_train[restr_ratio]
n_test_reduced = n_test[restr_ratio]
TSS_reduced_1 = TSS_1[restr_ratio]
ACC_reduced_1 = ACC_1[restr_ratio]
Recall_reduced_1 = Recall_1[restr_ratio]
Precision_reduced_1 = Precision_1[restr_ratio]
AUC_reduced_1 = AUC_1[restr_ratio]
Kappa_reduced_1 = Kappa_1[restr_ratio]
F1_reduced_1 = F1_1[restr_ratio]
train_test_ratio_reduced = train_test_ratio[restr_ratio]

best_epoch_reduced2 = best_epoch2[restr_ratio]
TSS_reduced_2 = TSS_2[restr_ratio]
ACC_reduced_2 = ACC_2[restr_ratio]
Recall_reduced_2 = Recall_2[restr_ratio]
Precision_reduced_2 = Precision_2[restr_ratio]
AUC_reduced_2 = AUC_2[restr_ratio]
Kappa_reduced_2 = Kappa_2[restr_ratio]
F1_reduced_2 = F1_2[restr_ratio]


best_epoch_reduced3 = best_epoch3[restr_ratio]
TSS_reduced_3 = TSS_3[restr_ratio]
ACC_reduced_3 = ACC_3[restr_ratio]
Recall_reduced_3 = Recall_3[restr_ratio]
Precision_reduced_3 = Precision_3[restr_ratio]
AUC_reduced_3 = AUC_3[restr_ratio]
Kappa_reduced_3 = Kappa_3[restr_ratio]
F1_reduced_3 = F1_3[restr_ratio]


best_epoch_reduced4 = best_epoch4[restr_ratio]
TSS_reduced_4 = TSS_4[restr_ratio]
ACC_reduced_4 = ACC_4[restr_ratio]
Recall_reduced_4 = Recall_4[restr_ratio]
Precision_reduced_4 = Precision_4[restr_ratio]
AUC_reduced_4 = AUC_4[restr_ratio]
Kappa_reduced_4 = Kappa_4[restr_ratio]
F1_reduced_4 = F1_4[restr_ratio]

print(len(kitter), len(kitter_reduced))

## Plotting the results

In [None]:
fig = plt.figure(figsize=(25,5), constrained_layout=False)
rcParams['font.size'] = 16
rcParams['font.family'] = 'serif'
plt.rcParams['lines.linewidth'] = 2
gs = fig.add_gridspec(1, 5, hspace=0, wspace=0)
ax1, ax2, ax3, ax4, ax5 = gs.subplots(sharex='col')#, sharey='row')

["s","P", "D", "*", "o", "v", "X"]


# ax1 = fig.add_subplot(gs[0, 0])
ax1.set_xlabel('Single Layer', fontsize=23)
# plt.errorbar([0, 1, 2, 3, 4, 5, 6], [np.mean(TSS_reduced_1), np.mean(ACC_reduced_1), np.mean(Recall_reduced_1), np.mean(Precision_reduced_1), np.mean(AUC_reduced_1), np.mean(Kappa_reduced_1), np.mean(F1_reduced_1)], yerr=[np.std(TSS_reduced_1), np.std(ACC_reduced_1), np.std(Recall_reduced_1), np.std(Precision_reduced_1), np.std(AUC_reduced_1), np.std(Kappa_reduced_1), np.std(F1_reduced_1)], lw=1.5, fmt="s", capsize=3)

ax1.errorbar([0], [np.mean(TSS_reduced_1)], yerr=[np.std(TSS_reduced_1)], lw=1.5, fmt="s", c="red", capsize=5, elinewidth=3, alpha=1, label='TSS')
ax1.errorbar([1], [np.mean(ACC_reduced_1)], yerr=[np.std(ACC_reduced_1)], lw=1.5, fmt="P", c="orange", capsize=5, elinewidth=3, alpha=1, label='ACC')
ax1.errorbar([2], [np.mean(Recall_reduced_1)], yerr=[np.std(Recall_reduced_1)], lw=1.5, fmt="D", c="dodgerblue", capsize=5, elinewidth=3, alpha=1, label='Recall')
ax1.errorbar([3], [np.mean(Precision_reduced_1)], yerr=[np.std(Precision_reduced_1)], lw=1.5, fmt="*", c="green", capsize=5, elinewidth=3, alpha=1, label='Precision')
ax1.errorbar([4], [np.mean(AUC_reduced_1)], yerr=[np.std(AUC_reduced_1)], lw=1.5, fmt="o", c="royalblue", capsize=5, elinewidth=3, alpha=1, label='AUC')
ax1.errorbar([5], [np.mean(Kappa_reduced_1)], yerr=[np.std(Kappa_reduced_1)], lw=1.5, fmt="v", c="mediumblue", capsize=5, elinewidth=3, alpha=1, label='Kappa')
ax1.errorbar([6], [np.mean(F1_reduced_1)], yerr=[np.std(F1_reduced_1)], lw=1.5, fmt="X", c="brown", capsize=5, elinewidth=3, alpha=1, label='F1')


ax1.set_xticks([])
ax1.set_yticks([])
ax1.yaxis.set_major_locator(MultipleLocator(0.1))
ax1.yaxis.set_minor_locator(MultipleLocator(.05))
ax1.tick_params(which='major', length=8,width=1.5, labelsize=23)
ax1.tick_params(which='minor', length=5,width=1.5, labelsize=23)
ax1.set_ylim(0,1)
ax1.set_xlim(-1,9)

# ax2 = fig.add_subplot(gs[0, 1])
ax2.set_xlabel('Two Layers', fontsize=23)
# plt.errorbar([0, 1, 2, 3, 4, 5, 6], [np.mean(TSS_reduced_1), np.mean(ACC_reduced_1), np.mean(Recall_reduced_1), np.mean(Precision_reduced_1), np.mean(AUC_reduced_1), np.mean(Kappa_reduced_1), np.mean(F1_reduced_1)], yerr=[np.std(TSS_reduced_1), np.std(ACC_reduced_1), np.std(Recall_reduced_1), np.std(Precision_reduced_1), np.std(AUC_reduced_1), np.std(Kappa_reduced_1), np.std(F1_reduced_1)], lw=1.5, fmt="s", capsize=3)

l1 = ax2.errorbar([0], [np.mean(TSS_reduced_2)], yerr=[np.std(TSS_reduced_2)], lw=1.5, fmt="s", c="red", capsize=5, elinewidth=3, alpha=1, label='TSS')
l2 = ax2.errorbar([1], [np.mean(ACC_reduced_2)], yerr=[np.std(ACC_reduced_2)], lw=1.5, fmt="P", c="orange", capsize=5, elinewidth=3, alpha=1, label='ACC')
l3 = ax2.errorbar([2], [np.mean(Recall_reduced_2)], yerr=[np.std(Recall_reduced_2)], lw=1.5, fmt="D", c="dodgerblue", capsize=5, elinewidth=3, alpha=1, label='Recall')
l4 = ax2.errorbar([3], [np.mean(Precision_reduced_2)], yerr=[np.std(Precision_reduced_2)], lw=1.5, fmt="*", c="green", capsize=5, elinewidth=3, alpha=1, label='Precision')
l5 = ax2.errorbar([4], [np.mean(AUC_reduced_2)], yerr=[np.std(AUC_reduced_2)], lw=1.5, fmt="o", c="royalblue", capsize=5, elinewidth=3, alpha=1, label='AUC')
l6 = ax2.errorbar([5], [np.mean(Kappa_reduced_2)], yerr=[np.std(Kappa_reduced_2)], lw=1.5, fmt="v", c="mediumblue", capsize=5, elinewidth=3, alpha=1, label='Kappa')
l7 = ax2.errorbar([6], [np.mean(F1_reduced_2)], yerr=[np.std(F1_reduced_2)], lw=1.5, fmt="X", c="brown", capsize=5, elinewidth=3, alpha=1, label='F1')
# plt.legend()

ax2.set_xticks([])
ax2.set_yticklabels([])
ax2.yaxis.set_major_locator(MultipleLocator(0.1))
ax2.yaxis.set_minor_locator(MultipleLocator(.05))
ax2.tick_params(which='major', length=8,width=1.5)
ax2.tick_params(which='minor', length=5,width=1.5)
ax2.set_ylim(0,1)
ax2.set_xlim(-1,9)

# ax3 = fig.add_subplot(gs[0, 2])
ax3.set_xlabel('Three Layers', fontsize=23)
# plt.errorbar([0, 1, 2, 3, 4, 5, 6], [np.mean(TSS_reduced_1), np.mean(ACC_reduced_1), np.mean(Recall_reduced_1), np.mean(Precision_reduced_1), np.mean(AUC_reduced_1), np.mean(Kappa_reduced_1), np.mean(F1_reduced_1)], yerr=[np.std(TSS_reduced_1), np.std(ACC_reduced_1), np.std(Recall_reduced_1), np.std(Precision_reduced_1), np.std(AUC_reduced_1), np.std(Kappa_reduced_1), np.std(F1_reduced_1)], lw=1.5, fmt="s", capsize=3)

ax3.errorbar([0], [np.mean(TSS_reduced_3)], yerr=[np.std(TSS_reduced_3)], lw=1.5, fmt="s", c="red", capsize=5, alpha=1, elinewidth=3, label='TSS')
ax3.errorbar([1], [np.mean(ACC_reduced_3)], yerr=[np.std(ACC_reduced_3)], lw=1.5, fmt="P", c="orange", capsize=5, alpha=1, elinewidth=3, label='ACC')
ax3.errorbar([2], [np.mean(Recall_reduced_3)], yerr=[np.std(Recall_reduced_3)], lw=1.5, fmt="D", c="dodgerblue", capsize=5, alpha=1, elinewidth=3, label='Recall')
ax3.errorbar([3], [np.mean(Precision_reduced_3)], yerr=[np.std(Precision_reduced_3)], lw=1.5, fmt="*", c="green", capsize=5, alpha=1, elinewidth=3, label='Precision')
ax3.errorbar([4], [np.mean(AUC_reduced_3)], yerr=[np.std(AUC_reduced_3)], lw=1.5, fmt="o", c="royalblue", capsize=5, alpha=1, elinewidth=3, label='AUC')
ax3.errorbar([5], [np.mean(Kappa_reduced_3)], yerr=[np.std(Kappa_reduced_3)], lw=1.5, fmt="v", c="mediumblue", capsize=5, alpha=1, elinewidth=3, label='Kappa')
ax3.errorbar([6], [np.mean(F1_reduced_3)], yerr=[np.std(F1_reduced_3)], lw=1.5, fmt="X", c="brown", capsize=5, alpha=1, elinewidth=3, label='F1')
# plt.legend()

ax3.set_xticks([])
ax3.set_yticklabels([])
ax3.yaxis.set_major_locator(MultipleLocator(0.1))
ax3.yaxis.set_minor_locator(MultipleLocator(.05))
ax3.tick_params(which='major', length=8,width=1.5)
ax3.tick_params(which='minor', length=5,width=1.5)
ax3.set_ylim(0,1)
ax3.set_xlim(-1,9)

# ax4 = fig.add_subplot(gs[0, 3])
ax4.set_xlabel('Convolutional Network', fontsize=23)
# plt.errorbar([0, 1, 2, 3, 4, 5, 6], [np.mean(TSS_reduced_1), np.mean(ACC_reduced_1), np.mean(Recall_reduced_1), np.mean(Precision_reduced_1), np.mean(AUC_reduced_1), np.mean(Kappa_reduced_1), np.mean(F1_reduced_1)], yerr=[np.std(TSS_reduced_1), np.std(ACC_reduced_1), np.std(Recall_reduced_1), np.std(Precision_reduced_1), np.std(AUC_reduced_1), np.std(Kappa_reduced_1), np.std(F1_reduced_1)], lw=1.5, fmt="s", capsize=3)

ax4.errorbar([0], [np.mean(TSS_reduced_4)], yerr=[np.std(TSS_reduced_4)], lw=1.5, fmt="s", c="red", capsize=5, elinewidth=3, alpha=1, label='TSS')
ax4.errorbar([1], [np.mean(ACC_reduced_4)], yerr=[np.std(ACC_reduced_4)], lw=1.5, fmt="P", c="orange", capsize=5, elinewidth=3, alpha=1, label='ACC')
ax4.errorbar([2], [np.mean(Recall_reduced_4)], yerr=[np.std(Recall_reduced_4)], lw=1.5, fmt="D", c="dodgerblue", capsize=5, elinewidth=3, alpha=1, label='Recall')
ax4.errorbar([3], [np.mean(Precision_reduced_4)], yerr=[np.std(Precision_reduced_4)], lw=1.5, fmt="*", c="green", capsize=5, elinewidth=3, alpha=1, label='Precision')
ax4.errorbar([4], [np.mean(AUC_reduced_4)], yerr=[np.std(AUC_reduced_4)], lw=1.5, fmt="o", c="royalblue", capsize=5, elinewidth=3, alpha=1, label='AUC')
ax4.errorbar([5], [np.mean(Kappa_reduced_4)], yerr=[np.std(Kappa_reduced_4)], lw=1.5, fmt="v", c="mediumblue", capsize=5, elinewidth=3, alpha=1, label='Kappa')
ax4.errorbar([6], [np.mean(F1_reduced_4)], yerr=[np.std(F1_reduced_4)], lw=1.5, fmt="X", c="brown", capsize=5, elinewidth=3, alpha=1, label='F1')
# plt.legend()

ax4.set_xticks([])
ax4.set_yticklabels([])
ax4.yaxis.set_major_locator(MultipleLocator(0.1))
ax4.yaxis.set_minor_locator(MultipleLocator(.05))
ax4.tick_params(which='major', length=8, width=1.5)
ax4.tick_params(which='minor', length=5, width=1.5)
ax4.set_ylim(0,1)
ax4.set_xlim(-1,9)

# ax5 = fig.add_subplot(gs[0, 4])
plt.legend([l1,l2,l3,l4,l5,l6,l7],['TSS','ACC','Recall','Precision','AUC','Kappa','F1'], prop={'size': 20})
ax5.set_xticks([])
ax5.set_yticks([])
fig.savefig(f'/sml/zbindenj/{line}/stats/stats_validation.png')
plt.show()


fig2 = plt.figure(figsize=(20,5), constrained_layout=False)
rcParams['font.size'] = 16
rcParams['font.family'] = 'serif'
plt.rcParams['lines.linewidth'] = 2
gs = fig2.add_gridspec(1, 6)

ax = fig2.add_subplot(gs[0, 0])
plt.ylabel('Train/Test ratio', fontsize=23)
plt.errorbar([0], [np.mean(train_test_ratio_reduced)], yerr=[np.std(train_test_ratio_reduced)], lw=1.5, fmt="s", c="k", capsize=3, elinewidth=3, alpha=.7, label='Train/Test ratio')
sns.swarmplot(data=train_test_ratio_reduced, color="r")
# plt.scatter(np.zeros([3]), [np.mean(np.array(ACC_reduced_1)), np.max(np.array(ACC_reduced_1)), np.min(np.array(ACC_reduced_1))], c='r', s=100, marker='x')
ax.set_xticks([])
# ax.yaxis.set_major_locator(MultipleLocator(0.1))
# ax.yaxis.set_minor_locator(MultipleLocator(.05))
# ax.tick_params(which='major', length=8,width=1.5)
# ax.tick_params(which='minor', length=5,width=1.5)

ax = fig2.add_subplot(gs[0, 1])
plt.ylabel('Number of test spectra', fontsize=23)
plt.errorbar([0], [np.mean(n_test_reduced)], yerr=[np.std(n_test_reduced)], lw=1.5, fmt="s", c="k", capsize=3, elinewidth=3, alpha=.7, label='Number of test spectra')
sns.swarmplot(data=n_test_reduced, color="r")
# sns.boxplot( data=TSS_reduced_2, color="k", boxprops=dict(alpha=.3))
# sns.swarmplot(data=TSS_reduced_2, color="k")
# plt.scatter(np.zeros([3]), [np.mean(np.array(ACC_reduced_1)), np.max(np.array(ACC_reduced_1)), np.min(np.array(ACC_reduced_1))], c='r', s=100, marker='x')
ax.set_xticks([])
# ax.yaxis.set_major_locator(MultipleLocator(0.1))
# ax.yaxis.set_minor_locator(MultipleLocator(.05))
# ax.tick_params(which='major', length=8,width=1.5)
# ax.tick_params(which='minor', length=5,width=1.5)

ax = fig2.add_subplot(gs[0, 2])
plt.ylabel('Number of epochs', fontsize=23)
plt.errorbar([0], [np.mean(best_epoch_reduced1)], yerr=[np.std(best_epoch_reduced1)], lw=1.5, fmt="s", c="k", capsize=3, elinewidth=3, alpha=.7, label='Number of epochs')
sns.swarmplot(data=best_epoch_reduced1, color="r")
# sns.boxplot( data=TSS_reduced_2, color="k", boxprops=dict(alpha=.3))
# sns.swarmplot(data=TSS_reduced_2, color="k")
# plt.scatter(np.zeros([3]), [np.mean(np.array(ACC_reduced_1)), np.max(np.array(ACC_reduced_1)), np.min(np.array(ACC_reduced_1))], c='r', s=100, marker='x')
ax.set_xticks([])
ax.yaxis.set_major_locator(MultipleLocator(5))
ax.set_ylim([0,10])
# ax.yaxis.set_minor_locator(MultipleLocator(.05))
ax.tick_params(which='major', length=8,width=1.5)
# ax.tick_params(which='minor', length=5,width=1.5)

ax = fig2.add_subplot(gs[0, 3])
plt.ylabel('Number of epochs', fontsize=23)
plt.errorbar([0], [np.mean(best_epoch_reduced2)], yerr=[np.std(best_epoch_reduced2)], lw=1.5, fmt="s", c="k", capsize=3, elinewidth=3, alpha=.7, label='Number of epochs')
sns.swarmplot(data=best_epoch_reduced2, color="r")
# sns.boxplot( data=TSS_reduced_2, color="k", boxprops=dict(alpha=.3))
# sns.swarmplot(data=TSS_reduced_2, color="k")
# plt.scatter(np.zeros([3]), [np.mean(np.array(ACC_reduced_1)), np.max(np.array(ACC_reduced_1)), np.min(np.array(ACC_reduced_1))], c='r', s=100, marker='x')
ax.set_xticks([])
ax.yaxis.set_major_locator(MultipleLocator(5))
ax.set_ylim([0,10])
# ax.yaxis.set_minor_locator(MultipleLocator(.05))
ax.tick_params(which='major', length=8,width=1.5)
# ax.tick_params(which='minor', length=5,width=1.5)

ax = fig2.add_subplot(gs[0, 4])
plt.ylabel('Number of epochs', fontsize=23)
plt.errorbar([0], [np.mean(best_epoch_reduced3)], yerr=[np.std(best_epoch_reduced3)], lw=1.5, fmt="s", c="k", capsize=3, elinewidth=3, alpha=.7, label='Number of epochs')
sns.swarmplot(data=best_epoch_reduced3, color="r")
# sns.boxplot( data=TSS_reduced_2, color="k", boxprops=dict(alpha=.3))
# sns.swarmplot(data=TSS_reduced_2, color="k")
# plt.scatter(np.zeros([3]), [np.mean(np.array(ACC_reduced_1)), np.max(np.array(ACC_reduced_1)), np.min(np.array(ACC_reduced_1))], c='r', s=100, marker='x')
ax.set_xticks([])
ax.yaxis.set_major_locator(MultipleLocator(5))
ax.set_ylim([0,10])
# ax.yaxis.set_minor_locator(MultipleLocator(.05))
ax.tick_params(which='major', length=8,width=1.5)
# ax.tick_params(which='minor', length=5,width=1.5)

ax = fig2.add_subplot(gs[0, 5])
plt.ylabel('Number of epochs', fontsize=23)
plt.errorbar([0], [np.mean(best_epoch_reduced4)], yerr=[np.std(best_epoch_reduced4)], lw=1.5, fmt="s", c="k", capsize=3, elinewidth=3, alpha=.7, label='Number of epochs')
sns.swarmplot(data=best_epoch_reduced4, color="r")
# sns.boxplot( data=TSS_reduced_2, color="k", boxprops=dict(alpha=.3))
# sns.swarmplot(data=TSS_reduced_2, color="k")
# plt.scatter(np.zeros([3]), [np.mean(np.array(ACC_reduced_1)), np.max(np.array(ACC_reduced_1)), np.min(np.array(ACC_reduced_1))], c='r', s=100, marker='x')
ax.set_xticks([])
ax.set_ylim([0,10])
ax.yaxis.set_major_locator(MultipleLocator(5))
# ax.yaxis.set_minor_locator(MultipleLocator(.05))
ax.tick_params(which='major', length=8,width=1.5)
# ax.tick_params(which='minor', length=5,width=1.5)

plt.tight_layout()
plt.show()