# Установка нужных библиотек

# Импорт библиотек

In [1]:
import os
import json
import glob
import torch 
import sys
sys.path.append("..")

import matplotlib.pyplot as plt
from src.mylib.models.models import Model
from src.mylib.train import Trainer

file = os.path.abspath('')

2024-04-11 02:19:33.672551: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


# Работа с данными

Начальные параметры

In [2]:
batch_size = 64

# Длина окна
window_length_seconds = 5 
sample_rate = 64
window_length = window_length_seconds * sample_rate

# Расстояние между двумя окнами
hop_length_seconds = 1
hop_length = sample_rate * hop_length_seconds

# Количество ложные стимулов
number_of_mismatch = 4

In [3]:
experiment_folder = os.path.dirname(file)

# Load the config file
with open(os.path.join(experiment_folder, "src/mylib/utils/config.json")) as file_path:
    config = json.load(file_path)

# Path to the dataset, which is already split to train, val, test
data_folder = os.path.join(config["dataset_folder"], config['derivatives_folder'], config["split_folder"])

# Пути к тренировочным, валидационным и тестовым данным
train_files = [x for x in glob.glob(os.path.join(data_folder, "train_-_*")) if
                       os.path.basename(x).split("_-_")[-1].split(".")[0] in ["eeg", "envelope"]]
val_files = [x for x in glob.glob(os.path.join(data_folder, "val_-_*")) if
                       os.path.basename(x).split("_-_")[-1].split(".")[0] in ["eeg", "envelope"]]
test_files = [x for x in glob.glob(os.path.join(data_folder, "test_-_*")) if
                       os.path.basename(x).split("_-_")[-1].split(".")[0] in ["eeg", "envelope"]]

## Обучение 

In [4]:
args = {"window_length" : window_length, "hop_length" : hop_length, "number_of_mismatch" : number_of_mismatch, "batch_size" : batch_size, 
        "max_files" : None}

Базовый энкодер ЭЭГ + базовый энкодер стимула

In [5]:
model = Model()
trainer = Trainer(
    model, train_files, val_files, test_files, args, torch.optim.Adam(model.parameters(), lr=1e-3) 
)
print("Training")
trainer.train_model(epochs=2, run_name="Baseline", eps=1e-5)
print("Testing")
trainer.test(window_length, hop_length, number_of_mismatch, None)

Training
EPOCH 1:
  batch 100 loss: 0.14436513595159414
  batch 200 loss: 0.0009197634062729498
  batch 300 loss: 0.00022809133013353743
  batch 400 loss: 9.146743850525495e-06
LOSS train 9.146743850525495e-06 valid 0.00019639057730306194
EPOCH 2:
  batch 100 loss: 0.00017933270480170683
  batch 200 loss: 1.6600384697085246e-05
  batch 300 loss: 3.069836762728317e-05
  batch 400 loss: 4.906420497619024e-07
LOSS train 4.906420497619024e-07 valid 0.0002756381290764012
Testing
sub-033
    Mean accuracy per subject: 99.11373707533235
sub-011
    Mean accuracy per subject: 99.27360774818402
sub-014
    Mean accuracy per subject: 99.6319018404908
sub-024
    Mean accuracy per subject: 99.27360774818402
sub-078
    Mean accuracy per subject: 99.08925318761385
sub-064
    Mean accuracy per subject: 99.21752738654148
sub-073
    Mean accuracy per subject: 98.72495446265938
sub-042
    Mean accuracy per subject: 98.53420195439739
sub-021
    Mean accuracy per subject: 99.38650306748467
sub-076
 

Трансформер для ЭЭГ + базовый энкодер стимула

In [5]:
model = Model(use_transformer=True)
trainer = Trainer(
    model, train_files, val_files, test_files, args, torch.optim.Adam(model.parameters(), lr=1e-3) 
)
print("Training")
trainer.train_model(epochs=2, run_name="Transformer", eps=1e-5)
print("Testing")
trainer.test(window_length, hop_length, number_of_mismatch, None)

Training
EPOCH 1:
  batch 100 loss: 0.14217073852825343
  batch 200 loss: 0.0021196607220531918
  batch 300 loss: 0.0012117775476999603
  batch 400 loss: 0.00012176215706791283
  batch 500 loss: 0.0006003972515463829
  batch 600 loss: 0.0008171536028385162
  batch 700 loss: 0.0005988744071510155
  batch 800 loss: 1.7175454407531985e-05
  batch 900 loss: 2.066788147097043e-06
