# **Install**

In [None]:
!pip install torchmetrics

# **Imports 📢**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset


import numpy as np
import pandas as pd
from PIL import Image
from scipy.io import loadmat
import matplotlib.pyplot as plt
from torchmetrics import Accuracy
from sklearn.model_selection import train_test_split
import seaborn as sns

from tqdm import tqdm

from scipy.io import loadmat
import os

# **Utils 🧰**

In [None]:
def cal_metrics (all_targets, all_outputs):
  from sklearn import metrics
  all_targets = all_targets.detach().cpu().numpy()
  all_outputs = all_outputs.detach().cpu().numpy()

  acc = metrics.accuracy_score(all_targets, all_outputs)
  macro_precision = metrics.precision_score(all_targets, all_outputs, average = 'macro', zero_division=1)
  macro_recall = metrics.recall_score(all_targets, all_outputs, average = 'macro')
  macro_f1 = metrics.f1_score(all_targets, all_outputs, average = 'macro')

  return acc, macro_precision, macro_recall, macro_f1

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def num_params(model):
  nums = sum(p.numel() for p in model.parameters())/1e6
  return nums

In [None]:
def num_trainable_params(model):
  nums = sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6
  return nums

In [None]:
def calculate_metrics(predictions, targets):
    # Convert softmax predictions to class labels
    predicted_labels = torch.argmax(predictions, dim=1)

    # Calculate true positives, false positives, and false negatives
    true_positives = torch.sum((predicted_labels == 1) & (targets == 1)).item()
    false_positives = torch.sum((predicted_labels == 1) & (targets == 0)).item()
    false_negatives = torch.sum((predicted_labels == 0) & (targets == 1)).item()

    # Calculate precision
    precision = true_positives / (true_positives + false_positives + 1e-7)

    # Calculate recall
    recall = true_positives / (true_positives + false_negatives + 1e-7)

    # Calculate F1 score
    f1_score = 2 * (precision * recall) / (precision + recall + 1e-7)

    return f1_score, precision, recall

In [None]:
from sklearn.metrics import confusion_matrix
def save_confusion_matrix(targets, predicted_labels, classes, save_path):
    predicted_labels = torch.argmax(predicted_labels, dim=1)
    cm = confusion_matrix(targets.cpu().numpy(), predicted_labels.cpu().numpy())
    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]  # Normalize confusion matrix

    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title("Confusion Matrix")
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    # Format and display the confusion matrix values
    fmt = '.2f'
    thresh = cm.max() / 2.
    for i, j in np.ndindex(cm.shape):
        plt.text(j, i, format(cm[i, j], fmt), ha='center', va='center',
                 color='white' if cm[i, j] > thresh else 'black')

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.savefig(save_path, format='png')
    plt.close()
    # Calculate sensitivity and specificity
    tn, fp, fn, tp = cm.ravel()
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)

    return sensitivity, specificity

In [None]:
from sklearn import metrics
def plot_ROC(targets, predicted_labels, save_path):
  # predicted_labels = torch.argmax(predicted_labels, dim=1)
  fpr, tpr, _ = metrics.roc_curve(targets.cpu().numpy(),  predicted_labels[:,1].cpu().numpy())

  noskill_probabilities = [0 for number in range(len(targets.cpu().numpy()))]
  fprno, tprno, _ = metrics.roc_curve(targets.cpu().numpy(),  noskill_probabilities)
  #create ROC curve
  plt.plot(fprno,tprno,'b--')
  plt.plot(fpr,tpr,'r')
  plt.ylabel('True Positive Rate')
  plt.xlabel('False Positive Rate')
  plt.savefig(save_path, format='png')
  plt.close()
  return 0

# **Device ⚙️**

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

# **Dataset 🗂️**

In [None]:
channels = [0, 1, 2, 3, 4, 5] # Frontal = [0, 1, 2, 3, 4, 5], Central = [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], Parietal = [18, 19, 20, 21],
# All = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]
task = 'right' # left, right, foot, tongue
duration = 2 # second
apply_filter = False
fl, fh = [0.5, 4] # Delta = [0.5, 4], Theta = [4, 8], Alpha = [8, 13], Beta = [13, 30], Gamma = [30, 100]

## Load dataset

In [None]:
from scipy.signal import butter, filtfilt
fs = 250  # Sampling frequency

order = 5  # Filter order

# Create bandpass filter coefficients
nyq = 0.5 * fs
low = fl / nyq
high = fh / nyq
b, a = butter(order, [low, high], btype='band')

In [None]:
df = []
for i in range(1,10):
  data = loadmat(f'/gdrive/MyDrive/Motor_Imagery/BCI2a/subjects1000/sub{i}/data_{task}_sub{i}.mat')
  data_val = loadmat(f'/gdrive/MyDrive/Motor_Imagery/BCI2a/subjects1000_val/sub{i}/data_{task}_sub{i}.mat')
  if duration == 4:
    data1 = data[f'data_{task}'][:,channels,:]
    data_val = data_val[f'data_{task}'][:,channels,:]
    data = np.concatenate((data1, data_val), axis=0)
    if apply_filter == True:
      data = filtfilt(b, a, data) #frequency filter
    label = [i for i in range(1, 10) for _ in range(data.shape[0])]
    label = np.array(label).reshape((9, data.shape[0]))
    df.append(data)
  if duration == 2:
    data1 = data[f'data_{task}'][:,channels,:500]
    data2 = data[f'data_{task}'][:,channels,500:1000]
    data1_val = data_val[f'data_{task}'][:,channels,:500]
    data2_val = data_val[f'data_{task}'][:,channels,500:1000]
    data = np.concatenate((data1, data2, data1_val, data2_val), axis=0)
    if apply_filter == True:
      data = filtfilt(b, a, data) #frequency filter
    label = [i for i in range(1, 10) for _ in range(data.shape[0])]
    label = np.array(label).reshape((9, data.shape[0]))
    df.append(data)
  if duration == 1:
    data1 = data[f'data_{task}'][:,channels,:250]
    data2 = data[f'data_{task}'][:,channels,250:500]
    data3 = data[f'data_{task}'][:,channels,500:750]
    data4 = data[f'data_{task}'][:,channels,750:1000]
    data1_val = data_val[f'data_{task}'][:,channels,:250]
    data2_val = data_val[f'data_{task}'][:,channels,250:500]
    data3_val = data_val[f'data_{task}'][:,channels,500:750]
    data4_val = data_val[f'data_{task}'][:,channels,750:1000]
    data = np.concatenate((data1, data2, data3, data4, data1_val, data2_val, data3_val, data4_val), axis=0)
    # data = np.concatenate((data1, data2, data3), axis=0) #for data with 75 sample
    if apply_filter == True:
      data = filtfilt(b, a, data) #frequency filter
    label = [i for i in range(1, 10) for _ in range(data.shape[0])]
    label = np.array(label).reshape((9, data.shape[0]))
    df.append(data)
df = np.array(df)
print(df.shape)
num_trial = df.shape[1]
num_ch = df.shape[2]
num_smaple = df.shape[3]
df = df.reshape((9*num_trial,num_ch,num_smaple))
label = np.array(label)
label = label.reshape((9*num_trial,))

In [None]:
# df = []
# for i in range(1,10):
#   data = loadmat(f'/gdrive/MyDrive/Motor_Imagery/BCI2a/subjects/sub{i}/data_{task}_sub{i}.mat')
#   data = data[f'data_{task}'][:,channels,:250]
#   # data = filtfilt(b, a, data) #frequency filter
#   label = [i for i in range(1, 10) for _ in range(72)]
#   label = np.array(label).reshape((9, 72))
#   df.append(data)
# df = np.array(df)
# print(df.shape)
# num_ch = df.shape[2]
# num_smaple = df.shape[3]
# df = df.reshape((9*72,num_ch,num_smaple))
# label = np.array(label)
# label = label.reshape((9*72,))

In [None]:
print(df.shape)
print(label.shape)

In [None]:
label = label -1

In [None]:
x_train, x_valid, y_train, y_valid = train_test_split(df, label, test_size=0.2, random_state=23)
x_train, x_test, y_train, y_test = train_test_split(x_train, y_train, test_size=0.1, random_state=23)

In [None]:
x_train.shape

In [None]:
x_test.shape

In [None]:
x_train = torch.FloatTensor(x_train)
x_train = x_train.unsqueeze(1)
y_train = torch.LongTensor(y_train)
y_train = y_train.squeeze()

x_valid = torch.FloatTensor(x_valid)
x_valid = x_valid.unsqueeze(1)
y_valid = torch.LongTensor(y_valid)
y_valid = y_valid.squeeze()

x_test = torch.FloatTensor(x_test)
x_test = x_test.unsqueeze(1)
y_test = torch.LongTensor(y_test)
y_test = y_test.squeeze()

mu = x_train.mean(dim=0)
std = x_train.std(dim=0)

x_train = (x_train - mu) / std
x_valid = (x_valid - mu) / std
x_test = (x_test - mu) / std

In [None]:
x_train.shape, y_train.shape

In [None]:
x_valid.shape, y_valid.shape

In [None]:
y_train

In [None]:
torch.unique(y_train)

## TensorDataset

In [None]:
train_dataset = TensorDataset(x_train, y_train)
valid_dataset = TensorDataset(x_valid, y_valid)
test_dataset = TensorDataset(x_test, y_test)

## DataLoader

In [None]:
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=130, shuffle=True)

In [None]:
x, y = next(iter(train_loader))
print(x.shape)
print(y.shape)

# CNN

In [None]:
# class CNN(nn.Module):
#     def __init__(self, num_ch):
#         super().__init__()

#         self.conv1 = nn.Conv2d(1, 64, kernel_size=(num_ch-1, 1), padding=1)
#         self.batchnorm1 = nn.BatchNorm2d(64)
#         self.relu1 = nn.ReLU()

#         self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
#         self.batchnorm2 = nn.BatchNorm2d(64)
#         self.relu2 = nn.ReLU()

#         self.maxpool1 = nn.MaxPool2d(2, 2)

#         self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
#         self.batchnorm3 = nn.BatchNorm2d(128)
#         self.relu3 = nn.ReLU()

#         self.conv4 = nn.Conv2d(128, 128, 3, padding=1)
#         self.batchnorm4 = nn.BatchNorm2d(128)
#         self.relu4 = nn.ReLU()

