# Train DSF e iESPnet

In [1]:
import sys
import os
import torch

import torchaudio.transforms    as T
import torch.optim              as optim
import pandas                   as pd

from torchvision       import transforms
from torch.utils.data  import DataLoader
from torch             import nn


sys.path.append(os.path.abspath(os.path.join('..','..','iESPnet_SRC_main','utilities')))
from Generator         import SeizureDatasetLabelTimev2, scale_spec, permute_spec, smoothing_label
from Model             import iESPnet
from TrainEval         import train_model_opt, test_model, train_model, get_thr_output, get_performance_indices
from IO                import get_spectrogram_2

sys.path.append(os.path.abspath(os.path.join('..','05-Train-Test')))
from utilit_train_test import make_weights_for_balanced_classes

sys.path.append(os.path.abspath(os.path.join('../../..','03 Dynamic-Spatial-Filtering')))
from models                         import DynamicSpatialFilter

In [2]:
torch.manual_seed(0)

<torch._C.Generator at 0x7372519fd0f0>

In [3]:
# direccion donde se encuentran los espectrogramas 

SPE_DIR        = '/media/martin/Disco2/Rns_Data/PITT_PI_EEG/'
meta_data_file = '/media/martin/Disco2/Rns_Data/PITT_PI_EEG/METADATA/allfiles_metadata.csv'

df_meta        = pd.read_csv(meta_data_file)

#### Exploración dataframe

In [4]:
# dataframe exploración

#df_meta.info()

In [5]:
#df_meta.shape

In [6]:
#len(df_meta.rns_id.unique())

In [7]:
# cantidad de datos por pacientes

#df_meta.rns_id.value_counts().head()

In [8]:
# cantidad de datos por pacientes

#df_meta.rns_id.value_counts().tail()

In [9]:
#df_meta[df_meta['rns_id' ]== 'PIT-RNS1713']['label'].value_counts()

## Continuación train

In [10]:
# Variables iESPnet

FREQ_MASK_PARAM = 10
TIME_MASK_PARAN = 20
N_CLASSES       = 1
learning_rate   = 1e-3
batch_size      = 64 #128
epochs          = 20
num_workers     = 4


save_path       = 'models_DSF_iESPnet/'
patients        = df_meta['rns_id'].unique().tolist()

In [11]:
# Variables DSF

denoising          = 'autoreject'   # 'autoreject' 'data_augm' 
model              = 'stager_net'
dsf_type           = 'dsfd'         # 'dsfd' 'dsfm_st'
mlp_input          = 'log_diag_cov'
dsf_soft_thresh    = False
dsf_n_out_channels = None
n_channels         = 4

In [12]:
# hiperparametros DynamicSpatialFilter

'''

hparams1 = {
           "n_channels"        : 4,
           "mlp_input"         : mlp_input,
           "n_out_channels"    : dsf_n_out_channels,
           "apply_soft_thresh" : dsf_soft_thresh,
          }

'''

'\n\nhparams1 = {\n           "n_channels"        : 4,\n           "mlp_input"         : mlp_input,\n           "n_out_channels"    : dsf_n_out_channels,\n           "apply_soft_thresh" : dsf_soft_thresh,\n          }\n\n'

In [13]:
# hiperparametros iESPnet

hparams = {
           "n_cnn_layers" : 3,
           "n_rnn_layers" : 3,
           "rnn_dim"      : [150, 100, 50],
           "n_class"      : N_CLASSES,
           "out_ch"       : [8,8,16],
           "dropout"      : 0.3,
           "learning_rate": learning_rate,
           "batch_size"   : batch_size,
           "num_workers"  : num_workers,
           "epochs"       : epochs
          }

In [14]:
# ejemplo para un unico paciente s = 0 --- patient = PIT-RNS1603

s = 0

In [15]:
model1 = DynamicSpatialFilter(
                              n_channels, 
                              mlp_input            = mlp_input, 
                              n_out_channels       = dsf_n_out_channels, 
                              apply_soft_thresh    = dsf_soft_thresh
                             )

model2 = iESPnet(
                hparams['n_cnn_layers'],
                hparams['n_rnn_layers'],
                hparams['rnn_dim'],
                hparams['n_class'],
                hparams['out_ch'],
                hparams['dropout'],
               )

