In [1]:
!nvidia-smi

Mon Nov  7 14:34:53 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.129.06   Driver Version: 470.129.06   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:19:00.0 Off |                  N/A |
| 30%   31C    P8    30W / 350W |      1MiB / 24268MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:65:00.0 Off |                  N/A |
| 30%   32C    P8    20W / 350W |   5606MiB / 24268MiB |      0%      Default |
|       

##### %config Completer.use_jedi = False

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# os.environ['http_proxy'] = "http://127.0.0.1:3128"
# os.environ['https_proxy'] = "http://127.0.0.1:3128"

In [3]:
import matplotlib.pyplot as plt
from torch.utils.data import Dataset

import numpy as np
import pandas as pd
import os

In [4]:
from torch.utils.tensorboard import SummaryWriter

In [5]:
import torch
from bert_conv_custom import BertConfig, BertEncoder

In [6]:
from transformers import BertModel

# Architecture

In [7]:
class TransposeCustom(torch.nn.Module):
    def __init__(self):
        super(TransposeCustom, self).__init__()
        
    def forward(self, x):
        return torch.transpose(x, 1, 2)

In [8]:
# config = BertConfig(is_decoder=True, 
#                     add_cross_attention=True,
#                     ff_layer='conv',
#                     conv_kernel=1,
#                     conv_kernel_num=3)

In [9]:
def _make_span_from_seeds(seeds, span, total=None):
    inds = list()
    for seed in seeds:
        for i in range(seed, seed + span):
            if total is not None and i >= total:
                break
            elif i not in inds:
                inds.append(int(i))
    return np.array(inds)

In [10]:
def _make_mask(shape, p, total, span, allow_no_inds=False):
    # num_mask_spans = np.sum(np.random.rand(total) < p)
    # num_mask_spans = int(p * total)
    mask = torch.zeros(shape, requires_grad=False, dtype=torch.bool)

    for i in range(shape[0]):
        mask_seeds = list()
        while not allow_no_inds and len(mask_seeds) == 0 and p > 0:
            mask_seeds = np.nonzero(np.random.rand(total) < p)[0]

        mask[i, _make_span_from_seeds(mask_seeds, span, total=total)] = True

    return mask

