# Loading the data

Import needed libraries

In [61]:
import pandas as pd
import torch
import numpy as np
import matplotlib.pyplot as plt
import os

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

Setup global variables for loading

In [62]:
# when working locally set correct paths from the current directory to

# directory that contains data from kaggle hms
INPUT_DATA_DIR = "data"

# directory in which our npy files are/will be stored
PROCESSED_DATA_DIR = "processed_data"

Load the metadata

In [63]:
train_meta_full = pd.read_csv(INPUT_DATA_DIR + "/train.csv")
train_meta = train_meta_full.loc[train_meta_full["eeg_sub_id"] == 0]
test_meta = pd.read_csv(INPUT_DATA_DIR + "/test.csv")

Create a butterfilter

In [64]:
from scipy.signal import butter, lfilter


def butter_lowpass_filter(data, cutoff_freq=20, sampling_rate=200, order=4):
    nyquist = 0.5 * sampling_rate
    normal_cutoff = cutoff_freq / nyquist
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    filtered_data = lfilter(b, a, data, axis=0)
    return filtered_data

cutoff_freq=20 
sampling_rate=200
order=4
nyquist = 0.5 * sampling_rate
normal_cutoff = cutoff_freq / nyquist
b, a = butter(order, normal_cutoff, btype='low', analog=False)

Extract parquet data

In [65]:
# take a parquet dataframe and compute correct values for each column
# we want columns such as "Fp1-F7" as can be seen in /example_figures
def extract_parquet(parquet_data: torch.tensor):
    parquet_data["Fp1-F7"] = parquet_data["Fp1"] - parquet_data["F7"]
    parquet_data["F7-T3"] = parquet_data["F7"] - parquet_data["T3"]
    parquet_data["T3-T5"] = parquet_data["T3"] - parquet_data["T5"]
    parquet_data["T5-O1"] = parquet_data["T5"] - parquet_data["O1"]

    parquet_data["Fp2-F8"] = parquet_data["Fp2"] - parquet_data["F8"]
    parquet_data["F8-T4"] = parquet_data["F8"] - parquet_data["T4"]
    parquet_data["T4-T6"] = parquet_data["T4"] - parquet_data["T6"]
    parquet_data["T6-O2"] = parquet_data["T6"] - parquet_data["O2"]

    parquet_data["Fp1-F3"] = parquet_data["Fp1"] - parquet_data["F3"]
    parquet_data["F3-C3"] = parquet_data["F3"] - parquet_data["C3"]
    parquet_data["C3-P3"] = parquet_data["C3"] - parquet_data["P3"]
    parquet_data["P3-O1"] = parquet_data["P3"] - parquet_data["O1"]

    parquet_data["Fp2-F4"] = parquet_data["Fp2"] - parquet_data["F4"]
    parquet_data["F4-C4"] = parquet_data["F4"] - parquet_data["C4"]
    parquet_data["C4-P4"] = parquet_data["C4"] - parquet_data["P4"]
    parquet_data["P4-O2"] = parquet_data["P4"] - parquet_data["O2"]

    parquet_data["Fz-Cz"] = parquet_data["Fz"] - parquet_data["Cz"]
    parquet_data["Cz-Pz"] = parquet_data["Cz"] - parquet_data["Pz"]

    parquet_data = parquet_data.drop(
        [
            "Fp1",
            "F3",
            "C3",
            "P3",
            "F7",
            "T3",
            "T5",
            "O1",
            "Fz",
            "Cz",
            "Pz",
            "Fp2",
            "F4",
            "C4",
            "P4",
            "F8",
            "T4",
            "T6",
            "O2",
        ],
        axis=1,
    )
    idx = parquet_data.columns[1:].to_list() + [parquet_data.columns[0]]
    parquet_data = parquet_data[idx].values.T
    # parquet_data = butter_lowpass_filter(parquet_data)
    parquet_data = lfilter(b, a, parquet_data, axis=0)
    parquet_data = torch.from_numpy(parquet_data).type(torch.float32)
    parquet_data = torch.clip(parquet_data, -1024, 1024)
    return parquet_data

Get the data and labels for training

In [66]:
eeg_data = []
faulty_eeg_id = []
nan_rows = 0

if not os.path.exists(f"{PROCESSED_DATA_DIR}/eeg_labels.pt"):
    all_labels = train_meta.loc[~train_meta["eeg_id"].isin(faulty_eeg_id)][
        ["seizure_vote", "lpd_vote", "gpd_vote", "lrda_vote", "grda_vote", "other_vote"]
    ].values
    
    for eeg_id, label, offset in zip(
        train_meta["eeg_id"], all_labels, train_meta["eeg_label_offset_seconds"]
    ):
        parquet_data = pd.read_parquet(
            INPUT_DATA_DIR + f"/train_eegs/{eeg_id}.parquet"
        ).interpolate(method="ffill")[:10000]
        if np.any(parquet_data.isna()):
            faulty_eeg_id.append(eeg_id)
            continue

        eeg = extract_parquet(parquet_data)
        eeg_data.append(eeg)

    # eeg_data.squeeze()[:,None,:,:].shape
    eeg_data = torch.stack(eeg_data)
    torch.save(eeg_data, f"{PROCESSED_DATA_DIR}/eeg_data.pt")

    eeg_labels = all_labels[np.where(1-train_meta["eeg_id"].isin(faulty_eeg_id))]
    eeg_labels = torch.tensor(np.array(eeg_labels), dtype=torch.float32)
    eeg_labels = eeg_labels / eeg_labels.sum(dim=1, keepdims=True)
    torch.save(eeg_labels, f"{PROCESSED_DATA_DIR}/eeg_labels.pt")
