In [1]:
!pip install sentence-transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentence-transformers
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.0/86.0 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting transformers<5.0.0,>=4.6.0 (from sentence-transformers)
  Downloading transformers-4.29.2-py3-none-any.whl (7.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m85.5 MB/s[0m eta [36m0:00:00[0m
Collecting sentencepiece (from sentence-transformers)
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m84.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting huggingface-hub>=0.4.0 (from sentence-transformers)
  Downloading huggingface_hub-0.15.1-py3-

In [2]:
from sentence_transformers import SentenceTransformer, util
from sentence_transformers import models, losses
from sentence_transformers import LoggingHandler, InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from torch.utils.data import DataLoader
import math
import logging
from datetime import datetime
import os
import gzip
import csv

In [3]:
#Check if dataset exsist. If not, download and extract  it
nli_dataset_path = 'data/AllNLI.tsv.gz'
sts_dataset_path = 'data/stsbenchmark.tsv.gz'

if not os.path.exists(nli_dataset_path):
  util.http_get('https://sbert.net/datasets/AllNLI.tsv.gz', nli_dataset_path)

if not os.path.exists(sts_dataset_path):
  util.http_get('https://sbert.net/datasets/stsbenchmark.tsv.gz', sts_dataset_path)

  0%|          | 0.00/40.8M [00:00<?, ?B/s]

  0%|          | 0.00/392k [00:00<?, ?B/s]

In [9]:
label2int = {'contradiction':0,"entailment":1,"neutral":2}

train_samples=[]
with gzip.open(nli_dataset_path,'rt',encoding='utf8') as f:
  reader = csv.DictReader(f,delimiter='\t',quoting=csv.QUOTE_NONE)
  for row in reader:
    if row['split'] == 'train':
      label_id = label2int[row['label']]
      train_samples.append(InputExample(texts=[row['sentence1'],row['sentence2']],label=label_id))

len(train_samples)

942069

In [13]:
model_name = "bert-base-uncased"

train_batch_size=16
num_epochs = 1

model_path = 'output/training_nli'

# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
word_embedding_model = models.Transformer(model_name)

# Apply mean pooling to get one fixed sized sentence vector
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                               pooling_mode_mean_tokens=True,
                               pooling_mode_cls_token=False,
                               pooling_mode_max_tokens=False)

model = SentenceTransformer(modules=[word_embedding_model,pooling_model])

# Read the AllNLI.tsv.gz file and create the training dataset
label2int = {'contradiction':0,"entailment":1,"neutral":2}

train_samples=[]
with gzip.open(nli_dataset_path,'rt',encoding='utf8') as f:
  reader = csv.DictReader(f,delimiter='\t',quoting=csv.QUOTE_NONE)
  for row in reader:
    if row['split'] == 'train':
      label_id = label2int[row['label']]
      train_samples.append(InputExample(texts=[row['sentence1'],row['sentence2']],label=label_id))

train_dataloader = DataLoader(train_samples[:50000],shuffle=True,batch_size=train_batch_size)
train_loss = losses.SoftmaxLoss(model=model,
                                sentence_embedding_dimension=model.get_sentence_embedding_dimension(),
                                num_labels=len(label2int))

#Read STSbenchmark dataset and use it as development set

dev_samples = []
with gzip.open(sts_dataset_path,'rt',encoding='utf8') as f:
  reader = csv.DictReader(f,delimiter='\t',quoting=csv.QUOTE_NONE)
  for row in reader:
    if row['split'] == 'dev':
      score = float(row['score'])/5.0
      dev_samples.append(InputExample(texts=[row['sentence1'],row['sentence2']],label=score))

dev_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples,batch_size=train_batch_size,name='sts-dev')

# Configure the training
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1)

# Train the model
model.fit(train_objectives=[(train_dataloader,train_loss)],
          evaluator=dev_evaluator,
          epochs = num_epochs,
          evaluation_steps = 1000,
          warmup_steps = warmup_steps,
          output_path = model_path)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/3125 [00:00<?, ?it/s]

In [14]:
# Load the stored model and evaluate its performance on test set
test_samples = []
with gzip.open(sts_dataset_path,'rt',encoding='utf8') as f:
  reader = csv.DictReader(f,delimiter='\t',quoting=csv.QUOTE_NONE)
  for row in reader:
    if row['split'] == 'test':
      score = float(row['score']) / 5.0
      test_samples.append(InputExample(texts = [row['sentence1'],row['sentence2']],label=score))

model = SentenceTransformer(model_path)
test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples,batch_size=train_batch_size,name='sts-test')
test_evaluator(model,output_path = model_path)

0.6888110595893926