In [1]:
!nvidia-smi

Sat Aug 26 18:37:45 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.78.01    Driver Version: 525.78.01    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:19:00.0 Off |                  N/A |
| 30%   32C    P8    30W / 350W |      6MiB / 24576MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:65:00.0 Off |                  N/A |
| 30%   33C    P8    21W / 350W |      6MiB / 24576MiB |      0%      Defaul

##### %config Completer.use_jedi = False

In [6]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
# os.environ['http_proxy'] = "http://127.0.0.1:3128"
# os.environ['https_proxy'] = "http://127.0.0.1:3128"

In [7]:
import matplotlib.pyplot as plt
from torch.utils.data import Dataset

import numpy as np
import pandas as pd
import os

In [8]:
from torch.utils.tensorboard import SummaryWriter

In [9]:
import torch
from transformers import BertConfig
from transformers.models.bert.modeling_bert import BertEncoder

# Architecture

In [10]:
class TransposeCustom(torch.nn.Module):
    def __init__(self):
        super(TransposeCustom, self).__init__()
        
    def forward(self, x):
        return torch.transpose(x, 1, 2)

In [11]:
# config = BertConfig(is_decoder=True, 
#                     add_cross_attention=True,
#                     ff_layer='conv',
#                     conv_kernel=1,
#                     conv_kernel_num=3)

In [12]:
def _make_span_from_seeds(seeds, span, total=None):
    inds = list()
    for seed in seeds:
        for i in range(seed, seed + span):
            if total is not None and i >= total:
                break
            elif i not in inds:
                inds.append(int(i))
    return np.array(inds)

In [13]:
def _make_mask(shape, p, total, span, allow_no_inds=False):
    # num_mask_spans = np.sum(np.random.rand(total) < p)
    # num_mask_spans = int(p * total)
    mask = torch.zeros(shape, requires_grad=False, dtype=torch.bool)

    for i in range(shape[0]):
        mask_seeds = list()
        while not allow_no_inds and len(mask_seeds) == 0 and p > 0:
            mask_seeds = np.nonzero(np.random.rand(total) < p)[0]

        mask[i, _make_span_from_seeds(mask_seeds, span, total=total)] = True

    return mask