else:
    eeg_data = torch.load(f"{PROCESSED_DATA_DIR}/eeg_data.pt")
    eeg_labels = torch.load(f"{PROCESSED_DATA_DIR}/eeg_labels.pt")

In [89]:
print(eeg_data.shape)
print(eeg_labels.shape)

torch.Size([17018, 19, 10000])
torch.Size([17018, 6])


---

# Creating a data loader, preprocessing

Setup global variables for dataloder and preprocessing

In [90]:
SAMPLING_FREQUENCY = 200
SAMPLES_IN_MEASUREMENT = 10000
FOLDS = 5
BATCH_SIZE = 32
NUM_WORKERS = 0

Create a HMS dataset class that will help us load the data during the model training

In [91]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    eeg_data, eeg_labels, test_size=0.1, stratify=eeg_labels.argmax(axis=1)
)

In [92]:
X_train, y_train = X_train.to(DEVICE), y_train.to(DEVICE)
X_test, y_test = X_test.to(DEVICE), y_test.to(DEVICE)

Create custom dataset and dataloater

In [93]:
from torch.utils.data import DataLoader, Dataset


class CustomImageDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


train_loader = DataLoader(
    CustomImageDataset(X_train, y_train), batch_size=BATCH_SIZE, shuffle=True
)
test_loader = DataLoader(
    CustomImageDataset(X_test, y_test), batch_size=BATCH_SIZE, shuffle=True
)

---

# Creating a model

In [118]:
import torch.nn as nn

class SimpleEEGModel(nn.Module):
    def __init__(
        self,
    ):
        super(SimpleEEGModel, self).__init__()
        
        self.lstm = nn.LSTM(
            input_size=19,
            hidden_size=50,
            num_layers=2,
            batch_first=True,
            dropout=0.0
        )
        
        self.clf = nn.Linear(
            in_features=50,
            out_features=6
        )
    
    def forward(self, x):
        x, _ = self.lstm(x.permute(0, 2, 1))
        x = x[:, -1, :]
        x = self.clf(x)
        return x

In [119]:
class ConvLSTM(nn.Module):
    def __init__(self):
        super(ConvLSTM, self).__init__()
        
        self.conv = nn.Conv1d(19, 19, kernel_size=4, stride=4)
        self.lstm = nn.LSTM(
            input_size=19,
            hidden_size=50,
            num_layers=2,
            batch_first=True,
            dropout=0.0
        )
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(50, 6)
        self.softmax = nn.Softmax(dim=1)

    # input in CHW / CW format
    def forward(self, x):
        x = self.conv(x)
        x = x.permute(0, 2, 1)  # (batch_size, seq_length, input_size)
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        x = self.flatten(x)
        x = self.fc(x)
        out = self.softmax(x)
        return out

# Training (TBD)

---

In [120]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn import CrossEntropyLoss
from torch.optim import Adam, SGD


model = SimpleEEGModel().to(DEVICE)
optimizer = Adam(model.parameters(), lr=1e-2)
# scheduler = ReduceLROnPlateau(optimizer, 'min')
loss_function = nn.KLDivLoss(reduction="batchmean", log_target=False)

In [122]:
EPOCHS = 10

total_loss = 0
for epoch in range(EPOCHS):
    iteration = 0
    total_loss = 0
    for batch in train_loader:        
        
        # batch_data = nn.functional.normalize(batch[0], dim=-1)
        batch_data = batch[0]
        
        batch_labels = batch[1]
        
        prediction = model(batch_data)
        
        loss = loss_function(nn.functional.log_softmax(prediction, dim=1), batch_labels)
        print(loss.item())
        
        total_loss += float(loss.item())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        iteration += 1
        
    # scheduler.step()
    print("==========================")
    print(total_loss / iteration)
    print("==========================")

1.0269769430160522
1.3157577514648438
1.1721874475479126
0.8918254375457764
1.2959089279174805
1.2959293127059937
1.335465908050537
1.3814728260040283
1.018403172492981
1.2768104076385498
1.1708396673202515
1.4860143661499023
1.1796013116836548
1.0468554496765137
1.3494267463684082
1.2096214294433594
0.980535626411438
1.3907415866851807
1.4375603199005127
1.4014742374420166
1.0343414545059204
1.5782907009124756
1.1954923868179321
1.1283328533172607
1.1722311973571777
1.2290353775024414
1.0813007354736328
1.1492975950241089
0.9183259010314941
1.4774136543273926
1.1765300035476685
1.1881678104400635
1.535722255706787
1.6673552989959717
1.156713604927063
1.5948553085327148
1.266348123550415
1.1769192218780518
1.185274362564087
1.4635636806488037
1.3391674757003784
1.473076343536377
1.0483709573745728
1.1125916242599487
1.254305362701416
1.0151498317718506
1.4473389387130737
0.9743373990058899
0.9801393151283264
1.4526491165161133
1.155890703201294
1.4465030431747437
1.0221357345581055
1.1

KeyboardInterrupt: 

---