In [3]:
import os
import re
import json
import jsonlines as jl
import joblib
from pathlib import Path
import itertools
from collections import defaultdict

import pickle

import numpy as np
from numpy.linalg import norm

from tqdm import tqdm

from scipy.stats import spearmanr, pearsonr
from sklearn.model_selection import train_test_split

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.distributed import DistributedSampler

from transformers import get_scheduler

from keras.preprocessing.text import Tokenizer

# Useful Constants

In [4]:
#################
# Paths and files
#################

# Data paths
INPUT_PATH = Path("..")/"input"
ALL_TRIPLES = Path(INPUT_PATH)/"triple_set.jsonl"
TRAIN_TRIPLES = Path(INPUT_PATH)/"train_triples.jsonl"
VAL_TRIPLES = Path(INPUT_PATH)/"val_triples.jsonl"
TEST_TRIPLES = Path(INPUT_PATH)/"test_triples.jsonl"

# Model paths
OUTPUT_PATH = Path("..")/"result"
FASTTEXT = "./ft_all_phr_2022_30_03_top_100000.pkl"
FIT_TOKENIZER = Path(OUTPUT_PATH)/"ft_tokenizer.pkl"
OUTPUT_MODEL = Path(OUTPUT_PATH)/"triplet_loss_model.sav"

###########
# Сonstants
###########

# Data fields
ANCHOR = "anchor"
POSITIVE = "positive"
NEGATIVE = "negative"
TYPE2IDX = dict(anchor=0, positive=1, negative=2)

# Devices and threads
GPU_NUM = torch.cuda.device_count()
GPU_IDS = [f'cuda:{_id}' for _id in range(GPU_NUM)]

# Model parameters
VOCAB_SIZE = 20000
EMBEDDING_SIZE = 300
LSTM_SIZE = 128
LSTM_NUM_LAYERS = 3
OUT_EMBEDDING_SIZE = 768
TRAIN_BATCH_SIZE = 1024
VAL_BATCH_SIZE = 1024
MAX_SEQ_LEN = 32

# Loading Data

In [5]:
def load_triples(_file):
    with jl.open(_file, mode="r") as infile:
        triple_lst = [item for item in infile]
    return triple_lst

In [6]:
with open(FIT_TOKENIZER, "rb") as infile:
    TOKENIZER = pickle.load(infile)

In [7]:
NB_WORDS = min(VOCAB_SIZE, len(TOKENIZER.word_index.items()))
NB_WORDS

20000

In [10]:
dataset_is_split = os.path.exists(TRAIN_TRIPLES) and os.path.exists(VAL_TRIPLES) and os.path.exists(TEST_TRIPLES)

if not dataset_is_split:
    # Split data
    all_triples = load_triples(ALL_TRIPLES)
    train_triples, val_test_triples = train_test_split(all_triples, test_size=0.3, random_state=1)
    val_triples, test_triples = train_test_split(val_test_triples, test_size=0.33, random_state=1)
    # Dump sets
    with jl.open(TRAIN_TRIPLES, mode="w") as outfile:
        outfile.write_all(train_triples)
    with jl.open(VAL_TRIPLES, mode="w") as outfile:
        outfile.write_all(val_triples)
    with jl.open(TEST_TRIPLES, mode="w") as outfile:
        outfile.write_all(test_triples)
else:
    train_triples = load_triples(TRAIN_TRIPLES)
    val_triples = load_triples(VAL_TRIPLES)

False


# FastText + BiLSTM Training

In [11]:
# Utilities

class _dict(dict):
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
    
    def __getstate__(self):
        return self.__dict__
    
    def __setstate__(self, d):
        return self.__dict__.update(d)

In [12]:
class TextPreprocessor:
    def __init__(self, args=None, tokenizer=TOKENIZER, pretrained_embeddings=FASTTEXT):
        if args is not None:
            self.args = vars(args)
        if not args:
            self.args = _dict(tokenizer=tokenizer,
                             pretrained_embeddings=joblib.load(pretrained_embeddings))
            
    def to_matrix(self, data, _dtype=torch.int64, batch_first=True):
        indexed_data = [[], [], []]
        
        # Texts to token IDs
        for item in data:
            indexed_data[0].append(self.clean_text(item[ANCHOR]))
            indexed_data[1].append(self.clean_text(item[POSITIVE]))
            indexed_data[2].append(self.clean_text(item[NEGATIVE]))
        indexed_data = list(map(self.args.tokenizer.texts_to_sequences, indexed_data))
        
        # Sequence length dimension size
        max_seq_len = MAX_SEQ_LEN
        
        # Filling the padded matrix
        _size = len(data)
        out_matrix = torch.zeros((_size, 3, max_seq_len), dtype=_dtype)
        for idx in range(_size):
            for item_type in range(3):
                bound = min(len(indexed_data[item_type][idx]), max_seq_len)
                out_matrix[idx, item_type, :bound] = torch.from_numpy(
                    np.array(indexed_data[item_type][idx][:bound]))
        if not batch_first:
            out_matrix = out_matrix.permute(1, 2, 0)
            
        return out_matrix
    
    def build_embedding_matrix(self, emb_size=EMBEDDING_SIZE, nb_words=NB_WORDS, _dtype=torch.float32):
        embedding_matrix = torch.zeros((nb_words + 1, emb_size), dtype=_dtype)
        oof_tokens = set()
        
        for word, i in self.args.tokenizer.word_index.items():
            if i >= VOCAB_SIZE + 1:
                continue
            embedding_vector = None
            try:
                embedding_vector = self.args.pretrained_embeddings[word] 
            except:
                oof_tokens.add(word)
            if embedding_vector is not None:
                embedding_matrix[i, :] = torch.from_numpy(embedding_vector)
                
        print(f"{len(oof_tokens)} out of FastText tokens")
        return embedding_matrix
    
    @staticmethod
    def clean_text(_text):
        _text = re.sub("\n|(</?[^>]*>)", " ", _text)
        _text = re.sub("\.\s", " ", _text)
        _text = re.sub('здравствуйте', '', _text)
        _text = re.sub('добрый день', '', _text)
        _text = re.sub(r"\s+", " ", _text)
        return _text.strip()

