# Tutorial 3: Train your own ESM

In this tutorial, we will train our own ESM and publish it on the Hugging Face Hub, such that others can use it for intermediate task selection.

In [1]:
from hfselect import ESMTrainer, Dataset
from transformers import BertModel, BertTokenizer

ESMs should always be trained by embedding a dataset once with a base model and once with the base model after being fine-tuned on the same dataset. For example, we train BERT on the IMDB dataset. Then, we embed the IMDB dataset using BERT and the fine-tuned version of BERT. The ESM will be trained on the resulting pairs of embeddings.

For the sake of this tutorial we will work around this, as we don't want to fine-tune a language model. We will use BERT, a fine-tuned version of BERT from the HF Hub, and the dataset that is lists as being fine-tuned on. Furthermore, we will not embed the complete train dataset but only a sample.

<span style="color:red">
<strong>*We advise you to embed the exact same dataset (including its length) for embedding as used for fine-tuning the language model. This way, other users know exactly what the ESM represents.
</strong></span>

In [2]:
base_model = BertModel.from_pretrained("google-bert/bert-base-uncased")
tuned_model = BertModel.from_pretrained("prithivMLmods/Spam-Bert-Uncased")

tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")

In [3]:
dataset = Dataset.from_hugging_face(
    name="prithivMLmods/Spam-Text-Detect-Analysis",
    split="train",
    text_col="Message",
    label_col="Category",
    is_regression=False,
    num_examples=1000
)

# Training the ESM

We train the ESM with the default parameter for training.

In [4]:
device_name = "cpu" #Change this to cuda if you want to use a GPU

trainer = ESMTrainer(device_name=device_name)

esm = trainer.train_with_models(dataset=dataset, base_model=base_model, tuned_model=tuned_model, tokenizer=tokenizer)

Computing embedding dataset:   0%|          | 0/8 [00:00<?, ?batch/s]

Training ESM:   0%|          | 0/10 [00:00<?, ?epoch/s]

In [5]:
esm

ESM - Task ID: prithivMLmods/Spam-Text-Detect-Analysis - Subset: None

The config of the ESM gets filled with as much metadata as possible from the training process. Feel free to supplement it with relevant information.

In [6]:
esm.config

ESMConfig {
  "base_model_name": "google-bert/bert-base-uncased",
  "developers": null,
  "esm_architecture": null,
  "esm_batch_size": 32,
  "esm_embedding_dim": null,
  "esm_learning_rate": 0.001,
  "esm_num_epochs": 10,
  "esm_optimizer": null,
  "esm_weight_decay": 0.01,
  "label_column": "Category",
  "language": null,
  "lm_batch_size": null,
  "lm_learning_rate": null,
  "lm_num_epochs": null,
  "lm_optimizer": null,
  "lm_weight_decay": null,
  "num_examples": 1000,
  "seed": null,
  "streamed": false,
  "task_id": "prithivMLmods/Spam-Text-Detect-Analysis",
  "task_split": "train",
  "task_subset": null,
  "text_column": "Message",
  "transformers_version": "4.47.1"
}

# Testing the ESM

This step is not necessary for using the ESM. We demonstrate that it succesfully transforms embedding from a 768-dimensional embeddings space to an embedding space in the dimension.

In [7]:
tokenized_input = tokenizer(
    dataset[0]["Message"],
    padding='max_length',
    truncation=True,
    max_length=128,
    return_tensors='pt',
    return_token_type_ids=False
)

base_embedding = base_model(**tokenized_input)[1]
transformed_embedding = esm(base_embedding)

print(transformed_embedding.shape)

torch.Size([1, 768])


# Saving / Publishing the ESM

In [8]:
esm.to_disk("esm.safetensors")

In [9]:
esm.publish("davidschulte/ESM__prithivMLmods_Spam-Text-Detect-Analysis")

model.safetensors:   0%|          | 0.00/2.36M [00:00<?, ?B/s]

The ESM was succesfully published on the HF Hub and can now be accessed by other users to rank the *Spam-Text-Detect-Analysis* dataset.