In [None]:
# Import

import pandas as pd
import numpy as np
import os
import re
import sys 
import pickle
from tqdm.notebook import tqdm as tqdm
from matplotlib import pyplot as plt

In [None]:
# TPU imports

import warnings
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.data_parallel as dp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
import warnings

warnings.filterwarnings("ignore")

# Other imports
import torch
import torchvision
import torchvision.transforms as transforms
import time

In [None]:
# Defining the neural network (CNN + LSTM) 
# !pip install torchsummary

import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
import random



import torch
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence
from typing import *



class VariationalDropout(nn.Module):
    """
    Applies the same dropout mask across the temporal dimension
    See https://arxiv.org/abs/1512.05287 for more details.
    Note that this is not applied to the recurrent activations in the LSTM like the above paper.
    Instead, it is applied to the inputs and outputs of the recurrent layer.
    """
    def __init__(self, dropout: float, batch_first: Optional[bool]=False):
        super().__init__()
        self.dropout = dropout
        self.batch_first = batch_first

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.training or self.dropout <= 0.:
            return x

        is_packed = isinstance(x, PackedSequence)
        if is_packed:
            x, batch_sizes = x
            max_batch_size = int(batch_sizes[0])
        else:
            batch_sizes = None
            max_batch_size = x.size(0)

        # Drop same mask across entire sequence
        if self.batch_first:
            m = x.new_empty(max_batch_size, 1, x.size(2), requires_grad=False).bernoulli_(1 - self.dropout)
        else:
            m = x.new_empty(1, max_batch_size, x.size(2), requires_grad=False).bernoulli_(1 - self.dropout)
        x = x.masked_fill(m == 0, 0) / (1 - self.dropout)

        if is_packed:
            return PackedSequence(x, batch_sizes)
        else:
            return x

class LSTM(nn.LSTM):
    def __init__(self, *args, dropouti: float=0.,
                 dropoutw: float=0., dropouto: float=0.,
                 batch_first=True, unit_forget_bias=True, **kwargs):
        super().__init__(*args, **kwargs, batch_first=batch_first)
        self.unit_forget_bias = unit_forget_bias
        self.dropoutw = dropoutw
        self.input_drop = VariationalDropout(dropouti,
                                             batch_first=batch_first)
        self.output_drop = VariationalDropout(dropouto,
                                              batch_first=batch_first)
        self._init_weights()

    def _init_weights(self):
        """
        Use orthogonal init for recurrent layers, xavier uniform for input layers
        Bias is 0 except for forget gate
        """
        for name, param in self.named_parameters():
            if "weight_hh" in name:
                nn.init.orthogonal_(param.data)
            elif "weight_ih" in name:
                nn.init.xavier_uniform_(param.data)
            elif "bias" in name and self.unit_forget_bias:
                nn.init.zeros_(param.data)
                param.data[self.hidden_size:2 * self.hidden_size] = 1

    def _drop_weights(self):
        for name, param in self.named_parameters():
            if "weight_hh" in name:
                getattr(self, name).data = \
                    torch.nn.functional.dropout(param.data, p=self.dropoutw,
                                                training=self.training).contiguous()

    def forward(self, input, hx=None):
        self._drop_weights()
        input = self.input_drop(input)
        seq, state = super().forward(input, hx=hx)
        return self.output_drop(seq), state
    
    




# Separable Convs in Pytorch
# https://gist.github.com/iiSeymour/85a5285e00cbed60537241da7c3b8525

class TCSConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation, padding):
        super(TCSConv1d, self).__init__()
        self.depthwise = nn.Conv1d(in_channels=in_channels, out_channels=in_channels,
                                   kernel_size=kernel_size, dilation=dilation, padding=padding,
                                   groups=in_channels, bias=False)
        self.pointwise = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x