#         self.maxpool2 = nn.MaxPool2d(2, 2)

#         self.conv5 = nn.Conv2d(128, 256, 3, padding=1)
#         self.batchnorm5 = nn.BatchNorm2d(256)
#         self.relu5 = nn.ReLU()

#         self.conv6 = nn.Conv2d(256, 256, 3, padding=1)
#         self.batchnorm6 = nn.BatchNorm2d(256)
#         self.relu6 = nn.ReLU()

#         self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
#         self.flatten = nn.Flatten()
#         self.fc = nn.Linear(256, 9)

#     def forward(self, x):
#         print(x.shape)
#         x = self.relu1(self.batchnorm1(self.conv1(x)))
#         print(x.shape)
#         x = self.relu2(self.batchnorm2(self.conv2(x)))
#         print(x.shape)
#         x = self.maxpool1(x)
#         print(x.shape)
#         x = self.relu3(self.batchnorm3(self.conv3(x)))
#         print(x.shape)
#         x = self.relu4(self.batchnorm4(self.conv4(x)))
#         print(x.shape)
#         x = self.maxpool2(x)
#         print(x.shape)
#         x = self.relu5(self.batchnorm5(self.conv5(x)))
#         print(x.shape)
#         x = self.relu6(self.batchnorm6(self.conv6(x)))
#         print(x.shape)
#         x = self.avgpool(x)
#         print(x.shape)
#         x = self.flatten(x)
#         print(x.shape)
#         print('-------------------------------------------------------------------------------------------------------')
#         x = self.fc(x)
#         return x


In [None]:
def CNN():
  network = nn.Sequential(nn.Conv2d(1, 64, kernel_size=(num_ch-1,1), padding=1),
                          nn.BatchNorm2d(64),
                          nn.ReLU(),

                          nn.Conv2d(64, 64, 3, padding=1),
                          nn.BatchNorm2d(64),
                          nn.ReLU(),

                          nn.MaxPool2d(2, 2), # BSx64x16x16

                          nn.Conv2d(64, 128, 3, padding=1),
                          nn.BatchNorm2d(128),
                          nn.ReLU(),

                          nn.Conv2d(128, 128, 3, padding=1),
                          nn.BatchNorm2d(128),
                          nn.ReLU(),

                          nn.MaxPool2d(2,2), # 8x8

                          nn.Conv2d(128, 256, 3, padding=1),
                          nn.BatchNorm2d(256),
                          nn.ReLU(),

                          nn.Conv2d(256, 256, 3, padding=1),
                          nn.BatchNorm2d(256),
                          nn.ReLU(),
                          # BSx256x8x8 -> BSx256x1x1
                          nn.AdaptiveAvgPool2d(output_size=(1, 1)), # BS1x1

                          nn.Flatten(), # BSx256
                          nn.Linear(256, 9)
                      )

  return network

In [None]:
# model = CNN().to(device)

# # Print the model architecture
# print(model)
# num_params(model)

# Functions

In [None]:
def train_one_epoch(model, train_loader, loss_fn, optimizer, epoch=None):
  model.train()
  loss_train = AverageMeter()
  acc_train = Accuracy(task="multiclass", num_classes=9).to(device)
  with tqdm(train_loader, unit="batch") as tepoch:
    for inputs, targets in tepoch:
      if epoch is not None:
        tepoch.set_description(f"Epoch {epoch}")
      inputs = inputs.to(device)
      targets = targets.to(device)

      outputs = model(inputs)


      loss = loss_fn(outputs, targets)

      loss.backward(retain_graph=True)

      optimizer.step()
      optimizer.zero_grad()

      loss_train.update(loss.item())
      acc_train(outputs, targets.int())
      tepoch.set_postfix(loss=loss_train.avg,
                         accuracy=100.*acc_train.compute().item())
  return model, loss_train.avg, acc_train.compute().item()

In [None]:
def train_one_epoch_kd(student, teacher, train_loader, loss_fn, optimizer, epoch=None):
  student.train()
  loss_train = AverageMeter()
  acc_train = Accuracy(task="multiclass", num_classes=9).to(device)
  with tqdm(train_loader, unit="batch") as tepoch:
    for inputs, targets in tepoch:
      if epoch is not None:
        tepoch.set_description(f"Epoch {epoch}")
      inputs = inputs.to(device)
      targets = targets.to(device)

      outputs = student(inputs)


      with torch.no_grad():
        teacher_outputs = teacher(inputs)

      loss = loss_fn_kd(outputs, targets, teacher_outputs, T=10, alpha=0.6)

      loss.backward()

      optimizer.step()
      optimizer.zero_grad()

      loss_train.update(loss.item())
      acc_train(outputs, targets.int())
      tepoch.set_postfix(loss=loss_train.avg,
                         accuracy=100.*acc_train.compute().item())
  return student, loss_train.avg, acc_train.compute().item()

In [None]:
def validation(model, test_loader, loss_fn):
  model.eval()
  with torch.no_grad():
    loss_valid = AverageMeter()
    acc_valid = Accuracy(task="multiclass", num_classes=9).to(device)

    all_targets = []
    all_outputs = []

    for i, (inputs, targets) in enumerate(test_loader):
      inputs = inputs.to(device)
      targets = targets.to(device)

      outputs = model(inputs)
      loss = loss_fn(outputs, targets)

      loss_valid.update(loss.item())
      acc_valid(outputs, targets.int())
      outputs = torch.argmax(outputs, dim=1)

      all_targets.append(targets)
      all_outputs.append(outputs)

    all_targets = torch.cat(all_targets, dim=0)
    all_outputs = torch.cat(all_outputs, dim=0)

  return loss_valid.avg, acc_valid.compute().item(), all_targets, all_outputs

# 5-fold (combine)

### Step 1: check forward path

Calculate loss for one batch

In [None]:
from openpyxl import load_workbook
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader, TensorDataset, SubsetRandomSampler
from scipy.signal import butter, filtfilt

In [None]:
def loss_fn_kd(outputs, labels, teacher_outputs, T, alpha):
  loss = F.kl_div(F.log_softmax(outputs/T, dim=1),
                  F.softmax(teacher_outputs/T, dim=1),
                  reduction='batchmean') * (alpha * T**2) + \
         F.cross_entropy(outputs, labels) * (1 - alpha)
  return loss

In [None]:
channels = [0, 1, 2, 3, 4, 5] # Frontal = [0, 1, 2, 3, 4, 5], Central = [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], Parietal = [18, 19, 20, 21],
# All = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]
# chan = [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], [18, 19, 20, 21], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]]
task = 'right' # left, right, foot, tongue
apply_filter = True
time = [4]
band = [[30, 100]]
num_epochs = 80

In [None]:
for fl, fh in band:
  if fl == 'a':
    apply_filter = False
  else:
    # ------------------------------------------------------------------ Train and Validation Data -------------------------------------------------
    fs = 250  # Sampling frequency
    order = 5  # Filter order
    # Create bandpass filter coefficients
    nyq = 0.5 * fs
    low = fl / nyq
    high = fh / nyq
    b, a = butter(order, [low, high], btype='band')
  for t in time:
    df = []
    for i in range(1,10):
      data = loadmat(f'/gdrive/MyDrive/Motor_Imagery/BCI2a/subjects1000/sub{i}/data_{task}_sub{i}.mat')
      data_val = loadmat(f'/gdrive/MyDrive/Motor_Imagery/BCI2a/subjects1000_val/sub{i}/data_{task}_sub{i}.mat')
      if t == 4:
        data1 = data[f'data_{task}'][:,channels,:]
        data_val = data_val[f'data_{task}'][:,channels,:]
        data = np.concatenate((data1, data_val), axis=0)
        if apply_filter == True:
          data = filtfilt(b, a, data) #frequency filter
        label = [i for i in range(1, 10) for _ in range(data.shape[0])]
        label = np.array(label).reshape((9, data.shape[0]))
        df.append(data)
      if t == 2:
        data1 = data[f'data_{task}'][:,channels,:500]
        data2 = data[f'data_{task}'][:,channels,500:1000]
        data1_val = data_val[f'data_{task}'][:,channels,:500]
        data2_val = data_val[f'data_{task}'][:,channels,500:1000]
        data = np.concatenate((data1, data2, data1_val, data2_val), axis=0)
        if apply_filter == True:
          data = filtfilt(b, a, data) #frequency filter
        label = [i for i in range(1, 10) for _ in range(data.shape[0])]
        label = np.array(label).reshape((9, data.shape[0]))
        df.append(data)
      if t == 1:
        data1 = data[f'data_{task}'][:,channels,:250]
        data2 = data[f'data_{task}'][:,channels,250:500]
        data3 = data[f'data_{task}'][:,channels,500:750]
        data4 = data[f'data_{task}'][:,channels,750:1000]
        data1_val = data_val[f'data_{task}'][:,channels,:250]
        data2_val = data_val[f'data_{task}'][:,channels,250:500]
        data3_val = data_val[f'data_{task}'][:,channels,500:750]
        data4_val = data_val[f'data_{task}'][:,channels,750:1000]
        data = np.concatenate((data1, data2, data3, data4, data1_val, data2_val, data3_val, data4_val), axis=0)
        # data = np.concatenate((data1, data2, data3), axis=0) #for data with 75 sample
        if apply_filter == True:
          data = filtfilt(b, a, data) #frequency filter
        label = [i for i in range(1, 10) for _ in range(data.shape[0])]
        label = np.array(label).reshape((9, data.shape[0]))
        df.append(data)
    df = np.array(df)
    print(df.shape)
    num_trial = df.shape[1]
    num_ch = df.shape[2]
    num_smaple = df.shape[3]
    df = df.reshape((9*num_trial,num_ch,num_smaple))
    label = np.array(label)
    label = label.reshape((9*num_trial,))
    label = label -1
    x_train, x_valid, y_train, y_valid = train_test_split(df, label, test_size=0.2, random_state=23)
    x_train, x_test, y_train, y_test = train_test_split(x_train, y_train, test_size=0.1, random_state=23)

    x_train = torch.FloatTensor(x_train)
    x_train = x_train.unsqueeze(1)
    y_train = torch.LongTensor(y_train)
    y_train = y_train.squeeze()
    x_valid = torch.FloatTensor(x_valid)
    x_valid = x_valid.unsqueeze(1)
    y_valid = torch.LongTensor(y_valid)
    y_valid = y_valid.squeeze()
    x_test = torch.FloatTensor(x_test)
    x_test = x_test.unsqueeze(1)
    y_test = torch.LongTensor(y_test)
    y_test = y_test.squeeze()

    mu = x_train.mean(dim=0)
    std = x_train.std(dim=0)
    x_train = (x_train - mu) / std
    x_valid = (x_valid - mu) / std
    x_test = (x_test - mu) / std

    train_dataset = TensorDataset(x_train, y_train)
    valid_dataset = TensorDataset(x_valid, y_valid)
    test_dataset = TensorDataset(x_test, y_test)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)