LOSS train 2.066788147097043e-06 valid 1.7467377588164068e-10
Testing
sub-013
    Mean accuracy per subject: 100.0
sub-064
    Mean accuracy per subject: 100.0
sub-005
    Mean accuracy per subject: 100.0
sub-021
    Mean accuracy per subject: 99.87730061349693
sub-078
    Mean accuracy per subject: 100.0
sub-024
    Mean accuracy per subject: 99.87893462469734
sub-042
    Mean accuracy per subject: 100.0
sub-085
    Mean accuracy per subject: 100.0
sub-022
    Mean accuracy per subject: 99.87730061349693
sub-036
    Mean accuracy per subject: 100.0
sub-001
    Mean accuracy per subject: 99.87893462469734
sub-062
 

Базовый энкодер ЭЭГ + Эмбеддинги `Wav2Vec`

In [5]:
model = Model(use_embeddings=True)
trainer = Trainer(
    model, train_files, val_files, test_files, args, torch.optim.Adam(model.parameters(), lr=1e-3), use_embeddings=True, embedding_type="wav2vec" 
)
print("Training")
trainer.train_model(epochs=2, run_name="Wav2Vec", eps=1e-5)
print("Testing")
trainer.test(window_length, hop_length, number_of_mismatch, None)

Training
EPOCH 1:
  batch 100 loss: 0.14670967061872106
  batch 200 loss: 0.0002088138647377491
  batch 300 loss: 0.0004688701720472466
  batch 400 loss: 2.47868995427325e-05
  batch 500 loss: 5.7140075941566334e-05
  batch 600 loss: 2.771553678115879e-08
  batch 700 loss: 5.164515732758446e-07
  batch 800 loss: 0.00010941589002527507
  batch 900 loss: 1.2982928008664629e-06
  batch 1000 loss: 5.7820174260996284e-05
  batch 1100 loss: 0.0002700865108567985
  batch 1200 loss: 0.001713825878876527
  batch 1300 loss: 9.78196553211319e-05
  batch 1400 loss: 6.904328486312395e-07
  batch 1500 loss: 0.0001654549483994916
  batch 1600 loss: 2.2600606383402777e-05
  batch 1700 loss: 2.232706955922481e-05
  batch 1800 loss: 5.270355261099979e-07
  batch 1900 loss: 5.773793950439199e-05
  batch 2000 loss: 1.918154739541933e-06
  batch 2100 loss: 1.1456513675511814e-07
  batch 2200 loss: 0.0
  batch 2300 loss: 9.685748336707433e-10
  batch 2400 loss: 3.917019967047963e-08
  batch 2500 loss: 0.0
 

Базовый энкодер ЭЭГ + Эмбеддинги `Whisper`

In [5]:
model = Model(use_embeddings=True)
trainer = Trainer(
    model, train_files, val_files, test_files, args, torch.optim.Adam(model.parameters(), lr=1e-3), use_embeddings=True, embedding_type="whisper"
)
print("Training")
trainer.train_model(epochs=2, run_name="Whisper", eps=1e-5)
print("Testing")
trainer.test(window_length, hop_length, number_of_mismatch, None)

Training
EPOCH 1:
  batch 100 loss: 0.0
  batch 200 loss: 0.0
  batch 300 loss: 0.0
  batch 400 loss: 0.0
  batch 500 loss: 0.0
  batch 600 loss: 0.0
  batch 700 loss: 0.0
  batch 800 loss: 0.0
  batch 900 loss: 0.0
  batch 1000 loss: 0.0
  batch 1100 loss: 0.0
  batch 1200 loss: 0.0
  batch 1300 loss: 0.0
  batch 1400 loss: 0.0
  batch 1500 loss: 0.0
  batch 1600 loss: 0.0
  batch 1700 loss: 0.0
  batch 1800 loss: 0.0
  batch 1900 loss: 0.0
  batch 2000 loss: 0.0
  batch 2100 loss: 0.0
  batch 2200 loss: 0.0
  batch 2300 loss: 0.0
  batch 2400 loss: 0.0
  batch 2500 loss: 0.0
  batch 2600 loss: 0.0
  batch 2700 loss: 0.0
  batch 2800 loss: 0.0
  batch 2900 loss: 0.0
  batch 3000 loss: 0.0
  batch 3100 loss: 0.0
  batch 3200 loss: 0.0
  batch 3300 loss: 0.0
  batch 3400 loss: 0.0
  batch 3500 loss: 0.0
  batch 3600 loss: 0.0
  batch 3700 loss: 0.0
  batch 3800 loss: 0.0
  batch 3900 loss: 0.0
  batch 4000 loss: 0.0
  batch 4100 loss: 0.0
  batch 4200 loss: 0.0
  batch 4300 loss: 0.0
  

