In [1]:
# !pip install sentence-transformers -q

In [2]:
import os
import time
import ast
import gc
import random
import warnings
import multiprocessing as mp
from tqdm import tqdm
from types import SimpleNamespace

import numpy as np
import pandas as pd
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import Sampler
from torch.nn.utils.rnn import pad_sequence
from torch.cuda.amp import GradScaler, autocast
import torch.nn.functional as F
from sklearn.model_selection import GroupKFold
from sklearn.metrics import accuracy_score

from transformers import (
    AutoTokenizer,
    AutoModel,
    AutoConfig,
    get_cosine_schedule_with_warmup,
    get_linear_schedule_with_warmup,
)

from sentence_transformers import SentenceTransformer

warnings.filterwarnings('ignore')
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
warnings.filterwarnings('ignore')

# Read data

In [3]:
from datasets import load_dataset
ds = load_dataset("sentence-transformers/all-nli", "pair", cache_dir="../cache")
print(ds)

DatasetDict({
    train: Dataset({
        features: ['anchor', 'positive'],
        num_rows: 314315
    })
    dev: Dataset({
        features: ['anchor', 'positive'],
        num_rows: 6808
    })
    test: Dataset({
        features: ['anchor', 'positive'],
        num_rows: 6831
    })
})


In [4]:
train = ds['train']
dev = ds['dev']
test = ds['test']

dev.to_pandas()

Unnamed: 0,anchor,positive
0,Two women are embracing while holding to go pa...,Two woman are holding packages.
1,"Two young children in blue jerseys, one with t...",Two kids in numbered jerseys wash their hands.
2,A man selling donuts to a customer during a wo...,A man selling donuts to a customer.
3,Two young boys of opposing teams play football...,boys play football
4,A man in a blue shirt standing in front of a g...,A man is wearing a blue shirt
...,...,...
6803,"Under Ferdinand and Isabella, Spain underwent ...",Ferdinand and Isabella caused stunning changes...
6804,Kyoto's kabuki troupe performs in December and...,Kyoto has a kabuki troupe and so does Osaka.
6805,7) Nonautomated First-Class and Standard-A mai...,Nonautomated First-Class and Standard-A mailer...
6806,"Finally, the FDA will conduct workshops, issue...",The FDA is set to conduct workshops.


# Baseline

In [5]:
def calc_top_k_accuracy(top_k_preds, true_labels):
    binary_label_masks = []
    for top_k_pred, true_label in zip(top_k_preds, true_labels):
        if true_label in top_k_pred:
            binary_label_masks.append(1)
        else:
            binary_label_masks.append(0)
    accuracy = np.mean(binary_label_masks)
    return accuracy

In [6]:
model_paths = [
    'intfloat/multilingual-e5-small',
    'intfloat/multilingual-e5-base',
    # 'intfloat/multilingual-e5-large',
    # 'bkai-foundation-models/vietnamese-bi-encoder',
    # 'VoVanPhuc/sup-SimCSE-VietNamese-phobert-base'
]

print("=====> BEFORE FINETUNE <=====")
for model_path in model_paths:
    embed_model = SentenceTransformer(model_path, cache_folder="../cache")

    desc_embed = embed_model.encode(
        dev['anchor'],
        batch_size=16,
        device='cuda',
        convert_to_tensor=True,
        normalize_embeddings=True,
        show_progress_bar=True
    )
    print(desc_embed.shape)

    top_1_preds = []
    top_3_preds = []
    top_5_preds = []
    labels = []
    for idx, query in enumerate(dev['positive']):
        question_embed = embed_model.encode(query, device='cuda', convert_to_tensor=True, normalize_embeddings=True)
        scores = question_embed @ desc_embed.T
        # print(scores.shape)
        top_1_pred = scores.argmax(dim=-1).cpu().numpy().tolist()
        top_3_pred = scores.topk(3).indices.cpu().tolist()
        top_5_pred = scores.topk(5).indices.cpu().tolist()
        
        top_1_preds.append(top_1_pred)
        top_3_preds.append(top_3_pred)
        top_5_preds.append(top_5_pred)
        labels.append(idx)

        # print(idx, top_1_pred, top_3_pred, top_5_pred)
        # if idx > 100:
        #     break

    top_1_acc = accuracy_score(top_1_preds, labels)
    top_3_acc = calc_top_k_accuracy(top_3_preds, labels)
    top_5_acc = calc_top_k_accuracy(top_5_preds, labels)

    print(f'### {model_path}')
    print(f'Accuracy@1: {top_1_acc}')
    print(f'Accuracy@3: {top_3_acc}')
    print(f'Accuracy@5: {top_5_acc}\n')

=====> BEFORE FINETUNE <=====


Batches:   0%|          | 0/426 [00:00<?, ?it/s]

