## Import all library

In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import mne
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import pywt 
from PIL import Image
from utility import *
from model import *
from torchsummary import summary
import torch.optim.lr_scheduler as lr_scheduler


MODEL_FILE_DIRC_SateLight = MODEL_FILE_DIRC + "/SateLight"
MODEL_FILE_DIRC_CNN       = MODEL_FILE_DIRC + "/CNN"
os.makedirs(MODEL_FILE_DIRC_CNN, exist_ok=True)
os.makedirs(MODEL_FILE_DIRC_SateLight, exist_ok=True)

torch.manual_seed(3407)

<torch._C.Generator at 0x29beac13d30>

## Data Preparation
* Convert the data from csv to Dataloader

In [2]:
%%time
eeg_num_list = list(range(1,41))

dataloaders, num_data = get_dataloader(eeg_num_list, shuffle=True, get_dataDWT=False)
    

num_train_data, num_valid_data, num_test_data = num_data
train_data, valid_data, test_data             = dataloaders

The data from EEG_csv/eeg1.csv is loaded 
There is no spike in this eeg file
(252, 1280, 19)
EEG1 has 252 windows of data 


The data from EEG_csv/eeg2.csv is loaded 
There is no spike in this eeg file
(43, 1280, 19)
EEG2 has 43 windows of data 


The data from EEG_csv/eeg3.csv is loaded 
There is no spike in this eeg file
(88, 1280, 19)
EEG3 has 88 windows of data 


The data from EEG_csv/eeg4.csv is loaded 
There is no spike in this eeg file
(245, 1280, 19)
EEG4 has 245 windows of data 


The data from EEG_csv/eeg5.csv is loaded 
There is no spike in this eeg file
(236, 1280, 19)
EEG5 has 236 windows of data 


The data from EEG_csv/eeg6.csv is loaded 
There is spike in this eeg file
Data before split : (481920, 19)
Data with spike   : (465280, 19)
Data without spike: (16640, 19)
Data after  split into window: (376, 1280, 19)
Labels: (376,)
Num spike: 13
EEG6 has 376 windows of data 


The data from EEG_csv/eeg7.csv is loaded 
There is no spike in this eeg file
(311, 1280, 19)
EEG7 h

## Function to train the model

