In [1]:
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from datasets import load_dataset
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import random
from collections import defaultdict

In [2]:
MODEL_NAME = "cointegrated/rubert-tiny2"
DATASET_NAME = "sberquad"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
dataset = load_dataset(DATASET_NAME)
train_data = dataset["train"]

#tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
#model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)

In [4]:
queries = []
passages = []

for i in range(len(train_data)):
    query = train_data[i]["question"]
    queries.append(query)
    
    passage = train_data[i]["context"]
    passages.append(passage)

In [5]:
def get_batches(queries, passages, batch_size):
    batches = []
    pending = []
    current_batch = {"question": [], "context": []}
    current_passages = set()

    def process_pending():
        nonlocal current_batch, current_passages, pending
        new_pending = []
        for q, p in pending:
            if p not in current_passages and len(current_batch["question"]) < batch_size:
                current_batch["question"].append(q)
                current_batch["context"].append(p)
                current_passages.add(p)
            else:
                new_pending.append((q, p))
        pending = new_pending

    for q, p in zip(queries, passages):
        if p in current_passages:
            pending.append((q, p))
        else:
            current_batch["question"].append(q)
            current_batch["context"].append(p)
            current_passages.add(p)

        if len(current_batch["question"]) == batch_size:
            batches.append(current_batch)
            current_batch = {"question": [], "context": []}
            current_passages = set()
            process_pending()
    
    
    while pending and len(current_batch["question"]) < batch_size:
        q, p = pending.pop(0)
        if p not in current_passages:
            current_batch["question"].append(q)
            current_batch["context"].append(p)
            current_passages.add(p)
        else:
            pending.append((q, p))
            if all(p in current_passages for _, p in pending):
                break
            
    if current_batch["question"]:
        batches.append(current_batch)

    return batches

In [6]:
class ContrastiveLoss(nn.Module):
    def __init__(self):
        super(ContrastiveLoss, self).__init__()

    def forward(self, query, passage, negative_passages, temperature):
        s_positive = F.cosine_similarity(query, passage, dim=-1) / temperature
        s_negative = F.cosine_similarity(query.unsqueeze(1), negative_passages, dim=-1) / temperature

        exp_for_sum = torch.cat([s_positive.unsqueeze(-1), s_negative], dim=-1)
        log_exp_sum = torch.logsumexp(exp_for_sum, dim=-1)
        
        return (-s_positive + log_exp_sum).mean()

In [7]:
BATCH_SIZE = 16
NUM_EPOCHS = 10
WARMUP_RATIO = 0.1

train_data_batched = get_batches(queries, passages, BATCH_SIZE)
trainloader = DataLoader(train_data_batched, batch_size=None, collate_fn=lambda x: x, shuffle=True)
loss_function = ContrastiveLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

total_steps = len(trainloader) * NUM_EPOCHS
num_warmup_steps = int(total_steps * WARMUP_RATIO)

scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=total_steps
)

In [8]:
model.train()

for epoch in range(NUM_EPOCHS):
    total_loss = 0
    progressBar = tqdm(range(len(trainloader)), desc=f"Epoch {epoch+1}")

    for batch in trainloader:
        query = tokenizer(batch["question"], return_tensors="pt", truncation=True, padding=True).to(DEVICE)
        passage = tokenizer(batch["context"], return_tensors="pt", truncation=True, padding=True).to(DEVICE)

        query_emb = model(**query).last_hidden_state.mean(dim=1)
        passage_emb = model(**passage).last_hidden_state.mean(dim=1)

        negative_passages = []
        for i in range(len(passage_emb)):
            negatives = torch.cat([passage_emb[:i], passage_emb[i + 1:]])
            negative_passages.append(negatives)

        negative_passages = torch.stack(negative_passages)

        loss = loss_function(query_emb, passage_emb, negative_passages, 0.01)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        progressBar.update(1)
        total_loss += loss.item()
        progressBar.set_postfix({"Loss": loss.item()})

    print(f"Epoch {epoch+1}, Loss: {total_loss / len(trainloader)}")

Epoch 1:   0%|          | 0/2832 [00:00<?, ?it/s]

Epoch 1, Loss: 0.21011831918678714


Epoch 2:   0%|          | 0/2832 [00:00<?, ?it/s]

Epoch 2, Loss: 0.0638824780761017


Epoch 3:   0%|          | 0/2832 [00:00<?, ?it/s]