In [13]:
text_preprocessor = TextPreprocessor()
embedding_matrix = text_preprocessor.build_embedding_matrix()
embedding_matrix.shape

3153 out of FastText tokens


torch.Size([20001, 300])

In [14]:
class TripletSet:
    def __init__(self, train_data, val_data,
                 preprocessor=text_preprocessor, batch_size=(TRAIN_BATCH_SIZE, VAL_BATCH_SIZE)):
        kwargs = self.make_kwargs()
        
        train_features = preprocessor.to_matrix(train_data)
        train_labels = torch.ones(len(train_data))
        val_features = preprocessor.to_matrix(val_data)
        val_labels = torch.ones(len(val_data))
        
        self.train_set = TensorDataset(train_features, train_labels)
        self.train_loader = DataLoader(self.train_set, batch_size=batch_size[0],
                                       drop_last=True, shuffle=True, **kwargs)
        self.val_set = TensorDataset(val_features, val_labels)
        self.val_loader = DataLoader(self.val_set, batch_size=batch_size[1],
                                     drop_last=True, shuffle=True, **kwargs)
        
    @staticmethod
    def make_kwargs():
        use_cuda = torch.cuda.is_available()
        print(f"Use CUDA {use_cuda}")
        kwargs = {'num_workers': 4 * GPU_NUM, 'pin_memory': True} if use_cuda else {}
        return kwargs

In [15]:
_dataset = TripletSet(train_triples, val_triples)

Use CUDA True


In [16]:
class SoftmaxLoss(nn.Module):
    def __init__(self, _loss=(lambda _: F.relu(_))):
        super(SoftmaxLoss, self).__init__()
        self.loss_func = _loss
        
    def forward(self, anchor, positive, negative, is_test=False):
        positive_similarity = torch.sum(anchor * positive, axis=-1, keepdims=True)
        _matmul = torch.matmul(anchor, torch.transpose(negative, 0, 1))
        negative_similarity = torch.log(torch.sum(
            torch.exp(_matmul), axis=-1, keepdims=True))
        loss = self.loss_func(negative_similarity - positive_similarity)
        return loss
    
    @staticmethod
    def mean_loss(_true, _predicted):
        _mean = torch.mean(_predicted - 0 * _true, dim=0)
        return _mean

In [17]:
class TripletLossModel(nn.Module):
    def __init__(self, criterion,
                 init_weights=embedding_matrix,
                 emb_size=EMBEDDING_SIZE, vocab_size=VOCAB_SIZE,
                 lstm_num_layers=LSTM_NUM_LAYERS,
                 lstm_hidden_size=LSTM_SIZE, is_bi=True,
                 out_emb_size=OUT_EMBEDDING_SIZE,
                 batch_first=True, dropout=0.2, norm_eps=1e-12):
        super(TripletLossModel, self).__init__()
        
        # Parameters
        self.emb_size = emb_size
        self.vocab_size = vocab_size
        self.lstm_num_layers = lstm_num_layers
        self.lstm_hidden_size = lstm_hidden_size
        self.directions = int(is_bi) + 1
        self.out_emb_size = out_emb_size
        
        # Layers
        self.embedding = nn.Embedding(self.vocab_size + 1, self.emb_size)
        self.embedding.weight = nn.Parameter(init_weights)
        self.lstm = nn.LSTM(input_size=self.emb_size, hidden_size=self.lstm_hidden_size,
                            num_layers=self.lstm_num_layers, bidirectional=is_bi,
                            batch_first=batch_first, dropout=dropout)
        self.cosine_similarity = nn.CosineSimilarity(dim=1, eps=1e-6)
        
        self.best = defaultdict(int)
        
    def forward(self, _input):
        anchor_input = self.fetch_triplet_item(_input, ANCHOR)
        positive_input = self.fetch_triplet_item(_input, POSITIVE)
        negative_input = self.fetch_triplet_item(_input, NEGATIVE)
        
        encoded_anchor = self.encode(anchor_input)
        encoded_positive = self.encode(positive_input)
        encoded_negative = self.encode(negative_input)
        
        cosine_pos = self.cos(encoded_anchor, encoded_positive)
        cosine_neg = self.cos(encoded_anchor, encoded_negative)
        
        return _dict(anchor=encoded_anchor,
                    positive=encoded_positive,
                    negative=encoded_negative,
                    cos_sim=(cosine_pos, cosine_neg))
        
    def encode(self, tokenized_seq, _const=0.5):
        encoded_seq = self.embedding(tokenized_seq)
        encoded_seq, _ = self.lstm(encoded_seq)
        encoded_seq = encoded_seq[:, -1, :] * _const
        return encoded_seq
        
    @staticmethod
    def fetch_triplet_item(_input, item_type):
        if TYPE2IDX.get(item_type) is None:
            raise Exception(f"No item type {item_type}")
        return _input[:, TYPE2IDX[item_type], :]
    
    def cos(self, u, v):
        return self.cosine_similarity(u, v)

