In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pywt
from pywt import wavedec
import os
import torch
from torch.autograd import Variable
from torch.nn import Parameter, Linear, ReLU, CrossEntropyLoss, Sequential, Conv2d, Conv1d, MaxPool2d, MaxPool1d,AvgPool2d, Module, Softmax, BatchNorm2d, Dropout2d, Sigmoid, BCEWithLogitsLoss, LeakyReLU, BatchNorm1d, PReLU, Dropout, BCELoss, LogSoftmax, NLLLoss
from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, MultiplicativeLR
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import StepLR
from sklearn.utils import shuffle
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from scipy.signal import butter, lfilter, iirnotch, filtfilt, sosfilt
import shutil
from scipy import signal
import subprocess as sp
import time
seed = 3
np.random.seed(seed)
torch.manual_seed(seed = seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

Device: cpu


In [None]:
from google.colab import drive
drive.mount('/content/drive')
drive_path = '/content/drive/My Drive/chb_dataset/'

Mounted at /content/drive


In [None]:
def apply_filter(segment):
  b, a = iirnotch(60, 20, 256)
  noise_removed = lfilter(b, a, segment)

  b, a = iirnotch(50, 16, 256)
  noise_removed = lfilter(b, a, noise_removed)

  low = 0.5 / (0.5 * 256)
  high = 40 / (0.5 * 256)
  
  sos = butter(6, [low, high], btype = "bandpass", output = 'sos')
  filtered = sosfilt(sos, noise_removed)
  return filtered

In [None]:
def list_to_data(data_list, label, window_size, overlap, downsample = 0):
    x, y = list(), list()
    input_interval = 256 * window_size
    inc = int(input_interval / overlap)
    for data in data_list:
      for i in range(0, data.shape[1], inc):
          end_input = i + input_interval

          if end_input < data.shape[1]:
            img_seg = None
            if downsample != 0:              
              img_seg = apply_filter(data[:, i : end_input:downsample])

            else:              
              img_seg = apply_filter(data[:, i : end_input])

            x.append(img_seg)
            y.append(label)

    x, y = np.array(x), np.array(y)
    x, y = shuffle(x, y, random_state = seed)
    return x, y

In [None]:
#new
def save_data(window_size, x, y, start_value = 0, del_dir = False):
  count = start_value
  x_list, y_list = list(), list()
  if del_dir:
    try:
      shutil.rmtree('my_data')
    except:
      pass
  if not os.path.exists('my_data'):
    os.makedirs('my_data')

  for x_temp, y_temp in zip(x, y):
    file_name = "my_data/x_"+str(window_size)+"_"+str(count)+".npy"
    np.save(file_name, x_temp)
    x_list.append(file_name)
    y_list.append(y_temp)
    count += 1

  data = pd.DataFrame()
  data['file'] = x_list
  data['label'] = y_list
  return data
    

In [None]:
def create_list(pre_data, inter_data, window_size, overlap, downsample = 0, start_value = 0):
  x, y = list_to_data(pre_data, 1, window_size, overlap, downsample)
  pre_list = save_data(window_size, x, y, start_value)
  del x
  del y

  x, y = list_to_data(inter_data, 0, window_size, overlap, downsample)
  inter_list = save_data(window_size, x, y, pre_list.shape[0])
  del x
  del y

  train_size = int(0.8 * pre_list.shape[0])
  val_size = int(0.9 * pre_list.shape[0])

  train_list = pd.concat([pre_list[0:train_size], inter_list[0:train_size]])

  val_list = pd.concat([pre_list[train_size:val_size], inter_list[train_size:val_size]])

  test_list = pd.concat([pre_list[val_size:], inter_list[val_size:]])
  
  return train_list, val_list, test_list

In [None]:
patient_num = 6
data_path = drive_path + 'processed_data/'
result_path = drive_path + 'new_results/patient_'+str(patient_num)+'/'
PREICTAL_INTERVAL_IN_MINS = 20
WINDOW_SIZE = 10
OVERLAP = 6
IS_DWT = True
IS_CONT = True
WINDOW_SIZE_END = 40
model_file = "new_p"+str(patient_num)+"_"+("dwt_" if IS_DWT else "")+str(PREICTAL_INTERVAL_IN_MINS)+"_"+('cont_'+str(WINDOW_SIZE)+"_"+str(WINDOW_SIZE_END) if IS_CONT else str(WINDOW_SIZE))+"_"+str(OVERLAP)+".pt"
train_history_file = "new_p"+str(patient_num)+"_"+("dwt_" if IS_DWT else "")+str(PREICTAL_INTERVAL_IN_MINS)+"_"+('cont_'+str(WINDOW_SIZE)+"_"+str(WINDOW_SIZE_END) if IS_CONT else str(WINDOW_SIZE))+"_"+str(OVERLAP)+".csv"
log_file = "new_p"+str(patient_num)+"_"+("dwt_" if IS_DWT else "")+str(PREICTAL_INTERVAL_IN_MINS)+"_"+('cont_'+str(WINDOW_SIZE)+"_"+str(WINDOW_SIZE_END) if IS_CONT else str(WINDOW_SIZE))+"_"+str(OVERLAP)+".txt"
plot_file = "new_p"+str(patient_num)+"_"+("dwt_" if IS_DWT else "")+str(PREICTAL_INTERVAL_IN_MINS)+"_"+('cont_'+str(WINDOW_SIZE)+"_"+str(WINDOW_SIZE_END) if IS_CONT else str(WINDOW_SIZE))+"_"+str(OVERLAP)+".png"

if not os.path.exists(result_path):
  os.makedirs(result_path)

print(f"Model filename: {model_file}")
print(f"Train history file: {train_history_file}")
print(f"Log  file: {log_file}")
print(f"Plot file: {plot_file}")

Model filename: new_p6_dwt_20_cont_10_40_6.pt
Train history file: new_p6_dwt_20_cont_10_40_6.csv
Log  file: new_p6_dwt_20_cont_10_40_6.txt
Plot file: new_p6_dwt_20_cont_10_40_6.png


In [None]:
inter_name = f"patient_{patient_num}_inter_interval_{PREICTAL_INTERVAL_IN_MINS}.npy"
pre_name = f"patient_{patient_num}_pre_interval_{PREICTAL_INTERVAL_IN_MINS}.npy"

pre = np.load(data_path + pre_name)
inter = np.load(data_path + inter_name)


Network with trainable dwt parameter

In [None]:
#rough
x = torch.randn((32, 18, 1289))
conv_1 = Conv1d(18, 12, kernel_size = 3, stride = 1)
conv_2 = Conv1d(12, 6, kernel_size = 3, stride = 2)
conv_3 = Conv1d(6, 3, kernel_size = 5, stride = 1)
conv_4 = Conv1d(3, 1, kernel_size = 5, stride = 2)
maxpool = MaxPool1d(2, stride = 1)


x = F.relu(conv_1(x))
x = conv_2(x)
x = maxpool(x)
print(f"After block one: {x.size()}")
x = conv_3(x)
x = conv_4(x)
#x = maxpool(x)
print(f"After block two: {x.size()}")
#x = conv_5(x)
#x = conv_6(x)
print(f"Final size: {x.size()}")


cd1 = torch.randn(x.size()).squeeze(1)
cd2 = torch.randn(x.size()).squeeze(1)
cd3 = torch.randn(x.size()).squeeze(1)
cd4 = torch.randn(x.size()).squeeze(1)
print(cd1.size())
result = torch.cat([cd1, cd2, cd3, cd4], dim = 1)
print(result.size())

After block one: torch.Size([32, 6, 642])
After block two: torch.Size([32, 1, 317])
Final size: torch.Size([32, 1, 317])
torch.Size([32, 317])
torch.Size([32, 1268])


In [None]:
class CoeffNet(Module):

  def __init__(self):
    super(CoeffNet, self).__init__()

    self.conv_1 = Conv1d(18, 15, kernel_size = 3, stride = 2)
    self.conv_2 = Conv1d(15, 12, kernel_size = 3, stride = 2)
    self.conv_3 = Conv1d(12, 9, kernel_size = 5, stride = 1)
    self.conv_4 = Conv1d(9, 6, kernel_size = 5, stride = 1)
    self.conv_5 = Conv1d(6, 3, kernel_size = 7, stride = 1)
    self.conv_6 = Conv1d(3, 1, kernel_size = 7, stride = 1)
    '''
    self.conv_1 = Conv1d(18, 12, kernel_size = 3, stride = 1)
    self.conv_2 = Conv1d(12, 6, kernel_size = 3, stride = 2)
    self.conv_3 = Conv1d(6, 3, kernel_size = 5, stride = 1)
    self.conv_4 = Conv1d(3, 1, kernel_size = 5, stride = 2)
    '''
    self.maxpool = MaxPool1d(2, stride = 1)
    self.batchNorm_1 = BatchNorm1d(12)
    self.batchNorm_2 = BatchNorm1d(6)

  def forward(self, x):

    x = F.relu(self.conv_1(x))
    x = F.relu(self.conv_2(x))
    x = self.batchNorm_1(x)
    x = self.maxpool(x)

    x = F.relu(self.conv_3(x))
    x = F.relu(self.conv_4(x))
    x = self.batchNorm_2(x)
    x = self.maxpool(x)

    x = F.relu(self.conv_5(x))
    x = F.relu(self.conv_6(x))

    return x

In [None]:
#trainable parameter neural network
class CustomDWT(Module):

  def __init__(self):
    super(CustomDWT, self).__init__()    
    
    self.batchNorm = BatchNorm1d(512)

    self.linear_1 = Linear(1196, 1024)#1196
    self.linear_2 = Linear(1024, 512)
    self.linear_3 = Linear(512, 256)
    self.linear_4 = Linear(256, 2)

    
    self.cd1_branch = CoeffNet()
    self.cd2_branch = CoeffNet()
    self.cd3_branch = CoeffNet()
    self.cd4_branch = CoeffNet()
    

  def forward(self, x):
    x = x.cpu().detach().numpy()
    cd1, cd2, cd3, cd4 = list(), list(), list(), list()

    #iterating on each sample inside a batch
    for i in range(x.shape[0]): 
      item = x[i]            
      cd1_temp, cd2_temp, cd3_temp, cd4_temp = list(), list(), list(), list()

      #iterating each channel
      for ch in range(item.shape[0]):
        single_ch = item[ch]
        d_and_a = wavedec(single_ch, 'db10', level = 4)
        del d_and_a[0]
        max_length = max(len(d_and_a[0]), len(d_and_a[1]), len(d_and_a[2]), len(d_and_a[3]))      

        for coeff_index in range(4):
          d_and_a[coeff_index] = np.concatenate([d_and_a[coeff_index], [0 for l in range(max_length - len(d_and_a[coeff_index]))]])
        
        cd1_temp.append(d_and_a[0])
        cd2_temp.append(d_and_a[1])
        cd3_temp.append(d_and_a[2])
        cd4_temp.append(d_and_a[3])
      
      cd1.append(cd1_temp)
      cd2.append(cd2_temp)
      cd3.append(cd3_temp)
      cd4.append(cd4_temp)
    
    cd1 = torch.from_numpy(np.array(cd1)).to(device)
    cd2 = torch.from_numpy(np.array(cd2)).to(device)
    cd3 = torch.from_numpy(np.array(cd3)).to(device)
    cd4 = torch.from_numpy(np.array(cd4)).to(device)
    
    cd1 = self.cd1_branch(cd1).squeeze(1)
    cd2 = self.cd2_branch(cd2).squeeze(1)
    cd3 = self.cd3_branch(cd3).squeeze(1)
    cd4 = self.cd4_branch(cd4).squeeze(1)
    

    x = torch.cat([cd1, cd2, cd3, cd4], dim = 1).to(device)
    x = F.relu(self.linear_1(x))
    x = F.relu(self.linear_2(x))
    x = self.batchNorm(x)
    x = F.relu(self.linear_3(x))
    x = F.relu(self.linear_4(x))
    
    return x

In [None]:
#from https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta= 0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def is_best(self, val_loss):
      score = -val_loss
      
      if self.best_score is None:
        self.best_score = score

      elif score < self.best_score + self.delta:
        return False
      else:
        return True

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), 'checkpoint.pt')
        self.val_loss_min = val_loss

