In [1]:
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from datasets import load_dataset
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

In [2]:
MODEL_NAME = "cointegrated/rubert-tiny2"
DATASET_NAME = "sberquad"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
dataset = load_dataset(DATASET_NAME)
train_data = dataset["train"]

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)

In [4]:
class ContrastiveLoss(nn.Module):
    def __init__(self):
        super(ContrastiveLoss, self).__init__()

    def forward(self, query, passage, negative_passages, temperature):
        s_positive = F.cosine_similarity(query, passage, dim=-1) / temperature
        s_negative = F.cosine_similarity(query.unsqueeze(1), negative_passages, dim=-1) / temperature

        exp_for_sum = torch.cat([s_positive.unsqueeze(-1), s_negative], dim=-1)
        log_exp_sum = torch.logsumexp(exp_for_sum, dim=-1)
        
        return (-s_positive + log_exp_sum).mean()

In [5]:
BATCH_SIZE = 16
NUM_EPOCHS = 10
WARMUP_RATIO = 0.1

trainloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
loss_function = ContrastiveLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

total_steps = len(trainloader) * NUM_EPOCHS
num_warmup_steps = int(total_steps * WARMUP_RATIO)

scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=total_steps
)

In [6]:
model.train()

for epoch in range(NUM_EPOCHS):
    total_loss = 0
    progressBar = tqdm(range(len(trainloader)), desc=f"Epoch {epoch+1}")

    for batch in trainloader:
        query = tokenizer(batch["question"], return_tensors="pt", truncation=True, padding=True).to(DEVICE)
        passage = tokenizer(batch["context"], return_tensors="pt", truncation=True, padding=True).to(DEVICE)

        query_emb = model(**query).last_hidden_state.mean(dim=1)
        passage_emb = model(**passage).last_hidden_state.mean(dim=1)

        negative_passages = []
        for i in range(len(passage_emb)):
            negatives = torch.cat([passage_emb[:i], passage_emb[i + 1:]])
            negative_passages.append(negatives)

        negative_passages = torch.stack(negative_passages)

        loss = loss_function(query_emb, passage_emb, negative_passages, 0.01)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        progressBar.update(1)
        total_loss += loss.item()
        progressBar.set_postfix({"Loss": loss.item()})

    print(f"Epoch {epoch+1}, Loss: {total_loss / len(trainloader)}")

Epoch 1:   0%|          | 0/2833 [00:00<?, ?it/s]

Epoch 1, Loss: 0.20955448313900576


Epoch 2:   0%|          | 0/2833 [00:00<?, ?it/s]

Epoch 2, Loss: 0.09102391052111442


Epoch 3:   0%|          | 0/2833 [00:00<?, ?it/s]

Epoch 3, Loss: 0.04133797020109307


Epoch 4:   0%|          | 0/2833 [00:00<?, ?it/s]

Epoch 4, Loss: 0.029266249615944952


Epoch 5:   0%|          | 0/2833 [00:00<?, ?it/s]

Epoch 5, Loss: 0.021405379173825498


Epoch 6:   0%|          | 0/2833 [00:00<?, ?it/s]

Epoch 6, Loss: 0.014963899279835786


Epoch 7:   0%|          | 0/2833 [00:00<?, ?it/s]

Epoch 7, Loss: 0.011368514790452498


Epoch 8:   0%|          | 0/2833 [00:00<?, ?it/s]

Epoch 8, Loss: 0.009662995353588315


Epoch 9:   0%|          | 0/2833 [00:00<?, ?it/s]

Epoch 9, Loss: 0.009089007884253472


Epoch 10:   0%|          | 0/2833 [00:00<?, ?it/s]

Epoch 10, Loss: 0.007146501221338683


In [7]:
model.save_pretrained("new_rubert-tiny2")
tokenizer.save_pretrained("tokenizer_rubert-tiny2")

('tokenizer_rubert-tiny2\\tokenizer_config.json',
 'tokenizer_rubert-tiny2\\special_tokens_map.json',
 'tokenizer_rubert-tiny2\\vocab.txt',
 'tokenizer_rubert-tiny2\\added_tokens.json',
 'tokenizer_rubert-tiny2\\tokenizer.json')