In [14]:
import math
class PositionalEncoding(torch.nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = torch.nn.Dropout(p=dropout)
        
        position = torch.arange(max_len).unsqueeze(1)
        position = (position.T - position).T / max_len
        self.register_buffer('rel_position', position)
        
        self.conv = torch.nn.Conv1d(max_len, d_model, 25, padding=25 // 2, groups=16)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [batch_size, seq_len, embedding_dim]
        """
        rel_pos = self.conv(self.rel_position[:x.size(1), :x.size(1)][None])[0].T
        print(rel_pos.shape)
        x = x + rel_pos
        return self.dropout(x)

In [15]:
class EEGEmbedder(torch.nn.Module):
    def __init__(self):
        super(EEGEmbedder, self).__init__()
        config = BertConfig(is_decoder=False, 
                    add_cross_attention=False,
                    ff_layer='linear',
                    hidden_size=512,
                    num_attention_heads=8,
                    num_hidden_layers=8,
                    conv_kernel=1,
                    conv_kernel_num=1)
        self.model = BertEncoder(config)
        
        self.pos_e = PositionalEncoding(512, max_len=6000)
        self.ch_embedder = torch.nn.Embedding(len(mitsar_chls), 512)
        self.ch_norm = torch.nn.LayerNorm(512)
        
        self.input_norm = torch.nn.LayerNorm(2)
        self.input_embedder = torch.nn.Sequential(
            TransposeCustom(),
            torch.nn.Conv1d(21, 32, 5, 2, padding=0),
            torch.nn.Conv1d(32, 64, 5, 2, padding=0),
            # TransposeCustom(),
            torch.nn.GroupNorm(64 // 2, 64),
            torch.nn.GELU(),
            # TransposeCustom(),
            torch.nn.Conv1d(64, 128, 3, 2, padding=0),
            torch.nn.Conv1d(128, 196, 3, 2, padding=0),
            # TransposeCustom(),
            torch.nn.GroupNorm(196 // 2, 196),
            torch.nn.GELU(),
            # TransposeCustom(),
            torch.nn.Conv1d(196, 256, 5, 1, padding=0),
            torch.nn.Conv1d(256, 384, 5, 1, padding=0),
            # TransposeCustom(),
            torch.nn.GroupNorm(384 // 2, 384),
            torch.nn.GELU(),
            # TransposeCustom(),
            torch.nn.Conv1d(384, 512, 5, 1, padding=0),
            torch.nn.Conv1d(512, 512, 1, 1, padding=0),
            torch.nn.GroupNorm(512 // 2, 512),
            torch.nn.GELU(),
            TransposeCustom(),
        # torch.nn.LeakyReLU(),
        )
        
        
        self.output_embedder = torch.nn.Linear(512, 512)
        self.transpose = TransposeCustom()
        
        self.mask_embedding = torch.nn.Parameter(torch.normal(0, 512**(-0.5), size=(512,)),
                                                   requires_grad=True)
        self.placeholder = torch.nn.Parameter(torch.normal(0, 512**(-0.5), size=(512,)),
                                                   requires_grad=True)
        
    def forward(self, inputs, attention_mask, ch_vector):
        embedding = self.input_embedder(inputs)
        # create embedding for two channel indexes and sumup them to a single one
        ch_embedding = self.ch_embedder(ch_vector).sum(1)
        ch_embedding = ch_embedding[:, None]
        embedding += ch_embedding
        
        # perform masking
        embedding_unmasked = embedding.clone() # keep for loss calculation
        mask = _make_mask((embedding.shape[0], embedding.shape[1]), 0.05, embedding.shape[1], 10)
        embedding[mask] = self.mask_embedding
        
        # additional vector for classification tasks later
        placeholder = torch.zeros((embedding.shape[0], 1, embedding.shape[2]), device=embedding.device)
        placeholder += self.placeholder
        embedding = torch.cat([placeholder, embedding], 1)
        encoder_output = self.model(embedding, output_hidden_states=True,
                               output_attentions=True)[0]
        
        encoder_output = self.output_embedder(encoder_output)
        
        return encoder_output[:, 1:], embedding_unmasked

# Data

In [16]:
def _generate_negatives(z):
    """Generate negative samples to compare each sequence location against"""
    num_negatives = 20
    batch_size, feat, full_len = z.shape
    z_k = z.permute([0, 2, 1]).reshape(-1, feat)
    with torch.no_grad():
        # candidates = torch.arange(full_len).unsqueeze(-1).expand(-1, self.num_negatives).flatten()
        negative_inds = torch.randint(0, full_len-1, size=(batch_size, full_len * num_negatives))
        # From wav2vec 2.0 implementation, I don't understand
        # negative_inds[negative_inds >= candidates] += 1

        for i in range(1, batch_size):
            negative_inds[i] += i * full_len

    z_k = z_k[negative_inds.view(-1)].view(batch_size, full_len, num_negatives, feat)
    return z_k, negative_inds

In [17]:
def _calculate_similarity( z, c, negatives):
    c = c[..., :].permute([0, 2, 1]).unsqueeze(-2)
    z = z.permute([0, 2, 1]).unsqueeze(-2)

    # In case the contextualizer matches exactly, need to avoid divide by zero errors
    negative_in_target = (c == negatives).all(-1)
    targets = torch.cat([c, negatives], dim=-2)

    logits = torch.nn.functional.cosine_similarity(z, targets, dim=-1) / 0.1
    if negative_in_target.any():
        logits[1:][negative_in_target] = float("-inf")

    return logits.view(-1, logits.shape[-1])

In [18]:
def masking(ts):
    start_shift = np.random.choice(range(10))
    downsampling = 2
    indices = np.random.choice(np.array(list(range(110)))[start_shift::10][::downsampling], 5, replace=False)
    masked_idx = []
    for i in indices:
        masked_idx.extend(range(i, i+10))

    masked_idx = np.array(masked_idx)
    
    # mask = np.ones((6000, 2))
    # # desync some masked channels
    # ts_masked = ts.copy()
    # if np.random.choice([0, 1], p=[0.7, 0.3]):
    #     ts_masked[masked_idx, np.random.choice([0, 1])] *= 0
    # else:
    #     ts_masked[masked_idx] *= 0
        
    return None, masked_idx

In [51]:
class TEST(torch.utils.data.Dataset):
    def __init__(self, path):
        super(TEST, self).__init__()
        self.main_path = path
        self.paths = path

    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, idx):
        path = self.paths[idx]
        # take 60s of recording with specified shift
        key = False
        while(key == False):
            try:
                sample = np.load(path, allow_pickle=True).item()
                key = True
            except Exception as e:
                print("Path: {} is broken ".format(path), e)
                path = np.random.choice(self.paths, 1)[0]
                
        signal = sample['value_pure']
        real_len = signal.shape[0]
        channels_ids = [i for i, val in enumerate(sample['channels']) if i != 3 and val in mitsar_chls]
        print('sample[channels]', sample['channels'])
        print('mitsar_chls', mitsar_chls)
        print('channels_ids', channels_ids)
        # choose 2 random channels
        # channels_to_train = np.random.choice(channels_ids, 2, replace=False)
        # channels_vector = torch.tensor((channels_to_train))
        
        # use all available
        channels_to_train = channels_ids
        channels_vector = torch.tensor((channels_to_train))
        signal = signal[:, channels_to_train]

        print('channels_to_train', channels_to_train)

        # remove normalization for now with within channel z-norm
        # sample_norm = (sample - tuh_filtered_stat_vals['min_vals_filtered'][channels_vector]) / (tuh_filtered_stat_vals['max_vals_filtered'][channels_vector] - tuh_filtered_stat_vals['min_vals_filtered'][channels_vector] + 1e-6)
        sample_norm_mean = signal.mean()
        sample_norm_std = np.std(signal)
        
        signal_norm = (signal - sample_norm_mean) / (sample_norm_std)
        
        if signal_norm.shape[0] < 6000:
            signal_norm = np.pad(signal_norm, ((0, 6000 - signal_norm.shape[0]), (0, 0)))
        
        attention_mask = torch.ones(6000)
        attention_mask[real_len:] = 0
        return {'anchor': torch.from_numpy(signal_norm).float(), 
                # 'label': sample_label, 
                # 'anchor_masked': torch.from_numpy(sample_masked).float(), 
                # 'mask': torch.tensor(mask),
                'channels': channels_vector,
                'attention_mask': attention_mask}

In [52]:
# 56 server
# splitted_paths = ['/home/data/TUH_pretrain.filtered_1_40.v2.splited/{}'.format(i) for i in os.listdir('/home/data/TUH_pretrain.filtered_1_40.v2.splited/')]
# 34 server
splitted_paths = ['/media/hdd/data/TUH_pretrain.filtered_1_40.v2.splited/{}'.format(i) for i in os.listdir('/media/hdd/data/TUH_pretrain.filtered_1_40.v2.splited/')]


In [53]:
sample = np.load(np.random.choice(splitted_paths, 1)[0], allow_pickle=True)

In [54]:
sample

array({'value': array([[  2.9273543 ,   6.77061426,  -2.80408903, ...,   5.15022475,
         -0.17297896, -19.82115969],
       [  6.93326314,  10.97923026,   0.67440623, ...,   9.75331321,
          4.75569813, -23.13059152],
       [ 12.67457194,  16.60059825,   5.67962452, ...,  16.59343865,
         12.71718032,  -8.59118442],
       ...,
       [ 10.86360004,   8.23089587,   5.4728705 , ...,   0.85300305,
          3.27642359, -11.6886919 ],
       [ 13.62825731,   9.50276507,   6.48051907, ...,   0.15708415,
          2.7950165 , -11.35647126],
       [ 13.76536792,   8.60568661,   5.90039718, ...,   0.38953789,
          0.76735608, -11.34078283]]), 'value_pure': array([[ -2.36025874,   3.68367735, -12.96940723, ..., -16.00250418,
        -17.24266316, -19.82115969],
       [ -6.86395286,  -0.67825386, -17.96700513, ..., -20.11534848,
        -21.16590104, -23.13059152],
       [  6.73416014,  12.99882206,  -4.87992788, ...,  -5.30467718,
         -4.96709112,  -8.59118442],
  

In [55]:
len(splitted_paths)

161710

In [56]:
mitsar_chls = ['Fp1', 'Fp2', 'FZ', 'FCz', 'Cz', 'Pz', 'O1', 'O2', 'F3', 'F4', 
               'F7', 'F8', 'C3', 'C4', 'T3', 'T4', 'P3', 'P4', 'T5', 'T6', 'A1', 'A2']
mitsar_chls = [i.upper() for i in mitsar_chls]

In [57]:
test_dataset = TEST(splitted_paths)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0, drop_last=True)

In [58]:
test_dataset.__getitem__(4)

sample[channels] ['FP1', 'FP2', 'FZ', 'FCZ', 'CZ', 'PZ', 'O1', 'O2', 'F3', 'F4', 'F7', 'F8', 'C3', 'C4', 'T3', 'T4', 'P3', 'P4', 'T5', 'T6', 'A1', 'A2', 'EKG1']
mitsar_chls ['FP1', 'FP2', 'FZ', 'FCZ', 'CZ', 'PZ', 'O1', 'O2', 'F3', 'F4', 'F7', 'F8', 'C3', 'C4', 'T3', 'T4', 'P3', 'P4', 'T5', 'T6', 'A1', 'A2']
channels_ids [0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]
channels_to_train [0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]


{'anchor': tensor([[ 0.3678,  0.0749,  0.0408,  ..., -0.0528,  0.1377,  0.0048],
         [ 0.3655,  0.0216,  0.0055,  ...,  0.0235,  0.1583,  0.0343],
         [ 0.3018, -0.0476,  0.0187,  ..., -0.0155,  0.0765, -0.0356],
         ...,
         [-0.3121, -0.1127, -0.2664,  ..., -0.3250, -0.4115, -0.3452],
         [-0.3218, -0.1485, -0.2783,  ..., -0.2441, -0.4151, -0.2793],
         [-0.2981, -0.1423, -0.2441,  ..., -0.1679, -0.3665, -0.2556]]),
 'channels': tensor([ 0,  1,  2,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
         19, 20, 21]),
 'attention_mask': tensor([1., 1., 1.,  ..., 1., 1., 1.])}

# Training

In [43]:
model = EEGEmbedder()

In [44]:

from torch.optim.lr_scheduler import _LRScheduler
class NoamLR(_LRScheduler):
    def __init__(self, optimizer, warmup_steps, d_model=512):
        self.warmup_steps = warmup_steps
        self.d_model = d_model
        super().__init__(optimizer)

    def get_lr(self):
        last_epoch = max(1, self.last_epoch)
        factor = min(last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5))
        # scale = self.warmup_steps ** 0.5 * min(last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5))
        # return [base_lr * scale for base_lr in self.base_lrs]
        return [base_lr * self.d_model ** (-0.5) * factor for base_lr in self.base_lrs]

In [45]:
cossim = torch.nn.CosineSimilarity(dim=-1)

def cosloss(anchor, real, negative):
    a = torch.exp(cossim(anchor, real)) / 0.1
    b = sum([torch.exp(cossim(anchor, negative[:, n])) / 0.1 for n in range(negative.shape[1])]) + 1e-6
    return -torch.log(a/b)

In [46]:
import random
def worker_init_fn(worker_id):
    torch_seed = torch.initial_seed()
    random.seed(torch_seed + worker_id)
    np.random.seed((torch_seed + worker_id) % 2**30)

train_dataset = TEST(splitted_paths[:-15000])
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0, drop_last=True, worker_init_fn = worker_init_fn)

test_dataset = TEST(splitted_paths[-15000:])
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0, drop_last=True, worker_init_fn = worker_init_fn)


In [47]:
writer = SummaryWriter('logs')

In [48]:

model.train()

lr_d = 1
acc_size = 8
training_epochs1 = 100000 // len(train_loader)

optim = torch.optim.AdamW(model.parameters(), lr=lr_d)

model_test = torch.nn.DataParallel(model)
model_test.to('cuda:0')

loss_func = torch.nn.MSELoss()
# scheduler = torch.optim.lr_scheduler.LinearLR(optim, start_factor=1.0, end_factor=0.1, total_iters=training_epochs1*len(train_loader))
scheduler = NoamLR(optim, 100000, 512)

steps = 0

In [49]:
len(train_loader), training_epochs1, training_epochs1 * len(train_loader)

(2292, 43, 98556)

In [71]:
train_dataset[0]['anchor'].shape[:]
aa = ['f1','c2','c3','d4']
bb = ['c2','d4','d5']

  # channels_ids = [i for i, val in enumerate(sample['channels']) if i != 3 and val in mitsar_chls]

# qq = [i for i, val in enumerate(aa) if i != 1 and val in bb]
qq = [i for i, val in enumerate(aa) if i != 1 and val in bb]
print('ch: ', (qq))


sample[channels] ['FP1', 'FP2', 'FZ', 'FCZ', 'CZ', 'PZ', 'O1', 'O2', 'F3', 'F4', 'F7', 'F8', 'C3', 'C4', 'T3', 'T4', 'P3', 'P4', 'T5', 'T6', 'A1', 'A2', 'EKG1']
channels_ids [0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]
channels_to_train [0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]
ch:  [3]


In [73]:
train_dataset[0]['channels'].shape[:]

torch.Size([21])

In [30]:
# model.cpu()(batch['anchor'][None], batch['mask'][None], batch['channels'][None])

In [37]:
epoch, training_epochs1

(25, 43)

In [38]:

for epoch in range(25, training_epochs1):
    mean_loss = 0
    acc_step = 0
    for batch in train_loader:
        ae, label = model_test(
            batch['anchor'],#.to('cuda:0'), 
            None, 
            batch['channels'].long())
        
        logits = _calculate_similarity(torch.transpose(label, 1, 2), torch.transpose(ae, 1, 2), _generate_negatives(torch.transpose(label, 1, 2))[0])

        fake_labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
        loss = torch.nn.CrossEntropyLoss()(logits, fake_labels) + 0.001 * label.pow(2).mean()

        loss = loss.mean()
        loss.backward()
        mean_loss += loss.item()
        acc_step += 1
        steps += 1
        # raise
        if acc_step != 0 and acc_step % acc_size == 0:
            optim.step()
            scheduler.step()
            optim.zero_grad()
            if steps % 100 == 0:
                print('Loss/train\t{}'.format(mean_loss / acc_size))
            writer.add_scalar('Loss/train', mean_loss / acc_size, steps)
            mean_loss = 0
        if steps != 0 and steps % 1000 == 0:
            der = 0
            try:
                with torch.no_grad():
                    for batch in test_loader:
                        ae, label = model_test(
                            batch['anchor'],#.to('cuda:0'), 
                            None, 
                            batch['channels'].long())
                        logits = _calculate_similarity(torch.transpose(label, 1, 2), torch.transpose(ae, 1, 2), _generate_negatives(torch.transpose(label, 1, 2))[0])

                        fake_labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
                        loss = torch.nn.CrossEntropyLoss()(logits, fake_labels) + 0.001 * label.pow(2).mean()

                        loss = loss.mean() / acc_size
                        der += loss
                der /= len(test_loader)
                writer.add_scalar('Loss/test', der, steps)

                print('Loss: {}\t'.format(der))
            except:
                raise
            torch.save(model_test.module.state_dict(), 'models/step.pt'.format(steps))

Loss/train	2.707833170890808
Loss/train	2.7044030725955963
Loss/train	2.7018640637397766
Loss/train	2.702863395214081
Loss/train	2.7006741762161255
Loss: 0.33758121728897095	
Loss/train	2.6986430883407593
Loss/train	2.6988857984542847
Loss/train	2.6954116821289062
Loss/train	2.693866729736328
Loss/train	2.693163961172104
Loss: 0.3367648124694824	
Loss/train	2.6922953128814697
Loss/train	2.692641854286194
Loss/train	2.690729022026062
Loss/train	2.6889248192310333
Loss/train	2.687345862388611
Loss: 0.33587104082107544	
Loss/train	2.6858376562595367
Loss/train	2.685423821210861
Loss/train	2.6850564181804657
Loss/train	2.681473135948181
Loss/train	2.681659013032913
Loss: 0.3349754810333252	
Loss/train	2.6786417961120605
Loss/train	2.6785849034786224
Loss/train	2.676129937171936
Loss/train	2.67608904838562
Loss/train	2.6727095246315002
Loss/train	2.67220139503479
Loss: 0.3341512382030487	
Loss/train	2.6699333786964417
Loss/train	2.6697755455970764
Loss/train	2.6685484647750854
Loss/train	2.

In [None]:

for epoch in range(training_epochs1):
    mean_loss = 0
    acc_step = 0
    for batch in train_loader:
        ae, label = model_test(
            batch['anchor'],#.to('cuda:0'),
            None,
            batch['channels'].long())

        logits = _calculate_similarity(torch.transpose(label, 1, 2), torch.transpose(ae, 1, 2), _generate_negatives(torch.transpose(label, 1, 2))[0])

        fake_labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
        loss = torch.nn.CrossEntropyLoss()(logits, fake_labels) + 0.001 * label.pow(2).mean()

        loss = loss.mean()
        loss.backward()
        mean_loss += loss.item()
        acc_step += 1
        steps += 1
        # raise
        if acc_step != 0 and acc_step % acc_size == 0:
            optim.step()
            scheduler.step()
            optim.zero_grad()
            if steps % 100 == 0:
                print('Loss/train\t{}'.format(mean_loss / acc_size))
            writer.add_scalar('Loss/train', mean_loss / acc_size, steps)
            mean_loss = 0
        if steps != 0 and steps % 1000 == 0:
            der = 0
            try:
                with torch.no_grad():
                    for batch in test_loader:
                        ae, label = model_test(
                            batch['anchor'],#.to('cuda:0'),
                            None,
                            batch['channels'].long())
                        logits = _calculate_similarity(torch.transpose(label, 1, 2), torch.transpose(ae, 1, 2), _generate_negatives(torch.transpose(label, 1, 2))[0])

                        fake_labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
                        loss = torch.nn.CrossEntropyLoss()(logits, fake_labels) + 0.001 * label.pow(2).mean()

                        loss = loss.mean() / acc_size
                        der += loss
                der /= len(test_loader)
                writer.add_scalar('Loss/test', der, steps)

                print('Loss: {}\t'.format(der))
            except:
                raise
            torch.save(model_test.module.state_dict(), 'models/step.pt'.format(steps))

Loss/train	12.755858182907104
Loss/train	12.643582820892334
Loss/train	12.454817771911621
Loss/train	12.185070753097534
Loss/train	11.838916063308716
Loss: 1.477887749671936	
Loss/train	11.409443974494934
Loss/train	10.90146255493164
Loss/train	10.320974111557007
Loss/train	9.67824900150299
Loss/train	8.992188572883606
Loss: 1.1207029819488525	
Loss/train	8.31241750717163
Loss/train	7.99466860294342
Loss/train	7.372100114822388
Loss/train	6.819593012332916
Loss/train	6.336482048034668
Loss: 0.7635732293128967	
Loss/train	5.912988305091858
Loss/train	5.541968882083893
Loss/train	5.225658714771271
Loss/train	4.952653110027313
Loss/train	4.719955623149872
Loss: 0.576254665851593	
Loss/train	4.516627132892609
Loss/train	4.346154153347015
Loss/train	4.199998915195465
Loss/train	4.136499226093292
Loss/train	4.017586529254913
Loss/train	3.9136194586753845
Loss: 0.48888149857521057	
Loss/train	3.8248919248580933
Loss/train	3.7474300265312195
Loss/train	3.6797587275505066
Loss/train	3.618002593