In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False

In [None]:
import os

if IN_COLAB:
    ! pip3 install mne
    ! pip3 install yadisk

    import yadisk

    if not os.path.exists('dataset.zip'):
        url = 'https://disk.yandex.ru/d/hItiy8VK-lk0gQ' 
        y = yadisk.YaDisk()
        y.download_public(url, 'dataset.zip')

        ! unzip dataset.zip

    if not os.path.exists('eegproject'):
        ! unzip eegproject.zip

In [None]:
import torch
import random
import numpy as np

def set_random_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
set_random_seed(42)

In [None]:
from eegproject.data.sequential_dataset import SequentialEEGDataset, get_short_sequence_dataset, iterate_batches
from eegproject.data.transforms import scale

In [None]:
train_dataset_ = SequentialEEGDataset(split='train', preprocessed_path='eeg_dataset/train_eeg_dataset.pt', transform=scale)
test_dataset_ = SequentialEEGDataset(split='test', preprocessed_path='eeg_dataset/test_eeg_dataset.pt', transform=scale)

In [None]:
import math

def iterate_batch_multiple_seq_minibatches(inputs, targets, batch_size, seq_length):
    assert len(inputs) == len(targets)
    n_inputs = len(inputs)

    input_sample_shape = inputs[0].shape[1:]
    target_sample_shape = targets[0].shape[1:]
    
    seq_idx = np.arange(n_inputs)
    np.random.shuffle(seq_idx)
    
    # Compute the number of maximum loops
    n_loops = int(math.ceil(len(seq_idx) / batch_size))

    # For each batch of subjects (size=batch_size)
    for l in range(n_loops):
        start_idx = l*batch_size
        end_idx = (l+1)*batch_size
        seq_inputs = np.asarray(inputs)[seq_idx[start_idx:end_idx]]
        seq_targets = np.asarray(targets)[seq_idx[start_idx:end_idx]]

        max_skips = 5
        for s_idx in range(len(seq_inputs)):
            n_skips = np.random.randint(max_skips)
            seq_inputs[s_idx] = seq_inputs[s_idx][n_skips:]
            seq_targets[s_idx] = seq_targets[s_idx][n_skips:]

        # Determine the maximum number of batch sequences
        n_max_seq_inputs = -1
        for s_idx, s in enumerate(seq_inputs):
            if len(s) > n_max_seq_inputs:
                n_max_seq_inputs = len(s)

        n_batch_seqs = int(math.ceil(n_max_seq_inputs / seq_length))

        # For each batch sequence (size=seq_length)
        for b in range(n_batch_seqs):
            start_loop = True if b == 0 else False
            start_idx = b*seq_length
            end_idx = (b+1)*seq_length
            batch_inputs = np.zeros((batch_size, seq_length) + input_sample_shape, dtype=np.float32)
            batch_targets = np.ones((batch_size, seq_length) + target_sample_shape, dtype=int) * -1
            batch_seq_len = np.zeros(batch_size, dtype=int)
            # For each subject
            for s_idx, s in enumerate(zip(seq_inputs, seq_targets)):
                # (seq_len, sample_shape)
                each_seq_inputs = s[0][start_idx:end_idx]
                each_seq_targets = s[1][start_idx:end_idx]
                batch_inputs[s_idx, :len(each_seq_inputs)] = each_seq_inputs
                batch_targets[s_idx, :len(each_seq_targets)] = each_seq_targets
                batch_seq_len[s_idx] = len(each_seq_inputs)
            batch_x = torch.tensor(batch_inputs).float() # .reshape((-1,) + input_sample_shape)
            batch_y = torch.tensor(batch_targets) # .reshape((-1,) + target_sample_shape)
            yield batch_x, batch_y, start_loop

In [None]:
import matplotlib.pyplot as plt

cnt = 0

for x, y, sl in iterate_batch_multiple_seq_minibatches(train_dataset_.X, train_dataset_.y, 32, 10):
    print(x.shape)
    break
    
cnt

torch.Size([32, 10, 3000])




0

In [None]:
from eegproject.models.cnn_classifier import CNNClassifier
from eegproject.models.lstm_classifier import LSTMClassifier
from eegproject.models.cnn_encoder import CNNEncoder
from torch.nn.utils.rnn import pad_sequence
import torch

In [None]:
from torch.utils.data import DataLoader

In [None]:
batch_size = 15
sequence_length = 35

In [None]:
from tqdm.auto import tqdm
import numpy as np
from sklearn.metrics import accuracy_score

def train_one_epoch(model, train_dataloader, criterion, optimizer, device="cuda:0"):
    model.to(device).train()
    cum_loss = 0
    n_objects = 0
    state = None
    for features, y, sl in tqdm(iterate_batch_multiple_seq_minibatches(train_dataset_.X, train_dataset_.y, batch_size, sequence_length)):
        if sl:
            state = model.init_state(device)
        
        features = features.to(device)
        y = y.to(device)
                
        preds, new_state = model(features, state)

        loss = criterion(preds[y != -1], y[y != -1])
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        state = new_state
        
        cur_obj = torch.sum(y != -1).cpu()
        cum_loss += loss.detach().cpu().numpy() * cur_obj.numpy()
        n_objects += cur_obj
        
    return cum_loss / n_objects


