In [None]:
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from datasets import load_dataset
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import random
from collections import defaultdict
from utils import get_data, get_batches, validate_batches
from sklearn.model_selection import train_test_split

In [None]:
def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
set_seed(42)

In [None]:
MODEL_NAME = "cointegrated/rubert-tiny2"
# "petkopetkov/medical-question-answering-all"
#"tom-010/google_natural_questions_answerability"
DATASET_NAME = "tom-010/google_natural_questions_answerability"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_ON_FILTERED = False

In [None]:
dataset = load_dataset(DATASET_NAME)

In [None]:
#X_train, X_test, y_train, y_test = train_test_split(dataset["train"]['input'], dataset["train"]["output"], test_size=0.3, random_state=42)

In [None]:
train_data = [{"question": q, "context": c} for q, c in zip(dataset["train"]['question'], dataset["train"]["answer"]) if c is not None]
valid_data = [{"question": q, "context": c} for q, c in zip(dataset["validation"]['question'], dataset["validation"]["answer"]) if c is not None]

In [None]:
if TRAIN_ON_FILTERED:
    with open("filtered_array.txt", "r") as f:
        filtered = list(map(int, f.read().split()))
    
    indices = set(range(len(train_data))) - set(filtered)
    queries_train, passages_train = get_data(indices, train_data)

    with open("filtered_array_val.txt", "r") as f:
        filtered = list(map(int, f.read().split()))
    
    indices = set(range(len(valid_data))) - set(filtered)
    queries_valid, passages_valid = get_data(indices, valid_data)

else:
    queries_train, passages_train = get_data(range(len(train_data)), train_data)
    queries_valid, passages_valid = get_data(range(len(valid_data)), valid_data)

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self):
        super(ContrastiveLoss, self).__init__()

    def forward(self, query, passage, negative_passages, temperature):
        s_positive = F.cosine_similarity(query, passage, dim=-1) / temperature
        s_negative = F.cosine_similarity(query.unsqueeze(1), negative_passages, dim=-1) / temperature

        exp_for_sum = torch.cat([s_positive.unsqueeze(-1), s_negative], dim=-1)
        log_exp_sum = torch.logsumexp(exp_for_sum, dim=-1)
        
        return (-s_positive + log_exp_sum).mean()

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)

In [None]:
BATCH_SIZE = 128
NUM_EPOCHS = 4
WARMUP_RATIO = 0.1

train_data_batched = get_batches(queries_train, passages_train, BATCH_SIZE)
valid_data_batched = get_batches(queries_valid, passages_valid, BATCH_SIZE)

trainloader = DataLoader(train_data_batched, batch_size=None, collate_fn=lambda x: x, shuffle=True)
validloader = DataLoader(valid_data_batched, batch_size=None, collate_fn=lambda x: x, shuffle=True)

loss_function = ContrastiveLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=9e-5, weight_decay=0.01)

total_steps = len(trainloader) * 5
num_warmup_steps = int(total_steps * WARMUP_RATIO)

scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=total_steps
)

In [None]:
validate_batches(train_data_batched)
validate_batches(valid_data_batched)

In [None]:
model.train()

for epoch in range(NUM_EPOCHS):
    total_loss = 0
    progressBar = tqdm(range(len(trainloader)), desc=f"Epoch {epoch+1}")

    for batch in trainloader:
        query = tokenizer(batch["question"], return_tensors="pt", truncation=True, padding=True).to(DEVICE)
        passage = tokenizer(batch["context"], return_tensors="pt", truncation=True, padding=True).to(DEVICE)

        query_emb = model(**query).last_hidden_state.mean(dim=1)
        passage_emb = model(**passage).last_hidden_state.mean(dim=1)

        negative_passages = []
        for i in range(len(passage_emb)):
            negatives = torch.cat([passage_emb[:i], passage_emb[i + 1:]])
            negative_passages.append(negatives)

        negative_passages = torch.stack(negative_passages)

        loss = loss_function(query_emb, passage_emb, negative_passages, 0.01)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        progressBar.update(1)
        total_loss += loss.item()
        progressBar.set_postfix({"Loss": loss.item()})

    model.eval()
    valid_loss = 0
    with torch.no_grad():
        for batch in validloader:
            query = tokenizer(batch["question"], return_tensors="pt", truncation=True, padding=True).to(DEVICE)
            passage = tokenizer(batch["context"], return_tensors="pt", truncation=True, padding=True).to(DEVICE)

            query_emb = model(**query).last_hidden_state.mean(dim=1)
            passage_emb = model(**passage).last_hidden_state.mean(dim=1)

            negative_passages = []
            for i in range(len(passage_emb)):
                negatives = torch.cat([passage_emb[:i], passage_emb[i + 1:]])
                negative_passages.append(negatives)

            negative_passages = torch.stack(negative_passages)

            loss = loss_function(query_emb, passage_emb, negative_passages, 0.01)
            valid_loss += loss.item()

    model.train()
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(trainloader)}, Valid_Loss: {valid_loss / len(validloader)}")

In [None]:
model.eval()
model.save_pretrained("2_new_rubert-tiny2")
tokenizer.save_pretrained("2_tokenizer_rubert-tiny2")

In [None]:
tokenizer = AutoTokenizer.from_pretrained("./1_tokenizer_rubert-tiny2")
model = AutoModel.from_pretrained("./1_new_rubert-tiny2").to(DEVICE)

In [None]:
model.eval()

progressBar = tqdm(range(len(valid_data)))
queries_emb = []
passages_emb = []

with torch.no_grad():
    for i in range(len(valid_data)):
        query = tokenizer(valid_data[i]["question"], return_tensors="pt", truncation=True, padding=True).to(DEVICE)
        query_emb = model(**query).last_hidden_state.mean(dim=1)
        queries_emb.append(query_emb.cpu())

        passage = tokenizer(valid_data[i]["context"], return_tensors="pt", truncation=True, padding=True).to(DEVICE)
        passage_emb = model(**passage).last_hidden_state.mean(dim=1)
        passages_emb.append(passage_emb.cpu())

        progressBar.update(1)

In [None]:
passages_emb_tuples = [tuple(x[0].numpy().tolist()) for x in passages_emb]

In [None]:
tuple_to_indices = defaultdict(list)
for idx, tup in enumerate(passages_emb_tuples):
    tuple_to_indices[tup].append(idx)

progressBar = tqdm(range(len(valid_data)))
number_range = set(range(len(valid_data)))
pool_size = 500
filtered = []

for i in range(len(valid_data)):
    query_emb = queries_emb[i]
    
    forbidden_indices = set(tuple_to_indices[passages_emb_tuples[i]])
    forbidden_indices.add(i)

    available_indices = list(number_range - forbidden_indices)

    pool = random.sample(available_indices, pool_size)

    top_list = []
    for j in pool:
        passage_emb = passages_emb[j]
        cos_sim = F.cosine_similarity(query_emb, passage_emb, dim=-1).item()
        top_list.append(cos_sim)

    top_list.sort(reverse=True)

    passage_emb = passages_emb[i]
    cos_sim = F.cosine_similarity(query_emb, passage_emb, dim=-1).item()

    if cos_sim < top_list[10]:
        filtered.append(i)

    progressBar.update(1)

with open("filtered_array_val.txt", "w") as f:
    f.write(" ".join(map(str, filtered)))
    
print(len(filtered))
print(len(valid_data))