In [18]:
class ModelPipeline:
    def __init__(self, dataset, model, criterion, optimizer, device):
        self.dataset = dataset
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = torch.device(device)
        self.lr_scheduler = None
        
        self.best = defaultdict(int)
        self.progress_bar = None
        
    def fit(self, epoch_num, log_interval=10):
        if self.device.type == 'cuda':
            self.model.cuda()
        
        num_steps = epoch_num * len(self.dataset.train_loader)
        
        self.progress_bar = tqdm(range(num_steps), position=0, leave=True)
        self.lr_scheduler = get_scheduler(name="linear", optimizer=self.optimizer,
                                    num_warmup_steps=0, num_training_steps=num_steps)
        for epoch in range(epoch_num):
            self.train_epoch(epoch)
            
            if (epoch + 1) % log_interval == 0 or epoch == 0:
                print()
                self.validate()
                print()
    
    def train_epoch(self, epoch):
        self.model.train()
        
        epoch_loss = []
        
        for _input, _target in self.dataset.train_loader:
            if self.device.type == 'cuda':
                _input, _target = _input.cuda(), _target.cuda()
            
            _output = self.model(_input)
            loss = self.criterion(_output.anchor, _output.positive, _output.negative)
            mean_loss = self.criterion.mean_loss(loss, loss)
            epoch_loss.append(mean_loss.cpu().item())
            
            mean_loss.backward()
            self.optimizer.step()
            self.lr_scheduler.step()
            self.optimizer.zero_grad()
            
            self.progress_bar.update(1)
            
        print(f"Epoch #{epoch + 1}\nMean epoch loss: {sum(epoch_loss) / len(epoch_loss)}")
    
    def validate(self):
        self.model.eval()
        
        with torch.no_grad():
            cosine_dist, targets = [], []
            
            for _input, _target in self.dataset.val_loader:
                if self.device.type == "cuda":
                    _input, _target = _input.cuda(), _target.cuda()
                
                _output = self.model(_input)
                cosine_dist.append(_output.cos_sim[0])
                cosine_dist.append(_output.cos_sim[1])
                targets.append(_target)
                targets.append(_target * -1)
                
            cosine_dist = torch.cat(cosine_dist).cpu().detach().numpy()
            targets = torch.cat(targets)
            targets = torch.squeeze(targets).cpu().detach().numpy()
        
        for metric, func in [("spearman_r", spearmanr), ("pearson_r", pearsonr)]:
            coef, _ = func(targets, cosine_dist)
            coef = np.round(coef, 4)

            metric_name = f"{metric}"
            message = f"{metric} = {coef}"
            if coef > self.best[metric_name]:
                self.best[metric_name] = coef
                message = "*** New best: " + message
                if metric == "spearman_r":
                    torch.save(self.model.state_dict(), OUTPUT_MODEL)

            print(message)

In [19]:
_model = TripletLossModel(embedding_matrix)
_criterion = SoftmaxLoss()
_optimizer = Adam(_model.parameters(), lr=1e-5)
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

evaluation_pipeline = ModelPipeline(dataset=_dataset, model=_model, criterion=_criterion,
                                    optimizer=_optimizer, device=_device)

In [20]:
evaluation_pipeline.fit(epoch_num=150)

  1%|          | 683/102450 [01:28<3:29:12,  8.11it/s]

Epoch #1
Mean epoch loss: 6.61628636154912

*** New best: spearman_r = 0.6213
*** New best: pearson_r = 0.5877



  1%|▏         | 1366/102450 [03:08<3:27:43,  8.11it/s] 

Epoch #2
Mean epoch loss: 5.869832014444108


  2%|▏         | 2049/102450 [04:37<3:26:06,  8.12it/s] 