In [None]:
model = CNN().to(device)
loss_fn = nn.MultiMarginLoss()
lr = 0.00005
wd = 3e-4
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

In [None]:
# model = Multimodal(model1, model2).to(device)
# loss_fn = nn.CrossEntropyLoss()

x_batch, y_batch = next(iter(train_loader))
outputs = model(x_batch.to(device))
loss = loss_fn(outputs, y_batch.to(device))
print(loss)

In [None]:
outputs.shape, y_batch.shape

### Step 2: check backward path

Select 5 random batches and train the model

In [None]:
_, mini_train_dataset = random_split(train_dataset, (len(train_dataset)-5000,5000))
mini_train_loader = DataLoader(mini_train_dataset, 50)

In [None]:
# model = RNNModel(nn.LSTM, 1, 16, 1, False, 2).to(device)
# model = CNNModel([64, 64], [3, 3], 2).to(device)
# model = CNNLSTM(1, 32, 128, 3, 2).to(device)


# loss_fn = nn.CrossEntropyLoss()

In [None]:
# optimizer = optim.Adam(model.parameters(), lr=0.01)
model = CNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=wd)
loss_fn = nn.MultiMarginLoss()

In [None]:
num_epochs = 15
for epoch in range(num_epochs):
  model, _, _ = train_one_epoch(model, mini_train_loader, loss_fn, optimizer, epoch)

In [None]:
torch.cuda.empty_cache()

## save test

In [None]:
from openpyxl import load_workbook
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader, TensorDataset, SubsetRandomSampler
from scipy.signal import butter, filtfilt

In [None]:
def loss_fn_kd(outputs, labels, teacher_outputs, T, alpha):
  loss = F.kl_div(F.log_softmax(outputs/T, dim=1),
                  F.softmax(teacher_outputs/T, dim=1),
                  reduction='batchmean') * (alpha * T**2) + \
         F.cross_entropy(outputs, labels) * (1 - alpha)
  return loss

In [None]:
channels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21] # Frontal = [0, 1, 2, 3, 4, 5], Central = [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], Parietal = [18, 19, 20, 21],
# All = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]
# chan = [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], [18, 19, 20, 21], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]]
task = 'left' # left, right, foot, tongue
apply_filter = True
time = [4] #[4, 2]
band = [['a', 'a']] #[[0.5, 4], [4, 8], [8, 13], [13, 30], [30, 100], ['a', 'a']]
num_epochs = 50

In [None]:
for fl, fh in band:
  if fl == 'a':
    apply_filter = False
  else:
    # ------------------------------------------------------------------ Train and Validation Data -------------------------------------------------
    fs = 250  # Sampling frequency
    order = 5  # Filter order
    # Create bandpass filter coefficients
    nyq = 0.5 * fs
    low = fl / nyq
    high = fh / nyq
    b, a = butter(order, [low, high], btype='band')
  for t in time:
    df = []
    for i in range(1,10):
      data = loadmat(f'/gdrive/MyDrive/Motor_Imagery/BCI2a/subjects1000/sub{i}/data_{task}_sub{i}.mat')
      data_val = loadmat(f'/gdrive/MyDrive/Motor_Imagery/BCI2a/subjects1000_val/sub{i}/data_{task}_sub{i}.mat')
      if t == 4:
        data1 = data[f'data_{task}'][:,channels,:]
        data_val = data_val[f'data_{task}'][:,channels,:]
        data = np.concatenate((data1, data_val), axis=0)
        if apply_filter == True:
          data = filtfilt(b, a, data) #frequency filter
        label = [i for i in range(1, 10) for _ in range(data.shape[0])]
        label = np.array(label).reshape((9, data.shape[0]))
        df.append(data)
      if t == 2:
        data1 = data[f'data_{task}'][:,channels,:500]
        data2 = data[f'data_{task}'][:,channels,500:1000]
        data1_val = data_val[f'data_{task}'][:,channels,:500]
        data2_val = data_val[f'data_{task}'][:,channels,500:1000]
        data = np.concatenate((data1, data2, data1_val, data2_val), axis=0)
        if apply_filter == True:
          data = filtfilt(b, a, data) #frequency filter
        label = [i for i in range(1, 10) for _ in range(data.shape[0])]
        label = np.array(label).reshape((9, data.shape[0]))
        df.append(data)
      if t == 1:
        data1 = data[f'data_{task}'][:,channels,:250]
        data2 = data[f'data_{task}'][:,channels,250:500]
        data3 = data[f'data_{task}'][:,channels,500:750]
        data4 = data[f'data_{task}'][:,channels,750:1000]
        data1_val = data_val[f'data_{task}'][:,channels,:250]
        data2_val = data_val[f'data_{task}'][:,channels,250:500]
        data3_val = data_val[f'data_{task}'][:,channels,500:750]
        data4_val = data_val[f'data_{task}'][:,channels,750:1000]
        data = np.concatenate((data1, data2, data3, data4, data1_val, data2_val, data3_val, data4_val), axis=0)
        # data = np.concatenate((data1, data2, data3), axis=0) #for data with 75 sample
        if apply_filter == True:
          data = filtfilt(b, a, data) #frequency filter
        label = [i for i in range(1, 10) for _ in range(data.shape[0])]
        label = np.array(label).reshape((9, data.shape[0]))
        df.append(data)
    df = np.array(df)
    print(df.shape)
    num_trial = df.shape[1]
    num_ch = df.shape[2]
    num_smaple = df.shape[3]
    df = df.reshape((9*num_trial,num_ch,num_smaple))
    label = np.array(label)
    label = label.reshape((9*num_trial,))
    label = label -1
    x_train, x_valid, y_train, y_valid = train_test_split(df, label, test_size=0.2, random_state=23)
    x_train, x_test, y_train, y_test = train_test_split(x_train, y_train, test_size=0.1, random_state=23)
    # print(x_train.shape, x_valid.shape, x_test.shape)
    # break
    x_train = torch.FloatTensor(x_train)
    x_train = x_train.unsqueeze(1)
    y_train = torch.LongTensor(y_train)
    y_train = y_train.squeeze()
    x_valid = torch.FloatTensor(x_valid)
    x_valid = x_valid.unsqueeze(1)
    y_valid = torch.LongTensor(y_valid)
    y_valid = y_valid.squeeze()
    x_test = torch.FloatTensor(x_test)
    x_test = x_test.unsqueeze(1)
    y_test = torch.LongTensor(y_test)
    y_test = y_test.squeeze()

    mu = x_train.mean(dim=0)
    std = x_train.std(dim=0)
    x_train = (x_train - mu) / std
    x_valid = (x_valid - mu) / std
    x_test = (x_test - mu) / std

    train_dataset = TensorDataset(x_train, y_train)
    valid_dataset = TensorDataset(x_valid, y_valid)
    test_dataset = TensorDataset(x_test, y_test)

    # --------------------------------------------------------------- K-Fold cross-validation -------------------------------------------------------
    kf = KFold(n_splits=3, shuffle=True, random_state=42)
    all_loss_test_hist = []
    all_acc_test_hist = []
    all_precision_test_hist = []
    all_recall_test_hist = []
    all_f1_test_hist = []
    all_loss_test_hist_s = []
    all_acc_test_hist_s = []
    all_precision_test_hist_s = []
    all_recall_test_hist_s = []
    all_f1_test_hist_s = []
    all_targests_test_hist = []
    all_outputs_test_hist = []

    for fold, (train_idx, valid_idx) in enumerate(kf.split(x_train)):
      print(f"Fold {fold+1}, fl = {fl}, t = {t}")
      train_sampler = SubsetRandomSampler(train_idx)
      valid_sampler = SubsetRandomSampler(valid_idx)
      train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=32)
      valid_loader = DataLoader(train_dataset, sampler=valid_sampler, batch_size=32)
      test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

      model = CNN().to(device)
      loss_fn = nn.MultiMarginLoss()
      lr = 0.00005
      wd = 3e-4
      optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

      best_loss_valid = float('inf')

      for epoch in range(num_epochs):
        # Train
        model, loss_train, acc_train = train_one_epoch(model, train_loader, loss_fn, optimizer, epoch)

        # Validation
        loss_valid, acc_valid, _, _ = validation(model, valid_loader, loss_fn)

        if loss_valid < best_loss_valid:
            path = '/gdrive/MyDrive/Motor_Imagery'
            torch.save(model, path + '/model_5_fold.pt')
            best_loss_valid = loss_valid
            print('Model Saved!')

        print(f'Valid: Loss = {loss_valid:.4}, Acc = {acc_valid:.4}')
        print()

      model = torch.load('/gdrive/MyDrive/Motor_Imagery/model_5_fold.pt')
      final_loss_test, final_acc_test, all_targets_test, all_outputs_test = validation(model, test_loader, loss_fn)
      acc, macro_precision, macro_recall, macro_f1 = cal_metrics(all_targets_test, all_outputs_test)

      all_loss_test_hist.append(final_loss_test)
      all_acc_test_hist.append(final_acc_test)
      all_precision_test_hist.append(macro_precision)
      all_recall_test_hist.append(macro_recall)
      all_f1_test_hist.append(macro_f1)

      #------------------------------------------------------------KD----------------------------------------------------------------------

      teacher = torch.load('/gdrive/MyDrive/Motor_Imagery/model_5_fold.pt')
      # teacher.eval()
      student = CNN().to(device)
      lr = 0.00005
      wd = 3e-4
      optimizer = optim.Adam(student.parameters(), lr=lr, weight_decay=wd)
      loss_fn = nn.MultiMarginLoss()

      best_loss_valid_s = torch.inf
      epoch_counter = 0


      for epoch in range(num_epochs):
        # Train
        student, loss_train, acc_train = train_one_epoch_kd(student,
                                                            teacher,
                                                            train_loader,
                                                            loss_fn_kd,
                                                            optimizer,
                                                            epoch)
        # Validation
        loss_valid, acc_valid, _, _ = validation(student,
                                                 valid_loader,
                                                 loss_fn)

        if loss_valid < best_loss_valid_s:
          # path = '/gdrive/MyDrive/Motor_Imagery'
          # torch.save(model, path + '/model_6_fold.pt')
          best_loss_valid_s = loss_valid
          print('best')

        print(f'Valid: Loss = {loss_valid:.4}, Acc = {acc_valid:.4}')
        print()

        epoch_counter += 1

      # student = torch.load('/gdrive/MyDrive/Motor_Imagery/model_6_fold.pt')
      # student.eval()
      final_loss_test, final_acc_test, all_targets_test, all_outputs_test = validation(student, test_loader, loss_fn)
      acc, macro_precision, macro_recall, macro_f1 = cal_metrics(all_targets_test, all_outputs_test)

      all_loss_test_hist_s.append(final_loss_test)
      all_acc_test_hist_s.append(final_acc_test)
      all_precision_test_hist_s.append(macro_precision)
      all_recall_test_hist_s.append(macro_recall)
      all_f1_test_hist_s.append(macro_f1)
      all_targests_test_hist.append(all_targets_test)
      all_outputs_test_hist.append(all_outputs_test)

  # --------------------------------------------Save Results----------------------------------------------------------

    a1 = sum(all_loss_test_hist)/3
    a2 = sum(all_loss_test_hist_s)/3
    b1 = sum(all_acc_test_hist)/3
    b2 = sum(all_acc_test_hist_s)/3
    c1 = sum(all_precision_test_hist)/3
    c2 = sum(all_precision_test_hist_s)/3
    d1 = sum(all_recall_test_hist)/3
    d2 = sum(all_recall_test_hist_s)/3
    e1 = sum(all_f1_test_hist)/3
    e2 = sum(all_f1_test_hist_s)/3


    df = pd.DataFrame([[e1, d1, c1, b1*100, a1, e2, d2, c2, b2*100, a2]],
                      columns=['f1', 'recall', 'precision', 'acc', 'loss', 'f1_s', 'recall_s', 'precision_s', 'acc_s', 'loss_s'])

    # # Path to the Excel file
    # excel_file_path = '/gdrive/MyDrive/Motor_Imagery/resultsyyyyyyyy.xlsx'

    # if os.path.exists(excel_file_path):
    #     # If the file exists, read the existing data
    #     existing_df = pd.read_excel(excel_file_path)

    #     # Append the new data
    #     updated_df = pd.concat([existing_df, df], ignore_index=True)
    # else:
    #     # If the file does not exist, create a new DataFrame
    #     updated_df = df

    # # Write the updated DataFrame back to the Excel file
    # with pd.ExcelWriter(excel_file_path, engine='openpyxl', mode='w') as writer:
    #     updated_df.to_excel(writer, index=False)

