In [2]:
import os
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['https_proxy'] = 'http://127.0.0.1:7890'

In [164]:
import torch
import torch.nn.functional as F
from copy import deepcopy

In [81]:
F.cosine_similarity

<function torch._VariableFunctionsClass.cosine_similarity>

In [91]:
from sentence_transformers import SentenceTransformer, LoggingHandler, losses, InputExample
from torch.utils.data import DataLoader

model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')

## model

In [92]:
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

- `model[0]`: `token_embeddings` 
- `model[1]`: `pooling_mode_mean_tokens`
    - `sentence_embedding`: 

## dataloader

In [122]:
train_examples = [
    InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1),
#     InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0)
]

In [147]:
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2)

In [124]:
train_dataloader.collate_fn

<function torch.utils.data._utils.collate.default_collate(batch)>

In [176]:
train_dataloader.collate_fn = model.smart_batching_collate
batch = next(iter(train_dataloader))
batch 

([{'input_ids': tensor([[ 101, 2023, 2003, 1037, 3893, 3940,  102]]),
   'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0]]),
   'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])},
  {'input_ids': tensor([[  101,  2073,  1996,  3292,  2097,  2022, 18478,  2094,   102]]),
   'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0]]),
   'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])}],
 tensor([1]))

## losses

- $(a,b)$: pair sentences embeddings

$$
\frac12|a,b|^2, \ell=1\\
\text{ReLU}^2(\epsilon-|a,b|^2), \ell=0
$$

In [97]:
train_loss = losses.ContrastiveLoss(model=model)

In [98]:
list(train_loss.named_parameters())[0][1].shape

torch.Size([30522, 384])

In [99]:
# SiameseDistanceMetric.COSINE_DISTANCE
train_loss.distance_metric 

<function sentence_transformers.losses.ContrastiveLoss.SiameseDistanceMetric.<lambda>(x, y)>

## model.forward

In [152]:
batch[1]

tensor([1])

In [149]:
batch[0][0]['input_ids'].shape

torch.Size([1, 7])

In [177]:
features, labels = batch
features

[{'input_ids': tensor([[ 101, 2023, 2003, 1037, 3893, 3940,  102]]),
  'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0]]),
  'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])},
 {'input_ids': tensor([[  101,  2073,  1996,  3292,  2097,  2022, 18478,  2094,   102]]),
  'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0]]),
  'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])}]

In [61]:
# model[2](model[1](model[0](features[0])))

In [129]:
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

In [178]:
feature_cpy = deepcopy(features)

In [183]:
model[1](model[0](features[0]))['sentence_embedding'][0, :5]

tensor([-0.2930,  0.3243, -0.6169, -0.0097, -0.1806], grad_fn=<SliceBackward0>)

In [184]:
(torch.sum(model[0](feature_cpy[0])['token_embeddings'], dim=1) / 7)[0, :5]

tensor([-0.2930,  0.3243, -0.6169, -0.0097, -0.1806], grad_fn=<SliceBackward0>)

## forward loss

In [109]:
sent1_embed = model(features[0])['sentence_embedding']
sent2_embed = model(features[1])['sentence_embedding']

In [116]:
# 1 - F.cosine_similarity(sent1_embed, sent2_embed)
train_loss.distance_metric(sent1_embed, sent2_embed)

tensor([0.9867], grad_fn=<RsubBackward1>)

In [117]:
train_loss(features, labels)

tensor(0.4868, grad_fn=<MeanBackward0>)

In [118]:
1/2*(1-F.cosine_similarity(sent1_embed, sent2_embed))**2

tensor([0.4868], grad_fn=<MulBackward0>)

In [None]:
model.fit([(train_dataloader, train_loss)], show_progress_bar=True)