class Net(nn.Module):
      
    
    def __init__(self):
        super(Net, self).__init__()
        
        
        filter_size = 16 
        kernel_size_var = 3  
        
        
        # W:input volume size
        # F:kernel size
        # S:stride
        # P:amount of padding
        # size of output volume = (W-F+2P)/S+1
        
        # to keep the same size, padding = dilation * (kernel - 1) / 2
        

        self.skip = TCSConv1d(in_channels=1, out_channels=filter_size, kernel_size=1,
                              dilation=1, padding=int((1-1)/2))
        
        
        # Drop 0.2

        self.conv_1 = TCSConv1d(in_channels=1, out_channels=filter_size,
                                     kernel_size=kernel_size_var, dilation=1,
                                     padding=int((kernel_size_var-1)/2))
        
        self.bn_1 = nn.BatchNorm1d(filter_size) 
        
        self.drop_1 = nn.Dropout2d(0.2)
        
        
        self.conv_2 = TCSConv1d(in_channels=filter_size, out_channels=filter_size,
                                     kernel_size=kernel_size_var, dilation=2,
                                     padding=int(2*(kernel_size_var-1)/2))
        
        self.bn_2 = nn.BatchNorm1d(filter_size) 
        
        self.drop_2 = nn.Dropout2d(0.2)
        
        
        self.conv_3 = TCSConv1d(in_channels=filter_size, out_channels=filter_size,
                                     kernel_size=kernel_size_var, dilation=4,
                                     padding=int(4*(kernel_size_var-1)/2))
        
        self.bn_3 = nn.BatchNorm1d(filter_size) 
        
        self.drop_3 = nn.Dropout2d(0.2)  
        
        
        self.conv_4 = TCSConv1d(in_channels=filter_size, out_channels=filter_size,
                                     kernel_size=kernel_size_var, dilation=8,
                                     padding=int(8*(kernel_size_var-1)/2))
        
        self.bn_4 = nn.BatchNorm1d(filter_size) 
        
        self.drop_4 = nn.Dropout2d(0.2)
        
        self.conv_4b = TCSConv1d(in_channels=filter_size, out_channels=filter_size,
                                     kernel_size=kernel_size_var, dilation=16,
                                     padding=int(16*(kernel_size_var-1)/2))
        
        self.bn_4b = nn.BatchNorm1d(filter_size) 
        
        self.drop_4b = nn.Dropout2d(0.2)
        
        
        self.avgPool_a = nn.AvgPool1d(kernel_size=4)
        
        
        
        # Drop 0.1
        
        
        self.conv_5 = TCSConv1d(in_channels=filter_size, out_channels=filter_size,
                                     kernel_size=kernel_size_var, dilation=1,
                                     padding=int((kernel_size_var-1)/2))
        
        self.bn_5 = nn.BatchNorm1d(filter_size) 
        
        self.drop_5 = nn.Dropout2d(0.1)
        
        
        self.conv_6 = TCSConv1d(in_channels=filter_size, out_channels=filter_size,
                                     kernel_size=kernel_size_var, dilation=3,
                                     padding=int(3*(kernel_size_var-1)/2))
        
        self.bn_6 = nn.BatchNorm1d(filter_size) 
        
        self.drop_6 = nn.Dropout2d(0.1)
        
        
        self.conv_7 = TCSConv1d(in_channels=filter_size, out_channels=filter_size,
                                     kernel_size=kernel_size_var, dilation=5,
                                     padding=int(5*(kernel_size_var-1)/2))
        
        self.bn_7 = nn.BatchNorm1d(filter_size) 
        
        self.drop_7 = nn.Dropout2d(0.1)  
        
        
        self.conv_8 = TCSConv1d(in_channels=filter_size, out_channels=filter_size,
                                     kernel_size=kernel_size_var, dilation=9,
                                     padding=int(9*(kernel_size_var-1)/2))
        
        self.bn_8 = nn.BatchNorm1d(filter_size) 
        
        self.drop_8 = nn.Dropout2d(0.1)
        
        
        self.conv_8b = TCSConv1d(in_channels=filter_size, out_channels=filter_size,
                                     kernel_size=kernel_size_var, dilation=16,
                                     padding=int(16*(kernel_size_var-1)/2))
        
        self.bn_8b = nn.BatchNorm1d(filter_size) 
        
        self.drop_8b = nn.Dropout2d(0.1)
        
        self.avgPool_b = nn.AvgPool1d(kernel_size=4)
        
        
        #0.3
        self.conv_9 = TCSConv1d(in_channels=filter_size, out_channels=filter_size,
                                     kernel_size=kernel_size_var, dilation=1,
                                     padding=int((kernel_size_var-1)/2))
        
        self.bn_9 = nn.BatchNorm1d(filter_size) 
        
        self.drop_9 = nn.Dropout2d(0.3)
        
        
        self.conv_10 = TCSConv1d(in_channels=filter_size, out_channels=filter_size,
                                     kernel_size=kernel_size_var, dilation=3,
                                     padding=int(3*(kernel_size_var-1)/2))
        
        self.bn_10 = nn.BatchNorm1d(filter_size) 
        
        self.drop_10 = nn.Dropout2d(0.3)
        
        
        self.conv_11 = TCSConv1d(in_channels=filter_size, out_channels=filter_size,
                                     kernel_size=kernel_size_var, dilation=5,
                                     padding=int(5*(kernel_size_var-1)/2))
        
        self.bn_11 = nn.BatchNorm1d(filter_size) 
        
        self.drop_11 = nn.Dropout2d(0.3)  
        
        
        self.conv_12 = TCSConv1d(in_channels=filter_size, out_channels=filter_size,
                                     kernel_size=kernel_size_var, dilation=9,
                                     padding=int(9*(kernel_size_var-1)/2))
        
        self.bn_12 = nn.BatchNorm1d(filter_size) 
        
        self.drop_12 = nn.Dropout2d(0.3)  
        
        self.conv_12b = TCSConv1d(in_channels=filter_size, out_channels=filter_size,
                                     kernel_size=kernel_size_var, dilation=16,
                                     padding=int(16*(kernel_size_var-1)/2))
        
        self.bn_12b = nn.BatchNorm1d(filter_size) 
        
        self.drop_12b = nn.Dropout2d(0.3)
        
        self.avgPool_c = nn.AvgPool1d(kernel_size=5)
        
        
        #0.5
        self.conv_13 = TCSConv1d(in_channels=filter_size, out_channels=filter_size,
                                     kernel_size=kernel_size_var, dilation=1,
                                     padding=int((kernel_size_var-1)/2))
        
        self.bn_13 = nn.BatchNorm1d(filter_size) 
        
        self.drop_13 = nn.Dropout2d(0.3)
        
        
        self.conv_14 = TCSConv1d(in_channels=filter_size, out_channels=filter_size,
                                     kernel_size=kernel_size_var, dilation=2,
                                     padding=int(2*(kernel_size_var-1)/2))
        
        self.bn_14 = nn.BatchNorm1d(filter_size) 
        
        self.drop_14 = nn.Dropout2d(0.3)
        
        
        self.conv_15 = TCSConv1d(in_channels=filter_size, out_channels=filter_size,
                                     kernel_size=kernel_size_var, dilation=4,
                                     padding=int(4*(kernel_size_var-1)/2))
        
        self.bn_15 = nn.BatchNorm1d(filter_size) 
        
        self.drop_15 = nn.Dropout2d(0.3)  
        
        
        self.conv_16 = TCSConv1d(in_channels=filter_size, out_channels=filter_size,
                                     kernel_size=kernel_size_var, dilation=8,
                                     padding=int(8*(kernel_size_var-1)/2))
        
        self.bn_16 = nn.BatchNorm1d(filter_size) 
        
        self.drop_16 = nn.Dropout2d(0.3) 
        
        
        self.conv_16b = TCSConv1d(in_channels=filter_size, out_channels=filter_size,
                                     kernel_size=kernel_size_var, dilation=16,
                                     padding=int(16*(kernel_size_var-1)/2))
        
        self.bn_16b = nn.BatchNorm1d(filter_size) 
        
        self.drop_16b = nn.Dropout2d(0.3)
        


        self.conv_17 = TCSConv1d(in_channels=filter_size, out_channels=1,
                                     kernel_size=1, dilation=1,
                                     padding=int((1-1)/2))
        
        self.bn_17 = nn.BatchNorm1d(1) 
        
        
        
        
        self.lstm_1 = LSTM(input_size=filter_size, hidden_size=128, num_layers=1, bidirectional=True, batch_first=True, dropouti=0.1)
        
        
        
        self.fc_1 = nn.Linear(768, 1024) 
        
        self.bn_fc_1 = nn.BatchNorm1d(1024) 
                
        self.drop_fc_1 = nn.Dropout(0.5)  
        
        
        self.fc_2 = nn.Linear(1024, 512)
        
        self.bn_fc_2 = nn.BatchNorm1d(512) 
        
        self.drop_fc_2 = nn.Dropout(0.5)   
        
        
        self.fc_3 = nn.Linear(256, 256)
        
        self.bn_fc_3 = nn.BatchNorm1d(256) 
        
        self.drop_fc_3 = nn.Dropout(0.5)   
        
        
        self.fc_4 = nn.Linear(256, 256)
        
        self.bn_fc_4 = nn.BatchNorm1d(256) 
        
        self.drop_fc_4 = nn.Dropout(0.5)   
        

        self.lstm_spo2 = LSTM(input_size=1, hidden_size=256, num_layers=1, bidirectional=True, batch_first=True, dropouto=0.1, dropouti=0.1)
        
        self.ln = nn.LayerNorm(1024)
        
        
        self.fc_5 = nn.Linear(1024, 60)

        
        
        
    def forward(self, x, y):
        skip_conn = self.skip(x)
        
        
        x = self.drop_1(F.relu(self.bn_1(self.conv_1(x))))
        skip_conn = skip_conn.add(x)
                
        x = self.drop_2(F.relu(self.bn_2(self.conv_2(skip_conn))))
        skip_conn = skip_conn.add(x)
        
        x = self.drop_3(F.relu(self.bn_3(self.conv_3(skip_conn))))
        skip_conn = skip_conn.add(x)
        
        x = self.drop_4(F.relu(self.bn_4(self.conv_4(skip_conn))))
        skip_conn = skip_conn.add(x)
        

        skip_conn = self.avgPool_a(skip_conn)
        

        x = self.drop_5(F.relu(self.bn_5(self.conv_5(skip_conn))))
        skip_conn = skip_conn.add(x)
        
        x = self.drop_6(F.relu(self.bn_6(self.conv_6(skip_conn))))
        skip_conn = skip_conn.add(x)
        
        x = self.drop_7(F.relu(self.bn_7(self.conv_7(skip_conn))))
        skip_conn = skip_conn.add(x)
        
        x = self.drop_8(F.relu(self.bn_8(self.conv_8(skip_conn))))
        skip_conn = skip_conn.add(x)
       

        skip_conn = self.avgPool_b(skip_conn)
        
        
        x = self.drop_9(F.relu(self.bn_9(self.conv_9(skip_conn))))
        skip_conn = skip_conn.add(x)
        
        x = self.drop_10(F.relu(self.bn_10(self.conv_10(skip_conn))))
        skip_conn = skip_conn.add(x)
        
        x = self.drop_11(F.relu(self.bn_11(self.conv_11(skip_conn))))
        skip_conn = skip_conn.add(x)
        
        x = self.drop_12(F.relu(self.bn_12(self.conv_12(skip_conn))))
        skip_conn = skip_conn.add(x)
       

        skip_conn = self.avgPool_c(skip_conn)    
        
        
        x = self.drop_13(F.relu(self.bn_13(self.conv_13(skip_conn))))
        skip_conn = skip_conn.add(x)
        
        x = self.drop_14(F.relu(self.bn_14(self.conv_14(skip_conn))))
        skip_conn = skip_conn.add(x)
        
        x = self.drop_15(F.relu(self.bn_15(self.conv_15(skip_conn))))
        skip_conn = skip_conn.add(x)
        
        x = self.drop_16(F.relu(self.bn_16(self.conv_16(skip_conn))))
        skip_conn = skip_conn.add(x)  
        
        
        x = x.permute(0, 2, 1)
                
        
        x, states = self.lstm_1(x)        
        x = x[:, -1, :]
        
        y, states = self.lstm_spo2(y)
        y = y[:, -1, :]
        
        
        
        x = torch.cat((x, y), 1)
        
        x = x.view(-1, 768)
        
        
        x = self.drop_fc_1(F.relu(self.ln(self.fc_1(x))))
        
        x = self.fc_5(x) 
        
        
        return x

  
    
