In [3]:
import os
import sys
import math
import numpy as np
import ast
from tqdm import tqdm
import random
from sentence_transformers import models, losses, datasets
from sentence_transformers import Loggimm...'

ngHandler, SentenceTransformer, util, InputExample

class SentenceTripletTrainer:
    def __init__(self, train_file, model_name, train_batch, max_seq_length, num_epoch):
        self.train_file = train_file
        self.model_name = model_name
        self.train_batch = train_batch
        self.max_seq_length = max_seq_length
        self.num_epoch = num_epoch
        self.train_set = None
        self.word_embedding_model = None
        self.pooling_model = None
        self.model = None
        self.train_samples = []
        self.loader = None
        self.train_loss = None
        self.epochs = None
        self.warmup_steps = None
        
    def load_data(self):
        with open(self.train_file) as f:
            self.train_set = [ast.literal_eval(j) for j in f.read().strip().splitlines()]
        self.train_set = random.sample(self.train_set, 500000)

    def build_model(self):
        self.word_embedding_model = models.Transformer(self.model_name, max_seq_length = self.max_seq_length)
        self.pooling_model = models.Pooling(self.word_embedding_model.get_word_embedding_dimension(), pooling_mode = 'mean')
        self.model = SentenceTransformer(modules = [self.word_embedding_model, self.pooling_model], device = 'cuda')
        self.model = self.model.to('cuda')

    def prepare_samples(self):
        for row in self.train_set:
            self.train_samples.append(InputExample(texts = [row[0], row[1], row[2]]))
        self.loader = datasets.NoDuplicatesDataLoader(self.train_samples, batch_size = self.train_batch)

    def set_loss(self):
        self.train_loss = losses.TripletLoss(self.model, distance_metric=losses.TripletDistanceMetric.COSINE, triplet_margin=.1)

    def set_params(self):
        self.epochs = self.num_epoch
        self.warmup_steps = int(len(self.loader)*self.epochs*0.05)

    def train(self):
        self.model.fit(train_objectives = [(self.loader, self.train_loss)],
                  epochs = self.epochs,
                  output_path = 'mini_lm_5M',
                  warmup_steps = self.warmup_steps,
                  show_progress_bar = True)
        
    def run(self):
        self.load_data()
        self.build_model()
        self.prepare_samples()
        self.set_loss()
        self.set_params()
        self.train()

# create an instance of the class and run the model
#sentence_triplet_trainer = SentenceTripletTrainer('data/triplets.jsonl', 'sentence-transformers/all-MiniLM-L6-v2', 28, 64, 1)
#sentence_triplet_trainer.run()


In [None]:
sentence_triplet_trainer = SentenceTripletTrainer('data/triplets.jsonl', 'sentence-transformers/all-MiniLM-L6-v2', 28, 64, 1)
sentence_triplet_trainer.run()