In [7]:
from sentence_transformers import SentenceTransformer, InputExample
from sentence_transformers.cross_encoder import CrossEncoder
import gzip
import json
import random
from torch.utils.data import Dataset, IterableDataset
import logging
from torch.utils.data import DataLoader
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
from sentence_transformers.cross_encoder.evaluation import CERerankingEvaluator
import math
from sentence_transformers import LoggingHandler, util
import torch.nn

In [40]:
def get_example(raw_example):
    if isinstance(raw_example, dict):
        return InputExample(texts=[raw_example['query'], random.choice(raw_example['pos'])], label=random.randint(0, 3))
    else:
        return InputExample(texts=[raw_example[0], raw_example[1]], label=random.randint(0, 3))
        
def load_pair_dataset(filepath):
    examples=[]
    with gzip.open(filepath, 'rt') as fIn:
            for line in fIn:
                example = get_example(json.loads(line))
                examples.append(example)
    return examples

full_set = load_pair_dataset("/Users/g.salazar.2/git/trec_dh/gustavo/gooaq_pairs.jsonl.gz")

In [42]:
len(full_set)

3012496

In [41]:
from sklearn.model_selection import train_test_split
(train_set, test_set) = train_test_split(full_set, test_size=0.33, random_state=42, shuffle=False)

In [43]:
len(train_set), len(test_set)

(2018372, 994124)

In [None]:

logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

train_batch_size = 4
num_epochs = 1
num_labels = 4
max_length = 512
evaluation_steps = 100
lr = 7e-6

train_dataloader = DataLoader(train_set, shuffle=True, batch_size=train_batch_size)
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up

default_activation_function = torch.nn.Identity()

model = CrossEncoder('microsoft/deberta-v3-large', num_labels=num_labels, 
                     tokenizer_args={'pad_token': '[PAD]'}, 
                     default_activation_function=default_activation_function)

evaluator = CERerankingEvaluator(test_set, name='train-eval')


logger.info("Warmup-steps: {}".format(warmup_steps))
loss_fct=torch.nn.L1Loss()

model.config.pad_token_id = model.tokenizer.pad_token_id

model.fit(train_dataloader=train_dataloader,
          evaluator=evaluator,
          epochs=num_epochs,
          evaluation_steps=evaluation_steps,
          warmup_steps=warmup_steps,
          optimizer_params={'lr': lr},
          output_path="model_saved")



Downloading (…)lve/main/config.json:   0%|          | 0.00/580 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/874M [00:00<?, ?B/s]

Some weights of the model checkpoint at microsoft/deberta-v3-large were not used when initializing DebertaV2ForSequenceClassification: ['mask_predictions.dense.weight', 'mask_predictions.dense.bias', 'mask_predictions.LayerNorm.bias', 'lm_predictions.lm_head.dense.weight', 'lm_predictions.lm_head.bias', 'mask_predictions.classifier.bias', 'lm_predictions.lm_head.LayerNorm.bias', 'mask_predictions.LayerNorm.weight', 'lm_predictions.lm_head.LayerNorm.weight', 'mask_predictions.classifier.weight', 'lm_predictions.lm_head.dense.bias']
- This IS expected if you are initializing DebertaV2ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DebertaV2ForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from 

Downloading (…)okenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

Downloading spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


2023-06-04 03:19:45 - Use pytorch device: cpu
2023-06-04 03:19:45 - Warmup-steps: 50460


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/504593 [00:00<?, ?it/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
