In [1]:
import pandas as pd
from models import *
from tqdm import tqdm
tqdm.pandas()
from torch import nn
import json
import numpy as np
import pickle
import os
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score
from argparse import Namespace
from transformers import *
import torch
import matplotlib.pyplot as plt
import torch.utils.data
import torch.nn.functional as F
import argparse
from transformers.modeling_utils import * 
from fairseq.data.encoders.fastbpe import fastBPE
from fairseq.data import Dictionary
from vncorenlp import VnCoreNLP
from utils import *

In [2]:
# phobert_large => 1 cuda out of memory
args = Namespace(
    train_path = './data/train.csv',
    dict_path = "./phobert_base/dict.txt",
    config_path = "./phobert_base/config.json",
    rdrsegmenter_path = '/home/tuna/FDM/MarketSentiment/PhoBert-Sentiment-Classification/VnCoreNLP-master/VnCoreNLP-1.1.1.jar',
    pretrained_path = './phobert_base/model.bin',
    max_sequence_length = 256,
    batch_size = 24,
    accumulation_steps = 5,
    epochs = 5,
    fold = 0,
    seed = 69,
    lr = 3e-5,
    ckpt_path = './models',
    bpe_codes = "./phobert_base/bpe.codes"
)

In [3]:
seed_everything(69)
bpe = fastBPE(args)
rdrsegmenter = VnCoreNLP(args.rdrsegmenter_path, annotators="wseg", max_heap_size='-Xmx500m') 

In [4]:
# Load model
config = RobertaConfig.from_pretrained(
    args.config_path,
    output_hidden_states=True,
    num_labels=1
)
print(config)
model_bert = RobertaForAIViVN.from_pretrained(args.pretrained_path, config=config)
model_bert.cuda()
tsfm = model_bert.roberta

