In [1]:
#source: https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/cross-encoder/training_quora_duplicate_questions.py

"""
This examples trains a CrossEncoder for the Quora Duplicate Questions Detection task. A CrossEncoder takes a sentence pair
as input and outputs a label. Here, it output a continious labels 0...1 to indicate the similarity between the input pair.
It does NOT produce a sentence embedding and does NOT work for individual sentences.
Usage:
python training_quora_duplicate_questions.py
"""

'\nThis examples trains a CrossEncoder for the Quora Duplicate Questions Detection task. A CrossEncoder takes a sentence pair\nas input and outputs a label. Here, it output a continious labels 0...1 to indicate the similarity between the input pair.\nIt does NOT produce a sentence embedding and does NOT work for individual sentences.\nUsage:\npython training_quora_duplicate_questions.py\n'

In [2]:
from torch.utils.data import DataLoader
import math
from sentence_transformers import LoggingHandler, util
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator
from sentence_transformers.readers import InputExample
import logging
from datetime import datetime
import os
import gzip
import csv
from zipfile import ZipFile

In [3]:
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)
#### /print debug information to stdout

In [4]:
#Inputfile
train_dataset_file = '../Data/Crossencoder/DF_Constructed_CrossEncoder_train.csv'
test_dataset_file = '../Data/Crossencoder/DF_Constructed_CrossEncoder_test.csv'

In [5]:
# Read the quora dataset split for classification
logger.info("Read train dataset")
train_samples = []

with open(test_dataset_file, 'r', encoding='utf8') as fIn:
    reader = csv.reader(fIn, delimiter=',', quotechar='"')
    next(reader, None)
    for row in reader:
        try:
            train_samples.append(InputExample(texts=[row[1], row[2]], label=int(row[3])))
        except:
            print(row)
                

2022-06-18 14:54:09 - Read train dataset


In [6]:
# Read the quora dataset split for classification
logger.info("Read dev dataset")
dev_samples = []

with open(train_dataset_file, 'r', encoding='utf8') as fIn:
    reader = csv.reader(fIn, delimiter=',', quotechar='"')
    for row in reader:
        try:
            dev_samples.append(InputExample(texts=[row[1], row[2]], label=int(row[3])))
        except:
            print(row)
                

2022-06-18 14:54:09 - Read dev dataset
['', 'clariQ', 'fascet', 'isRelated']


In [7]:
#Configuration
train_batch_size = 64
num_epochs = 10
model_save_path = '../Models/training_F0_7-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")



In [8]:
#We use distilroberta-base with a single label, i.e., it will output a value between 0 and 1 indicating the similarity of the two questions
model = CrossEncoder('roberta-base', num_labels=1) #distilroberta-base try roberta-base


Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.layer_norm.weight', 'roberta.pooler.dense.weight', 'lm_head.decoder.weight', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifie

2022-06-18 14:54:18 - Use pytorch device: cpu


In [9]:
# We wrap train_samples (which is a List[InputExample]) into a pytorch DataLoader
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)

In [10]:
# We add an evaluator, which evaluates the performance during training
evaluator = CEBinaryClassificationEvaluator.from_input_examples(dev_samples, name='Quora-dev')

In [11]:
# Configure the training
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up
logger.info("Warmup-steps: {}".format(warmup_steps))

2022-06-18 14:54:18 - Warmup-steps: 176


In [12]:
# Train the model
model.fit(train_dataloader=train_dataloader,
          evaluator=evaluator,
          epochs=num_epochs,
          evaluation_steps=5000,
          warmup_steps=warmup_steps,
          output_path=model_save_path)

Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/176 [00:00<?, ?it/s]

2022-06-18 16:06:08 - CEBinaryClassificationEvaluator: Evaluating the model on Quora-dev dataset after epoch 0:
2022-06-18 16:26:44 - Accuracy:           83.44	(Threshold: 0.7870)
2022-06-18 16:26:44 - F1:                 84.25	(Threshold: 0.7731)
2022-06-18 16:26:44 - Precision:          80.02
2022-06-18 16:26:44 - Recall:             88.95
2022-06-18 16:26:44 - Average Precision:  90.27

2022-06-18 16:26:44 - Save model to output/training_F0_7-2022-06-18_14-54-09


Iteration:   0%|          | 0/176 [00:00<?, ?it/s]

2022-06-18 17:35:19 - CEBinaryClassificationEvaluator: Evaluating the model on Quora-dev dataset after epoch 1:
2022-06-18 17:55:52 - Accuracy:           94.10	(Threshold: 0.7497)
2022-06-18 17:55:52 - F1:                 94.22	(Threshold: 0.7481)
2022-06-18 17:55:52 - Precision:          92.57
2022-06-18 17:55:52 - Recall:             95.92
2022-06-18 17:55:52 - Average Precision:  98.01

2022-06-18 17:55:52 - Save model to output/training_F0_7-2022-06-18_14-54-09


Iteration:   0%|          | 0/176 [00:00<?, ?it/s]

2022-06-18 18:55:54 - CEBinaryClassificationEvaluator: Evaluating the model on Quora-dev dataset after epoch 2:
2022-06-18 19:13:42 - Accuracy:           96.93	(Threshold: 0.7894)
2022-06-18 19:13:42 - F1:                 96.96	(Threshold: 0.7809)
2022-06-18 19:13:42 - Precision:          96.31
2022-06-18 19:13:42 - Recall:             97.61
2022-06-18 19:13:42 - Average Precision:  99.24

2022-06-18 19:13:42 - Save model to output/training_F0_7-2022-06-18_14-54-09


