In [1]:
!nvidia-smi

Tue Nov  1 18:58:46 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.129.06   Driver Version: 470.129.06   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:19:00.0 Off |                  N/A |
| 30%   32C    P8    30W / 350W |      1MiB / 24268MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:65:00.0 Off |                  N/A |
| 30%   32C    P8    20W / 350W |      1MiB / 24268MiB |      0%      Default |
|       

##### %config Completer.use_jedi = False

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
# os.environ['http_proxy'] = "http://127.0.0.1:3128"
# os.environ['https_proxy'] = "http://127.0.0.1:3128"

In [3]:
import matplotlib.pyplot as plt
from torch.utils.data import Dataset

import numpy as np
import pandas as pd
import os

In [4]:
from torch.utils.tensorboard import SummaryWriter

In [5]:
import torch
from bert_conv_custom import BertConfig, BertEncoder

In [6]:
from transformers import BertModel

# Architecture

In [7]:
class TransposeCustom(torch.nn.Module):
    def __init__(self):
        super(TransposeCustom, self).__init__()
        
    def forward(self, x):
        return torch.transpose(x, 1, 2)

In [8]:
# config = BertConfig(is_decoder=True, 
#                     add_cross_attention=True,
#                     ff_layer='conv',
#                     conv_kernel=1,
#                     conv_kernel_num=3)

In [9]:
def _make_span_from_seeds(seeds, span, total=None):
    inds = list()
    for seed in seeds:
        for i in range(seed, seed + span):
            if total is not None and i >= total:
                break
            elif i not in inds:
                inds.append(int(i))
    return np.array(inds)

In [10]:
def _make_mask(shape, p, total, span, allow_no_inds=False):
    # num_mask_spans = np.sum(np.random.rand(total) < p)
    # num_mask_spans = int(p * total)
    mask = torch.zeros(shape, requires_grad=False, dtype=torch.bool)

    for i in range(shape[0]):
        mask_seeds = list()
        while not allow_no_inds and len(mask_seeds) == 0 and p > 0:
            mask_seeds = np.nonzero(np.random.rand(total) < p)[0]

        mask[i, _make_span_from_seeds(mask_seeds, span, total=total)] = True

    return mask

