# InPars - Finetuning

Author: Monique Monteiro (moniquelouise@gmail.com)

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


In [1]:
main_dir = "/content/gdrive/MyDrive/Unicamp-aula-9"

## Libraries installation

In [2]:
!pip install transformers -q

In [3]:
!pip install jsonlines -q

In [4]:
!pip install evaluate -q

In [5]:
!pip install trectools -q

In [11]:
import random
import torch
import torch.nn.functional as F
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from statistics import mean, stdev

Random seeds definition, to enable replication of results.

In [7]:
random.seed(123)
np.random.seed(123)
torch.manual_seed(123)

<torch._C.Generator at 0x7fa350879210>

## Dataset processing

In [None]:
dataset_file = f"{main_dir}/trec-covid/monique_monteiro_1000_queries.jsonl"

In [None]:
!head {dataset_file}

{"query": "What is the most suitable protein for a diagnostic approach for Salmonella Enteritidis and why?", "positive_doc_id": "m1cmkkw3", "negative_doc_ids": ["0o3mryu1", "qpq7i1ya", "j5mkparg", "auoo0dm5", "dqfvrerw"]}
{"query": "What is Cryptosporidium parvum and why is it a major cause of disease in both humans and animals?", "positive_doc_id": "ukbl0svm", "negative_doc_ids": ["k3u2nvpe", "gc11fyms", "xoq9qblv", "20o4ufa3", "3huo5nf0"]}
{"query": "What is the role of the renin-angiotensin-aldosterone system in the context of SARS-CoV-2 infection?", "positive_doc_id": "12o4zey2", "negative_doc_ids": ["6gd6nwpu", "dt4t2wos", "8zwfkken", "sv7xpi4f", "6pf73z08"]}
{"query": "What are the functions of individual endolysosomal proteases in cellular processes such as autophagy and lipoprotein particle degradation?", "positive_doc_id": "eqv6a7tj", "negative_doc_ids": ["gmrty2uu", "uzn214j6", "032utjfh", "efet3ozc", "0muwl6oc"]}
{"query": "What is the prevalence of olfactory dysfunction in 

In [8]:
import jsonlines

id_to_text = dict()

with jsonlines.open(f"{main_dir}/trec-covid/corpus.jsonl") as reader:
  for item in reader:
    id = item["_id"]
    text = item["title"] + ' ' + item["text"]
    id_to_text[id] = text

In [9]:
import json
import pandas as pd
from sklearn.model_selection import train_test_split

def generate_training_data(dataset_file, model_name):
  dataset = []
  dataset_ids = []
  i = 0

  with open(dataset_file, 'r') as f:
    for line in f:
      data = json.loads(line)
      query = data["query"]
      positive_doc_id = data["positive_doc_id"]
      negative_doc_ids = data["negative_doc_ids"]
      
      #Chooses a random negative document
      negative_doc_id = random.choice(negative_doc_ids)

      #Gets the documents texts
      positive_doc = id_to_text[positive_doc_id]
      negative_doc = id_to_text[negative_doc_id]

      dataset.append((query, positive_doc, negative_doc))
      dataset_ids.append((i, positive_doc_id, negative_doc_id))

      i+=1

  df = pd.DataFrame(dataset, columns=['query', 'pos', 'neg'])

  df_pos = pd.DataFrame()
  df_neg = pd.DataFrame()

  for index, row in df.iterrows():
    if model_name == 'microsoft/MiniLM-L12-H384-uncased':
      df_pos = df_pos.append({"query":row[0], "passage":row[1], "score":1.0}, 
                             ignore_index=True)
      df_neg = df_neg.append({"query":row[0], "passage":row[2], "score":0.0}, 
                             ignore_index=True)
    elif model_name == 'cross-encoder/ms-marco-MiniLM-L-6-v2':
      df_pos = df_pos.append({"query":row[0], "passage":row[1], "score":True}, 
                             ignore_index=True)
      df_neg = df_neg.append({"query":row[0], "passage":row[2], "score":False}, 
                             ignore_index=True)

  print(model_name)
  print(df_pos.head())
  X_train_pos = df_pos.drop("score", axis=1)
  Y_train_pos = df_pos["score"]

  X_train_pos, X_val_pos, Y_train_pos, Y_val_pos = train_test_split(X_train_pos, 
                                                                    Y_train_pos, 
                                                                    test_size=0.1, 
                                                                    random_state=42)

  X_train_neg = df_neg.drop("score", axis=1)
  Y_train_neg = df_neg["score"]

  X_train_neg, X_val_neg, Y_train_neg, Y_val_neg = train_test_split(X_train_neg, 
                                                                    Y_train_neg, 
                                                                    test_size=0.1, 
                                                                    random_state=42)

  X_train = pd.concat([X_train_pos, X_train_neg], axis=0, ignore_index=True)
  Y_train = pd.concat([Y_train_pos, Y_train_neg], axis=0, ignore_index=True)
  X_val = pd.concat([X_val_pos, X_val_neg], axis=0, ignore_index=True)
  Y_val = pd.concat([Y_val_pos, Y_val_neg], axis=0, ignore_index=True)

  return X_train, Y_train, X_val, Y_val


## Finetuning with microsoft/MiniLM-L12-H384-uncased

Data preparation, with tokenization and dataset/dataloaders construction.

In [None]:
model_name = 'microsoft/MiniLM-L12-H384-uncased'

In [None]:
X_train, Y_train, X_val, Y_val = generate_training_data(dataset_file, model_name)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [None]:


lengths = [len(tokens) for tokens in tokenizer(list(X_train["passage"])[:1_000])['input_ids']]
print(f'Mean length in tokens: {mean(lengths):0.2f}')
print(f'Stdev length in tokens: {stdev(lengths):0.2f}')

In [12]:
max_length = 512
max_length_query = int(max_length/4)
max_length_passage = int(3*max_length/4)
max_length_query, max_length_passage

(128, 384)

In [None]:
train_queries = list(X_train["query"])
train_passages = list(X_train["passage"])
val_queries = list(X_val["query"])
val_passages = list(X_val["passage"])

train_queries_tokenized = tokenizer(train_queries, truncation=True, max_length=max_length_query)
train_passages_tokenized = tokenizer(train_passages, truncation=True, max_length=max_length_passage)
val_queries_tokenized = tokenizer(val_queries, truncation=True, max_length=max_length_query)
val_passages_tokenized = tokenizer(val_passages, truncation=True, max_length=max_length_passage)

In [13]:
from torch.utils import data

class Dataset(data.Dataset):
    def __init__(self, queries, passages, targets):
        self.queries = queries
        self.passages = passages
        self.targets = targets
    
    def __len__(self):
        return len(self.queries['input_ids'])
    
    def __getitem__(self, idx):
        return {
            'input_ids': self.queries['input_ids'][idx] + self.passages['input_ids'][idx],
            'attention_mask': self.queries['attention_mask'][idx] + self.passages['attention_mask'][idx],
            #'labels': float(self.targets[idx])
            'labels': int(self.targets[idx])
        }


In [None]:
dataset_train = Dataset(train_queries_tokenized, train_passages_tokenized, Y_train)
assert len(dataset_train[0]['input_ids']) > 0
assert len(dataset_train[1]['attention_mask']) > 0

In [None]:
dataset_valid = Dataset(val_queries_tokenized, val_passages_tokenized, Y_val)

In [14]:
from transformers import BatchEncoding

# Tokens do tipo "pad" para textos com tamanho inferior ao máximo suportado.
def collate_fn(batch):
  return BatchEncoding(tokenizer.pad(batch, return_tensors='pt'))


In [None]:
dataloader_train = data.DataLoader(dataset_train, batch_size=32, shuffle=True, collate_fn=collate_fn)
dataloader_valid = data.DataLoader(dataset_valid, batch_size=32, shuffle=False, collate_fn=collate_fn)

for batch in dataloader_train:
    assert batch['input_ids'].shape[0] <= dataloader_train.batch_size
    assert batch['input_ids'].shape[1] <= max_length
    break

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


### Training Loop

In [15]:
from tqdm import tqdm 

def evaluate(model, dataloader, set_name, model_name):
    losses = []
    correct = 0
    model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, mininterval=0.5, desc=set_name, disable=False):
            outputs = model(**batch.to(device))
            loss_val = outputs.loss
            losses.append(loss_val.cpu().item())
            preds = outputs.logits.argmax(dim=1)
            
            if model_name == 'cross-encoder/ms-marco-MiniLM-L-6-v2':
              preds  = (outputs.logits.view(-1)>0).float()
            
            correct += (preds == batch['labels']).sum().item()

    print(f'{set_name} loss: {mean(losses):0.3f}; {set_name} accuracy: {correct / len(dataloader.dataset):0.3f}')
    return correct / len(dataloader.dataset)
     

In [16]:
from torch import nn
from torch import optim
from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [17]:
from transformers.optimization import get_constant_schedule
import os
import shutil

def train(model_name, epochs = 5, lr=5e-5):
  model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
  print('Parameters', model.num_parameters())

  optimizer = optim.AdamW(model.parameters(), lr)
  num_training_steps = epochs * len(dataloader_train)

  num_warmup_steps = int(num_training_steps * 0.1)
  scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, 
                                              num_training_steps)
  #scheduler = get_constant_schedule(optimizer)

  evaluate(model=model, dataloader=dataloader_valid, set_name='Valid', 
           model_name=model_name)
  best_acc = 0

  # Training loop
  for epoch in tqdm(range(epochs), desc='Epochs'):
      model.train()
      train_losses = []
      for batch in tqdm(dataloader_train, mininterval=0.5, desc='Train', 
                        disable=False):
          optimizer.zero_grad()
          outputs = model(**batch.to(device))
          loss = outputs.loss
          loss.backward()
          optimizer.step()
          scheduler.step()
          train_losses.append(loss.cpu().item())

      print(f'Epoch: {epoch + 1} Training loss: {mean(train_losses):0.2f}')
      acc = evaluate(model=model, dataloader=dataloader_valid, set_name='Valid', 
                     model_name=model_name)
      #Saves the best checkpoint
      if acc > best_acc:
        best_acc = acc
        if os.path.exists(f'{MODELS_PATH}/best_checkpoint'):
          shutil.rmtree(f'{MODELS_PATH}/best_checkpoint')
        model.save_pretrained(f'{MODELS_PATH}/best_checkpoint')
  
  return model

In [18]:
MODELS_PATH = '/content/gdrive/MyDrive/Unicamp-aula-9'

In [None]:
model = train(model_name, 5)
model_name = model_name.replace('/','_')
model.save_pretrained(f'{MODELS_PATH}/models_ranker_{model_name}')
tokenizer.save_pretrained(f'{MODELS_PATH}/tokenizer_ranker')

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/MiniLM-L12-H384-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Parameters 33360770


Valid:   0%|          | 0/7 [00:00<?, ?it/s]

Valid loss: 0.697; Valid accuracy: 0.500


Epochs:   0%|          | 0/5 [00:00<?, ?it/s]

Train:   0%|          | 0/57 [00:00<?, ?it/s]

Epoch: 1 Training loss: 0.63


Valid:   0%|          | 0/7 [00:00<?, ?it/s]

Valid loss: 0.237; Valid accuracy: 0.950


Train:   0%|          | 0/57 [00:00<?, ?it/s]

Epoch: 2 Training loss: 0.15


Valid:   0%|          | 0/7 [00:00<?, ?it/s]

Valid loss: 0.087; Valid accuracy: 0.980


Train:   0%|          | 0/57 [00:00<?, ?it/s]

Epoch: 3 Training loss: 0.06


Valid:   0%|          | 0/7 [00:00<?, ?it/s]

Valid loss: 0.081; Valid accuracy: 0.980


Train:   0%|          | 0/57 [00:00<?, ?it/s]

Epoch: 4 Training loss: 0.04


Valid:   0%|          | 0/7 [00:00<?, ?it/s]

Valid loss: 0.088; Valid accuracy: 0.975


Train:   0%|          | 0/57 [00:00<?, ?it/s]

Epoch: 5 Training loss: 0.02


Valid:   0%|          | 0/7 [00:00<?, ?it/s]

Valid loss: 0.077; Valid accuracy: 0.980


('/content/gdrive/MyDrive/Unicamp-aula-9/tokenizer_ranker/tokenizer_config.json',
 '/content/gdrive/MyDrive/Unicamp-aula-9/tokenizer_ranker/special_tokens_map.json',
 '/content/gdrive/MyDrive/Unicamp-aula-9/tokenizer_ranker/vocab.txt',
 '/content/gdrive/MyDrive/Unicamp-aula-9/tokenizer_ranker/added_tokens.json',
 '/content/gdrive/MyDrive/Unicamp-aula-9/tokenizer_ranker/tokenizer.json')

### Reranking

Evaluation on TREC-COVID

In [None]:
id_to_query = dict()

with jsonlines.open(f"{main_dir}/trec-covid/queries.jsonl") as reader:
  for item in reader:
    id = item["_id"]
    text = item["text"]
    id_to_query[id] = text

In [None]:
import jsonlines

id_to_doc = dict()

with jsonlines.open(f"{main_dir}/trec-covid/corpus.jsonl") as reader:
  for item in reader:
    id = item["_id"]
    text = item["title"] + ' ' + item["text"]
    id_to_doc[id] = text

In [None]:
!head {main_dir}/trec-covid/run.trec-covid.bm25tuned.txt

1 Q0 dv9m19yk 1 4.158100 Anserini
1 Q0 kgifmjvb 2 3.338900 Anserini
1 Q0 wmfcey6f 3 3.338899 Anserini
1 Q0 safr9z37 4 3.220100 Anserini
1 Q0 0paafp5j 5 3.207300 Anserini
1 Q0 96zsd27n 6 3.207299 Anserini
1 Q0 4dtk1kyh 7 3.184800 Anserini
1 Q0 lhd0jn0z 8 2.903200 Anserini
1 Q0 55dihml5 9 2.899800 Anserini
1 Q0 qtx0d5f8 10 2.888800 Anserini


In [181]:
import pickle
import os

def tokenize_test_queries_and_passages():
  tokenized_queries = None
  tokenized_passages = None

  if os.path.exists(f"{main_dir}/trec-covid/tok_queries_test.pickle"):
    with open(f"{main_dir}/trec-covid/tok_queries_test.pickle", "rb") as f:
      print("Loading test queries...")
      tokenized_queries = pickle.load(f) 

  if os.path.exists(f"{main_dir}/trec-covid/tok_passages_test.pickle"):
    with open(f"{main_dir}/trec-covid/tok_passages_test.pickle", "rb") as f:
      print("Loading test passages...")
      tokenized_passages = pickle.load(f) 

  query_ids = []
  queries = []
  passage_ids = []
  passages = []

  with open(f'{main_dir}/trec-covid/run.trec-covid.bm25tuned.txt') as f:
    for line in f:
        fields = line.strip().split()
        query_id = fields[0]
        query_ids.append(query_id)
        passage_id = fields[2]
        passage_ids.append(passage_id)
        
        if not tokenized_queries:
          query_text = id_to_query[query_id]
          queries.append(query_text)

        if not tokenized_passages:
          passage_text = id_to_doc[passage_id]
          passages.append(passage_text)

  if not tokenized_queries:
    tokenized_queries = tokenizer(queries, max_length=max_length_query, truncation=True)

    with open(f"{main_dir}/trec-covid/tok_queries_test.pickle", 'wb') as f:
      pickle.dump(tokenized_queries, f)

  if not tokenized_passages:
    tokenized_passages = tokenizer(passages, max_length=max_length_passage, truncation=True)

    with open(f"{main_dir}/trec-covid/tok_passages_test.pickle", 'wb') as f:
      pickle.dump(tokenized_passages, f)

  return tokenized_queries, tokenized_passages

In [None]:
tokenized_queries, tokenized_passages = tokenize_test_queries_and_passages()
dataset_test = Dataset(tokenized_queries, tokenized_passages, [1]*len(tokenized_queries['input_ids']))

In [None]:
dataloader_test = data.DataLoader(dataset_test, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [None]:

def evaluate_test_dataset(model, dataloader, set_name, use_logits=False):
    scores = []
    model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, mininterval=0.5, desc=set_name, disable=False):
            outputs = model(**batch.to(device))
            if use_logits:
              # Usa os logits brutos
              pos_score = outputs.logits[:,1]
            else:
              # Usa os logits normalizados pelo softmax (por default)
              pos_score = torch.softmax(outputs.logits,1)[:,1]
            scores = scores + pos_score.tolist()
    return scores

In [162]:
def evaluate_ndcg_10(scores, model_name, eval_desc):
  zipped_results = []

  #Por alguma razão misteriosa, o zip do Python não funcionou, deixou a lista 
  #vazia ou impossível de ser iterada.
  for i, query_id in enumerate(query_ids):
    zipped_results.append((query_id, passage_ids[i], scores[i]))

  #Quebra a lista em sublistas por query
  prev_query_id = -1
  sublists = []
  current_list = []

  for query_id, passage_id, score in zipped_results:
    if query_id != prev_query_id:
      if len(current_list) > 0:
        sublists.append(current_list)
        current_list = []
    current_list.append((query_id, passage_id, score))
    prev_query_id = query_id

  if len(current_list) > 0:
    sublists.append(current_list)

  # Ordena cada sublista
  sorted_list = []

  for sublist in sublists:
    sorted_sublist = sorted(sublist, key=lambda x: x[2], reverse=True)
    sorted_list += sorted_sublist

  # Gera o arquivo de run no formato TREC
  trec_run_file = f"{main_dir}/trec-covid/run.trec-covid.bert_reranked_{model_name.replace('/', '_')}_{eval_desc}.trec"
  with open(trec_run_file, "w") as f:
    for i, (query_id, passage_id, score) in enumerate(sorted_list):
      f.write(f'{query_id}\t{passage_id}\t{i+1}\t{score}\tbert_reranked_{model_name}\n')

  return trec_run_file

In [None]:

softmax_scores = evaluate_test_dataset(model=model, dataloader=dataloader_test, set_name='Test')
logit_scores = evaluate_test_dataset(model=model, dataloader=dataloader_test, set_name='Test', use_logits=True)

Test:   0%|          | 0/1563 [00:00<?, ?it/s]

Test:   0%|          | 0/1563 [00:00<?, ?it/s]

In [None]:

trec_file_softmax = evaluate_ndcg_10(softmax_scores, model_name, "softmax")
trec_file_logits = evaluate_ndcg_10(logit_scores, model_name, "logits")

In [None]:
!head {trec_file_softmax}

1	ews32six	1	0.9927167296409607	bert_reranked_microsoft_MiniLM-L12-H384-uncased
1	k4l70wf6	2	0.9926959276199341	bert_reranked_microsoft_MiniLM-L12-H384-uncased
1	fraczoxu	3	0.9926958084106445	bert_reranked_microsoft_MiniLM-L12-H384-uncased
1	zpb9jicw	4	0.9926846027374268	bert_reranked_microsoft_MiniLM-L12-H384-uncased
1	dt0b87me	5	0.9926679730415344	bert_reranked_microsoft_MiniLM-L12-H384-uncased
1	8f5m0gej	6	0.9926672577857971	bert_reranked_microsoft_MiniLM-L12-H384-uncased
1	67fyu0i5	7	0.9926584362983704	bert_reranked_microsoft_MiniLM-L12-H384-uncased
1	o7kl258h	8	0.9926577210426331	bert_reranked_microsoft_MiniLM-L12-H384-uncased
1	1ew0p6x7	9	0.9926537275314331	bert_reranked_microsoft_MiniLM-L12-H384-uncased
1	2cwvga0k	10	0.9926504492759705	bert_reranked_microsoft_MiniLM-L12-H384-uncased


In [None]:
import pandas as pd

qrel = pd.read_csv(f"{main_dir}/trec-covid/test.tsv", sep="\t", header=None, 
                   skiprows=1, names=["query", "docid", "rel"])