In [None]:
len(train_dataset)

In [None]:
df

In [None]:
from sklearn.metrics import confusion_matrix

In [None]:
all_targets_test_hists = np.concatenate([t.cpu().numpy() for t in all_targests_test_hist])
all_outputs_test_hists = np.concatenate([t.cpu().numpy() for t in all_outputs_test_hist])

# Now you can create the confusion matrix:
cm = confusion_matrix(all_targets_test_hists, all_outputs_test_hists)

In [None]:
# cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]  # Normalize confusion matrix

plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.colorbar()
tick_marks = np.arange(9)
plt.xticks(tick_marks, ['1','2','3','4','5','6','7','8','9'], rotation=45)
plt.yticks(tick_marks, ['1','2','3','4','5','6','7','8','9'])
fmt = '.2f'
thresh = cm.max() / 2.
for i, j in np.ndindex(cm.shape):
   plt.text(j, i, format(cm[i, j], fmt), ha='center', va='center',
            color='white' if cm[i, j] > thresh else 'black')

plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
# plt.savefig(save_path, format='png')
# plt.close()

## save valid


In [None]:
from openpyxl import load_workbook
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader, TensorDataset, SubsetRandomSampler

In [None]:
def loss_fn_kd(outputs, labels, teacher_outputs, T, alpha):
  loss = F.kl_div(F.log_softmax(outputs/T, dim=1),
                  F.softmax(teacher_outputs/T, dim=1),
                  reduction='batchmean') * (alpha * T**2) + \
         F.cross_entropy(outputs, labels) * (1 - alpha)
  return loss

In [None]:
channels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21] # Frontal = [0, 1, 2, 3, 4, 5], Central = [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], Parietal = [18, 19, 20, 21],
# All = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]
# chan = [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], [18, 19, 20, 21], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]]
task = 'tongue' # left, right, foot, tongue
apply_filter = True
time = [4,1]
band = [[0.5, 4], [4, 8], [8, 13], [13, 30], [30, 100], ['a', 'a']]
num_epochs = 80

In [None]:
for fl, fh in band:
  if fl == 'a':
    apply_filter = False
  else:
    fs = 250  # Sampling frequency
    order = 5  # Filter order
    # Create bandpass filter coefficients
    nyq = 0.5 * fs
    low = fl / nyq
    high = fh / nyq
    b, a = butter(order, [low, high], btype='band')
  for t in time:
    df = []
    for i in range(1,10):
      data = loadmat(f'/gdrive/MyDrive/Motor_Imagery/BCI2a/subjects1000/sub{i}/data_{task}_sub{i}.mat')
      if t == 4:
        data = data[f'data_{task}'][:,channels,:]
        if apply_filter == True:
          data = filtfilt(b, a, data) #frequency filter
        label = [i for i in range(1, 10) for _ in range(data.shape[0])]
        label = np.array(label).reshape((9, data.shape[0]))
        df.append(data)
      if t == 1:
        data1 = data[f'data_{task}'][:,channels,:250]
        data2 = data[f'data_{task}'][:,channels,250:500]
        data3 = data[f'data_{task}'][:,channels,500:750]
        data4 = data[f'data_{task}'][:,channels,750:1000]
        data = np.concatenate((data1, data2, data3, data4), axis=0)
        # data = np.concatenate((data1, data2, data3), axis=0) #for data with 75 sample
        if apply_filter == True:
          data = filtfilt(b, a, data) #frequency filter
        label = [i for i in range(1, 10) for _ in range(data.shape[0])]
        label = np.array(label).reshape((9, data.shape[0]))
        df.append(data)
    df = np.array(df)
    print(df.shape)
    num_trial = df.shape[1]
    num_ch = df.shape[2]
    num_smaple = df.shape[3]
    df = df.reshape((9*num_trial,num_ch,num_smaple))
    label = np.array(label)
    label = label.reshape((9*num_trial,))
    label = label -1
    x_train, x_valid, y_train, y_valid = train_test_split(df, label, test_size=0.2, random_state=23)
    x_train, x_test, y_train, y_test = train_test_split(x_train, y_train, test_size=0.1, random_state=23)

    x_train = torch.FloatTensor(x_train)
    x_train = x_train.unsqueeze(1)
    y_train = torch.LongTensor(y_train)
    y_train = y_train.squeeze()
    x_valid = torch.FloatTensor(x_valid)
    x_valid = x_valid.unsqueeze(1)
    y_valid = torch.LongTensor(y_valid)
    y_valid = y_valid.squeeze()
    x_test = torch.FloatTensor(x_test)
    x_test = x_test.unsqueeze(1)
    y_test = torch.LongTensor(y_test)
    y_test = y_test.squeeze()
    mu = x_train.mean(dim=0)
    std = x_train.std(dim=0)
    x_train = (x_train - mu) / std
    x_valid = (x_valid - mu) / std
    x_test = (x_test - mu) / std

    x_train_tensor = torch.tensor(x_train, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train, dtype=torch.long)
    x_test_tensor = torch.tensor(x_test, dtype=torch.float32)
    y_test_tensor = torch.tensor(y_test, dtype=torch.long)
    train_dataset = TensorDataset(x_train_tensor, y_train_tensor)

    # KFold cross-validation
    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    all_loss_valid_hist = []
    all_acc_valid_hist = []
    all_precision_valid_hist = []
    all_recall_valid_hist = []
    all_f1_valid_hist = []
    all_loss_valid_hist_s = []
    all_acc_valid_hist_s = []
    all_precision_valid_hist_s = []
    all_recall_valid_hist_s = []
    all_f1_valid_hist_s = []

    for fold, (train_idx, valid_idx) in enumerate(kf.split(x_train)):
      print(f"Fold {fold+1}, fl = {fl}, t = {t}")
      train_sampler = SubsetRandomSampler(train_idx)
      valid_sampler = SubsetRandomSampler(valid_idx)
      train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=32)
      valid_loader = DataLoader(train_dataset, sampler=valid_sampler, batch_size=32)

      model = CNN().to(device)
      loss_fn = nn.MultiMarginLoss()
      lr = 0.00005
      wd = 3e-4
      optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

      # Histories for this fold
      loss_valid_hist = []
      acc_valid_hist = []
      precision_valid_hist = []
      recall_valid_hist = []
      f1_valid_hist = []

      best_loss_valid = float('inf')

      for epoch in range(num_epochs):
        # Train
        model, loss_train, acc_train = train_one_epoch(model, train_loader, loss_fn, optimizer, epoch)

        # Validation
        loss_valid, acc_valid, all_targets, all_outputs = validation(model, valid_loader, loss_fn)

        acc, macro_precision, macro_recall, macro_f1 = cal_metrics(all_targets, all_outputs)

        loss_valid_hist.append(loss_valid)
        acc_valid_hist.append(acc_valid)
        precision_valid_hist.append(macro_precision)
        recall_valid_hist.append(macro_recall)
        f1_valid_hist.append(macro_f1)

        if loss_valid < best_loss_valid:
            path = '/gdrive/MyDrive/Motor_Imagery'
            torch.save(model, path + '/model_5_fold.pt')
            best_loss_valid = loss_valid
            print('Model Saved!')
            print(f'macro_precision = {macro_precision:.4}, macro_recall = {macro_recall:.4}, macro_f1 = {macro_f1:.4}')


        print(f'Valid: Loss = {loss_valid:.4}, Acc = {acc_valid:.4}')
        print()

      all_loss_valid_hist.append(loss_valid_hist)
      all_acc_valid_hist.append(acc_valid_hist)
      all_precision_valid_hist.append(precision_valid_hist)
      all_recall_valid_hist.append(recall_valid_hist)
      all_f1_valid_hist.append(f1_valid_hist)

      #------------------------------------------------------------KD----------------------------------------------------------------------

      loss_valid_hist = []
      acc_valid_hist = []
      precision_valid_hist = []
      recall_valid_hist = []
      f1_valid_hist = []

      teacher = torch.load('/gdrive/MyDrive/Motor_Imagery/model_5_fold.pt')
      teacher.eval()
      student = CNN().to(device)
      lr = 0.00005
      wd = 3e-4
      optimizer = optim.Adam(student.parameters(), lr=lr, weight_decay=wd)
      loss_fn = nn.MultiMarginLoss()

      best_loss_valid_s = torch.inf
      epoch_counter = 0


      for epoch in range(num_epochs):
        # Train
        student, loss_train, acc_train = train_one_epoch_kd(student,
                                                            teacher,
                                                            train_loader,
                                                            loss_fn_kd,
                                                            optimizer,
                                                            epoch)
        # Validation
        loss_valid, acc_valid, all_targets, all_outputs = validation(student,
                                                                    valid_loader,
                                                                    loss_fn)

        acc, macro_precision, macro_recall, macro_f1 = cal_metrics(all_targets, all_outputs)

        loss_valid_hist.append(loss_valid)
        acc_valid_hist.append(acc_valid)
        precision_valid_hist.append(macro_precision)
        recall_valid_hist.append(macro_recall)
        f1_valid_hist.append(macro_f1)


        if loss_valid < best_loss_valid_s:
          # path = '/gdrive/MyDrive/Motor_Imagery'
          # torch.save(model, path + '/model' + '.pt')
          best_loss_valid_s = loss_valid
          print('best')
          print(f'macro_precision = {macro_precision:.4}, macro_recall = {macro_recall:.4}, macro_f1 = {macro_f1:.4}')


        print(f'Valid: Loss = {loss_valid:.4}, Acc = {acc_valid:.4}')
        print()

        epoch_counter += 1

      all_loss_valid_hist_s.append(loss_valid_hist)
      all_acc_valid_hist_s.append(acc_valid_hist)
      all_precision_valid_hist_s.append(precision_valid_hist)
      all_recall_valid_hist_s.append(recall_valid_hist)
      all_f1_valid_hist_s.append(f1_valid_hist)

  # --------------------------------------------Save Results----------------------------------------------------------

    a1=0
    a2=0
    b1=0
    b2=0
    c1=0
    c2=0
    d1=0
    d2=0
    e1=0
    e2=0

    for i in range(0,5):
      a1 = a1 + min(all_loss_valid_hist[i])
    a1 = a1/5
    for i in range(0,5):
      a2 = a2 + min(all_loss_valid_hist_s[i])
    a2 = a2/5

    for i in range(0,5):
      b1 = b1 + max(all_acc_valid_hist[i])
    b1 = b1/5
    for i in range(0,5):
      b2 = b2 + max(all_acc_valid_hist_s[i])
    b2 = b2/5

    for i in range(0,5):
      c1 = c1 + max(all_precision_valid_hist[i])
    c1 = c1/5
    for i in range(0,5):
      c2 = c2 + max(all_precision_valid_hist_s[i])
    c2 = c2/5

    for i in range(0,5):
      d1 = d1 + max(all_recall_valid_hist[i])
    d1 = d1/5
    for i in range(0,5):
      d2 = d2 + max(all_recall_valid_hist_s[i])
    d2 = d2/5

    for i in range(0,5):
      e1 = e1 + max(all_f1_valid_hist[i])
    e1 = e1/5
    for i in range(0,5):
      e2 = e2 + max(all_f1_valid_hist_s[i])
    e2 = e2/5


    df = pd.DataFrame([[e1, d1, c1, b1*100, a1, e2, d2, c2, b2*100, a2]],
                      columns=['f1', 'recall', 'precision', 'acc', 'loss', 'f1_s', 'recall_s', 'precision_s', 'acc_s', 'loss_s'])

    # Path to the Excel file
    excel_file_path = '/gdrive/MyDrive/Motor_Imagery/results16.xlsx'

    if os.path.exists(excel_file_path):
        # If the file exists, read the existing data
        existing_df = pd.read_excel(excel_file_path)

        # Append the new data
        updated_df = pd.concat([existing_df, df], ignore_index=True)
    else:
        # If the file does not exist, create a new DataFrame
        updated_df = df

    # Write the updated DataFrame back to the Excel file
    with pd.ExcelWriter(excel_file_path, engine='openpyxl', mode='w') as writer:
        updated_df.to_excel(writer, index=False)

