In [1]:
import numpy as np
import faiss
import copy
import json
import torch
import math
import random
from tqdm import tqdm
from augmentation import TfIdfAugmentation
from tools.utils import ExternalPreprocessor

from modelling.models import DAN, Embedding
from modelling.templates import SequenceTemplate

In [2]:
MAX_LEN = 32

In [3]:
VOCAB_PATH = '../routing/data/sberbank_embeddings/w2v_m5_w3_v300_norm_v48_vocab.txt'
W2V_MATRIX_PATH = '../routing/data/sberbank_embeddings/w2v_m5_w3_v300_norm_v48_vectors.npy'

In [4]:
with open('token2prob.json') as f:
    token2prob = json.loads(f.read())

In [5]:
with open(VOCAB_PATH) as f:
    vocab = f.read().split('\n')

In [6]:
vocab = {key: value for value, key in enumerate(vocab)}

In [7]:
index2prob = {vocab[token]: token2prob[token] for token in token2prob if token in vocab}

In [8]:
aug = TfIdfAugmentation(indexes_matrix=np.load('nearest_matrix.npy'), index2prob=index2prob)

In [9]:
with open('train.jsonl') as f:
    data = [json.loads(sample) for sample in f.read().split('\n')]

In [10]:
random.shuffle(data)