Epoch #3
Mean epoch loss: 5.729963161585795


  3%|▎         | 2732/102450 [06:07<3:24:22,  8.13it/s] 

Epoch #4
Mean epoch loss: 5.232154555760575


  3%|▎         | 3415/102450 [07:37<3:22:33,  8.15it/s] 

Epoch #5
Mean epoch loss: 4.841225570326843


  4%|▍         | 4098/102450 [09:06<3:22:19,  8.10it/s] 

Epoch #6
Mean epoch loss: 4.649304404614785


  5%|▍         | 4781/102450 [10:36<3:20:49,  8.11it/s] 

Epoch #7
Mean epoch loss: 4.397690212394972


  5%|▌         | 5464/102450 [12:06<3:18:49,  8.13it/s] 

Epoch #8
Mean epoch loss: 4.159221550032474


  6%|▌         | 6147/102450 [13:35<3:17:31,  8.13it/s] 

Epoch #9
Mean epoch loss: 3.9760308307575096


  7%|▋         | 6830/102450 [15:05<3:16:20,  8.12it/s] 

Epoch #10
Mean epoch loss: 3.823870537745307

*** New best: spearman_r = 0.8294
*** New best: pearson_r = 0.7907



  7%|▋         | 7513/102450 [16:46<3:15:00,  8.11it/s]  

Epoch #11
Mean epoch loss: 3.6847287477546344


  8%|▊         | 8196/102450 [18:16<3:13:10,  8.13it/s] 

Epoch #12
Mean epoch loss: 3.555432537464237


  9%|▊         | 8879/102450 [19:45<3:11:38,  8.14it/s] 

Epoch #13
Mean epoch loss: 3.4399554956360756


  9%|▉         | 9562/102450 [21:15<3:10:43,  8.12it/s] 

Epoch #14
Mean epoch loss: 3.326565832205084


 10%|█         | 10245/102450 [22:45<3:09:47,  8.10it/s]

Epoch #15
Mean epoch loss: 3.1893108903401766


 11%|█         | 10928/102450 [24:15<3:08:02,  8.11it/s] 

Epoch #16
Mean epoch loss: 3.0322265157238806


 11%|█▏        | 11611/102450 [25:44<3:06:27,  8.12it/s] 

Epoch #17
Mean epoch loss: 2.893772584627523


 12%|█▏        | 12294/102450 [27:14<3:05:05,  8.12it/s] 

Epoch #18
Mean epoch loss: 2.7603778067654434


 13%|█▎        | 12977/102450 [28:45<3:03:30,  8.13it/s] 

Epoch #19
Mean epoch loss: 2.6320635675511284


 13%|█▎        | 13660/102450 [30:15<3:02:08,  8.12it/s] 

Epoch #20
Mean epoch loss: 2.511042615854199

*** New best: spearman_r = 0.8487
*** New best: pearson_r = 0.8433



 14%|█▍        | 14343/102450 [32:00<3:00:50,  8.12it/s]  

Epoch #21
Mean epoch loss: 2.407414848947595


 15%|█▍        | 15026/102450 [33:30<2:59:13,  8.13it/s] 

Epoch #22
Mean epoch loss: 2.312931313521691


 15%|█▌        | 15709/102450 [35:00<2:57:52,  8.13it/s] 

Epoch #23
Mean epoch loss: 2.228655698184213


 16%|█▌        | 16392/102450 [36:30<2:56:20,  8.13it/s] 

Epoch #24
Mean epoch loss: 2.1472087492879974


 17%|█▋        | 17075/102450 [38:00<2:55:03,  8.13it/s] 

Epoch #25
Mean epoch loss: 2.067263014669125


 17%|█▋        | 17758/102450 [39:30<2:53:39,  8.13it/s] 

Epoch #26
Mean epoch loss: 1.9864025325412107


 18%|█▊        | 18441/102450 [41:00<2:52:02,  8.14it/s] 

Epoch #27
Mean epoch loss: 1.9039945389480926


 19%|█▊        | 19124/102450 [42:30<2:50:49,  8.13it/s] 

Epoch #28
Mean epoch loss: 1.81931050507949


 19%|█▉        | 19807/102450 [44:00<2:49:32,  8.12it/s] 

Epoch #29
Mean epoch loss: 1.7394453815974045


 20%|██        | 20490/102450 [45:30<2:48:00,  8.13it/s] 

Epoch #30
Mean epoch loss: 1.6681588800899456

*** New best: spearman_r = 0.8567
*** New best: pearson_r = 0.87



 21%|██        | 21173/102450 [47:12<2:46:26,  8.14it/s]  

Epoch #31
Mean epoch loss: 1.603481118361388


 21%|██▏       | 21856/102450 [48:42<2:45:40,  8.11it/s] 

Epoch #32
Mean epoch loss: 1.5419556035129776


 22%|██▏       | 22539/102450 [50:12<2:43:18,  8.16it/s] 

Epoch #33
Mean epoch loss: 1.4845159711809843


 23%|██▎       | 23222/102450 [51:42<2:42:25,  8.13it/s] 

