In [38]:
import torch
from tqdm import tqdm
import argparse

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3,4"

import objgraph

import time
import utils
from result_tracking import ThinkerwiseResultTracker

import dn3
from dn3.configuratron import ExperimentConfig
from dn3.data.dataset import Thinker
from dn3.trainable.processes import StandardClassification

from dn3_ext import BENDRClassification, LinearHeadBENDR
from matplotlib import pyplot as plt
import numpy as np

In [9]:
experiment = ExperimentConfig("configs/downstream.yml")

Adding additional configuration entries: dict_keys(['train_params', 'lr', 'folds'])
Configuratron found 1 datasets.


In [10]:
data = experiment.datasets["mmidb"]

In [11]:
ds_name, ds = list(experiment.datasets.items())[0]

In [12]:
gen = utils.get_lmoso_iterator(ds_name, ds)

Scanning data/datasets/eegmmidb. If there are a lot of files, this may take a while...: 100%|██████████| 4/4 [00:00<00:00, 12.89it/s, extension=.gdf]


Creating dataset of 315 Preloaded Epoched recordings from 105 people.


Loading Physionet MMIDB: 100%|██████████| 105/105 [00:13<00:00,  7.53person/s]


>> Physionet MMIDB | DSID: None | 105 people | 4408 trials | 90 channels | 1536 samples/trial | 256Hz | 0 transforms
Constructed 1 channel maps
Used by 315 recordings:
EEG (original(new)): Fc5.(FC5) Fc3.(FC3) Fc1.(FC1) Fcz.(FCZ) Fc2.(FC2) Fc4.(FC4) Fc6.(FC6) C5..(C5) C3..(C3) C1..(C1) Cz..(CZ) C2..(C2) C4..(C4) C6..(C6) Cp5.(CP5) Cp3.(CP3) Cp1.(CP1) Cpz.(CPZ) Cp2.(CP2) Cp4.(CP4) Cp6.(CP6) Fp1.(FP1) Fpz.(FPZ) Fp2.(FP2) Af7.(AF7) Af3.(AF3) Afz.(AFZ) Af4.(AF4) Af8.(AF8) F7..(F7) F5..(F5) F3..(F3) F1..(F1) Fz..(FZ) F2..(F2) F4..(F4) F6..(F6) F8..(F8) Ft7.(FT7) Ft8.(FT8) T7..(T7) T8..(T8) T9..(T9) T10.(T10) Tp7.(TP7) Tp8.(TP8) P7..(P7) P5..(P5) P3..(P3) P1..(P1) Pz..(PZ) P2..(P2) P4..(P4) P6..(P6) P8..(P8) Po7.(PO7) Po3.(PO3) Poz.(POZ) Po4.(PO4) Po8.(PO8) O1..(O1) Oz..(OZ) O2..(O2) Iz..(IZ) 
EOG (original(new)): 
REF (original(new)): 
EXTRA (original(new)): 
Heuristically Assigned: Fc5.(FC5)  Fc3.(FC3)  Fc1.(FC1)  Fcz.(FCZ)  Fc2.(FC2)  Fc4.(FC4)  Fc6.(FC6)  C5..(C5)  C3..(C3)  C1..(C1)  Cz.

In [13]:
training, validation, test = next(gen)

Training:   >> Physionet MMIDB | DSID: None | 63 people | 2646 trials | 20 channels | 1536 samples/trial | 256Hz | 1 transforms
Validation: >> Physionet MMIDB | DSID: None | 21 people | 880 trials | 20 channels | 1536 samples/trial | 256Hz | 1 transforms
Test:       >> Physionet MMIDB | DSID: None | 21 people | 882 trials | 20 channels | 1536 samples/trial | 256Hz | 1 transforms


In [14]:
x, y = training.to_numpy()
x = torch.Tensor(x)
y = torch.Tensor(y)

Loading Batches: 100%|██████████| 42/42 [00:05<00:00,  7.72it/s]


