In [1]:
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments, SentenceTransformer, losses
from datasets import load_dataset
from huggingface_hub import login

In [2]:
# Read data/csw24.txt and convert line by line to csv
# each line is word   definition.
# convert to csv with two columns: word and definition
# import pandas as pd

# with open('data/csw24.txt', 'r') as file:
#     lines = file.readlines()

# # each line is word<tab>definition.
# # convert to csv with two columns: word and definition
# data = []
# for line in lines:
#     word, definition = line.strip().split('\t', 1)
#     data.append({'word': word, 'definition': definition})

# df = pd.DataFrame(data)

# assert len(df) == len(lines)
# df.to_csv('data/csw24.csv', index=False)


In [3]:
MODEL_NAME = 'sentence-transformers/all-MiniLM-L6-v2'
model = SentenceTransformer(MODEL_NAME)
# model.max_seq_length = 256 # For trial run purposes

## Matryoshka Loss

Matryoshka Representation Learning trains embeddings at multiple dimensions simultaneously, ensuring smaller embeddings (prefixes of the full embedding) remain useful while optimizing the full dimension. This provides flexible quality/speed tradeoffs from a single model.

In [4]:
base_loss = losses.MultipleNegativesRankingLoss(model)
target_dims = [384, 256]
mrl_loss = losses.MatryoshkaLoss(model, base_loss, target_dims)

## Dataset

The CSW24 dictionary dataset contains word-definition pairs used for training. The dataset is split into train/validation/test sets for model training and evaluation.


In [5]:
DATA_LOCATION = "data/non_indian_words.csv"
dataset = load_dataset("csv", data_files=DATA_LOCATION)

# Split into train and temp (test+val)
splits = dataset['train'].train_test_split(test_size=0.2)  # 80% train, 20% temp
train_dataset = splits['train']
temp = splits['test']

# Split temp into val and test
temp_splits = temp.train_test_split(test_size=0.5)  # 50% val, 50% test
val_dataset = temp_splits['train']
test_dataset = temp_splits['test']

print("Train Dataset Size:", len(train_dataset))
print("Val Dataset Size:", len(val_dataset))
print("Test Dataset Size:", len(test_dataset))

Train Dataset Size: 222635
Val Dataset Size: 27829
Test Dataset Size: 27830


## Evaluator

The InformationRetrievalEvaluator measures how well the model retrieves the correct definition for each word using cosine similarity. It computes accuracy, precision, recall, and NDCG metrics at various top-K thresholds.


In [6]:
evaluator = InformationRetrievalEvaluator(
    queries={i: example['word'] for i, example in enumerate(val_dataset)},
    corpus={i: example['definition'] for i, example in enumerate(val_dataset)},
    relevant_docs={i: [i] for i in range(len(val_dataset))}, # Word i's def is always doc i
    name='dictionary-test'
)

## Trainer

The SentenceTransformerTrainer handles the training loop with the specified loss function, training arguments, and evaluator. It automatically manages batching, gradient updates, evaluation, and checkpointing during training.


In [7]:

training_args = SentenceTransformerTrainingArguments(
    output_dir='./output',
    per_device_train_batch_size=64,
    num_train_epochs=1,
    fp16=True,
    learning_rate=2e-5,
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
)

trainer = SentenceTransformerTrainer(
    model=model,
    train_dataset=train_dataset,
    loss=mrl_loss,
    args=training_args,
    evaluator=evaluator
)

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

In [8]:
trainer.train()



Step,Training Loss,Validation Loss,Dictionary-test Cosine Accuracy@1,Dictionary-test Cosine Accuracy@3,Dictionary-test Cosine Accuracy@5,Dictionary-test Cosine Accuracy@10,Dictionary-test Cosine Precision@1,Dictionary-test Cosine Precision@3,Dictionary-test Cosine Precision@5,Dictionary-test Cosine Precision@10,Dictionary-test Cosine Recall@1,Dictionary-test Cosine Recall@3,Dictionary-test Cosine Recall@5,Dictionary-test Cosine Recall@10,Dictionary-test Cosine Ndcg@10,Dictionary-test Cosine Mrr@10,Dictionary-test Cosine Map@100
100,1.0186,No log,0.593913,0.7637,0.79586,0.823709,0.593913,0.254567,0.159172,0.082371,0.593913,0.7637,0.79586,0.823709,0.71801,0.68313,0.685363
200,0.7633,No log,0.605088,0.774767,0.805203,0.830033,0.605088,0.258256,0.161041,0.083003,0.605088,0.774767,0.805203,0.830033,0.727449,0.693467,0.695764
300,0.75,No log,0.625211,0.784829,0.811168,0.834741,0.625211,0.26161,0.162234,0.083474,0.625211,0.784829,0.811168,0.834741,0.739836,0.708362,0.710553
400,0.7503,No log,0.63671,0.787955,0.81318,0.835675,0.63671,0.262652,0.162636,0.083568,0.63671,0.787955,0.81318,0.835675,0.745611,0.715737,0.717937
500,0.7271,No log,0.640806,0.793094,0.817025,0.839556,0.640806,0.264365,0.163405,0.083956,0.640806,0.793094,0.817025,0.839556,0.749598,0.719767,0.721885
600,0.6531,No log,0.641956,0.793561,0.818714,0.840526,0.641956,0.26452,0.163743,0.084053,0.641956,0.793561,0.818714,0.840526,0.750772,0.720998,0.723131
700,0.6586,No log,0.651515,0.796579,0.820655,0.842143,0.651515,0.265526,0.164131,0.084214,0.651515,0.796579,0.820655,0.842143,0.75605,0.727491,0.729628
800,0.6559,No log,0.656869,0.799094,0.822128,0.843149,0.656869,0.266365,0.164426,0.084315,0.656869,0.799094,0.822128,0.843149,0.759083,0.731183,0.733346
900,0.6116,No log,0.652233,0.798735,0.821912,0.843976,0.652233,0.266245,0.164382,0.084398,0.652233,0.798735,0.821912,0.843976,0.757233,0.728463,0.730594
1000,0.615,No log,0.663912,0.80064,0.822703,0.844155,0.663912,0.26688,0.164541,0.084416,0.663912,0.80064,0.822703,0.844155,0.762504,0.735456,0.73761


TrainOutput(global_step=3479, training_loss=0.6101975886429065, metrics={'train_runtime': 143607.9898, 'train_samples_per_second': 1.55, 'train_steps_per_second': 0.024, 'total_flos': 0.0, 'train_loss': 0.6101975886429065, 'epoch': 1.0})

In [9]:
FINAL_MODEL_REPO="models/scrabble-embed-v2"
model.save_pretrained(FINAL_MODEL_REPO)

## HuggingFace Hub Push

The trained model is pushed to the HuggingFace Hub for sharing and deployment. This allows the model to be easily loaded and used by others using the SentenceTransformer library.


In [10]:
login()
REPO_ID = "mehularora/scrabble-embed-v2"
model.push_to_hub(repo_id=REPO_ID)

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.svâ€¦

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

'https://huggingface.co/mehularora/scrabble-embed-v2/commit/a4bf89bfb833f234dbc5ca0fd8ba78503db961fe'