In [None]:
import jsonlines
import random
from sentence_transformers import SentenceTransformer, models, losses, InputExample
from sentence_transformers.evaluation import BinaryClassificationEvaluator
from torch.utils.data import DataLoader

In [None]:
checkpoint_path = '/media/data/hr/BERTlike/checkpoint-495000'

In [None]:
word_embedding_model = models.Transformer(checkpoint_path, max_seq_length=256)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())

model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

In [None]:
train_data = []
with jsonlines.open('titles_train.jsonl') as reader:
    for obj in reader:
        train_data.append(InputExample(texts=[obj['prev'], obj['cur']]))

In [None]:
sentence_pairs = []
with jsonlines.open('titles_val.jsonl') as reader:
    for i, obj in enumerate(reader):
        sentence_pairs.append(obj)
        if i > 1000:
            break

In [None]:
sentences1 = []
sentences2 = []
labels = []
for pair in sentence_pairs:
    sentences1.append(pair['prev'])
    sentences2.append(pair['cur'])
    labels.append(1)
    sentences1.append(pair['prev'])
    while True:
        sampled_negative = random.choice(sentence_pairs)['cur']
        if sampled_negative != pair['cur']:
            sentences2.append(sampled_negative)
            break
    labels.append(0)

In [None]:
evaluator = BinaryClassificationEvaluator(sentences1=sentences1, sentences2=sentences2, labels=labels)

In [None]:
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=2)
train_loss = losses.MultipleNegativesRankingLoss(model=model)

In [None]:
model.fit(train_objectives=[(train_dataloader, train_loss)], 
          evaluator=evaluator,
          warmup_steps=100,
          checkpoint_save_steps=1000, 
          evaluation_steps=1000,
          output_path='output', 
          checkpoint_path='output', 
          save_best_model=False)