In [16]:
save_runs        = save_path + patients[s] + '/runs/'
save_models      = save_path + patients[s] + '/models/'
save_predictions = save_path + patients[s] + '/results/'
save_figs        = save_path + patients[s] + '/figs/'

if not os.path.exists(save_path):
    os.makedirs(save_path)
    
if not os.path.exists(save_runs):
    os.makedirs(save_runs)
    
if not os.path.exists(save_models):
    os.makedirs(save_models)
    
if not os.path.exists(save_predictions):
    os.makedirs(save_predictions)
    
if not os.path.exists(save_figs):
    os.makedirs(save_figs)

print('Running training for subject ' + patients[s] + ' [s]: ' + str(s))

Running training for subject PIT-RNS1603 [s]: 0


In [17]:
# define train y test de df_meta

train_df = df_meta.copy()
test_df  = df_meta[df_meta['rns_id'] == patients[s]]
test_df.reset_index(drop=True, inplace=True)
train_df.drop(train_df[train_df['rns_id'] == patients[s]].index, inplace = True)

In [18]:
# Dataloaders creados

train_data = SeizureDatasetLabelTimev2(
                                       file=train_df,
                                       root_dir=SPE_DIR,
                                       transform=None, 
                                       target_transform=smoothing_label(),
                                      )

In [19]:
transform_train = transforms.Compose([
                                        T.FrequencyMasking(FREQ_MASK_PARAM),
                                        T.TimeMasking(TIME_MASK_PARAN), 
                                        permute_spec()                                                                     
                                      ])

aca esta el cambio

In [20]:
# data augmentation only in train data

'''

train_data_trf = SeizureDatasetLabelTimev2(
                                            file=train_df,
                                            root_dir=SPE_DIR,
                                            transform=transform_train1, 
                                            target_transform=smoothing_label() 
                                           )

train_data = torch.utils.data.ConcatDataset([train_data_ori, train_data_trf1])

'''

'\n\ntrain_data_trf = SeizureDatasetLabelTimev2(\n                                            file=train_df,\n                                            root_dir=SPE_DIR,\n                                            transform=transform_train1, \n                                            target_transform=smoothing_label() \n                                           )\n\ntrain_data = torch.utils.data.ConcatDataset([train_data_ori, train_data_trf1])\n\n'

hasta aca

In [21]:
# testing data should be balanced, just be "as it is"

test_data = SeizureDatasetLabelTimev2(
                                      file=test_df,
                                      root_dir=SPE_DIR,
                                      transform=None,
                                      target_transform=smoothing_label()  
                                     )

In [22]:
# se debe balancear train_df
weights = make_weights_for_balanced_classes(train_df, [0,1], n_concat=1)
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))

In [23]:
outputfile = save_models + 'model'

##### Función make_weights_for_balanced_classes

In [24]:
# import numpy as np

#classes = [0, 1]

In [25]:
#class_sample_count = np.zeros(len(classes,), dtype=int)

#for n, cl in enumerate(classes):
#    class_sample_count[n] = sum(train_df.label==cl)
#class_sample_count

In [26]:
#weights = (1 / class_sample_count)
#weights

In [27]:
#target = train_df.label.to_numpy()
#target

In [28]:
#samples_weight = weights[target]
#samples_weight

In [29]:
#n_concat = 2
#for i in range(n_concat):
#    if i == 0:
#        sampler = samples_weight
#        print(sampler.shape)
#    else:
#        sampler = np.hstack((sampler, samples_weight))
#        print(sampler.shape)

### Funcion train_model_opt

In [30]:
avg_train_losses = []
avg_train_accs   = []

In [31]:
use_cuda = torch.cuda.is_available()
device   = torch.device("cuda" if use_cuda else "cpu")
print('Using {} device'.format(device))

Using cuda device


In [32]:
# following pytorch suggestion to speed up training

torch.backends.cudnn.benchmark = True

In [33]:
# variables o atributos de la funcion train_model_opt

# model1, model2, hparams, epochs, train_data, transform_train, sampler, save_path

# model1: DSF
# model2: iESPnet
# hparams: 
# epochs:
# train_data: dataloader de iEEG con smoothing label
# transforn_train: transformación (augmentation) de espectrogramas 
# sampler:
# save_path:


In [34]:
'''El dataloader es un iterador y se debe interpretar como tal, para acceder se debe utilizar o enumerate o next'''

'El dataloader es un iterador y se debe interpretar como tal, para acceder se debe utilizar o enumerate o next'

