In [None]:
from sentence_transformers import SentenceTransformer, SentencesDataset, InputExample, LoggingHandler, losses, models, util
from torch.utils.data import DataLoader
from sentence_transformers.evaluation import TripletEvaluator
from datetime import datetime
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
tqdm.pandas()
import pandas as pd
import numpy as np
from util import *
import csv
import logging
import os
import util

In [None]:
subset = list(sorted(set(all_tiers_100)-set(["PersonalizedProduct"])))
training_set = pd.read_json("training_set.json.gz", lines=True, orient="records")
testing_set = pd.read_json("testing_set.json.gz", lines=True, orient="records")

In [None]:
from sklearn.preprocessing import LabelEncoder

training_labels = training_set[subset].apply(util.array_labels_textual, axis=1).values.tolist()
testing_labels = testing_set[subset].apply(util.array_labels_textual, axis=1).values.tolist()
all_labels = training_labels + testing_labels

lbe = LabelEncoder()
lbe.fit(all_labels)
training_set['label'] = lbe.transform(training_labels)
testing_set['label'] = lbe.transform(testing_labels)

In [None]:
import funcy as f
@f.collecting
def create_examples(row):
    abstract = row.abstract
    claims = row.claims
    yield (abstract, claims, 1)
    for text in [abstract,claims]:
        for tag in subset:
            yield (text, f"Tag: {tier_translations[tag]}", row[tag])
raw_triplets = training_set.apply(create_examples, axis=1).explode()

def build_example(entry):
    return InputExample(texts=[entry[0], entry[1]], label=entry[2])

all_examples = raw_triplets.apply(build_example).sample(frac=1.0)
dev_examples = all_examples[:1000].values
train_examples = all_examples[1000:].values

In [None]:
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])

In [None]:
model_name = "bertForPatents/"

In [None]:
train_batch_size = 8
output_path = "output/training-triplets-"+model_name+"-"+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
num_epochs = 10

In [None]:
word_embedding_model = models.Transformer(model_name, max_seq_length=192, model_args={"gradient_checkpointing": True})

In [None]:
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                               pooling_mode_mean_tokens=True,
                               pooling_mode_cls_token=False,
                               pooling_mode_max_tokens=False)

In [None]:
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

In [None]:
from sentence_transformers.losses import OnlineContrastiveLoss

train_dataset = SentencesDataset(train_examples, model=model)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
train_loss = OnlineContrastiveLoss(model=model)

In [None]:
from sentence_transformers.evaluation import BinaryClassificationEvaluator
evaluator = BinaryClassificationEvaluator.from_input_examples(dev_examples, name='dev')

In [None]:
warmup_steps = int(len(train_dataset) * num_epochs / train_batch_size * 0.1) #10% of train data

In [None]:
# Train the model
model.fit(train_objectives=[(train_dataloader, train_loss)],
          evaluator=evaluator,
          use_amp=True,
          epochs=num_epochs,
          evaluation_steps=1000,
          warmup_steps=warmup_steps,
          output_path=output_path)

# ##############################################################################
# #
# # Load the stored model and evaluate its performance on STS benchmark dataset
# #
# ##############################################################################

# logging.info("Read test examples")
# test_examples = []
# with open(os.path.join(dataset_path, 'test.csv'), encoding="utf-8") as fIn:
#     reader = csv.DictReader(fIn, delimiter=',', quoting=csv.QUOTE_MINIMAL)
#     for row in reader:
#         test_examples.append(InputExample(texts=[row['Sentence1'], row['Sentence2'], row['Sentence3']]))


# model = SentenceTransformer(output_path)
# test_evaluator = TripletEvaluator.from_input_examples(test_examples, name='test')
# test_evaluator(model, output_path=output_path)