Epoch 3, Loss: 0.01893964079820346


Epoch 4:   0%|          | 0/2832 [00:00<?, ?it/s]

Epoch 4, Loss: 0.010464606132438225


Epoch 5:   0%|          | 0/2832 [00:00<?, ?it/s]

Epoch 5, Loss: 0.0057383512934275184


Epoch 6:   0%|          | 0/2832 [00:00<?, ?it/s]

Epoch 6, Loss: 0.002681931875744408


Epoch 7:   0%|          | 0/2832 [00:00<?, ?it/s]

Epoch 7, Loss: 0.002074575794624848


Epoch 8:   0%|          | 0/2832 [00:00<?, ?it/s]

Epoch 8, Loss: 0.00044417231994478855


Epoch 9:   0%|          | 0/2832 [00:00<?, ?it/s]

Epoch 9, Loss: 0.0002492921500556454


Epoch 10:   0%|          | 0/2832 [00:00<?, ?it/s]

Epoch 10, Loss: 5.8491878814859676e-05


In [None]:
#model.eval()
#model.save_pretrained("new_rubert-tiny2")
#tokenizer.save_pretrained("tokenizer_rubert-tiny2")

('tokenizer_rubert-tiny2\\tokenizer_config.json',
 'tokenizer_rubert-tiny2\\special_tokens_map.json',
 'tokenizer_rubert-tiny2\\vocab.txt',
 'tokenizer_rubert-tiny2\\added_tokens.json',
 'tokenizer_rubert-tiny2\\tokenizer.json')

In [4]:
tokenizer = AutoTokenizer.from_pretrained("tokenizer_rubert-tiny2")
model = AutoModel.from_pretrained("new_rubert-tiny2").to(DEVICE)
model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(83828, 312, padding_idx=0)
    (position_embeddings): Embedding(2048, 312)
    (token_type_embeddings): Embedding(2, 312)
    (LayerNorm): LayerNorm((312,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-2): 3 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=312, out_features=312, bias=True)
            (key): Linear(in_features=312, out_features=312, bias=True)
            (value): Linear(in_features=312, out_features=312, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=312, out_features=312, bias=True)
            (LayerNorm): LayerNorm((312,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)

In [5]:
progressBar = tqdm(range(len(train_data)))
queries_emb = []
passages_emb = []

with torch.no_grad():
    for i in range(len(train_data)):
        query = tokenizer(train_data[i]["question"], return_tensors="pt", truncation=True, padding=True).to(DEVICE)
        query_emb = model(**query).last_hidden_state.mean(dim=1)
        queries_emb.append(query_emb.cpu())

        passage = tokenizer(train_data[i]["context"], return_tensors="pt", truncation=True, padding=True).to(DEVICE)
        passage_emb = model(**passage).last_hidden_state.mean(dim=1)
        passages_emb.append(passage_emb.cpu())

        progressBar.update(1)

  0%|          | 0/45328 [00:00<?, ?it/s]

In [6]:
passages_emb_tuples = [tuple(x[0].numpy().tolist()) for x in passages_emb]

In [7]:
tuple_to_indices = defaultdict(list)
for idx, tup in enumerate(passages_emb_tuples):
    tuple_to_indices[tup].append(idx)

progressBar = tqdm(range(len(train_data)))
number_range = set(range(len(train_data)))
pool_size = 5000
filtered = []

for i in range(len(train_data)):
    query_emb = queries_emb[i]
    
    forbidden_indices = set(tuple_to_indices[passages_emb_tuples[i]])
    forbidden_indices.add(i)

    available_indices = list(number_range - forbidden_indices)

    pool = random.sample(available_indices, pool_size)

    top_list = []
    for j in pool:
        passage_emb = passages_emb[j]
        cos_sim = F.cosine_similarity(query_emb, passage_emb, dim=-1).item()
        top_list.append(cos_sim)

    top_list.sort(reverse=True)

    passage_emb = passages_emb[i]
    cos_sim = F.cosine_similarity(query_emb, passage_emb, dim=-1).item()

    if cos_sim < top_list[1]:
        filtered.append(i)

    progressBar.update(1)

  0%|          | 0/45328 [00:00<?, ?it/s]

In [8]:
print(len(filtered))
print(len(train_data))

8699
45328


In [9]:
with open("filtered_array.txt", "w") as f:
    f.write(" ".join(map(str, filtered)))