In [1]:
from sentence_transformers import SentenceTransformer, SentencesDataset, InputExample, LoggingHandler, losses, models, util
from torch.utils.data import DataLoader
from sentence_transformers.evaluation import TripletEvaluator
from datetime import datetime
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
tqdm.pandas()

import csv
import logging
import os

  from pandas import Panel


In [2]:
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])

In [3]:
model_name = "/var/patentmark/transformer-training/patent-electra-v4"

In [4]:
train_batch_size = 16
output_path = "output/training-contrastive-"+model_name+"-"+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
num_epochs = 1

In [5]:
word_embedding_model = models.Transformer(model_name)

In [6]:
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                               pooling_mode_mean_tokens=True,
                               pooling_mode_cls_token=False,
                               pooling_mode_max_tokens=False)

In [7]:
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

2020-11-21 12:22:06 - Use pytorch device: cuda


In [10]:

# if label not in label_map:
#                 label_map[label] = len(label_map)

#             label_id = label_map[label]
#             guid += 1
#             examples.append(InputExample(guid=guid, texts=[text], label=label_id))
            
            
# def build_example(row):
#     return InputExample(texts=[row['label'], row['positive'], row['negative']], label=0)

#logging.info("Read Triplet train dataset")
train_examples_df = pd.read_parquet("cte_tagged.parquet")
def build_example(row):
    return InputExample(texts=[row['abstract']], label=row['label_id'])


In [11]:
train_examples_df['label'] = train_examples_df.final_tags.astype('category')

In [12]:
train_examples_df['label_id'] = train_examples_df.label.cat.codes

In [13]:
train_examples = train_examples_df.progress_apply(build_example, axis=1).values

HBox(children=(FloatProgress(value=0.0, max=4382.0), HTML(value='')))




In [14]:
train_dataset = SentencesDataset(train_examples, model=model)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)


In [15]:
logging.info("Read Triplet dev dataset")
dev_examples_df = pd.read_parquet("cte_tagged_testing.parquet")
dev_examples = train_examples_df.progress_apply(build_example, axis=1).values

2020-11-21 12:22:14 - Read Triplet dev dataset


HBox(children=(FloatProgress(value=0.0, max=4382.0), HTML(value='')))




In [16]:
from sentence_transformers import evaluation

In [23]:
#evaluator = BatchSemiHardTripletLoss.
#evaluator = evaluation.LabelAccuracyEvaluator? #(dev_examples) #TripletEvaluator.from_input_examples(dev_examples, name='dev')

In [24]:
train_loss = losses.BatchSemiHardTripletLoss(model=model)  #TripletLoss(model=model)

In [25]:
warmup_steps = int(len(train_dataset) * num_epochs / train_batch_size * 0.1) #10% of train data


In [27]:
# Train the model
model.fit(train_objectives=[(train_dataloader, train_loss)],
          #evaluator=evaluator,
          epochs=50,
          #evaluation_steps=1000,
          warmup_steps=warmup_steps,
          output_path=output_path)

# ##############################################################################
# #
# # Load the stored model and evaluate its performance on STS benchmark dataset
# #
# ##############################################################################

# logging.info("Read test examples")
# test_examples = []
# with open(os.path.join(dataset_path, 'test.csv'), encoding="utf-8") as fIn:
#     reader = csv.DictReader(fIn, delimiter=',', quoting=csv.QUOTE_MINIMAL)
#     for row in reader:
#         test_examples.append(InputExample(texts=[row['Sentence1'], row['Sentence2'], row['Sentence3']]))


# model = SentenceTransformer(output_path)
# test_evaluator = TripletEvaluator.from_input_examples(test_examples, name='test')
# test_evaluator(model, output_path=output_path)

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=50.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=274.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=274.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=274.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=274.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=274.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=274.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=274.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=274.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=274.0, style=ProgressStyle(description_wi…





KeyboardInterrupt: 

In [29]:
model.save("contrastive")

2020-11-21 12:49:01 - Save model to contrastive
