In [1]:
# Design a siamese network to learn the similarity between two sentences
# It uses BERT embeddings to encode the sentences
# It need to support the triplet loss function

from torch import nn
from transformers import BertModel

criteria = nn.TripletMarginLoss(margin=1.0, p=2)

class SiameseNetwork(nn.Module):
    def __init__(self, bert_model):
        super(SiameseNetwork, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model)
        self.fc = nn.Linear(768, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask)
        pooled_output = outputs[1]
        return self.sigmoid(self.fc(pooled_output))

    def encode(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask)
        return outputs[1]

    def predict(self, input_ids1, attention_mask1, input_ids2, attention_mask2):
        return self.forward(input_ids1, attention_mask1), self.forward(input_ids2, attention_mask2)

    def triplet_loss(self, input_ids1, attention_mask1, input_ids2, attention_mask2, input_ids3, attention_mask3, margin):
        # Compute the embeddings for the three inputs
        emb1 = self.encode(input_ids1, attention_mask1)
        emb2 = self.encode(input_ids2, attention_mask2)
        emb3 = self.encode(input_ids3, attention_mask3)

        # Compute the distances between the embeddings
        dist_pos = nn.functional.pairwise_distance(emb1, emb2)
        dist_neg = nn.functional.pairwise_distance(emb1, emb3)

        # Compute the triplet loss
        loss = nn.functional.relu(dist_pos - dist_neg + margin)

        return loss.mean()



In [12]:
from transformers import BertTokenizer
from transformers import BertModel
from torch import nn

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
criteria = nn.TripletMarginLoss(margin=1.0, p=2)

def encode_sentence(sentence):
    tokens = tokenizer(sentence, add_special_tokens=True, return_tensors='pt', max_length=128, padding='max_length', truncation=True)
    return tokens['input_ids'], tokens['attention_mask']

anchor = "Trucks are awesome"
positive = "Pigs are cool animals"
negative = "I like trucks"

input_ids1, attention_mask1 = encode_sentence(anchor)
input_ids2, attention_mask2 = encode_sentence(positive)
input_ids3, attention_mask3 = encode_sentence(negative)

anchoer_emb = model(input_ids1, attention_mask1)[1] # Gets the pooled output
positive_emb = model(input_ids2, attention_mask2)[1]
negative_emb = model(input_ids3, attention_mask3)[1]

loss = criteria(anchoer_emb, positive_emb, negative_emb)
print(loss) # tensor(2.3362, grad_fn=<MeanBackward0>)

tensor(2.3362, grad_fn=<MeanBackward0>)