# split 5-fold (solo)

## save test

In [None]:
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader, TensorDataset, SubsetRandomSampler

# Convert your data into PyTorch tensors
x_train_tensor = torch.tensor(x_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
x_test_tensor = torch.tensor(x_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)

# Create a dataset from the training data
train_dataset = TensorDataset(x_train_tensor, y_train_tensor)

In [None]:
def loss_fn_kd(outputs, labels, teacher_outputs, T, alpha):
  loss = F.kl_div(F.log_softmax(outputs/T, dim=1),
                  F.softmax(teacher_outputs/T, dim=1),
                  reduction='batchmean') * (alpha * T**2) + \
         F.cross_entropy(outputs, labels) * (1 - alpha)
  return loss

In [None]:
num_epochs = 50

# KFold cross-validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)

all_loss_test_hist = []
all_acc_test_hist = []
all_precision_test_hist = []
all_recall_test_hist = []
all_f1_test_hist = []

all_loss_test_hist_s = []
all_acc_test_hist_s = []
all_precision_test_hist_s = []
all_recall_test_hist_s = []
all_f1_test_hist_s = []

In [None]:
for fold, (train_idx, valid_idx) in enumerate(kf.split(x_train)):
    print(f"Fold {fold+1}")
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=32)
    valid_loader = DataLoader(train_dataset, sampler=valid_sampler, batch_size=32)
    test_loader = DataLoader(TensorDataset(x_test_tensor, y_test_tensor), batch_size=32, shuffle=False)

    model = CNN().to(device)
    loss_fn = nn.MultiMarginLoss()
    lr = 0.00005
    wd = 3e-4
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)


    best_loss_valid = float('inf')

    for epoch in range(num_epochs):
      # Train
      model, loss_train, acc_train = train_one_epoch(model, train_loader, loss_fn, optimizer, epoch)

      # Validation
      loss_valid, acc_valid, _, _ = validation(model, valid_loader, loss_fn)

      if loss_valid < best_loss_valid:
          path = '/gdrive/MyDrive/Motor_Imagery'
          torch.save(model, path + '/model_5_fold.pt')
          best_loss_valid = loss_valid
          print('Model Saved!')

      print(f'Valid: Loss = {loss_valid:.4}, Acc = {acc_valid:.4}')
      print()

    model = torch.load('/gdrive/MyDrive/Motor_Imagery/model_5_fold.pt', weights_only=False)
    final_loss_test, final_acc_test, all_targets_test, all_outputs_test = validation(model, test_loader, loss_fn)
    acc, macro_precision, macro_recall, macro_f1 = cal_metrics(all_targets_test, all_outputs_test)

    all_loss_test_hist.append(final_loss_test)
    all_acc_test_hist.append(final_acc_test)
    all_precision_test_hist.append(macro_precision)
    all_recall_test_hist.append(macro_recall)
    all_f1_test_hist.append(macro_f1)

    #------------------------------------------------------------KD----------------------------------------------------------------------

    teacher = torch.load('/gdrive/MyDrive/Motor_Imagery/model_5_fold.pt', weights_only=False)
    teacher.eval()
    student = CNN().to(device)
    lr = 0.00005
    wd = 3e-4
    optimizer = optim.Adam(student.parameters(), lr=lr, weight_decay=wd)
    loss_fn = nn.MultiMarginLoss()

    best_loss_valid_s = torch.inf
    epoch_counter = 0


    for epoch in range(num_epochs):
      # Train
      student, loss_train, acc_train = train_one_epoch_kd(student,
                                                          teacher,
                                                          train_loader,
                                                          loss_fn_kd,
                                                          optimizer,
                                                          epoch)
      # Validation
      loss_valid, acc_valid, _, _ = validation(student,
                                              valid_loader,
                                              loss_fn)


      if loss_valid < best_loss_valid_s:
        # path = '/gdrive/MyDrive/Motor_Imagery'
        # torch.save(model, path + '/model_5_fold.pt')
        best_loss_valid_s = loss_valid
        print('best')

      print(f'Valid: Loss = {loss_valid:.4}, Acc = {acc_valid:.4}')
      print()

      epoch_counter += 1

    final_loss_test, final_acc_test, all_targets_test, all_outputs_test = validation(student, test_loader, loss_fn)
    acc, macro_precision, macro_recall, macro_f1 = cal_metrics(all_targets_test, all_outputs_test)

    all_loss_test_hist_s.append(final_loss_test)
    all_acc_test_hist_s.append(final_acc_test)
    all_precision_test_hist_s.append(macro_precision)
    all_recall_test_hist_s.append(macro_recall)
    all_f1_test_hist_s.append(macro_f1)

# --------------------------------------------Save Results----------------------------------------------------------

a1 = sum(all_loss_test_hist)/5
a2 = sum(all_loss_test_hist_s)/5
b1 = sum(all_acc_test_hist)/5
b2 = sum(all_acc_test_hist_s)/5
c1 = sum(all_precision_test_hist)/5
c2 = sum(all_precision_test_hist_s)/5
d1 = sum(all_recall_test_hist)/5
d2 = sum(all_recall_test_hist_s)/5
e1 = sum(all_f1_test_hist)/5
e2 = sum(all_f1_test_hist_s)/5

df = pd.DataFrame([[e1, d1, c1, b1*100, a1, e2, d2, c2, b2*100, a2]],
                  columns=['f1', 'recall', 'precision', 'acc', 'loss', 'f1_s', 'recall_s', 'precision_s', 'acc_s', 'loss_s'])

