In [None]:
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses, SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers
from datasets import Dataset
import json
import os

base_embedding = "sentence-transformers/all-mpnet-base-v2"
output_path = "models/sft-sql-embedding"
train_path = 'data/spider/train-sets/sql-embedding-train-set.json'

model = SentenceTransformer(base_embedding)

args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="models/mpnet-base-spider",
    # Optional training parameters:
    num_train_epochs=8,
    per_device_train_batch_size=160,
    learning_rate=2e-5,
    warmup_ratio=0.2,#
    fp16=True, 
    bf16=False,  
    batch_sampler=BatchSamplers.NO_DUPLICATES,  
    # Optional tracking/debugging parameters:
    save_strategy="steps",
    save_steps=200,
    save_total_limit=2,  
    logging_steps=100, 
    run_name="mpnet-base-spider",
    dataloader_num_workers=16, 
)

with open(train_path, 'r') as file:
    train = json.load(file)

num_train_samples = 300000

sentence1s = [entry["sentence1"] for entry in train[0:num_train_samples]]
sentence2s = [entry["sentence2"] for entry in train[0:num_train_samples]]
scores = [entry["score"] for entry in train[0:num_train_samples]]

embedding_train_dataset = Dataset.from_dict({
    'sentence1': sentence1s,
    'sentence2': sentence2s,
    'score': scores
})

loss = losses.CoSENTLoss(model)

trainer = SentenceTransformerTrainer(model=model,
                                     args=args,
                                     loss=loss,
                                     train_dataset=embedding_train_dataset)

trainer.train()

model.save_pretrained(output_path)

In [7]:
# Push the trained SQL embedding model to Hugging Face
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("models/sft-sql-embedding")
model.push_to_hub("sft-sql-embedding")

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

'https://huggingface.co/s2593817/sft-sql-embedding/commit/c89964f560f060f20de7c731dc274b0bc77d785a'