In [86]:
import random
from collections import defaultdict
from sentence_transformers import SentenceTransformer, SentencesDataset, SentenceTransformerTrainer, SentenceTransformerTrainingArguments, losses
from sentence_transformers.sampler import BatchSampler
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.losses import TripletLoss
from sentence_transformers.readers import LabelSentenceReader, InputExample
from torch.utils.data import DataLoader
from datasets import Dataset, load_dataset

trained_model_path = "models/bert-base-nli-stsb-mean-tokens"

# Load pre-trained model - we are using the original Sentence-BERT for this example.
sbert_model = SentenceTransformer('bert-base-nli-stsb-mean-tokens')

dataset = Dataset.from_json("registration-v1.jsonl")
#dataset = load_dataset("json", data_files="registration-v1.jsonl")
dataset.shape
# anchor   positive    negative
dataset = dataset.rename_column("question", "anchor")
dataset = dataset.rename_column("response", "positive")
dataset = dataset.rename_column("nonanswer", "negative")
dataset
train_dataset = dataset.select_columns(["anchor", "positive", "negative"]).take(7)
eval_dataset = dataset.select_columns(["anchor", "positive", "negative"]).skip(7).take(4)
train_dataset
eval_dataset


Dataset({
    features: ['anchor', 'positive', 'negative'],
    num_rows: 4
})

In [87]:
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    #output_dir="models/mpnet-base-all-nli-triplet",
    output_dir=trained_model_path,
    # Optional training parameters:
    num_train_epochs=100,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # losses that use "in-batch negatives" benefit from no duplicates
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=10,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=10,
    run_name="mpnet-base-all-nli-triplet",  # Will be used in W&B if `wandb` is installed
)

In [88]:
# Set up data for fine-tuning 
#sentence_reader = LabelSentenceReader(folder='~/tsv_files')
#data_list = sentence_reader.get_examples(filename='recipe_bot_data.tsv')
#triplets = triplets_from_labeled_dataset(input_examples=data_list)
#finetune_data = SentencesDataset(examples=triplets, model=sbert_model)
#finetune_data = train_dataset
#finetune_dataloader = DataLoader(finetune_data, shuffle=True, batch_size=16)
loss = losses.TripletLoss(model=sbert_model)
trainer = SentenceTransformerTrainer(
    model=sbert_model,
    args = args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
)
trainer.train()
sbert_model.save_pretrained(trained_model_path)

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss,Validation Loss
10,0.6621,0.689493
20,0.1203,0.076757
30,0.0,0.0
40,0.0,0.0
50,0.0,0.0
60,0.0,0.0
70,0.0,0.0
80,0.0,0.0
90,0.0,0.0
100,0.0,0.0


In [None]:
#model2 = sbert_model
new_question = """Hello LCD, I have a question about IAR CE. What do I do?"""
#model2 = SentenceTransformer(trained_model_path)
model2 = SentenceTransformer.load(trained_model_path)
encoded_question =model2.encode([new_question])
print(len(encoded_question[0]))

In [None]:

encoded_question2 =sbert_model.encode([new_question])
encoded_question2

metrics = trainer.evaluate(eval_dataset=eval_dataset)
print(metrics)


In [None]:
predictions = model2.similarity(encoded_question, encoded_question2)
predictions

In [None]:
encoded_q1 = model2.encode([eval_dataset[0]['anchor']])
encoded_q2 = model2.encode([eval_dataset[1]['anchor']])
encoded_q3 = model2.encode([eval_dataset[2]['anchor']])
encoded_q4 = model2.encode([eval_dataset[3]['anchor']])

predictions2 = model2.similarity(encoded_question, encoded_q2)
predictions2

predictions2 = model2.similarity(encoded_question, encoded_q2)
predictions2

print(model2.similarity(encoded_question, encoded_q1))
print(model2.similarity(encoded_question, encoded_q2))
print(model2.similarity(encoded_question, encoded_q3))
print(model2.similarity(encoded_question, encoded_q4))

