In [1]:
!pip install pandas




In [2]:
!pip install sentence-transformers
!pip install datasets



In [3]:
# Read data/csw24.txt and convert line by line to csv
# each line is word   definition.
# convert to csv with two columns: word and definition
import pandas as pd

with open('data/csw24.txt', 'r') as file:
    lines = file.readlines()

# each line is word<tab>definition.
# convert to csv with two columns: word and definition
data = []
for line in lines:
    word, definition = line.strip().split('\t', 1)
    data.append({'word': word, 'definition': definition})

df = pd.DataFrame(data)

assert len(df) == len(lines)
df.to_csv('data/csw24.csv', index=False)


In [5]:
from sentence_transformers import SentenceTransformer, losses, SentenceTransformerTrainer


In [48]:
MODEL_NAME = 'sentence-transformers/all-MiniLM-L6-v2'
# MODEL_NAME = 'Qwen/Qwen3-Embedding-0.6B'

In [49]:
model = SentenceTransformer(MODEL_NAME)

In [50]:
model.max_seq_length = 256

In [51]:
base_loss = losses.MultipleNegativesRankingLoss(model)

In [57]:
# Note: Target dims should contain the model's dimensions.
target_dims = [384, 256]
mrl_loss = losses.MatryoshkaLoss(model, base_loss, target_dims)

In [58]:
from datasets import load_dataset

dataset = load_dataset("csv", data_files="data/csw24.csv")

splits = dataset['train'].train_test_split(test_size=0.1)
test_dataset = splits['test']
train_val_dataset = splits['train']
train_val_splits = train_val_dataset.train_test_split(test_size=0.1)
train_dataset = train_val_splits['train']
val_dataset = train_val_splits['test']


print(len(train_dataset), len(val_dataset), len(test_dataset))

227518 25280 28089


In [59]:
from sentence_transformers.evaluation import InformationRetrievalEvaluator
evaluator = InformationRetrievalEvaluator(
    queries={i: example['word'] for i, example in enumerate(val_dataset)},
    corpus={i: example['definition'] for i, example in enumerate(val_dataset)},
    relevant_docs={i: [i] for i in range(len(val_dataset))}, # Word i's def is always doc i
    name='dictionary-test'
)

In [60]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments

training_args = SentenceTransformerTrainingArguments(
    output_dir='./output',
    per_device_train_batch_size=64,  # batch size
    num_train_epochs=1,
    fp16=True,
    learning_rate=2e-5,
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
)

trainer = SentenceTransformerTrainer(
    model=model,
    train_dataset=train_dataset,
    loss=mrl_loss,
    args=training_args,
    evaluator=evaluator
)

                                                                     

In [61]:
trainer.train()



Step,Training Loss,Validation Loss,Dictionary-test Cosine Accuracy@1,Dictionary-test Cosine Accuracy@3,Dictionary-test Cosine Accuracy@5,Dictionary-test Cosine Accuracy@10,Dictionary-test Cosine Precision@1,Dictionary-test Cosine Precision@3,Dictionary-test Cosine Precision@5,Dictionary-test Cosine Precision@10,Dictionary-test Cosine Recall@1,Dictionary-test Cosine Recall@3,Dictionary-test Cosine Recall@5,Dictionary-test Cosine Recall@10,Dictionary-test Cosine Ndcg@10,Dictionary-test Cosine Mrr@10,Dictionary-test Cosine Map@100
100,1.5353,No log,0.504826,0.67057,0.709771,0.743513,0.504826,0.223523,0.141954,0.074351,0.504826,0.67057,0.709771,0.743513,0.630627,0.593713,0.596473
200,1.2836,No log,0.546519,0.691377,0.720807,0.748536,0.546519,0.230459,0.144161,0.074854,0.546519,0.691377,0.720807,0.748536,0.654278,0.623322,0.626132
300,1.2305,No log,0.559217,0.698972,0.727532,0.755617,0.559217,0.232991,0.145506,0.075562,0.559217,0.698972,0.727532,0.755617,0.663688,0.633532,0.636208
400,1.1669,No log,0.56068,0.700356,0.729945,0.757081,0.56068,0.233452,0.145989,0.075708,0.56068,0.700356,0.729945,0.757081,0.665149,0.634996,0.637667
500,1.1904,No log,0.572271,0.704905,0.732239,0.758109,0.572271,0.234968,0.146448,0.075811,0.572271,0.704905,0.732239,0.758109,0.67144,0.642986,0.645754
600,1.0998,No log,0.573774,0.707951,0.734375,0.761709,0.573774,0.235984,0.146875,0.076171,0.573774,0.707951,0.734375,0.761709,0.673844,0.645033,0.647714
700,1.0655,No log,0.575277,0.708861,0.735285,0.762658,0.575277,0.236287,0.147057,0.076266,0.575277,0.708861,0.735285,0.762658,0.675059,0.646343,0.649096
800,1.095,No log,0.580934,0.711472,0.735997,0.763252,0.580934,0.237157,0.147199,0.076325,0.580934,0.711472,0.735997,0.763252,0.678071,0.650149,0.652928
900,1.1535,No log,0.585839,0.712658,0.738924,0.765427,0.585839,0.237553,0.147785,0.076543,0.585839,0.712658,0.738924,0.765427,0.681258,0.653691,0.656405
1000,1.0047,No log,0.586432,0.713331,0.73754,0.765071,0.586432,0.237777,0.147508,0.076507,0.586432,0.713331,0.73754,0.765071,0.681367,0.653972,0.65674


TrainOutput(global_step=3555, training_loss=1.0422707043954926, metrics={'train_runtime': 4517.6472, 'train_samples_per_second': 50.362, 'train_steps_per_second': 0.787, 'total_flos': 0.0, 'train_loss': 1.0422707043954926, 'epoch': 1.0})

In [62]:
model.save_pretrained("models/scrabble-embed/final")

In [66]:
!pip install ipywidgets

Collecting ipywidgets
  Downloading ipywidgets-8.1.8-py3-none-any.whl.metadata (2.4 kB)
Collecting widgetsnbextension~=4.0.14 (from ipywidgets)
  Downloading widgetsnbextension-4.0.15-py3-none-any.whl.metadata (1.6 kB)
Collecting jupyterlab_widgets~=3.0.15 (from ipywidgets)
  Downloading jupyterlab_widgets-3.0.16-py3-none-any.whl.metadata (20 kB)
Downloading ipywidgets-8.1.8-py3-none-any.whl (139 kB)
Downloading jupyterlab_widgets-3.0.16-py3-none-any.whl (914 kB)
   ---------------------------------------- 0.0/914.9 kB ? eta -:--:--
   ---------------------- ----------------- 524.3/914.9 kB 2.8 MB/s eta 0:00:01
   ---------------------------------------- 914.9/914.9 kB 3.8 MB/s  0:00:00
Downloading widgetsnbextension-4.0.15-py3-none-any.whl (2.2 MB)
   ---------------------------------------- 0.0/2.2 MB ? eta -:--:--
   --------- ------------------------------ 0.5/2.2 MB 2.4 MB/s eta 0:00:01
   ------------------- -------------------- 1.0/2.2 MB 2.3 MB/s eta 0:00:01
   ----------------

In [67]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [69]:
model.push_to_hub(repo_id='Mehularora/scrabble-embed-v1')

model.safetensors: 100%|██████████| 90.9M/90.9M [01:17<00:00, 1.17MB/s]


'https://huggingface.co/mehularora/scrabble-embed-v1/commit/6928d152581e6be299af0c53e4e0d48668c22622'