In [8]:
def start_classification_model_training(EPOCH_START, NUM_EPOCHS_CLASSIFIER, 
                                        model, MODEL_FILE_DIRC, model_name,  
                                        df, prev_best_valid_f1,prev_best_valid_loss,
                                        train_data, num_train_data, 
                                        valid_data, num_valid_data, 
                                        scheduler, optimizer, device):
    count = 0
    for epoch in range(EPOCH_START, NUM_EPOCHS_CLASSIFIER):
        
            ## 1. Training
            model.train()
            train_loss, train_metric = train_classifier(model, train_data, device, num_train_data, optimizer) 
            
            ## 2. Evaluating
            model.eval()
            valid_loss, valid_metric = evaluate_classifier(model, valid_data, device, num_valid_data) 
            
            ## 3. Show the result
            list_data       = [train_loss, valid_loss]
            for key in ["precision", "accuracy", "f1_score", "recall"]:
                list_data.append(train_metric[key])
                list_data.append(valid_metric[key])
            df.loc[len(df)] = list_data
            
            print_log(f"> > > Epoch     : {epoch}", MODEL_FILE_DIRC)
            print_log(f"Train {'loss':<10}: {train_loss}", MODEL_FILE_DIRC)
            print_log(f"Valid {'loss':<10}: {valid_loss}", MODEL_FILE_DIRC)
            for key in ["precision", "accuracy", "f1_score", "recall"]:
                print_log(f"Train {key:<10}: {train_metric[key]}", MODEL_FILE_DIRC)
                print_log(f"Valid {key:<10}: {valid_metric[key]}", MODEL_FILE_DIRC)

            
            ## 3.1 Plot the loss function
            fig,ax = plt.subplots(3,2, figsize=(10,10))
            x_data = range(len(df["Train Loss"]))
            for i, key in enumerate(["Loss","precision", "accuracy", "f1_score", "recall"]):    
                ax[i%3][i//3].plot(x_data, df[f"Train {key}"], label=f"Train {key}")
                ax[i%3][i//3].plot(x_data, df[f"Valid {key}"], label=f"Valid {key}")
                ax[i%3][i//3].legend()
            plt.savefig(f'{MODEL_FILE_DIRC}/Loss.png', transparent=False, facecolor='white')
            plt.close('all')

            ## 3.3. Save model and Stoping criteria
            if prev_best_valid_f1 <= valid_metric["f1_score"]:  # If previous best validation f1-score <= current f1-score 
                state_dict = {
                    "model": model.state_dict(), 
                    "epoch":epoch,
                    "valid_f1_score": valid_metric["f1_score"],
                    "valid_loss": valid_loss
                }
                torch.save(state_dict, f"{MODEL_FILE_DIRC}/{model_name}_best.pt")
                prev_best_valid_f1 = valid_metric["f1_score"]  # Previous validation loss = Current validation loss
                count = 0
            else:
                count += 1
            
            if epoch % 5 == 0:
                state_dict = {
                    "model": model.state_dict(), 
                    "epoch":epoch,
                    "valid_f1_score": valid_metric["f1_score"],
                    "valid_loss": valid_loss
                }
                torch.save(state_dict, f"{MODEL_FILE_DIRC}/{model_name}_{epoch}.pt")
            
            df.to_csv(f"{MODEL_FILE_DIRC}/Loss.csv", index=False)
            
            if count == MAX_COUNT_F1_SCORE:
                print_log(f"The validation f1 score is not increasing for continuous {MAX_COUNT_F1_SCORE} time, so training stop", MODEL_FILE_DIRC)
                break
            
            scheduler.step()

In [9]:
def load_classification_model_dict(model, MODEL_FILE_DIRC, model_name):
    list_model = os.listdir(MODEL_FILE_DIRC) 
    if len(list_model) > 0:    # Load the latest trained model
        if os.path.exists(f"{MODEL_FILE_DIRC}/{model_name}_best.pt"):
            state_dict_loaded    = torch.load(f"{MODEL_FILE_DIRC}/{model_name}_best.pt")
            prev_best_valid_f1   = state_dict_loaded["valid_f1_score"]
            prev_best_valid_loss = state_dict_loaded["valid_loss"]
        list_model.remove(f"{model_name}_best.pt")
        num_list   = [int(model_dir[model_dir.rindex("_") +1: model_dir.rindex(".")]) for model_dir in list_model if model_dir.endswith(".pt")]
        num_max    = np.max(num_list)
        
        state_dict_loaded = torch.load(f"{MODEL_FILE_DIRC}/{model_name}_{num_max}.pt")
        model.load_state_dict(state_dict_loaded["model"])
        EPOCH_START = state_dict_loaded["epoch"] + 1
        
        print(f"The model has been loaded from the file '{model_name}_{num_max}.pt'")

        if os.path.exists(f"{MODEL_FILE_DIRC}/Loss.csv"):
            df = pd.read_csv(f"{MODEL_FILE_DIRC}/Loss.csv")
            df = df.iloc[:EPOCH_START-1, :]
            print(f"The dataframe that record the loss have been loaded from {MODEL_FILE_DIRC}/Loss.csv")

    else:
        EPOCH_START            = 1
        prev_best_valid_f1     = -1
        prev_best_valid_loss   = 10000
        df                     = pd.DataFrame(columns = ["Train Loss", "Valid Loss"] + \
                                                         flatten_concatenation([[f"Train {metric}", f"Valid {metric}"] for metric in ["precision", "accuracy", "f1_score", "recall"]]) )
    return model, df, EPOCH_START, prev_best_valid_f1, prev_best_valid_loss

## Build and Train the SateLight model

### Build the SateLight model

In [9]:
model      = SateLight().to(device)
optimizer  = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)  
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

# Get the following information:
# 1. Previous Trained model (if exist)
# 2. df that store the training/validation loss & metrics
# 3. epoch where the training start
# 4. Previous Highest Validation F1-score
# 5. Previous Lowest  Validation Loss
model, df, EPOCH_START, prev_best_valid_f1, prev_best_valid_loss = load_classification_model_dict(model, MODEL_FILE_DIRC_SateLight, "SateLight")

# Get the summary of the model
print(summary(model, (19,1280)))

The model has been loaded from the file 'SateLight_25.pt'
The dataframe that record the loss have been loaded from Model/SateLight/Loss.csv
Layer (type:depth-idx)                        Output Shape              Param #
├─Sequential: 1-1                             [-1, 32, 1, 641]          --
|    └─Conv2d: 2-1                            [-1, 16, 19, 641]         10,256
|    └─Conv2d: 2-2                            [-1, 32, 1, 641]          640
├─BatchNorm1d: 1-2                            [-1, 32, 641]             64
├─ReLU: 1-3                                   [-1, 32, 641]             --
├─Dropout: 1-4                                [-1, 32, 641]             --
├─MaxPool1d: 1-5                              [-1, 32, 160]             --
├─ModuleList: 1                               []                        --
|    └─Sequential: 2-3                        [-1, 160, 32]             --
|    |    └─SelfAttention: 3-1                [-1, 160, 32]             7,392
|    |    └─BatchNorm1

### Print out the model info

In [10]:
seperate = "\n" + "-" * 100 + "\n"
print(seperate + "Model infomation" + seperate)
print(f"Device used        :", device)
print(f"BATCH SIZE         :", BATCH_SIZE)
print(f"MAX_COUNT_F1_SCORE :", MAX_COUNT_F1_SCORE)
print(f"LEARNING RATE      :", LEARNING_RATE)
print(f"Prev Best f1-score in validation dataset:", prev_best_valid_f1)
print(f"Prev Best validation loss               :", prev_best_valid_loss)
print(f"Number of EPOCH for training   :",NUM_EPOCHS_CLASSIFIER, f"(EPOCH start from {EPOCH_START})")
print(f"Num of epochs of data for train:", num_train_data)
print(f"Num of epochs of data for valid:", num_valid_data)
print(f'Model parameters               : {sum(p.numel() for p in model.parameters()):,}' )


----------------------------------------------------------------------------------------------------
Model infomation
----------------------------------------------------------------------------------------------------

Device used        : cuda
BATCH SIZE         : 32
MAX_COUNT_F1_SCORE : 20
LEARNING RATE      : 0.001
Prev Best f1-score in validation dataset: 0.625
Prev Best validation loss               : 0.07256671259143192
Number of EPOCH for training   : 101 (EPOCH start from 26)
Num of epochs of data for train: 6345
Num of epochs of data for valid: 1358
Model parameters               : 134,521


### Start Training Loop

In [12]:
%%time
start_classification_model_training(EPOCH_START, NUM_EPOCHS_CLASSIFIER, 
                                    model, MODEL_FILE_DIRC_SateLight, "SateLight",
                                    df, prev_best_valid_f1, prev_best_valid_loss, 
                                    train_data, num_train_data, 
                                    valid_data, num_valid_data, 
                                    scheduler, optimizer, device)

> > > Epoch     : 26
Train loss      : 0.01605148559438795
Valid loss      : 0.11045271051580063
Train precision : 0.8823529411764706
Valid precision : 1.0
Train accuracy  : 0.9982663514578408
Valid accuracy  : 0.9948453608247423
Train f1_score  : 0.8910891089108911
Valid f1_score  : 0.5333333333333333
Train recall    : 0.9
Valid recall    : 0.36363636363636365
> > > Epoch     : 27
Train loss      : 0.02467666392001191
Valid loss      : 0.12159415852657035
Train precision : 0.8148148148148148
Valid precision : 1.0
Train accuracy  : 0.997478329393223
Valid accuracy  : 0.9933726067746687
Train f1_score  : 0.8461538461538461
Valid f1_score  : 0.3076923076923077
Train recall    : 0.88
Valid recall    : 0.18181818181818182
> > > Epoch     : 28
Train loss      : 0.003504307770729208
Valid loss      : 0.19716880280761978
Train precision : 0.9038461538461539
Valid precision : 1.0
Train accuracy  : 0.9987391646966115
Valid accuracy  : 0.9933726067746687
Train f1_score  : 0.9215686274509803
Vali

### Test the model performance on testing dataset

In [13]:
# Load the best model and turn to evaluation mode
model.eval()
state_dict_loaded = torch.load(f"{MODEL_FILE_DIRC_SateLight}/SateLight_best.pt")
model.load_state_dict(state_dict_loaded["model"])

test_loss, test_metric = evaluate_classifier(model, test_data, device, num_test_data) 

print("Best model is at epoch:",state_dict_loaded["epoch"])
print("Metric on testing dataset:")
for key, value in test_metric.items():
    print(f"{key:<10}: {value:.4f}")

Best model is at epoch: 7
Metric on testing dataset:
precision : 0.4444
accuracy  : 0.9912
f1_score  : 0.4000
recall    : 0.3636


## Build and Train the Simple CNN model

### Build the CNN model

In [10]:
model      = CNN().to(device)
optimizer  = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)  
scheduler  = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

# Get the following information:
# 1. Previous Trained model (if exist)
# 2. df that store the training/validation loss & metrics
# 3. epoch where the training start
# 4. Previous Highest Validation Recall
# 5. Previous Lowest  Validation Loss
model, df, EPOCH_START, prev_best_valid_f1, prev_best_valid_loss = load_classification_model_dict(model, MODEL_FILE_DIRC_CNN, "CNN")

# Get the summary of the model
print(summary(model, (19,1280)))

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 32, 19]              --
|    └─Conv1dWithInitialization: 2-1     [-1, 8, 1278]             --
|    |    └─Conv1d: 3-1                  [-1, 8, 1278]             464
|    └─BatchNorm1d: 2-2                  [-1, 8, 1278]             16
|    └─ReLU: 2-3                         [-1, 8, 1278]             --
|    └─MaxPool1d: 2-4                    [-1, 8, 319]              --
|    └─Conv1dWithInitialization: 2-5     [-1, 16, 317]             --
|    |    └─Conv1d: 3-2                  [-1, 16, 317]             400
|    └─BatchNorm1d: 2-6                  [-1, 16, 317]             32
|    └─ReLU: 2-7                         [-1, 16, 317]             --
|    └─MaxPool1d: 2-8                    [-1, 16, 79]              --
|    └─Conv1dWithInitialization: 2-9     [-1, 32, 77]              --
|    |    └─Conv1d: 3-3                  [-1, 32, 77]              1,568
|    └─Bat