# # Path to the Excel file
# excel_file_path = '/gdrive/MyDrive/Motor_Imagery/results5.xlsx'

# if os.path.exists(excel_file_path):
#     # If the file exists, read the existing data
#     existing_df = pd.read_excel(excel_file_path)

#     # Append the new data
#     updated_df = pd.concat([existing_df, df], ignore_index=True)
# else:
#     # If the file does not exist, create a new DataFrame
#     updated_df = df

# # Write the updated DataFrame back to the Excel file
# with pd.ExcelWriter(excel_file_path, engine='openpyxl', mode='w') as writer:
#     updated_df.to_excel(writer, index=False)

In [None]:
df

In [None]:
final_loss_test, final_acc_test, all_targets_test, all_outputs_test = validation(student, test_loader, loss_fn)

In [None]:
final_acc_test

## save valid


In [None]:
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader, TensorDataset, SubsetRandomSampler

# Convert your data into PyTorch tensors
x_train_tensor = torch.tensor(x_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
x_test_tensor = torch.tensor(x_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)

# Create a dataset from the training data
train_dataset = TensorDataset(x_train_tensor, y_train_tensor)

In [None]:
num_epochs = 80

# KFold cross-validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)

# Histories for all folds
all_loss_valid_hist = []
all_acc_valid_hist = []
all_precision_valid_hist = []
all_recall_valid_hist = []
all_f1_valid_hist = []

all_loss_valid_hist_s = []
all_acc_valid_hist_s = []
all_precision_valid_hist_s = []
all_recall_valid_hist_s = []
all_f1_valid_hist_s = []

all_loss_test_hist = []
all_acc_test_hist = []

# Initialize best loss to a large number

In [None]:
def loss_fn_kd(outputs, labels, teacher_outputs, T, alpha):
  loss = F.kl_div(F.log_softmax(outputs/T, dim=1),
                  F.softmax(teacher_outputs/T, dim=1),
                  reduction='batchmean') * (alpha * T**2) + \
         F.cross_entropy(outputs, labels) * (1 - alpha)
  return loss

In [None]:
for fold, (train_idx, valid_idx) in enumerate(kf.split(x_train)):
    print(f"Fold {fold+1}")
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=32)
    valid_loader = DataLoader(train_dataset, sampler=valid_sampler, batch_size=32)

    model = CNN().to(device)
    loss_fn = nn.MultiMarginLoss()
    lr = 0.00005
    wd = 3e-4
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

    # Histories for this fold
    loss_valid_hist = []
    acc_valid_hist = []
    precision_valid_hist = []
    recall_valid_hist = []
    f1_valid_hist = []

    best_loss_valid = float('inf')

    for epoch in range(num_epochs):
      # Train
      model, loss_train, acc_train = train_one_epoch(model, train_loader, loss_fn, optimizer, epoch)

      # Validation
      loss_valid, acc_valid, all_targets, all_outputs = validation(model, valid_loader, loss_fn)

      acc, macro_precision, macro_recall, macro_f1 = cal_metrics(all_targets, all_outputs)

      loss_valid_hist.append(loss_valid)
      acc_valid_hist.append(acc_valid)
      precision_valid_hist.append(macro_precision)
      recall_valid_hist.append(macro_recall)
      f1_valid_hist.append(macro_f1)

      if loss_valid < best_loss_valid:
          path = '/gdrive/MyDrive/Motor_Imagery'
          torch.save(model, path + '/model_5_fold.pt')
          best_loss_valid = loss_valid
          print('Model Saved!')
          print(f'macro_precision = {macro_precision:.4}, macro_recall = {macro_recall:.4}, macro_f1 = {macro_f1:.4}')


      print(f'Valid: Loss = {loss_valid:.4}, Acc = {acc_valid:.4}')
      print()

    all_loss_valid_hist.append(loss_valid_hist)
    all_acc_valid_hist.append(acc_valid_hist)
    all_precision_valid_hist.append(precision_valid_hist)
    all_recall_valid_hist.append(recall_valid_hist)
    all_f1_valid_hist.append(f1_valid_hist)

    #------------------------------------------------------------KD----------------------------------------------------------------------

    loss_valid_hist = []
    acc_valid_hist = []
    precision_valid_hist = []
    recall_valid_hist = []
    f1_valid_hist = []

    teacher = torch.load('/gdrive/MyDrive/Motor_Imagery/model_5_fold.pt')
    teacher.eval()
    student = CNN().to(device)
    lr = 0.00005
    wd = 3e-4
    optimizer = optim.Adam(student.parameters(), lr=lr, weight_decay=wd)
    loss_fn = nn.MultiMarginLoss()

    best_loss_valid_s = torch.inf
    epoch_counter = 0


    for epoch in range(num_epochs):
      # Train
      student, loss_train, acc_train = train_one_epoch_kd(student,
                                                          teacher,
                                                          train_loader,
                                                          loss_fn_kd,
                                                          optimizer,
                                                          epoch)
      # Validation
      loss_valid, acc_valid, all_targets, all_outputs = validation(student,
                                                                  valid_loader,
                                                                  loss_fn)

      acc, macro_precision, macro_recall, macro_f1 = cal_metrics(all_targets, all_outputs)

      loss_valid_hist.append(loss_valid)
      acc_valid_hist.append(acc_valid)
      precision_valid_hist.append(macro_precision)
      recall_valid_hist.append(macro_recall)
      f1_valid_hist.append(macro_f1)


      if loss_valid < best_loss_valid_s:
        # path = '/gdrive/MyDrive/Motor_Imagery'
        # torch.save(model, path + '/model_5_fold.pt')
        best_loss_valid_s = loss_valid
        print('best')
        print(f'macro_precision = {macro_precision:.4}, macro_recall = {macro_recall:.4}, macro_f1 = {macro_f1:.4}')


      print(f'Valid: Loss = {loss_valid:.4}, Acc = {acc_valid:.4}')
      print()

      epoch_counter += 1

    all_loss_valid_hist_s.append(loss_valid_hist)
    all_acc_valid_hist_s.append(acc_valid_hist)
    all_precision_valid_hist_s.append(precision_valid_hist)
    all_recall_valid_hist_s.append(recall_valid_hist)
    all_f1_valid_hist_s.append(f1_valid_hist)

    test_loader = DataLoader(TensorDataset(x_test_tensor, y_test_tensor), batch_size=32, shuffle=False)
    final_loss_test, final_acc_test, all_targets_test, all_outputs_test = validation(student, test_loader, loss_fn)
    all_loss_test_hist.append(final_loss_test)
    all_acc_test_hist.append(final_acc_test)

# --------------------------------------------Save Results----------------------------------------------------------

a1=0
a2=0
b1=0
b2=0
c1=0
c2=0
d1=0
d2=0
e1=0
e2=0

for i in range(0,5):
  a1 = a1 + min(all_loss_valid_hist[i])
a1 = a1/5
for i in range(0,5):
  a2 = a2 + min(all_loss_valid_hist_s[i])
a2 = a2/5

for i in range(0,5):
  b1 = b1 + max(all_acc_valid_hist[i])
b1 = b1/5
for i in range(0,5):
  b2 = b2 + max(all_acc_valid_hist_s[i])
b2 = b2/5

for i in range(0,5):
  c1 = c1 + max(all_precision_valid_hist[i])
c1 = c1/5
for i in range(0,5):
  c2 = c2 + max(all_precision_valid_hist_s[i])
c2 = c2/5

for i in range(0,5):
  d1 = d1 + max(all_recall_valid_hist[i])
d1 = d1/5
for i in range(0,5):
  d2 = d2 + max(all_recall_valid_hist_s[i])
d2 = d2/5

for i in range(0,5):
  e1 = e1 + max(all_f1_valid_hist[i])
e1 = e1/5
for i in range(0,5):
  e2 = e2 + max(all_f1_valid_hist_s[i])
e2 = e2/5


df = pd.DataFrame([[e1, d1, c1, b1*100, a1, e2, d2, c2, b2*100, a2]],
                  columns=['f1_s', 'recall_s', 'precision_s', 'acc_s', 'loss_s', 'f1', 'recall', 'precision', 'acc', 'loss'])

# # Path to the Excel file
# excel_file_path = '/gdrive/MyDrive/Motor_Imagery/results5.xlsx'

# if os.path.exists(excel_file_path):
#     # If the file exists, read the existing data
#     existing_df = pd.read_excel(excel_file_path)

#     # Append the new data
#     updated_df = pd.concat([existing_df, df], ignore_index=True)
# else:
#     # If the file does not exist, create a new DataFrame
#     updated_df = df

# # Write the updated DataFrame back to the Excel file
# with pd.ExcelWriter(excel_file_path, engine='openpyxl', mode='w') as writer:
#     updated_df.to_excel(writer, index=False)

In [None]:
df

In [None]:
print(b1)
print(b2)
print(sum(all_acc_test_hist)/5)

In [None]:
from openpyxl import load_workbook
a1=0
a2=0
b1=0
b2=0
c1=0
c2=0
d1=0
d2=0
e1=0
e2=0

for i in range(0,5):
  a1 = a1 + min(all_loss_valid_hist[i])
a1 = a1/5
for i in range(0,5):
  a2 = a2 + min(all_loss_valid_hist_s[i])
a2 = a2/5

for i in range(0,5):
  b1 = b1 + max(all_acc_valid_hist[i])
b1 = b1/5
for i in range(0,5):
  b2 = b2 + max(all_acc_valid_hist_s[i])
b2 = b2/5

for i in range(0,5):
  c1 = c1 + max(all_precision_valid_hist[i])
c1 = c1/5
for i in range(0,5):
  c2 = c2 + max(all_precision_valid_hist_s[i])
c2 = c2/5

for i in range(0,5):
  d1 = d1 + max(all_recall_valid_hist[i])
d1 = d1/5
for i in range(0,5):
  d2 = d2 + max(all_recall_valid_hist_s[i])
d2 = d2/5

for i in range(0,5):
  e1 = e1 + max(all_f1_valid_hist[i])
e1 = e1/5
for i in range(0,5):
  e2 = e2 + max(all_f1_valid_hist_s[i])
e2 = e2/5