Epoch #34
Mean epoch loss: 1.4294299203166152


 23%|██▎       | 23905/102450 [53:12<2:40:47,  8.14it/s] 

Epoch #35
Mean epoch loss: 1.380913507013335


 24%|██▍       | 24588/102450 [54:42<2:39:48,  8.12it/s] 

Epoch #36
Mean epoch loss: 1.3349162320965031


 25%|██▍       | 25271/102450 [56:12<2:37:42,  8.16it/s] 

Epoch #37
Mean epoch loss: 1.291623963618523


 25%|██▌       | 25954/102450 [57:41<2:36:38,  8.14it/s] 

Epoch #38
Mean epoch loss: 1.252757324306717


 26%|██▌       | 26637/102450 [59:11<2:35:30,  8.13it/s] 

Epoch #39
Mean epoch loss: 1.2137452245282219


 27%|██▋       | 27320/102450 [1:00:42<2:34:03,  8.13it/s]

Epoch #40
Mean epoch loss: 1.178842186404182

*** New best: spearman_r = 0.8601
*** New best: pearson_r = 0.8855



 27%|██▋       | 28003/102450 [1:02:23<2:32:21,  8.14it/s]  

Epoch #41
Mean epoch loss: 1.146861492191192


 28%|██▊       | 28686/102450 [1:03:53<2:31:07,  8.14it/s] 

Epoch #42
Mean epoch loss: 1.1141331580687057


 29%|██▊       | 29369/102450 [1:05:23<2:29:53,  8.13it/s] 

Epoch #43
Mean epoch loss: 1.0820051127960182


 29%|██▉       | 30052/102450 [1:06:53<2:28:26,  8.13it/s] 

Epoch #44
Mean epoch loss: 1.0550205675260564


 30%|███       | 30735/102450 [1:08:23<2:27:11,  8.12it/s] 

Epoch #45
Mean epoch loss: 1.0274560461107145


 31%|███       | 31418/102450 [1:09:53<2:25:35,  8.13it/s] 

Epoch #46
Mean epoch loss: 0.9999807784183832


 31%|███▏      | 32101/102450 [1:11:25<2:24:31,  8.11it/s] 

Epoch #47
Mean epoch loss: 0.9769540886404224


 32%|███▏      | 32784/102450 [1:12:55<2:22:52,  8.13it/s] 

Epoch #48
Mean epoch loss: 0.9520076624994571


 33%|███▎      | 33467/102450 [1:14:25<2:21:39,  8.12it/s] 

Epoch #49
Mean epoch loss: 0.9298930828777274


 33%|███▎      | 34150/102450 [1:15:55<2:20:13,  8.12it/s] 

Epoch #50
Mean epoch loss: 0.9082617812938606

*** New best: spearman_r = 0.8616
*** New best: pearson_r = 0.8937



 34%|███▍      | 34833/102450 [1:17:36<2:18:36,  8.13it/s] 

Epoch #51
Mean epoch loss: 0.8882354415876799


 35%|███▍      | 35516/102450 [1:19:06<2:17:09,  8.13it/s] 

Epoch #52
Mean epoch loss: 0.8677386048946688


 35%|███▌      | 36199/102450 [1:20:36<2:15:43,  8.14it/s] 

Epoch #53
Mean epoch loss: 0.8474192961417738


 36%|███▌      | 36882/102450 [1:22:05<2:14:27,  8.13it/s] 

Epoch #54
Mean epoch loss: 0.82918285494493


 37%|███▋      | 37565/102450 [1:23:35<2:13:35,  8.09it/s] 

Epoch #55
Mean epoch loss: 0.811225963808258


 37%|███▋      | 38248/102450 [1:25:05<2:11:35,  8.13it/s] 

Epoch #56
Mean epoch loss: 0.7958197887785089


 38%|███▊      | 38931/102450 [1:26:35<2:10:57,  8.08it/s] 

Epoch #57
Mean epoch loss: 0.7787909941477796


 39%|███▊      | 39614/102450 [1:28:05<2:09:00,  8.12it/s] 

Epoch #58
Mean epoch loss: 0.7629334593238104


 39%|███▉      | 40297/102450 [1:29:35<2:07:25,  8.13it/s] 

Epoch #59
Mean epoch loss: 0.7484922728461717


 40%|████      | 40980/102450 [1:31:04<2:06:03,  8.13it/s] 

Epoch #60
Mean epoch loss: 0.7345112670392976

*** New best: spearman_r = 0.8625
*** New best: pearson_r = 0.8996



 41%|████      | 41663/102450 [1:32:46<2:04:33,  8.13it/s] 

Epoch #61
Mean epoch loss: 0.7220601957853158


 41%|████▏     | 42346/102450 [1:34:16<2:03:03,  8.14it/s] 

Epoch #62
Mean epoch loss: 0.7072829032537703


 42%|████▏     | 43029/102450 [1:35:45<2:02:10,  8.11it/s] 