RobertaConfig {
  "_num_labels": 1,
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bad_words_ids": null,
  "bos_token_id": 0,
  "decoder_start_token_id": null,
  "do_sample": false,
  "early_stopping": false,
  "eos_token_id": 2,
  "eos_token_ids": 0,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "is_encoder_decoder": false,
  "label2id": {
    "LABEL_0": 0
  },
  "layer_norm_eps": 1e-05,
  "length_penalty": 1.0,
  "max_length": 20,
  "max_position_embeddings": 258,
  "min_length": 0,
  "model_type": "roberta",
  "no_repeat_ngram_size": 0,
  "num_attention_heads": 12,
  "num_beams": 1,
  "num_hidden_layers": 12,
  "num_return_sequences": 1,
  "output_attentions": false,
  "output_hidden_states": true,
  "output_past": true,
  "pad_token_id": 0,
  "prefix": null,
 

In [5]:
# Load the dictionary  
vocab = Dictionary()
vocab.add_from_file(args.dict_path)
print(vocab.indices)

{'<s>': 0, '<pad>': 1, '</s>': 2, '<unk>': 3, ',': 4, '.': 5, 'và': 6, 'của': 7, 'là': 8, 'các': 9, 'có': 10, 'được': 11, 'trong': 12, 'cho': 13, 'đã': 14, 'với': 15, 'một': 16, 'không': 17, 'người': 18, ')': 19, '(': 20, 'những': 21, '"': 22, 'này': 23, 'để': 24, 'ở': 25, 'khi': 26, ':': 27, 'về': 28, 'năm': 29, 'đến': 30, '-': 31, 'cũng': 32, 'vào': 33, 'trên': 34, 'tại': 35, 'nhiều': 36, 'đó': 37, 'sẽ': 38, 'từ': 39, 'ra': 40, 'phải': 41, 'như': 42, 'ngày': 43, 'lại': 44, 'bị': 45, 'ông': 46, 'làm': 47, 'hơn': 48, 'việc': 49, 'còn': 50, 'nhưng': 51, 'đang': 52, 'sau': 53, 'thì': 54, 'biết': 55, 'Việt_Nam': 56, 'đi': 57, 'nước': 58, 'rất': 59, 'mới': 60, 'sự': 61, 'có_thể': 62, 'theo': 63, 'mà': 64, ';': 65, 'chỉ': 66, 'nhất': 67, 'mình': 68, 'nhà': 69, 'tôi': 70, 'trước': 71, 'lên': 72, 'con': 73, 'vẫn': 74, 'tới': 75, '2': 76, 'nên': 77, 'tháng': 78, 'Theo': 79, 'đồng': 80, 'cùng': 81, 'hai': 82, 'anh': 83, 'cao': 84, 'khác': 85, 'họ': 86, 'rằng': 87, 'bạn': 88, 'qua': 89, 'vì': 90

In [7]:
# Load training data
train_df = pd.read_csv(args.train_path,sep='\t').fillna("###")
print(train_df)
train_df.text = train_df.text.progress_apply(lambda x: ' '.join([' '.join(sent) for sent in rdrsegmenter.tokenize(x)]))
print(train_df)
y = train_df.label.values
X_train = convert_lines(train_df, vocab, bpe,args.max_sequence_length)
print(X_train)
print(X_train.shape)

  0%|          | 37/16087 [00:00<00:43, 366.16it/s]

                 id                                               text  label
0      train_000000  Dung dc sp tot cam on shop Đóng gói sản phẩm r...      0
1      train_000001  Chất lượng sản phẩm tuyệt vời . Son mịn nhưng ...      0
2      train_000002  Chất lượng sản phẩm tuyệt vời nhưng k có hộp k...      0
3      train_000003  :(( Mình hơi thất vọng 1 chút vì mình đã kỳ vọ...      1
4      train_000004  Lần trước mình mua áo gió màu hồng rất ok mà đ...      1
...             ...                                                ...    ...
16082  train_016082  Chẳng biết là Shop có biết đọc hay không mua ố...      1
16083  train_016083  Cuốn này mỏng. Đọc một buổi sáng là hết. Thú t...      1
16084  train_016084                                  Mang êm chân. Đẹp      0
16085  train_016085  Tôi đã nhận đc hàng.Sau đây là vài lời muốn nó...      1
16086  train_016086              Hình vậy mà túi xấu qá kém chất lg qá      1

[16087 rows x 3 columns]


100%|██████████| 16087/16087 [01:07<00:00, 237.46it/s]
  2%|▏         | 286/16087 [00:00<00:05, 2855.61it/s]

                 id                                               text  label
0      train_000000  Dung dc sp tot cam on shop Đóng_gói sản_phẩm r...      0
1      train_000001  Chất_lượng sản_phẩm tuyệt_vời . _Son mịn nhưng...      0
2      train_000002  Chất_lượng sản_phẩm tuyệt_vời nhưng k có hộp k...      0
3      train_000003  : ( ( Mình hơi thất_vọng 1 chút vì mình đã kỳ_...      1
4      train_000004  Lần trước mình mua áo_gió màu hồng rất ok mà đ...      1
...             ...                                                ...    ...
16082  train_016082  Chẳng biết là Shop có biết đọc hay không mua ố...      1
16083  train_016083  Cuốn này mỏng . Đọc một buổi sáng là hết . Thú...      1
16084  train_016084                                 Mang êm chân . Đẹp      0
16085  train_016085  Tôi đã nhận đc hàng.Sau đây là vài lời muốn nó...      1
16086  train_016086              Hình vậy_mà túi xấu qá kém chất lg qá      1

[16087 rows x 3 columns]


100%|██████████| 16087/16087 [00:04<00:00, 3333.30it/s]

[[6.3117e+04 1.3020e+03 8.8400e+02 ... 1.0000e+00 1.0000e+00 1.0000e+00]
 [6.3117e+04 1.3020e+03 8.8400e+02 ... 1.0000e+00 1.0000e+00 1.0000e+00]
 [6.3117e+04 1.3020e+03 8.8400e+02 ... 1.0000e+00 1.0000e+00 1.0000e+00]
 ...
 [6.3117e+04 1.3020e+03 8.8400e+02 ... 1.0000e+00 1.0000e+00 1.0000e+00]
 [6.3117e+04 1.3020e+03 8.8400e+02 ... 1.0000e+00 1.0000e+00 1.0000e+00]
 [6.3117e+04 1.3020e+03 8.8400e+02 ... 1.0000e+00 1.0000e+00 1.0000e+00]]





In [12]:
# Creating optimizer and lr schedulers
param_optimizer = list(model_bert.named_parameters())
print(len(param_optimizer))
print('param_optimizer', param_optimizer)
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
print('optimizer_grouped_parameters', optimizer_grouped_parameters)
num_train_optimization_steps = int(args.epochs*len(train_df)/args.batch_size/args.accumulation_steps)
optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, correct_bias=False)  # To reproduce BertAdam specific behavior set correct_bias=False
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=num_train_optimization_steps)  # PyTorch scheduler
scheduler0 = get_constant_schedule(optimizer)  # PyTorch scheduler

201
param_optimizer [('roberta.embeddings.word_embeddings.weight', Parameter containing:
tensor([[ 0.0472, -0.0314,  0.0343,  ...,  0.0144, -0.0032, -0.0680],
        [ 0.0186, -0.0044,  0.0190,  ...,  0.0169, -0.0055, -0.0302],
        [-0.0074, -0.0078, -0.0142,  ...,  0.0177, -0.0011,  0.0023],
        ...,
        [ 0.0374, -0.0152,  0.0123,  ...,  0.0210, -0.0214, -0.0185],
        [ 0.0266,  0.0122, -0.0232,  ...,  0.0253,  0.0002, -0.0303],
        [ 0.0106, -0.0063, -0.0064,  ...,  0.0092, -0.0129, -0.0153]],
       device='cuda:0', requires_grad=True)), ('roberta.embeddings.position_embeddings.weight', Parameter containing:
tensor([[ 0.0150, -0.0027, -0.0069,  ..., -0.0021,  0.0077,  0.0108],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0401, -0.0189, -0.0934,  ..., -0.0437,  0.0582,  0.0378],
        ...,
        [ 0.0010, -0.0455, -0.0163,  ..., -0.0149,  0.0101,  0.0913],
        [-0.0101, -0.0156,  0.0074,  ..., -0.0072, -0.0033,  0.132

optimizer_grouped_parameters [{'params': [Parameter containing:
tensor([[ 0.0472, -0.0314,  0.0343,  ...,  0.0144, -0.0032, -0.0680],
        [ 0.0186, -0.0044,  0.0190,  ...,  0.0169, -0.0055, -0.0302],
        [-0.0074, -0.0078, -0.0142,  ...,  0.0177, -0.0011,  0.0023],
        ...,
        [ 0.0374, -0.0152,  0.0123,  ...,  0.0210, -0.0214, -0.0185],
        [ 0.0266,  0.0122, -0.0232,  ...,  0.0253,  0.0002, -0.0303],
        [ 0.0106, -0.0063, -0.0064,  ...,  0.0092, -0.0129, -0.0153]],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([[ 0.0150, -0.0027, -0.0069,  ..., -0.0021,  0.0077,  0.0108],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0401, -0.0189, -0.0934,  ..., -0.0437,  0.0582,  0.0378],
        ...,
        [ 0.0010, -0.0455, -0.0163,  ..., -0.0149,  0.0101,  0.0913],
        [-0.0101, -0.0156,  0.0074,  ..., -0.0072, -0.0033,  0.1323],
        [-0.0528,  0.0500, -0.1011,  ...,  0.0354, -0.0190,  0.1747]],
 

In [16]:
model_bert

RobertaForAIViVN(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(64001, 768, padding_idx=1)
      (position_embeddings): Embedding(258, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-

In [8]:
if not os.path.exists(args.ckpt_path):
    os.mkdir(args.ckpt_path)

splits = list(StratifiedKFold(n_splits=5, shuffle=True, random_state=123).split(X_train, y))
for fold, (train_idx, val_idx) in enumerate(splits):
    print("Training for fold {}".format(fold))
    best_score = 0
    if fold != args.fold:
        continue
    train_dataset = torch.utils.data.TensorDataset(torch.tensor(X_train[train_idx],dtype=torch.long), torch.tensor(y[train_idx],dtype=torch.long))
    valid_dataset = torch.utils.data.TensorDataset(torch.tensor(X_train[val_idx],dtype=torch.long), torch.tensor(y[val_idx],dtype=torch.long))
    tq = tqdm(range(args.epochs + 1))
    for child in tsfm.children():
        for param in child.parameters():
            if not param.requires_grad:
                print("whoopsies")
            param.requires_grad = False
    frozen = True
    for epoch in tq:

        if epoch > 0 and frozen:
            for child in tsfm.children():
                for param in child.parameters():
                    param.requires_grad = True
            frozen = False
            del scheduler0
            torch.cuda.empty_cache()
        val_preds = []
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
        valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False)
        avg_loss = 0.
        avg_accuracy = 0.
        optimizer.zero_grad()
        pbar = tqdm(enumerate(train_loader),total=len(train_loader),leave=False)
        for i,(x_batch, y_batch) in pbar:
            model_bert.train()
            y_pred = model_bert(x_batch.cuda(), attention_mask=(x_batch>0).cuda())
            loss =  F.binary_cross_entropy_with_logits(y_pred.view(-1).cuda(),y_batch.float().cuda())
            loss = loss.mean()
            loss.backward()
            if i % args.accumulation_steps == 0 or i == len(pbar) - 1:
                optimizer.step()
                optimizer.zero_grad()
                if not frozen:
                    scheduler.step()
                else:
                    scheduler0.step()
            lossf = loss.item()
            pbar.set_postfix(loss = lossf)
            avg_loss += loss.item() / len(train_loader)
        model_bert.eval()
        pbar = tqdm(enumerate(valid_loader),total=len(valid_loader),leave=False)
        for i,(x_batch, y_batch) in pbar:
            y_pred = model_bert(x_batch.cuda(), attention_mask=(x_batch>0).cuda())
            y_pred = y_pred.squeeze().detach().cpu().numpy()
            val_preds = np.concatenate([val_preds, np.atleast_1d(y_pred)])
        val_preds = sigmoid(val_preds)
        best_th = 0
        score = f1_score(y[val_idx], val_preds > 0.5)
        print(f"\nAUC = {roc_auc_score(y[val_idx], val_preds):.4f}, F1 score @0.5 = {score:.4f}")
        if score >= best_score:
            torch.save(model_bert.state_dict(),os.path.join(args.ckpt_path, f"model_{fold}.bin"))
            best_score = score

  0%|          | 0/6 [00:00<?, ?it/s]
	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha)

  0%|          | 0/537 [00:00<?, ?it/s, loss=0.731][A
  0%|          | 1/537 [00:00<01:32,  5.79it/s, loss=0.731][A

Training for fold 0



  0%|          | 1/537 [00:00<01:32,  5.79it/s, loss=0.687][A
  0%|          | 2/537 [00:00<01:28,  6.03it/s, loss=0.687][A
  0%|          | 2/537 [00:00<01:28,  6.03it/s, loss=0.765][A
  1%|          | 3/537 [00:00<01:23,  6.39it/s, loss=0.765][A
  1%|          | 3/537 [00:00<01:23,  6.39it/s, loss=0.716][A
  1%|          | 4/537 [00:00<01:20,  6.64it/s, loss=0.716][A
  1%|          | 4/537 [00:00<01:20,  6.64it/s, loss=0.721][A
  1%|          | 5/537 [00:00<01:17,  6.85it/s, loss=0.721][A
  1%|          | 5/537 [00:00<01:17,  6.85it/s, loss=0.715][A
  1%|          | 6/537 [00:00<01:15,  7.01it/s, loss=0.715][A
  1%|          | 6/537 [00:00<01:15,  7.01it/s, loss=0.692][A
  1%|▏         | 7/537 [00:01<01:14,  7.10it/s, loss=0.692][A
  1%|▏         | 7/537 [00:01<01:14,  7.10it/s, loss=0.746][A
  1%|▏         | 8/537 [00:01<01:13,  7.19it/s, loss=0.746][A
  1%|▏         | 8/537 [00:01<01:13,  7.19it/s, loss=0.721][A
  2%|▏         | 9/537 [00:01<01:12,  7.25it/s, loss=0

 24%|██▍       | 129/537 [00:17<00:55,  7.29it/s, loss=0.615][A
 24%|██▍       | 129/537 [00:17<00:55,  7.29it/s, loss=0.705][A
 24%|██▍       | 130/537 [00:17<00:55,  7.29it/s, loss=0.705][A
 24%|██▍       | 130/537 [00:17<00:55,  7.29it/s, loss=0.755][A
 24%|██▍       | 131/537 [00:17<00:55,  7.30it/s, loss=0.755][A
 24%|██▍       | 131/537 [00:18<00:55,  7.30it/s, loss=0.676][A
 25%|██▍       | 132/537 [00:18<00:55,  7.30it/s, loss=0.676][A
 25%|██▍       | 132/537 [00:18<00:55,  7.30it/s, loss=0.723][A
 25%|██▍       | 133/537 [00:18<00:55,  7.30it/s, loss=0.723][A
 25%|██▍       | 133/537 [00:18<00:55,  7.30it/s, loss=0.641][A
 25%|██▍       | 134/537 [00:18<00:55,  7.31it/s, loss=0.641][A
 25%|██▍       | 134/537 [00:18<00:55,  7.31it/s, loss=0.611][A
 25%|██▌       | 135/537 [00:18<00:54,  7.31it/s, loss=0.611][A
 25%|██▌       | 135/537 [00:18<00:54,  7.31it/s, loss=0.668][A
 25%|██▌       | 136/537 [00:18<00:54,  7.31it/s, loss=0.668][A
 25%|██▌       | 136/537 

 47%|████▋     | 255/537 [00:35<00:39,  7.22it/s, loss=0.709][A
 47%|████▋     | 255/537 [00:35<00:39,  7.22it/s, loss=0.643][A
 48%|████▊     | 256/537 [00:35<00:38,  7.21it/s, loss=0.643][A
 48%|████▊     | 256/537 [00:35<00:38,  7.21it/s, loss=0.684][A
 48%|████▊     | 257/537 [00:35<00:38,  7.21it/s, loss=0.684][A
 48%|████▊     | 257/537 [00:35<00:38,  7.21it/s, loss=0.692][A
 48%|████▊     | 258/537 [00:35<00:38,  7.20it/s, loss=0.692][A
 48%|████▊     | 258/537 [00:35<00:38,  7.20it/s, loss=0.644][A
 48%|████▊     | 259/537 [00:35<00:38,  7.20it/s, loss=0.644][A
 48%|████▊     | 259/537 [00:35<00:38,  7.20it/s, loss=0.691][A
 48%|████▊     | 260/537 [00:35<00:38,  7.20it/s, loss=0.691][A
 48%|████▊     | 260/537 [00:35<00:38,  7.20it/s, loss=0.682][A
 49%|████▊     | 261/537 [00:35<00:38,  7.21it/s, loss=0.682][A
 49%|████▊     | 261/537 [00:36<00:38,  7.21it/s, loss=0.739][A
 49%|████▉     | 262/537 [00:36<00:38,  7.22it/s, loss=0.739][A
 49%|████▉     | 262/537 

 71%|███████   | 381/537 [00:52<00:21,  7.16it/s, loss=0.68][A
 71%|███████   | 381/537 [00:52<00:21,  7.16it/s, loss=0.699][A
 71%|███████   | 382/537 [00:52<00:21,  7.17it/s, loss=0.699][A
 71%|███████   | 382/537 [00:52<00:21,  7.17it/s, loss=0.638][A
 71%|███████▏  | 383/537 [00:52<00:21,  7.18it/s, loss=0.638][A
 71%|███████▏  | 383/537 [00:52<00:21,  7.18it/s, loss=0.643][A
 72%|███████▏  | 384/537 [00:53<00:21,  7.17it/s, loss=0.643][A
 72%|███████▏  | 384/537 [00:53<00:21,  7.17it/s, loss=0.662][A
 72%|███████▏  | 385/537 [00:53<00:21,  7.18it/s, loss=0.662][A
 72%|███████▏  | 385/537 [00:53<00:21,  7.18it/s, loss=0.694][A
 72%|███████▏  | 386/537 [00:53<00:21,  7.17it/s, loss=0.694][A
 72%|███████▏  | 386/537 [00:53<00:21,  7.17it/s, loss=0.667][A
 72%|███████▏  | 387/537 [00:53<00:20,  7.18it/s, loss=0.667][A
 72%|███████▏  | 387/537 [00:53<00:20,  7.18it/s, loss=0.719][A
 72%|███████▏  | 388/537 [00:53<00:20,  7.18it/s, loss=0.719][A
 72%|███████▏  | 388/537 [

 94%|█████████▍| 507/537 [01:10<00:04,  7.09it/s, loss=0.696][A
 94%|█████████▍| 507/537 [01:10<00:04,  7.09it/s, loss=0.688][A
 95%|█████████▍| 508/537 [01:10<00:04,  7.10it/s, loss=0.688][A
 95%|█████████▍| 508/537 [01:10<00:04,  7.10it/s, loss=0.719][A
 95%|█████████▍| 509/537 [01:10<00:03,  7.11it/s, loss=0.719][A
 95%|█████████▍| 509/537 [01:10<00:03,  7.11it/s, loss=0.712][A
 95%|█████████▍| 510/537 [01:10<00:03,  7.11it/s, loss=0.712][A
 95%|█████████▍| 510/537 [01:10<00:03,  7.11it/s, loss=0.678][A
 95%|█████████▌| 511/537 [01:10<00:03,  7.12it/s, loss=0.678][A
 95%|█████████▌| 511/537 [01:10<00:03,  7.12it/s, loss=0.677][A
 95%|█████████▌| 512/537 [01:10<00:03,  7.13it/s, loss=0.677][A
 95%|█████████▌| 512/537 [01:11<00:03,  7.13it/s, loss=0.655][A
 96%|█████████▌| 513/537 [01:11<00:03,  7.13it/s, loss=0.655][A
 96%|█████████▌| 513/537 [01:11<00:03,  7.13it/s, loss=0.719][A
 96%|█████████▌| 514/537 [01:11<00:03,  7.14it/s, loss=0.719][A
 96%|█████████▌| 514/537 


AUC = 0.5446, F1 score @0.5 = 0.1779


 17%|█▋        | 1/6 [01:31<07:39, 91.91s/it]
  0%|          | 0/537 [00:00<?, ?it/s][A
  0%|          | 0/537 [00:00<?, ?it/s, loss=0.738][A
  0%|          | 1/537 [00:00<03:42,  2.41it/s, loss=0.738][A
  0%|          | 1/537 [00:00<03:42,  2.41it/s, loss=0.748][A
  0%|          | 2/537 [00:00<03:36,  2.47it/s, loss=0.748][A
  0%|          | 2/537 [00:01<03:36,  2.47it/s, loss=0.773][A
  1%|          | 3/537 [00:01<03:33,  2.50it/s, loss=0.773][A
  1%|          | 3/537 [00:01<03:33,  2.50it/s, loss=0.801][A
  1%|          | 4/537 [00:01<03:30,  2.53it/s, loss=0.801][A
  1%|          | 4/537 [00:01<03:30,  2.53it/s, loss=0.798][A
  1%|          | 5/537 [00:01<03:28,  2.55it/s, loss=0.798][A
  1%|          | 5/537 [00:02<03:28,  2.55it/s, loss=0.644][A
  1%|          | 6/537 [00:02<03:31,  2.51it/s, loss=0.644][A
  1%|          | 6/537 [00:02<03:31,  2.51it/s, loss=0.804][A
  1%|▏         | 7/537 [00:02<03:28,  2.54it/s, loss=0.804][A
  1%|▏         | 7/537 [00:03<03:28, 

 24%|██▎       | 127/537 [00:50<02:40,  2.55it/s, loss=0.608][A
 24%|██▍       | 128/537 [00:50<02:39,  2.56it/s, loss=0.608][A
 24%|██▍       | 128/537 [00:50<02:39,  2.56it/s, loss=0.594][A
 24%|██▍       | 129/537 [00:50<02:38,  2.57it/s, loss=0.594][A
 24%|██▍       | 129/537 [00:50<02:38,  2.57it/s, loss=0.619][A
 24%|██▍       | 130/537 [00:50<02:37,  2.58it/s, loss=0.619][A
 24%|██▍       | 130/537 [00:51<02:37,  2.58it/s, loss=0.715][A
 24%|██▍       | 131/537 [00:51<02:41,  2.52it/s, loss=0.715][A
 24%|██▍       | 131/537 [00:51<02:41,  2.52it/s, loss=0.682][A
 25%|██▍       | 132/537 [00:51<02:39,  2.55it/s, loss=0.682][A
 25%|██▍       | 132/537 [00:52<02:39,  2.55it/s, loss=0.75] [A
 25%|██▍       | 133/537 [00:52<02:37,  2.56it/s, loss=0.75][A
 25%|██▍       | 133/537 [00:52<02:37,  2.56it/s, loss=0.709][A
 25%|██▍       | 134/537 [00:52<02:36,  2.57it/s, loss=0.709][A
 25%|██▍       | 134/537 [00:52<02:36,  2.57it/s, loss=0.655][A
 25%|██▌       | 135/537 [

 47%|████▋     | 253/537 [01:39<01:51,  2.55it/s, loss=0.641][A
 47%|████▋     | 254/537 [01:39<01:50,  2.56it/s, loss=0.641][A
 47%|████▋     | 254/537 [01:39<01:50,  2.56it/s, loss=0.663][A
 47%|████▋     | 255/537 [01:39<01:49,  2.57it/s, loss=0.663][A
 47%|████▋     | 255/537 [01:40<01:49,  2.57it/s, loss=0.666][A
 48%|████▊     | 256/537 [01:40<01:51,  2.52it/s, loss=0.666][A
 48%|████▊     | 256/537 [01:40<01:51,  2.52it/s, loss=0.596][A
 48%|████▊     | 257/537 [01:40<01:50,  2.54it/s, loss=0.596][A
 48%|████▊     | 257/537 [01:41<01:50,  2.54it/s, loss=0.593][A
 48%|████▊     | 258/537 [01:41<01:49,  2.56it/s, loss=0.593][A
 48%|████▊     | 258/537 [01:41<01:49,  2.56it/s, loss=0.628][A
 48%|████▊     | 259/537 [01:41<01:48,  2.57it/s, loss=0.628][A
 48%|████▊     | 259/537 [01:41<01:48,  2.57it/s, loss=0.626][A
 48%|████▊     | 260/537 [01:41<01:47,  2.57it/s, loss=0.626][A
 48%|████▊     | 260/537 [01:42<01:47,  2.57it/s, loss=0.722][A
 49%|████▊     | 261/537 

 71%|███████   | 379/537 [02:28<01:01,  2.58it/s, loss=0.411][A
 71%|███████   | 380/537 [02:28<01:00,  2.58it/s, loss=0.411][A
 71%|███████   | 380/537 [02:29<01:00,  2.58it/s, loss=0.3]  [A
 71%|███████   | 381/537 [02:29<01:01,  2.54it/s, loss=0.3][A
 71%|███████   | 381/537 [02:29<01:01,  2.54it/s, loss=0.456][A
 71%|███████   | 382/537 [02:29<01:00,  2.56it/s, loss=0.456][A
 71%|███████   | 382/537 [02:29<01:00,  2.56it/s, loss=0.35] [A
 71%|███████▏  | 383/537 [02:29<00:59,  2.57it/s, loss=0.35][A
 71%|███████▏  | 383/537 [02:30<00:59,  2.57it/s, loss=0.526][A
 72%|███████▏  | 384/537 [02:30<00:59,  2.58it/s, loss=0.526][A
 72%|███████▏  | 384/537 [02:30<00:59,  2.58it/s, loss=0.252][A
 72%|███████▏  | 385/537 [02:30<00:58,  2.58it/s, loss=0.252][A
 72%|███████▏  | 385/537 [02:30<00:58,  2.58it/s, loss=0.152][A
 72%|███████▏  | 386/537 [02:30<00:59,  2.54it/s, loss=0.152][A
 72%|███████▏  | 386/537 [02:31<00:59,  2.54it/s, loss=0.167][A
 72%|███████▏  | 387/537 [02

 94%|█████████▍| 505/537 [03:18<00:12,  2.56it/s, loss=0.217][A
 94%|█████████▍| 506/537 [03:18<00:12,  2.51it/s, loss=0.217][A
 94%|█████████▍| 506/537 [03:18<00:12,  2.51it/s, loss=0.369][A
 94%|█████████▍| 507/537 [03:18<00:11,  2.53it/s, loss=0.369][A
 94%|█████████▍| 507/537 [03:18<00:11,  2.53it/s, loss=0.284][A
 95%|█████████▍| 508/537 [03:18<00:11,  2.55it/s, loss=0.284][A
 95%|█████████▍| 508/537 [03:19<00:11,  2.55it/s, loss=0.333][A
 95%|█████████▍| 509/537 [03:19<00:10,  2.55it/s, loss=0.333][A
 95%|█████████▍| 509/537 [03:19<00:10,  2.55it/s, loss=0.348][A
 95%|█████████▍| 510/537 [03:19<00:10,  2.56it/s, loss=0.348][A
 95%|█████████▍| 510/537 [03:20<00:10,  2.56it/s, loss=0.219][A
 95%|█████████▌| 511/537 [03:20<00:10,  2.51it/s, loss=0.219][A
 95%|█████████▌| 511/537 [03:20<00:10,  2.51it/s, loss=0.216][A
 95%|█████████▌| 512/537 [03:20<00:09,  2.54it/s, loss=0.216][A
 95%|█████████▌| 512/537 [03:20<00:09,  2.54it/s, loss=0.236][A
 96%|█████████▌| 513/537 


AUC = 0.9584, F1 score @0.5 = 0.8895


 33%|███▎      | 2/6 [05:19<08:50, 132.71s/it]
  0%|          | 0/537 [00:00<?, ?it/s][A
  0%|          | 0/537 [00:00<?, ?it/s, loss=0.114][A
  0%|          | 1/537 [00:00<03:42,  2.41it/s, loss=0.114][A
  0%|          | 1/537 [00:00<03:42,  2.41it/s, loss=0.28] [A
  0%|          | 2/537 [00:00<03:37,  2.46it/s, loss=0.28][A
  0%|          | 2/537 [00:01<03:37,  2.46it/s, loss=0.0348][A
  1%|          | 3/537 [00:01<03:34,  2.49it/s, loss=0.0348][A
  1%|          | 3/537 [00:01<03:34,  2.49it/s, loss=0.17]  [A
  1%|          | 4/537 [00:01<03:31,  2.51it/s, loss=0.17][A
  1%|          | 4/537 [00:01<03:31,  2.51it/s, loss=0.226][A
  1%|          | 5/537 [00:01<03:30,  2.53it/s, loss=0.226][A
  1%|          | 5/537 [00:02<03:30,  2.53it/s, loss=0.334][A
  1%|          | 6/537 [00:02<03:33,  2.48it/s, loss=0.334][A
  1%|          | 6/537 [00:02<03:33,  2.48it/s, loss=0.306][A
  1%|▏         | 7/537 [00:02<03:31,  2.51it/s, loss=0.306][A
  1%|▏         | 7/537 [00:03<03:31

 24%|██▎       | 127/537 [00:50<02:41,  2.53it/s, loss=0.181][A
 24%|██▎       | 127/537 [00:50<02:41,  2.53it/s, loss=0.165][A
 24%|██▍       | 128/537 [00:50<02:40,  2.55it/s, loss=0.165][A
 24%|██▍       | 128/537 [00:50<02:40,  2.55it/s, loss=0.22] [A
 24%|██▍       | 129/537 [00:50<02:39,  2.56it/s, loss=0.22][A
 24%|██▍       | 129/537 [00:51<02:39,  2.56it/s, loss=0.508][A
 24%|██▍       | 130/537 [00:51<02:38,  2.56it/s, loss=0.508][A
 24%|██▍       | 130/537 [00:51<02:38,  2.56it/s, loss=0.227][A
 24%|██▍       | 131/537 [00:51<02:41,  2.51it/s, loss=0.227][A
 24%|██▍       | 131/537 [00:51<02:41,  2.51it/s, loss=0.144][A
 25%|██▍       | 132/537 [00:51<02:39,  2.53it/s, loss=0.144][A
 25%|██▍       | 132/537 [00:52<02:39,  2.53it/s, loss=0.182][A
 25%|██▍       | 133/537 [00:52<02:38,  2.55it/s, loss=0.182][A
 25%|██▍       | 133/537 [00:52<02:38,  2.55it/s, loss=0.21] [A
 25%|██▍       | 134/537 [00:52<02:37,  2.56it/s, loss=0.21][A
 25%|██▍       | 134/537 [0

 47%|████▋     | 252/537 [01:39<01:53,  2.52it/s, loss=0.373][A
 47%|████▋     | 252/537 [01:39<01:53,  2.52it/s, loss=0.415][A
 47%|████▋     | 253/537 [01:39<01:52,  2.53it/s, loss=0.415][A
 47%|████▋     | 253/537 [01:39<01:52,  2.53it/s, loss=0.188][A
 47%|████▋     | 254/537 [01:39<01:51,  2.54it/s, loss=0.188][A
 47%|████▋     | 254/537 [01:40<01:51,  2.54it/s, loss=0.27] [A
 47%|████▋     | 255/537 [01:40<01:50,  2.55it/s, loss=0.27][A
 47%|████▋     | 255/537 [01:40<01:50,  2.55it/s, loss=0.265][A
 48%|████▊     | 256/537 [01:40<01:52,  2.49it/s, loss=0.265][A
 48%|████▊     | 256/537 [01:41<01:52,  2.49it/s, loss=0.295][A
 48%|████▊     | 257/537 [01:41<01:50,  2.52it/s, loss=0.295][A
 48%|████▊     | 257/537 [01:41<01:50,  2.52it/s, loss=0.0774][A
 48%|████▊     | 258/537 [01:41<01:50,  2.54it/s, loss=0.0774][A
 48%|████▊     | 258/537 [01:41<01:50,  2.54it/s, loss=0.453] [A
 48%|████▊     | 259/537 [01:41<01:49,  2.54it/s, loss=0.453][A
 48%|████▊     | 259/53

 70%|███████   | 377/537 [02:28<01:03,  2.53it/s, loss=0.483][A
 70%|███████   | 377/537 [02:28<01:03,  2.53it/s, loss=0.285][A
 70%|███████   | 378/537 [02:28<01:02,  2.55it/s, loss=0.285][A
 70%|███████   | 378/537 [02:29<01:02,  2.55it/s, loss=0.178][A
 71%|███████   | 379/537 [02:29<01:01,  2.55it/s, loss=0.178][A
 71%|███████   | 379/537 [02:29<01:01,  2.55it/s, loss=0.535][A
 71%|███████   | 380/537 [02:29<01:01,  2.56it/s, loss=0.535][A
 71%|███████   | 380/537 [02:30<01:01,  2.56it/s, loss=0.337][A
 71%|███████   | 381/537 [02:30<01:02,  2.51it/s, loss=0.337][A
 71%|███████   | 381/537 [02:30<01:02,  2.51it/s, loss=0.176][A
 71%|███████   | 382/537 [02:30<01:01,  2.53it/s, loss=0.176][A
 71%|███████   | 382/537 [02:30<01:01,  2.53it/s, loss=0.349][A
 71%|███████▏  | 383/537 [02:30<01:00,  2.55it/s, loss=0.349][A
 71%|███████▏  | 383/537 [02:31<01:00,  2.55it/s, loss=0.185][A
 72%|███████▏  | 384/537 [02:31<00:59,  2.55it/s, loss=0.185][A
 72%|███████▏  | 384/537 

 93%|█████████▎| 502/537 [03:17<00:13,  2.53it/s, loss=0.409][A
 93%|█████████▎| 502/537 [03:18<00:13,  2.53it/s, loss=0.19] [A
 94%|█████████▎| 503/537 [03:18<00:13,  2.55it/s, loss=0.19][A
 94%|█████████▎| 503/537 [03:18<00:13,  2.55it/s, loss=0.071][A
 94%|█████████▍| 504/537 [03:18<00:12,  2.55it/s, loss=0.071][A
 94%|█████████▍| 504/537 [03:18<00:12,  2.55it/s, loss=0.258][A
 94%|█████████▍| 505/537 [03:18<00:12,  2.56it/s, loss=0.258][A
 94%|█████████▍| 505/537 [03:19<00:12,  2.56it/s, loss=0.0815][A
 94%|█████████▍| 506/537 [03:19<00:12,  2.51it/s, loss=0.0815][A
 94%|█████████▍| 506/537 [03:19<00:12,  2.51it/s, loss=0.101] [A
 94%|█████████▍| 507/537 [03:19<00:11,  2.53it/s, loss=0.101][A
 94%|█████████▍| 507/537 [03:20<00:11,  2.53it/s, loss=0.188][A
 95%|█████████▍| 508/537 [03:20<00:11,  2.55it/s, loss=0.188][A
 95%|█████████▍| 508/537 [03:20<00:11,  2.55it/s, loss=0.192][A
 95%|█████████▍| 509/537 [03:20<00:10,  2.55it/s, loss=0.192][A
 95%|█████████▍| 509/53


AUC = 0.9683, F1 score @0.5 = 0.8602



  0%|          | 0/537 [00:00<?, ?it/s, loss=0.241][A
  0%|          | 1/537 [00:00<03:40,  2.43it/s, loss=0.241][A
  0%|          | 1/537 [00:00<03:40,  2.43it/s, loss=0.102][A
  0%|          | 2/537 [00:00<03:36,  2.48it/s, loss=0.102][A
  0%|          | 2/537 [00:01<03:36,  2.48it/s, loss=0.229][A
  1%|          | 3/537 [00:01<03:33,  2.50it/s, loss=0.229][A
  1%|          | 3/537 [00:01<03:33,  2.50it/s, loss=0.322][A
  1%|          | 4/537 [00:01<03:31,  2.52it/s, loss=0.322][A
  1%|          | 4/537 [00:01<03:31,  2.52it/s, loss=0.137][A
  1%|          | 5/537 [00:01<03:29,  2.54it/s, loss=0.137][A
  1%|          | 5/537 [00:02<03:29,  2.54it/s, loss=0.121][A
  1%|          | 6/537 [00:02<03:32,  2.50it/s, loss=0.121][A
  1%|          | 6/537 [00:02<03:32,  2.50it/s, loss=0.161][A
  1%|▏         | 7/537 [00:02<03:29,  2.52it/s, loss=0.161][A
  1%|▏         | 7/537 [00:03<03:29,  2.52it/s, loss=0.24] [A
  1%|▏         | 8/537 [00:03<03:28,  2.54it/s, loss=0.24][A


 24%|██▎       | 127/537 [00:50<02:41,  2.53it/s, loss=0.219][A
 24%|██▎       | 127/537 [00:50<02:41,  2.53it/s, loss=0.173][A
 24%|██▍       | 128/537 [00:50<02:40,  2.55it/s, loss=0.173][A
 24%|██▍       | 128/537 [00:50<02:40,  2.55it/s, loss=0.22] [A
 24%|██▍       | 129/537 [00:50<02:39,  2.55it/s, loss=0.22][A
 24%|██▍       | 129/537 [00:51<02:39,  2.55it/s, loss=0.142][A
 24%|██▍       | 130/537 [00:51<02:38,  2.56it/s, loss=0.142][A
 24%|██▍       | 130/537 [00:51<02:38,  2.56it/s, loss=0.303][A
 24%|██▍       | 131/537 [00:51<02:40,  2.53it/s, loss=0.303][A
 24%|██▍       | 131/537 [00:52<02:40,  2.53it/s, loss=0.0956][A
 25%|██▍       | 132/537 [00:52<02:38,  2.55it/s, loss=0.0956][A
 25%|██▍       | 132/537 [00:52<02:38,  2.55it/s, loss=0.112] [A
 25%|██▍       | 133/537 [00:52<02:37,  2.56it/s, loss=0.112][A
 25%|██▍       | 133/537 [00:52<02:37,  2.56it/s, loss=0.189][A
 25%|██▍       | 134/537 [00:52<02:36,  2.57it/s, loss=0.189][A
 25%|██▍       | 134/53

 47%|████▋     | 252/537 [01:39<01:52,  2.53it/s, loss=0.0922][A
 47%|████▋     | 252/537 [01:39<01:52,  2.53it/s, loss=0.223] [A
 47%|████▋     | 253/537 [01:39<01:51,  2.55it/s, loss=0.223][A
 47%|████▋     | 253/537 [01:39<01:51,  2.55it/s, loss=0.197][A
 47%|████▋     | 254/537 [01:39<01:50,  2.55it/s, loss=0.197][A
 47%|████▋     | 254/537 [01:40<01:50,  2.55it/s, loss=0.253][A
 47%|████▋     | 255/537 [01:40<01:50,  2.56it/s, loss=0.253][A
 47%|████▋     | 255/537 [01:40<01:50,  2.56it/s, loss=0.0731][A
 48%|████▊     | 256/537 [01:40<01:51,  2.51it/s, loss=0.0731][A
 48%|████▊     | 256/537 [01:41<01:51,  2.51it/s, loss=0.1]   [A
 48%|████▊     | 257/537 [01:41<01:50,  2.53it/s, loss=0.1][A
 48%|████▊     | 257/537 [01:41<01:50,  2.53it/s, loss=0.365][A
 48%|████▊     | 258/537 [01:41<01:49,  2.55it/s, loss=0.365][A
 48%|████▊     | 258/537 [01:41<01:49,  2.55it/s, loss=0.244][A
 48%|████▊     | 259/537 [01:41<01:48,  2.55it/s, loss=0.244][A
 48%|████▊     | 259/5

 70%|███████   | 377/537 [02:28<01:03,  2.54it/s, loss=0.185][A
 70%|███████   | 377/537 [02:28<01:03,  2.54it/s, loss=0.217][A
 70%|███████   | 378/537 [02:28<01:02,  2.55it/s, loss=0.217][A
 70%|███████   | 378/537 [02:29<01:02,  2.55it/s, loss=0.103][A
 71%|███████   | 379/537 [02:29<01:01,  2.55it/s, loss=0.103][A
 71%|███████   | 379/537 [02:29<01:01,  2.55it/s, loss=0.134][A
 71%|███████   | 380/537 [02:29<01:01,  2.56it/s, loss=0.134][A
 71%|███████   | 380/537 [02:29<01:01,  2.56it/s, loss=0.068][A
 71%|███████   | 381/537 [02:29<01:02,  2.50it/s, loss=0.068][A
 71%|███████   | 381/537 [02:30<01:02,  2.50it/s, loss=0.362][A
 71%|███████   | 382/537 [02:30<01:01,  2.53it/s, loss=0.362][A
 71%|███████   | 382/537 [02:30<01:01,  2.53it/s, loss=0.0723][A
 71%|███████▏  | 383/537 [02:30<01:00,  2.54it/s, loss=0.0723][A
 71%|███████▏  | 383/537 [02:31<01:00,  2.54it/s, loss=0.0712][A
 72%|███████▏  | 384/537 [02:31<01:00,  2.55it/s, loss=0.0712][A
 72%|███████▏  | 384/

 93%|█████████▎| 502/537 [03:17<00:13,  2.53it/s, loss=0.245][A
 93%|█████████▎| 502/537 [03:18<00:13,  2.53it/s, loss=0.0677][A
 94%|█████████▎| 503/537 [03:18<00:13,  2.54it/s, loss=0.0677][A
 94%|█████████▎| 503/537 [03:18<00:13,  2.54it/s, loss=0.118] [A
 94%|█████████▍| 504/537 [03:18<00:12,  2.55it/s, loss=0.118][A
 94%|█████████▍| 504/537 [03:18<00:12,  2.55it/s, loss=0.207][A
 94%|█████████▍| 505/537 [03:18<00:12,  2.56it/s, loss=0.207][A
 94%|█████████▍| 505/537 [03:19<00:12,  2.56it/s, loss=0.356][A
 94%|█████████▍| 506/537 [03:19<00:12,  2.51it/s, loss=0.356][A
 94%|█████████▍| 506/537 [03:19<00:12,  2.51it/s, loss=0.143][A
 94%|█████████▍| 507/537 [03:19<00:11,  2.53it/s, loss=0.143][A
 94%|█████████▍| 507/537 [03:20<00:11,  2.53it/s, loss=0.0736][A
 95%|█████████▍| 508/537 [03:20<00:11,  2.55it/s, loss=0.0736][A
 95%|█████████▍| 508/537 [03:20<00:11,  2.55it/s, loss=0.0709][A
 95%|█████████▍| 509/537 [03:20<00:10,  2.55it/s, loss=0.0709][A
 95%|█████████▍| 5


AUC = 0.9709, F1 score @0.5 = 0.9037


 67%|██████▋   | 4/6 [12:57<06:03, 181.79s/it]
  0%|          | 0/537 [00:00<?, ?it/s][A
  0%|          | 0/537 [00:00<?, ?it/s, loss=0.166][A
  0%|          | 1/537 [00:00<03:42,  2.41it/s, loss=0.166][A
  0%|          | 1/537 [00:00<03:42,  2.41it/s, loss=0.153][A
  0%|          | 2/537 [00:00<03:37,  2.46it/s, loss=0.153][A
  0%|          | 2/537 [00:01<03:37,  2.46it/s, loss=0.127][A
  1%|          | 3/537 [00:01<03:33,  2.50it/s, loss=0.127][A
  1%|          | 3/537 [00:01<03:33,  2.50it/s, loss=0.124][A
  1%|          | 4/537 [00:01<03:31,  2.52it/s, loss=0.124][A
  1%|          | 4/537 [00:01<03:31,  2.52it/s, loss=0.0475][A
  1%|          | 5/537 [00:01<03:29,  2.54it/s, loss=0.0475][A
  1%|          | 5/537 [00:02<03:29,  2.54it/s, loss=0.0867][A
  1%|          | 6/537 [00:02<03:33,  2.49it/s, loss=0.0867][A
  1%|          | 6/537 [00:02<03:33,  2.49it/s, loss=0.0806][A
  1%|▏         | 7/537 [00:02<03:30,  2.52it/s, loss=0.0806][A
  1%|▏         | 7/537 [00:03<

 23%|██▎       | 126/537 [00:50<02:44,  2.50it/s, loss=0.144][A
 24%|██▎       | 127/537 [00:50<02:42,  2.53it/s, loss=0.144][A
 24%|██▎       | 127/537 [00:50<02:42,  2.53it/s, loss=0.113][A
 24%|██▍       | 128/537 [00:50<02:40,  2.54it/s, loss=0.113][A
 24%|██▍       | 128/537 [00:50<02:40,  2.54it/s, loss=0.212][A
 24%|██▍       | 129/537 [00:50<02:39,  2.55it/s, loss=0.212][A
 24%|██▍       | 129/537 [00:51<02:39,  2.55it/s, loss=0.152][A
 24%|██▍       | 130/537 [00:51<02:39,  2.56it/s, loss=0.152][A
 24%|██▍       | 130/537 [00:51<02:39,  2.56it/s, loss=0.0541][A
 24%|██▍       | 131/537 [00:51<02:42,  2.51it/s, loss=0.0541][A
 24%|██▍       | 131/537 [00:52<02:42,  2.51it/s, loss=0.0668][A
 25%|██▍       | 132/537 [00:52<02:39,  2.53it/s, loss=0.0668][A
 25%|██▍       | 132/537 [00:52<02:39,  2.53it/s, loss=0.125] [A
 25%|██▍       | 133/537 [00:52<02:38,  2.54it/s, loss=0.125][A
 25%|██▍       | 133/537 [00:52<02:38,  2.54it/s, loss=0.204][A
 25%|██▍       | 134

 47%|████▋     | 251/537 [01:39<01:54,  2.51it/s, loss=0.132][A
 47%|████▋     | 252/537 [01:39<01:52,  2.52it/s, loss=0.132][A
 47%|████▋     | 252/537 [01:39<01:52,  2.52it/s, loss=0.0802][A
 47%|████▋     | 253/537 [01:39<01:51,  2.54it/s, loss=0.0802][A
 47%|████▋     | 253/537 [01:40<01:51,  2.54it/s, loss=0.113] [A
 47%|████▋     | 254/537 [01:40<01:51,  2.55it/s, loss=0.113][A
 47%|████▋     | 254/537 [01:40<01:51,  2.55it/s, loss=0.178][A
 47%|████▋     | 255/537 [01:40<01:50,  2.56it/s, loss=0.178][A
 47%|████▋     | 255/537 [01:41<01:50,  2.56it/s, loss=0.0761][A
 48%|████▊     | 256/537 [01:41<01:52,  2.50it/s, loss=0.0761][A
 48%|████▊     | 256/537 [01:41<01:52,  2.50it/s, loss=0.0455][A
 48%|████▊     | 257/537 [01:41<01:50,  2.53it/s, loss=0.0455][A
 48%|████▊     | 257/537 [01:41<01:50,  2.53it/s, loss=0.249] [A
 48%|████▊     | 258/537 [01:41<01:49,  2.55it/s, loss=0.249][A
 48%|████▊     | 258/537 [01:42<01:49,  2.55it/s, loss=0.0876][A
 48%|████▊     |

 70%|███████   | 376/537 [02:28<01:04,  2.51it/s, loss=0.0387][A
 70%|███████   | 377/537 [02:28<01:03,  2.53it/s, loss=0.0387][A
 70%|███████   | 377/537 [02:29<01:03,  2.53it/s, loss=0.113] [A
 70%|███████   | 378/537 [02:29<01:02,  2.55it/s, loss=0.113][A
 70%|███████   | 378/537 [02:29<01:02,  2.55it/s, loss=0.101][A
 71%|███████   | 379/537 [02:29<01:01,  2.55it/s, loss=0.101][A
 71%|███████   | 379/537 [02:29<01:01,  2.55it/s, loss=0.0916][A
 71%|███████   | 380/537 [02:29<01:01,  2.56it/s, loss=0.0916][A
 71%|███████   | 380/537 [02:30<01:01,  2.56it/s, loss=0.036] [A
 71%|███████   | 381/537 [02:30<01:02,  2.50it/s, loss=0.036][A
 71%|███████   | 381/537 [02:30<01:02,  2.50it/s, loss=0.236][A
 71%|███████   | 382/537 [02:30<01:01,  2.53it/s, loss=0.236][A
 71%|███████   | 382/537 [02:31<01:01,  2.53it/s, loss=0.145][A
 71%|███████▏  | 383/537 [02:31<01:00,  2.54it/s, loss=0.145][A
 71%|███████▏  | 383/537 [02:31<01:00,  2.54it/s, loss=0.25] [A
 72%|███████▏  | 38

 93%|█████████▎| 501/537 [03:17<00:14,  2.50it/s, loss=0.205] [A
 93%|█████████▎| 502/537 [03:17<00:13,  2.53it/s, loss=0.205][A
 93%|█████████▎| 502/537 [03:18<00:13,  2.53it/s, loss=0.0909][A
 94%|█████████▎| 503/537 [03:18<00:13,  2.54it/s, loss=0.0909][A
 94%|█████████▎| 503/537 [03:18<00:13,  2.54it/s, loss=0.186] [A
 94%|█████████▍| 504/537 [03:18<00:12,  2.55it/s, loss=0.186][A
 94%|█████████▍| 504/537 [03:19<00:12,  2.55it/s, loss=0.0277][A
 94%|█████████▍| 505/537 [03:19<00:12,  2.56it/s, loss=0.0277][A
 94%|█████████▍| 505/537 [03:19<00:12,  2.56it/s, loss=0.0751][A
 94%|█████████▍| 506/537 [03:19<00:12,  2.50it/s, loss=0.0751][A
 94%|█████████▍| 506/537 [03:19<00:12,  2.50it/s, loss=0.524] [A
 94%|█████████▍| 507/537 [03:19<00:11,  2.53it/s, loss=0.524][A
 94%|█████████▍| 507/537 [03:20<00:11,  2.53it/s, loss=0.105][A
 95%|█████████▍| 508/537 [03:20<00:11,  2.54it/s, loss=0.105][A
 95%|█████████▍| 508/537 [03:20<00:11,  2.54it/s, loss=0.185][A
 95%|█████████▍|


AUC = 0.9707, F1 score @0.5 = 0.9004



  0%|          | 0/537 [00:00<?, ?it/s, loss=0.286][A
  0%|          | 1/537 [00:00<03:41,  2.43it/s, loss=0.286][A
  0%|          | 1/537 [00:00<03:41,  2.43it/s, loss=0.0785][A
  0%|          | 2/537 [00:00<03:36,  2.48it/s, loss=0.0785][A
  0%|          | 2/537 [00:01<03:36,  2.48it/s, loss=0.0594][A
  1%|          | 3/537 [00:01<03:33,  2.51it/s, loss=0.0594][A
  1%|          | 3/537 [00:01<03:33,  2.51it/s, loss=0.0597][A
  1%|          | 4/537 [00:01<03:30,  2.53it/s, loss=0.0597][A
  1%|          | 4/537 [00:01<03:30,  2.53it/s, loss=0.149] [A
  1%|          | 5/537 [00:01<03:29,  2.54it/s, loss=0.149][A
  1%|          | 5/537 [00:02<03:29,  2.54it/s, loss=0.217][A
  1%|          | 6/537 [00:02<03:32,  2.50it/s, loss=0.217][A
  1%|          | 6/537 [00:02<03:32,  2.50it/s, loss=0.101][A
  1%|▏         | 7/537 [00:02<03:29,  2.53it/s, loss=0.101][A
  1%|▏         | 7/537 [00:03<03:29,  2.53it/s, loss=0.11] [A
  1%|▏         | 8/537 [00:03<03:27,  2.54it/s, loss=0.

 23%|██▎       | 126/537 [00:49<02:43,  2.51it/s, loss=0.207] [A
 24%|██▎       | 127/537 [00:49<02:41,  2.54it/s, loss=0.207][A
 24%|██▎       | 127/537 [00:50<02:41,  2.54it/s, loss=0.0235][A
 24%|██▍       | 128/537 [00:50<02:40,  2.55it/s, loss=0.0235][A
 24%|██▍       | 128/537 [00:50<02:40,  2.55it/s, loss=0.0793][A
 24%|██▍       | 129/537 [00:50<02:39,  2.56it/s, loss=0.0793][A
 24%|██▍       | 129/537 [00:51<02:39,  2.56it/s, loss=0.103] [A
 24%|██▍       | 130/537 [00:51<02:38,  2.56it/s, loss=0.103][A
 24%|██▍       | 130/537 [00:51<02:38,  2.56it/s, loss=0.0289][A
 24%|██▍       | 131/537 [00:51<02:41,  2.51it/s, loss=0.0289][A
 24%|██▍       | 131/537 [00:51<02:41,  2.51it/s, loss=0.171] [A
 25%|██▍       | 132/537 [00:51<02:39,  2.53it/s, loss=0.171][A
 25%|██▍       | 132/537 [00:52<02:39,  2.53it/s, loss=0.261][A
 25%|██▍       | 133/537 [00:52<02:38,  2.55it/s, loss=0.261][A
 25%|██▍       | 133/537 [00:52<02:38,  2.55it/s, loss=0.0437][A
 25%|██▍       

 47%|████▋     | 250/537 [01:38<01:51,  2.56it/s, loss=0.0731][A
 47%|████▋     | 251/537 [01:38<01:53,  2.51it/s, loss=0.0731][A
 47%|████▋     | 251/537 [01:39<01:53,  2.51it/s, loss=0.111] [A
 47%|████▋     | 252/537 [01:39<01:52,  2.54it/s, loss=0.111][A
 47%|████▋     | 252/537 [01:39<01:52,  2.54it/s, loss=0.351][A
 47%|████▋     | 253/537 [01:39<01:51,  2.55it/s, loss=0.351][A
 47%|████▋     | 253/537 [01:39<01:51,  2.55it/s, loss=0.156][A
 47%|████▋     | 254/537 [01:39<01:50,  2.56it/s, loss=0.156][A
 47%|████▋     | 254/537 [01:40<01:50,  2.56it/s, loss=0.05] [A
 47%|████▋     | 255/537 [01:40<01:49,  2.56it/s, loss=0.05][A
 47%|████▋     | 255/537 [01:40<01:49,  2.56it/s, loss=0.0761][A
 48%|████▊     | 256/537 [01:40<01:51,  2.51it/s, loss=0.0761][A
 48%|████▊     | 256/537 [01:41<01:51,  2.51it/s, loss=0.0372][A
 48%|████▊     | 257/537 [01:41<01:50,  2.54it/s, loss=0.0372][A
 48%|████▊     | 257/537 [01:41<01:50,  2.54it/s, loss=0.0211][A
 48%|████▊     | 2

 70%|██████▉   | 374/537 [02:27<01:04,  2.54it/s, loss=0.247] [A
 70%|██████▉   | 375/537 [02:27<01:03,  2.55it/s, loss=0.247][A
 70%|██████▉   | 375/537 [02:27<01:03,  2.55it/s, loss=0.0509][A
 70%|███████   | 376/537 [02:27<01:04,  2.49it/s, loss=0.0509][A
 70%|███████   | 376/537 [02:28<01:04,  2.49it/s, loss=0.13]  [A
 70%|███████   | 377/537 [02:28<01:03,  2.52it/s, loss=0.13][A
 70%|███████   | 377/537 [02:28<01:03,  2.52it/s, loss=0.147][A
 70%|███████   | 378/537 [02:28<01:02,  2.53it/s, loss=0.147][A
 70%|███████   | 378/537 [02:29<01:02,  2.53it/s, loss=0.0735][A
 71%|███████   | 379/537 [02:29<01:02,  2.54it/s, loss=0.0735][A
 71%|███████   | 379/537 [02:29<01:02,  2.54it/s, loss=0.142] [A
 71%|███████   | 380/537 [02:29<01:01,  2.55it/s, loss=0.142][A
 71%|███████   | 380/537 [02:29<01:01,  2.55it/s, loss=0.0439][A
 71%|███████   | 381/537 [02:29<01:02,  2.50it/s, loss=0.0439][A
 71%|███████   | 381/537 [02:30<01:02,  2.50it/s, loss=0.0178][A
 71%|███████   |

 93%|█████████▎| 498/537 [03:16<00:15,  2.54it/s, loss=0.04] [A
 93%|█████████▎| 499/537 [03:16<00:14,  2.54it/s, loss=0.04][A
 93%|█████████▎| 499/537 [03:16<00:14,  2.54it/s, loss=0.0431][A
 93%|█████████▎| 500/537 [03:16<00:14,  2.54it/s, loss=0.0431][A
 93%|█████████▎| 500/537 [03:17<00:14,  2.54it/s, loss=0.0358][A
 93%|█████████▎| 501/537 [03:17<00:14,  2.49it/s, loss=0.0358][A
 93%|█████████▎| 501/537 [03:17<00:14,  2.49it/s, loss=0.159] [A
 93%|█████████▎| 502/537 [03:17<00:13,  2.52it/s, loss=0.159][A
 93%|█████████▎| 502/537 [03:17<00:13,  2.52it/s, loss=0.0219][A
 94%|█████████▎| 503/537 [03:17<00:13,  2.53it/s, loss=0.0219][A
 94%|█████████▎| 503/537 [03:18<00:13,  2.53it/s, loss=0.258] [A
 94%|█████████▍| 504/537 [03:18<00:12,  2.54it/s, loss=0.258][A
 94%|█████████▍| 504/537 [03:18<00:12,  2.54it/s, loss=0.123][A
 94%|█████████▍| 505/537 [03:18<00:12,  2.55it/s, loss=0.123][A
 94%|█████████▍| 505/537 [03:19<00:12,  2.55it/s, loss=0.0352][A
 94%|█████████▍| 


AUC = 0.9716, F1 score @0.5 = 0.8978
Training for fold 1
Training for fold 2
Training for fold 3
Training for fold 4





In [15]:
for child in tsfm.children():
    for param in child.parameters():
        print(param)

Parameter containing:
tensor([[ 0.0472, -0.0314,  0.0343,  ...,  0.0144, -0.0032, -0.0680],
        [ 0.0186, -0.0044,  0.0190,  ...,  0.0169, -0.0055, -0.0302],
        [-0.0074, -0.0078, -0.0142,  ...,  0.0177, -0.0011,  0.0023],
        ...,
        [ 0.0374, -0.0152,  0.0123,  ...,  0.0210, -0.0214, -0.0185],
        [ 0.0266,  0.0122, -0.0232,  ...,  0.0253,  0.0002, -0.0303],
        [ 0.0106, -0.0063, -0.0064,  ...,  0.0092, -0.0129, -0.0153]],
       device='cuda:0', requires_grad=True)
Parameter containing:
tensor([[ 0.0150, -0.0027, -0.0069,  ..., -0.0021,  0.0077,  0.0108],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0401, -0.0189, -0.0934,  ..., -0.0437,  0.0582,  0.0378],
        ...,
        [ 0.0010, -0.0455, -0.0163,  ..., -0.0149,  0.0101,  0.0913],
        [-0.0101, -0.0156,  0.0074,  ..., -0.0072, -0.0033,  0.1323],
        [-0.0528,  0.0500, -0.1011,  ...,  0.0354, -0.0190,  0.1747]],
       device='cuda:0', requires_grad=True)


Parameter containing:
tensor([-1.0633e-03,  4.9353e-04,  1.6012e-03,  8.6880e-04, -1.1683e-03,
        -9.1493e-05, -5.3787e-04,  4.5300e-05, -3.2783e-04,  8.3447e-05,
        -2.1629e-03,  4.4632e-04,  7.9775e-04, -3.2258e-04, -8.9121e-04,
         7.3051e-04, -2.2316e-04, -3.0422e-03, -1.6060e-03, -7.1621e-04,
        -1.8466e-04, -3.3116e-04,  9.7418e-04, -1.9894e-03, -1.6785e-03,
        -2.0409e-04, -1.5926e-04,  1.1692e-03, -1.2083e-03,  1.0443e-03,
        -1.9398e-03,  2.8443e-04, -5.0354e-04, -2.3079e-03, -2.3389e-04,
        -1.6928e-03,  2.0313e-03, -1.5287e-03, -1.1768e-03, -1.1511e-03,
        -1.1292e-03,  1.7967e-03,  2.3918e-03, -2.3308e-03, -6.5506e-05,
         1.0138e-03, -1.8120e-03,  3.4084e-03, -2.8181e-04,  8.5926e-04,
         1.1282e-03,  4.0293e-04,  3.0851e-04, -1.1215e-03, -2.9659e-04,
        -3.2759e-04, -6.5851e-04, -9.6273e-04, -9.9277e-04,  1.2541e-03,
        -1.8234e-03,  9.9659e-04,  3.2711e-04, -3.7146e-04,  2.0561e-03,
        -1.5554e-03, -1.7319e

Parameter containing:
tensor([ 1.4496e-02,  1.6068e-02,  1.8967e-02,  1.3879e-01,  8.0078e-02,
         9.3872e-02,  1.4050e-01,  5.3418e-01,  8.0322e-02,  3.5950e-02,
         1.1914e-01,  6.5430e-02, -1.7731e-02,  7.2021e-02,  6.1890e-02,
        -1.3557e-02, -3.9764e-02,  1.4084e-02,  9.6619e-02, -3.7140e-02,
         1.4519e-02,  1.5030e-02,  8.5693e-02,  1.4636e-01,  1.8661e-02,
        -6.4941e-02,  6.7627e-02,  5.1331e-02,  1.0565e-01,  3.8986e-03,
        -2.9419e-02,  1.0663e-01,  4.9133e-03, -1.5991e-02, -3.0502e-02,
        -9.7733e-03,  2.4826e-02, -1.2097e-01,  3.8635e-02, -3.1281e-02,
        -2.7115e-02, -2.5162e-02,  8.4534e-03, -4.1504e-02,  2.5360e-02,
         7.4158e-02, -1.1151e-01,  3.3569e-02, -1.3908e-02, -1.3603e-02,
         7.7454e-02, -2.3560e-02, -1.0614e-01,  3.7323e-02, -1.1804e-01,
         5.3833e-02,  2.4521e-02,  2.9373e-02,  5.9631e-02, -5.5542e-02,
        -3.1586e-02,  2.0349e-01, -5.0934e-02,  9.2224e-02,  2.0111e-02,
         1.5442e-01,  5.2704e

Parameter containing:
tensor([-8.6746e-03, -4.4891e-02, -5.9540e-02,  4.4281e-02, -1.9302e-02,
        -3.7170e-02,  6.6681e-03, -2.0020e-01,  2.2308e-02,  4.6631e-02,
         2.5589e-02,  5.2490e-02, -1.9867e-02, -1.0689e-02,  3.5767e-02,
         1.4023e-02,  2.9083e-02,  3.3813e-02, -3.9429e-02, -7.4646e-02,
         1.2164e-01,  5.0934e-02, -8.4290e-02,  2.8824e-02,  2.8744e-03,
         1.3229e-02,  3.0502e-02, -1.9836e-03,  1.2199e-02,  1.2108e-02,
        -3.2318e-02,  2.0233e-02,  2.8625e-02,  4.5166e-02, -1.7365e-02,
        -3.2593e-02, -4.3732e-02,  1.9928e-02, -4.5044e-02, -2.1118e-02,
        -2.2659e-02,  1.4618e-02, -6.6223e-02,  1.1145e-01,  5.2719e-03,
        -4.3762e-02,  5.4260e-02,  3.5950e-02,  5.0995e-02,  7.4280e-02,
         6.9397e-02,  1.5091e-02,  9.6054e-03, -5.8929e-02,  4.1351e-02,
        -3.4454e-02,  2.8305e-02,  8.5878e-04,  7.7393e-02,  6.8481e-02,
         4.7112e-03, -3.8727e-02, -5.9738e-03,  6.0913e-02, -4.0558e-02,
        -6.1951e-03,  2.1271e

Parameter containing:
tensor([ 5.3741e-02,  3.5583e-02, -1.2878e-01, -2.1460e-01, -4.2084e-02,
        -7.9590e-02, -1.5979e-01, -1.7578e-01,  6.5613e-02,  2.0462e-02,
        -1.0419e-01,  6.8237e-02,  5.3711e-02,  2.1912e-02, -8.7891e-02,
        -7.0457e-03,  1.3931e-02, -2.2690e-02, -1.1560e-01,  3.4766e-01,
         1.3318e-01, -6.0883e-02, -5.0964e-02,  1.6211e-01, -3.0396e-02,
        -7.1716e-02,  2.9617e-02,  7.8796e-02, -2.5317e-01,  1.3330e-01,
         3.0853e-02, -6.6406e-02, -1.5190e-02, -1.3794e-01,  7.8857e-02,
         5.6976e-02, -4.4785e-03,  1.6162e-01, -4.4708e-02,  1.3440e-01,
         7.0190e-02,  1.8433e-01, -2.3376e-02, -1.0376e-01,  2.8735e-01,
        -1.9421e-01,  2.8491e-01,  5.7007e-02,  7.4158e-03,  1.5030e-02,
        -9.0515e-02,  1.3908e-02,  5.6824e-02,  5.2124e-02,  8.0225e-01,
        -2.5467e-02, -8.0688e-02, -2.9984e-02,  1.7822e-01,  1.9019e-01,
         1.7883e-01, -6.3379e-01,  4.7798e-03, -9.9182e-02, -1.0046e-01,
        -1.1761e-01, -2.2156e

Parameter containing:
tensor([0.4812, 0.4543, 0.3337, 0.6255, 0.5303, 0.5571, 0.6479, 0.3899, 0.4551,
        0.5518, 0.6201, 0.4656, 0.6230, 0.5977, 0.4138, 0.4534, 0.6567, 0.5513,
        0.6211, 0.3164, 0.3643, 0.4138, 0.4775, 0.4280, 0.6616, 0.3784, 0.5269,
        0.4590, 0.5620, 0.3398, 0.4478, 0.3708, 0.4148, 0.6406, 0.6079, 0.5493,
        0.4785, 0.5303, 0.5439, 0.5322, 0.5059, 0.6226, 0.5762, 0.5068, 0.4324,
        0.3872, 0.5073, 0.4922, 0.6006, 0.4836, 0.4268, 0.5469, 0.6074, 0.6143,
        0.8018, 0.5063, 0.5640, 0.6621, 0.4976, 0.3679, 0.6323, 0.6865, 0.5977,
        0.3000, 0.3047, 0.6323, 0.6655, 0.4277, 0.4041, 0.5493, 0.2988, 0.6743,
        0.6714, 0.6450, 0.4080, 0.5122, 0.5225, 0.5054, 0.5459, 0.6309, 0.6602,
        0.6289, 0.4019, 0.4685, 0.6328, 0.6167, 0.4885, 0.4355, 0.4587, 0.4070,
        0.3401, 0.6138, 0.5596, 0.5151, 0.6582, 0.4746, 0.3870, 0.5957, 0.4414,
        0.6401, 0.4797, 0.4578, 0.6328, 0.6206, 0.4272, 0.6313, 0.4189, 0.6294,
        0.4653, 0.

Parameter containing:
tensor([ 2.4765e-02, -3.2349e-02,  3.3539e-02,  4.8187e-02,  2.9785e-02,
         4.7577e-02,  6.7139e-03,  3.2318e-02, -6.5369e-02, -9.9335e-03,
        -1.3710e-02, -2.4300e-03, -2.6047e-02, -2.5238e-02,  3.7109e-02,
         2.8229e-03,  2.8259e-02, -5.9113e-02,  1.3771e-03,  7.2632e-02,
        -2.5803e-02, -2.0569e-02, -4.1962e-02, -8.0200e-02,  3.2539e-03,
        -3.8940e-02, -3.3081e-02,  6.6338e-03,  3.5095e-02, -3.5645e-02,
         1.9272e-02,  5.9814e-02,  3.8788e-02, -6.0349e-03, -2.6398e-03,
        -8.3557e-02,  4.1199e-02, -2.7023e-02, -3.1097e-02,  5.7945e-03,
         3.8696e-02,  2.1648e-03,  7.4829e-02,  7.3792e-02, -3.9825e-02,
         1.3092e-02, -3.0045e-02, -2.9564e-03,  7.2975e-03, -2.7176e-02,
        -3.1433e-02, -1.3908e-02, -2.1896e-02, -4.4495e-02, -2.8793e-02,
         1.6541e-02,  4.6509e-02,  2.1744e-02, -3.4973e-02, -5.2414e-03,
         9.3079e-03,  2.1957e-02, -5.8479e-03, -4.5654e-02, -2.7985e-02,
         2.3682e-02,  3.4363e