df = pd.DataFrame([[e1, d1, c1, b1*100, a1, e2, d2, c2, b2*100, a2]],
                  columns=['f1_s', 'recall_s', 'precision_s', 'acc_s', 'loss_s', 'f1', 'recall', 'precision', 'acc', 'loss'])

# Path to the Excel file
excel_file_path = '/gdrive/MyDrive/Motor_Imagery/results5.xlsx'

if os.path.exists(excel_file_path):
    # If the file exists, read the existing data
    existing_df = pd.read_excel(excel_file_path)

    # Append the new data
    updated_df = pd.concat([existing_df, df], ignore_index=True)
else:
    # If the file does not exist, create a new DataFrame
    updated_df = df

# Write the updated DataFrame back to the Excel file
with pd.ExcelWriter(excel_file_path, engine='openpyxl', mode='w') as writer:
    updated_df.to_excel(writer, index=False)
print('---------------------------------------------------------------------------------------------------------')

In [None]:
test_loader = DataLoader(TensorDataset(x_test_tensor, y_test_tensor), batch_size=32, shuffle=False)
final_loss_test, final_acc_test, all_targets_test, all_outputs_test = validation(student, test_loader, loss_fn)
print(f'Test: Loss = {final_loss_test:.4}, Acc = {final_acc_test:.4}')

# Train

In [None]:
model = CNN().to(device)
loss_fn = nn.MultiMarginLoss()
# loss_fn = nn.CrossEntropyLoss()

In [None]:
lr = 0.00005
wd = 3e-4
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

In [None]:
loss_train_hist = []
loss_valid_hist = []

acc_train_hist = []
acc_valid_hist = []
acc3_valid_hist = []

best_loss_valid = torch.inf
epoch_counter = 0

In [None]:
num_epochs = 70

for epoch in range(num_epochs):
  # Train
  model, loss_train, acc_train = train_one_epoch(model,
                                                 train_loader,
                                                 loss_fn,
                                                 optimizer,
                                                 epoch)
  # Validation
  loss_valid, acc_valid, all_targets, all_outputs = validation(model,
                                                              valid_loader,
                                                              loss_fn)

  loss_train_hist.append(loss_train)
  loss_valid_hist.append(loss_valid)

  acc_train_hist.append(acc_train)
  acc_valid_hist.append(acc_valid)

  if loss_valid < best_loss_valid:
    path = '/gdrive/MyDrive/Motor_Imagery'
    torch.save(model, path + '/model' + '.pt')
    best_loss_valid = loss_valid
    acc, macro_precision, macro_recall, macro_f1 = cal_metrics(all_targets, all_outputs)
    print('Model Saved!')
    print(f'macro_precision = {macro_precision:.4}, macro_recall = {macro_recall:.4}, macro_f1 = {macro_f1:.4}')

  print(f'Valid: Loss = {loss_valid:.4}, Acc = {acc_valid:.4}')
  print()

  epoch_counter += 1

loss_test, acc_test, all_targets, all_outputs = validation(model,
                                                              test_loader,
                                                              loss_fn)
print(f'Test: Loss = {loss_test:.4}, Acc = {acc_test:.4}')


In [None]:
model = torch.load('/gdrive/MyDrive/Motor_Imagery/model.pt')
loss_test, acc_test, all_targets, all_outputs = validation(model,
                                           valid_loader,
                                           loss_fn)

acc, macro_precision, macro_recall, macro_f1 = cal_metrics(all_targets, all_outputs)

print(f'Valid: Loss = {loss_test:.4}, Acc = {acc_test:.4}')
print(f'macro_precision = {macro_precision:.4}, macro_recall = {macro_recall:.4}, macro_f1 = {macro_f1:.4}')

In [None]:
all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, labels in valid_loader:
        outputs = model(inputs.to(device))
        preds = F.softmax(outputs, dim=1).argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

In [None]:
conf_matrix = confusion_matrix(all_labels, all_preds)

# Plot the confusion matrix as before
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=range(9), yticklabels=range(9))
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

# Knowledge distillation loss

In [None]:
def loss_fn_kd(outputs, labels, teacher_outputs, T, alpha):
  loss = F.kl_div(F.log_softmax(outputs/T, dim=1),
                  F.softmax(teacher_outputs/T, dim=1),
                  reduction='batchmean') * (alpha * T**2) + \
         F.cross_entropy(outputs, labels) * (1 - alpha)
  return loss

In [None]:
teacher = torch.load('/gdrive/MyDrive/Motor_Imagery/model.pt')
teacher.eval()

In [None]:
student = CNN().to(device)

In [None]:
lr = 0.00005
wd = 3e-4
optimizer = optim.Adam(student.parameters(), lr=lr, weight_decay=wd)
# loss_fn = nn.CrossEntropyLoss()
loss_fn = nn.MultiMarginLoss()

In [None]:
loss_train_hist = []
loss_valid_hist = []

acc_train_hist = []
acc_valid_hist = []

best_loss_valid_s = torch.inf
epoch_counter = 0

In [None]:
num_epochs = 80

for epoch in range(num_epochs):
  # Train
  student, loss_train, acc_train = train_one_epoch_kd(student,
                                                      teacher,
                                                      train_loader,
                                                      loss_fn_kd,
                                                      optimizer,
                                                      epoch)
  # Validation
  loss_valid, acc_valid, all_targets, all_outputs = validation(student,
                                                              valid_loader,
                                                              loss_fn)

  loss_train_hist.append(loss_train)
  loss_valid_hist.append(loss_valid)

  acc_train_hist.append(acc_train)
  acc_valid_hist.append(acc_valid)

  if loss_valid < best_loss_valid_s:
    # path = '/gdrive/MyDrive/Motor_Imagery'
    # torch.save(model, path + '/model' + '.pt')
    best_loss_valid_s = loss_valid
    acc_s, macro_precision_s, macro_recall_s, macro_f1_s = cal_metrics(all_targets, all_outputs)
    print('best')
    print(f'macro_precision = {macro_precision_s:.4}, macro_recall = {macro_recall_s:.4}, macro_f1 = {macro_f1_s:.4}')

  print(f'Valid: Loss = {loss_valid:.4}, Acc = {acc_valid:.4}')
  print()

  epoch_counter += 1


# result

In [None]:
print(f'best_acc = {acc:.4}, best_loss = {best_loss_valid:.4}, best_precision = {macro_precision:.4}, best_recall = {macro_recall:.4}, best_f1 = {macro_f1:.4}')
print(f'best_acc = {acc_s:.4}, best_loss = {best_loss_valid_s:.4}, best_precision = {macro_precision_s:.4}, best_recall = {macro_recall_s:.4}, best_f1 = {macro_f1_s:.4}')

# Plot

In [None]:
plt.plot(range(epoch_counter), loss_train_hist, 'r-', label='Train')
plt.plot(range(epoch_counter), loss_valid_hist, 'b-', label='Validation')

plt.xlabel('Epoch')
plt.ylabel('loss')
plt.grid(True)
plt.legend()

In [None]:
plt.plot(range(epoch_counter), acc_train_hist, 'r-', label='Train')
plt.plot(range(epoch_counter), acc_valid_hist, 'b-', label='Validation')

plt.xlabel('Epoch')
plt.ylabel('Acc')
plt.grid(True)
plt.legend()

In [None]:
plt.subplot(2, 1, 1)

plt.plot(range(epoch_counter), loss_train_hist, 'r-', label='Train')
plt.plot(range(epoch_counter), loss_valid_hist, 'b-', label='Validation')

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.legend()

# Create a 2x1 subplot for accuracy
plt.subplot(2, 1, 2)

plt.plot(range(epoch_counter), acc_train_hist, 'r-', label='Train')
plt.plot(range(epoch_counter), acc_valid_hist, 'b-', label='Validation')

plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.grid(True)
plt.legend()

# Adjust layout for better visualization
plt.tight_layout()

# Show the plots
plt.show()

# **Dataset_val 🗂️**

## Load dataset

In [None]:
from scipy.signal import butter, filtfilt
fs = 250  # Sampling frequency
f1 = 13  # Lower cutoff frequency
f2 = 30  # Upper cutoff frequency
order = 5  # Filter order

# apply_filter = True
# Create bandpass filter coefficients
nyq = 0.5 * fs
low = f1 / nyq
high = f2 / nyq
b, a = butter(order, [low, high], btype='band')

In [None]:
df = []
for i in range(1,10):
  data = loadmat(f'/gdrive/MyDrive/Motor_Imagery/BCI2a/subjects1000_val/sub{i}/data_{task}_sub{i}.mat')
  if duration == 4:
    data = data[f'data_{task}'][:,channels,:]
    if apply_filter == True:
      data = filtfilt(b, a, data) #frequency filter
    label = [i for i in range(1, 10) for _ in range(data.shape[0])]
    label = np.array(label).reshape((9, data.shape[0]))
    df.append(data)
  if duration == 2:
    data1 = data[f'data_{task}'][:,channels,:500]
    data2 = data[f'data_{task}'][:,channels,500:1000]
    data = np.concatenate((data1, data2), axis=0)
    if apply_filter == True:
      data = filtfilt(b, a, data) #frequency filter
    label = [i for i in range(1, 10) for _ in range(data.shape[0])]
    label = np.array(label).reshape((9, data.shape[0]))
    df.append(data)
  if duration == 1:
    data1 = data[f'data_{task}'][:,channels,:250]
    data2 = data[f'data_{task}'][:,channels,250:500]
    data3 = data[f'data_{task}'][:,channels,500:750]
    data4 = data[f'data_{task}'][:,channels,750:1000]
    data = np.concatenate((data1, data2, data3, data4), axis=0)
    # data = np.concatenate((data1, data2, data3), axis=0) #for data with 75 sample
    if apply_filter == True:
      data = filtfilt(b, a, data) #frequency filter
    label = [i for i in range(1, 10) for _ in range(data.shape[0])]
    label = np.array(label).reshape((9, data.shape[0]))
    df.append(data)
df = np.array(df)
print(df.shape)
num_trial = df.shape[1]
num_ch = df.shape[2]
num_smaple = df.shape[3]
df = df.reshape((9*num_trial,num_ch,num_smaple))
label = np.array(label)
label = label.reshape((9*num_trial,))