Epoch #63
Mean epoch loss: 0.6958118098546611


 43%|████▎     | 43712/102450 [1:37:15<2:00:42,  8.11it/s] 

Epoch #64
Mean epoch loss: 0.6848058766885872


 43%|████▎     | 44395/102450 [1:38:45<1:59:20,  8.11it/s] 

Epoch #65
Mean epoch loss: 0.6728639798143423


 44%|████▍     | 45078/102450 [1:40:15<1:57:43,  8.12it/s] 

Epoch #66
Mean epoch loss: 0.6606102502887106


 45%|████▍     | 45761/102450 [1:41:45<1:56:38,  8.10it/s] 

Epoch #67
Mean epoch loss: 0.651904144363906


 45%|████▌     | 46444/102450 [1:43:15<1:55:04,  8.11it/s] 

Epoch #68
Mean epoch loss: 0.6419729922910385


 46%|████▌     | 47127/102450 [1:44:45<1:53:38,  8.11it/s] 

Epoch #69
Mean epoch loss: 0.6321810528875969


 47%|████▋     | 47810/102450 [1:46:16<1:54:11,  7.97it/s] 

Epoch #70
Mean epoch loss: 0.6243421876936671

*** New best: spearman_r = 0.8631
*** New best: pearson_r = 0.9043



 47%|████▋     | 48493/102450 [1:48:01<1:51:47,  8.04it/s] 

Epoch #71
Mean epoch loss: 0.6153737394736417


 48%|████▊     | 49176/102450 [1:49:31<1:49:30,  8.11it/s] 

Epoch #72
Mean epoch loss: 0.6064192148724552


 49%|████▊     | 49859/102450 [1:51:01<1:48:17,  8.09it/s] 

Epoch #73
Mean epoch loss: 0.5987139036156248


 49%|████▉     | 50542/102450 [1:52:32<1:47:00,  8.09it/s] 

Epoch #74
Mean epoch loss: 0.5899596520608201


 50%|█████     | 51225/102450 [1:54:02<1:44:55,  8.14it/s] 

Epoch #75
Mean epoch loss: 0.5825240010311105


 51%|█████     | 51908/102450 [1:55:32<1:44:05,  8.09it/s] 

Epoch #76
Mean epoch loss: 0.5749773080317552


 51%|█████▏    | 52591/102450 [1:57:03<1:42:27,  8.11it/s] 

Epoch #77
Mean epoch loss: 0.5692069265626011


 52%|█████▏    | 53274/102450 [1:58:33<1:41:10,  8.10it/s] 

Epoch #78
Mean epoch loss: 0.5621104578584389


 53%|█████▎    | 53957/102450 [2:00:03<1:39:41,  8.11it/s] 

Epoch #79
Mean epoch loss: 0.5561336659407721


 53%|█████▎    | 54640/102450 [2:01:33<1:38:40,  8.08it/s] 

Epoch #80
Mean epoch loss: 0.5495222932597729

*** New best: spearman_r = 0.8634
*** New best: pearson_r = 0.9067



 54%|█████▍    | 55323/102450 [2:03:15<1:36:50,  8.11it/s] 

Epoch #81
Mean epoch loss: 0.5435407332323307


 55%|█████▍    | 56006/102450 [2:04:45<1:35:03,  8.14it/s] 

Epoch #82
Mean epoch loss: 0.5377429000716957


 55%|█████▌    | 56689/102450 [2:06:15<1:34:01,  8.11it/s] 

Epoch #83
Mean epoch loss: 0.5316294932609771


 56%|█████▌    | 57372/102450 [2:07:45<1:32:34,  8.12it/s] 

Epoch #84
Mean epoch loss: 0.5264982862144403


 57%|█████▋    | 58055/102450 [2:09:15<1:31:28,  8.09it/s] 

Epoch #85
Mean epoch loss: 0.5216721748363221


 57%|█████▋    | 58738/102450 [2:10:46<1:29:47,  8.11it/s] 

Epoch #86
Mean epoch loss: 0.5173016725050757


 58%|█████▊    | 59421/102450 [2:12:16<1:28:28,  8.11it/s] 

Epoch #87
Mean epoch loss: 0.5120587936693892


 59%|█████▊    | 60104/102450 [2:13:46<1:27:11,  8.09it/s] 

Epoch #88
Mean epoch loss: 0.5070987664937624


 59%|█████▉    | 60787/102450 [2:15:16<1:25:27,  8.13it/s] 

Epoch #89
Mean epoch loss: 0.5021827065578143


 60%|██████    | 61470/102450 [2:16:47<1:24:25,  8.09it/s] 

Epoch #90
Mean epoch loss: 0.4977820127286115

*** New best: spearman_r = 0.8636
*** New best: pearson_r = 0.9086



 61%|██████    | 62153/102450 [2:18:28<1:22:44,  8.12it/s] 

Epoch #91
Mean epoch loss: 0.49346355633540173


 61%|██████▏   | 62836/102450 [2:19:59<1:21:35,  8.09it/s] 