Iteration:   0%|          | 0/176 [00:00<?, ?it/s]

2022-06-18 20:06:55 - CEBinaryClassificationEvaluator: Evaluating the model on Quora-dev dataset after epoch 3:
2022-06-18 20:23:09 - Accuracy:           97.92	(Threshold: 0.9128)
2022-06-18 20:23:09 - F1:                 97.93	(Threshold: 0.9124)
2022-06-18 20:23:09 - Precision:          97.44
2022-06-18 20:23:09 - Recall:             98.43
2022-06-18 20:23:09 - Average Precision:  99.46

2022-06-18 20:23:09 - Save model to output/training_F0_7-2022-06-18_14-54-09


Iteration:   0%|          | 0/176 [00:00<?, ?it/s]

2022-06-18 21:13:03 - CEBinaryClassificationEvaluator: Evaluating the model on Quora-dev dataset after epoch 4:
2022-06-18 21:27:56 - Accuracy:           98.86	(Threshold: 0.7627)
2022-06-18 21:27:56 - F1:                 98.87	(Threshold: 0.7627)
2022-06-18 21:27:56 - Precision:          98.29
2022-06-18 21:27:56 - Recall:             99.45
2022-06-18 21:27:56 - Average Precision:  99.80

2022-06-18 21:27:56 - Save model to output/training_F0_7-2022-06-18_14-54-09


Iteration:   0%|          | 0/176 [00:00<?, ?it/s]

2022-06-18 22:15:47 - CEBinaryClassificationEvaluator: Evaluating the model on Quora-dev dataset after epoch 5:
2022-06-18 22:29:47 - Accuracy:           99.03	(Threshold: 0.8793)
2022-06-18 22:29:47 - F1:                 99.03	(Threshold: 0.8093)
2022-06-18 22:29:47 - Precision:          98.50
2022-06-18 22:29:47 - Recall:             99.57
2022-06-18 22:29:47 - Average Precision:  99.83

2022-06-18 22:29:47 - Save model to output/training_F0_7-2022-06-18_14-54-09


Iteration:   0%|          | 0/176 [00:00<?, ?it/s]

2022-06-18 23:09:12 - CEBinaryClassificationEvaluator: Evaluating the model on Quora-dev dataset after epoch 6:
2022-06-18 23:19:40 - Accuracy:           99.38	(Threshold: 0.8886)
2022-06-18 23:19:40 - F1:                 99.39	(Threshold: 0.8544)
2022-06-18 23:19:40 - Precision:          99.06
2022-06-18 23:19:40 - Recall:             99.71
2022-06-18 23:19:40 - Average Precision:  99.91

2022-06-18 23:19:40 - Save model to output/training_F0_7-2022-06-18_14-54-09


Iteration:   0%|          | 0/176 [00:00<?, ?it/s]

2022-06-18 23:54:11 - CEBinaryClassificationEvaluator: Evaluating the model on Quora-dev dataset after epoch 7:
2022-06-19 00:02:50 - Accuracy:           99.46	(Threshold: 0.7117)
2022-06-19 00:02:50 - F1:                 99.46	(Threshold: 0.7117)
2022-06-19 00:02:50 - Precision:          99.08
2022-06-19 00:02:50 - Recall:             99.84
2022-06-19 00:02:50 - Average Precision:  99.93

2022-06-19 00:02:50 - Save model to output/training_F0_7-2022-06-18_14-54-09


Iteration:   0%|          | 0/176 [00:00<?, ?it/s]

2022-06-19 00:28:15 - CEBinaryClassificationEvaluator: Evaluating the model on Quora-dev dataset after epoch 8:
2022-06-19 00:35:41 - Accuracy:           99.55	(Threshold: 0.9466)
2022-06-19 00:35:41 - F1:                 99.56	(Threshold: 0.9466)
2022-06-19 00:35:41 - Precision:          99.36
2022-06-19 00:35:41 - Recall:             99.75
2022-06-19 00:35:41 - Average Precision:  99.96

2022-06-19 00:35:41 - Save model to output/training_F0_7-2022-06-18_14-54-09


Iteration:   0%|          | 0/176 [00:00<?, ?it/s]

2022-06-19 01:01:25 - CEBinaryClassificationEvaluator: Evaluating the model on Quora-dev dataset after epoch 9:
2022-06-19 01:08:56 - Accuracy:           99.55	(Threshold: 0.5762)
2022-06-19 01:08:56 - F1:                 99.56	(Threshold: 0.5762)
2022-06-19 01:08:56 - Precision:          99.15
2022-06-19 01:08:56 - Recall:             99.96
2022-06-19 01:08:56 - Average Precision:  99.95



In [13]:
#Run CEBinaryClassificationEvaluator
evaluator(model)

2022-06-19 01:08:57 - CEBinaryClassificationEvaluator: Evaluating the model on Quora-dev dataset:
2022-06-19 01:16:23 - Accuracy:           99.55	(Threshold: 0.5762)
2022-06-19 01:16:23 - F1:                 99.56	(Threshold: 0.5762)
2022-06-19 01:16:23 - Precision:          99.15
2022-06-19 01:16:23 - Recall:             99.96
2022-06-19 01:16:23 - Average Precision:  99.95



0.9995474561726904

In [14]:
#Run CECorrelationEvaluator
evaluator = CECorrelationEvaluator.from_input_examples(dev_samples, name='sts-test')
evaluator(model)

2022-06-19 01:16:23 - CECorrelationEvaluator: Evaluating the model on sts-test dataset:
2022-06-19 01:23:50 - Correlation:	Pearson: 0.9917	Spearman: 0.8653


0.8653128212971559