In [None]:
LEARNING_RATE = 0.0001
MOMENTUM = 0.6
L2 = 0.01

In [None]:
model = CustomDWT()
criterion = CrossEntropyLoss()
model = model.to(device)
model.double()
criterion = criterion.to(device)
optimizer = SGD(model.parameters(), lr = LEARNING_RATE, momentum = MOMENTUM, weight_decay = L2)

In [None]:
class PytorchDataGen(Dataset):

  def __init__(self, data):
    self.file_list = data['file'].values.tolist()
    self.labels = data['label'].values.tolist()
  
  def __len__(self):
    return len(self.file_list)

  def __getitem__(self, index):
    file_name = self.file_list[index]
    x = np.load(file_name)    
    y = self.labels[index]
    return x,y

In [None]:
!rm -rf my_data
!rm -rf checkpoint.pt

In [None]:
train_list, val_list, test_list = create_list(pre, inter, WINDOW_SIZE, OVERLAP, downsample = (WINDOW_SIZE // 10))

In [None]:
t = np.load(train_list['file'][0].values[0])
print(t.shape)

(18, 2560)


In [None]:
'''
ONLY FOR CONT LEARNING
Process data given list of window sizes
'''
window_sizes = [i for i in range(10, WINDOW_SIZE_END + 10, 10)]
train_list = pd.DataFrame()
val_list = pd.DataFrame()
test_list = pd.DataFrame()

for window_size in window_sizes:
  train_list_temp, val_list_temp, test_list_temp = create_list(pre, inter, window_size, OVERLAP, downsample = (window_size // 10))
  train_list = pd.concat([train_list, train_list_temp], axis = 0).reset_index(drop = True)
  val_list = pd.concat([val_list, val_list_temp], axis = 0).reset_index(drop = True)
  test_list = pd.concat([test_list, test_list_temp], axis = 0).reset_index(drop = True)

train_list = train_list.sample(frac = 1).reset_index(drop = True)
val_list = val_list.sample(frac = 1).reset_index(drop = True)
test_list = test_list.sample(frac = 1).reset_index(drop = True)

In [None]:
print(f"Train list: {train_list.shape} Val list: {val_list.shape} Test list: {test_list.shape}")

Train list: (18942, 2) Val list: (2368, 2) Test list: (2370, 2)


In [None]:
batch_size = 32
train_params = {
    'batch_size':batch_size,
    'shuffle':True,
    'num_workers':2
}
val_params = {
    'batch_size':batch_size,
    'shuffle':False,
    'num_workers':2
}

In [None]:
train_data = PytorchDataGen(train_list)
val_data = PytorchDataGen(val_list)

train_gen = DataLoader(train_data, **train_params, drop_last = True)
val_gen = DataLoader(val_data, **val_params, drop_last = True)

In [None]:
#for using with crossentropy loss
def calc_acc(y_pred, y_true):
  prob = F.softmax(y_pred, dim = 1)
  pred_label = np.argmax(prob.detach().cpu().numpy(), axis = 1)
  y_true = y_true.detach().cpu().numpy()
  acc = (pred_label == y_true).sum()/pred_label.shape[0]
  return acc

Training

In [None]:
early_stopping = EarlyStopping(patience = 20)
EPOCHS = 100
train_loss_list, val_loss_list = list(), list()
train_acc_list, val_acc_list = list(), list()
early_stop_epoch = 0

train_best_loss = 0
val_best_loss = 0
train_best_acc = 0
val_best_acc = 0
lr_list = []
start_time = time.time()
for epoch in range(EPOCHS):
  train_running_loss, train_running_acc = list(), list()
  val_running_loss, val_running_acc = list(), list()
  #training loop
  model.train()  
  for (x_train, y_train) in train_gen:
    x_train = x_train.to(device)
    y_train = y_train.to(device)

    optimizer.zero_grad()
    pred = model(x_train)  
    train_loss_func = criterion(pred, y_train)
    train_loss_func.backward()
    optimizer.step()
        
    acc = calc_acc(pred, y_train)
    train_running_acc.append(acc)
    train_running_loss.append(train_loss_func.item())
    
  train_loss = np.average(train_running_loss)
  train_acc = np.average(train_running_acc)
  train_acc_list.append(train_acc)
  train_loss_list.append(train_loss)


  #validation loop
  with torch.no_grad():
    model.eval()
    for (x_val, y_val) in val_gen:
      x_val = x_val.to(device)
      y_val = y_val.to(device)
  
      pred = model(x_val)
      val_loss_func = criterion(pred, y_val)
      acc = calc_acc(pred, y_val)
      val_running_acc.append(acc)
      val_running_loss.append(val_loss_func.item())
    
    val_loss = np.average(val_running_loss)
    val_acc = np.average(val_running_acc)
    val_acc_list.append(val_acc)
    val_loss_list.append(val_loss)
  
    
  print(f"Epoch: {epoch + 1}/{EPOCHS}   Train Acc = {train_acc}  Train Loss = {train_loss}    Val Acc = {val_acc}  Val Loss = {val_loss}")  
  
  #Earlystopping
  early_stopping(val_loss, model)

  if early_stopping.is_best(val_loss):
    early_stop_epoch = epoch
    train_best_loss = train_loss
    train_best_acc = train_acc
    val_best_loss = val_loss
    val_best_acc = val_acc
  
  if early_stopping.early_stop:
    print("EARLY STOPPING")
    break

    
print(f"Train loss: {train_best_loss}")
print(f"Train acc: {train_best_acc}")
print(f"Val loss: {val_best_loss}")
print(f"Val acc: {val_best_acc}")
print(f"Time taken: {time.time() - start_time}")

Epoch: 1/100   Train Acc = 0.49746192893401014  Train Loss = 0.697828190049869    Val Acc = 0.5206925675675675  Val Loss = 0.6933083144989904
Epoch: 2/100   Train Acc = 0.5377009306260575  Train Loss = 0.6823935020244912    Val Acc = 0.5388513513513513  Val Loss = 0.6647295040611558
Epoch: 3/100   Train Acc = 0.5864530456852792  Train Loss = 0.6427568552915937    Val Acc = 0.6258445945945946  Val Loss = 0.6178596454894187
Epoch: 4/100   Train Acc = 0.6692047377326565  Train Loss = 0.6033156912273057    Val Acc = 0.7115709459459459  Val Loss = 0.5763694200096655
Epoch: 5/100   Train Acc = 0.7214467005076142  Train Loss = 0.5734840363684802    Val Acc = 0.7508445945945946  Val Loss = 0.5482316085347244
Epoch: 6/100   Train Acc = 0.7384200507614214  Train Loss = 0.5527318068671767    Val Acc = 0.7584459459459459  Val Loss = 0.5305489069291696
Epoch: 7/100   Train Acc = 0.7462986463620981  Train Loss = 0.540079623016015    Val Acc = 0.7668918918918919  Val Loss = 0.5199109560117997
Epoch: 

Testing and plotting

In [None]:
sp.call(['cp', 'checkpoint.pt', model_file])
sp.call(['mv', model_file, result_path+model_file])

In [None]:
result = pd.DataFrame()
result['train_loss'] = train_loss_list
result['train_acc'] = train_acc_list
result['val_loss'] = val_loss_list
result['val_acc'] = val_acc_list
result['early_stop'] = [early_stop_epoch for i in range(len(val_acc_list))]
result.to_csv(result_path + train_history_file)

In [None]:
test_model = CustomDWT()
test_model.load_state_dict(torch.load('checkpoint.pt'))
test_model = test_model.to(device)
test_model.double()
test_model = test_model.eval()

In [None]:
test_params = {
    'shuffle':False,
    'batch_size': 1
}
test_data = PytorchDataGen(test_list)
test_gen = DataLoader(test_data, **test_params)

total_acc = []
total_loss = []
y_pred = []
y_true = []

for x_test, y_test in test_gen:
  
  pred = test_model(x_test.to(device))
  y_test = y_test.to(device)
  #y_test = y_test.unsqueeze(1)
  loss = criterion(pred, y_test)
  acc = calc_acc(pred, y_test)

  total_loss.append(loss.item())
  total_acc.append(acc)
  
  softmax = torch.exp(pred).cpu()
  prob = list(softmax.detach().numpy())
  pred_label = np.argmax(prob, axis = 1)
  y_test = y_test.detach().cpu().numpy()
  y_pred.append(pred_label[0])
  y_true.append(y_test[0])

print(f"Average accuracy {np.average(total_acc)}")
print(f"Average loss {np.average(total_loss)}")

In [None]:
#Negative = interictal (0)
#Positive = preictal (1)
TN, FP, FN, TP = confusion_matrix(y_true, y_pred).ravel()
sensitivity = TP/(TP + FN)
specificity = TN/(TN + FP)
fpr = FP/(FP + TN)
print(f"Sensitivity: {sensitivity}")
print(f"Specificity: {specificity}")
print(f"False Positive rate: {fpr}")

In [None]:
print(f"True Positive: {TP}")
print(f"True Negative: {TN}")
print(f"False Positive: {FP}")
print(f"False Negative: {FN}")

In [None]:
plt.plot(train_loss_list[0:early_stop_epoch], color = 'black')
plt.plot(val_loss_list[0:early_stop_epoch], color = 'red')
plt.plot(train_acc_list[0:early_stop_epoch], color = 'blue')
plt.plot(val_acc_list[0:early_stop_epoch], color = 'green')
plt.legend(['Train loss', 'Val loss', 'Train Acc', 'Val Acc', 'Early Stop'], loc = 'upper right')
plt.savefig(result_path + plot_file, bbox_inches = "tight")
plt.plot()

In [None]:

#logging data to file
with open(result_path+log_file, "w") as f:
  print(f" Continuos Learning: {IS_CONT}\n Window Size Limit: {WINDOW_SIZE_END if IS_CONT else 0}\nLearning Rate: {LEARNING_RATE}\nMomentum: {MOMENTUM}\nL2: {L2}\nSensitivity: {sensitivity}\nSpecificity: {specificity}\nFPR: {fpr}\nTrue Positive: {TP}\nTrue Negative: {TN}\nFalse Positive: {FP}\nFalse Negative: {FN}", file = f)



##Old Code

In [None]:
class DWTLayer(Module):

  def __init__(self):
    super(DWTLayer, self).__init__()
    
    self.linear_1 = Linear(1289, 512)
    self.linear_2 = Linear(654, 512)
    self.linear_3 = Linear(336, 512)
    self.linear_4 = Linear(177, 512)
  
    self.batch_norm = BatchNorm2d(4)
  def forward(self, x):
    cd_1, cd_2, cd_3, cd_4 = list(), list(), list(), list()
    x = x.cpu().detach().numpy()

    #iterating on each sample inside a batch
    for i in range(x.shape[0]): 
      item = x[i]
      cd_1_temp, cd_2_temp, cd_3_temp, cd_4_temp = list(), list(), list(), list()

      #iterating each channel
      for ch in range(item.shape[0]):
        single_ch = item[ch]
        d_and_a = wavedec(single_ch, 'db10', level = 4)
        cd_4_temp.append(d_and_a[1])
        cd_3_temp.append(d_and_a[2])
        cd_2_temp.append(d_and_a[3])
        cd_1_temp.append(d_and_a[4])
      
      cd_1.append(cd_1_temp)
      cd_2.append(cd_2_temp)
      cd_3.append(cd_3_temp)
      cd_4.append(cd_4_temp)
      
    cd_1 = np.array(cd_1)
    cd_2 = np.array(cd_2)
    cd_3 = np.array(cd_3)
    cd_4 = np.array(cd_4)

    cd_1 = torch.from_numpy(cd_1).to(device)
    cd_2 = torch.from_numpy(cd_2).to(device)
    cd_3 = torch.from_numpy(cd_3).to(device)
    cd_4 = torch.from_numpy(cd_4).to(device)
    
    cd_1 = self.linear_1(cd_1)
    cd_2 = self.linear_2(cd_2)
    cd_3 = self.linear_3(cd_3)
    cd_4 = self.linear_4(cd_4)
          
    x = torch.stack([cd_1, cd_2, cd_3, cd_4], dim = 1)        
    x = self.batch_norm(x)
    return x

In [None]:
#trainable parameter neural network
class CustomDWT(Module):

  def __init__(self):
    super(CustomDWT, self).__init__()    
    self.conv_1 = Conv2d(4, 4, kernel_size = (3,3), stride = (1, 2))
    self.conv_2 = Conv2d(4, 4, kernel_size = (3,3), stride = (1, 2))
    self.conv_3 = Conv2d(4, 16, kernel_size = (5,5), stride = (1, 2), padding = (1,0))
    self.conv_4 = Conv2d(16, 16, kernel_size = (5,5), stride = (1, 2), padding = (1, 0))
    
    self.batch_norm_1 = BatchNorm2d(4)
    self.batch_norm_2 = BatchNorm2d(16)

    self.mpl = MaxPool2d(kernel_size = 2, stride = 2)
   
    self.linear_1 = Linear(624, 256)    
    self.linear_2 = Linear(256, 128)
    self.linear_3 = Linear(128, 2)
    
    self.batch_norm_l1 = BatchNorm1d(256)
    self.batch_norm_l2 = BatchNorm1d(128)    

    self.dwt_layer = DWTLayer()


  def forward(self, x):
    #dwt layer
    x = self.dwt_layer(x)
    
    #cnn blocks
    x = torch.tanh(self.conv_1(x))    
    x = torch.tanh(self.conv_2(x))
    x = self.batch_norm_1(x)
    x = self.mpl(x)
    x = torch.tanh(self.conv_3(x))    
    x = torch.tanh(self.conv_4(x))
    x = self.batch_norm_2(x)

    #faltening
    x = x.view(x.size(0), -1)    

    #FCNs    
    x = self.linear_1(x)
    x = self.batch_norm_l1(x)    
    x = torch.tanh(x)

    x = self.linear_2(x)
    x = self.batch_norm_l2(x)
    x = torch.tanh(x)

    x = self.linear_3(x)
    return x