In [11]:
import math
class PositionalEncoding(torch.nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = torch.nn.Dropout(p=dropout)
        
        position = torch.arange(max_len).unsqueeze(1)
        position = (position.T - position).T / max_len
        self.register_buffer('rel_position', position)
        
        self.conv = torch.nn.Conv1d(max_len, d_model, 25, padding=25 // 2, groups=16)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [batch_size, seq_len, embedding_dim]
        """
        rel_pos = self.conv(self.rel_position[:x.size(1), :x.size(1)][None])[0].T
        print(rel_pos.shape)
        x = x + rel_pos
        return self.dropout(x)

In [12]:
class EEGEmbedder(torch.nn.Module):
    def __init__(self):
        super(EEGEmbedder, self).__init__()
        config = BertConfig(is_decoder=False, 
                    add_cross_attention=False,
                    ff_layer='linear',
                    hidden_size=512,
                    num_attention_heads=8,
                    num_hidden_layers=8,
                    conv_kernel=1,
                    conv_kernel_num=1)
        self.model = BertEncoder(config)
        
        self.pos_e = PositionalEncoding(512, max_len=6000)
        self.ch_embedder = torch.nn.Embedding(len(mitsar_chls), 512)
        self.ch_norm = torch.nn.LayerNorm(512)
        
        self.input_norm = torch.nn.LayerNorm(2)
        self.input_embedder = torch.nn.Sequential(
            TransposeCustom(),
            torch.nn.Conv1d(2, 32, 5, 2, padding=0),
            torch.nn.Conv1d(32, 64, 5, 2, padding=0),
            # TransposeCustom(),
            torch.nn.GroupNorm(64 // 2, 64),
            torch.nn.GELU(),
            # TransposeCustom(),
            torch.nn.Conv1d(64, 128, 3, 2, padding=0),
            torch.nn.Conv1d(128, 196, 3, 2, padding=0),
            # TransposeCustom(),
            torch.nn.GroupNorm(196 // 2, 196),
            torch.nn.GELU(),
            # TransposeCustom(),
            torch.nn.Conv1d(196, 256, 5, 1, padding=0),
            torch.nn.Conv1d(256, 384, 5, 1, padding=0),
            # TransposeCustom(),
            torch.nn.GroupNorm(384 // 2, 384),
            torch.nn.GELU(),
            # TransposeCustom(),
            torch.nn.Conv1d(384, 512, 5, 1, padding=0),
            torch.nn.Conv1d(512, 512, 1, 1, padding=0),
            torch.nn.GroupNorm(512 // 2, 512),
            torch.nn.GELU(),
            TransposeCustom(),
        # torch.nn.LeakyReLU(),
        )
        # self.input_norm = torch.nn.LayerNorm(10)
        # self.output_embedder = torch.nn.Conv1d(512, 512, 1)
        self.output_embedder = torch.nn.Linear(512, 512)
        self.transpose = TransposeCustom()
        
        self.mask_embedding = torch.nn.Parameter(torch.normal(0, 512**(-0.5), size=(512,)),
                                                   requires_grad=True)
        self.classification = torch.nn.Sequential(
            torch.nn.Linear(512, 2),
            torch.nn.Softmax(-1)
        )
        
    def single_forward(self, inputs, attention_mask, ch_vector, placeholder):
        embedding = self.input_embedder(inputs)
        # create embedding for two channel indexes and sumup them to a single one
        ch_embedding = self.ch_embedder(ch_vector).sum(1)
        ch_embedding = ch_embedding[:, None]
        # print(embedding.shape, ch_embedding.shape)
        embedding += ch_embedding
        # embedding = self.ch_norm(embedding)
        # we lost some channel specific information
        # embedding_unmasked = embedding.clone()
        
        # mask = _make_mask((embedding.shape[0], embedding.shape[1]), 0.05, embedding.shape[1], 10)
        # embedding[mask] = self.mask_embedding
        
        # for b_i in range(embedding.shape[0]):
        #     embedding_masked[b_i][mask[b_i]] = self.mask_embedding
        
        embedding = torch.cat([placeholder, embedding], 1)
        encoder_output = self.model(embedding, output_hidden_states=True,
                               output_attentions=True)[0]
        encoder_output = self.output_embedder(encoder_output)
        # encoder_output = self.transpose(encoder_output)
        return encoder_output[:, 0], None
    
    def forward(self, a, mask, ch_vector, placeholder):
        a_downsampled_embedding, label = self.single_forward(a, mask, ch_vector, placeholder)
        pred = self.classification(a_downsampled_embedding)
        return pred, label
    
    def infer(self, a, ch_vector, placeholder):
        embedding = self.input_embedder(inputs)

        # create embedding for two channel indexes and sumup them to a single one
        ch_embedding = self.ch_embedder(ch_vector).sum(1)
        ch_embedding = ch_embedding[:, None]
        # print(embedding.shape, ch_embedding.shape)
        embedding += ch_embedding
        # embedding = self.ch_norm(embedding)
            
        embedding = torch.cat([placeholder, embedding], 1)
        encoder_output = self.model(embedding, output_hidden_states=True,
                               output_attentions=True)[0]
        
        encoder_output = self.output_embedder(self.transpose(encoder_output))
        return encoder_output, embedding

# Data

In [13]:
def _generate_negatives(z):
    """Generate negative samples to compare each sequence location against"""
    num_negatives = 20
    batch_size, feat, full_len = z.shape
    z_k = z.permute([0, 2, 1]).reshape(-1, feat)
    with torch.no_grad():
        # candidates = torch.arange(full_len).unsqueeze(-1).expand(-1, self.num_negatives).flatten()
        negative_inds = torch.randint(0, full_len-1, size=(batch_size, full_len * num_negatives))
        # From wav2vec 2.0 implementation, I don't understand
        # negative_inds[negative_inds >= candidates] += 1

        for i in range(1, batch_size):
            negative_inds[i] += i * full_len

    z_k = z_k[negative_inds.view(-1)].view(batch_size, full_len, num_negatives, feat)
    return z_k, negative_inds

In [14]:
def _calculate_similarity( z, c, negatives):
    c = c[..., :].permute([0, 2, 1]).unsqueeze(-2)
    z = z.permute([0, 2, 1]).unsqueeze(-2)

    # In case the contextualizer matches exactly, need to avoid divide by zero errors
    negative_in_target = (c == negatives).all(-1)
    targets = torch.cat([c, negatives], dim=-2)

    logits = torch.nn.functional.cosine_similarity(z, targets, dim=-1) / 0.1
    if negative_in_target.any():
        logits[1:][negative_in_target] = float("-inf")

    return logits.view(-1, logits.shape[-1])

In [15]:
def masking(ts):
    start_shift = np.random.choice(range(10))
    downsampling = 2
    indices = np.random.choice(np.array(list(range(110)))[start_shift::10][::downsampling], 5, replace=False)
    masked_idx = []
    for i in indices:
        masked_idx.extend(range(i, i+10))

    masked_idx = np.array(masked_idx)
    
    # mask = np.ones((6000, 2))
    # # desync some masked channels
    # ts_masked = ts.copy()
    # if np.random.choice([0, 1], p=[0.7, 0.3]):
    #     ts_masked[masked_idx, np.random.choice([0, 1])] *= 0
    # else:
    #     ts_masked[masked_idx] *= 0
        
    return None, masked_idx

In [16]:
# ({0: tensor(2.4250), 1: tensor(1.9767)},
#  {0: tensor(66.8642), 1: tensor(46.8854)})

mean_dict = {0: (0.9631),
 1: (1.0248),
 2: (1.3041),
 3: (0.),
 4: (1.5822),
 5: (1.7250),
 6: (0.9935),
 7: (0.9548),
 8: (0.7488),
 9: (1.3948),
 10: (0.8879),
 11: (1.0527),
 12: (1.3401),
 13: (1.5541),
 14: (1.2600),
 15: (1.0487),
 16: (0.7529),
 17: (1.6566),
 18: (0.9272),
 19: (1.2238),
 20: (1.2619),
 21: (1.5236)}

std_dict = {0: (64.1294),
 1: (64.1984),
 2: (45.9215),
 3: (0.),
 4: (45.1312),
 5: (51.7621),
 6: (43.5150),
 7: (39.7182),
 8: (46.8787),
 9: (49.0797),
 10: (52.2342),
 11: (51.9236),
 12: (50.7353),
 13: (52.1277),
 14: (48.8627),
 15: (42.7040),
 16: (46.5815),
 17: (60.2403),
 18: (41.6082),
 19: (44.6035),
 20: (82.8107),
 21: (53.5717)}

In [17]:
writer = SummaryWriter('./logs')

2022-11-07 14:34:57.399139: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [18]:
class TEST(torch.utils.data.Dataset):
    def __init__(self, main, labels, norm):
        super(TEST, self).__init__()
        self.main = main
        self.label = labels
        self.norm = norm

    def __len__(self):
        return len(self.main)
    
    def __getitem__(self, idx):
        # sample = torch.from_numpy(np.load(self.meta.iloc[idx]['path'])[:6000].astype(np.float32)).clone()
        sample = np.copy(self.main[idx])
        # sample = butter_bandpass_filter_v2(sample, 1, 40, 100)
        # sample_label = torch.tensor(np.copy(self.label[idx]))-1#torch.tensor(1 if self.main[idx]['label'] == 'work' else 0)
        sample_label = torch.tensor(0) if self.label[idx] == 1 else torch.tensor(1)
        # sample_label = label_map[sample['label']]
        
        channels = [mitsar_chls.index('T4'), mitsar_chls.index('T6')]
        sample = sample[:, [HSE_chls.index('T7'), HSE_chls.index('T8')]]
        # sample_min, sample_max = sample.min(0), sample.max(0)
        # sample = (sample - sample_min) / (sample_max - sample_min)
        sample[:, 0] -= self.norm['mean'][HSE_chls.index('T7')]
        sample[:, 1] -= self.norm['mean'][HSE_chls.index('T8')]
        sample[:, 0] /= self.norm['std'][HSE_chls.index('T7')]
        sample[:, 1] /= self.norm['std'][HSE_chls.index('T8')]
        
        # sample = butter_bandpass_filter_v2(sample, 1, 40, 100)
        sample = torch.from_numpy(sample[:3000].astype(np.float32)).clone()
        return {'anchor': sample, 
                'label': sample_label,
                'channels': torch.tensor(channels)}

In [19]:
HSE_chls = ['Fp1', 'Fz', 'F3', 'F7', 'FC5', 'FC1', 'C3', 'T7',
    'CP5', 'CP1', 'Pz', 'P3', 'P7', 'O1', 'Oz', 'O2', 'P4', 'P8', 'CP6',
    'CP2', 'Cz', 'C4', 'T8', 'FC6', 'FC2', 'F4', 'F8', 'FP2']

HSE_chls = [i.upper() for i in HSE_chls]

In [20]:
mitsar_chls = ['Fp1', 'Fp2', 'FZ', 'FCz', 'Cz', 'Pz', 'O1', 'O2', 'F3', 'F4', 
               'F7', 'F8', 'C3', 'C4', 'T3', 'T4', 'P3', 'P4', 'T5', 'T6', 'A1', 'A2']
mitsar_chls = [i.upper() for i in mitsar_chls]

In [21]:
train_data = np.load('/home/data/HSE_exp/processed/v2/train_signal.npy', allow_pickle=True)
test_data = np.load('/home/data/HSE_exp/processed/v2/test_signal.npy', allow_pickle=True)

train_label = np.load('/home/data/HSE_exp/processed/v2/train_label.npy', allow_pickle=True)
test_label = np.load('/home/data/HSE_exp/processed/v2/test_label.npy', allow_pickle=True)

In [22]:
train_data = [val for i, val in enumerate(train_data) if train_label[i] not in [0, 2]]
train_label = [val for i, val in enumerate(train_label) if val not in [0, 2]]

test_data = [val for i, val in enumerate(test_data) if test_label[i]  not in [0, 2]]
test_label = [val for i, val in enumerate(test_label) if val  not in [0, 2]]

In [23]:
# train_data = np.array([butter_bandpass_filter_v2(i, 1, 40, 100) for i in train_data])
# test_data = np.array([butter_bandpass_filter_v2(i, 1, 40, 100) for i in test_data])

In [24]:
channels_meta = {'mean': [], 'std': []}
channels_meta['mean'] = (np.concatenate([train_data, test_data]).reshape(np.concatenate([train_data, test_data]).shape[0] * np.concatenate([train_data, test_data]).shape[1], -1).mean(0))
channels_meta['std'] = (np.concatenate([train_data, test_data]).reshape(np.concatenate([train_data, test_data]).shape[0] * np.concatenate([train_data, test_data]).shape[1], -1).std(0))

In [25]:
train_data3 = np.load('/home/data/HSE_exp/processed/v2/train_signal.3.npy', allow_pickle=True)
test_data3 = np.load('/home/data/HSE_exp/processed/v2/test_signal.3.npy', allow_pickle=True)

train_label3 = np.load('/home/data/HSE_exp/processed/v2/train_label.3.npy', allow_pickle=True)
test_label3 = np.load('/home/data/HSE_exp/processed/v2/test_label.3.npy', allow_pickle=True)

In [26]:
train_data3 = [val for i, val in enumerate(train_data3) if train_label3[i] not in [0, 2]]
train_label3 = [val for i, val in enumerate(train_label3) if val not in [0, 2]]

test_data3 = [val for i, val in enumerate(test_data3) if test_label3[i]  not in [0, 2]]
test_label3 = [val for i, val in enumerate(test_label3) if val  not in [0, 2]]

In [27]:
channels_meta3 = {'mean': [], 'std': []}
channels_meta3['mean'] = (np.concatenate([train_data3, test_data3]).reshape(np.concatenate([train_data3, test_data3]).shape[0] * np.concatenate([train_data3, test_data3]).shape[1], -1).mean(0))
channels_meta3['std'] = (np.concatenate([train_data3, test_data3]).reshape(np.concatenate([train_data3, test_data3]).shape[0] * np.concatenate([train_data3, test_data3]).shape[1], -1).std(0))

In [28]:


# train_data = data[:len(data)//100 * 90]
# train_data = [i for i in train_data if i['eeg'].shape[0] > 500]
# test_data = data[len(data)//100 * 90:]
# test_data = [i for i in test_data if i['eeg'].shape[0] > 500]

In [29]:
len(train_data), len(test_data)

(329, 139)

In [30]:
test_dataset = TEST(test_data, test_label, channels_meta)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0, drop_last=True)

In [31]:
mitsar_chls = ['Fp1', 'Fp2', 'FZ', 'FCz', 'Cz', 'Pz', 'O1', 'O2', 'F3', 'F4', 
               'F7', 'F8', 'C3', 'C4', 'T3', 'T4', 'P3', 'P4', 'T5', 'T6', 'A1', 'A2']
mitsar_chls = [i.upper() for i in mitsar_chls]

# Training

In [32]:
model = EEGEmbedder()

In [33]:
model.load_state_dict(torch.load('../../../../Pretraining/v5/models2/step_159000.pt'), strict=False)

_IncompatibleKeys(missing_keys=['classification.0.weight', 'classification.0.bias'], unexpected_keys=[])

In [34]:

from torch.optim.lr_scheduler import _LRScheduler
class NoamLR(_LRScheduler):
    def __init__(self, optimizer, warmup_steps, d_model=512):
        self.warmup_steps = warmup_steps
        self.d_model = d_model
        super().__init__(optimizer)

    def get_lr(self):
        last_epoch = max(1, self.last_epoch)
        factor = min(last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5))
        # scale = self.warmup_steps ** 0.5 * min(last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5))
        # return [base_lr * scale for base_lr in self.base_lrs]
        return [base_lr * self.d_model ** (-0.5) * factor for base_lr in self.base_lrs]

In [35]:
cossim = torch.nn.CosineSimilarity(dim=-1)

def cosloss(anchor, real, negative):
    a = torch.exp(cossim(anchor, real)) / 0.1
    b = sum([torch.exp(cossim(anchor, negative[:, n])) / 0.1 for n in range(negative.shape[1])]) + 1e-6
    return -torch.log(a/b)

In [36]:
import random
def worker_init_fn(worker_id):
    torch_seed = torch.initial_seed()
    random.seed(torch_seed + worker_id)
    np.random.seed((torch_seed + worker_id) % 2**30)


train_dataset = TEST(train_data, train_label, channels_meta)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0, drop_last=True, worker_init_fn = worker_init_fn)

test_dataset = TEST(test_data, test_label, channels_meta)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0, drop_last=True, worker_init_fn = worker_init_fn)

test_dataset3 = TEST(test_data3, test_label3, channels_meta3)
test_loader3 = torch.utils.data.DataLoader(test_dataset3, batch_size=32, shuffle=False, num_workers=0, drop_last=True, worker_init_fn = worker_init_fn)

In [37]:

model.train()

lr_d = 5e-6
acc_size = 1
training_epochs1 = 15000 // len(train_loader)

# model_test = torch.nn.DataParallel(model)
# optim = torch.optim.AdamW(model.parameters(), lr=lr_d)
optim = torch.optim.AdamW([{'params': model.model.parameters(), 'lr': 1e-7},
                          {'params': model.classification.parameters(), 'lr': 1e-4}])
# scheduler = NoamLR(optim, 3000, 512)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optim, max_lr=lr_d, total_steps=training_epochs1*len(train_loader))
model.to('cuda:0')

loss_func = torch.nn.CrossEntropyLoss()

steps = 0

In [38]:
len(train_loader), training_epochs1, training_epochs1 * len(train_loader)

(10, 1500, 15000)

In [39]:
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import accuracy_score, balanced_accuracy_score

In [40]:
# model.cpu()(batch['anchor'][None], batch['mask'][None], batch['channels'][None])

In [41]:
for epoch in range(training_epochs1):
    mean_loss = 0
    acc_step = 0
    for batch in train_loader:
        # batch = train_dataset.__getitem__(i)
        optim.zero_grad()
        placeholder = torch.zeros((batch['anchor'].shape[0], 1, 512)) - 5
        ae, _ = model(
            batch['anchor'].to('cuda:0'), 
            None, 
            batch['channels'].long().to('cuda:0'),
            placeholder.to('cuda:0'))
        loss = loss_func(ae.view(-1, 2), batch['label'].to('cuda:0').long())
        # loss = loss.mean() / acc_size
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        mean_loss = loss.item()
        acc_step += 1
        steps += 1
        optim.step()
        # scheduler.step()
        if steps % 500 == 0:
            print('Loss: {}\t'.format(mean_loss))
        
        if steps != 0 and steps % 1000 == 0:
            der = 0
            preds = []
            reals = []
            try:
                with torch.no_grad():
                    for batch in test_loader:
                        # batch = test_dataset.__getitem__(i)
                        placeholder = torch.zeros((batch['anchor'].shape[0], 1, 512)) - 5
                        ae, label = model(
                            batch['anchor'].to('cuda:0'), 
                            None, 
                            batch['channels'].long().to('cuda:0'),
                            placeholder.to('cuda:0'))
                        # loss_positive = loss_fct(ae, pe)
                        # loss_negative = loss_fct(ae, ne)
                        reals.extend(batch['label'])
                        preds.extend(ae.view(-1, 2))
                        loss = loss_func(ae.view(-1, 2), batch['label'].to('cuda:0').long())

                        loss = loss.mean() / acc_size
                        der += loss
                der /= len(test_loader)
                writer.add_scalar('Loss/test', der, steps)
                
                
                reals = np.array([i.tolist() for i in reals])
                preds = np.array([i.tolist() for i in preds])
                # preds[np.where(preds < 0.5)] = 0
                # preds[np.where(preds >= 0.5)] = 1
                print(precision_recall_fscore_support(reals, preds.argmax(-1)))

                print('Loss: {}\t'.format(der))
            except:
                raise
            # torch.save(model_test.module.state_dict(), '{}/step_{}.der_{}.pt'.format(model_path, steps, round(der, 3)))

Loss: 0.6289346218109131	
Loss: 0.5992487668991089	
(array([0.52112676, 0.52631579]), array([0.578125, 0.46875 ]), array([0.54814815, 0.49586777]), array([64, 64]))
Loss: 0.6803186535835266	
Loss: 0.4345097839832306	
Loss: 0.5174616575241089	
(array([0.57575758, 0.58064516]), array([0.59375, 0.5625 ]), array([0.58461538, 0.57142857]), array([64, 64]))
Loss: 0.6768562197685242	
Loss: 0.5666288137435913	
Loss: 0.53565913438797	
(array([0.55172414, 0.54285714]), array([0.5    , 0.59375]), array([0.52459016, 0.56716418]), array([64, 64]))
Loss: 0.7066055536270142	
Loss: 0.5124634504318237	
Loss: 0.5562436580657959	
(array([0.578125, 0.578125]), array([0.578125, 0.578125]), array([0.578125, 0.578125]), array([64, 64]))
Loss: 0.6942340135574341	
Loss: 0.5782976746559143	
Loss: 0.506091296672821	
(array([0.578125, 0.578125]), array([0.578125, 0.578125]), array([0.578125, 0.578125]), array([64, 64]))
Loss: 0.6966123580932617	
Loss: 0.47986263036727905	
Loss: 0.46121659874916077	
(array([0.5517

In [42]:
preds = []
reals = []
with torch.no_grad():
    for batch in test_loader3:
        # batch = test_dataset.__getitem__(i)
        placeholder = torch.zeros((batch['anchor'].shape[0], 1, 512)) - 5
        ae, label = model(
            batch['anchor'].to('cuda:0'), 
            None, 
            batch['channels'].long().to('cuda:0'),
            placeholder.to('cuda:0'))
        # loss_positive = loss_fct(ae, pe)
        # loss_negative = loss_fct(ae, ne)
        reals.extend(batch['label'])
        preds.extend(ae.view(-1, 2))
        loss = loss_func(ae.view(-1, 2), batch['label'].to('cuda:0').long())

        loss = loss.mean() / acc_size
        der += loss
der /= len(test_loader)
writer.add_scalar('Loss/test', der, steps)


reals = np.array([i.tolist() for i in reals])
preds = np.array([i.tolist() for i in preds])
# preds[np.where(preds < 0.5)] = 0
# preds[np.where(preds >= 0.5)] = 1
print(precision_recall_fscore_support(reals, preds.argmax(-1)))

(array([0.50485437, 0.72      ]), array([0.88135593, 0.26086957]), array([0.64197531, 0.38297872]), array([59, 69]))