In [11]:
train, test = data[:-len(data) // 10], data[-len(data) // 10:]

In [12]:
len(train), len(test), len(train) + len(test) == len(data)

(11623, 1292, True)

In [13]:
TARGET2INDEX = {
    'ANNA.1.sales': 0,
    'ANNA.1.sbbol': 1,
    'ANNA.1.oper_support': 2
}

In [14]:
def sequence_padding(sequence, max_sequence_length, value) -> np.ndarray:

    sequence = sequence[:max_sequence_length]

    if len(sequence) < max_sequence_length:
        for _ in range((max_sequence_length - len(sequence))):
            sequence.append(value)

    sequence = np.array(sequence)

    return sequence

In [15]:
def indexing_batch(x, vocab, max_sequence_length):
    
    x = [[vocab[tok] for tok in sample if tok in vocab] for sample in x]
    
    x = np.array([sequence_padding(sample, max_sequence_length=max_sequence_length, value=0) for sample in x])
    
    return x

In [16]:
def batch_processing(batch):
    
    x = [sample['tokens'] for sample in batch]
    
    x = indexing_batch(x, vocab, MAX_LEN)
    x_aug = aug.replace_batch(copy.deepcopy(x))
    
    y = np.array([TARGET2INDEX[sample['target']] for sample in batch])
    
    x = torch.LongTensor(x)
    x_aug = torch.LongTensor(x_aug)
    y = torch.LongTensor(y)
    
    return [x, x_aug, y]

In [17]:
def loader(data, batch_size=32):

    for n_batch in range(math.ceil(len(data) / batch_size)):

        batch = data[n_batch * batch_size:(n_batch + 1) * batch_size]

        batch = batch_processing(batch)

        yield batch

In [18]:
from modelling.layers import BaseModule
from modelling.templates import SequenceTemplate

In [19]:
word_matrix = np.load(W2V_MATRIX_PATH)

In [20]:
class Model(BaseModule):
    
    def __init__(self):
        
        super().__init__()
        
        self.embedding = Embedding(vocab_size=word_matrix.shape[0],
                                   embedding_matrix=word_matrix)
        
        self.dan = DAN((300, 256), activation_function_output=torch.nn.ReLU())
        
        self.linear = torch.nn.Linear(256, 256)
        
        self.activation = torch.nn.ReLU()
        
        self.classifier = torch.nn.Linear(256, 3)
        
    def forward(self, x, x_aug):
        
        x_rep = self.embedding(x)
        x_rep = self.dan(x_rep)
        x_rep = self.linear(x_rep)
        x_rep = torch.nn.functional.log_softmax(x_rep, dim=1)
        
        with torch.no_grad():
            
            x_aug_rep = self.embedding(x_aug)
            x_aug_rep = self.dan(x_aug_rep)
            x_aug_rep = self.linear(x_aug_rep)
            x_aug_rep = torch.nn.functional.softmax(x_aug_rep, dim=1)
    
        y_pred = self.classifier(x_rep)
        
        return x_rep, x_aug_rep, y_pred

In [21]:
model = Model()

In [22]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [23]:
kl_div = torch.nn.KLDivLoss(reduction='batchmean')
cross_entropy = torch.nn.CrossEntropyLoss()

In [24]:
from tsa import TrainingSignalAnnealingScheduler

In [25]:
kl_losses = []
ce_losses = []
losses = []

l = 0.5

N_EPOCHS = 75

tsa = TrainingSignalAnnealingScheduler(total_steps=N_EPOCHS * len(train) // 32, n_classes=3, schedule_type='exp')

for n in range(N_EPOCHS):
    
    epoch_kl_losses = []
    epoch_ce_losses = []
    epoch_losses = []

    model.train()
    
    pg = tqdm(total=len(train), desc=f'Epoch: {n}')

    for x, x_aug, y in loader(train):

        optimizer.zero_grad()

        x_rep, x_aug_rep, y_pred = model(x, x_aug)

        kl_loss = kl_div(x_rep, x_aug_rep)
        epoch_kl_losses.append(kl_loss.item())
        
        y_pred, y = tsa(y_pred, y)
        
        if y_pred.size(0) == 0:
            ce_loss = 0
            epoch_ce_losses.append(ce_loss)
        else:
            ce_loss = cross_entropy(y_pred, y)
            epoch_ce_losses.append(ce_loss.item())

        loss = ce_loss + kl_loss

        loss.backward()

        optimizer.step()
        
        epoch_losses.append(loss.item())

        pg.update(x.shape[0])
        pg.set_postfix(kl_loss=epoch_kl_losses[-1], ce_loss=epoch_ce_losses[-1], loss=epoch_losses[-1])

    pg.close()
    
    test_epoch_kl_losses = []
    test_epoch_ce_losses = []
    test_epoch_losses = []

    model.eval()

    for x, x_aug, y in loader(test):

        x_rep, x_aug_rep, y_pred = model(x, x_aug)

        kl_loss = kl_div(x_rep, x_aug_rep)
        ce_loss = cross_entropy(y_pred, y)

        loss = ce_loss + kl_loss

    test_epoch_kl_losses.append(kl_loss.item())
    test_epoch_ce_losses.append(ce_loss.item())
    test_epoch_losses.append(loss.item())
    
    print('KL Train - {:.3f} | Test - {:.3f}'.format(np.mean(epoch_kl_losses), np.mean(test_epoch_kl_losses)))
    print('CE Train - {:.3f} | Test - {:.3f}'.format(np.mean(epoch_ce_losses), np.mean(test_epoch_ce_losses)))
    print('Aggregated Train - {:.3f} | Test - {:.3f}'.format(np.mean(epoch_losses), np.mean(test_epoch_losses)))
    
    kl_losses.extend(copy.deepcopy(epoch_kl_losses))
    ce_losses.extend(copy.deepcopy(epoch_ce_losses))
    losses.extend(copy.deepcopy(epoch_losses))

Epoch: 0: 100%|██████████| 11623/11623 [00:02<00:00, 4425.68it/s, ce_loss=1.1, kl_loss=0.000135, loss=1.1]   
Epoch: 1:   7%|▋         | 864/11623 [00:00<00:02, 4523.23it/s, ce_loss=1.13, kl_loss=0.000166, loss=1.13]

KL Train - 0.000 | Test - 0.000
CE Train - 1.157 | Test - 1.059
Aggregated Train - 1.157 | Test - 1.059


Epoch: 1: 100%|██████████| 11623/11623 [00:02<00:00, 4433.52it/s, ce_loss=1.09, kl_loss=0.000165, loss=1.09]
Epoch: 2:   8%|▊         | 896/11623 [00:00<00:02, 4578.58it/s, ce_loss=1.14, kl_loss=0.000221, loss=1.14]

KL Train - 0.000 | Test - 0.000
CE Train - 1.108 | Test - 1.056
Aggregated Train - 1.109 | Test - 1.056


Epoch: 2: 100%|██████████| 11623/11623 [00:02<00:00, 4484.05it/s, ce_loss=1.14, kl_loss=0.000157, loss=1.14] 
Epoch: 3:   8%|▊         | 896/11623 [00:00<00:02, 4528.47it/s, ce_loss=1.17, kl_loss=0.000253, loss=1.17]

KL Train - 0.000 | Test - 0.000
CE Train - 1.110 | Test - 1.057
Aggregated Train - 1.111 | Test - 1.057


Epoch: 3: 100%|██████████| 11623/11623 [00:02<00:00, 4473.70it/s, ce_loss=1.09, kl_loss=0.000149, loss=1.09] 
Epoch: 4:   8%|▊         | 896/11623 [00:00<00:02, 4554.34it/s, ce_loss=1.17, kl_loss=0.000211, loss=1.17] 

KL Train - 0.000 | Test - 0.000
CE Train - 1.087 | Test - 1.055
Aggregated Train - 1.087 | Test - 1.056


Epoch: 4: 100%|██████████| 11623/11623 [00:02<00:00, 4487.98it/s, ce_loss=0, kl_loss=0.000159, loss=0.000159]
Epoch: 5:   8%|▊         | 896/11623 [00:00<00:02, 4587.25it/s, ce_loss=1.1, kl_loss=0.000265, loss=1.1]  

KL Train - 0.000 | Test - 0.000
CE Train - 1.066 | Test - 1.009
Aggregated Train - 1.066 | Test - 1.010


Epoch: 5: 100%|██████████| 11623/11623 [00:02<00:00, 4493.01it/s, ce_loss=1.14, kl_loss=0.000177, loss=1.14] 
Epoch: 6:   8%|▊         | 896/11623 [00:00<00:02, 4626.63it/s, ce_loss=1.13, kl_loss=0.000229, loss=1.13] 

KL Train - 0.000 | Test - 0.000
CE Train - 1.067 | Test - 1.044
Aggregated Train - 1.067 | Test - 1.044


Epoch: 6: 100%|██████████| 11623/11623 [00:02<00:00, 4460.47it/s, ce_loss=1.11, kl_loss=0.000142, loss=1.11] 
Epoch: 7:   8%|▊         | 896/11623 [00:00<00:02, 4576.87it/s, ce_loss=1.18, kl_loss=0.000363, loss=1.18] 

KL Train - 0.000 | Test - 0.000
CE Train - 1.070 | Test - 1.037
Aggregated Train - 1.070 | Test - 1.037


Epoch: 7: 100%|██████████| 11623/11623 [00:02<00:00, 4478.86it/s, ce_loss=0, kl_loss=0.000219, loss=0.000219]
Epoch: 8:   8%|▊         | 896/11623 [00:00<00:02, 4593.20it/s, ce_loss=1.11, kl_loss=0.000332, loss=1.11] 

KL Train - 0.000 | Test - 0.000
CE Train - 1.061 | Test - 1.023
Aggregated Train - 1.061 | Test - 1.023


Epoch: 8: 100%|██████████| 11623/11623 [00:02<00:00, 4482.22it/s, ce_loss=0, kl_loss=0.000138, loss=0.000138]
Epoch: 9:   8%|▊         | 896/11623 [00:00<00:02, 4577.93it/s, ce_loss=1.1, kl_loss=0.000421, loss=1.1]   

KL Train - 0.000 | Test - 0.000
CE Train - 1.053 | Test - 1.029
Aggregated Train - 1.054 | Test - 1.029


Epoch: 9: 100%|██████████| 11623/11623 [00:02<00:00, 4483.40it/s, ce_loss=0, kl_loss=0.000206, loss=0.000206]
Epoch: 10:   8%|▊         | 896/11623 [00:00<00:02, 4572.77it/s, ce_loss=1.15, kl_loss=0.000304, loss=1.15] 

KL Train - 0.000 | Test - 0.000
CE Train - 1.029 | Test - 0.987
Aggregated Train - 1.029 | Test - 0.987


Epoch: 10: 100%|██████████| 11623/11623 [00:02<00:00, 4485.50it/s, ce_loss=0, kl_loss=0.000238, loss=0.000238]
Epoch: 11:   8%|▊         | 896/11623 [00:00<00:02, 4602.23it/s, ce_loss=1.09, kl_loss=0.000579, loss=1.09]

KL Train - 0.000 | Test - 0.000
CE Train - 1.024 | Test - 0.996
Aggregated Train - 1.024 | Test - 0.996


Epoch: 11: 100%|██████████| 11623/11623 [00:02<00:00, 4469.91it/s, ce_loss=0, kl_loss=0.000191, loss=0.000191]
Epoch: 12:   7%|▋         | 864/11623 [00:00<00:02, 4538.10it/s, ce_loss=1.11, kl_loss=0.000252, loss=1.11] 

KL Train - 0.000 | Test - 0.000
CE Train - 1.011 | Test - 0.990
Aggregated Train - 1.011 | Test - 0.990


Epoch: 12: 100%|██████████| 11623/11623 [00:02<00:00, 4488.54it/s, ce_loss=0, kl_loss=0.000293, loss=0.000293]
Epoch: 13:   8%|▊         | 896/11623 [00:00<00:02, 4546.45it/s, ce_loss=1.15, kl_loss=0.000509, loss=1.15] 

KL Train - 0.000 | Test - 0.000
CE Train - 1.014 | Test - 1.000
Aggregated Train - 1.015 | Test - 1.000


Epoch: 13: 100%|██████████| 11623/11623 [00:02<00:00, 4478.45it/s, ce_loss=1.1, kl_loss=0.000236, loss=1.1]   
Epoch: 14:   8%|▊         | 896/11623 [00:00<00:02, 4636.38it/s, ce_loss=1.09, kl_loss=0.00042, loss=1.09]  

KL Train - 0.000 | Test - 0.000
CE Train - 0.991 | Test - 0.969
Aggregated Train - 0.992 | Test - 0.970


Epoch: 14: 100%|██████████| 11623/11623 [00:02<00:00, 4448.47it/s, ce_loss=0, kl_loss=0.000386, loss=0.000386]
Epoch: 15:   8%|▊         | 896/11623 [00:00<00:02, 4671.62it/s, ce_loss=1.13, kl_loss=0.000504, loss=1.13] 

KL Train - 0.000 | Test - 0.000
CE Train - 0.959 | Test - 0.975
Aggregated Train - 0.959 | Test - 0.976


Epoch: 15: 100%|██████████| 11623/11623 [00:02<00:00, 4469.74it/s, ce_loss=0, kl_loss=0.000212, loss=0.000212]
Epoch: 16:   8%|▊         | 896/11623 [00:00<00:02, 4548.98it/s, ce_loss=1.11, kl_loss=0.000623, loss=1.11] 

KL Train - 0.000 | Test - 0.001
CE Train - 0.972 | Test - 0.958
Aggregated Train - 0.973 | Test - 0.958


Epoch: 16: 100%|██████████| 11623/11623 [00:02<00:00, 4501.21it/s, ce_loss=1.09, kl_loss=0.000213, loss=1.09] 
Epoch: 17:   7%|▋         | 864/11623 [00:00<00:02, 4530.40it/s, ce_loss=1.25, kl_loss=0.000282, loss=1.25] 

KL Train - 0.000 | Test - 0.001
CE Train - 0.986 | Test - 0.955
Aggregated Train - 0.987 | Test - 0.955


Epoch: 17: 100%|██████████| 11623/11623 [00:02<00:00, 4453.46it/s, ce_loss=0, kl_loss=0.000387, loss=0.000387]
Epoch: 18:   7%|▋         | 864/11623 [00:00<00:02, 4531.60it/s, ce_loss=1.17, kl_loss=0.00037, loss=1.17]  

KL Train - 0.000 | Test - 0.000
CE Train - 1.008 | Test - 0.931
Aggregated Train - 1.009 | Test - 0.931


Epoch: 18: 100%|██████████| 11623/11623 [00:02<00:00, 4495.48it/s, ce_loss=0, kl_loss=0.000463, loss=0.000463]
Epoch: 19:   8%|▊         | 896/11623 [00:00<00:02, 4598.98it/s, ce_loss=1.15, kl_loss=0.000913, loss=1.15] 

KL Train - 0.001 | Test - 0.001
CE Train - 0.959 | Test - 0.944
Aggregated Train - 0.959 | Test - 0.945


Epoch: 19: 100%|██████████| 11623/11623 [00:02<00:00, 4508.22it/s, ce_loss=0, kl_loss=0.000174, loss=0.000174]
Epoch: 20:   8%|▊         | 896/11623 [00:00<00:02, 4536.22it/s, ce_loss=1.08, kl_loss=0.000959, loss=1.09] 

KL Train - 0.001 | Test - 0.000
CE Train - 0.960 | Test - 0.939
Aggregated Train - 0.960 | Test - 0.940


Epoch: 20: 100%|██████████| 11623/11623 [00:02<00:00, 4529.03it/s, ce_loss=0, kl_loss=0.000296, loss=0.000296]
Epoch: 21:   8%|▊         | 896/11623 [00:00<00:02, 4611.17it/s, ce_loss=1.14, kl_loss=0.00106, loss=1.14]  

KL Train - 0.001 | Test - 0.001
CE Train - 0.897 | Test - 0.958
Aggregated Train - 0.898 | Test - 0.958


Epoch: 21: 100%|██████████| 11623/11623 [00:02<00:00, 4522.96it/s, ce_loss=0, kl_loss=0.000277, loss=0.000277]
Epoch: 22:   8%|▊         | 896/11623 [00:00<00:02, 4598.75it/s, ce_loss=1.1, kl_loss=0.000718, loss=1.1]   

KL Train - 0.001 | Test - 0.001
CE Train - 0.905 | Test - 0.921
Aggregated Train - 0.905 | Test - 0.922


Epoch: 22: 100%|██████████| 11623/11623 [00:02<00:00, 4516.42it/s, ce_loss=0, kl_loss=0.000184, loss=0.000184]
Epoch: 23:   8%|▊         | 896/11623 [00:00<00:02, 4603.79it/s, ce_loss=1.1, kl_loss=0.000507, loss=1.1]  

KL Train - 0.001 | Test - 0.001
CE Train - 0.941 | Test - 0.904
Aggregated Train - 0.942 | Test - 0.905


Epoch: 23: 100%|██████████| 11623/11623 [00:02<00:00, 4526.34it/s, ce_loss=0, kl_loss=0.000244, loss=0.000244]
Epoch: 24:   8%|▊         | 896/11623 [00:00<00:02, 4510.94it/s, ce_loss=1.07, kl_loss=0.00137, loss=1.07]  

KL Train - 0.001 | Test - 0.001
CE Train - 0.900 | Test - 0.951
Aggregated Train - 0.901 | Test - 0.951


Epoch: 24: 100%|██████████| 11623/11623 [00:02<00:00, 4474.14it/s, ce_loss=0, kl_loss=0.000369, loss=0.000369]
Epoch: 25:   8%|▊         | 896/11623 [00:00<00:02, 4575.34it/s, ce_loss=1.09, kl_loss=0.00162, loss=1.1]   

KL Train - 0.001 | Test - 0.001
CE Train - 0.949 | Test - 0.905
Aggregated Train - 0.949 | Test - 0.906


Epoch: 25: 100%|██████████| 11623/11623 [00:02<00:00, 4518.02it/s, ce_loss=0, kl_loss=0.000926, loss=0.000926]
Epoch: 26:   8%|▊         | 896/11623 [00:00<00:02, 4620.77it/s, ce_loss=1.16, kl_loss=0.000867, loss=1.16] 

KL Train - 0.001 | Test - 0.001
CE Train - 0.931 | Test - 0.909
Aggregated Train - 0.931 | Test - 0.911


Epoch: 26: 100%|██████████| 11623/11623 [00:02<00:00, 4527.45it/s, ce_loss=0, kl_loss=0.000354, loss=0.000354]
Epoch: 27:   8%|▊         | 896/11623 [00:00<00:02, 4581.21it/s, ce_loss=1.17, kl_loss=0.00119, loss=1.17]  

KL Train - 0.001 | Test - 0.001
CE Train - 0.962 | Test - 0.887
Aggregated Train - 0.963 | Test - 0.888


Epoch: 27: 100%|██████████| 11623/11623 [00:02<00:00, 4534.75it/s, ce_loss=0, kl_loss=0.000187, loss=0.000187]
Epoch: 28:   8%|▊         | 896/11623 [00:00<00:02, 4564.21it/s, ce_loss=1.11, kl_loss=0.000998, loss=1.11] 

KL Train - 0.001 | Test - 0.001
CE Train - 0.916 | Test - 0.902
Aggregated Train - 0.917 | Test - 0.903


Epoch: 28: 100%|██████████| 11623/11623 [00:02<00:00, 4507.34it/s, ce_loss=0, kl_loss=0.00151, loss=0.00151]  
Epoch: 29:   8%|▊         | 896/11623 [00:00<00:02, 4641.67it/s, ce_loss=1.12, kl_loss=0.00199, loss=1.13]  

KL Train - 0.001 | Test - 0.001
CE Train - 0.915 | Test - 0.878
Aggregated Train - 0.916 | Test - 0.879


Epoch: 29: 100%|██████████| 11623/11623 [00:02<00:00, 4535.81it/s, ce_loss=1.15, kl_loss=0.000467, loss=1.15] 
Epoch: 30:   7%|▋         | 864/11623 [00:00<00:02, 4489.25it/s, ce_loss=1.25, kl_loss=0.000786, loss=1.25]

KL Train - 0.001 | Test - 0.001
CE Train - 0.922 | Test - 0.907
Aggregated Train - 0.923 | Test - 0.908


Epoch: 30: 100%|██████████| 11623/11623 [00:02<00:00, 4525.63it/s, ce_loss=0, kl_loss=0.000432, loss=0.000432]
Epoch: 31:   8%|▊         | 896/11623 [00:00<00:02, 4577.66it/s, ce_loss=1.07, kl_loss=0.00126, loss=1.07]   

KL Train - 0.001 | Test - 0.001
CE Train - 0.944 | Test - 0.856
Aggregated Train - 0.945 | Test - 0.857


Epoch: 31: 100%|██████████| 11623/11623 [00:02<00:00, 4518.89it/s, ce_loss=0, kl_loss=0.00108, loss=0.00108]   
Epoch: 32:   8%|▊         | 896/11623 [00:00<00:02, 4595.27it/s, ce_loss=1.06, kl_loss=0.00139, loss=1.06]   

KL Train - 0.001 | Test - 0.001
CE Train - 0.907 | Test - 0.861
Aggregated Train - 0.908 | Test - 0.862


Epoch: 32: 100%|██████████| 11623/11623 [00:02<00:00, 4518.70it/s, ce_loss=1.01, kl_loss=0.00085, loss=1.01]  
Epoch: 33:   8%|▊         | 896/11623 [00:00<00:02, 4772.17it/s, ce_loss=1.04, kl_loss=0.00192, loss=1.05]  

KL Train - 0.001 | Test - 0.001
CE Train - 0.890 | Test - 0.893
Aggregated Train - 0.891 | Test - 0.894


Epoch: 33: 100%|██████████| 11623/11623 [00:02<00:00, 4462.72it/s, ce_loss=0.992, kl_loss=0.000678, loss=0.993]
Epoch: 34:   8%|▊         | 896/11623 [00:00<00:02, 4564.30it/s, ce_loss=1.05, kl_loss=0.00265, loss=1.05]   

KL Train - 0.001 | Test - 0.002
CE Train - 0.897 | Test - 0.880
Aggregated Train - 0.898 | Test - 0.882


Epoch: 34: 100%|██████████| 11623/11623 [00:02<00:00, 4478.48it/s, ce_loss=0, kl_loss=0.000475, loss=0.000475] 
Epoch: 35:   7%|▋         | 864/11623 [00:00<00:02, 4517.22it/s, ce_loss=1.12, kl_loss=0.00139, loss=1.12]  

KL Train - 0.001 | Test - 0.001
CE Train - 0.914 | Test - 0.848
Aggregated Train - 0.915 | Test - 0.848


Epoch: 35: 100%|██████████| 11623/11623 [00:02<00:00, 4504.76it/s, ce_loss=0, kl_loss=0.000454, loss=0.000454]
Epoch: 36:   8%|▊         | 896/11623 [00:00<00:02, 4525.31it/s, ce_loss=1.03, kl_loss=0.0029, loss=1.03]   

KL Train - 0.001 | Test - 0.001
CE Train - 0.879 | Test - 0.861
Aggregated Train - 0.880 | Test - 0.863


Epoch: 36: 100%|██████████| 11623/11623 [00:02<00:00, 4496.68it/s, ce_loss=0.968, kl_loss=0.00128, loss=0.969]
Epoch: 37:   8%|▊         | 896/11623 [00:00<00:02, 4641.04it/s, ce_loss=1.11, kl_loss=0.00207, loss=1.12]   

KL Train - 0.002 | Test - 0.001
CE Train - 0.883 | Test - 0.857
Aggregated Train - 0.885 | Test - 0.858


Epoch: 37: 100%|██████████| 11623/11623 [00:02<00:00, 4490.70it/s, ce_loss=0, kl_loss=0.00205, loss=0.00205]  
Epoch: 38:   8%|▊         | 896/11623 [00:00<00:02, 4557.23it/s, ce_loss=1.03, kl_loss=0.00171, loss=1.03]   

KL Train - 0.002 | Test - 0.002
CE Train - 0.960 | Test - 0.844
Aggregated Train - 0.962 | Test - 0.845


Epoch: 38: 100%|██████████| 11623/11623 [00:02<00:00, 4499.88it/s, ce_loss=0.952, kl_loss=0.000936, loss=0.953]
Epoch: 39:   8%|▊         | 896/11623 [00:00<00:02, 4500.41it/s, ce_loss=1.21, kl_loss=0.00135, loss=1.21]  

KL Train - 0.002 | Test - 0.001
CE Train - 0.902 | Test - 0.866
Aggregated Train - 0.904 | Test - 0.867


Epoch: 39: 100%|██████████| 11623/11623 [00:02<00:00, 4467.94it/s, ce_loss=0, kl_loss=0.00137, loss=0.00137]   
Epoch: 40:   8%|▊         | 896/11623 [00:00<00:02, 4521.01it/s, ce_loss=1.14, kl_loss=0.00197, loss=1.14]  

KL Train - 0.002 | Test - 0.002
CE Train - 0.943 | Test - 0.844
Aggregated Train - 0.945 | Test - 0.847


Epoch: 40: 100%|██████████| 11623/11623 [00:02<00:00, 4474.35it/s, ce_loss=0, kl_loss=0.00101, loss=0.00101]  
Epoch: 41:   8%|▊         | 896/11623 [00:00<00:02, 4576.67it/s, ce_loss=0.978, kl_loss=0.00248, loss=0.98] 

KL Train - 0.002 | Test - 0.002
CE Train - 0.925 | Test - 0.845
Aggregated Train - 0.927 | Test - 0.847


Epoch: 41: 100%|██████████| 11623/11623 [00:02<00:00, 4507.40it/s, ce_loss=0.944, kl_loss=0.00126, loss=0.946]
Epoch: 42:   8%|▊         | 896/11623 [00:00<00:02, 4615.65it/s, ce_loss=1.04, kl_loss=0.00381, loss=1.04]   

KL Train - 0.002 | Test - 0.001
CE Train - 0.914 | Test - 0.841
Aggregated Train - 0.916 | Test - 0.842


Epoch: 42: 100%|██████████| 11623/11623 [00:02<00:00, 4472.89it/s, ce_loss=0.909, kl_loss=0.00148, loss=0.911]
Epoch: 43:   8%|▊         | 896/11623 [00:00<00:02, 4608.78it/s, ce_loss=1.06, kl_loss=0.00381, loss=1.07]  

KL Train - 0.002 | Test - 0.002
CE Train - 0.940 | Test - 0.838
Aggregated Train - 0.943 | Test - 0.840


Epoch: 43: 100%|██████████| 11623/11623 [00:02<00:00, 4541.46it/s, ce_loss=0, kl_loss=0.00124, loss=0.00124]  
Epoch: 44:   7%|▋         | 864/11623 [00:00<00:02, 4505.05it/s, ce_loss=1.04, kl_loss=0.00186, loss=1.04]  

KL Train - 0.002 | Test - 0.002
CE Train - 0.909 | Test - 0.838
Aggregated Train - 0.912 | Test - 0.840


Epoch: 44: 100%|██████████| 11623/11623 [00:02<00:00, 4558.47it/s, ce_loss=0.909, kl_loss=0.00243, loss=0.912]
Epoch: 45:   8%|▊         | 896/11623 [00:00<00:02, 4602.64it/s, ce_loss=1.04, kl_loss=0.00405, loss=1.04]  

KL Train - 0.003 | Test - 0.002
CE Train - 0.907 | Test - 0.836
Aggregated Train - 0.909 | Test - 0.839


Epoch: 45: 100%|██████████| 11623/11623 [00:02<00:00, 4597.81it/s, ce_loss=0.858, kl_loss=0.00212, loss=0.86] 
Epoch: 46:   8%|▊         | 928/11623 [00:00<00:02, 4848.21it/s, ce_loss=1.02, kl_loss=0.00633, loss=1.02]  

KL Train - 0.003 | Test - 0.002
CE Train - 0.882 | Test - 0.814
Aggregated Train - 0.885 | Test - 0.816


Epoch: 46: 100%|██████████| 11623/11623 [00:02<00:00, 4713.71it/s, ce_loss=0, kl_loss=0.00204, loss=0.00204]  
Epoch: 47:   8%|▊         | 960/11623 [00:00<00:02, 4874.34it/s, ce_loss=1.04, kl_loss=0.0022, loss=1.04]   

KL Train - 0.003 | Test - 0.002
CE Train - 0.892 | Test - 0.795
Aggregated Train - 0.895 | Test - 0.796


Epoch: 47: 100%|██████████| 11623/11623 [00:02<00:00, 4752.44it/s, ce_loss=0, kl_loss=0.000723, loss=0.000723]
Epoch: 48:   8%|▊         | 928/11623 [00:00<00:02, 4786.45it/s, ce_loss=0.965, kl_loss=0.00241, loss=0.968]

KL Train - 0.003 | Test - 0.001
CE Train - 0.897 | Test - 0.783
Aggregated Train - 0.901 | Test - 0.783


Epoch: 48: 100%|██████████| 11623/11623 [00:02<00:00, 4804.75it/s, ce_loss=0.814, kl_loss=0.000973, loss=0.815]
Epoch: 49:   8%|▊         | 960/11623 [00:00<00:02, 4895.77it/s, ce_loss=0.971, kl_loss=0.0029, loss=0.974] 

KL Train - 0.003 | Test - 0.002
CE Train - 0.883 | Test - 0.810
Aggregated Train - 0.886 | Test - 0.812


Epoch: 49: 100%|██████████| 11623/11623 [00:02<00:00, 4800.25it/s, ce_loss=0.816, kl_loss=0.00134, loss=0.817]
Epoch: 50:   8%|▊         | 960/11623 [00:00<00:02, 4857.36it/s, ce_loss=0.939, kl_loss=0.00454, loss=0.944]

KL Train - 0.004 | Test - 0.002
CE Train - 0.874 | Test - 0.807
Aggregated Train - 0.878 | Test - 0.809


Epoch: 50: 100%|██████████| 11623/11623 [00:02<00:00, 4788.09it/s, ce_loss=0, kl_loss=0.00263, loss=0.00263]  
Epoch: 51:   8%|▊         | 928/11623 [00:00<00:02, 4794.29it/s, ce_loss=0.869, kl_loss=0.00478, loss=0.873]

KL Train - 0.004 | Test - 0.003
CE Train - 0.873 | Test - 0.766
Aggregated Train - 0.877 | Test - 0.769


Epoch: 51: 100%|██████████| 11623/11623 [00:02<00:00, 4630.03it/s, ce_loss=0.764, kl_loss=0.00193, loss=0.766]
Epoch: 52:   7%|▋         | 864/11623 [00:00<00:02, 4524.28it/s, ce_loss=0.972, kl_loss=0.00236, loss=0.974]

KL Train - 0.004 | Test - 0.005
CE Train - 0.854 | Test - 0.782
Aggregated Train - 0.858 | Test - 0.787


Epoch: 52: 100%|██████████| 11623/11623 [00:02<00:00, 4458.47it/s, ce_loss=0, kl_loss=0.00495, loss=0.00495]  
Epoch: 53:   7%|▋         | 864/11623 [00:00<00:02, 4459.88it/s, ce_loss=0.96, kl_loss=0.00211, loss=0.963] 

KL Train - 0.004 | Test - 0.004
CE Train - 0.840 | Test - 0.773
Aggregated Train - 0.845 | Test - 0.777


Epoch: 53: 100%|██████████| 11623/11623 [00:02<00:00, 4481.85it/s, ce_loss=0.702, kl_loss=0.00281, loss=0.705]
Epoch: 54:   8%|▊         | 896/11623 [00:00<00:02, 4560.19it/s, ce_loss=0.837, kl_loss=0.0051, loss=0.842] 

KL Train - 0.005 | Test - 0.004
CE Train - 0.824 | Test - 0.765
Aggregated Train - 0.829 | Test - 0.769


Epoch: 54: 100%|██████████| 11623/11623 [00:02<00:00, 4500.07it/s, ce_loss=0.685, kl_loss=0.00222, loss=0.687]
Epoch: 55:   8%|▊         | 896/11623 [00:00<00:02, 4569.34it/s, ce_loss=0.869, kl_loss=0.00758, loss=0.877]

KL Train - 0.005 | Test - 0.004
CE Train - 0.808 | Test - 0.743
Aggregated Train - 0.813 | Test - 0.748


Epoch: 55: 100%|██████████| 11623/11623 [00:02<00:00, 4494.20it/s, ce_loss=0.67, kl_loss=0.00259, loss=0.673] 
Epoch: 56:   8%|▊         | 896/11623 [00:00<00:02, 4653.51it/s, ce_loss=0.816, kl_loss=0.00587, loss=0.822]

KL Train - 0.005 | Test - 0.004
CE Train - 0.795 | Test - 0.731
Aggregated Train - 0.800 | Test - 0.735


Epoch: 56: 100%|██████████| 11623/11623 [00:02<00:00, 4495.26it/s, ce_loss=0.629, kl_loss=0.00322, loss=0.633]
Epoch: 57:   8%|▊         | 896/11623 [00:00<00:02, 4536.99it/s, ce_loss=0.77, kl_loss=0.00546, loss=0.775] 

KL Train - 0.006 | Test - 0.005
CE Train - 0.779 | Test - 0.734
Aggregated Train - 0.785 | Test - 0.739


Epoch: 57: 100%|██████████| 11623/11623 [00:02<00:00, 4481.25it/s, ce_loss=0.624, kl_loss=0.00528, loss=0.629]
Epoch: 58:   8%|▊         | 896/11623 [00:00<00:02, 4657.22it/s, ce_loss=0.785, kl_loss=0.00855, loss=0.793]

KL Train - 0.006 | Test - 0.004
CE Train - 0.763 | Test - 0.703
Aggregated Train - 0.769 | Test - 0.707


Epoch: 58: 100%|██████████| 11623/11623 [00:02<00:00, 4504.67it/s, ce_loss=0.593, kl_loss=0.00849, loss=0.602]
Epoch: 59:   8%|▊         | 896/11623 [00:00<00:02, 4534.26it/s, ce_loss=0.772, kl_loss=0.0104, loss=0.782] 

KL Train - 0.006 | Test - 0.004
CE Train - 0.750 | Test - 0.684
Aggregated Train - 0.756 | Test - 0.688


Epoch: 59: 100%|██████████| 11623/11623 [00:02<00:00, 4506.22it/s, ce_loss=0.588, kl_loss=0.00603, loss=0.594]
Epoch: 60:   8%|▊         | 896/11623 [00:00<00:02, 4555.51it/s, ce_loss=0.767, kl_loss=0.0146, loss=0.781] 

KL Train - 0.007 | Test - 0.006
CE Train - 0.734 | Test - 0.672
Aggregated Train - 0.741 | Test - 0.678


Epoch: 60: 100%|██████████| 11623/11623 [00:02<00:00, 4494.16it/s, ce_loss=0.549, kl_loss=0.00285, loss=0.552]
Epoch: 61:   8%|▊         | 896/11623 [00:00<00:02, 4570.43it/s, ce_loss=0.715, kl_loss=0.00874, loss=0.724]

KL Train - 0.007 | Test - 0.005
CE Train - 0.717 | Test - 0.651
Aggregated Train - 0.724 | Test - 0.656


Epoch: 61: 100%|██████████| 11623/11623 [00:02<00:00, 4470.28it/s, ce_loss=0.495, kl_loss=0.00288, loss=0.498]
Epoch: 62:   8%|▊         | 896/11623 [00:00<00:02, 4722.60it/s, ce_loss=0.783, kl_loss=0.0085, loss=0.792] 

KL Train - 0.007 | Test - 0.004
CE Train - 0.702 | Test - 0.637
Aggregated Train - 0.709 | Test - 0.640


Epoch: 62: 100%|██████████| 11623/11623 [00:02<00:00, 4496.38it/s, ce_loss=0.493, kl_loss=0.00345, loss=0.496]
Epoch: 63:   8%|▊         | 896/11623 [00:00<00:02, 4521.77it/s, ce_loss=0.681, kl_loss=0.00696, loss=0.687]

KL Train - 0.008 | Test - 0.009
CE Train - 0.688 | Test - 0.613
Aggregated Train - 0.696 | Test - 0.622


Epoch: 63: 100%|██████████| 11623/11623 [00:02<00:00, 4505.61it/s, ce_loss=0.436, kl_loss=0.00605, loss=0.442]
Epoch: 64:   8%|▊         | 896/11623 [00:00<00:02, 4540.86it/s, ce_loss=0.753, kl_loss=0.00725, loss=0.76] 

KL Train - 0.008 | Test - 0.006
CE Train - 0.665 | Test - 0.608
Aggregated Train - 0.673 | Test - 0.614


Epoch: 64: 100%|██████████| 11623/11623 [00:02<00:00, 4508.14it/s, ce_loss=0.406, kl_loss=0.00304, loss=0.409]
Epoch: 65:   8%|▊         | 896/11623 [00:00<00:02, 4608.42it/s, ce_loss=0.633, kl_loss=0.00636, loss=0.639]

KL Train - 0.008 | Test - 0.007
CE Train - 0.645 | Test - 0.599
Aggregated Train - 0.653 | Test - 0.606


Epoch: 65: 100%|██████████| 11623/11623 [00:02<00:00, 4510.47it/s, ce_loss=0.409, kl_loss=0.02, loss=0.429]   
Epoch: 66:   8%|▊         | 896/11623 [00:00<00:02, 4651.16it/s, ce_loss=0.606, kl_loss=0.0101, loss=0.616] 

KL Train - 0.009 | Test - 0.007
CE Train - 0.624 | Test - 0.588
Aggregated Train - 0.632 | Test - 0.595


Epoch: 66: 100%|██████████| 11623/11623 [00:02<00:00, 4507.99it/s, ce_loss=0.379, kl_loss=0.00384, loss=0.383]
Epoch: 67:   7%|▋         | 864/11623 [00:00<00:02, 4463.80it/s, ce_loss=0.896, kl_loss=0.0079, loss=0.904] 

KL Train - 0.009 | Test - 0.008
CE Train - 0.600 | Test - 0.575
Aggregated Train - 0.609 | Test - 0.583


Epoch: 67: 100%|██████████| 11623/11623 [00:02<00:00, 4470.09it/s, ce_loss=0.379, kl_loss=0.00406, loss=0.383]
Epoch: 68:   7%|▋         | 832/11623 [00:00<00:02, 4150.03it/s, ce_loss=0.53, kl_loss=0.0116, loss=0.542]  

KL Train - 0.010 | Test - 0.011
CE Train - 0.577 | Test - 0.562
Aggregated Train - 0.587 | Test - 0.574


Epoch: 68: 100%|██████████| 11623/11623 [00:02<00:00, 4444.67it/s, ce_loss=0.421, kl_loss=0.00593, loss=0.427]
Epoch: 69:   8%|▊         | 960/11623 [00:00<00:02, 4779.97it/s, ce_loss=0.678, kl_loss=0.00857, loss=0.687]

KL Train - 0.010 | Test - 0.005
CE Train - 0.548 | Test - 0.538
Aggregated Train - 0.558 | Test - 0.543


Epoch: 69: 100%|██████████| 11623/11623 [00:02<00:00, 4710.69it/s, ce_loss=0.337, kl_loss=0.00772, loss=0.345]
Epoch: 70:   8%|▊         | 928/11623 [00:00<00:02, 4701.22it/s, ce_loss=0.378, kl_loss=0.00784, loss=0.386]

KL Train - 0.010 | Test - 0.006
CE Train - 0.521 | Test - 0.528
Aggregated Train - 0.531 | Test - 0.534


Epoch: 70: 100%|██████████| 11623/11623 [00:02<00:00, 4383.10it/s, ce_loss=0.274, kl_loss=0.0108, loss=0.285] 
Epoch: 71:   7%|▋         | 864/11623 [00:00<00:02, 4489.38it/s, ce_loss=0.873, kl_loss=0.00699, loss=0.88] 

KL Train - 0.011 | Test - 0.008
CE Train - 0.488 | Test - 0.513
Aggregated Train - 0.499 | Test - 0.521


Epoch: 71: 100%|██████████| 11623/11623 [00:02<00:00, 4359.63it/s, ce_loss=0.252, kl_loss=0.00938, loss=0.261]
Epoch: 72:   8%|▊         | 896/11623 [00:00<00:02, 4524.68it/s, ce_loss=0.475, kl_loss=0.0144, loss=0.49]  

KL Train - 0.012 | Test - 0.007
CE Train - 0.453 | Test - 0.494
Aggregated Train - 0.465 | Test - 0.501


Epoch: 72: 100%|██████████| 11623/11623 [00:02<00:00, 4500.11it/s, ce_loss=0.203, kl_loss=0.00713, loss=0.21] 
Epoch: 73:   7%|▋         | 864/11623 [00:00<00:02, 4459.88it/s, ce_loss=0.812, kl_loss=0.00781, loss=0.82] 

KL Train - 0.012 | Test - 0.015
CE Train - 0.412 | Test - 0.480
Aggregated Train - 0.424 | Test - 0.495


Epoch: 73: 100%|██████████| 11623/11623 [00:02<00:00, 4592.12it/s, ce_loss=0.185, kl_loss=0.0111, loss=0.197] 
Epoch: 74:   8%|▊         | 896/11623 [00:00<00:02, 4584.97it/s, ce_loss=0.321, kl_loss=0.0179, loss=0.339] 

KL Train - 0.013 | Test - 0.005
CE Train - 0.364 | Test - 0.469
Aggregated Train - 0.377 | Test - 0.474


Epoch: 74: 100%|██████████| 11623/11623 [00:02<00:00, 4714.55it/s, ce_loss=0.123, kl_loss=0.0135, loss=0.136] 


KL Train - 0.013 | Test - 0.015
CE Train - 0.301 | Test - 0.459
Aggregated Train - 0.315 | Test - 0.474


In [26]:
tsa.tsa.current_prob

1.0

In [25]:
class Model(BaseModule):
    
    def __init__(self):
        
        super().__init__()
        
        self.embedding = Embedding(vocab_size=word_matrix.shape[0],
                                   embedding_matrix=word_matrix)
        
        self.dan = DAN((300, 256), activation_function_output=torch.nn.ReLU())
        
        self.linear = torch.nn.Linear(256, 256)
        
        self.activation = torch.nn.ReLU()
        
        self.classifier = torch.nn.Linear(256, 3)
        
    def forward(self, x):
        
        x_rep = self.embedding(x)
        x_rep = self.dan(x_rep)
        x_rep = self.linear(x_rep)
        x_rep = torch.nn.functional.log_softmax(x_rep, dim=1)
    
        y_pred = self.classifier(x_rep)
        
        return y_pred

In [26]:
model = Model()

In [27]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [28]:
losses = []

l = 0.5

N_EPOCHS = 75

tsa = TrainingSignalAnnealingScheduler(total_steps=N_EPOCHS * len(train) // 32, n_classes=3, schedule_type='exp')

for n in range(N_EPOCHS):
    
    epoch_losses = []

    model.train()
    
    pg = tqdm(total=len(train), desc=f'Epoch: {n}')

    for x, _, y in loader(train):

        optimizer.zero_grad()

        y_pred = model(x)
        
        y_pred, y = tsa(y_pred, y)
        
        if y_pred.size(0) == 0:
            loss = 0
            epoch_losses.append(loss)
        else:
            loss = cross_entropy(y_pred, y)
            epoch_losses.append(loss.item())
            
            loss.backward()

            optimizer.step()

#         loss = cross_entropy(y_pred, y)

        pg.update(x.shape[0])
        pg.set_postfix(loss=epoch_losses[-1])

    pg.close()

    test_epoch_losses = []

    model.eval()

    for x, x_aug, y in loader(test):

        y_pred = model(x)

        loss = cross_entropy(y_pred, y)

    test_epoch_losses.append(loss.item())
    
    print('Loss Train - {:.3f} | Test - {:.3f}'.format(np.mean(epoch_losses), np.mean(test_epoch_losses)))

    losses.extend(copy.deepcopy(epoch_losses))

Epoch: 0: 100%|██████████| 11623/11623 [00:02<00:00, 4992.84it/s, loss=0]   
Epoch: 1:   9%|▊         | 992/11623 [00:00<00:02, 5091.79it/s, loss=1.1] 

Loss Train - 1.209 | Test - 1.026


Epoch: 1: 100%|██████████| 11623/11623 [00:02<00:00, 5003.40it/s, loss=1.09]
Epoch: 2:   9%|▊         | 992/11623 [00:00<00:02, 5113.96it/s, loss=1.12]

Loss Train - 1.124 | Test - 1.032


Epoch: 2: 100%|██████████| 11623/11623 [00:02<00:00, 5034.19it/s, loss=0]   
Epoch: 3:   9%|▊         | 992/11623 [00:00<00:02, 5146.64it/s, loss=1.12]

Loss Train - 1.099 | Test - 1.013


Epoch: 3: 100%|██████████| 11623/11623 [00:02<00:00, 5035.44it/s, loss=0]   
Epoch: 4:   9%|▉         | 1024/11623 [00:00<00:01, 5367.90it/s, loss=1.15]

Loss Train - 1.085 | Test - 1.004


Epoch: 4: 100%|██████████| 11623/11623 [00:02<00:00, 5046.79it/s, loss=1.16]
Epoch: 5:   9%|▊         | 992/11623 [00:00<00:02, 5144.26it/s, loss=1.18]

Loss Train - 1.066 | Test - 1.023


Epoch: 5: 100%|██████████| 11623/11623 [00:02<00:00, 5057.68it/s, loss=1.08]
Epoch: 6:   9%|▊         | 992/11623 [00:00<00:02, 5181.45it/s, loss=1.14]

Loss Train - 1.074 | Test - 1.026


Epoch: 6: 100%|██████████| 11623/11623 [00:02<00:00, 5059.88it/s, loss=0]   
Epoch: 7:   9%|▊         | 992/11623 [00:00<00:02, 5109.73it/s, loss=1.2] 

Loss Train - 1.074 | Test - 1.008


Epoch: 7: 100%|██████████| 11623/11623 [00:02<00:00, 5038.79it/s, loss=0]   
Epoch: 8:   9%|▊         | 992/11623 [00:00<00:02, 5237.16it/s, loss=1.12]

Loss Train - 1.075 | Test - 0.983


Epoch: 8: 100%|██████████| 11623/11623 [00:02<00:00, 5082.40it/s, loss=0]   
Epoch: 9:   9%|▉         | 1024/11623 [00:00<00:01, 5313.54it/s, loss=1.16]

Loss Train - 1.024 | Test - 0.978


Epoch: 9: 100%|██████████| 11623/11623 [00:02<00:00, 5027.60it/s, loss=1.07]
Epoch: 10:   9%|▉         | 1024/11623 [00:00<00:01, 5446.09it/s, loss=1.11]

Loss Train - 1.043 | Test - 0.981


Epoch: 10: 100%|██████████| 11623/11623 [00:02<00:00, 5082.54it/s, loss=0]   
Epoch: 11:   9%|▉         | 1024/11623 [00:00<00:01, 5355.18it/s, loss=1.09]

Loss Train - 0.995 | Test - 0.992


Epoch: 11: 100%|██████████| 11623/11623 [00:02<00:00, 5100.61it/s, loss=0]   
Epoch: 12:   9%|▉         | 1024/11623 [00:00<00:02, 5222.62it/s, loss=1.1]

Loss Train - 0.969 | Test - 0.963


Epoch: 12: 100%|██████████| 11623/11623 [00:02<00:00, 5139.23it/s, loss=0]   
Epoch: 13:   9%|▉         | 1024/11623 [00:00<00:02, 5141.92it/s, loss=1.12]

Loss Train - 0.977 | Test - 1.012


Epoch: 13: 100%|██████████| 11623/11623 [00:02<00:00, 5221.97it/s, loss=0]   
Epoch: 14:   9%|▉         | 1088/11623 [00:00<00:01, 5426.20it/s, loss=1.09]

Loss Train - 0.995 | Test - 0.994


Epoch: 14: 100%|██████████| 11623/11623 [00:02<00:00, 5294.53it/s, loss=0]   
Epoch: 15:   8%|▊         | 960/11623 [00:00<00:02, 4877.91it/s, loss=1.14]

Loss Train - 0.999 | Test - 0.978


Epoch: 15: 100%|██████████| 11623/11623 [00:02<00:00, 5173.43it/s, loss=0]   
Epoch: 16:   9%|▊         | 992/11623 [00:00<00:01, 5417.68it/s, loss=1.11]

Loss Train - 0.986 | Test - 0.980


Epoch: 16: 100%|██████████| 11623/11623 [00:02<00:00, 5093.60it/s, loss=0]   
Epoch: 17:   9%|▉         | 1056/11623 [00:00<00:01, 5742.21it/s, loss=1.19]

Loss Train - 0.959 | Test - 0.946


Epoch: 17: 100%|██████████| 11623/11623 [00:02<00:00, 5315.98it/s, loss=0]   
Epoch: 18:   9%|▉         | 1056/11623 [00:00<00:01, 5463.96it/s, loss=1.07]

Loss Train - 1.006 | Test - 0.937


Epoch: 18: 100%|██████████| 11623/11623 [00:02<00:00, 5392.96it/s, loss=0]   
Epoch: 19:   9%|▉         | 1088/11623 [00:00<00:01, 5645.57it/s, loss=1.13]

Loss Train - 0.920 | Test - 0.950


Epoch: 19: 100%|██████████| 11623/11623 [00:02<00:00, 5428.72it/s, loss=0]   
Epoch: 20:  10%|▉         | 1120/11623 [00:00<00:01, 5916.08it/s, loss=1.08]

Loss Train - 0.967 | Test - 0.942


Epoch: 20: 100%|██████████| 11623/11623 [00:02<00:00, 5433.75it/s, loss=0]   
Epoch: 21:  10%|▉         | 1120/11623 [00:00<00:01, 5658.33it/s, loss=1.06]

Loss Train - 0.916 | Test - 0.940


Epoch: 21: 100%|██████████| 11623/11623 [00:02<00:00, 5393.96it/s, loss=0]   
Epoch: 22:   9%|▉         | 1056/11623 [00:00<00:01, 5485.81it/s, loss=1.12]

Loss Train - 0.967 | Test - 0.939


Epoch: 22: 100%|██████████| 11623/11623 [00:02<00:00, 5335.15it/s, loss=0]   
Epoch: 23:   9%|▉         | 1088/11623 [00:00<00:01, 5701.84it/s, loss=0]   

Loss Train - 0.928 | Test - 0.927


Epoch: 23: 100%|██████████| 11623/11623 [00:02<00:00, 5384.74it/s, loss=0]   
Epoch: 24:  10%|▉         | 1120/11623 [00:00<00:01, 5627.25it/s, loss=1.06]

Loss Train - 0.909 | Test - 0.934


Epoch: 24: 100%|██████████| 11623/11623 [00:02<00:00, 5543.52it/s, loss=0]   
Epoch: 25:  10%|▉         | 1120/11623 [00:00<00:01, 5709.92it/s, loss=1.04]

Loss Train - 0.878 | Test - 0.885


Epoch: 25: 100%|██████████| 11623/11623 [00:02<00:00, 5429.54it/s, loss=1.03]
Epoch: 26:   9%|▉         | 1088/11623 [00:00<00:01, 5511.86it/s, loss=1.04]

Loss Train - 0.942 | Test - 0.881


Epoch: 26: 100%|██████████| 11623/11623 [00:02<00:00, 5527.99it/s, loss=0]   
Epoch: 27:   9%|▉         | 1088/11623 [00:00<00:01, 5563.49it/s, loss=0]   

Loss Train - 0.943 | Test - 0.891


Epoch: 27: 100%|██████████| 11623/11623 [00:02<00:00, 5468.42it/s, loss=0]   
Epoch: 28:  10%|▉         | 1120/11623 [00:00<00:01, 5656.38it/s, loss=1.12]

Loss Train - 0.945 | Test - 0.884


Epoch: 28: 100%|██████████| 11623/11623 [00:02<00:00, 5519.86it/s, loss=0]   
Epoch: 29:   9%|▉         | 1088/11623 [00:00<00:01, 5724.72it/s, loss=1.03]

Loss Train - 0.901 | Test - 0.874


Epoch: 29: 100%|██████████| 11623/11623 [00:02<00:00, 5485.35it/s, loss=0]   
Epoch: 30:   9%|▉         | 1056/11623 [00:00<00:01, 5439.77it/s, loss=1.09]

Loss Train - 0.896 | Test - 0.820


Epoch: 30: 100%|██████████| 11623/11623 [00:02<00:00, 5482.33it/s, loss=0]   
Epoch: 31:   9%|▉         | 1024/11623 [00:00<00:02, 5291.98it/s, loss=1.12]

Loss Train - 0.895 | Test - 0.854


Epoch: 31: 100%|██████████| 11623/11623 [00:02<00:00, 5378.93it/s, loss=0]    
Epoch: 32:   9%|▉         | 1088/11623 [00:00<00:01, 5617.70it/s, loss=0.994]

Loss Train - 0.908 | Test - 0.843


Epoch: 32: 100%|██████████| 11623/11623 [00:02<00:00, 5479.02it/s, loss=0]    
Epoch: 33:   9%|▉         | 1088/11623 [00:00<00:01, 5791.36it/s, loss=0.987]

Loss Train - 0.904 | Test - 0.834


Epoch: 33: 100%|██████████| 11623/11623 [00:02<00:00, 5458.00it/s, loss=0]    
Epoch: 34:   9%|▉         | 1088/11623 [00:00<00:01, 5684.85it/s, loss=1]   

Loss Train - 0.944 | Test - 0.825


Epoch: 34: 100%|██████████| 11623/11623 [00:02<00:00, 5346.05it/s, loss=0]    
Epoch: 35:   9%|▉         | 1024/11623 [00:00<00:01, 5351.65it/s, loss=1.32]

Loss Train - 0.944 | Test - 0.817


Epoch: 35: 100%|██████████| 11623/11623 [00:02<00:00, 5305.98it/s, loss=0.96] 
Epoch: 36:   9%|▉         | 1024/11623 [00:00<00:02, 5200.77it/s, loss=1.06]

Loss Train - 0.938 | Test - 0.807


Epoch: 36: 100%|██████████| 11623/11623 [00:02<00:00, 5414.48it/s, loss=0.965]
Epoch: 37:   9%|▉         | 1088/11623 [00:00<00:01, 5638.72it/s, loss=0.958]

Loss Train - 0.912 | Test - 0.768


Epoch: 37: 100%|██████████| 11623/11623 [00:02<00:00, 5431.75it/s, loss=0]    
Epoch: 38:   9%|▉         | 1088/11623 [00:00<00:01, 5506.47it/s, loss=0.962]

Loss Train - 0.914 | Test - 0.797


Epoch: 38: 100%|██████████| 11623/11623 [00:02<00:00, 5424.71it/s, loss=0]    
Epoch: 39:   9%|▉         | 1088/11623 [00:00<00:01, 5484.41it/s, loss=0.954]

Loss Train - 0.890 | Test - 0.793


Epoch: 39: 100%|██████████| 11623/11623 [00:02<00:00, 5455.01it/s, loss=0.926]
Epoch: 40:   9%|▉         | 1088/11623 [00:00<00:01, 5511.02it/s, loss=1.02] 

Loss Train - 0.905 | Test - 0.761


Epoch: 40: 100%|██████████| 11623/11623 [00:02<00:00, 5332.55it/s, loss=0]    
Epoch: 41:   9%|▉         | 1056/11623 [00:00<00:01, 5482.33it/s, loss=0.965]

Loss Train - 0.923 | Test - 0.771


Epoch: 41: 100%|██████████| 11623/11623 [00:02<00:00, 5437.80it/s, loss=0.908]
Epoch: 42:   9%|▉         | 1024/11623 [00:00<00:01, 5311.10it/s, loss=1.05]

Loss Train - 0.906 | Test - 0.796


Epoch: 42: 100%|██████████| 11623/11623 [00:02<00:00, 5439.05it/s, loss=0]    
Epoch: 43:   9%|▉         | 1088/11623 [00:00<00:01, 5455.01it/s, loss=0.951]

Loss Train - 0.885 | Test - 0.768


Epoch: 43: 100%|██████████| 11623/11623 [00:02<00:00, 5393.58it/s, loss=0]    
Epoch: 44:   9%|▉         | 1056/11623 [00:00<00:01, 5478.15it/s, loss=0.956]

Loss Train - 0.929 | Test - 0.756


Epoch: 44: 100%|██████████| 11623/11623 [00:02<00:00, 5382.99it/s, loss=0]    
Epoch: 45:   9%|▉         | 1088/11623 [00:00<00:01, 5492.16it/s, loss=0.885]

Loss Train - 0.915 | Test - 0.729


Epoch: 45: 100%|██████████| 11623/11623 [00:02<00:00, 5357.54it/s, loss=0]    
Epoch: 46:   9%|▉         | 1056/11623 [00:00<00:01, 5433.71it/s, loss=0.945]

Loss Train - 0.895 | Test - 0.739


Epoch: 46: 100%|██████████| 11623/11623 [00:02<00:00, 5270.82it/s, loss=0.843]
Epoch: 47:   9%|▉         | 1088/11623 [00:00<00:01, 5452.16it/s, loss=0.881]

Loss Train - 0.893 | Test - 0.736


Epoch: 47: 100%|██████████| 11623/11623 [00:02<00:00, 5209.64it/s, loss=0.844]
Epoch: 48:   9%|▊         | 992/11623 [00:00<00:02, 5098.69it/s, loss=0.967]

Loss Train - 0.901 | Test - 0.703


Epoch: 48: 100%|██████████| 11623/11623 [00:02<00:00, 5108.40it/s, loss=0]    
Epoch: 49:   9%|▉         | 1024/11623 [00:00<00:01, 5317.23it/s, loss=0.916]

Loss Train - 0.890 | Test - 0.705


Epoch: 49: 100%|██████████| 11623/11623 [00:02<00:00, 5120.76it/s, loss=0]    
Epoch: 50:   9%|▉         | 1024/11623 [00:00<00:02, 5298.99it/s, loss=0.944]

Loss Train - 0.883 | Test - 0.689


Epoch: 50: 100%|██████████| 11623/11623 [00:02<00:00, 5282.17it/s, loss=0]    
Epoch: 51:   9%|▉         | 1056/11623 [00:00<00:01, 5349.16it/s, loss=0.909]

Loss Train - 0.880 | Test - 0.680


Epoch: 51: 100%|██████████| 11623/11623 [00:02<00:00, 5162.33it/s, loss=0.773]
Epoch: 52:   9%|▉         | 1056/11623 [00:00<00:01, 5454.76it/s, loss=0.877]

Loss Train - 0.856 | Test - 0.634


Epoch: 52: 100%|██████████| 11623/11623 [00:02<00:00, 5331.90it/s, loss=0]    
Epoch: 53:   9%|▉         | 1024/11623 [00:00<00:02, 5210.53it/s, loss=0.847]

Loss Train - 0.846 | Test - 0.666


Epoch: 53: 100%|██████████| 11623/11623 [00:02<00:00, 5182.49it/s, loss=0]    
Epoch: 54:   9%|▉         | 1056/11623 [00:00<00:01, 5443.93it/s, loss=0.876]

Loss Train - 0.830 | Test - 0.649


Epoch: 54: 100%|██████████| 11623/11623 [00:02<00:00, 5329.79it/s, loss=0.727]
Epoch: 55:   9%|▉         | 1024/11623 [00:00<00:01, 5463.49it/s, loss=0.8] 

Loss Train - 0.823 | Test - 0.641


Epoch: 55: 100%|██████████| 11623/11623 [00:02<00:00, 5308.46it/s, loss=0]    
Epoch: 56:   9%|▉         | 1024/11623 [00:00<00:01, 5435.57it/s, loss=0.782]

Loss Train - 0.796 | Test - 0.624


Epoch: 56: 100%|██████████| 11623/11623 [00:02<00:00, 5327.89it/s, loss=0.679]
Epoch: 57:   9%|▉         | 1056/11623 [00:00<00:01, 5469.16it/s, loss=0.831]

Loss Train - 0.786 | Test - 0.613


Epoch: 57: 100%|██████████| 11623/11623 [00:02<00:00, 5328.81it/s, loss=0.615]
Epoch: 58:   9%|▉         | 1056/11623 [00:00<00:01, 5462.46it/s, loss=0.827]

Loss Train - 0.770 | Test - 0.615


Epoch: 58: 100%|██████████| 11623/11623 [00:02<00:00, 5304.74it/s, loss=0]    
Epoch: 59:   9%|▉         | 1056/11623 [00:00<00:01, 5412.99it/s, loss=0.777]

Loss Train - 0.757 | Test - 0.639


Epoch: 59: 100%|██████████| 11623/11623 [00:02<00:00, 5090.09it/s, loss=0]    
Epoch: 60:   9%|▉         | 1056/11623 [00:00<00:01, 5414.19it/s, loss=0.756]

Loss Train - 0.743 | Test - 0.606


Epoch: 60: 100%|██████████| 11623/11623 [00:02<00:00, 5232.77it/s, loss=0.52] 
Epoch: 61:   9%|▉         | 1056/11623 [00:00<00:01, 5442.17it/s, loss=0.712]

Loss Train - 0.725 | Test - 0.585


Epoch: 61: 100%|██████████| 11623/11623 [00:02<00:00, 5300.52it/s, loss=0.493]
Epoch: 62:   9%|▉         | 1024/11623 [00:00<00:02, 5165.75it/s, loss=0.696]

Loss Train - 0.709 | Test - 0.588


Epoch: 62: 100%|██████████| 11623/11623 [00:02<00:00, 5184.59it/s, loss=0.502]
Epoch: 63:   9%|▉         | 1056/11623 [00:00<00:01, 5352.79it/s, loss=0.688]

Loss Train - 0.691 | Test - 0.574


Epoch: 63: 100%|██████████| 11623/11623 [00:02<00:00, 5282.46it/s, loss=0]    
Epoch: 64:   9%|▉         | 1056/11623 [00:00<00:01, 5318.16it/s, loss=0.643]

Loss Train - 0.670 | Test - 0.587


Epoch: 64: 100%|██████████| 11623/11623 [00:02<00:00, 5216.92it/s, loss=0.401]
Epoch: 65:   9%|▉         | 1024/11623 [00:00<00:01, 5326.44it/s, loss=0.586]

Loss Train - 0.649 | Test - 0.545


Epoch: 65: 100%|██████████| 11623/11623 [00:02<00:00, 5099.33it/s, loss=0.381]
Epoch: 66:   9%|▉         | 1024/11623 [00:00<00:02, 5208.42it/s, loss=0.607]

Loss Train - 0.624 | Test - 0.544


Epoch: 66: 100%|██████████| 11623/11623 [00:02<00:00, 4942.32it/s, loss=0.339]
Epoch: 67:   9%|▊         | 992/11623 [00:00<00:02, 5002.29it/s, loss=0.604]

Loss Train - 0.604 | Test - 0.532


Epoch: 67: 100%|██████████| 11623/11623 [00:02<00:00, 5007.49it/s, loss=0.304]
Epoch: 68:   9%|▊         | 992/11623 [00:00<00:02, 5145.04it/s, loss=0.566]

Loss Train - 0.582 | Test - 0.552


Epoch: 68: 100%|██████████| 11623/11623 [00:02<00:00, 5029.17it/s, loss=0.283]
Epoch: 69:   9%|▊         | 992/11623 [00:00<00:02, 5081.43it/s, loss=0.533]

Loss Train - 0.554 | Test - 0.545


Epoch: 69: 100%|██████████| 11623/11623 [00:02<00:00, 4848.13it/s, loss=0.25] 
Epoch: 70:   9%|▊         | 992/11623 [00:00<00:02, 5144.75it/s, loss=0.45] 

Loss Train - 0.525 | Test - 0.538


Epoch: 70: 100%|██████████| 11623/11623 [00:02<00:00, 4995.44it/s, loss=0.224]
Epoch: 71:   9%|▊         | 992/11623 [00:00<00:02, 5092.30it/s, loss=0.436]

Loss Train - 0.493 | Test - 0.525


Epoch: 71: 100%|██████████| 11623/11623 [00:02<00:00, 4996.75it/s, loss=0.185]
Epoch: 72:   9%|▊         | 992/11623 [00:00<00:02, 5009.26it/s, loss=0.425]

Loss Train - 0.458 | Test - 0.518


Epoch: 72: 100%|██████████| 11623/11623 [00:02<00:00, 4995.00it/s, loss=0.17] 
Epoch: 73:   9%|▊         | 992/11623 [00:00<00:02, 5051.24it/s, loss=0.378]

Loss Train - 0.416 | Test - 0.494


Epoch: 73: 100%|██████████| 11623/11623 [00:02<00:00, 4962.86it/s, loss=0.129]
Epoch: 74:   9%|▊         | 992/11623 [00:00<00:02, 5094.32it/s, loss=0.316]

Loss Train - 0.367 | Test - 0.495


Epoch: 74: 100%|██████████| 11623/11623 [00:02<00:00, 5007.52it/s, loss=0.0707]


Loss Train - 0.306 | Test - 0.494