In [35]:
kwargs = {'num_workers': hparams["num_workers"], 'pin_memory': True} if use_cuda else {}

train_loader = DataLoader(train_data, batch_size=hparams["batch_size"], sampler=sampler, **kwargs)

In [36]:
#move model1 to device

model1.to(device)

DynamicSpatialFilter(
  (feat_extractor): SpatialFeatureExtractor()
  (mlp): Sequential(
    (0): Linear(in_features=4, out_features=4, bias=True)
    (1): ReLU()
    (2): Linear(in_features=4, out_features=20, bias=True)
  )
)

In [37]:
#move model1 to device

model2.to(device)

iESPnet(
  (freqcnn): Sequential(
    (0): Conv2d(4, 8, kernel_size=(120, 1), stride=(1, 1), padding=(119, 0), dilation=(2, 1), bias=False)
    (1): ReLU()
    (2): InstanceNorm2d(8, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  )
  (timecnn): Sequential(
    (0): Conv2d(4, 8, kernel_size=(1, 181), stride=(1, 1), padding=(0, 180), dilation=(1, 2), bias=False)
    (1): ReLU()
    (2): InstanceNorm2d(8, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  )
  (cnn_ori): Conv2d(4, 16, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), bias=False)
  (cnn): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), bias=False)
  (rescnn_layers): Sequential(
    (0): ResidualCNNbatch(
      (cnn1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (cnn2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (dropout1): Dropout(p=0.3, inplace=False)
      (dropout2): Dropout(p=0.3, inpla

In [38]:
print('Num Model Parameters', sum([param1.nelement() for param1 in model1.parameters()]))

Num Model Parameters 120


In [39]:
print('Num Model Parameters', sum([param2.nelement() for param2 in model2.parameters()]))

Num Model Parameters 1654837


In [40]:
optimizer1 = optim.AdamW(model1.parameters(), hparams['learning_rate'], weight_decay=1e-4)

optimizer2 = optim.AdamW(model2.parameters(), hparams['learning_rate'], weight_decay=1e-4)

In [41]:
scheduler1 = optim.lr_scheduler.OneCycleLR(
                                            optimizer1, 
                                            max_lr          = hparams['learning_rate'], 
                                            steps_per_epoch = int(len(train_loader)),
                                            epochs          = hparams['epochs'],
                                            anneal_strategy = 'linear'
                                            )

In [42]:
scheduler2 = optim.lr_scheduler.OneCycleLR(
                                            optimizer2, 
                                            max_lr          = hparams['learning_rate'], 
                                            steps_per_epoch = int(len(train_loader)),
                                            epochs          = hparams['epochs'],
                                            anneal_strategy = 'linear'
                                            )

In [43]:
criterion = nn.BCEWithLogitsLoss().to(device)

In [44]:
'''

for epoch in range(1, epochs + 1):
    # agregar todos los argumentos necesarios
    train_losses, train_aucpr = training_DSF_iESPnet(model, device, train_loader, transform_train, criterion, optimizer, scheduler, epoch)
    
    train_loss = np.average(train_losses)

    avg_train_losses.append(train_loss)
    
    avg_train_accs.append(train_aucpr)
    
'''

'\n\nfor epoch in range(1, epochs + 1):\n    # agregar todos los argumentos necesarios\n    train_losses, train_aucpr = training_DSF_iESPnet(model, device, train_loader, transform_train, criterion, optimizer, scheduler, epoch)\n    \n    train_loss = np.average(train_losses)\n\n    avg_train_losses.append(train_loss)\n    \n    avg_train_accs.append(train_aucpr)\n    \n'

##### Función training_DSF_iESPnet

In [45]:
# variables o atributos de la funcion training_DSF_iESPnet

# model1, model2, device, train_loader, transform_train, criterion, optimizer, scheduler, epoch

In [46]:
train_loss = 0.0
train_losses = []
cont = 0

In [47]:
model1.train()

DynamicSpatialFilter(
  (feat_extractor): SpatialFeatureExtractor()
  (mlp): Sequential(
    (0): Linear(in_features=4, out_features=4, bias=True)
    (1): ReLU()
    (2): Linear(in_features=4, out_features=20, bias=True)
  )
)

In [48]:
model2.train()

iESPnet(
  (freqcnn): Sequential(
    (0): Conv2d(4, 8, kernel_size=(120, 1), stride=(1, 1), padding=(119, 0), dilation=(2, 1), bias=False)
    (1): ReLU()
    (2): InstanceNorm2d(8, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  )
  (timecnn): Sequential(
    (0): Conv2d(4, 8, kernel_size=(1, 181), stride=(1, 1), padding=(0, 180), dilation=(1, 2), bias=False)
    (1): ReLU()
    (2): InstanceNorm2d(8, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  )
  (cnn_ori): Conv2d(4, 16, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), bias=False)
  (cnn): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), bias=False)
  (rescnn_layers): Sequential(
    (0): ResidualCNNbatch(
      (cnn1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (cnn2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (dropout1): Dropout(p=0.3, inplace=False)
      (dropout2): Dropout(p=0.3, inpla

In [49]:
for batch_idx, _data in enumerate(train_loader):

    cont+=1
    eeg, labels = _data
    break

In [50]:
eeg.shape

torch.Size([64, 4, 22500])

In [51]:
labels.shape

torch.Size([64, 181])

In [52]:
eeg, labels = eeg.to(device), labels.to(device)

In [53]:
optimizer1.zero_grad(set_to_none=True)
optimizer2.zero_grad(set_to_none=True)

In [54]:
outputs1 = model1(eeg)

---------------------------------------------
DynamicSpatialFilter input shape: torch.Size([64, 4, 22500])
After reshaping: torch.Size([64, 1, 4, 22500])
SpatialFeatureExtractor input shape: torch.Size([64, 1, 4, 22500])
SpatialFeatureExtractor output shape: torch.Size([64, 1, 4])
MLP output shape: torch.Size([64, 1, 20])
W shape: torch.Size([64, 1, 4, 4])
Bias shape: torch.Size([64, 1, 4, 1])
Output shape: torch.Size([64, 1, 4, 22500])


In [55]:
# reshape para tener el mismo tamaño que necesita el espectrograma

outputs1 = outputs1.squeeze(1)
print(outputs1.shape)

torch.Size([64, 4, 22500])


In [56]:
outputs1 = outputs1.to('cpu')

In [57]:
# variables y atributos para crear el espectrograma

# signal, fs, n_fft = 256, win_len = None, hop_len = None, power = 2.0

ECOG_SAMPLE_RATE = 250
ECOG_CHANNELS    = 4
TT               = 1000 # window length
SPEC_WIN_LEN     = int(ECOG_SAMPLE_RATE * TT / 1000 ) # win size
overlap          = 500 
SPEC_HOP_LEN     = int(ECOG_SAMPLE_RATE * (TT - overlap) / 1000) # Length of hop between windows.
SPEC_NFFT        = 500  # to see changes in 0.5 reso
top_db           = 40.0

In [58]:
spectrograms = get_spectrogram_2(outputs1, ECOG_SAMPLE_RATE, SPEC_NFFT, SPEC_WIN_LEN, SPEC_HOP_LEN, top_db)

In [59]:
print(spectrograms.shape)

(64, 4, 120, 181)


In [60]:
spectrograms = torch.from_numpy(spectrograms)
print(spectrograms.shape)

torch.Size([64, 4, 120, 181])


In [61]:
# visualización de espectrogramas

In [62]:
spectrograms_transformed =  transform_train(spectrograms)

In [63]:
print(spectrograms_transformed.shape)

torch.Size([64, 4, 120, 181])


In [64]:
spectrograms2train = torch.cat((spectrograms, spectrograms_transformed), axis=0) #fijate acá el axis

In [65]:
spectrograms2train = spectrograms2train.to(device)

In [66]:
labels2train = torch.cat((labels, labels), axis=0) 

In [67]:
output2 = model2(spectrograms2train)

In [68]:
m = nn.Sigmoid()
probs = m(output2)

y_true  = torch.max(labels2train, dim =1)[0]
y_pred  = torch.max(probs, dim=1)[0]

In [69]:
if cont==1:
    Y_true = y_true
    Y_pred = y_pred

else:                
    Y_true = torch.cat((Y_true, y_true), axis=0)
    Y_pred = torch.cat((Y_pred, y_pred), axis=0)

In [70]:
# Compute loss
loss = criterion(output2, labels2train)
# Perform backward pass
loss.backward()
train_loss += loss.item()
# Perform optimization
optimizer1.step()
optimizer2.step()
scheduler1.step()
scheduler2.step()