In [None]:
# df = []
# for i in range(1,10):
#   data = loadmat(f'/gdrive/MyDrive/Motor_Imagery/BCI2a/subjects/sub{i}/data_{task}_sub{i}.mat')
#   data = data[f'data_{task}'][:,channels,:250]
#   # data = filtfilt(b, a, data) #frequency filter
#   label = [i for i in range(1, 10) for _ in range(72)]
#   label = np.array(label).reshape((9, 72))
#   df.append(data)
# df = np.array(df)
# print(df.shape)
# num_ch = df.shape[2]
# num_smaple = df.shape[3]
# df = df.reshape((9*72,num_ch,num_smaple))
# label = np.array(label)
# label = label.reshape((9*72,))

In [None]:
print(df.shape)
print(label.shape)

In [None]:
label = label -1

In [None]:
x_test, _, y_test, _ = train_test_split(df, label, test_size=0.1, random_state=23)
# x_train, x_test, y_train, y_test = train_test_split(x_train, y_train, test_size=0.1, random_state=23)

In [None]:
x_test.shape

In [None]:
y_test.shape

In [None]:
x_test = torch.FloatTensor(x_test)
x_test = x_test.unsqueeze(1)
y_test = torch.LongTensor(y_test)
y_test = y_test.squeeze()

mu = x_train.mean(dim=0)
std = x_train.std(dim=0)

x_test = (x_test - mu) / std

In [None]:
torch.unique(y_test)

## TensorDataset

In [None]:
test_dataset = TensorDataset(x_test, y_test)

## DataLoader

In [None]:
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [None]:
x, y = next(iter(train_loader))
print(x.shape)
print(y.shape)

In [None]:
# model = torch.load('/gdrive/MyDrive/Motor_Imagery/model.pt')
loss_test, acc_test, all_targets, all_outputs = validation(student,
                                           test_loader,
                                           loss_fn)

acc, macro_precision, macro_recall, macro_f1 = cal_metrics(all_targets, all_outputs)
print(f'Valid: Loss = {loss_test:.4}, Acc = {acc_test:.4}')
print(f'macro_precision = {macro_precision:.4}, macro_recall = {macro_recall:.4}, macro_f1 = {macro_f1:.4}')

In [None]:
model(x.to(device))[1]

In [None]:
y[1]

# test

In [None]:
channels = [0, 1, 2, 3, 4, 5] # Frontal = [0, 1, 2, 3, 4, 5], Central = [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], Parietal = [18, 19, 20, 21],
# All = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]
# chan = [[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], [18, 19, 20, 21], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]]
task = 'tongue' # left, right, foot, tongue
apply_filter = True
time = [4,1]
band = [[0.5, 4], [4, 8], [8, 13], [13, 30], [30, 100], ['a', 'a']]
from openpyxl import load_workbook

In [None]:
for fl, fh in band:
  if fl == 'a':
    apply_filter = False
  else:
    fs = 250  # Sampling frequency
    order = 5  # Filter order
    # Create bandpass filter coefficients
    nyq = 0.5 * fs
    low = fl / nyq
    high = fh / nyq
    b, a = butter(order, [low, high], btype='band')
  for t in time:
    df = []
    for i in range(1,10):
      data = loadmat(f'/gdrive/MyDrive/Motor_Imagery/BCI2a/subjects1000/sub{i}/data_{task}_sub{i}.mat')
      if t == 4:
        data = data[f'data_{task}'][:,channels,:]
        if apply_filter == True:
          data = filtfilt(b, a, data) #frequency filter
        label = [i for i in range(1, 10) for _ in range(data.shape[0])]
        label = np.array(label).reshape((9, data.shape[0]))
        df.append(data)
      if t == 1:
        data1 = data[f'data_{task}'][:,channels,:250]
        data2 = data[f'data_{task}'][:,channels,250:500]
        data3 = data[f'data_{task}'][:,channels,500:750]
        data4 = data[f'data_{task}'][:,channels,750:1000]
        data = np.concatenate((data1, data2, data3, data4), axis=0)
        # data = np.concatenate((data1, data2, data3), axis=0) #for data with 75 sample
        if apply_filter == True:
          data = filtfilt(b, a, data) #frequency filter
        label = [i for i in range(1, 10) for _ in range(data.shape[0])]
        label = np.array(label).reshape((9, data.shape[0]))
        df.append(data)
    df = np.array(df)
    print(df.shape)
    num_trial = df.shape[1]
    num_ch = df.shape[2]
    num_smaple = df.shape[3]
    df = df.reshape((9*num_trial,num_ch,num_smaple))
    label = np.array(label)
    label = label.reshape((9*num_trial,))
    label = label -1
    x_train, x_valid, y_train, y_valid = train_test_split(df, label, test_size=0.2, random_state=23)
    x_train, x_test, y_train, y_test = train_test_split(x_train, y_train, test_size=0.1, random_state=23)

    x_train = torch.FloatTensor(x_train)
    x_train = x_train.unsqueeze(1)
    y_train = torch.LongTensor(y_train)
    y_train = y_train.squeeze()
    x_valid = torch.FloatTensor(x_valid)
    x_valid = x_valid.unsqueeze(1)
    y_valid = torch.LongTensor(y_valid)
    y_valid = y_valid.squeeze()
    x_test = torch.FloatTensor(x_test)
    x_test = x_test.unsqueeze(1)
    y_test = torch.LongTensor(y_test)
    y_test = y_test.squeeze()
    mu = x_train.mean(dim=0)
    std = x_train.std(dim=0)
    x_train = (x_train - mu) / std
    x_valid = (x_valid - mu) / std
    x_test = (x_test - mu) / std

    train_dataset = TensorDataset(x_train, y_train)
    valid_dataset = TensorDataset(x_valid, y_valid)
    test_dataset = TensorDataset(x_test, y_test)

    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=128, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=130, shuffle=True)

    model = CNN().to(device)
    loss_fn = nn.MultiMarginLoss()
    lr = 0.00005
    wd = 3e-4
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    loss_train_hist = []
    loss_valid_hist = []

    acc_train_hist = []
    acc_valid_hist = []
    acc3_valid_hist = []

    best_loss_valid = torch.inf
    epoch_counter = 0

    num_epochs = 80
    for epoch in range(num_epochs):
      # Train
      model, loss_train, acc_train = train_one_epoch(model,
                                                    train_loader,
                                                    loss_fn,
                                                    optimizer,
                                                    epoch)
      # Validation
      loss_valid, acc_valid, all_targets, all_outputs = validation(model,
                                                                  valid_loader,
                                                                  loss_fn)

      loss_train_hist.append(loss_train)
      loss_valid_hist.append(loss_valid)

      acc_train_hist.append(acc_train)
      acc_valid_hist.append(acc_valid)

      if loss_valid < best_loss_valid:
        path = '/gdrive/MyDrive/Motor_Imagery'
        torch.save(model, path + '/model' + '.pt')
        best_loss_valid = loss_valid
        acc, macro_precision, macro_recall, macro_f1 = cal_metrics(all_targets, all_outputs)
        print('Model Saved!')
        print(f'macro_precision = {macro_precision:.4}, macro_recall = {macro_recall:.4}, macro_f1 = {macro_f1:.4}')

      print(f'Valid: Loss = {loss_valid:.4}, Acc = {acc_valid:.4}')
      print()

      epoch_counter += 1

    teacher = torch.load('/gdrive/MyDrive/Motor_Imagery/model.pt')
    teacher.eval()
    student = CNN().to(device)
    lr = 0.00005
    wd = 3e-4
    optimizer = optim.Adam(student.parameters(), lr=lr, weight_decay=wd)
    # loss_fn = nn.CrossEntropyLoss()
    loss_fn = nn.MultiMarginLoss()
    loss_train_hist = []
    loss_valid_hist = []

    acc_train_hist = []
    acc_valid_hist = []

    best_loss_valid_s = torch.inf
    epoch_counter = 0

    num_epochs = 80
    for epoch in range(num_epochs):
      # Train
      student, loss_train, acc_train = train_one_epoch_kd(student,
                                                          teacher,
                                                          train_loader,
                                                          loss_fn_kd,
                                                          optimizer,
                                                          epoch)
      # Validation
      loss_valid, acc_valid, all_targets, all_outputs = validation(student,
                                                                  valid_loader,
                                                                  loss_fn)

      loss_train_hist.append(loss_train)
      loss_valid_hist.append(loss_valid)

      acc_train_hist.append(acc_train)
      acc_valid_hist.append(acc_valid)

      if loss_valid < best_loss_valid_s:
        # path = '/gdrive/MyDrive/Motor_Imagery'
        # torch.save(model, path + '/model' + '.pt')
        best_loss_valid_s = loss_valid
        acc_s, macro_precision_s, macro_recall_s, macro_f1_s = cal_metrics(all_targets, all_outputs)
        print('best')
        print(f'macro_precision = {macro_precision_s:.4}, macro_recall = {macro_recall_s:.4}, macro_f1 = {macro_f1_s:.4}')

      print(f'Valid: Loss = {loss_valid:.4}, Acc = {acc_valid:.4}')
      print()

      epoch_counter += 1

    # df = pd.DataFrame([[best_loss_valid_s, acc_s, macro_precision_s, macro_recall_s, macro_f1_s, best_loss_valid, acc, macro_precision, macro_recall, macro_f1]])
    df = pd.DataFrame([[macro_f1, macro_recall, macro_precision, acc*100, best_loss_valid, macro_f1_s, macro_recall_s, macro_precision_s, acc_s*100, best_loss_valid_s]])
    excel_file_path = f'/gdrive/MyDrive/Motor_Imagery/results5.xlsx'
    try:
        # Load the existing Excel file
        existing_wb = load_workbook(excel_file_path)
        # Create a Pandas Excel writer using openpyxl
        writer = pd.ExcelWriter(excel_file_path, engine='openpyxl')
        # Copy the existing sheets
        writer.book = existing_wb
        # Append the new DataFrame to the existing Excel file
        df.to_excel(writer, index=False, header=False, startrow=existing_wb.active.max_row, float_format="%.4f")
        # Save the workbook
        writer.save()
        writer.close()
    except FileNotFoundError:
        # If the file doesn't exist, create it and write the row
        df.to_excel(excel_file_path, header=False, index=False, float_format="%.4f")
    print('---------------------------------------------------------------------------------------------------------')