In [51]:
!nvidia-smi
# !export 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512'

Fri Sep  1 20:04:36 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.199.02   Driver Version: 470.199.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:19:00.0 Off |                  N/A |
| 30%   27C    P8    21W / 350W |   4257MiB / 24268MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:65:00.0 Off |                  N/A |
| 34%   45C    P2   114W / 350W |   1717MiB / 24268MiB |      0%      Default |
|       

##### %config Completer.use_jedi = False

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
# os.environ['http_proxy'] = "http://127.0.0.1:3128"
# os.environ['https_proxy'] = "http://127.0.0.1:3128"

In [55]:
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
torch.cuda.empty_cache()
import numpy as np
import pandas as pd
import os

from typing import Callable
import functools
from pathlib import Path

In [4]:
from torch.utils.tensorboard import SummaryWriter

In [5]:
import torch
from transformers import BertConfig
from transformers.models.bert.modeling_bert import BertEncoder

# Architecture

In [6]:
class TransposeCustom(torch.nn.Module):
    def __init__(self):
        super(TransposeCustom, self).__init__()
        
    def forward(self, x):
        return torch.transpose(x, 1, 2)

In [7]:
# config = BertConfig(is_decoder=True, 
#                     add_cross_attention=True,
#                     ff_layer='conv',
#                     conv_kernel=1,
#                     conv_kernel_num=3)

In [9]:
def _make_span_from_seeds(seeds, span, total=None):
    inds = list()
    for seed in seeds:
        for i in range(seed, seed + span):
            if total is not None and i >= total:
                break
            elif i not in inds:
                inds.append(int(i))
    return np.array(inds)

def _make_mask(shape, p, total, span, allow_no_inds=False):
    # num_mask_spans = np.sum(np.random.rand(total) < p)
    # num_mask_spans = int(p * total)
    mask = torch.zeros(shape, requires_grad=False, dtype=torch.bool)

    for i in range(shape[0]):
        mask_seeds = list()
        while not allow_no_inds and len(mask_seeds) == 0 and p > 0:
            mask_seeds = np.nonzero(np.random.rand(total) < p)[0]

        mask[i, _make_span_from_seeds(mask_seeds, span, total=total)] = True

    return mask