In [16]:
class BENDR(BENDRClassification):
    def __init__(self, targets, samples, channels, encoder_h=512, contextualizer_hidden=3076, projection_head=False,
                 new_projection_layers=0, dropout=0., trial_embeddings=None, layer_drop=0, keep_layers=None,
                 mask_p_t=0.01, mask_p_c=0.005, mask_t_span=0.1, mask_c_span=0.1, multi_gpu=False):
        super().__init__(targets, samples, channels, encoder_h, contextualizer_hidden, projection_head,
                 new_projection_layers, dropout, trial_embeddings, layer_drop, keep_layers,
                 mask_p_t, mask_p_c, mask_t_span, mask_c_span, multi_gpu)
        
        self.input_shape = torch.Size((channels, samples))
        self.output_shape = torch.Size((targets, ))
        
    def load_classifier(self, classifier_file, freeze=False):
        classifier_state_dict = torch.load(classifier_file)
        self.classifier.load_state_dict(classifier_state_dict, strict=True)

        for param in self.classifier.parameters():
            param.requires_grad = not freeze

    def load_all(self, encoder_file: str, contextualizer_file: str, classifier_file: str, strict=True,
                 freeze_encoder=False, freeze_contextualizer=False, freeze_classifier=False):
        encoder_state_dict = torch.load(encoder_file)
        self.encoder.load_state_dict(encoder_state_dict, strict=strict)
        self.encoder.freeze_features(unfreeze=not freeze_encoder)

        contextualizer_state_dict = torch.load(contextualizer_file)
        self.contextualizer.load_state_dict(contextualizer_state_dict, strict=True)
        self.contextualizer.freeze_features(unfreeze=not freeze_contextualizer)

        self.load_classifier(classifier_file, freeze=freeze_classifier)

    def save_all(self, encoder_file: str, contextualizer_file: str, classifier_file: str):
        torch.save(self.encoder.state_dict(), encoder_file)
        torch.save(self.contextualizer.state_dict(), contextualizer_file)
        torch.save(self.classifier.state_dict(), classifier_file)

    def forward(self, x: torch.Tensor, return_features: bool = False, grad: bool = False):
        assert isinstance(x, torch.Tensor), "Input has to be of instance torch.Tensor"

        if x.shape == self.input_shape: x = torch.unsqueeze(x, dim=0)
        
        assert x.shape[1:] == self.input_shape, "Input has to be of shape (*, {}, {})".format(*self.input_shape)

        self.return_features = return_features
        
        if grad:
            output = super().forward(x)
        else:
            with torch.no_grad(): output = super().forward(x)

        if return_features:
            return output[0], output[1]
        else:
            return output
        
    def forward_probs(self, x: torch.Tensor, return_features: bool = False, grad: bool = False):
        output = self.forward(x, return_features, grad)

        if return_features:
            return output[0].softmax(dim=1), output[1]
        else:
            return output.softmax(dim=1)

    def evaluate(self, X, batch_size = 8, return_probs = False):
        assert isinstance(X, torch.Tensor), "Input has to be of instance torch.Tensor"
        assert X.shape[1:] == self.input_shape, "Input has to be of shape (*, {}, {})".format(*self.input_shape)

        N = len(X)

        probs = torch.empty((N, self.targets))

        prog_bar = tqdm(total = int(torch.ceil(torch.Tensor([len(X) / batch_size])).item()),
                        desc="Evaluating", unit="batches")

        for i, x in zip(range(0, N, batch_size), X.split(batch_size)):
            probs[i:(i + batch_size)] = self.forward_probs(x)
            prog_bar.update(1)

        prog_bar.close()

        predictions = probs.argmax(1)
        
        if return_probs:
            return predictions, probs
        else:
            return predictions


In [17]:
experiment.encoder_weights = 'encoder_BENDR.pt'
experiment.context_weights = 'contextualizer_BENDR.pt'

In [18]:
model = BENDR(targets=2, samples=1536, channels=20)
model.load_all(experiment.encoder_weights, experiment.context_weights, "classifier_BENDR.pt")
model = model.train(False)

Receptive field: 143 samples | Downsampled by 96 | Overlap of 47 samples | 16 encoded samples/trial


In [19]:
device = torch.device('cuda')
model = model.to(device)

In [22]:
preds = model.evaluate(x.to(device)).cpu()

Evaluating: 100%|██████████| 331/331 [00:08<00:00, 37.28batch/s]
