In [1]:
import argparse
from transformers import AutoTokenizer, TrainingArguments
from pathlib import Path
from tqdm import tqdm
from pathlib import Path
from torch.utils.data import Dataset
import json
from ir_measures import read_trec_run
from collections import Counter
from tqdm import tqdm
from pathlib import Path
from torch.utils.data import Dataset
from transformers import Trainer
from torch.utils.data import DataLoader
from collections import defaultdict
from tqdm import tqdm
import torch
import ir_measures
from ir_measures import *
import os
from transformers.utils import WEIGHTS_NAME
import logging
from torch import nn
from transformers import AutoModel

TRAINING_ARGS_NAME = "training_args.bin"
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

In [2]:
def read_pairs(path: str):
    pairs = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in tqdm(f.readlines(), desc=f'reading pairs from {Path(path).name}'):
            qid, did = line.strip().split('\t')
            pairs.append((qid, did))
    return pairs

def read_triplets(path: str):
    triplets = []
    with open(path, 'r', encoding='utf-8') as f:
        for line in tqdm(f.readlines(), desc=f'reading triplets from {Path(path).name}'):
            qid, pos_id, neg_id = line.strip().split('\t')
            triplets.append((qid, pos_id, neg_id))
    return triplets

In [3]:
class PairDataset(Dataset):

    def __init__(
        self,
        collection_path: str,
        queries_path: str,
        query_doc_pair_path: str,
        qrels_path: str = None,
        top_k: int = 100,
    ):
        self.collection = dict(read_pairs(collection_path))
        self.queries = dict(read_pairs(queries_path))
        with open(qrels_path, 'r') as r:
            self.qrels = json.load(r)
        self.pairs = []
        query_count = {}
        for pair in read_trec_run(query_doc_pair_path):
            q, d = pair.query_id, pair.doc_id
            if q not in query_count:
                query_count[q] = 1
            elif query_count[q] < top_k:
                query_count[q] += 1
            self.pairs.append((q, d))
        self.top_k = min([max(Counter(pair[0] for pair in self.pairs).values()), top_k])
    
    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        query_id, doc_id = self.pairs[idx]
        query_text = self.queries[query_id]
        doc_text = self.collection[doc_id]
        return query_id, doc_id, query_text, doc_text

In [4]:
class TripletDataset(Dataset):

    def __init__(self, collection_path: str, queries_path: str, train_triplets_path: str):
        self.collection = {}
        with open(collection_path, "r", encoding="utf-8") as f:
            for line in f:
                id_, text = line.rstrip("\n").split("\t")
                self.collection[id_] = text

        self.queries = {}
        with open(queries_path, "r", encoding="utf-8") as f:
            for line in f:
                id_, text = line.rstrip("\n").split("\t")
                self.queries[id_] = text

        self.triplets = read_triplets(train_triplets_path)

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        qid, pid, nid = self.triplets[idx]
        q_text = self.queries[qid]
        p_text = self.collection[pid]
        n_text = self.collection[nid]
        return q_text, p_text, n_text

In [5]:
# Import the cross attention layer from the transformers library
from transformers import BertModel

# Modify the MyDenseBiEncoder class to use cross attention
class MyDenseBiEncoder(nn.Module):

    def __init__(self, model_name_or_dir, n_heads) -> None:
        super().__init__()
        # Use a pre-trained model with a cross attention layer
        self.model = BertModel.from_pretrained(model_name_or_dir, add_cross_attention=True)
        self.loss = nn.CrossEntropyLoss()
        # Add a multi head attention layer after the model
        self.attention = MultiHeadAttention(n_heads)

    def encode(self, input_ids, attention_mask, **kwargs):
        # Get the query and document inputs from the kwargs
        query_input_ids = kwargs["query_input_ids"]
        query_attention_mask = kwargs["query_attention_mask"]
        # Pass the query and document inputs to the model with cross attention
        outputs = self.model(input_ids, attention_mask=attention_mask, encoder_hidden_states=query_input_ids, encoder_attention_mask=query_attention_mask, **kwargs)
        # Get the last hidden state of the model
        last_hidden_state = outputs.last_hidden_state
        # Mask padding tokens
        mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size())
        last_hidden_state = last_hidden_state.masked_fill(mask == 0, 0.0)
        # Apply the multi head attention layer to the last hidden state
        attention_output = self.attention(last_hidden_state, last_hidden_state, last_hidden_state, mask)
        # Avg over hidden states to get document representation
        sum_attention_output = torch.sum(attention_output, dim=1)
        sum_attention_mask = torch.clamp(torch.sum(attention_mask, dim=1), min=1e-9)
        mean_attention_output = sum_attention_output / sum_attention_mask.unsqueeze(-1)
        return mean_attention_output

    def score_pairs(self, queries, docs):
        q_vectors = self.encode(queries.input_ids, queries.attention_mask)
        d_vectors = self.encode(docs.input_ids, docs.attention_mask)
        cos = nn.CosineSimilarity(dim=1, eps=1e-6)
        scores = cos(q_vectors, d_vectors)
        return scores

    def forward(self, queries, pos_docs, neg_docs):
        pos_scores = self.score_pairs(queries, pos_docs)
        neg_scores = self.score_pairs(queries, neg_docs)
        loss = self.loss(torch.cat((pos_scores, neg_scores), dim=0), torch.cat((torch.ones_like(pos_scores), torch.zeros_like(neg_scores)), dim=0))
        return loss, pos_scores, neg_scores

    def save_pretrained(self, model_dir, state_dict=None):
        self.model.save_pretrained(model_dir, state_dict=state_dict)

    @classmethod
    def from_pretrained(cls, model_name_or_dir, n_heads):
        return cls(model_name_or_dir, n_heads)