qrel["q0"] = "q0"
qrel = qrel.to_dict(orient="list")

In [None]:
!head {main_dir}/trec-covid/test.tsv

query-id	corpus-id	score
1	005b2j4b	2
1	00fmeepz	1
1	g7dhmyyo	2
1	0194oljo	1
1	021q9884	1
1	02f0opkr	1
1	047xpt2c	0
1	04ftw7k9	0
1	pl9ht0d0	0


In [None]:
from evaluate import load

def eval_ndcg10(run):
  trec_eval = load("trec_eval")
  results = trec_eval.compute(predictions=[run], references=[qrel])
  return results['NDCG@10'] 

In [None]:
import pandas as pd

run_trec_file_softmax = pd.read_csv(trec_file_softmax, sep="\t", header=None, 
                   skiprows=1, names=["query", "docid", "rank", "score", "system"])
run_trec_file_softmax["q0"] = "q0"
run_trec_file_softmax = run_trec_file_softmax.to_dict(orient="list")

In [None]:
import pandas as pd

run_trec_file_logits = pd.read_csv(trec_file_logits, sep="\t", header=None, 
                   skiprows=1, names=["query", "docid", "rank", "score", "system"])
run_trec_file_logits["q0"] = "q0"
run_trec_file_logits = run_trec_file_logits.to_dict(orient="list")

In [None]:
eval_ndcg10(run_trec_file_softmax)

0.6110263574494792

In [None]:
eval_ndcg10(run_trec_file_logits)

0.6125945845401235

## Finetuning with cross-encoder/ms-marco-MiniLM-L-6-v2

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils import data
from transformers import BatchEncoding

In [None]:
model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"


In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:


class MSMARCODataset(data.Dataset):
    def __init__(self, tokenizer, query, passages, targets, max_lenght=356):
        self.tokenizer = tokenizer
        self.query = query
        self.passages = passages
        self.targets = targets
        self.max_lenght = max_lenght
    
    def __len__(self):
        return len(self.query)

    def __getitem__(self, idx):
        instruction_token = self.tokenizer(self.query[idx],self.passages[idx],
                                           max_length=self.max_lenght, 
                                           truncation=True,
                                           padding="max_length", 
                                           return_tensors='pt')


        return {'input_ids':torch.squeeze(instruction_token['input_ids']).long().to(device),\
               'attention_mask':torch.squeeze(instruction_token['attention_mask']).long().to(device), \
               'labels':torch.tensor(self.targets[idx], dtype=torch.float16)}
              
def collate_fn(batch):
    return BatchEncoding(tokenizer.pad(batch, return_tensors='pt'))
    

In [None]:
train_queries = list(X_train["query"])
train_passages = list(X_train["passage"])
val_queries = list(X_val["query"])
val_passages = list(X_val["passage"])

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
dataset_train = MSMARCODataset(tokenizer, train_queries, train_passages, Y_train)
assert len(dataset_train[0]['input_ids']) > 0
assert len(dataset_train[1]['attention_mask']) > 0

In [None]:
dataset_val = MSMARCODataset(tokenizer, val_queries, val_passages, Y_val)
assert len(dataset_val[0]['input_ids']) > 0
assert len(dataset_val[1]['attention_mask']) > 0

In [None]:
dataloader_train = data.DataLoader(dataset_train, batch_size=32, shuffle=True, collate_fn=collate_fn)
dataloader_valid = data.DataLoader(dataset_val, batch_size=32, shuffle=False, collate_fn=collate_fn)


In [None]:
model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"


In [None]:
model = train(model_name, 40, lr=1e-3)

Parameters 22713601


Valid: 100%|██████████| 7/7 [00:01<00:00,  4.64it/s]


Valid loss: 35.093; Valid accuracy: 0.510