torch.Size([6808, 384])
### intfloat/multilingual-e5-small
Accuracy@1: 0.6429200940070505
Accuracy@3: 0.8043478260869565
Accuracy@5: 0.8425381903642774



Batches:   0%|          | 0/426 [00:00<?, ?it/s]

torch.Size([6808, 768])
### intfloat/multilingual-e5-base
Accuracy@1: 0.6527614571092832
Accuracy@3: 0.814042303172738
Accuracy@5: 0.8534077555816686



# Fine-tuning

In [7]:
cfg = {
    'model_name': 'intfloat/multilingual-e5-small',
    'batch_size': 32,
    'max_length': 512,
    'epochs': 10,
    'learning_rate': 2e-4,
    'warmup_steps': 0,
    'weight_decay': 0.1,
    'intermediate_dropout': 0.,
    'num_workers': mp.cpu_count(),
    'seed': 42,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

cfg = SimpleNamespace(**cfg)

In [8]:
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, cache_dir='../cache/')

In [9]:
encodings = tokenizer(
    dev['anchor'],
    padding='max_length',
    truncation=True,
    max_length=cfg.max_length,
    return_tensors='pt'
)

print(encodings)
print(encodings.input_ids.shape)

{'input_ids': tensor([[     0,  32964,  24793,  ...,      1,      1,      1],
        [     0,  32964,  27150,  ...,      1,      1,      1],
        [     0,     62,    332,  ...,      1,      1,      1],
        ...,
        [     0,  49413,   3775,  ...,      1,      1,      1],
        [     0, 201106,      4,  ...,      1,      1,      1],
        [     0, 114765,    944,  ...,      1,      1,      1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}
torch.Size([6808, 512])


In [10]:
class EmbedDataset(Dataset):
	def __init__(self, encodings_1, encodings_2):
		self.encodings_1 = encodings_1
		self.encodings_2 = encodings_2

	def __getitem__(self, idx):
		item = {f'{key}_1': torch.tensor(val[idx]) for key, val in self.encodings_1.items()}
		item.update(
            {f'{key}_2': torch.tensor(val[idx]) for key, val in self.encodings_2.items()}
        )
		return item

	def __len__(self):
		return self.encodings_1.input_ids.shape[0]

In [11]:
def get_dataloader(tokenizer, questions, descriptions, mode, batch_size, max_length, num_workers):

	encodings_1 = tokenizer(
		questions,
		padding='max_length',
		truncation=True,
		max_length=max_length,
		return_tensors='pt'
	)

	encodings_2 = tokenizer(
		descriptions,
		padding='max_length',
		truncation=True,
		max_length=max_length,
		return_tensors='pt'
	)

	dataset = EmbedDataset(encodings_1, encodings_2)

	if mode == 'train':
		data_loader = DataLoader(
			dataset=dataset,
			batch_size=batch_size,
			drop_last=True,
			shuffle=True,
			num_workers=num_workers
		)

	else:
		data_loader = DataLoader(
			dataset=dataset,
			batch_size=batch_size,
			drop_last=False,
			shuffle=False,
			num_workers=num_workers
		)

	return data_loader

In [12]:
api_questions = dev['anchor']
api_descriptions = dev['positive']

train_dataloader = get_dataloader(
    tokenizer=tokenizer,
    questions=api_questions,
    descriptions=api_descriptions,
    mode='train',
    batch_size=cfg.batch_size,
    max_length=cfg.max_length,
    num_workers=cfg.num_workers,
)

for batch in train_dataloader:
    print(batch)
    anchor = batch['input_ids_1']
    positive = batch['input_ids_2']
    mask_1 = batch['attention_mask_1']
    mask_2 = batch['attention_mask_2']
    print(anchor.shape, positive.shape)
    print(mask_1.shape, mask_2.shape)
    break

{'input_ids_1': tensor([[    0,  3493,  2412,  ...,     1,     1,     1],
        [    0,    62, 21115,  ...,     1,     1,     1],
        [    0, 32964, 22556,  ...,     1,     1,     1],
        ...,
        [    0, 78289,  5510,  ...,     1,     1,     1],
        [    0,    62,  4000,  ...,     1,     1,     1],
        [    0,    17,  1600,  ...,     1,     1,     1]]), 'attention_mask_1': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'input_ids_2': tensor([[    0,   581,  3445,  ...,     1,     1,     1],
        [    0, 41021,   621,  ...,     1,     1,     1],
        [    0, 32964, 10269,  ...,     1,     1,     1],
        ...,
        [    0,  8622,   509,  ...,     1,     1,     1],
        [    0,    62,  4000,  ...,     1,     1,     1],
        [    0,  4263,   398,  ...,     1,     1,     1]]), 'att

In [13]:
def set_seed(seed=318):
	random.seed(seed)
	np.random.seed(seed)
	os.environ['PYTHONHASHSEED'] = str(seed)
	# os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
	torch.manual_seed(seed)
	torch.cuda.manual_seed(seed)
	torch.backends.cudnn.deterministic = True
	torch.backends.cudnn.benchmark = False
	# torch.use_deterministic_algorithms(True)

In [14]:
class MultiNegativesRankingLoss(nn.Module):
    """
    Ref: https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/losses/MultipleNegativesRankingLoss.py
    """
    def __init__(self, scale=50):
        super().__init__()
        self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
        self.scale = scale

    def forward(self, embed_1, embed_2, labels=None):
        cosine_scores = (
            F.normalize(embed_1) @ F.normalize(embed_2).T
        ) * self.scale

        labels = torch.tensor(
            range(len(cosine_scores)),
            dtype=torch.long,
            device=cosine_scores.device
        )

        loss = self.cross_entropy(cosine_scores, labels)
        return loss


loss_fn = MultiNegativesRankingLoss()

In [15]:
class TextMeanPooling(nn.Module):
    def __init__(self, eps=1e-06):
        super(TextMeanPooling, self).__init__()
        self.eps = eps

    def forward(self, token_embeddings, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        mean_embeds = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=self.eps)
        return mean_embeds

In [16]:
class EmbedModel(nn.Module):
    def __init__(self, cfg):
        super(EmbedModel, self).__init__()

        config = AutoConfig.from_pretrained(cfg.model_name, cache_dir='../cache')
        config.attention_probs_dropout_prob = cfg.intermediate_dropout
        config.hidden_dropout_prob = cfg.intermediate_dropout

        self.backbone = AutoModel.from_pretrained(cfg.model_name, config=config, cache_dir='../cache')
        self.backbone.gradient_checkpointing_enable()

        self.pooler = TextMeanPooling()
        self.loss_fn = MultiNegativesRankingLoss()


    def forward(self, input_ids_1, attention_mask_1, input_ids_2, attention_mask_2):
        embed_1 = self.backbone(input_ids_1, attention_mask_1).last_hidden_state
        embed_2 = self.backbone(input_ids_2, attention_mask_2).last_hidden_state
        # print("EMBED", embed_1.shape)  # (bs, max_length, embed_dim)

        x_1 = self.pooler(embed_1, attention_mask_1)
        x_2 = self.pooler(embed_2, attention_mask_2)
        # print("MEAN EMBED", x_1.shape)  # (bs, embed_dim)

        loss = self.loss_fn(x_1, x_2)

        return loss

In [17]:
set_seed(cfg.seed)
start_time = time.time()
scaler = GradScaler()

# Init model
model = EmbedModel(cfg)
model.to(cfg.device)
model.train()

# Init optim
optimizer = optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=cfg.warmup_steps,
    num_training_steps=len(train_dataloader)*cfg.epochs
)

In [18]:
# Fit
for epoch in range(cfg.epochs):
    for batch_idx, batch in enumerate(train_dataloader):
        input_ids_1 = batch['input_ids_1'].to(cfg.device)
        attention_mask_1 = batch['attention_mask_1'].to(cfg.device)
        input_ids_2 = batch['input_ids_2'].to(cfg.device)
        attention_mask_2 = batch['attention_mask_2'].to(cfg.device)

        with autocast():
            loss = model(
                input_ids_1=input_ids_1,
                attention_mask_1=attention_mask_1,
                input_ids_2=input_ids_2,
                attention_mask_2=attention_mask_2
            )

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        scheduler.step()

        if not batch_idx % 10:
            print(
                f'Epoch: {epoch + 1}/{cfg.epochs}'
                f' | Batch: {batch_idx}/{len(train_dataloader)}'
                f' | Loss: {loss.detach().cpu().item():.4f}')

Epoch: 1/10 | Batch: 0/212 | Loss: 0.1852
Epoch: 1/10 | Batch: 10/212 | Loss: 0.0195
Epoch: 1/10 | Batch: 20/212 | Loss: 0.0920
Epoch: 1/10 | Batch: 30/212 | Loss: 0.1731
Epoch: 1/10 | Batch: 40/212 | Loss: 0.2456
Epoch: 1/10 | Batch: 50/212 | Loss: 0.3065
Epoch: 1/10 | Batch: 60/212 | Loss: 0.1672
Epoch: 1/10 | Batch: 70/212 | Loss: 0.0637
Epoch: 1/10 | Batch: 80/212 | Loss: 0.1412
Epoch: 1/10 | Batch: 90/212 | Loss: 0.1030
Epoch: 1/10 | Batch: 100/212 | Loss: 0.1785
Epoch: 1/10 | Batch: 110/212 | Loss: 0.1163
Epoch: 1/10 | Batch: 120/212 | Loss: 0.1496
Epoch: 1/10 | Batch: 130/212 | Loss: 0.0328
Epoch: 1/10 | Batch: 140/212 | Loss: 0.0671
Epoch: 1/10 | Batch: 150/212 | Loss: 0.0993
Epoch: 1/10 | Batch: 160/212 | Loss: 0.1009
Epoch: 1/10 | Batch: 170/212 | Loss: 0.0901
Epoch: 1/10 | Batch: 180/212 | Loss: 0.0720
Epoch: 1/10 | Batch: 190/212 | Loss: 0.1213
Epoch: 1/10 | Batch: 200/212 | Loss: 0.1278
Epoch: 1/10 | Batch: 210/212 | Loss: 0.0685
Epoch: 2/10 | Batch: 0/212 | Loss: 0.0065
E

In [19]:
class ValDataset(Dataset):
	def __init__(self, encodings):
		self.encodings = encodings

	def __getitem__(self, idx):
		item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
		return item

	def __len__(self):
		return self.encodings['input_ids'].shape[0]


encodings_1 = tokenizer(
    dev['positive'],
    padding='max_length',
    truncation=True,
    max_length=cfg.max_length,
    return_tensors='pt'
)


val_dataset_1 = ValDataset(encodings_1)
val_dataloader_1 = DataLoader(
    dataset=val_dataset_1,
    batch_size=cfg.batch_size,
    drop_last=False,
    shuffle=False,
    num_workers=cfg.num_workers
)


encodings_2 = tokenizer(
    dev['anchor'],
    padding='max_length',
    truncation=True,
    max_length=cfg.max_length,
    return_tensors='pt'
)


val_dataset_2 = ValDataset(encodings_2)
val_dataloader_2 = DataLoader(
    dataset=val_dataset_2,
    batch_size=cfg.batch_size,
    drop_last=False,
    shuffle=False,
    num_workers=cfg.num_workers
)

In [29]:
pooler = TextMeanPooling()
desc_embed = torch.tensor([], device=cfg.device)
with torch.no_grad():
    model.eval()
    for input_ids, attention_mask in zip(encodings_1.input_ids, encodings_1.attention_mask):
        # Expand dims
        input_ids = torch.unsqueeze(input_ids, 0).to(cfg.device)
        attention_mask = torch.unsqueeze(attention_mask, 0).to(cfg.device)
        # Forward
        embed = model.backbone(input_ids, attention_mask).last_hidden_state
        mean_embed = pooler(embed, attention_mask)
        mean_embed = F.normalize(mean_embed, dim=-1)
        desc_embed = torch.cat((desc_embed, mean_embed), dim=0)

    print(input_ids.shape)
    print(embed.shape)
    print(mean_embed.shape)
    print(desc_embed.shape)

torch.Size([1, 512])
torch.Size([1, 512, 384])
torch.Size([1, 384])
torch.Size([6808, 384])


In [30]:
with torch.no_grad():
    model.eval()
    top_1_preds = []
    top_3_preds = []
    top_5_preds = []
    labels = []
    for idx, (input_ids, attention_mask) in enumerate(zip(encodings_2.input_ids, encodings_2.attention_mask)):
        # Expand dims
        input_ids = torch.unsqueeze(input_ids, 0).to(cfg.device)
        attention_mask = torch.unsqueeze(attention_mask, 0).to(cfg.device)
        # Forward
        question_embed = model.backbone(input_ids, attention_mask).last_hidden_state
        question_embed = pooler(question_embed, attention_mask)
        question_embed = F.normalize(question_embed, dim=-1)
        scores = question_embed @ desc_embed.T

        top_1_pred = scores.argmax(dim=-1).cpu().numpy().tolist()
        top_3_pred = scores.topk(3).indices.cpu().tolist()
        top_5_pred = scores.topk(5).indices.cpu().tolist()

        top_1_preds.extend(top_1_pred)
        top_3_preds.extend(top_3_pred)
        top_5_preds.extend(top_5_pred)
        labels.append(idx)

top_1_acc = accuracy_score(top_1_preds, labels)
top_3_acc = calc_top_k_accuracy(top_3_preds, labels)
top_5_acc = calc_top_k_accuracy(top_5_preds, labels)

print("=====> AFTER FINETUNE <=====")
print(f'### {cfg.model_name}')
print(f'Accuracy@1: {top_1_acc}')
print(f'Accuracy@3: {top_3_acc}')
print(f'Accuracy@5: {top_5_acc}')

=====> AFTER FINETUNE <=====
### intfloat/multilingual-e5-small
Accuracy@1: 0.7501468860164512
Accuracy@3: 0.9225910693301997
Accuracy@5: 0.9578437132784959
