# [Sentence-BERT](https://arxiv.org/pdf/1908.10084.pdf)

[Reference Code](https://www.pinecone.io/learn/series/nlp/train-sentence-transformers-softmax/)

In [None]:
import datasets

snli = datasets.load_dataset('snli', split='train')
snli

In [None]:
m_nli = datasets.load_dataset('glue', 'mnli', split='train')
m_nli

In [None]:
m_nli = m_nli.remove_columns(['idx'])
snli = snli.cast(m_nli.features)
dataset = datasets.concatenate_datasets([snli, m_nli])

In [None]:
print(len(dataset))
# there are -1 values in the label feature, these are where no class could be decided so we remove
dataset = dataset.filter(
    lambda x: 0 if x['label'] == -1 else 1
)
print(len(dataset))

In [None]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
all_cols = ['label']

for part in ['premise', 'hypothesis']:
    dataset = dataset.map(
        lambda x: tokenizer(
            x[part], max_length=128, padding='max_length',
            truncation=True
        ), batched=True
    )
    for col in ['input_ids', 'attention_mask']:
        dataset = dataset.rename_column(
            col, part+'_'+col
        )
        all_cols.append(part+'_'+col)
print(all_cols)

In [1]:
'''python
# covert dataset features to PyTorch tensors
dataset.set_format(type='torch', columns=all_cols)

# initialize the dataloader
batch_size = 16
loader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True
)
'''

"python\n# covert dataset features to PyTorch tensors\ndataset.set_format(type='torch', columns=all_cols)\n\n# initialize the dataloader\nbatch_size = 16\nloader = torch.utils.data.DataLoader(\n    dataset, batch_size=batch_size, shuffle=True\n)\n"

In [None]:
# start from a pretrained bert-base-uncased model
model = BertModel.from_pretrained('bert-base-uncased')

In [None]:
# define mean pooling function
def mean_pool(token_embeds, attention_mask):
    # reshape attention_mask to cover 768-dimension embeddings
    in_mask = attention_mask.unsqueeze(-1).expand(
        token_embeds.size()
    ).float()
    # perform mean-pooling but exclude padding tokens (specified by in_mask)
    pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(
        in_mask.sum(1), min=1e-9
    )
    return pool

In [None]:
uv_abs = torch.abs(torch.sub(u, v))  # produces |u-v| tensor
# then we concatenate
x = torch.cat([u, v, uv_abs], dim=-1)

In [None]:
# we would initialize the feed-forward NN first
ffnn = torch.nn.Linear(768*3, 3)
	...
# then later in the code process our concatenated vector with it
x = ffnn(x)

In [None]:
# as before, we would initialize the loss function first
loss_func = torch.nn.CrossEntropyLoss()
	...
# then later in the code add them to the process
x = loss_func(x, label)  # label is our *true* 0, 1, 2 class

In [None]:
from transformers.optimization import get_linear_schedule_with_warmup

# we would initialize everything first
optim = torch.optim.Adam(model.parameters(), lr=2e-5)
# and setup a warmup for the first ~10% steps
total_steps = int(len(dataset) / batch_size)
warmup_steps = int(0.1 * total_steps)
scheduler = get_linear_schedule_with_warmup(
		optim, num_warmup_steps=warmup_steps,
  	num_training_steps=total_steps - warmup_steps
)
	...
# then during the training loop we update the scheduler per step
scheduler.step()

In [None]:
from tqdm.auto import tqdm

# 1 epoch should be enough, increase if wanted
for epoch in range(1):
    model.train()  # make sure model is in training mode
    # initialize the dataloader loop with tqdm (tqdm == progress bar)
    loop = tqdm(loader, leave=True)
    for batch in loop:
        # zero all gradients on each new step
        optim.zero_grad()
        # prepare batches and more all to the active device
        inputs_ids_a = batch['premise_input_ids'].to(device)
        inputs_ids_b = batch['hypothesis_input_ids'].to(device)
        attention_a = batch['premise_attention_mask'].to(device)
        attention_b = batch['hypothesis_attention_mask'].to(device)
        label = batch['label'].to(device)
        # extract token embeddings from BERT
        u = model(
            inputs_ids_a, attention_mask=attention_a
        )[0]  # all token embeddings A
        v = model(
            inputs_ids_b, attention_mask=attention_b
        )[0]  # all token embeddings B
        # get the mean pooled vectors
        u = mean_pool(u, attention_a)
        v = mean_pool(v, attention_b)
        # build the |u-v| tensor
        uv = torch.sub(u, v)
        uv_abs = torch.abs(uv)
        # concatenate u, v, |u-v|
        x = torch.cat([u, v, uv_abs], dim=-1)
        # process concatenated tensor through FFNN
        x = ffnn(x)
        # calculate the 'softmax-loss' between predicted and true label
        loss = loss_func(x, label)
        # using loss, calculate gradients and then optimize
        loss.backward()
        optim.step()
        # update learning rate scheduler
        scheduler.step()
        # update the TDQM progress bar
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())

In [None]:
import os

model_path = './sbert_test_a'

if not os.path.exists(model_path):
    os.mkdir(model_path)

model.save_pretrained(model_path)