In [10]:
import math
class PositionalEncoding(torch.nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = torch.nn.Dropout(p=dropout)
        
        position = torch.arange(max_len).unsqueeze(1)
        position = (position.T - position).T / max_len
        self.register_buffer('rel_position', position)
        
        self.conv = torch.nn.Conv1d(max_len, d_model, 25, padding=25 // 2, groups=16)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [batch_size, seq_len, embedding_dim]
        """
        rel_pos = self.conv(self.rel_position[:x.size(1), :x.size(1)][None])[0].T
        print(rel_pos.shape)
        x = x + rel_pos
        return self.dropout(x)

In [11]:
channels = ['Fp1', 'Fp2', 'FZ', 'FCz', 'Cz', 'Pz', 'O1', 'O2', 'F3', 'F4', 
               'F7', 'F8', 'C3', 'C4', 'T3', 'T4', 'P3', 'P4', 'T5', 'T6', 'A1', 'A2']
channels = [i.upper() for i in channels]
print(channels)

['FP1', 'FP2', 'FZ', 'FCZ', 'CZ', 'PZ', 'O1', 'O2', 'F3', 'F4', 'F7', 'F8', 'C3', 'C4', 'T3', 'T4', 'P3', 'P4', 'T5', 'T6', 'A1', 'A2']


In [167]:
class EEGEmbedder(torch.nn.Module):
    def __init__(self, num_labels=2, channels=channels):
        super(EEGEmbedder, self).__init__()
        self.config = BertConfig(is_decoder=False, 
                    add_cross_attention=False,
                    ff_layer='linear',
                    hidden_size=512,
                    num_attention_heads=8,
                    num_hidden_layers=8,
                    conv_kernel=1,
                    conv_kernel_num=1)
        self.channels = channels
        self.model = BertEncoder(self.config)

        self.clf = True
        self.masking = True
        
        
        self.pos_e = PositionalEncoding(512, max_len=6000)
        self.ch_embedder = torch.nn.Embedding(22, 512)
        self.ch_norm = torch.nn.LayerNorm(512)
        
        self.input_norm = torch.nn.LayerNorm(2)
        self.input_embedder = torch.nn.Sequential(
            TransposeCustom(),
            torch.nn.Conv1d(21, 32, 5, 2, padding=0),
            torch.nn.Conv1d(32, 64, 5, 2, padding=0),
            # TransposeCustom(),
            torch.nn.GroupNorm(64 // 2, 64),
            torch.nn.GELU(),
            # TransposeCustom(),
            torch.nn.Conv1d(64, 128, 3, 2, padding=0),
            torch.nn.Conv1d(128, 196, 3, 2, padding=0),
            # TransposeCustom(),
            torch.nn.GroupNorm(196 // 2, 196),
            torch.nn.GELU(),
            # TransposeCustom(),
            torch.nn.Conv1d(196, 256, 5, 1, padding=0),
            torch.nn.Conv1d(256, 384, 5, 1, padding=0),
            # TransposeCustom(),
            torch.nn.GroupNorm(384 // 2, 384),
            torch.nn.GELU(),
            # TransposeCustom(),
            torch.nn.Conv1d(384, 512, 5, 1, padding=0),
            torch.nn.Conv1d(512, 512, 1, 1, padding=0),
            torch.nn.GroupNorm(512 // 2, 512),
            torch.nn.GELU(),
            TransposeCustom(),
        # torch.nn.LeakyReLU(),
        )
        
        self.output_embedder = torch.nn.Linear(512, 512)
        self.transpose = TransposeCustom()
        
        self.mask_embedding = torch.nn.Parameter(torch.normal(0, 512**(-0.5), size=(512,)),
                                                   requires_grad=True)
        self.placeholder = torch.nn.Parameter(torch.normal(0, 512**(-0.5), size=(512,)),
                                                   requires_grad=True)
        # self.classifier = torch.nn.Linear()
        
        self.classifier = torch.nn.Sequential(
            # torch.nn.Linear(768, 256),
            torch.nn.Linear(self.config.hidden_size, self.config.num_labels),
            torch.nn.Linear(256, 2),
            torch.nn.Sigmoid()
        )
        
    def forward(self, inputs, attention_mask, ch_vector):
        embedding = self.input_embedder(inputs)
        # create embedding for two channel indexes and sumup them to a single one
        ch_embedding = self.ch_embedder(ch_vector).sum(1)
        ch_embedding = ch_embedding[:, None]
        embedding += ch_embedding
        embedding_unmasked = embedding.clone() # keep for loss calculation
        # perform masking
        if self.masking:
            mask = _make_mask((embedding.shape[0], embedding.shape[1]), 0.05, embedding.shape[1], 10)
            embedding[mask] = self.mask_embedding
        
        # additional vector for classification tasks later
        placeholder = torch.zeros((embedding.shape[0], 1, embedding.shape[2]), device=embedding.device)
        placeholder += self.placeholder
        embedding = torch.cat([placeholder, embedding], 1)
        encoder_output = self.model(embedding, output_hidden_states=True,
                               output_attentions=True)[0]
    
        encoder_output = self.output_embedder(encoder_output)

        if self.clf:
            logits = self.classifier(encoder_output[:, 0]) #sequence classification
            # logits = self.classifier(encoder_output)  # token classification ?
            return logits
            
        
        return encoder_output[:, 1:], embedding_unmasked

In [13]:
input = torch.rand(10, 10000)

In [14]:
model = EEGEmbedder(2)

# Data

In [15]:
def _generate_negatives(z):
    """Generate negative samples to compare each sequence location against"""
    num_negatives = 20
    batch_size, feat, full_len = z.shape
    z_k = z.permute([0, 2, 1]).reshape(-1, feat)
    with torch.no_grad():
        # candidates = torch.arange(full_len).unsqueeze(-1).expand(-1, self.num_negatives).flatten()
        negative_inds = torch.randint(0, full_len-1, size=(batch_size, full_len * num_negatives))
        # From wav2vec 2.0 implementation, I don't understand
        # negative_inds[negative_inds >= candidates] += 1

        for i in range(1, batch_size):
            negative_inds[i] += i * full_len

    z_k = z_k[negative_inds.view(-1)].view(batch_size, full_len, num_negatives, feat)
    return z_k, negative_inds

In [16]:
def _calculate_similarity( z, c, negatives):
    c = c[..., :].permute([0, 2, 1]).unsqueeze(-2)
    z = z.permute([0, 2, 1]).unsqueeze(-2)

    # In case the contextualizer matches exactly, need to avoid divide by zero errors
    negative_in_target = (c == negatives).all(-1)
    targets = torch.cat([c, negatives], dim=-2)

    logits = torch.nn.functional.cosine_similarity(z, targets, dim=-1) / 0.1
    if negative_in_target.any():
        logits[1:][negative_in_target] = float("-inf")

    return logits.view(-1, logits.shape[-1])

In [17]:
splitted_paths = ['/media/hdd/data/TUH_pretrain.filtered_1_40.v2.splited/{}'.format(i) for i in os.listdir('/media/hdd/data/TUH_pretrain.filtered_1_40.v2.splited/')]

In [18]:
sample = np.load(np.random.choice(splitted_paths, 1)[0], allow_pickle=True)

In [19]:
len(splitted_paths)

161710

In [20]:
def masking(ts):
    start_shift = np.random.choice(range(10))
    downsampling = 2
    indices = np.random.choice(np.array(list(range(110)))[start_shift::10][::downsampling], 5, replace=False)
    masked_idx = []
    for i in indices:
        masked_idx.extend(range(i, i+10))

    masked_idx = np.array(masked_idx)
    
    # mask = np.ones((6000, 2))
    # # desync some masked channels
    # ts_masked = ts.copy()
    # if np.random.choice([0, 1], p=[0.7, 0.3]):
    #     ts_masked[masked_idx, np.random.choice([0, 1])] *= 0
    # else:
    #     ts_masked[masked_idx] *= 0
        
    return None, masked_idx

In [21]:
def SelectChannels(sample, channels=channels):
    signal = sample['value_pure']
    
    channels_ids = [i for i, val in enumerate(sample['channels']) if i != 3 and val in channels]
    
    
    # choose 2 random channels
    # channels_to_train = np.random.choice(channels_ids, 2, replace=False)
    # use all available
    channels_to_train = channels_ids
    signal = signal[:, channels_to_train]
    return signal, channels_to_train

def SleepEdf():
    sample_ch = sample['channels']
    return
    


class TEST(torch.utils.data.Dataset):
    def __init__(self, path, preprocess_montage:Callable=SelectChannels) :
        super(TEST, self).__init__()
        self.main_path = path
        self.paths = path
        self.preprocess_montage=preprocess_montage
 

    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, idx):
        path = self.paths[idx]
        # take 60s of recording with specified shift
        key = False
        while(key == False):
            try:
                sample = np.load(path, allow_pickle=True).item()
                key = True
            except Exception as e:
                print("Path: {} is broken ".format(path), e)
                path = np.random.choice(self.paths, 1)[0]
                

        signal, channels_to_train = self.preprocess_montage(sample)
        real_len = signal.shape[0]

        channels_vector = torch.tensor((channels_to_train))

        
        # remove normalization for now with within channel z-norm
        # sample_norm = (sample - tuh_filtered_stat_vals['min_vals_filtered'][channels_vector]) / (tuh_filtered_stat_vals['max_vals_filtered'][channels_vector] - tuh_filtered_stat_vals['min_vals_filtered'][channels_vector] + 1e-6)
        sample_norm_mean = signal.mean()
        sample_norm_std = np.std(signal)
        
        signal_norm = (signal - sample_norm_mean) / (sample_norm_std)
        
        if signal_norm.shape[0] < 6000:
            signal_norm = np.pad(signal_norm, ((0, 6000 - signal_norm.shape[0]), (0, 0)))
        
        attention_mask = torch.ones(6000)
        attention_mask[real_len:] = 0
        return {'anchor': torch.from_numpy(signal_norm).float(), 
                # 'label': sample_label, 
                # 'anchor_masked': torch.from_numpy(sample_masked).float(), 
                # 'mask': torch.tensor(mask),
                'channels': channels_vector,
                'attention_mask': attention_mask}

In [22]:
test_dataset = TEST(splitted_paths, preprocess_montage=functools.partial(SelectChannels, channels=channels))
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0, drop_last=True)

In [64]:
test_dataset.__getitem__(4)

{'anchor': tensor([[-2.0693e+00, -1.4693e+00, -9.1560e-03,  ..., -4.1035e+00,
          -3.9363e+00, -4.0551e+00],
         [-2.0585e+00, -1.4897e+00,  6.2453e-04,  ..., -4.0526e+00,
          -3.8592e+00, -4.0669e+00],
         [-1.8462e+00, -1.3069e+00,  2.4727e-01,  ..., -3.9245e+00,
          -3.7516e+00, -3.8181e+00],
         ...,
         [-4.8936e-01, -5.6444e-01, -5.1781e-01,  ...,  4.8580e-01,
           3.1735e-01,  7.0941e-01],
         [-9.7947e-02, -9.8642e-02, -1.3797e-01,  ...,  9.3987e-01,
           7.8519e-01,  1.2159e+00],
         [-1.8030e-01, -1.2938e-01, -1.6457e-01,  ...,  7.5456e-01,
           5.9439e-01,  1.1224e+00]]),
 'channels': tensor([ 0,  1,  2,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
         19, 20, 21]),
 'attention_mask': tensor([1., 1., 1.,  ..., 1., 1., 1.])}

# Training

In [24]:
model = EEGEmbedder()
# model.load_state_dict(torch.load(file))


In [25]:

from torch.optim.lr_scheduler import _LRScheduler
class NoamLR(_LRScheduler):
    def __init__(self, optimizer, warmup_steps, d_model=512):
        self.warmup_steps = warmup_steps
        self.d_model = d_model
        super().__init__(optimizer)

    def get_lr(self):
        last_epoch = max(1, self.last_epoch)
        factor = min(last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5))
        # scale = self.warmup_steps ** 0.5 * min(last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5))
        # return [base_lr * scale for base_lr in self.base_lrs]
        return [base_lr * self.d_model ** (-0.5) * factor for base_lr in self.base_lrs]

In [26]:
cossim = torch.nn.CosineSimilarity(dim=-1)

def cosloss(anchor, real, negative):
    a = torch.exp(cossim(anchor, real)) / 0.1
    b = sum([torch.exp(cossim(anchor, negative[:, n])) / 0.1 for n in range(negative.shape[1])]) + 1e-6
    return -torch.log(a/b)

In [27]:
import random
def worker_init_fn(worker_id):
    torch_seed = torch.initial_seed()
    random.seed(torch_seed + worker_id)
    np.random.seed((torch_seed + worker_id) % 2**30)

train_dataset = TEST(splitted_paths[:-15000])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0, drop_last=True, worker_init_fn = worker_init_fn)

test_dataset = TEST(splitted_paths[-15000:])
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0, drop_last=True, worker_init_fn = worker_init_fn)


In [28]:
writer = SummaryWriter('logs')

In [29]:

model.train()

lr_d = 1
acc_size = 8
training_epochs1 = 100000 // len(train_loader)

optim = torch.optim.AdamW(model.parameters(), lr=lr_d)

model_test = torch.nn.DataParallel(model)
model_test.to('cuda:0')

loss_func = torch.nn.MSELoss()
# scheduler = torch.optim.lr_scheduler.LinearLR(optim, start_factor=1.0, end_factor=0.1, total_iters=training_epochs1*len(train_loader))
scheduler = NoamLR(optim, 100000, 512)

steps = 0

In [30]:
len(train_loader), training_epochs1, training_epochs1 * len(train_loader)

(2292, 43, 98556)

In [31]:
# model.cpu()(batch['anchor'][None], batch['mask'][None], batch['channels'][None])

In [32]:

# for epoch in range(training_epochs1):
#     mean_loss = 0
#     acc_step = 0
#     for batch in train_loader:
#         logits = model_test(
#             batch['anchor'],#.to('cuda:0'), 
#             None,
#             batch['channels'].long(),
#             clf=False)
#         ae, label = model_test(
#             batch['anchor'],#.to('cuda:0'), 
#             None, 
#             batch['channels'].long())
        
#         logits = _calculate_similarity(torch.transpose(label, 1, 2), torch.transpose(ae, 1, 2), _generate_negatives(torch.transpose(label, 1, 2))[0])

#         fake_labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
#         loss = torch.nn.CrossEntropyLoss()(logits, fake_labels) + 0.001 * label.pow(2).mean()

#         loss = loss.mean()
#         loss.backward()
#         mean_loss += loss.item()
#         acc_step += 1
#         steps += 1
#         # raise
#         if acc_step != 0 and acc_step % acc_size == 0:
#             optim.step()
#             scheduler.step()
#             optim.zero_grad()
#             if steps % 100 == 0:
#                 print('Loss/train\t{}'.format(mean_loss / acc_size))
#             writer.add_scalar('Loss/train', mean_loss / acc_size, steps)
#             mean_loss = 0
#         if steps != 0 and steps % 1000 == 0:
#             der = 0
#             try:
#                 with torch.no_grad():
#                     for batch in test_loader:
#                         ae, label = model_test(
#                             batch['anchor'],#.to('cuda:0'), 
#                             None, 
#                             batch['channels'].long())
#                         logits = _calculate_similarity(torch.transpose(label, 1, 2), torch.transpose(ae, 1, 2), _generate_negatives(torch.transpose(label, 1, 2))[0])

#                         fake_labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
#                         loss = torch.nn.CrossEntropyLoss()(logits, fake_labels) + 0.001 * label.pow(2).mean()

#                         loss = loss.mean() / acc_size
#                         der += loss
#                 der /= len(test_loader)
#                 writer.add_scalar('Loss/test', der, steps)

#                 print('Loss: {}\t'.format(der))
#             except:
#                 raise
#             # torch.save(model_test.module.state_dict(), 'models/step.pt'.format(steps))

## INFER

In [56]:
path10 = Path('/media/hdd/pretraining/synt_data/data_10')
path12 = Path('/media/hdd/pretraining/synt_data/data_12')

In [163]:
def _SelectChannels(sample, channels=channels):
    signal = sample['value_pure']
    
    channels_ids = [i for i, val in enumerate(sample['channels']) if i != 3 and val in channels]
    
    
    # choose 2 random channels
    # channels_to_train = np.random.choice(channels_ids, 2, replace=False)
    # use all available
    channels_to_train = channels_ids
    signal = signal[:, channels_to_train]
    return signal, channels_to_train


class SYNT(torch.utils.data.Dataset):
    def __init__(self, paths, labels, preprocess_montage:Callable=_SelectChannels) :
        super(SYNT, self).__init__()
        self.paths = paths
        self.labels = labels
        self.preprocess_montage=preprocess_montage
 

    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, idx):
        path = self.paths[idx]
        # take 60s of recording with specified shift
        sample = np.load(path, allow_pickle=True)
        sample = {'value_pure':sample, 'channels':channels}
        
        # key = False
        # while(key == False):
        #     try:
        #         key = True
        #     except Exception as e:
        #         print("Path: {} is broken ".format(path), e)
        #         path = np.random.choice(self.paths, 1)[0]
                

        signal, channels_to_train = self.preprocess_montage(sample)
        real_len = signal.shape[0]

        channels_vector = torch.tensor((channels_to_train))

        
        # remove normalization for now with within channel z-norm
        # sample_norm = (sample - tuh_filtered_stat_vals['min_vals_filtered'][channels_vector]) / (tuh_filtered_stat_vals['max_vals_filtered'][channels_vector] - tuh_filtered_stat_vals['min_vals_filtered'][channels_vector] + 1e-6)
        sample_norm_mean = signal.mean()
        sample_norm_std = np.std(signal)
        
        signal_norm = (signal - sample_norm_mean) / (sample_norm_std)
        
        if signal_norm.shape[0] < 6000:
            signal_norm = np.pad(signal_norm, ((0, 6000 - signal_norm.shape[0]), (0, 0)))
        
        attention_mask = torch.ones(6000)
        attention_mask[real_len:] = 0
        return {'anchor': torch.from_numpy(signal_norm).float(), 
                'label': self.labels[idx], 
                # 'anchor_masked': torch.from_numpy(sample_masked).float(), 
                # 'mask': torch.tensor(mask),
                'channels': channels_vector,
                'attention_mask': attention_mask}

    
    
    
    
# paths = list(path10.glob("*.npy"))
# synt_dataset = SYNT(paths, labels=[0 for a in paths])
# synt_loader = torch.utils.data.DataLoader(synt_dataset, batch_size=64, shuffle=True, num_workers=0, drop_last=True, worker_init_fn = worker_init_fn)

# paths = list(path10.glob("*.npy"))
# synt_dataset12 = SYNT(paths, labels=np.ones(len(paths)))
# synt_loader12 = torch.utils.data.DataLoader(synt_dataset12, batch_size=64, shuffle=True, num_workers=0, drop_last=True, worker_init_fn = worker_init_fn)

paths_train = list(path10.glob("*.npy"))[:1000] + list(path12.glob("*.npy"))[:1000]
synt_dataset_train = SYNT(paths_train, labels=np.r_[np.ones(len(paths_train)), np.zeros(len(paths))])
synt_loader_train = torch.utils.data.DataLoader(synt_dataset_train, batch_size=64, shuffle=True, num_workers=0, drop_last=True, worker_init_fn = worker_init_fn)


paths_test = list(path10.glob("*.npy"))[-1000:] + list(path12.glob("*.npy"))[-1000:]
synt_dataset_test = SYNT(paths_test, labels=np.r_[np.ones(len(paths_test)), np.zeros(len(paths))])
synt_loader_test = torch.utils.data.DataLoader(synt_dataset_test, batch_size=64, shuffle=True, num_workers=0, drop_last=True, worker_init_fn = worker_init_fn)



In [164]:
len(list(path12.glob("*.npy")))

10000

In [162]:
class ClassifierEEGEmbedder(torch.nn.Module):
    def __init__(self, my_pretrained_model):
        super(ClassifierEEGEmbedder, self).__init__()
        self.pretrained = my_pretrained_model
        self.masking = False
        # self.my_new_layers = torch.nn.Sequential(torch.nn.Linear(self.pretrained.config.hidden_size, self.pretrained.config.num_labels),
        #                                         torch.nn.Linear(self.pretrained.config.hidden_size, self.pretrained.config.num_labels))
    
    
        self.my_new_layers = torch.nn.Sequential(
            # torch.nn.Linear(768, 256),
            torch.nn.Linear(512, 256),
            torch.nn.Linear(256, 2),
            torch.nn.Sigmoid()
        )
    def forward(self, inputs, attention_mask, ch_vector,):
        # x, label= self.pretrained(inputs, attention_mask, ch_vector,)
        x, label = self.pretrained(inputs, attention_mask, ch_vector)
        
        # print(x)
        x = self.my_new_layers(label)
        return x
    
    
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device = 'cpu'

    
model2 = ClassifierEEGEmbedder(model).to(device)
output = torch.tensor((0,0,2))
for synt in [synt_loader, synt_loader12]:
    cc = 0
    
    for batch in synt:
        cc +=1
        if cc > 1:
            break
        print(cc)

        sample=batch
        output = model2(sample['anchor'].to(device), 
               None, 
               sample['channels'].long().to(device)
              )
        output = torch.cat((output, output), 0)

1


ValueError: too many values to unpack (expected 2)

In [169]:
model = EEGEmbedder()
model.masking=False
model.clf=True
model.load_state_dict(torch.load('/media/hdd/pretraining/models/step.pt'), strict=False)


_IncompatibleKeys(missing_keys=['classifier.0.weight', 'classifier.0.bias', 'classifier.1.weight', 'classifier.1.bias'], unexpected_keys=[])

In [170]:
model = EEGEmbedder()
model.masking=False
model.clf=True
model.load_state_dict(torch.load('/media/hdd/pretraining/models/step.pt'), strict=False)

model = model.to(device)
output = torch.tensor((0,0,2))
for synt in [synt_loader, synt_loader12]:
    cc = 0
    
    for batch in synt:
        cc +=1
        if cc > 1:
            break
        print(cc)

        sample=batch
        output = model2(sample['anchor'].to(device), 
               None, 
               sample['channels'].long().to(device)
              )
        output = torch.cat((output, output), 0)

1


ValueError: too many values to unpack (expected 2)

In [171]:
model2(sample['anchor'].to(device), 
               None, 
               sample['channels'].long().to(device)
              )

ValueError: too many values to unpack (expected 2)

In [None]:
plt.plot(torch.median(output, axis=-2).values.detach().numpy())
plt.axvline(output.shape[0]/2, linestyle='--')


In [None]:

for epoch in range(training_epochs1):
    mean_loss = 0
    acc_step = 0
    for batch in train_loader:
        logits = model_test(
            batch['anchor'],#.to('cuda:0'), 
            None,
            batch['channels'].long(),
            clf=False)
        ae, label = model_test(
            batch['anchor'],#.to('cuda:0'), 
            None, 
            batch['channels'].long())
        
        logits = _calculate_similarity(torch.transpose(label, 1, 2), torch.transpose(ae, 1, 2), _generate_negatives(torch.transpose(label, 1, 2))[0])

        fake_labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
        loss = torch.nn.CrossEntropyLoss()(logits, fake_labels) + 0.001 * label.pow(2).mean()

        loss = loss.mean()
        loss.backward()
        mean_loss += loss.item()
        acc_step += 1
        steps += 1
        # raise
        if acc_step != 0 and acc_step % acc_size == 0:
            optim.step()
            scheduler.step()
            optim.zero_grad()
            if steps % 100 == 0:
                print('Loss/train\t{}'.format(mean_loss / acc_size))
            writer.add_scalar('Loss/train', mean_loss / acc_size, steps)
            mean_loss = 0
        if steps != 0 and steps % 1000 == 0:
            der = 0
            try:
                with torch.no_grad():
                    for batch in test_loader:
                        ae, label = model_test(
                            batch['anchor'],#.to('cuda:0'), 
                            None, 
                            batch['channels'].long())
                        logits = _calculate_similarity(torch.transpose(label, 1, 2), torch.transpose(ae, 1, 2), _generate_negatives(torch.transpose(label, 1, 2))[0])

                        fake_labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
                        loss = torch.nn.CrossEntropyLoss()(logits, fake_labels) + 0.001 * label.pow(2).mean()

                        loss = loss.mean() / acc_size
                        der += loss
                der /= len(test_loader)
                writer.add_scalar('Loss/test', der, steps)

                print('Loss: {}\t'.format(der))
            except:
                raise
            # torch.save(model_test.module.state_dict(), 'models/step.pt'.format(steps))