Epoch #92
Mean epoch loss: 0.48837297705396204


 62%|██████▏   | 63519/102450 [2:21:29<1:20:02,  8.11it/s] 

Epoch #93
Mean epoch loss: 0.4847008907515692


 63%|██████▎   | 64202/102450 [2:22:59<1:18:43,  8.10it/s] 

Epoch #94
Mean epoch loss: 0.4817320037725904


 63%|██████▎   | 64885/102450 [2:24:32<1:17:21,  8.09it/s] 

Epoch #95
Mean epoch loss: 0.4778487602927849


 64%|██████▍   | 65568/102450 [2:26:02<1:15:33,  8.14it/s] 

Epoch #96
Mean epoch loss: 0.4733729559715461


 65%|██████▍   | 66251/102450 [2:27:33<1:14:23,  8.11it/s] 

Epoch #97
Mean epoch loss: 0.46987789672953495


 65%|██████▌   | 66934/102450 [2:29:03<1:12:53,  8.12it/s] 

Epoch #98
Mean epoch loss: 0.4674008544070829


 66%|██████▌   | 67617/102450 [2:30:34<1:11:44,  8.09it/s] 

Epoch #99
Mean epoch loss: 0.463904147266818


 67%|██████▋   | 68300/102450 [2:32:04<1:10:14,  8.10it/s] 

Epoch #100
Mean epoch loss: 0.46061544323211884

*** New best: spearman_r = 0.8638
*** New best: pearson_r = 0.9104



 67%|██████▋   | 68983/102450 [2:33:47<1:08:46,  8.11it/s] 

Epoch #101
Mean epoch loss: 0.4581365369336671


 68%|██████▊   | 69666/102450 [2:35:17<1:07:16,  8.12it/s] 

Epoch #102
Mean epoch loss: 0.454394267122009


 69%|██████▊   | 70349/102450 [2:36:48<1:05:51,  8.12it/s] 

Epoch #103
Mean epoch loss: 0.4514841614059264


 69%|██████▉   | 71032/102450 [2:38:18<1:04:35,  8.11it/s] 

Epoch #104
Mean epoch loss: 0.4479180051240181


 70%|███████   | 71715/102450 [2:39:49<1:03:01,  8.13it/s] 

Epoch #105
Mean epoch loss: 0.44561569626125375


 71%|███████   | 72398/102450 [2:41:19<1:01:47,  8.11it/s] 

Epoch #106
Mean epoch loss: 0.4433179029957616


 71%|███████▏  | 73081/102450 [2:42:50<1:00:17,  8.12it/s] 

Epoch #107
Mean epoch loss: 0.4406870206487999


 72%|███████▏  | 73764/102450 [2:44:20<59:01,  8.10it/s]   

Epoch #108
Mean epoch loss: 0.43707640829058375


 73%|███████▎  | 74447/102450 [2:45:51<57:35,  8.10it/s]   

Epoch #109
Mean epoch loss: 0.4353242855120892


 73%|███████▎  | 75130/102450 [2:47:22<55:54,  8.14it/s]   

Epoch #110
Mean epoch loss: 0.43347047472244477

*** New best: spearman_r = 0.8639
*** New best: pearson_r = 0.9119



 74%|███████▍  | 75813/102450 [2:49:04<54:42,  8.11it/s]   

Epoch #111
Mean epoch loss: 0.43094707015318195


 75%|███████▍  | 76496/102450 [2:50:35<53:13,  8.13it/s]   

Epoch #112
Mean epoch loss: 0.42907725996712814


 75%|███████▌  | 77179/102450 [2:52:05<51:58,  8.10it/s]   

Epoch #113
Mean epoch loss: 0.4275197304272547


 76%|███████▌  | 77862/102450 [2:53:36<50:33,  8.10it/s]   

Epoch #114
Mean epoch loss: 0.42501659706605827


 77%|███████▋  | 78545/102450 [2:55:06<49:06,  8.11it/s]   

Epoch #115
Mean epoch loss: 0.42225801141858976


 77%|███████▋  | 79228/102450 [2:56:37<47:44,  8.11it/s]   

Epoch #116
Mean epoch loss: 0.42118533350538195


 78%|███████▊  | 79911/102450 [2:58:07<46:16,  8.12it/s]   

Epoch #117
Mean epoch loss: 0.41894967803843375


 79%|███████▊  | 80594/102450 [2:59:38<44:56,  8.10it/s]   

Epoch #118
Mean epoch loss: 0.4172940009858563


 79%|███████▉  | 81277/102450 [3:01:09<43:25,  8.13it/s]   

Epoch #119
Mean epoch loss: 0.41561998258921623


 80%|████████  | 81960/102450 [3:02:41<42:08,  8.10it/s]   

Epoch #120
Mean epoch loss: 0.41355357145669697

*** New best: spearman_r = 0.864
*** New best: pearson_r = 0.9127



 81%|████████  | 82643/102450 [3:04:24<40:40,  8.12it/s]   