In [6]:
class BiEncoderTripletCollator:
    def __init__(self, tokenizer, query_max_length, doc_max_length):
        self.tokenizer = tokenizer
        self.query_max_length = query_max_length
        self.doc_max_length = doc_max_length

    def __call__(self, batch):
        queries = []
        pos_docs = []
        neg_docs = []
        for query, pos, neg in batch:
            queries.append(query)
            pos_docs.append(pos)
            neg_docs.append(neg)
        queries = self.tokenizer(
            queries,
            padding=True,
            truncation=True,
            max_length=self.query_max_length,
            return_tensors="pt",
        )
        pos_docs = self.tokenizer(
            pos_docs,
            padding=True,
            truncation=True,
            max_length=self.doc_max_length,
            return_tensors="pt",
        )
        neg_docs = self.tokenizer(
            neg_docs,
            padding=True,
            truncation=True,
            max_length=self.doc_max_length,
            return_tensors="pt",
        )
        return {"queries": queries, "pos_docs": pos_docs, "neg_docs": neg_docs}

In [7]:
class BiEncoderPairCollator:
    def __init__(self, tokenizer, query_max_length, doc_max_length):
        self.tokenizer = tokenizer
        self.query_max_length = query_max_length
        self.doc_max_length = doc_max_length

    def __call__(self, batch):
        queries = []
        docs = []
        query_ids = []
        doc_ids = []
        for qid, did, query, doc in batch:
            query_ids.append(qid)
            doc_ids.append(did)
            queries.append(query)
            docs.append(doc)
        queries = self.tokenizer(
            queries,
            padding=True,
            truncation=True,
            max_length=self.query_max_length,
            return_tensors="pt",
        )
        docs = self.tokenizer(
            docs,
            padding=True,
            truncation=True,
            max_length=self.doc_max_length,
            return_tensors="pt",
        )
        return {
            "query_ids": query_ids,
            "doc_ids": doc_ids,
            "queries": queries,
            "docs": docs,
        }

In [8]:
pretrained = "distilbert-base-uncased"
output_dir = "output"
epochs = 1
train_batch_size = 8
eval_batch_size = 16
warmup_steps = 5000
max_steps = 4000
eval_steps = 100
lr = 5e-5
query_max_length = 100
doc_max_length = 250

In [9]:
OUTPUT_DIR = Path(output_dir) / "MyDenseBiEncoder"
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

In [10]:
tokenizer = AutoTokenizer.from_pretrained(pretrained)

In [11]:
train_dataset = TripletDataset(
    collection_path="data/collection.tsv",
    queries_path="data/train_queries.tsv",
    train_triplets_path="data/train_triplets.tsv",
)
dev_dataset = PairDataset(
    collection_path="data/collection.tsv",
    queries_path="data/dev_queries.tsv",
    query_doc_pair_path="data/dev_bm25.trec",
    qrels_path="data/dev_qrels.json",
)

UnicodeDecodeError: 'charmap' codec can't decode byte 0x9d in position 5720: character maps to <undefined>

In [None]:
triplet_collator = BiEncoderTripletCollator(tokenizer, query_max_length, doc_max_length)
pair_collator = BiEncoderPairCollator(tokenizer, query_max_length, doc_max_length)
model = MyDenseBiEncoder(pretrained)

In [None]:
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    learning_rate=lr,
    num_train_epochs=epochs,
    evaluation_strategy="steps",
    fp16=True,
    warmup_steps=warmup_steps,
    metric_for_best_model="RR@10",
    load_best_model_at_end=True,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=eval_batch_size,
    max_steps=max_steps,
    save_steps=eval_steps,
    eval_steps=eval_steps,
    save_total_limit=2,
)

In [None]:
trainer = HFTrainer(
    model,
    train_dataset=train_dataset,
    data_collator=triplet_collator,
    args=training_args,
    eval_dataset=dev_dataset,
    eval_collator=pair_collator,
)

In [None]:
trainer.train()
trainer.save_model()

In [None]:
test_dataset = PairDataset(
    collection_path="data/collection.tsv",
    queries_path="data/test_queries.tsv",
    query_doc_pair_path="data/test_bm25.trec",
    qrels_path="data/test_qrels.json",
)

test_pair_collator = BiEncoderPairCollator(tokenizer, query_max_length=query_max_length, doc_max_length=doc_max_length)

eval_results = trainer.evaluate(test_dataset, test_pair_collator)

# Print evaluation results
print("Evaluation Results:")
for metric, value in eval_results.items():
    print(f"{metric}: {value}")