# Helper function that is used to initialize the weights of the model
def init_weights(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d):
        if "fc_5" in str(m):
            nn.init.xavier_uniform_(m.weight)
        else:
            nn.init.kaiming_normal_(m.weight)
        nn.init.constant_(m.bias, 0.)
    
    
# Helper function that is used to (re)define the model and the optimizer from scratch
# https://pytorch.org/docs/stable/notes/randomness.html
def reinit_model():
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    net = Net()
    #net.apply(init_weights)
    return net



model = reinit_model()

print(model)


In [None]:
# Weighted hinge loss function
# input: specify the weight of class 1, with respect to class 0
# the idea is that the weight is 1 for instances of class 0, and w_1_class for instances of class 1
# inspired from: https://stackoverflow.com/questions/55754976/weighted-hinge-loss-function
def weightedSquaredHingeLoss(inp, tar, device, w_1_class=1):
    return torch.sum(torch.mean((torch.max(tar, torch.zeros(inp.shape[1], dtype=torch.float32).to(device))*(w_1_class-1)+1) * torch.max(1. - tar * inp, torch.zeros(inp.shape[1], dtype=torch.float32).to(device))**2, dim=-1))


In [None]:
# Defining some helper functions for the model that guide the training
import math


def reduce_fn_avg(vals):
    # take average
    return sum(vals) / len(vals)