Epoch #121
Mean epoch loss: 0.4116005410997969


 81%|████████▏ | 83326/102450 [3:05:55<39:22,  8.09it/s]   

Epoch #122
Mean epoch loss: 0.41120644688257185


 82%|████████▏ | 84009/102450 [3:07:26<37:56,  8.10it/s]   

Epoch #123
Mean epoch loss: 0.4090812884347505


 83%|████████▎ | 84692/102450 [3:08:56<36:27,  8.12it/s]   

Epoch #124
Mean epoch loss: 0.40773360285018934


 83%|████████▎ | 85375/102450 [3:10:27<35:01,  8.13it/s]  

Epoch #125
Mean epoch loss: 0.40656081828156465


 84%|████████▍ | 86058/102450 [3:11:58<33:44,  8.10it/s]  

Epoch #126
Mean epoch loss: 0.4056398685732488


 85%|████████▍ | 86741/102450 [3:13:29<32:23,  8.08it/s]  

Epoch #127
Mean epoch loss: 0.40347308526974646


 85%|████████▌ | 87424/102450 [3:15:00<30:50,  8.12it/s]  

Epoch #128
Mean epoch loss: 0.402722597776709


 86%|████████▌ | 88107/102450 [3:16:32<29:36,  8.07it/s]  

Epoch #129
Mean epoch loss: 0.40056508631168664


 87%|████████▋ | 88790/102450 [3:18:03<28:07,  8.10it/s]  

Epoch #130
Mean epoch loss: 0.4008040883268828

*** New best: spearman_r = 0.8641
*** New best: pearson_r = 0.9134



 87%|████████▋ | 89473/102450 [3:19:46<26:38,  8.12it/s]   

Epoch #131
Mean epoch loss: 0.39965073884318897


 88%|████████▊ | 90156/102450 [3:21:17<25:18,  8.09it/s]  

Epoch #132
Mean epoch loss: 0.39835789633879376


 89%|████████▊ | 90839/102450 [3:22:48<23:51,  8.11it/s]  

Epoch #133
Mean epoch loss: 0.3975580867003557


 89%|████████▉ | 91522/102450 [3:24:19<22:31,  8.09it/s]  

Epoch #134
Mean epoch loss: 0.3968283404019357


 90%|█████████ | 92205/102450 [3:25:49<20:59,  8.13it/s]  

Epoch #135
Mean epoch loss: 0.39568749074530984


 91%|█████████ | 92888/102450 [3:27:20<19:43,  8.08it/s]  

Epoch #136
Mean epoch loss: 0.39464123396538364


 91%|█████████▏| 93571/102450 [3:28:51<18:14,  8.11it/s]  

Epoch #137
Mean epoch loss: 0.39394115018460624


 92%|█████████▏| 94254/102450 [3:30:22<16:50,  8.11it/s]  

Epoch #138
Mean epoch loss: 0.39387297063062343


 93%|█████████▎| 94937/102450 [3:31:53<15:29,  8.08it/s]  

Epoch #139
Mean epoch loss: 0.39385493267682076


 93%|█████████▎| 95620/102450 [3:33:24<14:02,  8.11it/s]  

Epoch #140
Mean epoch loss: 0.39244823138151824

spearman_r = 0.8641
*** New best: pearson_r = 0.9137



 94%|█████████▍| 96303/102450 [3:35:07<12:38,  8.11it/s]   

Epoch #141
Mean epoch loss: 0.3920801118299867


 95%|█████████▍| 96986/102450 [3:36:38<11:14,  8.10it/s]  

Epoch #142
Mean epoch loss: 0.39157662284321876


 95%|█████████▌| 97669/102450 [3:38:09<09:50,  8.10it/s]  

Epoch #143
Mean epoch loss: 0.3916330189603663


 96%|█████████▌| 98352/102450 [3:39:42<08:25,  8.11it/s]  

Epoch #144
Mean epoch loss: 0.39002292663251675


 97%|█████████▋| 99035/102450 [3:41:13<07:01,  8.11it/s]  

Epoch #145
Mean epoch loss: 0.3908119797706604


 97%|█████████▋| 99718/102450 [3:42:43<05:37,  8.10it/s]  

Epoch #146
Mean epoch loss: 0.3900010206165286


 98%|█████████▊| 100401/102450 [3:44:14<04:12,  8.10it/s] 

Epoch #147
Mean epoch loss: 0.3906904322423837


 99%|█████████▊| 101084/102450 [3:45:44<02:48,  8.10it/s]  

Epoch #148
Mean epoch loss: 0.3909447586728399


 99%|█████████▉| 101767/102450 [3:47:15<01:24,  8.10it/s]

Epoch #149
Mean epoch loss: 0.38955496345397145


100%|██████████| 102450/102450 [3:48:46<00:00,  8.09it/s]

Epoch #150
Mean epoch loss: 0.3899305074927049

spearman_r = 0.8641
*** New best: pearson_r = 0.9138

