# Train and Fine-Tune Sentence Transformers Models

## Defining the model object


In [1]:
#importing necessary libraries

import pandas as pd
from sentence_transformers import SentenceTransformer, models, InputExample, losses
from torch.utils.data import DataLoader


In [2]:
'''creating a model object with a simple architecture of a pretrained model and a pooling layer after it'''

## Step 1: use an existing language model
word_embedding_model = models.Transformer('distilroberta-base')

## Step 2: use a pool function over the token embeddings
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())

## Join steps 1 and 2 using the modules argument
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

## Preparing the dataset for training a Sentence Transformers model


Looks like dataset['train']['set'] is used as the training data. It is a list of dictionaries. In the following section the dataloader will be defined using the musiccaps data.

In [3]:
# triplets_train_lp_musiccaps_msd.csv has 3 columns query, pos and negative and has captions of music
music_corpus = pd.read_csv('/home/mendu/Thesis/data/musiccaps/triplets_train_lp_musiccaps_msd.csv', index_col = [0])

In [4]:
music_corpus

Unnamed: 0,query,pos,neg
0,"This aggressive, confrontational, and energeti...",This song is an explosive and cathartic anthem...,The song is a hardcore gangsta rap with heavy ...
1,This alternative indie rock song combines gidd...,The song is a playful and fun alternative indi...,"This song speaks to a global audience, transce..."
2,This song's Nashville Sound Countrypolitan ble...,"A melancholic, reflective, and bittersweet cou...",The song features a blend of various cultural ...
3,Get lost in the captivating sound of alternati...,An upbeat and energetic indie-pop rock song wi...,The song is a representation of the hip-hop su...
4,Take a nostalgic journey down memory lane with...,A sentimental and reflective country pop song ...,This electrifying jazz fusion track is a celeb...
...,...,...,...
444860,Get ready to experience the perfect fusion of ...,The song has a catchy beat and upbeat rock ins...,Smells Like Teen Spirit by Nirvana is a grunge...
444861,This mind-bending tune combines the raw energy...,A genre-bending song that blends psychedelic g...,This electronic song is a chill and mellow tun...
444862,This easy listening instrumental pop track is ...,This instrumental pop-jazz song is a laid-back...,This song is a reflective and introspective ex...
444863,This electronic techno track incorporates elem...,This electronic song is a combination of techn...,This house-inspired dance track will get you m...


In [5]:
for index, row in music_corpus.iterrows():
    if index >= 5:
        break
    print(row['query'])

This aggressive, confrontational, and energetic alternative indie rock song boasts self-conscious, rowdy bravado with heavy punk and pop rock influence, filled with passionate, confident, and gutsy vocals, as well as swaggering urgency, and anguished distraught feelings. Its cathartic and rebellious lyrics, dramatic delivery, and street-smart attitude make it a perfect fit for anyone seeking a cutting-edge, alternative pop rock, punk rock, or hardcore punk sound, while also featuring a summery, knotty, volatile, and fiery new wave vibe.
This alternative indie rock song combines giddy rhythms and crunchy guitar riffs with angst-ridden lyrics, evoking anguished distraught emotions while still maintaining a playful and fun vibe that's perfect for hanging out or driving around town. The mix of punk pop and pop rock influences creates a unique sound that's somewhere between alternative pop rock and rock, making it perfect for fans of all genres.
This song's Nashville Sound Countrypolitan bl

In [6]:

train_examples = []
train_data = music_corpus
# For agility we only 1/4 of our available data
# there are 444865 dictionaries in the train_data
# Only using 50k dictionaries
n_examples = music_corpus.shape[0] // 8

for index, row in music_corpus.iterrows():
    # if index >= n_examples:
    #     break
    train_examples.append(InputExample(texts=[row['query'], row['pos'], row['neg']]))

In [7]:
len(train_examples)

444865

We wrap our training dataset into a Pytorch `Dataloader` to shuffle examples and get batch sizes.

In [8]:
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

## Loss functions for training a Sentence Transformers model


In [9]:
train_loss = losses.TripletLoss(model=model)

## How to train a Sentence Transformer model


In [10]:
num_epochs = 10

warmup_steps = int(len(train_dataloader) * num_epochs * 0.1) #10% of train data

Training takes around 45 minutes with a Google Colab Pro account. Decrease the number of epochs and examples if you are using a free account or no GPU.

In [11]:
history = model.fit(train_objectives=[(train_dataloader, train_loss)],
          epochs=num_epochs,
          warmup_steps=warmup_steps)
#Maybe check for model.run/ model.forward to make this a sentence embedding

Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/27805 [00:00<?, ?it/s]

Iteration:   0%|          | 0/27805 [00:00<?, ?it/s]

Iteration:   0%|          | 0/27805 [00:00<?, ?it/s]

Iteration:   0%|          | 0/27805 [00:00<?, ?it/s]

Iteration:   0%|          | 0/27805 [00:00<?, ?it/s]

Iteration:   0%|          | 0/27805 [00:00<?, ?it/s]

Iteration:   0%|          | 0/27805 [00:00<?, ?it/s]

Iteration:   0%|          | 0/27805 [00:00<?, ?it/s]

Iteration:   0%|          | 0/27805 [00:00<?, ?it/s]

Iteration:   0%|          | 0/27805 [00:00<?, ?it/s]

In [12]:
model.save(path = '/home/mendu/Thesis/data/musiccaps/new_embedding_model',
           model_name = 'sentence_embedding_finetunned_on_musiccaps',
           train_datasets = ['triplets_train_lp_musiccaps_msd'])

In [13]:
# Load the fine-tuned model
# model_ = SentenceTransformer('/home/mendu/Thesis/data/musiccaps/embedding_model')

In [14]:
# model_