def reduce_fn_sum(vals):
    # take average
    return sum(vals)




def _run(model, EPOCHS, param, training_data_in, validation_data_in=None):
    
    xm.set_rng_state(42)
    xm.save(xm.get_rng_state(), 'xm_seed')
      

    def train_fn(train_dataloader, model, optimizer, criterion, device, lr_scheduler=None):
        
        xm.set_rng_state(torch.load('xm_seed'), device=device)
        xm.master_print(xm.get_rng_state())

        running_loss = 0.
        running_tp = 0.
        running_tn = 0.
        running_fp = 0.
        running_fn = 0.
        running_instances = 0.

        # training() is a kind of switch for some specific layers/parts of the model that behave
        # differently during training and inference (evaluating) time
        # For example, Dropouts Layers, BatchNorm Layers etc. 
        model.train()

        for batch_idx, (ecg, spo2, labels) in enumerate(train_dataloader, 1):

            optimizer.zero_grad() # need to zero out the gradients every time, otherwise they accumulate
            ecg = ecg.to(device) # transfer the data to the computing device
            spo2 = spo2.to(device) # transfer the data to the computing device
            labels = labels.to(device)# transfer the labels to the computing device     
                
            outputs = model(ecg, spo2)
            
            loss = criterion(outputs, labels, device, imbalanced_ratio)
            
            #xm.master_print(f'Batch: {batch_idx}, loss: {loss.item()}')
                        
            loss.backward() # calculate the gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            xm.optimizer_step(optimizer) # update the network weights
                                                
            running_loss += loss.item()*len(labels)
            running_instances += len(labels)
            
            predicted = torch.where(outputs.data > 0, torch.from_numpy(np.asarray([1])).to(device), torch.from_numpy(np.asarray([0])).to(device))
            labels = torch.where(labels > 0, torch.from_numpy(np.asarray([1])).to(device), torch.from_numpy(np.asarray([0])).to(device))
            
            
            fp = ((predicted - labels) == 1.).sum().item() 
            fn = ((predicted - labels) == -1.).sum().item()
            tp = ((predicted + labels) == 2.).sum().item()
            tn = ((predicted + labels) == 0.).sum().item()
            fp_reduced = xm.mesh_reduce('fp_reduce', fp, reduce_fn_sum) 
            fn_reduced = xm.mesh_reduce('fn_reduce', fn, reduce_fn_sum) 
            tp_reduced = xm.mesh_reduce('tp_reduce', tp, reduce_fn_sum) 
            tn_reduced = xm.mesh_reduce('tn_reduce', tn, reduce_fn_sum) 
            
            running_tp += tp_reduced
            running_fp += fp_reduced
            running_tn += tn_reduced
            running_fn += fn_reduced
            
            if lr_scheduler != None:
                lr_scheduler.step()
                
                
        running_loss /= running_instances
        loss_reduced = xm.mesh_reduce('loss_reduced', running_loss, reduce_fn_avg)
        retval = {'loss':  loss_reduced, 
                  'tp':running_tp,
                  'tn':running_tn,
                  'fp':running_fp,
                  'fn':running_fn
                 }
        
        xm.save(xm.get_rng_state(), 'xm_seed')
            
        return retval
            

        
    def valid_fn(valid_dataloader, model, criterion, device):
                        
        running_loss = 0.
        running_tp = 0.
        running_tn = 0.
        running_fp = 0.
        running_fn = 0.
        running_instances = 0.
         
        # eval() is a kind of switch for some specific layers/parts of the model that behave
        # differently during training and inference (evaluating) time
        # For example, Dropouts Layers, BatchNorm Layers etc. 
        model.eval()
        
        for batch_idx, (ecg, spo2, labels) in enumerate(valid_dataloader, 1):

            ecg = ecg.to(device)
            spo2 = spo2.to(device)
            labels = labels.to(device)

            outputs = model(ecg, spo2)
            
            loss = criterion(outputs, labels, device, imbalanced_ratio)
            
            #xm.master_print(f'Batch: {batch_idx}, loss: {loss.item()}')

            running_loss += loss.item()*len(labels)
            running_instances += len(labels)
           
            predicted = torch.where(outputs.data > 0, torch.from_numpy(np.asarray([1])).to(device), torch.from_numpy(np.asarray([0])).to(device))
            labels = torch.where(labels > 0, torch.from_numpy(np.asarray([1])).to(device), torch.from_numpy(np.asarray([0])).to(device))
            
            
            fp = ((predicted - labels) == 1.).sum().item()
            fn = ((predicted - labels) == -1.).sum().item()
            tp = ((predicted + labels) == 2.).sum().item()
            tn = ((predicted + labels) == 0.).sum().item()
            fp_reduced = xm.mesh_reduce('val_fp_reduce', fp, reduce_fn_sum) 
            fn_reduced = xm.mesh_reduce('val_fn_reduce', fn, reduce_fn_sum) 
            tp_reduced = xm.mesh_reduce('val_tp_reduce', tp, reduce_fn_sum) 
            tn_reduced = xm.mesh_reduce('val_tn_reduce', tn, reduce_fn_sum) 
            
            running_tp += tp_reduced
            running_fp += fp_reduced
            running_tn += tn_reduced
            running_fn += fn_reduced
        
        running_loss /= running_instances
        loss_reduced = xm.mesh_reduce('loss_reduced', running_loss, reduce_fn_avg)    
        retval = {'loss': loss_reduced, 
                  'tp':running_tp,
                  'tn':running_tn,
                  'fp':running_fp,
                  'fn':running_fn
                 }
                    
        return retval
    
    
    
    # Defining distributed samplers and data loaders
    train_sampler = torch.utils.data.distributed.DistributedSampler(training_data_in,
                                                                    num_replicas=xm.xrt_world_size(), #numcores
                                                                    rank=xm.get_ordinal(),
                                                                    shuffle=True)
    
    train_dataloader = torch.utils.data.DataLoader(training_data_in, batch_size=BATCH_SIZE, sampler=train_sampler, num_workers=1) # only for GPUs num_workers=2, pin_memory=True)


    
    if validation_data_in != None:
        validation_sampler = torch.utils.data.distributed.DistributedSampler(validation_data_in,
                                                                             num_replicas=xm.xrt_world_size(),
                                                                             rank=xm.get_ordinal(),
                                                                             shuffle=True)

        validation_dataloader = torch.utils.data.DataLoader(validation_data_in, batch_size=BATCH_SIZE, sampler=validation_sampler, num_workers=1)



    # Defining the handle to the TPU device
    device = xm.xla_device()
    # Transferring the model to the computing device
    model.to(device) 
    # Defining the loss function
    criterion = weightedSquaredHingeLoss
    
    #torch.autograd.set_detect_anomaly(True)
    
    # Defining optimizer
    import torch.optim as optim
    optimizer = optim.Adam(model.parameters(), lr=3e-4, amsgrad=False, eps=1e-07) 
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-2, epochs=200, steps_per_epoch=math.ceil(len(trainset) / BATCH_SIZE), base_momentum=0.78, max_momentum=0.99)
    
    # Training code
    
    metrics_history = {"loss":[], "accuracy":[], "precision":[], "recall":[], "f1":[], "specificity":[], "tp":[], "tn":[], "fp":[], "fn":[],
                       "val_loss":[], "val_accuracy":[], "val_precision":[], "val_recall":[], "val_f1":[], "val_specificity":[], "val_tp":[], "val_tn":[], "val_fp":[], "val_fn":[]}
    
    train_begin = time.time()
    for epoch in range(EPOCHS):
        start = time.time()
        para_loader = pl.ParallelLoader(train_dataloader, [device], fixed_batch_size=True) # needed for parallel training

        xm.master_print("EPOCH:", epoch+1)

        train_metrics = train_fn(train_dataloader=para_loader.per_device_loader(device), 
                                 model=model,
                                 optimizer=optimizer, 
                                 criterion=criterion,
                                 device=device,
                                 lr_scheduler=lr_scheduler)
        
        metrics_history["loss"].append(train_metrics["loss"])
        tr_acc = (train_metrics["tp"] + train_metrics["tn"]) / (train_metrics["tp"] + train_metrics["tn"] + train_metrics["fp"] + train_metrics["fn"])
        metrics_history["accuracy"].append(tr_acc)
        metrics_history["tp"].append(train_metrics["tp"])
        metrics_history["tn"].append(train_metrics["tn"])
        metrics_history["fp"].append(train_metrics["fp"])
        metrics_history["fn"].append(train_metrics["fn"])
        
        precision = train_metrics["tp"] / (train_metrics["tp"] + train_metrics["fp"]) if train_metrics["tp"] > 0 else 0
        recall = train_metrics["tp"] / (train_metrics["tp"] + train_metrics["fn"]) if train_metrics["tp"] > 0 else 0
        specificity = train_metrics["tn"] / (train_metrics["tn"] + train_metrics["fp"]) if train_metrics["tn"] > 0 else 0
        f1 = 2*precision*recall / (precision + recall) if precision*recall > 0 else 0
        metrics_history["precision"].append(precision)
        metrics_history["recall"].append(recall)
        metrics_history["f1"].append(f1)
        metrics_history["specificity"].append(specificity)
        
        #assert train_metrics["tp"] + train_metrics["tn"] + train_metrics["fp"] + train_metrics["fn"] == len(training_data_in)*granularita_apnee  # vale solo se input divisibile per 8
        
        #optimizer_sch.step() ######### <- LR SCHEDULER
        
        
        if validation_data_in != None:    
            # Calculate the metrics on the validation data, in the same way as done for training
            with torch.no_grad(): # don't keep track of the info necessary to calculate the gradients
                para_loader = pl.ParallelLoader(validation_dataloader, [device], fixed_batch_size=True)

                val_metrics = valid_fn(valid_dataloader=para_loader.per_device_loader(device), 
                                       model=model,
                                       criterion=criterion, 
                                       device=device)

                metrics_history["val_loss"].append(val_metrics["loss"])
                val_acc = (val_metrics["tp"] + val_metrics["tn"]) / (val_metrics["tp"] + val_metrics["tn"] + val_metrics["fp"] + val_metrics["fn"])
                metrics_history["val_accuracy"].append(val_acc)
                metrics_history["val_tp"].append(val_metrics["tp"])
                metrics_history["val_tn"].append(val_metrics["tn"])
                metrics_history["val_fp"].append(val_metrics["fp"])
                metrics_history["val_fn"].append(val_metrics["fn"])

                val_precision = val_metrics["tp"] / (val_metrics["tp"] + val_metrics["fp"]) if val_metrics["tp"] > 0 else 0
                val_recall = val_metrics["tp"] / (val_metrics["tp"] + val_metrics["fn"]) if val_metrics["tp"] > 0 else 0
                val_specificity = val_metrics["tn"] / (val_metrics["tn"] + val_metrics["fp"]) if val_metrics["tn"] > 0 else 0
                val_f1 = 2*val_precision*val_recall / (val_precision + val_recall) if val_precision*val_recall > 0 else 0
                metrics_history["val_precision"].append(val_precision)
                metrics_history["val_recall"].append(val_recall)
                metrics_history["val_f1"].append(val_f1)
                metrics_history["val_specificity"].append(val_specificity)
                
            #assert val_metrics["tp"] + val_metrics["tn"] + val_metrics["fp"] + val_metrics["fn"] == len(validation_data_in)*granularita_apnee  # vale solo se input divisibile per 8


            xm.master_print("  > Training/validation loss:", round(train_metrics['loss'], 4), round(val_metrics['loss'], 4))
            xm.master_print("  > Training/validation accuracy:", round(tr_acc, 4), round(val_acc, 4))
            xm.master_print("  > Training/validation precision:", round(precision, 4), round(val_precision, 4))
            xm.master_print("  > Training/validation recall:", round(recall, 4), round(val_recall, 4))
            xm.master_print("  > Training/validation f1:", round(f1, 4), round(val_f1, 4))
            xm.master_print("  > Training/validation specificity:", round(specificity, 4), round(val_specificity, 4))
            xm.master_print("  > TRAIN tp tn fp fn :", train_metrics["tp"], train_metrics["tn"], train_metrics["fp"], train_metrics["fn"])
            xm.master_print("  > VAL tp tn fp fn :", val_metrics["tp"], val_metrics["tn"], val_metrics["fp"], val_metrics["fn"])
        else:
            xm.master_print("  > Training loss:", round(train_metrics['loss'], 4))
            xm.master_print("  > Training accuracy:", round(tr_acc, 4))
            xm.master_print("  > Training precision:", round(precision, 4))
            xm.master_print("  > Training recall:", round(recall, 4))
            xm.master_print("  > Training f1:", round(f1, 4))
            xm.master_print("  > Training specificity:", round(specificity, 4))
            xm.master_print("  > TRAIN tp tn fp fn :", train_metrics["tp"], train_metrics["tn"], train_metrics["fp"], train_metrics["fn"])


        xm.master_print("Completed in:", round(time.time() - start, 1), "seconds \n")

    xm.master_print("Training completed in:", round((time.time()- train_begin)/60, 1), "minutes")    

    
    
    # Save the model weights
    xm.save(
        model.state_dict(), './nnet_model_physio.pt'
    )
    
    # Save the metrics history
    xm.save(metrics_history, 'training_history')
    
    
    

# Prediction function, it runs on the CPU
def predict_(model, data_in): 
    import math
    
    data_dataloader = torch.utils.data.DataLoader(data_in, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)
    
    model.load_state_dict(torch.load('./nnet_model_physio.pt'))
    
    model.eval()
    
    predictions = []
    with torch.no_grad(): # don't keep track of the info necessary to calculate the gradients 
        pbar = tqdm(desc="Minibatches: ", total=math.ceil(len(data_in)/BATCH_SIZE))
        
        for batch_idx, (ecg, spo2, labels) in enumerate(data_dataloader, 1):
            outputs = model(ecg, spo2)
            predicted = outputs.data
            predictions.extend(predicted.numpy())
            pbar.update(1)
        pbar.close()
    
    return np.asarray(predictions)
