## Import all required library

In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import mne
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import pywt 
from PIL import Image
from datetime import datetime, timezone
from utility import *
from model import *

## Load the classification model

In [2]:
model      = SateLight().to(device)
state_dict_loaded = torch.load(MODEL_FILE_DIRC + f"/SateLight_synthesized/SateLight_best.pt")
model.load_state_dict(state_dict_loaded["model"])
model.eval()

SateLight(
  (two_con2D): Sequential(
    (0): Conv2d(1, 16, kernel_size=(1, 640), stride=(1, 1))
    (1): Conv2d(16, 32, kernel_size=(19, 1), stride=(1, 1), groups=16)
  )
  (batchNorm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.2, inplace=False)
  (maxpooling): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (attention_blocks): ModuleList(
    (0): Sequential(
      (0): SelfAttention(
        (query): Linear(in_features=32, out_features=32, bias=True)
        (key): Linear(in_features=32, out_features=32, bias=True)
        (value): Linear(in_features=32, out_features=32, bias=True)
        (attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
        )
      )
      (1): BatchNorm1d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Dropout(p=0.2, inplace=False)
    )
  

## Predict the result

In [3]:
SECONDS_TO_TRASH = 10
EEG_file_list = ["EEG_Dataset/36946379/369463791.edf"]
for edf_filename in EEG_file_list:
    num = 1000
    
    ## Process the edf file
    process_edf(edf_filename, num, 
                SKIP_FIRST = SECONDS_TO_TRASH, 
                SKIP_LAST  = SECONDS_TO_TRASH)    
    
    ## Get the data
    datasets, _, _ = get_dataloader([num], get_dataloader=False, shuffle=False)
    datasets = torch.cat(datasets, dim=0).to(torch.float32) 
    datasets = datasets.to(device)
    
    ## Classify the existance of spike
    output    = model(datasets.to(device))
    output    = torch.round(torch.sigmoid(output)).detach().flatten().cpu().numpy()
    num_spike = int(output.sum())
    print(f"There is total {num_spike} spike in {edf_filename}")
    print(output, "\n")
    
    if num_spike >= 1:
        raw = mne.io.read_raw_edf(edf_filename, preload = True)
        
        # Get the mne Annotations object
        annotations = mne.Annotations(onset= [SECONDS_TO_TRASH+5+i*DURATION for i, out in enumerate(output) if out==1],
                                      duration=[0.0] * num_spike,
                                      description=["Spike"] * num_spike,
                                      orig_time=raw.info["meas_date"])

        # Print the annotations to verify
        print(annotations.to_data_frame())
        
        # Failed to save again to edf, have problem on it 
        # Set the annotations
        # raw.set_annotations(annotations)
        # new_edf_filename = edf_filename[:-4] + "_Processed.edf"
        
        # mne.export.export_raw(new_edf_filename, raw, fmt="edf")
        # print(f"The label file have been saved from {edf_filename} to {new_edf_filename}")

For the file:  EEG_Dataset/36946379/369463791.edf
Extracting EDF parameters from c:\Users\xinju\Desktop\Python_Jupyter\Y3_Sem2_Thesis_XAI\EEG_Dataset\36946379\369463791.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 1380863  =      0.000 ...  2696.998 secs...
Before drop some row, shape = (1380864, 39)
After  drop some row, shape = (1370624, 39)
Setting up band-pass filter from 0.5 - 70 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 70.00 Hz
- Upper transition bandwidth: 17.50 Hz (-6 dB cutoff frequency: 78.75 Hz)
- Filter length: 3381 samples (6.604 s)

Creating RawArray with float64 data, n_channels=19, n_times=1370624
    Ra

[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.6s


Applying average reference.
Applying a custom ('EEG',) reference.
EEG_Dataset/36946379/369463791.edf does not contain NULL value, the process will continue
The dataframe of EEG_Dataset/36946379/369463791.edf have been saved to EEG_csv/eeg1000.csv

The data from EEG_csv/eeg1000.csv is loaded 
There is no spike in this eeg file
(267, 1280, 19)
EEG1000 has 267 windows of data 


> > > Train    data  has shape: torch.Size([187, 19, 1280]) when duration = 10 seconds
> > > Data after DWT has shape: torch.Size([187, 19, 1282])
> > > Label              shape: torch.Size([187])
There is total 8 spike in EEG_Dataset/36946379/369463791.edf
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.
 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0.
 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0