Epochs:   0%|          | 0/40 [00:00<?, ?it/s]
Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:01<00:22,  2.46it/s][A
Train:   7%|▋         | 4/57 [00:02<00:32,  1.64it/s][A
Train:   9%|▉         | 5/57 [00:02<00:30,  1.69it/s][A
Train:  12%|█▏        | 7/57 [00:03<00:25,  1.98it/s][A
Train:  16%|█▌        | 9/57 [00:04<00:22,  2.14it/s][A
Train:  19%|█▉        | 11/57 [00:05<00:20,  2.23it/s][A
Train:  23%|██▎       | 13/57 [00:06<00:20,  2.17it/s][A
Train:  26%|██▋       | 15/57 [00:07<00:18,  2.24it/s][A
Train:  30%|██▉       | 17/57 [00:08<00:20,  1.99it/s][A
Train:  33%|███▎      | 19/57 [00:09<00:18,  2.02it/s][A
Train:  37%|███▋      | 21/57 [00:10<00:17,  2.02it/s][A
Train:  40%|████      | 23/57 [00:11<00:16,  2.04it/s][A
Train:  44%|████▍     | 25/57 [00:12<00:15,  2.09it/s][A
Train:  47%|████▋     | 27/57 [00:13<00:13,  2.17it/s][A
Train:  51%|█████     | 29/57 [00:14<00:12,  2.24it/s][A
Train:  54%|█████▍    | 31/57 [00:15<00

Epoch: 1 Training loss: 6.17



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.82it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  6.10it/s]


Valid loss: 0.345; Valid accuracy: 0.505


Epochs:   2%|▎         | 1/40 [00:29<18:59, 29.22s/it]
Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:01<00:23,  2.39it/s][A
Train:   7%|▋         | 4/57 [00:02<00:28,  1.87it/s][A
Train:  11%|█         | 6/57 [00:03<00:24,  2.08it/s][A
Train:  14%|█▍        | 8/57 [00:04<00:26,  1.83it/s][A
Train:  18%|█▊        | 10/57 [00:05<00:23,  1.99it/s][A
Train:  21%|██        | 12/57 [00:05<00:21,  2.07it/s][A
Train:  25%|██▍       | 14/57 [00:07<00:20,  2.15it/s][A
Train:  28%|██▊       | 16/57 [00:08<00:21,  1.91it/s][A
Train:  32%|███▏      | 18/57 [00:08<00:19,  2.00it/s][A
Train:  35%|███▌      | 20/57 [00:10<00:17,  2.11it/s][A
Train:  39%|███▊      | 22/57 [00:11<00:18,  1.88it/s][A
Train:  40%|████      | 23/57 [00:11<00:18,  1.88it/s][A
Train:  44%|████▍     | 25/57 [00:13<00:15,  2.02it/s][A
Train:  47%|████▋     | 27/57 [00:14<00:18,  1.64it/s][A
Train:  51%|█████     | 29/57 [00:15<00:15,  1.82it/s][A
Train:  54%|█████▍    | 31/57 

Epoch: 2 Training loss: 0.20



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:01<00:00,  5.82it/s][A
Valid: 100%|██████████| 7/7 [00:02<00:00,  3.13it/s]


Valid loss: 0.369; Valid accuracy: 0.570


Epochs:   5%|▌         | 2/40 [00:59<19:00, 30.01s/it]
Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:22,  2.41it/s][A
Train:   7%|▋         | 4/57 [00:01<00:23,  2.23it/s][A
Train:  11%|█         | 6/57 [00:02<00:23,  2.16it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:21,  2.25it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:23,  2.00it/s][A
Train:  21%|██        | 12/57 [00:05<00:21,  2.09it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:19,  2.18it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:21,  1.93it/s][A
Train:  32%|███▏      | 18/57 [00:08<00:18,  2.07it/s][A
Train:  35%|███▌      | 20/57 [00:09<00:17,  2.17it/s][A
Train:  39%|███▊      | 22/57 [00:10<00:18,  1.89it/s][A
Train:  42%|████▏     | 24/57 [00:11<00:16,  1.98it/s][A
Train:  46%|████▌     | 26/57 [00:12<00:14,  2.11it/s][A
Train:  49%|████▉     | 28/57 [00:13<00:15,  1.90it/s][A
Train:  53%|█████▎    | 30/57 [00:14<00:13,  1.99it/s][A
Train:  56%|█████▌    | 32/57 

Epoch: 3 Training loss: 0.14



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.73it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  4.93it/s]


Valid loss: 0.137; Valid accuracy: 0.865


Epochs:   8%|▊         | 3/40 [01:30<18:38, 30.23s/it]
Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:24,  2.27it/s][A
Train:   7%|▋         | 4/57 [00:02<00:22,  2.34it/s][A
Train:  11%|█         | 6/57 [00:03<00:27,  1.88it/s][A
Train:  12%|█▏        | 7/57 [00:03<00:26,  1.88it/s][A
Train:  16%|█▌        | 9/57 [00:04<00:23,  2.07it/s][A
Train:  19%|█▉        | 11/57 [00:05<00:21,  2.19it/s][A
Train:  23%|██▎       | 13/57 [00:06<00:19,  2.26it/s][A
Train:  26%|██▋       | 15/57 [00:06<00:19,  2.21it/s][A
Train:  30%|██▉       | 17/57 [00:07<00:18,  2.21it/s][A
Train:  33%|███▎      | 19/57 [00:09<00:16,  2.27it/s][A
Train:  37%|███▋      | 21/57 [00:09<00:17,  2.00it/s][A
Train:  40%|████      | 23/57 [00:10<00:16,  2.08it/s][A
Train:  44%|████▍     | 25/57 [00:12<00:14,  2.17it/s][A
Train:  47%|████▋     | 27/57 [00:13<00:15,  1.92it/s][A
Train:  51%|█████     | 29/57 [00:13<00:13,  2.01it/s][A
Train:  54%|█████▍    | 31/57 [

Epoch: 4 Training loss: 0.10



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.79it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  4.61it/s]
Epochs:  10%|█         | 4/40 [02:00<18:01, 30.05s/it]

Valid loss: 0.130; Valid accuracy: 0.830



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:01<00:22,  2.43it/s][A
Train:   7%|▋         | 4/57 [00:02<00:31,  1.70it/s][A
Train:  11%|█         | 6/57 [00:03<00:26,  1.92it/s][A
Train:  14%|█▍        | 8/57 [00:04<00:23,  2.09it/s][A
Train:  18%|█▊        | 10/57 [00:05<00:25,  1.85it/s][A
Train:  21%|██        | 12/57 [00:06<00:22,  1.96it/s][A
Train:  25%|██▍       | 14/57 [00:07<00:20,  2.08it/s][A
Train:  28%|██▊       | 16/57 [00:08<00:21,  1.90it/s][A
Train:  32%|███▏      | 18/57 [00:09<00:19,  1.99it/s][A
Train:  35%|███▌      | 20/57 [00:10<00:17,  2.10it/s][A
Train:  39%|███▊      | 22/57 [00:11<00:18,  1.90it/s][A
Train:  42%|████▏     | 24/57 [00:12<00:16,  2.00it/s][A
Train:  46%|████▌     | 26/57 [00:13<00:14,  2.11it/s][A
Train:  49%|████▉     | 28/57 [00:14<00:15,  1.90it/s][A
Train:  53%|█████▎    | 30/57 [00:15<00:13,  1.99it/s][A
Train:  56%|█████▌    | 32/57 [00:16<00:11,  2.10it/s][A
Train:  60%|█████▉    | 34

Epoch: 5 Training loss: 0.06



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.86it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  4.11it/s]
Epochs:  12%|█▎        | 5/40 [02:29<17:24, 29.85s/it]

Valid loss: 0.255; Valid accuracy: 0.645



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:22,  2.45it/s][A
Train:   7%|▋         | 4/57 [00:01<00:23,  2.30it/s][A
Train:  11%|█         | 6/57 [00:02<00:21,  2.34it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:24,  2.04it/s][A
Train:  18%|█▊        | 10/57 [00:05<00:21,  2.16it/s][A
Train:  21%|██        | 12/57 [00:05<00:23,  1.90it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:21,  2.04it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:19,  2.11it/s][A
Train:  32%|███▏      | 18/57 [00:08<00:17,  2.19it/s][A
Train:  35%|███▌      | 20/57 [00:09<00:19,  1.93it/s][A
Train:  39%|███▊      | 22/57 [00:10<00:17,  2.02it/s][A
Train:  42%|████▏     | 24/57 [00:11<00:15,  2.12it/s][A
Train:  46%|████▌     | 26/57 [00:12<00:16,  1.91it/s][A
Train:  49%|████▉     | 28/57 [00:13<00:14,  2.00it/s][A
Train:  53%|█████▎    | 30/57 [00:14<00:12,  2.11it/s][A
Train:  56%|█████▌    | 32/57 [00:15<00:13,  1.91it/s][A
Train:  60%|█████▉    | 34

Epoch: 6 Training loss: 0.05



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.81it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  5.30it/s]
Epochs:  15%|█▌        | 6/40 [02:58<16:47, 29.63s/it]

Valid loss: 0.128; Valid accuracy: 0.740



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:22,  2.45it/s][A
Train:   7%|▋         | 4/57 [00:01<00:21,  2.43it/s][A
Train:  11%|█         | 6/57 [00:02<00:22,  2.27it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:22,  2.18it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:22,  2.11it/s][A
Train:  21%|██        | 12/57 [00:05<00:20,  2.21it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:20,  2.05it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:19,  2.15it/s][A
Train:  32%|███▏      | 18/57 [00:08<00:19,  1.99it/s][A
Train:  35%|███▌      | 20/57 [00:09<00:17,  2.06it/s][A
Train:  39%|███▊      | 22/57 [00:10<00:16,  2.15it/s][A
Train:  42%|████▏     | 24/57 [00:11<00:17,  1.92it/s][A
Train:  44%|████▍     | 25/57 [00:12<00:16,  1.91it/s][A
Train:  47%|████▋     | 27/57 [00:12<00:14,  2.06it/s][A
Train:  51%|█████     | 29/57 [00:13<00:13,  2.15it/s][A
Train:  54%|█████▍    | 31/57 [00:14<00:11,  2.22it/s][A
Train:  58%|█████▊    | 33

Epoch: 7 Training loss: 0.03



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.34it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  4.60it/s]
Epochs:  18%|█▊        | 7/40 [03:28<16:13, 29.51s/it]

Valid loss: 0.128; Valid accuracy: 0.785



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:22,  2.41it/s][A
Train:   7%|▋         | 4/57 [00:01<00:22,  2.38it/s][A
Train:  11%|█         | 6/57 [00:02<00:23,  2.20it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:21,  2.26it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:20,  2.32it/s][A
Train:  21%|██        | 12/57 [00:05<00:19,  2.30it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:18,  2.34it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:20,  2.05it/s][A
Train:  32%|███▏      | 18/57 [00:08<00:18,  2.10it/s][A
Train:  35%|███▌      | 20/57 [00:09<00:16,  2.19it/s][A
Train:  39%|███▊      | 22/57 [00:10<00:17,  1.94it/s][A
Train:  42%|████▏     | 24/57 [00:11<00:16,  2.00it/s][A
Train:  46%|████▌     | 26/57 [00:12<00:14,  2.11it/s][A
Train:  49%|████▉     | 28/57 [00:13<00:15,  1.91it/s][A
Train:  51%|█████     | 29/57 [00:13<00:14,  1.91it/s][A
Train:  54%|█████▍    | 31/57 [00:14<00:12,  2.05it/s][A
Train:  58%|█████▊    | 33

Epoch: 8 Training loss: 0.02



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.82it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  4.13it/s]
Epochs:  20%|██        | 8/40 [03:56<15:34, 29.20s/it]

Valid loss: 0.108; Valid accuracy: 0.725



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:22,  2.44it/s][A
Train:   7%|▋         | 4/57 [00:01<00:22,  2.33it/s][A
Train:  11%|█         | 6/57 [00:02<00:21,  2.38it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:23,  2.06it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:21,  2.15it/s][A
Train:  21%|██        | 12/57 [00:05<00:20,  2.24it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:22,  1.93it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:20,  2.02it/s][A
Train:  32%|███▏      | 18/57 [00:08<00:18,  2.13it/s][A
Train:  35%|███▌      | 20/57 [00:09<00:19,  1.90it/s][A
Train:  39%|███▊      | 22/57 [00:10<00:17,  2.00it/s][A
Train:  42%|████▏     | 24/57 [00:11<00:15,  2.11it/s][A
Train:  46%|████▌     | 26/57 [00:12<00:16,  1.91it/s][A
Train:  49%|████▉     | 28/57 [00:13<00:14,  2.01it/s][A
Train:  53%|█████▎    | 30/57 [00:14<00:12,  2.12it/s][A
Train:  56%|█████▌    | 32/57 [00:15<00:13,  1.90it/s][A
Train:  60%|█████▉    | 34

Epoch: 9 Training loss: 0.02



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.80it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  6.51it/s]
Epochs:  22%|██▎       | 9/40 [04:25<15:03, 29.13s/it]

Valid loss: 0.093; Valid accuracy: 0.685



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:22,  2.43it/s][A
Train:   7%|▋         | 4/57 [00:01<00:24,  2.19it/s][A
Train:  11%|█         | 6/57 [00:02<00:23,  2.14it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:21,  2.23it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:23,  1.97it/s][A
Train:  21%|██        | 12/57 [00:05<00:22,  2.01it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:21,  2.04it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:19,  2.14it/s][A
Train:  32%|███▏      | 18/57 [00:08<00:19,  1.97it/s][A
Train:  35%|███▌      | 20/57 [00:10<00:17,  2.09it/s][A
Train:  39%|███▊      | 22/57 [00:10<00:18,  1.88it/s][A
Train:  42%|████▏     | 24/57 [00:11<00:16,  2.02it/s][A
Train:  46%|████▌     | 26/57 [00:12<00:14,  2.13it/s][A
Train:  49%|████▉     | 28/57 [00:13<00:13,  2.21it/s][A
Train:  53%|█████▎    | 30/57 [00:14<00:14,  1.90it/s][A
Train:  56%|█████▌    | 32/57 [00:15<00:12,  2.00it/s][A
Train:  60%|█████▉    | 34

Epoch: 10 Training loss: 0.01



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:01<00:00,  5.81it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  4.06it/s]
Epochs:  25%|██▌       | 10/40 [04:55<14:37, 29.25s/it]

Valid loss: 0.089; Valid accuracy: 0.670



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:22,  2.41it/s][A
Train:   7%|▋         | 4/57 [00:01<00:21,  2.42it/s][A
Train:  11%|█         | 6/57 [00:02<00:21,  2.41it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:20,  2.41it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:20,  2.34it/s][A
Train:  21%|██        | 12/57 [00:05<00:19,  2.35it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:20,  2.07it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:19,  2.07it/s][A
Train:  32%|███▏      | 18/57 [00:08<00:17,  2.17it/s][A
Train:  35%|███▌      | 20/57 [00:09<00:19,  1.95it/s][A
Train:  39%|███▊      | 22/57 [00:10<00:17,  2.01it/s][A
Train:  42%|████▏     | 24/57 [00:11<00:15,  2.11it/s][A
Train:  46%|████▌     | 26/57 [00:12<00:16,  1.93it/s][A
Train:  47%|████▋     | 27/57 [00:12<00:15,  1.93it/s][A
Train:  51%|█████     | 29/57 [00:13<00:13,  2.07it/s][A
Train:  54%|█████▍    | 31/57 [00:14<00:11,  2.17it/s][A
Train:  58%|█████▊    | 33

Epoch: 11 Training loss: 0.01



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.76it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  6.41it/s]
Epochs:  28%|██▊       | 11/40 [05:23<13:57, 28.87s/it]

Valid loss: 0.070; Valid accuracy: 0.660



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:23,  2.38it/s][A
Train:   7%|▋         | 4/57 [00:02<00:22,  2.38it/s][A
Train:  11%|█         | 6/57 [00:03<00:30,  1.68it/s][A
Train:  12%|█▏        | 7/57 [00:04<00:28,  1.73it/s][A
Train:  14%|█▍        | 8/57 [00:04<00:33,  1.48it/s][A
Train:  18%|█▊        | 10/57 [00:06<00:26,  1.77it/s][A
Train:  19%|█▉        | 11/57 [00:07<00:33,  1.38it/s][A
Train:  21%|██        | 12/57 [00:07<00:35,  1.28it/s][A
Train:  25%|██▍       | 14/57 [00:09<00:26,  1.60it/s][A
Train:  26%|██▋       | 15/57 [00:09<00:31,  1.34it/s][A
Train:  30%|██▉       | 17/57 [00:10<00:24,  1.62it/s][A
Train:  33%|███▎      | 19/57 [00:11<00:20,  1.84it/s][A
Train:  37%|███▋      | 21/57 [00:12<00:17,  2.01it/s][A
Train:  40%|████      | 23/57 [00:13<00:16,  2.04it/s][A
Train:  44%|████▍     | 25/57 [00:14<00:14,  2.14it/s][A
Train:  47%|████▋     | 27/57 [00:15<00:15,  1.95it/s][A
Train:  51%|█████     | 29/

Epoch: 12 Training loss: 0.01



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:01<00:00,  5.74it/s][A
Valid: 100%|██████████| 7/7 [00:02<00:00,  3.11it/s]
Epochs:  30%|███       | 12/40 [05:54<13:50, 29.66s/it]

Valid loss: 0.067; Valid accuracy: 0.645



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:22,  2.46it/s][A
Train:   7%|▋         | 4/57 [00:01<00:21,  2.46it/s][A
Train:  11%|█         | 6/57 [00:02<00:24,  2.07it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:22,  2.13it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:21,  2.23it/s][A
Train:  21%|██        | 12/57 [00:05<00:23,  1.94it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:21,  2.03it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:19,  2.15it/s][A
Train:  32%|███▏      | 18/57 [00:08<00:20,  1.92it/s][A
Train:  35%|███▌      | 20/57 [00:09<00:18,  2.01it/s][A
Train:  39%|███▊      | 22/57 [00:10<00:16,  2.11it/s][A
Train:  42%|████▏     | 24/57 [00:11<00:17,  1.90it/s][A
Train:  44%|████▍     | 25/57 [00:12<00:16,  1.89it/s][A
Train:  47%|████▋     | 27/57 [00:13<00:14,  2.03it/s][A
Train:  51%|█████     | 29/57 [00:14<00:13,  2.15it/s][A
Train:  54%|█████▍    | 31/57 [00:14<00:11,  2.22it/s][A
Train:  58%|█████▊    | 33

Epoch: 13 Training loss: 0.01



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.86it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  5.09it/s]
Epochs:  32%|███▎      | 13/40 [06:23<13:14, 29.42s/it]

Valid loss: 0.065; Valid accuracy: 0.805



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:01<00:22,  2.41it/s][A
Train:   7%|▋         | 4/57 [00:01<00:25,  2.05it/s][A
Train:  11%|█         | 6/57 [00:02<00:24,  2.12it/s][A
Train:  14%|█▍        | 8/57 [00:04<00:21,  2.23it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:24,  1.94it/s][A
Train:  21%|██        | 12/57 [00:05<00:22,  2.03it/s][A
Train:  25%|██▍       | 14/57 [00:07<00:20,  2.14it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:21,  1.92it/s][A
Train:  32%|███▏      | 18/57 [00:08<00:19,  1.99it/s][A
Train:  35%|███▌      | 20/57 [00:10<00:17,  2.10it/s][A
Train:  39%|███▊      | 22/57 [00:10<00:18,  1.92it/s][A
Train:  42%|████▏     | 24/57 [00:12<00:16,  2.04it/s][A
Train:  46%|████▌     | 26/57 [00:12<00:15,  1.96it/s][A
Train:  49%|████▉     | 28/57 [00:13<00:13,  2.08it/s][A
Train:  53%|█████▎    | 30/57 [00:14<00:12,  2.17it/s][A
Train:  56%|█████▌    | 32/57 [00:16<00:11,  2.23it/s][A
Train:  60%|█████▉    | 34

Epoch: 14 Training loss: 0.01



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.85it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  4.31it/s]
Epochs:  35%|███▌      | 14/40 [06:52<12:42, 29.33s/it]

Valid loss: 0.064; Valid accuracy: 0.755



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:22,  2.41it/s][A
Train:   7%|▋         | 4/57 [00:01<00:23,  2.30it/s][A
Train:  11%|█         | 6/57 [00:02<00:21,  2.35it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:24,  1.99it/s][A
Train:  18%|█▊        | 10/57 [00:05<00:22,  2.12it/s][A
Train:  21%|██        | 12/57 [00:05<00:23,  1.88it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:21,  2.02it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:19,  2.08it/s][A
Train:  32%|███▏      | 18/57 [00:08<00:17,  2.18it/s][A
Train:  35%|███▌      | 20/57 [00:09<00:19,  1.94it/s][A
Train:  39%|███▊      | 22/57 [00:10<00:17,  2.02it/s][A
Train:  42%|████▏     | 24/57 [00:11<00:15,  2.13it/s][A
Train:  46%|████▌     | 26/57 [00:12<00:16,  1.92it/s][A
Train:  49%|████▉     | 28/57 [00:13<00:14,  2.05it/s][A
Train:  53%|█████▎    | 30/57 [00:14<00:12,  2.15it/s][A
Train:  56%|█████▌    | 32/57 [00:15<00:13,  1.89it/s][A
Train:  60%|█████▉    | 34

Epoch: 15 Training loss: 0.01



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.73it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  4.99it/s]
Epochs:  38%|███▊      | 15/40 [07:21<12:13, 29.34s/it]

Valid loss: 0.072; Valid accuracy: 0.755



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:01<00:23,  2.39it/s][A
Train:   7%|▋         | 4/57 [00:02<00:34,  1.52it/s][A
Train:   9%|▉         | 5/57 [00:03<00:32,  1.62it/s][A
Train:  11%|█         | 6/57 [00:03<00:36,  1.39it/s][A
Train:  14%|█▍        | 8/57 [00:05<00:28,  1.72it/s][A
Train:  16%|█▌        | 9/57 [00:06<00:36,  1.33it/s][A
Train:  18%|█▊        | 10/57 [00:07<00:37,  1.25it/s][A
Train:  19%|█▉        | 11/57 [00:07<00:38,  1.18it/s][A
Train:  23%|██▎       | 13/57 [00:08<00:28,  1.52it/s][A
Train:  25%|██▍       | 14/57 [00:09<00:27,  1.59it/s][A
Train:  28%|██▊       | 16/57 [00:10<00:22,  1.84it/s][A
Train:  30%|██▉       | 17/57 [00:11<00:24,  1.65it/s][A
Train:  32%|███▏      | 18/57 [00:11<00:26,  1.45it/s][A
Train:  35%|███▌      | 20/57 [00:12<00:21,  1.75it/s][A
Train:  37%|███▋      | 21/57 [00:13<00:20,  1.78it/s][A
Train:  40%|████      | 23/57 [00:14<00:17,  1.98it/s][A
Train:  42%|████▏     | 24/5

Epoch: 16 Training loss: 0.01



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.75it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  5.17it/s]
Epochs:  40%|████      | 16/40 [07:56<12:21, 30.91s/it]

Valid loss: 0.154; Valid accuracy: 0.760



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:01<00:22,  2.40it/s][A
Train:   7%|▋         | 4/57 [00:01<00:25,  2.12it/s][A
Train:  11%|█         | 6/57 [00:02<00:23,  2.16it/s][A
Train:  14%|█▍        | 8/57 [00:04<00:21,  2.26it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:24,  1.95it/s][A
Train:  21%|██        | 12/57 [00:05<00:22,  1.99it/s][A
Train:  25%|██▍       | 14/57 [00:07<00:20,  2.10it/s][A
Train:  28%|██▊       | 16/57 [00:08<00:21,  1.91it/s][A
Train:  30%|██▉       | 17/57 [00:09<00:21,  1.88it/s][A
Train:  32%|███▏      | 18/57 [00:09<00:24,  1.62it/s][A
Train:  35%|███▌      | 20/57 [00:11<00:20,  1.85it/s][A
Train:  37%|███▋      | 21/57 [00:12<00:24,  1.46it/s][A
Train:  39%|███▊      | 22/57 [00:12<00:26,  1.34it/s][A
Train:  42%|████▏     | 24/57 [00:14<00:20,  1.62it/s][A
Train:  44%|████▍     | 25/57 [00:14<00:23,  1.36it/s][A
Train:  47%|████▋     | 27/57 [00:15<00:18,  1.64it/s][A
Train:  49%|████▉     | 28

Epoch: 17 Training loss: 0.01



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.79it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  4.37it/s]
Epochs:  42%|████▎     | 17/40 [08:30<12:12, 31.85s/it]

Valid loss: 0.077; Valid accuracy: 0.630



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:23,  2.39it/s][A
Train:   7%|▋         | 4/57 [00:01<00:23,  2.30it/s][A
Train:  11%|█         | 6/57 [00:02<00:21,  2.33it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:20,  2.36it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:24,  1.91it/s][A
Train:  21%|██        | 12/57 [00:05<00:22,  2.01it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:20,  2.13it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:21,  1.91it/s][A
Train:  32%|███▏      | 18/57 [00:08<00:19,  2.03it/s][A
Train:  35%|███▌      | 20/57 [00:09<00:17,  2.13it/s][A
Train:  39%|███▊      | 22/57 [00:10<00:18,  1.89it/s][A
Train:  40%|████      | 23/57 [00:11<00:18,  1.89it/s][A
Train:  44%|████▍     | 25/57 [00:12<00:15,  2.04it/s][A
Train:  47%|████▋     | 27/57 [00:13<00:14,  2.08it/s][A
Train:  51%|█████     | 29/57 [00:14<00:12,  2.17it/s][A
Train:  54%|█████▍    | 31/57 [00:14<00:11,  2.17it/s][A
Train:  58%|█████▊    | 33

Epoch: 18 Training loss: 0.01



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.74it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  5.01it/s]
Epochs:  45%|████▌     | 18/40 [08:59<11:25, 31.15s/it]

Valid loss: 0.065; Valid accuracy: 0.640



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:22,  2.41it/s][A
Train:   7%|▋         | 4/57 [00:01<00:22,  2.39it/s][A
Train:  11%|█         | 6/57 [00:02<00:21,  2.39it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:20,  2.40it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:19,  2.38it/s][A
Train:  21%|██        | 12/57 [00:05<00:18,  2.38it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:21,  2.03it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:19,  2.11it/s][A
Train:  32%|███▏      | 18/57 [00:08<00:17,  2.19it/s][A
Train:  35%|███▌      | 20/57 [00:09<00:19,  1.92it/s][A
Train:  39%|███▊      | 22/57 [00:10<00:17,  1.98it/s][A
Train:  42%|████▏     | 24/57 [00:11<00:15,  2.08it/s][A
Train:  46%|████▌     | 26/57 [00:12<00:16,  1.93it/s][A
Train:  47%|████▋     | 27/57 [00:12<00:15,  1.92it/s][A
Train:  51%|█████     | 29/57 [00:13<00:13,  2.06it/s][A
Train:  54%|█████▍    | 31/57 [00:14<00:12,  2.16it/s][A
Train:  58%|█████▊    | 33

Epoch: 19 Training loss: 0.01



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.82it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  5.05it/s]
Epochs:  48%|████▊     | 19/40 [09:28<10:39, 30.47s/it]

Valid loss: 0.048; Valid accuracy: 0.755



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:22,  2.43it/s][A
Train:   7%|▋         | 4/57 [00:01<00:21,  2.42it/s][A
Train:  11%|█         | 6/57 [00:02<00:21,  2.42it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:21,  2.32it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:19,  2.35it/s][A
Train:  21%|██        | 12/57 [00:05<00:21,  2.05it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:20,  2.07it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:18,  2.17it/s][A
Train:  32%|███▏      | 18/57 [00:08<00:19,  1.96it/s][A
Train:  35%|███▌      | 20/57 [00:09<00:18,  1.99it/s][A
Train:  39%|███▊      | 22/57 [00:10<00:16,  2.10it/s][A
Train:  42%|████▏     | 24/57 [00:11<00:17,  1.93it/s][A
Train:  44%|████▍     | 25/57 [00:12<00:16,  1.93it/s][A
Train:  47%|████▋     | 27/57 [00:12<00:14,  2.07it/s][A
Train:  51%|█████     | 29/57 [00:13<00:12,  2.18it/s][A
Train:  54%|█████▍    | 31/57 [00:14<00:11,  2.24it/s][A
Train:  58%|█████▊    | 33

Epoch: 20 Training loss: 0.00



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.88it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  5.08it/s]
Epochs:  50%|█████     | 20/40 [09:57<09:57, 29.88s/it]

Valid loss: 0.045; Valid accuracy: 0.755



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:01<00:22,  2.40it/s][A
Train:   7%|▋         | 4/57 [00:02<00:26,  2.03it/s][A
Train:  11%|█         | 6/57 [00:02<00:24,  2.07it/s][A
Train:  14%|█▍        | 8/57 [00:04<00:22,  2.20it/s][A
Train:  18%|█▊        | 10/57 [00:05<00:24,  1.93it/s][A
Train:  19%|█▉        | 11/57 [00:05<00:24,  1.91it/s][A
Train:  23%|██▎       | 13/57 [00:06<00:21,  2.07it/s][A
Train:  26%|██▋       | 15/57 [00:07<00:19,  2.17it/s][A
Train:  30%|██▉       | 17/57 [00:08<00:17,  2.23it/s][A
Train:  33%|███▎      | 19/57 [00:09<00:17,  2.20it/s][A
Train:  37%|███▋      | 21/57 [00:09<00:16,  2.19it/s][A
Train:  40%|████      | 23/57 [00:11<00:15,  2.25it/s][A
Train:  44%|████▍     | 25/57 [00:12<00:15,  2.01it/s][A
Train:  47%|████▋     | 27/57 [00:12<00:14,  2.03it/s][A
Train:  51%|█████     | 29/57 [00:14<00:13,  2.13it/s][A
Train:  54%|█████▍    | 31/57 [00:14<00:13,  1.98it/s][A
Train:  58%|█████▊    | 33

Epoch: 21 Training loss: 0.00



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.80it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  6.48it/s]
Epochs:  52%|█████▎    | 21/40 [10:26<09:23, 29.65s/it]

Valid loss: 0.052; Valid accuracy: 0.570



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:22,  2.41it/s][A
Train:   7%|▋         | 4/57 [00:01<00:24,  2.18it/s][A
Train:  11%|█         | 6/57 [00:02<00:23,  2.15it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:21,  2.23it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:23,  2.03it/s][A
Train:  21%|██        | 12/57 [00:05<00:20,  2.15it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:22,  1.94it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:20,  1.99it/s][A
Train:  32%|███▏      | 18/57 [00:09<00:18,  2.10it/s][A
Train:  35%|███▌      | 20/57 [00:09<00:19,  1.90it/s][A
Train:  37%|███▋      | 21/57 [00:10<00:18,  1.90it/s][A
Train:  39%|███▊      | 22/57 [00:11<00:21,  1.64it/s][A
Train:  42%|████▏     | 24/57 [00:12<00:17,  1.86it/s][A
Train:  44%|████▍     | 25/57 [00:13<00:21,  1.46it/s][A
Train:  46%|████▌     | 26/57 [00:14<00:23,  1.34it/s][A
Train:  49%|████▉     | 28/57 [00:16<00:17,  1.62it/s][A
Train:  51%|█████     | 29

Epoch: 22 Training loss: 0.00



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:01<00:00,  5.80it/s][A
Valid: 100%|██████████| 7/7 [00:02<00:00,  3.13it/s]
Epochs:  55%|█████▌    | 22/40 [10:58<09:06, 30.35s/it]

Valid loss: 0.042; Valid accuracy: 0.730



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:22,  2.45it/s][A
Train:   7%|▋         | 4/57 [00:02<00:21,  2.43it/s][A
Train:  11%|█         | 6/57 [00:02<00:25,  2.02it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:23,  2.06it/s][A
Train:  18%|█▊        | 10/57 [00:05<00:21,  2.17it/s][A
Train:  21%|██        | 12/57 [00:05<00:23,  1.95it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:21,  2.00it/s][A
Train:  28%|██▊       | 16/57 [00:08<00:19,  2.11it/s][A
Train:  32%|███▏      | 18/57 [00:09<00:20,  1.91it/s][A
Train:  33%|███▎      | 19/57 [00:09<00:20,  1.89it/s][A
Train:  35%|███▌      | 20/57 [00:10<00:22,  1.63it/s][A
Train:  39%|███▊      | 22/57 [00:12<00:18,  1.85it/s][A
Train:  40%|████      | 23/57 [00:12<00:23,  1.45it/s][A
Train:  42%|████▏     | 24/57 [00:13<00:24,  1.34it/s][A
Train:  46%|████▌     | 26/57 [00:15<00:19,  1.63it/s][A
Train:  47%|████▋     | 27/57 [00:15<00:22,  1.31it/s][A
Train:  49%|████▉     | 28

Epoch: 23 Training loss: 0.00



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.81it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  6.13it/s]
Epochs:  57%|█████▊    | 23/40 [11:32<08:57, 31.60s/it]

Valid loss: 0.045; Valid accuracy: 0.755



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:01<00:22,  2.42it/s][A
Train:   7%|▋         | 4/57 [00:02<00:32,  1.64it/s][A
Train:   9%|▉         | 5/57 [00:02<00:30,  1.69it/s][A
Train:  12%|█▏        | 7/57 [00:03<00:25,  1.95it/s][A
Train:  16%|█▌        | 9/57 [00:04<00:22,  2.11it/s][A
Train:  19%|█▉        | 11/57 [00:05<00:20,  2.21it/s][A
Train:  23%|██▎       | 13/57 [00:06<00:19,  2.20it/s][A
Train:  26%|██▋       | 15/57 [00:07<00:18,  2.27it/s][A
Train:  30%|██▉       | 17/57 [00:08<00:19,  2.01it/s][A
Train:  33%|███▎      | 19/57 [00:09<00:18,  2.03it/s][A
Train:  37%|███▋      | 21/57 [00:10<00:16,  2.13it/s][A
Train:  40%|████      | 23/57 [00:11<00:17,  1.95it/s][A
Train:  44%|████▍     | 25/57 [00:12<00:15,  2.01it/s][A
Train:  47%|████▋     | 27/57 [00:13<00:14,  2.08it/s][A
Train:  51%|█████     | 29/57 [00:14<00:13,  2.14it/s][A
Train:  54%|█████▍    | 31/57 [00:15<00:13,  1.94it/s][A
Train:  58%|█████▊    | 33/

Epoch: 24 Training loss: 0.00



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.77it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  3.93it/s]
Epochs:  60%|██████    | 24/40 [12:05<08:30, 31.89s/it]

Valid loss: 0.046; Valid accuracy: 0.625



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:22,  2.43it/s][A
Train:   7%|▋         | 4/57 [00:01<00:22,  2.32it/s][A
Train:  11%|█         | 6/57 [00:02<00:21,  2.35it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:24,  2.02it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:23,  2.03it/s][A
Train:  21%|██        | 12/57 [00:05<00:20,  2.15it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:22,  1.94it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:20,  2.00it/s][A
Train:  32%|███▏      | 18/57 [00:08<00:18,  2.11it/s][A
Train:  35%|███▌      | 20/57 [00:09<00:19,  1.94it/s][A
Train:  37%|███▋      | 21/57 [00:10<00:18,  1.92it/s][A
Train:  39%|███▊      | 22/57 [00:11<00:21,  1.64it/s][A
Train:  42%|████▏     | 24/57 [00:12<00:17,  1.86it/s][A
Train:  44%|████▍     | 25/57 [00:13<00:21,  1.47it/s][A
Train:  46%|████▌     | 26/57 [00:14<00:22,  1.35it/s][A
Train:  49%|████▉     | 28/57 [00:15<00:17,  1.64it/s][A
Train:  53%|█████▎    | 30

Epoch: 25 Training loss: 0.00



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.30it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  6.21it/s]
Epochs:  62%|██████▎   | 25/40 [12:35<07:49, 31.32s/it]

Valid loss: 0.044; Valid accuracy: 0.760



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:01<00:22,  2.43it/s][A
Train:   7%|▋         | 4/57 [00:01<00:25,  2.09it/s][A
Train:  11%|█         | 6/57 [00:02<00:23,  2.22it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:21,  2.27it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:23,  2.01it/s][A
Train:  21%|██        | 12/57 [00:05<00:22,  2.03it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:20,  2.14it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:21,  1.92it/s][A
Train:  30%|██▉       | 17/57 [00:08<00:20,  1.91it/s][A
Train:  32%|███▏      | 18/57 [00:09<00:23,  1.64it/s][A
Train:  35%|███▌      | 20/57 [00:10<00:19,  1.86it/s][A
Train:  37%|███▋      | 21/57 [00:11<00:24,  1.45it/s][A
Train:  39%|███▊      | 22/57 [00:12<00:26,  1.34it/s][A
Train:  42%|████▏     | 24/57 [00:14<00:20,  1.63it/s][A
Train:  44%|████▍     | 25/57 [00:14<00:24,  1.32it/s][A
Train:  46%|████▌     | 26/57 [00:15<00:24,  1.27it/s][A
Train:  49%|████▉     | 28

Epoch: 26 Training loss: 0.00



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.83it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  6.17it/s]
Epochs:  65%|██████▌   | 26/40 [13:09<07:31, 32.25s/it]

Valid loss: 0.033; Valid accuracy: 0.835



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:01<00:23,  2.39it/s][A
Train:   7%|▋         | 4/57 [00:02<00:31,  1.68it/s][A
Train:  11%|█         | 6/57 [00:03<00:26,  1.90it/s][A
Train:  14%|█▍        | 8/57 [00:04<00:23,  2.08it/s][A
Train:  18%|█▊        | 10/57 [00:05<00:25,  1.85it/s][A
Train:  21%|██        | 12/57 [00:06<00:22,  1.97it/s][A
Train:  25%|██▍       | 14/57 [00:07<00:20,  2.09it/s][A
Train:  28%|██▊       | 16/57 [00:08<00:21,  1.87it/s][A
Train:  32%|███▏      | 18/57 [00:09<00:19,  1.97it/s][A
Train:  35%|███▌      | 20/57 [00:10<00:17,  2.09it/s][A
Train:  39%|███▊      | 22/57 [00:11<00:18,  1.88it/s][A
Train:  42%|████▏     | 24/57 [00:12<00:16,  2.01it/s][A
Train:  46%|████▌     | 26/57 [00:13<00:16,  1.88it/s][A
Train:  47%|████▋     | 27/57 [00:14<00:15,  1.88it/s][A
Train:  49%|████▉     | 28/57 [00:14<00:17,  1.62it/s][A
Train:  53%|█████▎    | 30/57 [00:16<00:14,  1.84it/s][A
Train:  54%|█████▍    | 31

Epoch: 27 Training loss: 0.00



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.85it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  6.16it/s]
Epochs:  68%|██████▊   | 27/40 [13:40<06:54, 31.87s/it]

Valid loss: 0.037; Valid accuracy: 0.795



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:01<00:22,  2.42it/s][A
Train:   7%|▋         | 4/57 [00:02<00:31,  1.68it/s][A
Train:  11%|█         | 6/57 [00:03<00:26,  1.90it/s][A
Train:  14%|█▍        | 8/57 [00:04<00:23,  2.08it/s][A
Train:  18%|█▊        | 10/57 [00:05<00:25,  1.85it/s][A
Train:  21%|██        | 12/57 [00:06<00:22,  1.97it/s][A
Train:  25%|██▍       | 14/57 [00:07<00:20,  2.09it/s][A
Train:  28%|██▊       | 16/57 [00:08<00:21,  1.89it/s][A
Train:  32%|███▏      | 18/57 [00:09<00:19,  2.03it/s][A
Train:  35%|███▌      | 20/57 [00:10<00:17,  2.13it/s][A
Train:  39%|███▊      | 22/57 [00:11<00:18,  1.86it/s][A
Train:  42%|████▏     | 24/57 [00:12<00:16,  1.97it/s][A
Train:  46%|████▌     | 26/57 [00:13<00:14,  2.09it/s][A
Train:  49%|████▉     | 28/57 [00:14<00:15,  1.89it/s][A
Train:  53%|█████▎    | 30/57 [00:15<00:13,  1.98it/s][A
Train:  56%|█████▌    | 32/57 [00:16<00:11,  2.09it/s][A
Train:  60%|█████▉    | 34

Epoch: 28 Training loss: 0.00



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.19it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  6.16it/s]
Epochs:  70%|███████   | 28/40 [14:10<06:15, 31.27s/it]

Valid loss: 0.039; Valid accuracy: 0.765



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:22,  2.42it/s][A
Train:   7%|▋         | 4/57 [00:01<00:21,  2.42it/s][A
Train:  11%|█         | 6/57 [00:02<00:21,  2.41it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:21,  2.29it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:20,  2.33it/s][A
Train:  21%|██        | 12/57 [00:05<00:22,  2.02it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:20,  2.05it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:19,  2.15it/s][A
Train:  32%|███▏      | 18/57 [00:08<00:19,  1.95it/s][A
Train:  35%|███▌      | 20/57 [00:09<00:18,  2.00it/s][A
Train:  39%|███▊      | 22/57 [00:10<00:16,  2.11it/s][A
Train:  42%|████▏     | 24/57 [00:11<00:17,  1.93it/s][A
Train:  44%|████▍     | 25/57 [00:12<00:16,  1.91it/s][A
Train:  46%|████▌     | 26/57 [00:13<00:18,  1.65it/s][A
Train:  49%|████▉     | 28/57 [00:14<00:15,  1.86it/s][A
Train:  51%|█████     | 29/57 [00:15<00:19,  1.46it/s][A
Train:  53%|█████▎    | 30

Epoch: 29 Training loss: 0.00



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.80it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  3.93it/s]
Epochs:  72%|███████▎  | 29/40 [14:46<05:58, 32.60s/it]

Valid loss: 0.033; Valid accuracy: 0.805



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:22,  2.42it/s][A
Train:   7%|▋         | 4/57 [00:01<00:23,  2.28it/s][A
Train:  11%|█         | 6/57 [00:02<00:21,  2.36it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:24,  1.99it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:23,  2.03it/s][A
Train:  21%|██        | 12/57 [00:05<00:20,  2.15it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:22,  1.93it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:20,  1.98it/s][A
Train:  32%|███▏      | 18/57 [00:08<00:19,  2.02it/s][A
Train:  35%|███▌      | 20/57 [00:09<00:17,  2.13it/s][A
Train:  39%|███▊      | 22/57 [00:10<00:17,  1.98it/s][A
Train:  42%|████▏     | 24/57 [00:11<00:15,  2.07it/s][A
Train:  46%|████▌     | 26/57 [00:12<00:14,  2.16it/s][A
Train:  49%|████▉     | 28/57 [00:13<00:15,  1.91it/s][A
Train:  53%|█████▎    | 30/57 [00:14<00:13,  2.00it/s][A
Train:  56%|█████▌    | 32/57 [00:15<00:11,  2.11it/s][A
Train:  60%|█████▉    | 34

Epoch: 30 Training loss: 0.00



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.80it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  6.23it/s]
Epochs:  75%|███████▌  | 30/40 [15:15<05:16, 31.65s/it]

Valid loss: 0.044; Valid accuracy: 0.660



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:01<00:22,  2.40it/s][A
Train:   7%|▋         | 4/57 [00:02<00:33,  1.60it/s][A
Train:   9%|▉         | 5/57 [00:02<00:31,  1.65it/s][A
Train:  12%|█▏        | 7/57 [00:03<00:25,  1.93it/s][A
Train:  16%|█▌        | 9/57 [00:05<00:30,  1.58it/s][A
Train:  18%|█▊        | 10/57 [00:05<00:28,  1.63it/s][A
Train:  21%|██        | 12/57 [00:06<00:24,  1.86it/s][A
Train:  25%|██▍       | 14/57 [00:07<00:21,  2.03it/s][A
Train:  28%|██▊       | 16/57 [00:08<00:19,  2.14it/s][A
Train:  32%|███▏      | 18/57 [00:09<00:17,  2.17it/s][A
Train:  35%|███▌      | 20/57 [00:10<00:16,  2.24it/s][A
Train:  39%|███▊      | 22/57 [00:11<00:17,  2.00it/s][A
Train:  42%|████▏     | 24/57 [00:12<00:16,  2.03it/s][A
Train:  46%|████▌     | 26/57 [00:13<00:14,  2.09it/s][A
Train:  49%|████▉     | 28/57 [00:14<00:13,  2.18it/s][A
Train:  53%|█████▎    | 30/57 [00:15<00:13,  1.93it/s][A
Train:  56%|█████▌    | 32/

Epoch: 31 Training loss: 0.00



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:01<00:00,  5.82it/s][A
Valid: 100%|██████████| 7/7 [00:02<00:00,  3.13it/s]
Epochs:  78%|███████▊  | 31/40 [15:46<04:41, 31.32s/it]

Valid loss: 0.053; Valid accuracy: 0.645



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:22,  2.42it/s][A
Train:   7%|▋         | 4/57 [00:01<00:21,  2.43it/s][A
Train:  11%|█         | 6/57 [00:02<00:24,  2.07it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:22,  2.14it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:21,  2.23it/s][A
Train:  21%|██        | 12/57 [00:05<00:23,  1.96it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:20,  2.05it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:19,  2.15it/s][A
Train:  32%|███▏      | 18/57 [00:08<00:20,  1.92it/s][A
Train:  35%|███▌      | 20/57 [00:09<00:18,  2.04it/s][A
Train:  39%|███▊      | 22/57 [00:10<00:16,  2.14it/s][A
Train:  42%|████▏     | 24/57 [00:11<00:17,  1.88it/s][A
Train:  44%|████▍     | 25/57 [00:12<00:17,  1.88it/s][A
Train:  47%|████▋     | 27/57 [00:13<00:14,  2.03it/s][A
Train:  51%|█████     | 29/57 [00:13<00:13,  2.14it/s][A
Train:  54%|█████▍    | 31/57 [00:14<00:11,  2.22it/s][A
Train:  58%|█████▊    | 33

Epoch: 32 Training loss: 0.00



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.89it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  6.19it/s]
Epochs:  80%|████████  | 32/40 [16:15<04:03, 30.47s/it]

Valid loss: 0.034; Valid accuracy: 0.790



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:01<00:22,  2.41it/s][A
Train:   7%|▋         | 4/57 [00:02<00:31,  1.67it/s][A
Train:  11%|█         | 6/57 [00:03<00:27,  1.89it/s][A
Train:  14%|█▍        | 8/57 [00:04<00:23,  2.06it/s][A
Train:  18%|█▊        | 10/57 [00:05<00:25,  1.86it/s][A
Train:  21%|██        | 12/57 [00:06<00:22,  1.97it/s][A
Train:  25%|██▍       | 14/57 [00:07<00:20,  2.10it/s][A
Train:  28%|██▊       | 16/57 [00:08<00:23,  1.77it/s][A
Train:  32%|███▏      | 18/57 [00:10<00:20,  1.93it/s][A
Train:  35%|███▌      | 20/57 [00:11<00:24,  1.52it/s][A
Train:  39%|███▊      | 22/57 [00:12<00:20,  1.72it/s][A
Train:  40%|████      | 23/57 [00:13<00:21,  1.58it/s][A
Train:  42%|████▏     | 24/57 [00:14<00:22,  1.44it/s][A
Train:  44%|████▍     | 25/57 [00:14<00:24,  1.33it/s][A
Train:  47%|████▋     | 27/57 [00:16<00:18,  1.61it/s][A
Train:  49%|████▉     | 28/57 [00:17<00:21,  1.33it/s][A
Train:  51%|█████     | 29

Epoch: 33 Training loss: 0.00



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.85it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  5.06it/s]
Epochs:  82%|████████▎ | 33/40 [16:48<03:39, 31.35s/it]

Valid loss: 0.040; Valid accuracy: 0.705



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:01<00:22,  2.43it/s][A
Train:   7%|▋         | 4/57 [00:02<00:26,  2.03it/s][A
Train:  11%|█         | 6/57 [00:02<00:24,  2.08it/s][A
Train:  14%|█▍        | 8/57 [00:04<00:22,  2.19it/s][A
Train:  18%|█▊        | 10/57 [00:05<00:24,  1.94it/s][A
Train:  19%|█▉        | 11/57 [00:05<00:23,  1.93it/s][A
Train:  23%|██▎       | 13/57 [00:06<00:21,  2.08it/s][A
Train:  26%|██▋       | 15/57 [00:07<00:19,  2.16it/s][A
Train:  30%|██▉       | 17/57 [00:08<00:17,  2.23it/s][A
Train:  33%|███▎      | 19/57 [00:09<00:17,  2.19it/s][A
Train:  37%|███▋      | 21/57 [00:09<00:16,  2.17it/s][A
Train:  40%|████      | 23/57 [00:11<00:15,  2.23it/s][A
Train:  44%|████▍     | 25/57 [00:11<00:15,  2.04it/s][A
Train:  47%|████▋     | 27/57 [00:13<00:14,  2.14it/s][A
Train:  51%|█████     | 29/57 [00:14<00:14,  1.95it/s][A
Train:  54%|█████▍    | 31/57 [00:15<00:13,  1.99it/s][A
Train:  58%|█████▊    | 33

Epoch: 34 Training loss: 0.00



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.82it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  4.47it/s]
Epochs:  85%|████████▌ | 34/40 [17:17<03:04, 30.70s/it]

Valid loss: 0.039; Valid accuracy: 0.760



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:24,  2.24it/s][A
Train:   7%|▋         | 4/57 [00:01<00:22,  2.33it/s][A
Train:  11%|█         | 6/57 [00:02<00:21,  2.37it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:24,  2.01it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:23,  2.04it/s][A
Train:  21%|██        | 12/57 [00:05<00:20,  2.15it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:22,  1.95it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:20,  2.01it/s][A
Train:  32%|███▏      | 18/57 [00:09<00:18,  2.12it/s][A
Train:  35%|███▌      | 20/57 [00:09<00:19,  1.90it/s][A
Train:  39%|███▊      | 22/57 [00:10<00:17,  2.03it/s][A
Train:  42%|████▏     | 24/57 [00:11<00:15,  2.09it/s][A
Train:  46%|████▌     | 26/57 [00:12<00:14,  2.18it/s][A
Train:  49%|████▉     | 28/57 [00:13<00:15,  1.93it/s][A
Train:  53%|█████▎    | 30/57 [00:14<00:13,  2.01it/s][A
Train:  56%|█████▌    | 32/57 [00:15<00:11,  2.11it/s][A
Train:  60%|█████▉    | 34

Epoch: 35 Training loss: 0.00



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.38it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  4.61it/s]
Epochs:  88%|████████▊ | 35/40 [17:47<02:31, 30.32s/it]

Valid loss: 0.039; Valid accuracy: 0.700



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:01<00:22,  2.46it/s][A
Train:   7%|▋         | 4/57 [00:02<00:32,  1.61it/s][A
Train:   9%|▉         | 5/57 [00:03<00:30,  1.68it/s][A
Train:  11%|█         | 6/57 [00:03<00:35,  1.44it/s][A
Train:  14%|█▍        | 8/57 [00:05<00:27,  1.77it/s][A
Train:  16%|█▌        | 9/57 [00:06<00:35,  1.36it/s][A
Train:  18%|█▊        | 10/57 [00:06<00:37,  1.26it/s][A
Train:  21%|██        | 12/57 [00:08<00:28,  1.58it/s][A
Train:  23%|██▎       | 13/57 [00:09<00:34,  1.28it/s][A
Train:  25%|██▍       | 14/57 [00:09<00:34,  1.25it/s][A
Train:  28%|██▊       | 16/57 [00:10<00:26,  1.56it/s][A
Train:  32%|███▏      | 18/57 [00:11<00:21,  1.79it/s][A
Train:  35%|███▌      | 20/57 [00:12<00:18,  1.96it/s][A
Train:  39%|███▊      | 22/57 [00:13<00:17,  2.06it/s][A
Train:  42%|████▏     | 24/57 [00:14<00:15,  2.16it/s][A
Train:  46%|████▌     | 26/57 [00:15<00:15,  1.96it/s][A
Train:  49%|████▉     | 28/5

Epoch: 36 Training loss: 0.00



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.85it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  3.92it/s]
Epochs:  90%|█████████ | 36/40 [18:19<02:03, 30.99s/it]

Valid loss: 0.040; Valid accuracy: 0.760



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:22,  2.42it/s][A
Train:   7%|▋         | 4/57 [00:01<00:22,  2.32it/s][A
Train:  11%|█         | 6/57 [00:02<00:21,  2.36it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:24,  2.00it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:23,  2.03it/s][A
Train:  21%|██        | 12/57 [00:05<00:21,  2.13it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:21,  1.96it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:20,  1.99it/s][A
Train:  32%|███▏      | 18/57 [00:08<00:18,  2.11it/s][A
Train:  35%|███▌      | 20/57 [00:09<00:19,  1.93it/s][A
Train:  37%|███▋      | 21/57 [00:10<00:18,  1.92it/s][A
Train:  39%|███▊      | 22/57 [00:11<00:21,  1.64it/s][A
Train:  42%|████▏     | 24/57 [00:12<00:17,  1.86it/s][A
Train:  44%|████▍     | 25/57 [00:13<00:21,  1.47it/s][A
Train:  46%|████▌     | 26/57 [00:14<00:23,  1.34it/s][A
Train:  49%|████▉     | 28/57 [00:15<00:17,  1.63it/s][A
Train:  53%|█████▎    | 30

Epoch: 37 Training loss: 0.00



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.69it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  5.01it/s]
Epochs:  92%|█████████▎| 37/40 [18:49<01:32, 30.78s/it]

Valid loss: 0.040; Valid accuracy: 0.765



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:01<00:22,  2.41it/s][A
Train:   7%|▋         | 4/57 [00:02<00:34,  1.54it/s][A
Train:   9%|▉         | 5/57 [00:03<00:32,  1.61it/s][A
Train:  12%|█▏        | 7/57 [00:04<00:26,  1.90it/s][A
Train:  16%|█▌        | 9/57 [00:05<00:31,  1.53it/s][A
Train:  18%|█▊        | 10/57 [00:06<00:29,  1.60it/s][A
Train:  19%|█▉        | 11/57 [00:06<00:32,  1.43it/s][A
Train:  23%|██▎       | 13/57 [00:07<00:25,  1.71it/s][A
Train:  25%|██▍       | 14/57 [00:08<00:24,  1.73it/s][A
Train:  28%|██▊       | 16/57 [00:09<00:21,  1.95it/s][A
Train:  32%|███▏      | 18/57 [00:10<00:18,  2.09it/s][A
Train:  35%|███▌      | 20/57 [00:11<00:20,  1.81it/s][A
Train:  39%|███▊      | 22/57 [00:12<00:18,  1.93it/s][A
Train:  42%|████▏     | 24/57 [00:13<00:16,  2.06it/s][A
Train:  46%|████▌     | 26/57 [00:14<00:16,  1.88it/s][A
Train:  49%|████▉     | 28/57 [00:15<00:14,  2.03it/s][A
Train:  53%|█████▎    | 30/

Epoch: 38 Training loss: 0.00



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.75it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  6.44it/s]
Epochs:  95%|█████████▌| 38/40 [19:20<01:01, 30.75s/it]

Valid loss: 0.037; Valid accuracy: 0.765



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:22,  2.42it/s][A
Train:   7%|▋         | 4/57 [00:01<00:23,  2.21it/s][A
Train:  11%|█         | 6/57 [00:02<00:23,  2.16it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:21,  2.24it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:23,  1.98it/s][A
Train:  21%|██        | 12/57 [00:05<00:22,  2.01it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:20,  2.05it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:19,  2.15it/s][A
Train:  32%|███▏      | 18/57 [00:08<00:20,  1.95it/s][A
Train:  35%|███▌      | 20/57 [00:09<00:18,  1.99it/s][A
Train:  39%|███▊      | 22/57 [00:10<00:16,  2.10it/s][A
Train:  42%|████▏     | 24/57 [00:11<00:16,  1.95it/s][A
Train:  46%|████▌     | 26/57 [00:12<00:15,  2.04it/s][A
Train:  49%|████▉     | 28/57 [00:13<00:13,  2.14it/s][A
Train:  53%|█████▎    | 30/57 [00:14<00:14,  1.93it/s][A
Train:  56%|█████▌    | 32/57 [00:15<00:12,  2.01it/s][A
Train:  60%|█████▉    | 34

Epoch: 39 Training loss: 0.00



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.82it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  4.52it/s]
Epochs:  98%|█████████▊| 39/40 [19:49<00:30, 30.23s/it]

Valid loss: 0.040; Valid accuracy: 0.715



Train:   0%|          | 0/57 [00:00<?, ?it/s][A
Train:   4%|▎         | 2/57 [00:00<00:22,  2.39it/s][A
Train:   7%|▋         | 4/57 [00:01<00:23,  2.29it/s][A
Train:  11%|█         | 6/57 [00:02<00:21,  2.34it/s][A
Train:  14%|█▍        | 8/57 [00:03<00:24,  2.01it/s][A
Train:  18%|█▊        | 10/57 [00:04<00:23,  2.04it/s][A
Train:  21%|██        | 12/57 [00:05<00:20,  2.15it/s][A
Train:  25%|██▍       | 14/57 [00:06<00:22,  1.95it/s][A
Train:  28%|██▊       | 16/57 [00:07<00:20,  1.99it/s][A
Train:  32%|███▏      | 18/57 [00:08<00:19,  2.02it/s][A
Train:  35%|███▌      | 20/57 [00:09<00:17,  2.12it/s][A
Train:  39%|███▊      | 22/57 [00:10<00:17,  1.96it/s][A
Train:  42%|████▏     | 24/57 [00:11<00:16,  2.04it/s][A
Train:  46%|████▌     | 26/57 [00:12<00:14,  2.14it/s][A
Train:  49%|████▉     | 28/57 [00:13<00:15,  1.93it/s][A
Train:  53%|█████▎    | 30/57 [00:14<00:13,  2.02it/s][A
Train:  56%|█████▌    | 32/57 [00:15<00:11,  2.12it/s][A
Train:  60%|█████▉    | 34

Epoch: 40 Training loss: 0.00



Valid:   0%|          | 0/7 [00:00<?, ?it/s][A
Valid:  43%|████▎     | 3/7 [00:00<00:00,  5.79it/s][A
Valid: 100%|██████████| 7/7 [00:01<00:00,  4.10it/s]
Epochs: 100%|██████████| 40/40 [20:18<00:00, 30.46s/it]

Valid loss: 0.039; Valid accuracy: 0.765





In [None]:
 model = AutoModelForSequenceClassification.from_pretrained(f'{MODELS_PATH}/best_checkpoint').to(device)

In [None]:
evaluate(model=model, dataloader=dataloader_valid, set_name='Valid')

Valid: 100%|██████████| 7/7 [00:01<00:00,  4.47it/s]

Valid loss: 0.137; Valid accuracy: 0.865





0.865

In [None]:
import pickle
import os

query_ids = []
queries = []
passage_ids = []
passages = []

with open(f'{main_dir}/trec-covid/run.trec-covid.bm25tuned.txt') as f:
  for line in f:
      fields = line.strip().split()
      query_id = fields[0]
      query_ids.append(query_id)
      passage_id = fields[2]
      passage_ids.append(passage_id)

      query_text = id_to_query[query_id]
      queries.append(query_text)

      passage_text = id_to_doc[passage_id]
      passages.append(passage_text)


In [None]:
len(queries)

50000

In [None]:
len(passages)

50000

In [None]:
dataset_test = MSMARCODataset(tokenizer, queries, passages, [1]*len(queries))

In [None]:
dataloader_test = data.DataLoader(dataset_test, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [None]:

def evaluate_test_dataset2(model, dataloader, set_name):
  scores = []
  model.eval()
  with torch.no_grad():
    for batch in tqdm(dataloader, mininterval=0.5, desc=set_name, disable=False):
      outputs = model(**batch.to(device))
      # Usa os logits brutos
      pos_score = outputs.logits[:,0]
      scores = scores + pos_score.tolist()
  return scores

In [None]:
logit_scores = evaluate_test_dataset2(model=model, dataloader=dataloader_test, set_name='Test')

Test: 100%|██████████| 1563/1563 [06:44<00:00,  3.86it/s]


In [None]:
trec_file_logits = evaluate_ndcg_10(logit_scores, model_name, "logits")

In [None]:
import pandas as pd

run_trec_file_logits = pd.read_csv(trec_file_logits, sep="\t", header=None, 
                   skiprows=1, names=["query", "docid", "rank", "score", "system"])
run_trec_file_logits["q0"] = "q0"
run_trec_file_logits = run_trec_file_logits.to_dict(orient="list")

In [None]:
eval_ndcg10(run_trec_file_logits)

Downloading builder script:   0%|          | 0.00/5.51k [00:00<?, ?B/s]

0.5269229736376339

## Finetuning with more datasets

In [None]:
import json
import glob

# Set the path of the directory containing JSONL files
directory_path = f"{main_dir}/trec-covid/datasets/"
print(directory_path)

# Define an empty list to store the concatenated data
data = []

# Loop through all the JSONL files in the directory
for filename in glob.glob(directory_path + "*.jsonl"):
  print(filename)
  with open(filename, "r") as file:
      # Read each line in the file and append to the data list
      for line in file:
          data.append(json.loads(line))

# Write the concatenated data to a new JSONL file
with open("concatenated.jsonl", "w") as outfile:
    for item in data:
        # Write each item as a JSON object on a separate line
        outfile.write(json.dumps(item) + "\n")


/content/gdrive/MyDrive/Unicamp-aula-9/trec-covid/datasets/
/content/gdrive/MyDrive/Unicamp-aula-9/trec-covid/datasets/monique_monteiro_1000_queries.jsonl
/content/gdrive/MyDrive/Unicamp-aula-9/trec-covid/datasets/leandro_carisio_01.jsonl
/content/gdrive/MyDrive/Unicamp-aula-9/trec-covid/datasets/manoel_2k_generated_queries_20230501.jsonl
/content/gdrive/MyDrive/Unicamp-aula-9/trec-covid/datasets/leonardo_avila_queries_v1.jsonl
/content/gdrive/MyDrive/Unicamp-aula-9/trec-covid/datasets/leonardo_pacheco_1k_generated_queries_20230502.jsonl
/content/gdrive/MyDrive/Unicamp-aula-9/trec-covid/datasets/manoel_1k_generated_queries_20230430.jsonl
/content/gdrive/MyDrive/Unicamp-aula-9/trec-covid/datasets/marcus_borela_1k_gptj6b_20230501.jsonl
/content/gdrive/MyDrive/Unicamp-aula-9/trec-covid/datasets/marcus_borela_1k_gptj6b_20230501_v2.jsonl
/content/gdrive/MyDrive/Unicamp-aula-9/trec-covid/datasets/mirelle_1k_generated_queries_20230501.jsonl
/content/gdrive/MyDrive/Unicamp-aula-9/trec-covid/da

In [None]:
!head concatenated.jsonl

{"query": "What is the most suitable protein for a diagnostic approach for Salmonella Enteritidis and why?", "positive_doc_id": "m1cmkkw3", "negative_doc_ids": ["0o3mryu1", "qpq7i1ya", "j5mkparg", "auoo0dm5", "dqfvrerw"]}
{"query": "What is Cryptosporidium parvum and why is it a major cause of disease in both humans and animals?", "positive_doc_id": "ukbl0svm", "negative_doc_ids": ["k3u2nvpe", "gc11fyms", "xoq9qblv", "20o4ufa3", "3huo5nf0"]}
{"query": "What is the role of the renin-angiotensin-aldosterone system in the context of SARS-CoV-2 infection?", "positive_doc_id": "12o4zey2", "negative_doc_ids": ["6gd6nwpu", "dt4t2wos", "8zwfkken", "sv7xpi4f", "6pf73z08"]}
{"query": "What are the functions of individual endolysosomal proteases in cellular processes such as autophagy and lipoprotein particle degradation?", "positive_doc_id": "eqv6a7tj", "negative_doc_ids": ["gmrty2uu", "uzn214j6", "032utjfh", "efet3ozc", "0muwl6oc"]}
{"query": "What is the prevalence of olfactory dysfunction in 

In [None]:
model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"

In [None]:
X_train, Y_train, X_val, Y_val = generate_training_data("concatenated.jsonl", model_name)

cross-encoder/ms-marco-MiniLM-L-6-v2
                                               query  \
0  What is the most suitable protein for a diagno...   
1  What is Cryptosporidium parvum and why is it a...   
2  What is the role of the renin-angiotensin-aldo...   
3  What are the functions of individual endolysos...   
4  What is the prevalence of olfactory dysfunctio...   

                                             passage  score  
0  Rapid identification of novel antigens of Salm...   True  
1  Cryptosporidium and host resistance: historica...   True  
2  Understanding the Renin-Angiotensin-Aldosteron...   True  
3  Specific functions of lysosomal proteases in e...   True  
4  Olfactory and rhinological evaluations in SARS...   True  


In [None]:
train_queries = list(X_train["query"])
train_passages = list(X_train["passage"])
val_queries = list(X_val["query"])
val_passages = list(X_val["passage"])

In [None]:
len(train_queries)

31808

In [None]:
len(val_queries)

3536

In [None]:
dataset_train = MSMARCODataset(tokenizer, train_queries, train_passages, Y_train)
assert len(dataset_train[0]['input_ids']) > 0
assert len(dataset_train[1]['attention_mask']) > 0

In [None]:
dataset_val = MSMARCODataset(tokenizer, val_queries, val_passages, Y_val)
assert len(dataset_val[0]['input_ids']) > 0
assert len(dataset_val[1]['attention_mask']) > 0

In [None]:

dataloader_train = data.DataLoader(dataset_train, batch_size=32, shuffle=True, collate_fn=collate_fn)
dataloader_valid = data.DataLoader(dataset_val, batch_size=32, shuffle=False, collate_fn=collate_fn)


In [None]:
model = train(model_name, 10, lr=1e-3)

Parameters 22713601


Valid: 100%|██████████| 111/111 [00:28<00:00,  3.94it/s]


Valid loss: 38.423; Valid accuracy: 0.521


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]
Train:   0%|          | 0/994 [00:00<?, ?it/s][A
Train:   0%|          | 2/994 [00:00<06:55,  2.39it/s][A
Train:   0%|          | 4/994 [00:01<07:45,  2.13it/s][A
Train:   1%|          | 6/994 [00:02<07:23,  2.23it/s][A
Train:   1%|          | 8/994 [00:03<07:19,  2.24it/s][A
Train:   1%|          | 10/994 [00:04<07:07,  2.30it/s][A
Train:   1%|          | 12/994 [00:05<08:18,  1.97it/s][A
Train:   1%|▏         | 14/994 [00:06<08:08,  2.01it/s][A
Train:   2%|▏         | 16/994 [00:07<07:43,  2.11it/s][A
Train:   2%|▏         | 18/994 [00:08<08:20,  1.95it/s][A
Train:   2%|▏         | 20/994 [00:09<08:10,  1.99it/s][A
Train:   2%|▏         | 22/994 [00:10<07:42,  2.10it/s][A
Train:   2%|▏         | 24/994 [00:11<08:26,  1.91it/s][A
Train:   3%|▎         | 25/994 [00:12<08:28,  1.91it/s][A
Train:   3%|▎         | 26/994 [00:13<09:53,  1.63it/s][A
Train:   3%|▎         | 28/994 [00:14<08:44,  1.84it/s][A
Train:   3%|▎         

Epoch: 1 Training loss: 0.67



Valid:   0%|          | 0/111 [00:00<?, ?it/s][A
Valid:   3%|▎         | 3/111 [00:01<00:20,  5.40it/s][A
Valid:   5%|▌         | 6/111 [00:01<00:30,  3.46it/s][A
Valid:   8%|▊         | 9/111 [00:02<00:23,  4.27it/s][A
Valid:  11%|█         | 12/111 [00:03<00:21,  4.58it/s][A
Valid:  14%|█▎        | 15/111 [00:03<00:23,  4.01it/s][A
Valid:  16%|█▌        | 18/111 [00:04<00:20,  4.48it/s][A
Valid:  19%|█▉        | 21/111 [00:04<00:21,  4.14it/s][A
Valid:  22%|██▏       | 24/111 [00:06<00:19,  4.57it/s][A
Valid:  24%|██▍       | 27/111 [00:06<00:22,  3.74it/s][A
Valid:  27%|██▋       | 30/111 [00:07<00:19,  4.21it/s][A
Valid:  30%|██▉       | 33/111 [00:07<00:16,  4.62it/s][A
Valid:  32%|███▏      | 36/111 [00:08<00:15,  4.95it/s][A
Valid:  36%|███▌      | 40/111 [00:09<00:13,  5.32it/s][A
Valid:  39%|███▊      | 43/111 [00:10<00:15,  4.48it/s][A
Valid:  41%|████▏     | 46/111 [00:10<00:16,  4.06it/s][A
Valid:  44%|████▍     | 49/111 [00:11<00:13,  4.46it/s][A
Valid:  

Valid loss: 0.037; Valid accuracy: 0.892


Epochs:  10%|█         | 1/10 [09:20<1:24:01, 560.16s/it]
Train:   0%|          | 0/994 [00:00<?, ?it/s][A
Train:   0%|          | 2/994 [00:00<06:55,  2.39it/s][A
Train:   0%|          | 4/994 [00:01<07:23,  2.23it/s][A
Train:   1%|          | 6/994 [00:02<07:41,  2.14it/s][A
Train:   1%|          | 8/994 [00:03<07:20,  2.24it/s][A
Train:   1%|          | 10/994 [00:04<08:06,  2.02it/s][A
Train:   1%|          | 12/994 [00:05<07:47,  2.10it/s][A
Train:   1%|▏         | 14/994 [00:06<07:28,  2.19it/s][A
Train:   2%|▏         | 16/994 [00:07<08:37,  1.89it/s][A
Train:   2%|▏         | 17/994 [00:08<08:41,  1.87it/s][A
Train:   2%|▏         | 18/994 [00:09<10:05,  1.61it/s][A
Train:   2%|▏         | 20/994 [00:10<08:51,  1.83it/s][A
Train:   2%|▏         | 21/994 [00:10<09:01,  1.80it/s][A
Train:   2%|▏         | 23/994 [00:11<08:07,  1.99it/s][A
Train:   3%|▎         | 25/994 [00:12<08:43,  1.85it/s][A
Train:   3%|▎         | 26/994 [00:13<08:45,  1.84it/s][A
Train:   3%

Epoch: 2 Training loss: 0.02



Valid:   0%|          | 0/111 [00:00<?, ?it/s][A
Valid:   3%|▎         | 3/111 [00:00<00:18,  5.80it/s][A
Valid:   5%|▌         | 6/111 [00:01<00:26,  3.95it/s][A
Valid:   8%|▊         | 9/111 [00:01<00:21,  4.66it/s][A
Valid:  11%|█         | 12/111 [00:02<00:19,  5.08it/s][A
Valid:  14%|█▎        | 15/111 [00:02<00:18,  5.32it/s][A
Valid:  16%|█▌        | 18/111 [00:03<00:17,  5.22it/s][A
Valid:  19%|█▉        | 21/111 [00:04<00:16,  5.43it/s][A
Valid:  22%|██▏       | 24/111 [00:04<00:15,  5.51it/s][A
Valid:  24%|██▍       | 27/111 [00:05<00:18,  4.50it/s][A
Valid:  27%|██▋       | 30/111 [00:06<00:19,  4.05it/s][A
Valid:  30%|██▉       | 33/111 [00:06<00:17,  4.50it/s][A
Valid:  32%|███▏      | 36/111 [00:07<00:15,  4.84it/s][A
Valid:  35%|███▌      | 39/111 [00:08<00:18,  3.97it/s][A
Valid:  38%|███▊      | 42/111 [00:09<00:18,  3.73it/s][A
Valid:  41%|████      | 45/111 [00:09<00:15,  4.19it/s][A
Valid:  43%|████▎     | 48/111 [00:10<00:13,  4.57it/s][A
Valid:  

Valid loss: 0.028; Valid accuracy: 0.909


Epochs:  20%|██        | 2/10 [18:25<1:13:29, 551.22s/it]
Train:   0%|          | 0/994 [00:00<?, ?it/s][A
Train:   0%|          | 2/994 [00:00<06:53,  2.40it/s][A
Train:   0%|          | 4/994 [00:01<07:30,  2.20it/s][A
Train:   1%|          | 6/994 [00:02<07:10,  2.29it/s][A
Train:   1%|          | 8/994 [00:03<08:15,  1.99it/s][A
Train:   1%|          | 10/994 [00:04<08:08,  2.01it/s][A
Train:   1%|          | 12/994 [00:05<08:01,  2.04it/s][A
Train:   1%|▏         | 14/994 [00:06<07:36,  2.15it/s][A
Train:   2%|▏         | 16/994 [00:07<08:11,  1.99it/s][A
Train:   2%|▏         | 18/994 [00:08<07:46,  2.09it/s][A
Train:   2%|▏         | 20/994 [00:09<08:26,  1.92it/s][A
Train:   2%|▏         | 21/994 [00:10<08:30,  1.91it/s][A
Train:   2%|▏         | 23/994 [00:11<07:53,  2.05it/s][A
Train:   3%|▎         | 25/994 [00:12<07:29,  2.16it/s][A
Train:   3%|▎         | 27/994 [00:12<07:13,  2.23it/s][A
Train:   3%|▎         | 29/994 [00:13<07:23,  2.18it/s][A
Train:   3%

Epoch: 3 Training loss: 0.01



Valid:   0%|          | 0/111 [00:00<?, ?it/s][A
Valid:   3%|▎         | 3/111 [00:00<00:18,  5.89it/s][A
Valid:   5%|▌         | 6/111 [00:01<00:17,  5.91it/s][A
Valid:   8%|▊         | 9/111 [00:02<00:27,  3.72it/s][A
Valid:  11%|█         | 12/111 [00:03<00:28,  3.53it/s][A
Valid:  14%|█▎        | 15/111 [00:03<00:23,  4.13it/s][A
Valid:  16%|█▌        | 18/111 [00:04<00:20,  4.53it/s][A
Valid:  19%|█▉        | 21/111 [00:05<00:23,  3.87it/s][A
Valid:  22%|██▏       | 24/111 [00:06<00:24,  3.51it/s][A
Valid:  23%|██▎       | 26/111 [00:06<00:26,  3.26it/s][A
Valid:  26%|██▌       | 29/111 [00:07<00:21,  3.75it/s][A
Valid:  29%|██▉       | 32/111 [00:08<00:18,  4.20it/s][A
Valid:  32%|███▏      | 35/111 [00:09<00:21,  3.59it/s][A
Valid:  35%|███▌      | 39/111 [00:09<00:16,  4.24it/s][A
Valid:  38%|███▊      | 42/111 [00:10<00:14,  4.62it/s][A
Valid:  41%|████      | 45/111 [00:11<00:15,  4.21it/s][A
Valid:  43%|████▎     | 48/111 [00:12<00:16,  3.73it/s][A
Valid:  

Valid loss: 0.023; Valid accuracy: 0.500



Train:   0%|          | 0/994 [00:00<?, ?it/s][A
Train:   0%|          | 2/994 [00:00<07:01,  2.35it/s][A
Train:   0%|          | 4/994 [00:01<06:55,  2.38it/s][A
Train:   1%|          | 6/994 [00:02<06:53,  2.39it/s][A
Train:   1%|          | 8/994 [00:03<07:06,  2.31it/s][A
Train:   1%|          | 10/994 [00:04<07:02,  2.33it/s][A
Train:   1%|          | 12/994 [00:05<07:59,  2.05it/s][A
Train:   1%|▏         | 14/994 [00:06<07:53,  2.07it/s][A
Train:   2%|▏         | 16/994 [00:07<07:33,  2.16it/s][A
Train:   2%|▏         | 18/994 [00:08<08:16,  1.96it/s][A
Train:   2%|▏         | 20/994 [00:09<08:06,  2.00it/s][A
Train:   2%|▏         | 22/994 [00:10<07:39,  2.11it/s][A
Train:   2%|▏         | 24/994 [00:11<08:22,  1.93it/s][A
Train:   3%|▎         | 25/994 [00:12<08:25,  1.92it/s][A
Train:   3%|▎         | 26/994 [00:13<09:51,  1.64it/s][A
Train:   3%|▎         | 28/994 [00:14<08:41,  1.85it/s][A
Train:   3%|▎         | 29/994 [00:15<11:06,  1.45it/s][A
Train:   

Epoch: 4 Training loss: 0.01



Valid:   0%|          | 0/111 [00:00<?, ?it/s][A
Valid:   3%|▎         | 3/111 [00:00<00:18,  5.72it/s][A
Valid:   5%|▌         | 6/111 [00:01<00:18,  5.79it/s][A
Valid:   8%|▊         | 9/111 [00:02<00:17,  5.87it/s][A
Valid:  11%|█         | 12/111 [00:02<00:24,  4.08it/s][A
Valid:  14%|█▎        | 15/111 [00:03<00:25,  3.80it/s][A
Valid:  16%|█▌        | 18/111 [00:04<00:21,  4.29it/s][A
Valid:  19%|█▉        | 21/111 [00:05<00:19,  4.69it/s][A
Valid:  22%|██▏       | 24/111 [00:05<00:21,  3.98it/s][A
Valid:  24%|██▍       | 27/111 [00:06<00:18,  4.43it/s][A
Valid:  27%|██▋       | 30/111 [00:06<00:17,  4.70it/s][A
Valid:  30%|██▉       | 33/111 [00:07<00:18,  4.12it/s][A
Valid:  32%|███▏      | 36/111 [00:08<00:16,  4.50it/s][A
Valid:  35%|███▌      | 39/111 [00:09<00:15,  4.77it/s][A
Valid:  38%|███▊      | 42/111 [00:09<00:16,  4.16it/s][A
Valid:  41%|████      | 45/111 [00:10<00:14,  4.54it/s][A
Valid:  43%|████▎     | 48/111 [00:11<00:13,  4.83it/s][A
Valid:  

Valid loss: 0.019; Valid accuracy: 0.907



Train:   0%|          | 0/994 [00:00<?, ?it/s][A
Train:   0%|          | 2/994 [00:00<06:54,  2.39it/s][A
Train:   0%|          | 4/994 [00:02<07:41,  2.15it/s][A
Train:   1%|          | 6/994 [00:03<10:40,  1.54it/s][A
Train:   1%|          | 8/994 [00:04<09:10,  1.79it/s][A
Train:   1%|          | 10/994 [00:05<08:17,  1.98it/s][A
Train:   1%|          | 12/994 [00:06<08:50,  1.85it/s][A
Train:   1%|▏         | 13/994 [00:07<08:53,  1.84it/s][A
Train:   2%|▏         | 15/994 [00:07<08:08,  2.00it/s][A
Train:   2%|▏         | 17/994 [00:08<07:39,  2.12it/s][A
Train:   2%|▏         | 19/994 [00:09<07:23,  2.20it/s][A
Train:   2%|▏         | 21/994 [00:10<07:20,  2.21it/s][A
Train:   2%|▏         | 23/994 [00:11<07:09,  2.26it/s][A
Train:   3%|▎         | 25/994 [00:12<07:55,  2.04it/s][A
Train:   3%|▎         | 27/994 [00:13<07:51,  2.05it/s][A
Train:   3%|▎         | 29/994 [00:14<07:29,  2.15it/s][A
Train:   3%|▎         | 31/994 [00:15<08:11,  1.96it/s][A
Train:   

Epoch: 5 Training loss: 0.02



Valid:   0%|          | 0/111 [00:00<?, ?it/s][A
Valid:   3%|▎         | 3/111 [00:00<00:20,  5.30it/s][A
Valid:   5%|▌         | 6/111 [00:01<00:27,  3.86it/s][A
Valid:   8%|▊         | 9/111 [00:01<00:22,  4.60it/s][A
Valid:  11%|█         | 12/111 [00:03<00:19,  5.04it/s][A
Valid:  14%|█▎        | 15/111 [00:03<00:24,  3.88it/s][A
Valid:  16%|█▌        | 18/111 [00:04<00:25,  3.69it/s][A
Valid:  20%|█▉        | 22/111 [00:05<00:20,  4.33it/s][A
Valid:  23%|██▎       | 25/111 [00:05<00:18,  4.70it/s][A
Valid:  25%|██▌       | 28/111 [00:06<00:19,  4.20it/s][A
Valid:  28%|██▊       | 31/111 [00:07<00:20,  3.89it/s][A
Valid:  31%|███       | 34/111 [00:07<00:17,  4.33it/s][A
Valid:  33%|███▎      | 37/111 [00:08<00:15,  4.69it/s][A
Valid:  36%|███▌      | 40/111 [00:09<00:17,  3.98it/s][A
Valid:  39%|███▊      | 43/111 [00:10<00:15,  4.42it/s][A
Valid:  41%|████▏     | 46/111 [00:10<00:14,  4.63it/s][A
Valid:  44%|████▍     | 49/111 [00:11<00:14,  4.14it/s][A
Valid:  

Valid loss: 0.032; Valid accuracy: 0.505



Train:   0%|          | 0/994 [00:00<?, ?it/s][A
Train:   0%|          | 2/994 [00:00<07:23,  2.24it/s][A
Train:   0%|          | 4/994 [00:01<07:15,  2.27it/s][A
Train:   1%|          | 6/994 [00:02<07:44,  2.13it/s][A
Train:   1%|          | 8/994 [00:03<07:22,  2.23it/s][A
Train:   1%|          | 10/994 [00:04<07:09,  2.29it/s][A
Train:   1%|          | 12/994 [00:05<07:19,  2.24it/s][A
Train:   1%|▏         | 14/994 [00:06<07:07,  2.29it/s][A
Train:   2%|▏         | 16/994 [00:07<08:02,  2.03it/s][A
Train:   2%|▏         | 18/994 [00:08<07:55,  2.05it/s][A
Train:   2%|▏         | 20/994 [00:09<07:32,  2.15it/s][A
Train:   2%|▏         | 22/994 [00:10<08:13,  1.97it/s][A
Train:   2%|▏         | 24/994 [00:11<07:46,  2.08it/s][A
Train:   3%|▎         | 26/994 [00:12<08:24,  1.92it/s][A
Train:   3%|▎         | 27/994 [00:13<08:33,  1.88it/s][A
Train:   3%|▎         | 28/994 [00:14<09:55,  1.62it/s][A
Train:   3%|▎         | 29/994 [00:14<11:10,  1.44it/s][A
Train:   

Epoch: 6 Training loss: 0.01



Valid:   0%|          | 0/111 [00:00<?, ?it/s][A
Valid:   3%|▎         | 3/111 [00:00<00:19,  5.49it/s][A
Valid:   5%|▌         | 6/111 [00:01<00:26,  3.97it/s][A
Valid:   8%|▊         | 9/111 [00:01<00:22,  4.61it/s][A
Valid:  11%|█         | 12/111 [00:02<00:20,  4.86it/s][A
Valid:  14%|█▎        | 15/111 [00:03<00:18,  5.14it/s][A
Valid:  16%|█▌        | 18/111 [00:03<00:20,  4.57it/s][A
Valid:  19%|█▉        | 21/111 [00:04<00:18,  4.93it/s][A
Valid:  22%|██▏       | 24/111 [00:05<00:22,  3.82it/s][A
Valid:  24%|██▍       | 27/111 [00:06<00:23,  3.64it/s][A
Valid:  27%|██▋       | 30/111 [00:06<00:19,  4.14it/s][A
Valid:  30%|██▉       | 33/111 [00:07<00:17,  4.54it/s][A
Valid:  32%|███▏      | 36/111 [00:08<00:19,  3.83it/s][A
Valid:  35%|███▌      | 39/111 [00:09<00:19,  3.65it/s][A
Valid:  38%|███▊      | 42/111 [00:09<00:16,  4.12it/s][A
Valid:  41%|████      | 45/111 [00:10<00:14,  4.53it/s][A
Valid:  43%|████▎     | 48/111 [00:11<00:16,  3.79it/s][A
Valid:  

Valid loss: 0.027; Valid accuracy: 0.500



Train:   0%|          | 0/994 [00:00<?, ?it/s][A
Train:   0%|          | 2/994 [00:00<06:49,  2.42it/s][A
Train:   0%|          | 4/994 [00:02<06:50,  2.41it/s][A
Train:   1%|          | 6/994 [00:03<09:32,  1.73it/s][A
Train:   1%|          | 8/994 [00:04<08:26,  1.95it/s][A
Train:   1%|          | 10/994 [00:05<07:52,  2.08it/s][A
Train:   1%|          | 12/994 [00:06<08:34,  1.91it/s][A
Train:   1%|▏         | 14/994 [00:07<08:19,  1.96it/s][A
Train:   2%|▏         | 16/994 [00:08<07:48,  2.09it/s][A
Train:   2%|▏         | 18/994 [00:09<08:22,  1.94it/s][A
Train:   2%|▏         | 20/994 [00:10<07:50,  2.07it/s][A
Train:   2%|▏         | 22/994 [00:11<08:15,  1.96it/s][A
Train:   2%|▏         | 24/994 [00:11<07:52,  2.05it/s][A
Train:   3%|▎         | 26/994 [00:12<07:28,  2.16it/s][A
Train:   3%|▎         | 28/994 [00:14<07:13,  2.23it/s][A
Train:   3%|▎         | 30/994 [00:15<08:50,  1.82it/s][A
Train:   3%|▎         | 31/994 [00:16<08:49,  1.82it/s][A
Train:   

Epoch: 7 Training loss: 0.01



Valid:   0%|          | 0/111 [00:00<?, ?it/s][A
Valid:   3%|▎         | 3/111 [00:00<00:19,  5.67it/s][A
Valid:   5%|▌         | 6/111 [00:01<00:18,  5.66it/s][A
Valid:   8%|▊         | 9/111 [00:01<00:17,  5.77it/s][A
Valid:  11%|█         | 12/111 [00:02<00:17,  5.81it/s][A
Valid:  14%|█▎        | 15/111 [00:03<00:23,  4.13it/s][A
Valid:  16%|█▌        | 18/111 [00:04<00:24,  3.79it/s][A
Valid:  20%|█▉        | 22/111 [00:04<00:20,  4.42it/s][A
Valid:  23%|██▎       | 25/111 [00:05<00:17,  4.79it/s][A
Valid:  25%|██▌       | 28/111 [00:06<00:19,  4.22it/s][A
Valid:  28%|██▊       | 31/111 [00:07<00:20,  3.90it/s][A
Valid:  31%|███       | 34/111 [00:07<00:17,  4.34it/s][A
Valid:  33%|███▎      | 37/111 [00:08<00:15,  4.71it/s][A
Valid:  36%|███▌      | 40/111 [00:09<00:17,  3.99it/s][A
Valid:  40%|███▉      | 44/111 [00:09<00:14,  4.56it/s][A
Valid:  42%|████▏     | 47/111 [00:10<00:13,  4.83it/s][A
Valid:  45%|████▌     | 50/111 [00:11<00:14,  4.26it/s][A
Valid:  

Valid loss: 0.022; Valid accuracy: 0.925


Epochs:  70%|███████   | 7/10 [1:04:16<27:20, 546.75s/it]
Train:   0%|          | 0/994 [00:00<?, ?it/s][A
Train:   0%|          | 2/994 [00:00<07:14,  2.28it/s][A
Train:   0%|          | 4/994 [00:02<07:00,  2.35it/s][A
Train:   1%|          | 6/994 [00:03<08:57,  1.84it/s][A
Train:   1%|          | 8/994 [00:04<08:05,  2.03it/s][A
Train:   1%|          | 10/994 [00:05<08:50,  1.85it/s][A
Train:   1%|          | 11/994 [00:06<08:50,  1.85it/s][A
Train:   1%|          | 12/994 [00:06<10:22,  1.58it/s][A
Train:   1%|▏         | 14/994 [00:07<08:56,  1.83it/s][A
Train:   2%|▏         | 16/994 [00:08<08:10,  1.99it/s][A
Train:   2%|▏         | 18/994 [00:09<07:40,  2.12it/s][A
Train:   2%|▏         | 20/994 [00:10<07:20,  2.21it/s][A
Train:   2%|▏         | 22/994 [00:11<08:06,  2.00it/s][A
Train:   2%|▏         | 24/994 [00:12<07:56,  2.04it/s][A
Train:   3%|▎         | 26/994 [00:13<07:52,  2.05it/s][A
Train:   3%|▎         | 28/994 [00:14<07:30,  2.14it/s][A
Train:   3%

Epoch: 8 Training loss: 0.01



Valid:   0%|          | 0/111 [00:00<?, ?it/s][A
Valid:   3%|▎         | 3/111 [00:01<00:18,  5.88it/s][A
Valid:   5%|▌         | 6/111 [00:02<00:34,  3.02it/s][A
Valid:   7%|▋         | 8/111 [00:02<00:35,  2.89it/s][A
Valid:  10%|▉         | 11/111 [00:03<00:27,  3.68it/s][A
Valid:  13%|█▎        | 14/111 [00:04<00:22,  4.30it/s][A
Valid:  15%|█▌        | 17/111 [00:05<00:27,  3.39it/s][A
Valid:  18%|█▊        | 20/111 [00:05<00:27,  3.33it/s][A
Valid:  21%|██        | 23/111 [00:06<00:22,  3.87it/s][A
Valid:  23%|██▎       | 26/111 [00:07<00:19,  4.32it/s][A
Valid:  26%|██▌       | 29/111 [00:08<00:22,  3.72it/s][A
Valid:  29%|██▉       | 32/111 [00:08<00:22,  3.58it/s][A
Valid:  32%|███▏      | 35/111 [00:09<00:18,  4.08it/s][A
Valid:  35%|███▌      | 39/111 [00:10<00:15,  4.65it/s][A
Valid:  38%|███▊      | 42/111 [00:11<00:16,  4.18it/s][A
Valid:  41%|████      | 45/111 [00:12<00:17,  3.73it/s][A
Valid:  43%|████▎     | 48/111 [00:12<00:17,  3.60it/s][A
Valid:  

Valid loss: 0.021; Valid accuracy: 0.500



Train:   0%|          | 0/994 [00:00<?, ?it/s][A
Train:   0%|          | 2/994 [00:01<06:44,  2.45it/s][A
Train:   0%|          | 4/994 [00:02<09:54,  1.66it/s][A
Train:   1%|          | 5/994 [00:03<09:42,  1.70it/s][A
Train:   1%|          | 6/994 [00:03<11:26,  1.44it/s][A
Train:   1%|          | 8/994 [00:05<09:18,  1.77it/s][A
Train:   1%|          | 9/994 [00:06<12:09,  1.35it/s][A
Train:   1%|          | 10/994 [00:07<13:10,  1.25it/s][A
Train:   1%|          | 11/994 [00:08<13:49,  1.19it/s][A
Train:   1%|          | 12/994 [00:08<14:19,  1.14it/s][A
Train:   1%|▏         | 14/994 [00:09<10:56,  1.49it/s][A
Train:   2%|▏         | 15/994 [00:10<10:25,  1.56it/s][A
Train:   2%|▏         | 17/994 [00:11<08:54,  1.83it/s][A
Train:   2%|▏         | 18/994 [00:11<09:56,  1.63it/s][A
Train:   2%|▏         | 20/994 [00:13<08:39,  1.87it/s][A
Train:   2%|▏         | 21/994 [00:13<10:48,  1.50it/s][A
Train:   2%|▏         | 23/994 [00:14<09:13,  1.75it/s][A
Train:   2%

Epoch: 9 Training loss: 0.00



Valid:   0%|          | 0/111 [00:00<?, ?it/s][A
Valid:   3%|▎         | 3/111 [00:00<00:18,  5.90it/s][A
Valid:   5%|▌         | 6/111 [00:01<00:25,  4.19it/s][A
Valid:   8%|▊         | 9/111 [00:02<00:27,  3.78it/s][A
Valid:  11%|█         | 12/111 [00:02<00:22,  4.40it/s][A
Valid:  14%|█▎        | 15/111 [00:03<00:20,  4.74it/s][A
Valid:  16%|█▌        | 18/111 [00:04<00:22,  4.10it/s][A
Valid:  19%|█▉        | 21/111 [00:04<00:19,  4.55it/s][A
Valid:  22%|██▏       | 24/111 [00:05<00:17,  4.86it/s][A
Valid:  24%|██▍       | 27/111 [00:06<00:16,  5.14it/s][A
Valid:  27%|██▋       | 30/111 [00:06<00:17,  4.59it/s][A
Valid:  31%|███       | 34/111 [00:07<00:15,  5.03it/s][A
Valid:  33%|███▎      | 37/111 [00:08<00:17,  4.21it/s][A
Valid:  37%|███▋      | 41/111 [00:09<00:14,  4.73it/s][A
Valid:  40%|███▉      | 44/111 [00:10<00:18,  3.67it/s][A
Valid:  42%|████▏     | 47/111 [00:10<00:15,  4.09it/s][A
Valid:  45%|████▌     | 50/111 [00:11<00:13,  4.49it/s][A
Valid:  

Valid loss: 0.017; Valid accuracy: 0.968


Epochs:  90%|█████████ | 9/10 [1:22:09<09:01, 541.29s/it]
Train:   0%|          | 0/994 [00:00<?, ?it/s][A
Train:   0%|          | 2/994 [00:00<07:17,  2.27it/s][A
Train:   0%|          | 4/994 [00:01<07:08,  2.31it/s][A
Train:   1%|          | 6/994 [00:02<07:08,  2.30it/s][A
Train:   1%|          | 8/994 [00:03<07:12,  2.28it/s][A
Train:   1%|          | 10/994 [00:04<07:05,  2.32it/s][A
Train:   1%|          | 12/994 [00:05<08:08,  2.01it/s][A
Train:   1%|▏         | 14/994 [00:06<07:47,  2.10it/s][A
Train:   2%|▏         | 16/994 [00:07<07:26,  2.19it/s][A
Train:   2%|▏         | 18/994 [00:08<08:34,  1.90it/s][A
Train:   2%|▏         | 19/994 [00:09<08:36,  1.89it/s][A
Train:   2%|▏         | 21/994 [00:10<07:57,  2.04it/s][A
Train:   2%|▏         | 23/994 [00:11<09:56,  1.63it/s][A
Train:   3%|▎         | 25/994 [00:12<08:54,  1.81it/s][A
Train:   3%|▎         | 27/994 [00:13<08:10,  1.97it/s][A
Train:   3%|▎         | 29/994 [00:14<08:40,  1.85it/s][A
Train:   3%

Epoch: 10 Training loss: 0.00



Valid:   0%|          | 0/111 [00:00<?, ?it/s][A
Valid:   3%|▎         | 3/111 [00:00<00:18,  5.86it/s][A
Valid:   5%|▌         | 6/111 [00:01<00:24,  4.35it/s][A
Valid:   8%|▊         | 9/111 [00:02<00:27,  3.77it/s][A
Valid:  11%|█         | 12/111 [00:02<00:22,  4.40it/s][A
Valid:  14%|█▎        | 15/111 [00:03<00:19,  4.85it/s][A
Valid:  16%|█▌        | 18/111 [00:04<00:24,  3.82it/s][A
Valid:  19%|█▉        | 21/111 [00:05<00:25,  3.58it/s][A
Valid:  21%|██        | 23/111 [00:06<00:26,  3.32it/s][A
Valid:  24%|██▍       | 27/111 [00:06<00:20,  4.06it/s][A
Valid:  28%|██▊       | 31/111 [00:07<00:17,  4.62it/s][A
Valid:  31%|███       | 34/111 [00:08<00:19,  4.02it/s][A
Valid:  33%|███▎      | 37/111 [00:09<00:19,  3.79it/s][A
Valid:  36%|███▌      | 40/111 [00:09<00:16,  4.23it/s][A
Valid:  39%|███▊      | 43/111 [00:10<00:14,  4.61it/s][A
Valid:  41%|████▏     | 46/111 [00:11<00:17,  3.81it/s][A
Valid:  44%|████▍     | 49/111 [00:12<00:17,  3.63it/s][A
Valid:  

In [158]:
 model = AutoModelForSequenceClassification.from_pretrained(f'{MODELS_PATH}/best_checkpoint').to(device)

In [159]:
evaluate(model=model, dataloader=dataloader_valid, set_name='Valid')

Valid: 100%|██████████| 111/111 [00:25<00:00,  4.38it/s]

Valid loss: 0.017; Valid accuracy: 0.968





0.9677601809954751

In [160]:
logit_scores = evaluate_test_dataset2(model=model, dataloader=dataloader_test, set_name='Test')

Test: 100%|██████████| 1563/1563 [06:48<00:00,  3.83it/s]


In [164]:
trec_file_logits = evaluate_ndcg_10(logit_scores, model_name, "logits")

In [165]:
import pandas as pd

run_trec_file_logits = pd.read_csv(trec_file_logits, sep="\t", header=None, 
                   skiprows=1, names=["query", "docid", "rank", "score", "system"])
run_trec_file_logits["q0"] = "q0"
run_trec_file_logits = run_trec_file_logits.to_dict(orient="list")

In [166]:
eval_ndcg10(run_trec_file_logits)

0.6204472079872368

In [168]:
model = train(model_name, 10, lr=3e-5)

Parameters 22713601


Valid: 100%|██████████| 111/111 [00:27<00:00,  3.97it/s]


Valid loss: 38.423; Valid accuracy: 0.521


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]
Train:   0%|          | 0/994 [00:00<?, ?it/s][A
Train:   0%|          | 2/994 [00:00<06:49,  2.42it/s][A
Train:   0%|          | 4/994 [00:01<06:46,  2.44it/s][A
Train:   1%|          | 6/994 [00:02<08:04,  2.04it/s][A
Train:   1%|          | 8/994 [00:03<07:43,  2.13it/s][A
Train:   1%|          | 10/994 [00:05<07:21,  2.23it/s][A
Train:   1%|          | 12/994 [00:06<08:38,  1.89it/s][A
Train:   1%|▏         | 13/994 [00:06<08:40,  1.89it/s][A
Train:   2%|▏         | 15/994 [00:08<08:00,  2.04it/s][A
Train:   2%|▏         | 17/994 [00:08<10:00,  1.63it/s][A
Train:   2%|▏         | 19/994 [00:09<08:58,  1.81it/s][A
Train:   2%|▏         | 21/994 [00:11<08:14,  1.97it/s][A
Train:   2%|▏         | 23/994 [00:12<08:45,  1.85it/s][A
Train:   2%|▏         | 24/994 [00:12<08:46,  1.84it/s][A
Train:   3%|▎         | 26/994 [00:14<08:04,  2.00it/s][A
Train:   3%|▎         | 28/994 [00:14<09:56,  1.62it/s][A
Train:   3%|▎         

Epoch: 1 Training loss: 2.71



Valid:   0%|          | 0/111 [00:00<?, ?it/s][A
Valid:   3%|▎         | 3/111 [00:00<00:18,  5.76it/s][A
Valid:   5%|▌         | 6/111 [00:01<00:17,  5.87it/s][A
Valid:   8%|▊         | 9/111 [00:01<00:17,  5.79it/s][A
Valid:  11%|█         | 12/111 [00:02<00:22,  4.44it/s][A
Valid:  14%|█▎        | 15/111 [00:02<00:19,  4.87it/s][A
Valid:  16%|█▌        | 18/111 [00:03<00:18,  5.13it/s][A
Valid:  19%|█▉        | 21/111 [00:04<00:22,  4.06it/s][A
Valid:  22%|██▏       | 24/111 [00:05<00:22,  3.78it/s][A
Valid:  24%|██▍       | 27/111 [00:05<00:19,  4.27it/s][A
Valid:  28%|██▊       | 31/111 [00:06<00:16,  4.81it/s][A
Valid:  31%|███       | 34/111 [00:07<00:18,  4.23it/s][A
Valid:  33%|███▎      | 37/111 [00:08<00:19,  3.89it/s][A
Valid:  36%|███▌      | 40/111 [00:08<00:16,  4.32it/s][A
Valid:  39%|███▊      | 43/111 [00:10<00:14,  4.67it/s][A
Valid:  41%|████▏     | 46/111 [00:10<00:16,  3.87it/s][A
Valid:  44%|████▍     | 49/111 [00:11<00:17,  3.61it/s][A
Valid:  

Valid loss: 0.188; Valid accuracy: 0.605


Epochs:  10%|█         | 1/10 [09:03<1:21:33, 543.68s/it]
Train:   0%|          | 0/994 [00:00<?, ?it/s][A
Train:   0%|          | 2/994 [00:01<06:55,  2.39it/s][A
Train:   0%|          | 4/994 [00:02<09:23,  1.76it/s][A
Train:   1%|          | 5/994 [00:02<09:16,  1.78it/s][A
Train:   1%|          | 7/994 [00:03<08:08,  2.02it/s][A
Train:   1%|          | 9/994 [00:04<07:37,  2.15it/s][A
Train:   1%|          | 11/994 [00:05<07:20,  2.23it/s][A
Train:   1%|▏         | 13/994 [00:06<07:28,  2.19it/s][A
Train:   2%|▏         | 15/994 [00:07<07:36,  2.15it/s][A
Train:   2%|▏         | 17/994 [00:08<07:20,  2.22it/s][A
Train:   2%|▏         | 19/994 [00:09<08:05,  2.01it/s][A
Train:   2%|▏         | 21/994 [00:10<07:59,  2.03it/s][A
Train:   2%|▏         | 23/994 [00:11<07:37,  2.12it/s][A
Train:   3%|▎         | 25/994 [00:12<08:10,  1.97it/s][A
Train:   3%|▎         | 27/994 [00:13<08:04,  2.00it/s][A
Train:   3%|▎         | 29/994 [00:14<07:39,  2.10it/s][A
Train:   3%|

Epoch: 2 Training loss: 0.12



Valid:   0%|          | 0/111 [00:00<?, ?it/s][A
Valid:   3%|▎         | 3/111 [00:00<00:18,  5.75it/s][A
Valid:   5%|▌         | 6/111 [00:01<00:24,  4.32it/s][A
Valid:   8%|▊         | 9/111 [00:02<00:27,  3.75it/s][A
Valid:  11%|█         | 12/111 [00:02<00:22,  4.39it/s][A
Valid:  14%|█▎        | 15/111 [00:03<00:19,  4.83it/s][A
Valid:  16%|█▌        | 18/111 [00:04<00:23,  3.89it/s][A
Valid:  19%|█▉        | 21/111 [00:05<00:25,  3.55it/s][A
Valid:  21%|██        | 23/111 [00:06<00:26,  3.28it/s][A
Valid:  23%|██▎       | 26/111 [00:06<00:22,  3.75it/s][A
Valid:  26%|██▌       | 29/111 [00:07<00:19,  4.22it/s][A
Valid:  29%|██▉       | 32/111 [00:08<00:22,  3.57it/s][A
Valid:  31%|███       | 34/111 [00:09<00:23,  3.32it/s][A
Valid:  33%|███▎      | 37/111 [00:09<00:19,  3.84it/s][A
Valid:  36%|███▌      | 40/111 [00:10<00:16,  4.34it/s][A
Valid:  39%|███▊      | 43/111 [00:11<00:19,  3.49it/s][A
Valid:  41%|████▏     | 46/111 [00:12<00:19,  3.40it/s][A
Valid:  

Valid loss: 0.169; Valid accuracy: 0.666


Epochs:  20%|██        | 2/10 [18:22<1:13:40, 552.54s/it]
Train:   0%|          | 0/994 [00:00<?, ?it/s][A
Train:   0%|          | 2/994 [00:00<07:00,  2.36it/s][A
Train:   0%|          | 4/994 [00:01<06:56,  2.37it/s][A
Train:   1%|          | 6/994 [00:02<06:55,  2.38it/s][A
Train:   1%|          | 8/994 [00:03<08:06,  2.03it/s][A
Train:   1%|          | 10/994 [00:04<08:02,  2.04it/s][A
Train:   1%|          | 12/994 [00:05<07:37,  2.15it/s][A
Train:   1%|▏         | 14/994 [00:06<08:14,  1.98it/s][A
Train:   2%|▏         | 16/994 [00:07<08:10,  2.00it/s][A
Train:   2%|▏         | 18/994 [00:08<07:46,  2.09it/s][A
Train:   2%|▏         | 20/994 [00:09<08:17,  1.96it/s][A
Train:   2%|▏         | 22/994 [00:10<07:50,  2.07it/s][A
Train:   2%|▏         | 24/994 [00:11<08:14,  1.96it/s][A
Train:   3%|▎         | 26/994 [00:12<07:55,  2.04it/s][A
Train:   3%|▎         | 28/994 [00:13<07:32,  2.13it/s][A
Train:   3%|▎         | 30/994 [00:14<07:16,  2.21it/s][A
Train:   3%

Epoch: 3 Training loss: 0.08



Valid:   0%|          | 0/111 [00:00<?, ?it/s][A
Valid:   3%|▎         | 3/111 [00:00<00:18,  5.83it/s][A
Valid:   5%|▌         | 6/111 [00:01<00:27,  3.75it/s][A
Valid:   8%|▊         | 9/111 [00:02<00:28,  3.54it/s][A
Valid:  12%|█▏        | 13/111 [00:03<00:22,  4.36it/s][A
Valid:  14%|█▍        | 16/111 [00:03<00:19,  4.77it/s][A
Valid:  17%|█▋        | 19/111 [00:04<00:21,  4.19it/s][A
Valid:  20%|█▉        | 22/111 [00:05<00:23,  3.82it/s][A
Valid:  23%|██▎       | 25/111 [00:05<00:20,  4.29it/s][A
Valid:  25%|██▌       | 28/111 [00:07<00:17,  4.68it/s][A
Valid:  28%|██▊       | 31/111 [00:07<00:20,  3.88it/s][A
Valid:  31%|███       | 34/111 [00:08<00:20,  3.70it/s][A
Valid:  33%|███▎      | 37/111 [00:08<00:17,  4.13it/s][A
Valid:  36%|███▌      | 40/111 [00:09<00:15,  4.55it/s][A
Valid:  39%|███▊      | 43/111 [00:10<00:17,  3.86it/s][A
Valid:  41%|████▏     | 46/111 [00:11<00:17,  3.65it/s][A
Valid:  44%|████▍     | 49/111 [00:11<00:15,  4.09it/s][A
Valid:  

Valid loss: 0.135; Valid accuracy: 0.722


Epochs:  30%|███       | 3/10 [27:32<1:04:18, 551.22s/it]
Train:   0%|          | 0/994 [00:00<?, ?it/s][A
Train:   0%|          | 2/994 [00:00<06:51,  2.41it/s][A
Train:   0%|          | 4/994 [00:02<06:52,  2.40it/s][A
Train:   1%|          | 6/994 [00:02<08:13,  2.00it/s][A
Train:   1%|          | 8/994 [00:03<08:00,  2.05it/s][A
Train:   1%|          | 10/994 [00:04<07:59,  2.05it/s][A
Train:   1%|          | 12/994 [00:05<07:36,  2.15it/s][A
Train:   1%|▏         | 14/994 [00:06<08:12,  1.99it/s][A
Train:   2%|▏         | 16/994 [00:08<07:46,  2.10it/s][A
Train:   2%|▏         | 18/994 [00:08<08:29,  1.91it/s][A
Train:   2%|▏         | 19/994 [00:09<08:31,  1.91it/s][A
Train:   2%|▏         | 20/994 [00:10<10:03,  1.61it/s][A
Train:   2%|▏         | 22/994 [00:12<08:49,  1.84it/s][A
Train:   2%|▏         | 23/994 [00:12<11:18,  1.43it/s][A
Train:   2%|▏         | 24/994 [00:13<11:56,  1.35it/s][A
Train:   3%|▎         | 25/994 [00:14<12:51,  1.26it/s][A
Train:   3%

Epoch: 4 Training loss: 0.06



Valid:   0%|          | 0/111 [00:00<?, ?it/s][A
Valid:   3%|▎         | 3/111 [00:00<00:18,  5.77it/s][A
Valid:   5%|▌         | 6/111 [00:01<00:26,  3.98it/s][A
Valid:   8%|▊         | 9/111 [00:02<00:30,  3.40it/s][A
Valid:  10%|▉         | 11/111 [00:03<00:31,  3.16it/s][A
Valid:  13%|█▎        | 14/111 [00:03<00:25,  3.75it/s][A
Valid:  15%|█▌        | 17/111 [00:04<00:21,  4.29it/s][A
Valid:  18%|█▊        | 20/111 [00:05<00:24,  3.67it/s][A
Valid:  21%|██        | 23/111 [00:06<00:21,  4.17it/s][A
Valid:  23%|██▎       | 26/111 [00:06<00:20,  4.05it/s][A
Valid:  26%|██▌       | 29/111 [00:07<00:18,  4.49it/s][A
Valid:  29%|██▉       | 32/111 [00:08<00:22,  3.45it/s][A
Valid:  32%|███▏      | 35/111 [00:09<00:22,  3.34it/s][A
Valid:  33%|███▎      | 37/111 [00:10<00:23,  3.15it/s][A
Valid:  36%|███▌      | 40/111 [00:10<00:19,  3.63it/s][A
Valid:  39%|███▊      | 43/111 [00:11<00:16,  4.02it/s][A
Valid:  41%|████▏     | 46/111 [00:11<00:14,  4.44it/s][A
Valid:  

Valid loss: 0.141; Valid accuracy: 0.741


Epochs:  40%|████      | 4/10 [36:43<55:07, 551.25s/it]  
Train:   0%|          | 0/994 [00:00<?, ?it/s][A
Train:   0%|          | 2/994 [00:00<07:55,  2.09it/s][A
Train:   0%|          | 4/994 [00:01<07:20,  2.25it/s][A
Train:   1%|          | 6/994 [00:02<07:32,  2.18it/s][A
Train:   1%|          | 8/994 [00:03<07:13,  2.28it/s][A
Train:   1%|          | 10/994 [00:04<08:17,  1.98it/s][A
Train:   1%|          | 12/994 [00:05<08:09,  2.01it/s][A
Train:   1%|▏         | 14/994 [00:07<07:42,  2.12it/s][A
Train:   2%|▏         | 16/994 [00:07<08:24,  1.94it/s][A
Train:   2%|▏         | 17/994 [00:08<08:28,  1.92it/s][A
Train:   2%|▏         | 18/994 [00:09<10:01,  1.62it/s][A
Train:   2%|▏         | 20/994 [00:11<08:47,  1.84it/s][A
Train:   2%|▏         | 21/994 [00:11<11:23,  1.42it/s][A
Train:   2%|▏         | 22/994 [00:12<11:54,  1.36it/s][A
Train:   2%|▏         | 24/994 [00:14<09:52,  1.64it/s][A
Train:   3%|▎         | 25/994 [00:14<12:14,  1.32it/s][A
Train:   3%

Epoch: 5 Training loss: 0.04



Valid:   0%|          | 0/111 [00:00<?, ?it/s][A
Valid:   3%|▎         | 3/111 [00:00<00:19,  5.67it/s][A
Valid:   5%|▌         | 6/111 [00:01<00:18,  5.82it/s][A
Valid:   8%|▊         | 9/111 [00:02<00:17,  5.82it/s][A
Valid:  11%|█         | 12/111 [00:03<00:23,  4.19it/s][A
Valid:  14%|█▎        | 15/111 [00:03<00:26,  3.62it/s][A
Valid:  16%|█▌        | 18/111 [00:04<00:26,  3.47it/s][A
Valid:  19%|█▉        | 21/111 [00:05<00:22,  4.01it/s][A
Valid:  22%|██▏       | 24/111 [00:06<00:19,  4.44it/s][A
Valid:  24%|██▍       | 27/111 [00:07<00:21,  3.85it/s][A
Valid:  27%|██▋       | 30/111 [00:07<00:23,  3.51it/s][A
Valid:  29%|██▉       | 32/111 [00:08<00:23,  3.30it/s][A
Valid:  32%|███▏      | 35/111 [00:08<00:20,  3.79it/s][A
Valid:  34%|███▍      | 38/111 [00:09<00:17,  4.28it/s][A
Valid:  37%|███▋      | 41/111 [00:09<00:15,  4.61it/s][A
Valid:  40%|███▉      | 44/111 [00:11<00:13,  4.92it/s][A
Valid:  42%|████▏     | 47/111 [00:12<00:16,  3.85it/s][A
Valid:  

Valid loss: 0.083; Valid accuracy: 0.773


Epochs:  50%|█████     | 5/10 [46:03<46:12, 554.52s/it]
Train:   0%|          | 0/994 [00:00<?, ?it/s][A
Train:   0%|          | 2/994 [00:01<06:49,  2.42it/s][A
Train:   0%|          | 4/994 [00:02<09:53,  1.67it/s][A
Train:   1%|          | 5/994 [00:03<09:27,  1.74it/s][A
Train:   1%|          | 6/994 [00:03<11:10,  1.47it/s][A
Train:   1%|          | 8/994 [00:05<09:11,  1.79it/s][A
Train:   1%|          | 9/994 [00:06<12:13,  1.34it/s][A
Train:   1%|          | 10/994 [00:06<12:39,  1.30it/s][A
Train:   1%|          | 12/994 [00:07<10:10,  1.61it/s][A
Train:   1%|▏         | 14/994 [00:08<08:58,  1.82it/s][A
Train:   2%|▏         | 16/994 [00:09<08:17,  1.96it/s][A
Train:   2%|▏         | 18/994 [00:10<07:47,  2.09it/s][A
Train:   2%|▏         | 20/994 [00:11<07:28,  2.17it/s][A
Train:   2%|▏         | 22/994 [00:12<08:10,  1.98it/s][A
Train:   2%|▏         | 24/994 [00:13<08:01,  2.01it/s][A
Train:   3%|▎         | 26/994 [00:14<07:36,  2.12it/s][A
Train:   3%|▎  

Epoch: 6 Training loss: 0.03



Valid:   0%|          | 0/111 [00:00<?, ?it/s][A
Valid:   3%|▎         | 3/111 [00:00<00:19,  5.68it/s][A
Valid:   5%|▌         | 6/111 [00:01<00:18,  5.78it/s][A
Valid:   8%|▊         | 9/111 [00:02<00:23,  4.26it/s][A
Valid:  11%|█         | 12/111 [00:03<00:27,  3.60it/s][A
Valid:  13%|█▎        | 14/111 [00:03<00:29,  3.28it/s][A
Valid:  15%|█▌        | 17/111 [00:04<00:24,  3.81it/s][A
Valid:  18%|█▊        | 20/111 [00:05<00:21,  4.33it/s][A
Valid:  21%|██        | 23/111 [00:05<00:24,  3.57it/s][A
Valid:  23%|██▎       | 26/111 [00:06<00:20,  4.09it/s][A
Valid:  26%|██▌       | 29/111 [00:07<00:18,  4.50it/s][A
Valid:  29%|██▉       | 32/111 [00:08<00:20,  3.93it/s][A
Valid:  32%|███▏      | 35/111 [00:09<00:21,  3.54it/s][A
Valid:  33%|███▎      | 37/111 [00:09<00:22,  3.30it/s][A
Valid:  36%|███▌      | 40/111 [00:10<00:18,  3.78it/s][A
Valid:  39%|███▊      | 43/111 [00:11<00:15,  4.25it/s][A
Valid:  41%|████▏     | 46/111 [00:11<00:17,  3.62it/s][A
Valid:  

Valid loss: 0.048; Valid accuracy: 0.805


Epochs:  60%|██████    | 6/10 [55:20<37:01, 555.28s/it]
Train:   0%|          | 0/994 [00:00<?, ?it/s][A
Train:   0%|          | 2/994 [00:01<06:53,  2.40it/s][A
Train:   0%|          | 4/994 [00:02<10:42,  1.54it/s][A
Train:   1%|          | 5/994 [00:02<09:59,  1.65it/s][A
Train:   1%|          | 7/994 [00:03<08:31,  1.93it/s][A
Train:   1%|          | 9/994 [00:04<08:01,  2.05it/s][A
Train:   1%|          | 11/994 [00:05<07:40,  2.13it/s][A
Train:   1%|▏         | 13/994 [00:06<07:26,  2.20it/s][A
Train:   2%|▏         | 15/994 [00:07<07:15,  2.25it/s][A
Train:   2%|▏         | 17/994 [00:08<08:08,  2.00it/s][A
Train:   2%|▏         | 19/994 [00:09<08:01,  2.03it/s][A
Train:   2%|▏         | 21/994 [00:10<07:56,  2.04it/s][A
Train:   2%|▏         | 23/994 [00:11<07:34,  2.14it/s][A
Train:   3%|▎         | 25/994 [00:12<08:07,  1.99it/s][A
Train:   3%|▎         | 27/994 [00:13<07:59,  2.02it/s][A
Train:   3%|▎         | 29/994 [00:14<07:35,  2.12it/s][A
Train:   3%|▎ 

Epoch: 7 Training loss: 0.03



Valid:   0%|          | 0/111 [00:00<?, ?it/s][A
Valid:   3%|▎         | 3/111 [00:00<00:19,  5.66it/s][A
Valid:   5%|▌         | 6/111 [00:01<00:18,  5.72it/s][A
Valid:   8%|▊         | 9/111 [00:02<00:27,  3.73it/s][A
Valid:  11%|█         | 12/111 [00:03<00:28,  3.53it/s][A
Valid:  14%|█▎        | 15/111 [00:03<00:23,  4.12it/s][A
Valid:  16%|█▌        | 18/111 [00:04<00:20,  4.56it/s][A
Valid:  19%|█▉        | 21/111 [00:05<00:22,  3.92it/s][A
Valid:  22%|██▏       | 24/111 [00:06<00:24,  3.53it/s][A
Valid:  23%|██▎       | 26/111 [00:06<00:25,  3.28it/s][A
Valid:  26%|██▌       | 29/111 [00:07<00:21,  3.75it/s][A
Valid:  29%|██▉       | 32/111 [00:08<00:18,  4.24it/s][A
Valid:  32%|███▏      | 35/111 [00:09<00:21,  3.61it/s][A
Valid:  34%|███▍      | 38/111 [00:09<00:17,  4.12it/s][A
Valid:  37%|███▋      | 41/111 [00:10<00:15,  4.54it/s][A
Valid:  40%|███▉      | 44/111 [00:11<00:17,  3.85it/s][A
Valid:  42%|████▏     | 47/111 [00:12<00:18,  3.50it/s][A
Valid:  

Valid loss: 0.047; Valid accuracy: 0.806


Epochs:  70%|███████   | 7/10 [1:04:42<27:52, 557.59s/it]
Train:   0%|          | 0/994 [00:00<?, ?it/s][A
Train:   0%|          | 2/994 [00:01<06:52,  2.40it/s][A
Train:   0%|          | 4/994 [00:02<08:59,  1.84it/s][A
Train:   1%|          | 5/994 [00:02<08:59,  1.83it/s][A
Train:   1%|          | 7/994 [00:04<07:59,  2.06it/s][A
Train:   1%|          | 9/994 [00:05<09:58,  1.64it/s][A
Train:   1%|          | 11/994 [00:06<08:49,  1.86it/s][A
Train:   1%|          | 12/994 [00:07<09:45,  1.68it/s][A
Train:   1%|▏         | 13/994 [00:08<11:05,  1.47it/s][A
Train:   1%|▏         | 14/994 [00:08<12:14,  1.33it/s][A
Train:   2%|▏         | 16/994 [00:10<10:02,  1.62it/s][A
Train:   2%|▏         | 17/994 [00:11<12:12,  1.33it/s][A
Train:   2%|▏         | 18/994 [00:12<13:01,  1.25it/s][A
Train:   2%|▏         | 19/994 [00:12<13:39,  1.19it/s][A
Train:   2%|▏         | 21/994 [00:13<10:43,  1.51it/s][A
Train:   2%|▏         | 22/994 [00:14<10:20,  1.57it/s][A
Train:   2%|

Epoch: 8 Training loss: 0.02



Valid:   0%|          | 0/111 [00:00<?, ?it/s][A
Valid:   3%|▎         | 3/111 [00:00<00:19,  5.54it/s][A
Valid:   5%|▌         | 6/111 [00:01<00:18,  5.62it/s][A
Valid:   8%|▊         | 9/111 [00:02<00:18,  5.65it/s][A
Valid:  11%|█         | 12/111 [00:03<00:23,  4.24it/s][A
Valid:  14%|█▎        | 15/111 [00:03<00:26,  3.66it/s][A
Valid:  16%|█▌        | 18/111 [00:04<00:26,  3.51it/s][A
Valid:  19%|█▉        | 21/111 [00:05<00:22,  4.03it/s][A
Valid:  22%|██▏       | 24/111 [00:06<00:19,  4.43it/s][A
Valid:  24%|██▍       | 27/111 [00:06<00:22,  3.77it/s][A
Valid:  27%|██▋       | 30/111 [00:07<00:22,  3.67it/s][A
Valid:  30%|██▉       | 33/111 [00:08<00:18,  4.12it/s][A
Valid:  32%|███▏      | 36/111 [00:09<00:16,  4.50it/s][A
Valid:  35%|███▌      | 39/111 [00:09<00:18,  3.80it/s][A
Valid:  38%|███▊      | 42/111 [00:10<00:18,  3.67it/s][A
Valid:  41%|████      | 45/111 [00:11<00:15,  4.14it/s][A
Valid:  43%|████▎     | 48/111 [00:11<00:13,  4.50it/s][A
Valid:  

Valid loss: 0.034; Valid accuracy: 0.887


Epochs:  80%|████████  | 8/10 [1:14:11<18:42, 561.09s/it]
Train:   0%|          | 0/994 [00:00<?, ?it/s][A
Train:   0%|          | 2/994 [00:00<06:50,  2.42it/s][A
Train:   0%|          | 4/994 [00:01<07:30,  2.20it/s][A
Train:   1%|          | 6/994 [00:02<07:45,  2.12it/s][A
Train:   1%|          | 8/994 [00:03<07:25,  2.21it/s][A
Train:   1%|          | 10/994 [00:04<08:06,  2.02it/s][A
Train:   1%|          | 12/994 [00:05<08:01,  2.04it/s][A
Train:   1%|▏         | 14/994 [00:06<07:37,  2.14it/s][A
Train:   2%|▏         | 16/994 [00:07<08:16,  1.97it/s][A
Train:   2%|▏         | 18/994 [00:08<08:09,  2.00it/s][A
Train:   2%|▏         | 20/994 [00:09<07:46,  2.09it/s][A
Train:   2%|▏         | 22/994 [00:10<08:18,  1.95it/s][A
Train:   2%|▏         | 23/994 [00:11<08:25,  1.92it/s][A
Train:   3%|▎         | 25/994 [00:13<07:52,  2.05it/s][A
Train:   3%|▎         | 27/994 [00:13<09:47,  1.65it/s][A
Train:   3%|▎         | 29/994 [00:14<08:50,  1.82it/s][A
Train:   3%

Epoch: 9 Training loss: 0.02



Valid:   0%|          | 0/111 [00:00<?, ?it/s][A
Valid:   3%|▎         | 3/111 [00:01<00:19,  5.68it/s][A
Valid:   5%|▌         | 6/111 [00:02<00:33,  3.12it/s][A
Valid:   7%|▋         | 8/111 [00:03<00:36,  2.79it/s][A
Valid:   9%|▉         | 10/111 [00:03<00:39,  2.55it/s][A
Valid:  12%|█▏        | 13/111 [00:04<00:30,  3.26it/s][A
Valid:  14%|█▍        | 16/111 [00:05<00:24,  3.89it/s][A
Valid:  17%|█▋        | 19/111 [00:05<00:26,  3.46it/s][A
Valid:  20%|█▉        | 22/111 [00:06<00:22,  3.97it/s][A
Valid:  23%|██▎       | 25/111 [00:06<00:21,  3.95it/s][A
Valid:  25%|██▌       | 28/111 [00:08<00:18,  4.39it/s][A
Valid:  28%|██▊       | 31/111 [00:09<00:23,  3.47it/s][A
Valid:  31%|███       | 34/111 [00:09<00:22,  3.40it/s][A
Valid:  33%|███▎      | 37/111 [00:10<00:19,  3.88it/s][A
Valid:  36%|███▌      | 40/111 [00:11<00:16,  4.33it/s][A
Valid:  39%|███▊      | 43/111 [00:12<00:18,  3.72it/s][A
Valid:  41%|████▏     | 46/111 [00:13<00:18,  3.49it/s][A
Valid:  

Valid loss: 0.031; Valid accuracy: 0.855



Train:   0%|          | 0/994 [00:00<?, ?it/s][A
Train:   0%|          | 2/994 [00:00<06:51,  2.41it/s][A
Train:   0%|          | 4/994 [00:02<06:51,  2.40it/s][A
Train:   1%|          | 6/994 [00:03<08:29,  1.94it/s][A
Train:   1%|          | 8/994 [00:04<08:15,  1.99it/s][A
Train:   1%|          | 10/994 [00:05<08:09,  2.01it/s][A
Train:   1%|          | 12/994 [00:05<08:04,  2.03it/s][A
Train:   1%|▏         | 14/994 [00:06<07:50,  2.08it/s][A
Train:   2%|▏         | 16/994 [00:08<07:31,  2.16it/s][A
Train:   2%|▏         | 18/994 [00:09<08:34,  1.90it/s][A
Train:   2%|▏         | 19/994 [00:10<08:38,  1.88it/s][A
Train:   2%|▏         | 20/994 [00:11<10:07,  1.60it/s][A
Train:   2%|▏         | 21/994 [00:11<11:19,  1.43it/s][A
Train:   2%|▏         | 23/994 [00:12<09:27,  1.71it/s][A
Train:   2%|▏         | 24/994 [00:12<09:26,  1.71it/s][A
Train:   3%|▎         | 26/994 [00:13<08:36,  1.87it/s][A
Train:   3%|▎         | 28/994 [00:15<07:56,  2.03it/s][A
Train:   

Epoch: 10 Training loss: 0.02



Valid:   0%|          | 0/111 [00:00<?, ?it/s][A
Valid:   3%|▎         | 3/111 [00:00<00:19,  5.66it/s][A
Valid:   5%|▌         | 6/111 [00:01<00:18,  5.74it/s][A
Valid:   8%|▊         | 9/111 [00:02<00:17,  5.77it/s][A
Valid:  11%|█         | 12/111 [00:03<00:24,  4.11it/s][A
Valid:  14%|█▎        | 15/111 [00:03<00:25,  3.70it/s][A
Valid:  16%|█▌        | 18/111 [00:04<00:26,  3.51it/s][A
Valid:  19%|█▉        | 21/111 [00:05<00:22,  4.00it/s][A
Valid:  22%|██▏       | 24/111 [00:06<00:19,  4.39it/s][A
Valid:  24%|██▍       | 27/111 [00:06<00:22,  3.74it/s][A
Valid:  27%|██▋       | 30/111 [00:07<00:22,  3.58it/s][A
Valid:  29%|██▉       | 32/111 [00:08<00:23,  3.32it/s][A
Valid:  32%|███▏      | 35/111 [00:08<00:20,  3.78it/s][A
Valid:  34%|███▍      | 38/111 [00:09<00:17,  4.16it/s][A
Valid:  37%|███▋      | 41/111 [00:09<00:15,  4.57it/s][A
Valid:  40%|███▉      | 44/111 [00:10<00:13,  4.89it/s][A
Valid:  42%|████▏     | 47/111 [00:11<00:12,  5.08it/s][A
Valid:  

Valid loss: 0.028; Valid accuracy: 0.848





In [169]:
 model = AutoModelForSequenceClassification.from_pretrained(f'{MODELS_PATH}/best_checkpoint').to(device)

In [170]:
evaluate(model=model, dataloader=dataloader_valid, set_name='Valid')

Valid: 100%|██████████| 111/111 [00:29<00:00,  3.82it/s]

Valid loss: 0.034; Valid accuracy: 0.887





0.8871606334841629

In [171]:
logit_scores = evaluate_test_dataset2(model=model, dataloader=dataloader_test, set_name='Test')

Test: 100%|██████████| 1563/1563 [06:50<00:00,  3.81it/s]


In [172]:
trec_file_logits = evaluate_ndcg_10(logit_scores, model_name, "logits")

In [173]:
import pandas as pd

run_trec_file_logits = pd.read_csv(trec_file_logits, sep="\t", header=None, 
                   skiprows=1, names=["query", "docid", "rank", "score", "system"])
run_trec_file_logits["q0"] = "q0"
run_trec_file_logits = run_trec_file_logits.to_dict(orient="list")

In [174]:
eval_ndcg10(run_trec_file_logits)

0.6751980134653743

## Checking microsoft/MiniLM-L12-H384-uncased in the full dataset

In [19]:
model_name = 'microsoft/MiniLM-L12-H384-uncased'

In [20]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [21]:
X_train, Y_train, X_val, Y_val = generate_training_data("concatenated.jsonl", model_name)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  df_neg = df_neg.append({"query":row[0], "passage":row[2], "score":0.0},
  df_pos = df_pos.append({"query":row[0], "passage":row[1], "score":1.0},
  df_neg = df_neg.append({"query":row[0], "passage":row[2], "score":0.0},
  df_pos = df_pos.append({"query":row[0], "passage":row[1], "score":1.0},
  df_neg = df_neg.append({"query":row[0], "passage":row[2], "score":0.0},
  df_pos = df_pos.append({"query":row[0], "passage":row[1], "score":1.0},
  df_neg = df_neg.append({"query":row[0], "passage":row[2], "score":0.0},
  df_pos = df_pos.append({"query":row[0], "passage":row[1], "score":1.0},
  df_neg = df_neg.append({"query":row[0], "passage":row[2], "score":0.0},
  df_pos = df_pos.append({"query":row[0], "passage":row[1], "score":1.0},
  df_neg = df_neg.append({"query":row[0], "passage":row[2], "score":0.0},
  df_pos = df_pos.append({"query":row[0], "passage":row[1], "score":1.0},
  df_neg = df_neg.append({"query":row[0], "pass

microsoft/MiniLM-L12-H384-uncased
                                               query  \
0  What is the most suitable protein for a diagno...   
1  What is Cryptosporidium parvum and why is it a...   
2  What is the role of the renin-angiotensin-aldo...   
3  What are the functions of individual endolysos...   
4  What is the prevalence of olfactory dysfunctio...   

                                             passage  score  
0  Rapid identification of novel antigens of Salm...    1.0  
1  Cryptosporidium and host resistance: historica...    1.0  
2  Understanding the Renin-Angiotensin-Aldosteron...    1.0  
3  Specific functions of lysosomal proteases in e...    1.0  
4  Olfactory and rhinological evaluations in SARS...    1.0  


  df_neg = df_neg.append({"query":row[0], "passage":row[2], "score":0.0},
  df_pos = df_pos.append({"query":row[0], "passage":row[1], "score":1.0},
  df_neg = df_neg.append({"query":row[0], "passage":row[2], "score":0.0},
  df_pos = df_pos.append({"query":row[0], "passage":row[1], "score":1.0},
  df_neg = df_neg.append({"query":row[0], "passage":row[2], "score":0.0},
  df_pos = df_pos.append({"query":row[0], "passage":row[1], "score":1.0},
  df_neg = df_neg.append({"query":row[0], "passage":row[2], "score":0.0},
  df_pos = df_pos.append({"query":row[0], "passage":row[1], "score":1.0},
  df_neg = df_neg.append({"query":row[0], "passage":row[2], "score":0.0},
  df_pos = df_pos.append({"query":row[0], "passage":row[1], "score":1.0},
  df_neg = df_neg.append({"query":row[0], "passage":row[2], "score":0.0},
  df_pos = df_pos.append({"query":row[0], "passage":row[1], "score":1.0},
  df_neg = df_neg.append({"query":row[0], "passage":row[2], "score":0.0},
  df_pos = df_pos.append({"query":row[

In [22]:
train_queries = list(X_train["query"])
train_passages = list(X_train["passage"])
val_queries = list(X_val["query"])
val_passages = list(X_val["passage"])

train_queries_tokenized = tokenizer(train_queries, truncation=True, max_length=max_length_query)
train_passages_tokenized = tokenizer(train_passages, truncation=True, max_length=max_length_passage)
val_queries_tokenized = tokenizer(val_queries, truncation=True, max_length=max_length_query)
val_passages_tokenized = tokenizer(val_passages, truncation=True, max_length=max_length_passage)

In [23]:
dataset_train = Dataset(train_queries_tokenized, train_passages_tokenized, Y_train)
assert len(dataset_train[0]['input_ids']) > 0
assert len(dataset_train[1]['attention_mask']) > 0

In [24]:
dataset_valid = Dataset(val_queries_tokenized, val_passages_tokenized, Y_val)

In [25]:
len(dataset_train)

31808

In [26]:
len(dataset_valid)

3536

In [27]:
#dataloader_train = data.DataLoader(dataset_train, batch_size=32, shuffle=True, collate_fn=collate_fn)
#dataloader_valid = data.DataLoader(dataset_valid, batch_size=32, shuffle=False, collate_fn=collate_fn)
dataloader_train = data.DataLoader(dataset_train, batch_size=16, shuffle=True, collate_fn=collate_fn)
dataloader_valid = data.DataLoader(dataset_valid, batch_size=16, shuffle=False, collate_fn=collate_fn)


for batch in dataloader_train:
    assert batch['input_ids'].shape[0] <= dataloader_train.batch_size
    assert batch['input_ids'].shape[1] <= max_length
    break

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [None]:
model = train(model_name, 5)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/MiniLM-L12-H384-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Parameters 33360770


Valid:   0%|          | 0/221 [00:00<?, ?it/s]

Valid loss: 0.694; Valid accuracy: 0.500


Epochs:   0%|          | 0/5 [00:00<?, ?it/s]

Train:   0%|          | 0/1988 [00:00<?, ?it/s]