def predict(model, test_dataloder, criterion, device="cuda:0"):
    model.to(device).eval()
    with torch.no_grad():        
        predicts = torch.tensor([])
        true_values = torch.tensor([])
        cum_loss = 0
        n_objects = 0
        state = None
        for features, y, sl in tqdm(iterate_batch_multiple_seq_minibatches(test_dataset_.X, test_dataset_.y, batch_size, sequence_length)):
            if sl:
                state = model.init_state(device)
            
            features = features.to(device)
            y = y.to(device)
            
            cur, state = model(features, state)
            cur = cur[y != -1]
            
            predicts = torch.cat([predicts, torch.argmax(cur.cpu(), axis=1)])
            true_values = torch.cat([true_values, y[y != -1].cpu()])

            cur_obj = torch.sum(y != -1).cpu()
            n_objects += cur_obj
            cum_loss += criterion(cur, y[y != -1]).cpu().item() * cur_obj
        
        return cum_loss / n_objects, predicts, true_values
    

def train(model, train_dataloader, test_dataloader, criterion, optimizer, device="cuda:0", n_epochs=10, scheduler=None):
    model.to(device)
    for epoch in range(n_epochs):
        print('Train')
        train_loss = train_one_epoch(model, train_dataloader, criterion, optimizer, device)
        print('Evaluate')
        val_loss, predicted, true = predict(model, test_dataloader, criterion, device)
        if scheduler is not None:
            scheduler.step(val_loss)
            
        accuracy = accuracy_score(predicted, true)
        print('Epoch {}, val loss {:.3f}, train loss {:.3f}, accuracy {:.3f}'\
              .format(epoch + 1, val_loss, train_loss, accuracy))

In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)

cuda:0


In [None]:
encoder = torch.load('cnn_encoder.pt', map_location='cpu')
model = LSTMClassifier(encoder, batch_size=batch_size, bidirectional=True, hidden_size=32)

clip_value = 1.0
for p in model.parameters():
    p.register_hook(lambda grad: torch.clamp(grad, -clip_value, clip_value))

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5, verbose=True)

In [None]:
val_loss, predicted, true = predict(model, None, criterion, device)
accuracy_score(predicted, true)

0it [00:00, ?it/s]



0.36182561146534536

In [None]:
train(model, None, None, criterion, optimizer, device, 35, scheduler)

Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 1, val loss 1.303, train loss 0.901, accuracy 0.595
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 2, val loss 0.969, train loss 0.804, accuracy 0.659
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 3, val loss 1.573, train loss 0.731, accuracy 0.540
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 4, val loss 1.317, train loss 0.718, accuracy 0.523
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 5, val loss 1.175, train loss 0.705, accuracy 0.572
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 6, val loss 0.912, train loss 0.686, accuracy 0.655
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 7, val loss 1.189, train loss 0.658, accuracy 0.582
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 8, val loss 0.922, train loss 0.657, accuracy 0.681
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 9, val loss 0.889, train loss 0.651, accuracy 0.669
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 10, val loss 0.673, train loss 0.636, accuracy 0.753
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 11, val loss 1.089, train loss 0.623, accuracy 0.607
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 12, val loss 1.004, train loss 0.630, accuracy 0.645
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 13, val loss 1.179, train loss 0.593, accuracy 0.598
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 00014: reducing learning rate of group 0 to 5.0000e-05.
Epoch 14, val loss 0.879, train loss 0.596, accuracy 0.672
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 15, val loss 0.877, train loss 0.612, accuracy 0.676
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 16, val loss 1.450, train loss 0.608, accuracy 0.563
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 17, val loss 0.942, train loss 0.606, accuracy 0.664
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 00018: reducing learning rate of group 0 to 2.5000e-05.
Epoch 18, val loss 0.871, train loss 0.569, accuracy 0.684
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 19, val loss 1.288, train loss 0.586, accuracy 0.595
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 20, val loss 0.908, train loss 0.571, accuracy 0.667
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 21, val loss 1.112, train loss 0.585, accuracy 0.635
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 00022: reducing learning rate of group 0 to 1.2500e-05.
Epoch 22, val loss 1.434, train loss 0.565, accuracy 0.585
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 23, val loss 0.911, train loss 0.567, accuracy 0.673
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 24, val loss 0.919, train loss 0.560, accuracy 0.666
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 25, val loss 0.988, train loss 0.561, accuracy 0.664
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 00026: reducing learning rate of group 0 to 6.2500e-06.
Epoch 26, val loss 1.113, train loss 0.567, accuracy 0.628
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 27, val loss 1.120, train loss 0.546, accuracy 0.620
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 28, val loss 1.068, train loss 0.572, accuracy 0.629
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 29, val loss 0.881, train loss 0.582, accuracy 0.670
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 00030: reducing learning rate of group 0 to 3.1250e-06.
Epoch 30, val loss 1.180, train loss 0.568, accuracy 0.612
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 31, val loss 1.023, train loss 0.555, accuracy 0.642
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 32, val loss 0.997, train loss 0.546, accuracy 0.661
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 33, val loss 0.906, train loss 0.545, accuracy 0.667
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 00034: reducing learning rate of group 0 to 1.5625e-06.
Epoch 34, val loss 1.003, train loss 0.548, accuracy 0.655
Train


0it [00:00, ?it/s]



Evaluate


0it [00:00, ?it/s]

Epoch 35, val loss 0.980, train loss 0.551, accuracy 0.648


In [None]:
torch.save(model, 'lstm_classifier.pt')