In [11]:
import math
class PositionalEncoding(torch.nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = torch.nn.Dropout(p=dropout)
        
        position = torch.arange(max_len).unsqueeze(1)
        position = (position.T - position).T / max_len
        self.register_buffer('rel_position', position)
        
        self.conv = torch.nn.Conv1d(max_len, d_model, 25, padding=25 // 2, groups=16)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [batch_size, seq_len, embedding_dim]
        """
        rel_pos = self.conv(self.rel_position[:x.size(1), :x.size(1)][None])[0].T
        print(rel_pos.shape)
        x = x + rel_pos
        return self.dropout(x)

In [12]:
class EEGEmbedder(torch.nn.Module):
    def __init__(self):
        super(EEGEmbedder, self).__init__()
        config = BertConfig(is_decoder=False, 
                    add_cross_attention=False,
                    ff_layer='linear',
                    hidden_size=512,
                    num_attention_heads=8,
                    num_hidden_layers=8,
                    conv_kernel=1,
                    conv_kernel_num=1)
        self.model = BertEncoder(config)
        
        self.pos_e = PositionalEncoding(512, max_len=6000)
        self.ch_embedder = torch.nn.Embedding(len(mitsar_chls), 512)
        self.ch_norm = torch.nn.LayerNorm(512)
        
        self.input_norm = torch.nn.LayerNorm(2)
        self.input_embedder = torch.nn.Sequential(
            TransposeCustom(),
            torch.nn.Conv1d(2, 32, 5, 2, padding=0),
            torch.nn.Conv1d(32, 64, 5, 2, padding=0),
            # TransposeCustom(),
            torch.nn.GroupNorm(64 // 2, 64),
            torch.nn.GELU(),
            # TransposeCustom(),
            torch.nn.Conv1d(64, 128, 3, 2, padding=0),
            torch.nn.Conv1d(128, 196, 3, 2, padding=0),
            # TransposeCustom(),
            torch.nn.GroupNorm(196 // 2, 196),
            torch.nn.GELU(),
            # TransposeCustom(),
            torch.nn.Conv1d(196, 256, 5, 1, padding=0),
            torch.nn.Conv1d(256, 384, 5, 1, padding=0),
            # TransposeCustom(),
            torch.nn.GroupNorm(384 // 2, 384),
            torch.nn.GELU(),
            # TransposeCustom(),
            torch.nn.Conv1d(384, 512, 5, 1, padding=0),
            torch.nn.Conv1d(512, 512, 1, 1, padding=0),
            torch.nn.GroupNorm(512 // 2, 512),
            torch.nn.GELU(),
            TransposeCustom(),
        # torch.nn.LeakyReLU(),
        )
        # self.input_norm = torch.nn.LayerNorm(10)
        # self.output_embedder = torch.nn.Conv1d(512, 512, 1)
        self.output_embedder = torch.nn.Linear(512, 512)
        self.transpose = TransposeCustom()
        
        self.mask_embedding = torch.nn.Parameter(torch.normal(0, 512**(-0.5), size=(512,)),
                                                   requires_grad=True)
        
    def single_forward(self, inputs, attention_mask, ch_vector, placeholder):
        embedding = self.input_embedder(inputs)
        # create embedding for two channel indexes and sumup them to a single one
        ch_embedding = self.ch_embedder(ch_vector).sum(1)
        ch_embedding = ch_embedding[:, None]
        # print(embedding.shape, ch_embedding.shape)
        embedding += ch_embedding
        # embedding = self.ch_norm(embedding)
        # we lost some channel specific information
        embedding_unmasked = embedding.clone()
        
        mask = _make_mask((embedding.shape[0], embedding.shape[1]), 0.05, embedding.shape[1], 10)
        embedding[mask] = self.mask_embedding
        
        # for b_i in range(embedding.shape[0]):
        #     embedding_masked[b_i][mask[b_i]] = self.mask_embedding
        
        embedding = torch.cat([placeholder, embedding], 1)
        encoder_output = self.model(embedding, output_hidden_states=True,
                               output_attentions=True)[0]
        encoder_output = self.output_embedder(encoder_output)
        # encoder_output = self.transpose(encoder_output)
        return encoder_output[:, 1:], embedding_unmasked
    
    def forward(self, a, mask, ch_vector, placeholder):
        a_downsampled_embedding, label = self.single_forward(a, mask, ch_vector, placeholder)
        
        return a_downsampled_embedding, label
    
    def infer(self, a, ch_vector, placeholder):
        embedding = self.input_embedder(inputs)

        # create embedding for two channel indexes and sumup them to a single one
        ch_embedding = self.ch_embedder(ch_vector).sum(1)
        ch_embedding = ch_embedding[:, None]
        # print(embedding.shape, ch_embedding.shape)
        embedding += ch_embedding
        # embedding = self.ch_norm(embedding)
            
        embedding = torch.cat([placeholder, embedding], 1)
        encoder_output = self.model(embedding, output_hidden_states=True,
                               output_attentions=True)[0]
        
        encoder_output = self.output_embedder(self.transpose(encoder_output))
        return encoder_output, embedding

# Data

In [13]:
def _generate_negatives(z):
    """Generate negative samples to compare each sequence location against"""
    num_negatives = 20
    batch_size, feat, full_len = z.shape
    z_k = z.permute([0, 2, 1]).reshape(-1, feat)
    with torch.no_grad():
        # candidates = torch.arange(full_len).unsqueeze(-1).expand(-1, self.num_negatives).flatten()
        negative_inds = torch.randint(0, full_len-1, size=(batch_size, full_len * num_negatives))
        # From wav2vec 2.0 implementation, I don't understand
        # negative_inds[negative_inds >= candidates] += 1

        for i in range(1, batch_size):
            negative_inds[i] += i * full_len

    z_k = z_k[negative_inds.view(-1)].view(batch_size, full_len, num_negatives, feat)
    return z_k, negative_inds

In [14]:
def _calculate_similarity( z, c, negatives):
    c = c[..., :].permute([0, 2, 1]).unsqueeze(-2)
    z = z.permute([0, 2, 1]).unsqueeze(-2)

    # In case the contextualizer matches exactly, need to avoid divide by zero errors
    negative_in_target = (c == negatives).all(-1)
    targets = torch.cat([c, negatives], dim=-2)

    logits = torch.nn.functional.cosine_similarity(z, targets, dim=-1) / 0.1
    if negative_in_target.any():
        logits[1:][negative_in_target] = float("-inf")

    return logits.view(-1, logits.shape[-1])

In [15]:
def masking(ts):
    start_shift = np.random.choice(range(10))
    downsampling = 2
    indices = np.random.choice(np.array(list(range(110)))[start_shift::10][::downsampling], 5, replace=False)
    masked_idx = []
    for i in indices:
        masked_idx.extend(range(i, i+10))

    masked_idx = np.array(masked_idx)
    
    # mask = np.ones((6000, 2))
    # # desync some masked channels
    # ts_masked = ts.copy()
    # if np.random.choice([0, 1], p=[0.7, 0.3]):
    #     ts_masked[masked_idx, np.random.choice([0, 1])] *= 0
    # else:
    #     ts_masked[masked_idx] *= 0
        
    return None, masked_idx

In [16]:
# ({0: tensor(2.4250), 1: tensor(1.9767)},
#  {0: tensor(66.8642), 1: tensor(46.8854)})

mean_dict = {0: (0.9631),
 1: (1.0248),
 2: (1.3041),
 3: (0.),
 4: (1.5822),
 5: (1.7250),
 6: (0.9935),
 7: (0.9548),
 8: (0.7488),
 9: (1.3948),
 10: (0.8879),
 11: (1.0527),
 12: (1.3401),
 13: (1.5541),
 14: (1.2600),
 15: (1.0487),
 16: (0.7529),
 17: (1.6566),
 18: (0.9272),
 19: (1.2238),
 20: (1.2619),
 21: (1.5236)}

std_dict = {0: (64.1294),
 1: (64.1984),
 2: (45.9215),
 3: (0.),
 4: (45.1312),
 5: (51.7621),
 6: (43.5150),
 7: (39.7182),
 8: (46.8787),
 9: (49.0797),
 10: (52.2342),
 11: (51.9236),
 12: (50.7353),
 13: (52.1277),
 14: (48.8627),
 15: (42.7040),
 16: (46.5815),
 17: (60.2403),
 18: (41.6082),
 19: (44.6035),
 20: (82.8107),
 21: (53.5717)}

In [17]:
writer = SummaryWriter('./logs')

2022-11-01 18:58:50.882690: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [53]:
class TEST(torch.utils.data.Dataset):
    def __init__(self, path):
        super(TEST, self).__init__()
        self.main_path = path
        self.paths = path
        # self.paths = ['{}/{}'.format(self.main_path, i) for i in os.listdir(self.main_path)]

    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, idx):
        path = self.paths[idx]
        # take 60s of recording with specified shift
        key = False
        while(key == False):
            try:
                sample = np.load(path, allow_pickle=True).item()['value']
                key = True
            except Exception as e:
                print("Path: {} is broken ".format(path), e)
                path = np.random.choice(self.paths, 1)[0]
                # sample = np.load(path, allow_pickle=True).item()['value']
        real_len = sample.shape[0]
        # if np.random.choice([0, 1], p=[0.9, 0.1]):
        #     real_len = np.random.randint(real_len // 2, real_len)
            
        sample = sample[:real_len]
        # sample = torch.from_numpy(sample[:6000].astype(np.float32)).clone()
        channels_ids = [i for i, val in enumerate(mitsar_chls) if i != 3]
        
        # choose 2 random channels
        channels_to_train = np.random.choice(channels_ids, 2, replace=False)
        channels_vector = torch.tensor((channels_to_train))
        sample = sample[:, channels_to_train]
        
        sample_norm = (sample - tuh_filtered_stat_vals['min_vals_filtered'][channels_vector]) / (tuh_filtered_stat_vals['max_vals_filtered'][channels_vector] - tuh_filtered_stat_vals['min_vals_filtered'][channels_vector] + 1e-6)
        # sample_norm = sample_norm * 2 - 1
        # _, mask = masking(sample_norm)
        
        if sample_norm.shape[0] < 6000:
            sample_norm = np.pad(sample_norm, ((0, 6000 - sample_norm.shape[0]), (0, 0)))
        
        attention_mask = torch.ones(6000)
        attention_mask[real_len:] = 0
        return {'anchor': torch.from_numpy(sample_norm).float(), 
                # 'label': sample_label, 
                # 'anchor_masked': torch.from_numpy(sample_masked).float(), 
                # 'mask': torch.tensor(mask),
                'channels': channels_vector,
                'attention_mask': attention_mask}

In [19]:
tuh_filtered_stat_vals = np.load('/home/data/TUH_pretrain.filtered_1_40/stat_vals.npy', allow_pickle=True).item()

In [20]:
# file_paths = []
# for path1 in os.listdir('/home/data/TUH_pretrain.filtered_1_40/'):
#     if 'npy' not in path1:
#         for path2 in os.listdir('/home/data/TUH_pretrain.filtered_1_40/{}'.format(path1)):
#             for path3 in os.listdir('/home/data/TUH_pretrain.filtered_1_40//{}/{}'.format(path1, path2)):
#                 file_paths.append('/home/data/TUH_pretrain.filtered_1_40//{}/{}/{}'.format(path1, path2, path3))

In [21]:
# len(file_paths)

In [22]:
# tmp = {}

# for path in file_paths:
#     tmp[path] = np.load(path, allow_pickle=True).item()

In [23]:
# sizes = np.load('/home/data/TUH_pretrain.filtered_1_40/sizes.npy', allow_pickle=True).item()

In [24]:
# file_paths = [i for i in file_paths if sizes[i][0] > 10000]

In [25]:
# splitted_paths = []
# for path in file_paths:
#     shape = sizes[path][0]
#     for i in range(0, shape, 6000):
#         if (shape - i > 6000):
#             splitted_paths.append((path, i))

In [26]:
splitted_paths = ['/home/data/TUH_pretrain.filtered_1_40.splited/{}'.format(i) for i in os.listdir('/home/data/TUH_pretrain.filtered_1_40.splited/')]

In [52]:
np.random.choice(splitted_paths, 1)[0]

'/home/data/TUH_pretrain.filtered_1_40.splited/2585_264000.npy'

In [27]:
len(splitted_paths)

365495

In [28]:
mitsar_chls = ['Fp1', 'Fp2', 'FZ', 'FCz', 'Cz', 'Pz', 'O1', 'O2', 'F3', 'F4', 
               'F7', 'F8', 'C3', 'C4', 'T3', 'T4', 'P3', 'P4', 'T5', 'T6', 'A1', 'A2']
mitsar_chls = [i.upper() for i in mitsar_chls]

In [29]:
test_dataset = TEST(splitted_paths)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0, drop_last=True)

# Training

In [30]:
model = EEGEmbedder()

In [31]:

from torch.optim.lr_scheduler import _LRScheduler
class NoamLR(_LRScheduler):
    def __init__(self, optimizer, warmup_steps, d_model=512):
        self.warmup_steps = warmup_steps
        self.d_model = d_model
        super().__init__(optimizer)

    def get_lr(self):
        last_epoch = max(1, self.last_epoch)
        factor = min(last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5))
        # scale = self.warmup_steps ** 0.5 * min(last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5))
        # return [base_lr * scale for base_lr in self.base_lrs]
        return [base_lr * self.d_model ** (-0.5) * factor for base_lr in self.base_lrs]

In [32]:
cossim = torch.nn.CosineSimilarity(dim=-1)

def cosloss(anchor, real, negative):
    a = torch.exp(cossim(anchor, real)) / 0.1
    b = sum([torch.exp(cossim(anchor, negative[:, n])) / 0.1 for n in range(negative.shape[1])]) + 1e-6
    return -torch.log(a/b)

In [54]:
import random
def worker_init_fn(worker_id):
    torch_seed = torch.initial_seed()
    random.seed(torch_seed + worker_id)
    np.random.seed((torch_seed + worker_id) % 2**30)

train_dataset = TEST(splitted_paths[:-15000])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0, drop_last=True, worker_init_fn = worker_init_fn)

test_dataset = TEST(splitted_paths[-15000:])
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0, drop_last=True, worker_init_fn = worker_init_fn)


In [33]:

model.train()

lr_d = 1
acc_size = 1
training_epochs1 = 150000 // len(train_loader)

optim = torch.optim.AdamW(model.parameters(), lr=lr_d)

model_test = torch.nn.DataParallel(model)
model_test.to('cuda:0')

loss_func = torch.nn.MSELoss()
# scheduler = torch.optim.lr_scheduler.LinearLR(optim, start_factor=1.0, end_factor=0.1, total_iters=training_epochs1*len(train_loader))
scheduler = NoamLR(optim, 100000, 512)

steps = 0

In [34]:
len(train_loader), training_epochs1, training_epochs1 * len(train_loader)

(5476, 27, 147852)

In [35]:
# model.cpu()(batch['anchor'][None], batch['mask'][None], batch['channels'][None])

In [None]:

for epoch in range(training_epochs1):
    mean_loss = 0
    acc_step = 0
    for batch in train_loader:
        # batch = train_dataset.__getitem__(i)
        placeholder = torch.zeros((batch['anchor'].shape[0], 1, 512)) - 5
        ae, label = model_test(
            batch['anchor'],#.to('cuda:0'), 
            None, 
            batch['channels'].long(),
            placeholder)#.to('cuda:0'))
        
        logits = _calculate_similarity(torch.transpose(label, 1, 2), torch.transpose(ae, 1, 2), _generate_negatives(torch.transpose(label, 1, 2))[0])

        fake_labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
        loss = torch.nn.CrossEntropyLoss()(logits, fake_labels) + 0.001 * label.pow(2).mean()

        loss = loss.mean()
        loss.backward()
        mean_loss += loss.item()
        acc_step += 1
        steps += 1
        # raise
        if acc_step != 0 and acc_step % acc_size == 0:
            optim.step()
            scheduler.step()
            optim.zero_grad()
            if steps % 100 == 0:
                print('Loss/train\t{}'.format(mean_loss / acc_size))
            writer.add_scalar('Loss/train', mean_loss / acc_size, steps)
            mean_loss = 0
        if steps != 0 and steps % 1000 == 0:
            der = 0
            try:
                with torch.no_grad():
                    for batch in test_loader:
                        # batch = test_dataset.__getitem__(i)
                        placeholder = torch.zeros((batch['anchor'].shape[0], 1, 512)) - 5
                        ae, label = model_test(
                            batch['anchor'],#.to('cuda:0'), 
                            None, 
                            batch['channels'].long(),#.to('cuda:0'),
                            placeholder)
                        # loss_positive = loss_fct(ae, pe)
                        # loss_negative = loss_fct(ae, ne)
                        logits = _calculate_similarity(torch.transpose(label, 1, 2), torch.transpose(ae, 1, 2), _generate_negatives(torch.transpose(label, 1, 2))[0])

                        fake_labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
                        loss = torch.nn.CrossEntropyLoss()(logits, fake_labels) + 0.001 * label.pow(2).mean()

                        loss = loss.mean() / acc_size
                        der += loss
                der /= len(test_loader)
                writer.add_scalar('Loss/test', der, steps)

                print('Loss: {}\t'.format(der))
            except:
                raise
            torch.save(model_test.module.state_dict(), 'models2/step_{}.pt'.format(steps))

Loss/train	1.1030877828598022
Loss/train	1.1205166578292847
Loss/train	1.172516942024231
Loss/train	1.148479700088501
Loss/train	1.1080737113952637
Loss/train	1.1164870262145996
Loss/train	1.1132405996322632
Loss: 1.1013708114624023	
Loss/train	1.0995947122573853
Loss/train	1.1270442008972168
Loss/train	1.0850714445114136
Loss/train	1.0480189323425293
Loss/train	1.038840413093567
Loss/train	1.047231674194336
Loss/train	1.0649646520614624
Loss/train	1.0510094165802002
Loss/train	1.071996808052063
Loss/train	1.0528682470321655
Loss: 1.0501459836959839	
Loss/train	1.0457582473754883
Loss/train	1.0784471035003662
Loss/train	0.9939504861831665
Loss/train	1.0439149141311646
Loss/train	0.9971293210983276
Loss/train	1.04259192943573
Loss/train	1.0033913850784302
Loss/train	1.0128108263015747
Loss/train	1.0224108695983887
Loss/train	1.02928626537323
Loss: 1.0030395984649658	
Loss/train	1.0195443630218506
Loss/train	1.0206608772277832
Loss/train	0.9925925135612488
Loss/train	0.9921804666519165
L

In [59]:
der

tensor(0.5176, device='cuda:0')

In [57]:
steps

159219

In [47]:
np.load('/home/data/TUH_pretrain.filtered_1_40.splited/2990_174000.npy', allow_pickle=True)

array({'value': array([[ -3.85031008,   2.85196812,   2.07325583, ...,  -7.50037225,
        -10.90961471,  -8.02874688],
       [-14.41600966,  -6.85561872,  -1.99783744, ..., -12.02037489,
        -20.83239401,  -7.28345165],
       [-22.68984145, -16.19871822,  -5.15646806, ..., -17.4048751 ,
        -25.79506833, -21.20112015],
       ...,
       [  0.83328995,   6.0946769 ,  -1.01211029, ...,   5.56320695,
          3.40607593,   9.95040555],
       [ -5.23044887,   2.98819828,  -2.53847765, ...,   1.96995798,
          1.19625315,   6.24915094],
       [ -6.41559363,   0.44634677,  -3.84797447, ...,  -2.32357481,
         -2.30819229,  -0.28246068]]), 'date': 's004_2016_08_25', 'id': '147_00014747', 'ref_type': '01_tcp_ar', 'channels': ['FP1', 'FP2', 'FZ', 'FCZ', 'CZ', 'PZ', 'O1', 'O2', 'F3', 'F4', 'F7', 'F8', 'C3', 'C4', 'T3', 'T4', 'P3', 'P4', 'T5', 'T6', 'A1', 'A2']},
      dtype=object)

In [37]:
model.train()

lr_d = 1e-5
acc_size = 1
training_epochs1 = 150000 // len(train_loader)

optim = torch.optim.AdamW(model.parameters(), lr=lr_d)

model_test = torch.nn.DataParallel(model)
model_test.to('cuda:0')

loss_func = torch.nn.MSELoss()

In [38]:

for epoch in range(training_epochs1):
    mean_loss = 0
    acc_step = 0
    for batch in train_loader:
        # batch = train_dataset.__getitem__(i)
        placeholder = torch.zeros((batch['anchor'].shape[0], 1, 512)) - 5
        ae, label = model_test(
            batch['anchor'],#.to('cuda:0'), 
            batch['mask'], 
            batch['channels'].long(),
            placeholder)#.to('cuda:0'))
        # loss_value = loss_fct(output, batch['labels'].to('cuda:0'))
        # loss_positive = loss_fct(ae, pe)
        # loss_negative = loss_fct(ae, ne)
        reshaped_indexed_real = []
        reshaped_indexed_pred = []
        reshaped_indexed_negative = []
        for b_i in range(ae.shape[0]):
            reshaped_indexed_pred.append(ae[b_i][batch['mask'][b_i][::10]][None])
            reshaped_indexed_real.append(label[b_i][batch['mask'][b_i][::10]][None])
            reshaped_indexed_negative.append(label[b_i][(batch['mask'][b_i] + 20) % 132][None])
        reshaped_indexed_pred = torch.cat(reshaped_indexed_pred, 0)
        reshaped_indexed_real = torch.cat(reshaped_indexed_real, 0)
        reshaped_indexed_negative = torch.cat(reshaped_indexed_negative, 0)
        # raise
        loss = cosloss(
            reshaped_indexed_pred, 
            reshaped_indexed_real, 
            reshaped_indexed_negative.reshape(
                reshaped_indexed_negative.shape[0], 
                reshaped_indexed_negative.shape[1] // 5, 
                reshaped_indexed_negative.shape[1] // 10, 
                reshaped_indexed_negative.shape[2]))
        # loss = loss_func(reshaped_indexed_pred[indexes_fine][:, :, :2], reshaped_indexed_real[indexes_fine][:, :, :2].to('cuda:0'))

        # hard_loss = loss[loss > 0.0]
        # if (len(hard_loss) == 0):
        #     hard_loss = loss
        # hard_loss = loss
        loss = loss.mean()
        loss.backward()
        mean_loss += loss.item()
        acc_step += 1
        steps += 1
        # raise
        if acc_step != 0 and acc_step % acc_size == 0:
            optim.step()
            scheduler.step()
            optim.zero_grad()
            if steps % 100 == 0:
                print('Loss/train\t{}'.format(mean_loss / acc_size))
            writer.add_scalar('Loss/train', mean_loss / acc_size, steps)
            mean_loss = 0
        if steps != 0 and steps % 1000 == 0:
            der = 0
            try:
                with torch.no_grad():
                    for batch in test_loader:
                        # batch = test_dataset.__getitem__(i)
                        placeholder = torch.zeros((batch['anchor'].shape[0], 1, 512)) - 5
                        ae, label = model_test(
                            batch['anchor'],#.to('cuda:0'), 
                            batch['mask'], 
                            batch['channels'].long(),#.to('cuda:0'),
                            placeholder)
                        # loss_positive = loss_fct(ae, pe)
                        # loss_negative = loss_fct(ae, ne)
                        reshaped_indexed_real = []
                        reshaped_indexed_pred = []
                        reshaped_indexed_negative = []
                        for b_i in range(ae.shape[0]):
                            reshaped_indexed_pred.append(ae[b_i][batch['mask'][b_i][::10]][None])
                            reshaped_indexed_real.append(label[b_i][batch['mask'][b_i][::10]][None])
                            reshaped_indexed_negative.append(label[b_i][(batch['mask'][b_i] + 20) % 132][None])
                        reshaped_indexed_pred = torch.cat(reshaped_indexed_pred, 0)
                        reshaped_indexed_real = torch.cat(reshaped_indexed_real, 0)
                        reshaped_indexed_negative = torch.cat(reshaped_indexed_negative, 0)

                        loss = cosloss(
                            reshaped_indexed_pred, 
                            reshaped_indexed_real, 
                            reshaped_indexed_negative.reshape(
                                reshaped_indexed_negative.shape[0], 
                                reshaped_indexed_negative.shape[1] // 5, 
                                reshaped_indexed_negative.shape[1] // 10, 
                                reshaped_indexed_negative.shape[2]))

                        loss = loss.mean() / acc_size
                        der += loss
                der /= len(test_loader)
                writer.add_scalar('Loss/test', der, steps)

                print('Loss: {}\t'.format(der))
            except:
                raise
            torch.save(model_test.module.state_dict(), 'models/step_{}.pt'.format(steps))

Loss/train	1.8929529190063477
Loss/train	1.9144232273101807
Loss/train	1.8872368335723877
Loss/train	1.898066759109497
Loss/train	1.9088211059570312
Loss/train	1.896066665649414
Loss/train	1.903018593788147
Loss/train	1.8914568424224854
Loss/train	1.903407096862793
Loss: 1.7999563217163086	
Loss/train	1.8804712295532227
Loss/train	1.8972877264022827
Loss/train	1.8831760883331299
Loss/train	1.9031308889389038
Loss/train	1.8904520273208618
Loss/train	1.9119493961334229
Loss/train	1.9052908420562744
Loss/train	1.896553874015808
Loss/train	1.9152153730392456
Loss/train	1.869958758354187
Loss: 1.7969849109649658	
Loss/train	1.8917416334152222
Loss/train	1.9191652536392212
Loss/train	1.8704347610473633
Loss/train	1.8648254871368408
Loss/train	1.8790819644927979
Loss/train	1.910555124282837
Loss/train	1.8742694854736328
Loss/train	1.905527949333191
Loss/train	1.8932743072509766
Loss/train	1.8955141305923462
Loss: 1.7965463399887085	
Loss/train	1.860345482826233
Loss/train	1.887194037437439
Lo

KeyboardInterrupt: 