### Print out the model info

In [11]:
seperate = "\n" + "-" * 100 + "\n"
print(seperate + "Model infomation" + seperate)
print(f"Device used        :", device)
print(f"BATCH SIZE         :", BATCH_SIZE)
print(f"MAX_COUNT_F1_SCORE :", MAX_COUNT_F1_SCORE)
print(f"LEARNING RATE      :", LEARNING_RATE)
print(f"Prev Best recall in validation dataset:", prev_best_valid_f1)
print(f"Prev Best validation loss             :", prev_best_valid_loss)
print(f"Number of EPOCH for training   :",NUM_EPOCHS_CLASSIFIER, f"(EPOCH start from {EPOCH_START})")
print(f"Num of epochs of data for train:", num_train_data)
print(f"Num of epochs of data for valid:", num_valid_data)
print(f'Model parameters               : {sum(p.numel() for p in model.parameters()):,}' )


----------------------------------------------------------------------------------------------------
Model infomation
----------------------------------------------------------------------------------------------------

Device used        : cuda
BATCH SIZE         : 32
MAX_COUNT_F1_SCORE : 20
LEARNING RATE      : 0.001
Prev Best recall in validation dataset: -1
Prev Best validation loss             : 10000
Number of EPOCH for training   : 101 (EPOCH start from 1)
Num of epochs of data for train: 6345
Num of epochs of data for valid: 1358
Model parameters               : 3,153