Трансформер для ЭЭГ + Эмбеддинги `Wav2Vec`

In [5]:
model = Model(use_transformer=True, use_embeddings=True)
trainer = Trainer(
    model, train_files, val_files, test_files, args, torch.optim.Adam(model.parameters(), lr=1e-3), use_embeddings=True, embedding_type="wav2vec"
)
print("Training")
trainer.train_model(epochs=2, run_name="Transformer_Wav2Vec", eps=1e-5)
print("Testing")
trainer.test(window_length, hop_length, number_of_mismatch, None)

Training
EPOCH 1:
  batch 100 loss: 0.0012281817494652746
  batch 200 loss: 1.06697241101672e-05
  batch 300 loss: 0.00010448084200857188
  batch 400 loss: 3.885615478793625e-07
  batch 500 loss: 1.3007353100391583e-07
  batch 600 loss: 1.6763799237651256e-10
  batch 700 loss: 4.1924858962261166e-08
  batch 800 loss: 1.368063263385011e-07
  batch 900 loss: 1.7764142739018497e-07
  batch 1000 loss: 8.501882689415651e-07
  batch 1100 loss: 2.497818298934362e-06
  batch 1200 loss: 1.1581593201981377e-06
  batch 1300 loss: 1.8590069304047496e-07
  batch 1400 loss: 1.5280927270211465e-07
  batch 1500 loss: 9.576736552219245e-08
  batch 1600 loss: 2.9988486283105685e-09
  batch 1700 loss: 5.027000229418377e-08
  batch 1800 loss: 1.111972579792564e-08
  batch 1900 loss: 1.1008106194365652e-08
  batch 2000 loss: 1.1175870007207323e-10
  batch 2100 loss: 0.0
  batch 2200 loss: 0.0
  batch 2300 loss: 0.0
  batch 2400 loss: 1.303851426825986e-10
  batch 2500 loss: 0.0
  batch 2600 loss: 0.0
  bat

Трансформер для ЭЭГ + Эмбеддинги `Whisper`

In [5]:
model = Model(use_transformer=True, use_embeddings=True)
trainer = Trainer(
    model, train_files, val_files, test_files, args, torch.optim.Adam(model.parameters(), lr=1e-3), use_embeddings=True, embedding_type="whisper"
)
print("Training")
trainer.train_model(epochs=2, run_name="Transformer_Whisper", eps=1e-5)
print("Testing")
trainer.test(window_length, hop_length, number_of_mismatch, None)

Training
EPOCH 1:
  batch 100 loss: 0.4832695484161377
  batch 200 loss: 0.0
  batch 300 loss: 0.0
  batch 400 loss: 0.0
  batch 500 loss: 0.0
  batch 600 loss: 0.0
  batch 700 loss: 0.0
  batch 800 loss: 0.0
  batch 900 loss: 0.0
  batch 1000 loss: 0.0
  batch 1100 loss: 0.0
  batch 1200 loss: 0.0
  batch 1300 loss: 0.0
  batch 1400 loss: 0.0
  batch 1500 loss: 0.0
  batch 1600 loss: 0.0
  batch 1700 loss: 0.0
  batch 1800 loss: 0.0
  batch 1900 loss: 0.0
  batch 2000 loss: 0.0
  batch 2100 loss: 0.0
  batch 2200 loss: 0.0
  batch 2300 loss: 0.0
  batch 2400 loss: 0.0
  batch 2500 loss: 0.0
  batch 2600 loss: 0.0
  batch 2700 loss: 0.0
  batch 2800 loss: 0.0
  batch 2900 loss: 0.0
  batch 3000 loss: 0.0
  batch 3100 loss: 0.0
  batch 3200 loss: 0.0
  batch 3300 loss: 0.0
  batch 3400 loss: 0.0
  batch 3500 loss: 0.0
  batch 3600 loss: 0.0
  batch 3700 loss: 0.0
  batch 3800 loss: 0.0
  batch 3900 loss: 0.0
  batch 4000 loss: 0.0
  batch 4100 loss: 0.0
  batch 4200 loss: 0.0
  batch 43