### Start Training Loop

In [12]:
%%time
start_classification_model_training(EPOCH_START, NUM_EPOCHS_CLASSIFIER, 
                                    model, MODEL_FILE_DIRC_CNN, "CNN",
                                    df, prev_best_valid_f1, prev_best_valid_loss, 
                                    train_data, num_train_data, 
                                    valid_data, num_valid_data, 
                                    scheduler, optimizer, device)

> > > Epoch     : 1
Train loss      : 0.1091184768549822
Valid loss      : 0.07955073616405695
Train precision : 0.01
Valid precision : 1.0
Train accuracy  : 0.9766745468873128
Valid accuracy  : 0.9926362297496318
Train f1_score  : 0.013333333333333334
Valid f1_score  : 0.16666666666666666
Train recall    : 0.02
Valid recall    : 0.09090909090909091
> > > Epoch     : 2
Train loss      : 0.059958299165929194
Valid loss      : 0.08894044124803867
Train precision : 0.5454545454545454
Valid precision : 1.0
Train accuracy  : 0.9922773837667455
Valid accuracy  : 0.9926362297496318
Train f1_score  : 0.19672131147540983
Valid f1_score  : 0.16666666666666666
Train recall    : 0.12
Valid recall    : 0.09090909090909091
> > > Epoch     : 3
Train loss      : 0.04167339899111172
Valid loss      : 0.06574114996409085
Train precision : 0.7391304347826086
Valid precision : 0.6666666666666666
Train accuracy  : 0.9938534278959811
Valid accuracy  : 0.9926362297496318
Train f1_score  : 0.4657534246575342


### Test the model performance on testing dataset

In [13]:
# Load the best model and turn to evaluation mode
model.eval()
state_dict_loaded = torch.load(f"{MODEL_FILE_DIRC_CNN}/CNN_best.pt")
model.load_state_dict(state_dict_loaded["model"])

test_loss, test_metric = evaluate_classifier(model, test_data, device, num_test_data) 

print("Best model is at epoch:",state_dict_loaded["epoch"])
print("Metric on testing dataset:")
for key, value in test_metric.items():
    print(f"{key:<10}: {value:.4f}")

Best model is at epoch: 6
Metric on testing dataset:
precision : 0.3333
accuracy  : 0.9867
f1_score  : 0.4375
recall    : 0.6364
