<a href="https://colab.research.google.com/github/jtlagumbay/cebqa/blob/main/retriever/bm_roberta.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **QA Pipeline**

1. ElasticSeach Indexer
2. BM25 Retriever
3. Fine-tuned XLMR Reader


# Dependencies

In [None]:
pip install elasticsearch transformers datasets evaluate rank_bm25 nltk fuzzywuzzy sentence_transformers


In [None]:
pip install --upgrade --no-cache-dir numpy==1.26.4

In [None]:
pip install faiss-cpu

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
CEBQA_DATASET = "jhoannarica/cebquad_split"
# ELASTIC_URL = "https://tender-separately-mudfish.ngrok-free.app"
# CURRENT_MODEL = "/content/drive/MyDrive/UP Files/IV - 2nd sem/CMSC 198.1/cebqa_roberta/xlmr/2025-03-15_09-36/model"
# CURRENT_TOKENIZER = "/content/drive/MyDrive/UP Files/IV - 2nd sem/CMSC 198.1/cebqa_roberta/xlmr/2025-03-15_09-36/tokenizer"
CURRENT_ROOT = "/Users/jhoannaricalagumbay/Library/CloudStorage/GoogleDrive-jtlagumbay@up.edu.ph/My Drive/UP Files/IV - 2nd sem/CMSC 198.1/cebqa_roberta/new-split/xlmr_body-filtered/2025-04-04_05-13"
CURRENT_MODEL = CURRENT_ROOT+"/model"
CURRENT_TOKENIZER = CURRENT_ROOT+"/tokenizer"
ELASTIC_URL = "http://localhost:9200"
INDEX_NAME = "superbalita"
K = 3
DATASET_CSV = "/Users/jhoannaricalagumbay/School/cebqa/dataset/articles_202503120405_author_removed_fixed.csv"
BM25 = "bm25"
FAISS = "faiss"
CEBQA_DPR_MODEL = "/Users/jhoannaricalagumbay/Library/CloudStorage/GoogleDrive-jtlagumbay@up.edu.ph/My Drive/cebqa/dpr/2025-04-16_03-37/model"
CEBQA_DPR_TOKENIZER = "/Users/jhoannaricalagumbay/Library/CloudStorage/GoogleDrive-jtlagumbay@up.edu.ph/My Drive/cebqa/dpr/2025-04-16_03-37/tokenizer"
DPR_CONTEXT_ENCODER = "voidful/dpr-ctx_encoder-bert-base-multilingual"

# Indexer

Start ElasticSearch Locally:
1. Start ES docker
2. Start NGROK: `ngrok http --url=tender-separately-mudfish.ngrok-free.app 9200`


In [None]:
headers = {
    "Origin": "https://colab.research.google.com",
     "Content-Type": "application/json",
}

response = requests.options(ELASTIC_URL, headers=headers)
print(response.headers)


In [None]:
es = Elasticsearch([ELASTIC_URL], verify_certs=False, headers=headers)
print(es.info())
# try:
#     print(es.transport.perform_request('GET', '/'))
# except Exception as e:
#     print("Error:", e)

In [None]:
class ElasticSearchIndexer:
    def __init__(self, index_name=INDEX_NAME):
        self.index_name = index_name
        self.es = Elasticsearch(ELASTIC_URL)  # Ensure ES is running
        print(f"Initiating ESIndexer {self.index_name}")

    def create_index(self):
        """ Create an index with a text field for BM25 """
        if not self.es.indices.exists(index=self.index_name):
            self.es.indices.create(index=self.index_name, body={
                "settings": {
                    "number_of_shards": 1,
                    "number_of_replicas": 0
                },
                "mappings": {
                    "properties": {
                        "id": {"type": "keyword"},
                        "title": {"type": "text"},
                        "body": {"type": "text"}
                    }
                }
            })
            print(f"Index '{self.index_name}' created.")

    def index_documents(self, documents):
        """ Bulk index documents into ElasticSearch """
        actions = [
            {
                "_index": self.index_name,
                "_id": doc["id"],  # Use document ID for uniqueness
                "_source": {
                    "id": doc["id"],
                    "title": doc["pseudonymized_title"],
                    "body": doc["pseudonymized_body"]
                }
            }
            for doc in documents
        ]
        bulk(self.es, actions)
        print(f"Indexed {len(documents)} documents.")

    def index_from_csv(self, file_path):
        df = pd.read_csv(file_path)
        documents = df.to_dict(orient="records")  # Convert DataFrame to a list of dicts
        self.index_documents(documents)



In [None]:
# Sample usage
# indexer = ElasticSearchIndexer()
# indexer.create_index()
# indexer.index_from_csv("/Users/jhoannaricalagumbay/School/cebqa/dataset/articles_202503120405_author_removed_fixed.csv")

# BM25

In [None]:
class BM25Retriever:
    def __init__(self, index_name="superbalita"):
        print(f"Initiating retriever with index_name: {index_name}")
        self.index_name = index_name
        self.es = Elasticsearch(ELASTIC_URL)

    def retrieve(self, query, top_k=3):
        """ Retrieve top-k relevant documents using BM25 """
        print(f"retrieving {top_k} docs for [{query}]")
        response = self.es.search(index=self.index_name, body={
            "query": {
                "match": {
                    "body": query
                }
            },
            "size": top_k
        })
        return [hit["_source"] for hit in response["hits"]["hits"]]

    def retrieve_batch(self, queries, top_k=3):
        print(f"Retrieve Batch for {len(queries)} queries")
        """ Retrieve top-k relevant documents for multiple queries using BM25 in batch mode """
        if not isinstance(queries, list):
            raise ValueError("queries should be a list of strings")

        # Multi-search request body
        request_body = ""
        for query in queries:
            safe_question = json.dumps(query)
            request_body += f'{{"index": "{self.index_name}"}}\n'  # Metadata
            request_body += f'{{"query": {{"match": {{"body": {safe_question}}}}}, "size": {top_k}}}\n'  # Query

        # Send multi-search request
        response = self.es.msearch(body=request_body)

        # Extract results
        results = []
        for query_response in response["responses"]:
            retrieved_docs = [hit["_source"] for hit in query_response["hits"]["hits"]]
            results.append(retrieved_docs)

        return results  # List of lists, where each sublist contains retrieved documents for a query

    def retrieve_batch_query_dict(self, queries_list, top_k=3):
        print(f"Retrieve Batch Dict for {len(queries_list)} queries")

        """ Retrieve top-k relevant documents for multiple queries using BM25 in batch mode.

        Args:
            queries_list (list): A list of dictionaries, each containing 'id' and 'question'.
            top_k (int): Number of top relevant documents to retrieve per query.

        Returns:
            dict: A dictionary where keys are query IDs and values are lists of retrieved documents.
        """
        if not isinstance(queries_list, list) or not all(isinstance(q, dict) and 'id' in q and 'question' in q for q in queries_list):
            raise ValueError("queries_list should be a list of dictionaries with 'id' and 'question' keys")

        # Multi-search request body
        request_body = ""
        query_ids = []  # To track IDs in order
        for query in queries_list:
            safe_question = json.dumps(query["question"])
            query_ids.append(query["id"])
            request_body += f'{{"index": "{self.index_name}"}}\n'  # Metadata
            request_body += f'{{"query": {{"match": {{"body": {safe_question}}}}}, "size": {top_k}}}\n'  # Query

        # Send multi-search request
        response = self.es.msearch(body=request_body)

        # Extract results and associate with query IDs
        results = []
        for i, query_response in enumerate(response["responses"]):
            retrieved_docs = [hit["_source"] for hit in query_response["hits"]["hits"]]
            results.append({
                "query_id": str(query_ids[i]),
                "top_docs": retrieved_docs
            })

        return results  # Dictionary format: {id: retrieved_docs}



In [None]:
# Sample usage
retriever = BM25Retriever()
query = ['Unsa ang giingon ni Gobernador Abalayan nga mabuhat ra "with a united country"?']
top_docs = retriever.retrieve_batch(query)
print("Retrieved Documents:", top_docs)

# FAISS Indexer

In [None]:
class FAISSIndexer:
    def __init__(self, index_file="faiss_index.idx", model_name="sentence-transformers/all-MiniLM-L6-v2", use_fine_tuned = False):
        self.index_file = index_file
        self.model = SentenceTransformer(model_name)
        print("Loading DPR multilingual context encoder...")
        self.dpr_tokenizer = DPRContextEncoderTokenizer.from_pretrained(DPR_CONTEXT_ENCODER)
        self.dpr_model = DPRContextEncoder.from_pretrained(DPR_CONTEXT_ENCODER)
        self.use_fine_tuned = use_fine_tuned
        self.index = None
        self.documents = []  # Store original text
        self.index_from_csv()
        print("FAISS Indexer initialized.")


    def encode_contexts_with_dpr(self):
        """Encode contexts using DPR multilingual context encoder and re-index."""
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dpr_model.to(device)
        self.dpr_model.eval()

        with torch.no_grad():
            inputs = self.dpr_tokenizer(
                self.documents,
                return_tensors="pt",
                truncation=True,
                padding="max_length",
                max_length=512
            ).to(device)

            embeddings = self.dpr_model(**inputs).pooler_output.cpu().numpy()

            if not embeddings.flags['C_CONTIGUOUS']:
                embeddings = np.ascontiguousarray(embeddings)

            faiss.normalize_L2(embeddings)
            return embeddings

    def create_index(self, d):
        """Create a new FAISS index."""
        self.index = faiss.IndexFlatL2(d)
        print(f"Created FAISS index with dimension {d}.")

    def index_documents(self, documents):
        """Index documents into FAISS."""
        self.documents = [doc['pseudonymized_body'] for doc in documents]
        self.article_ids = [doc['id'] for doc in documents]
        self.titles = [doc['pseudonymized_title'] for doc in documents]

        if self.use_fine_tuned:
            embeddings = self.encode_contexts_with_dpr()
        else:
            embeddings = self.model.encode(self.documents, convert_to_numpy=True)

        d = embeddings.shape[1]

        if self.index is None:
            self.create_index(d)

        self.index.add(embeddings)
        print(f"Indexed {len(documents)} documents into FAISS.")
        self.save_index()

    def index_from_csv(self, file_path=DATASET_CSV):
        """Load documents from a CSV file and index them."""
        df = pd.read_csv(file_path)
        documents = df.to_dict(orient="records")
        self.index_documents(documents)

    def save_index(self):
        """Save FAISS index to disk."""
        faiss.write_index(self.index, self.index_file)
        print(f"FAISS index saved to {self.index_file}.")

    def load_index(self):
        """Load FAISS index from disk."""
        self.index = faiss.read_index(self.index_file)
        print(f"FAISS index loaded from {self.index_file}.")



In [None]:
class FAISSRetriever:
    def __init__(
            self,
            q_encoder = DPRQuestionEncoder.from_pretrained(CEBQA_DPR_MODEL),
            q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(CEBQA_DPR_TOKENIZER),
            index_file="faiss_index.idx",
            top_k=3,
            device="cuda" if torch.cuda.is_available() else "cpu",
            use_fine_tuned = False
        ):
        print("Initializing FAISS Retriever")
        self.indexer = FAISSIndexer(index_file=index_file, use_fine_tuned=use_fine_tuned)
        self.top_k = top_k
        self.use_fine_tuned = use_fine_tuned
        self.q_encoder = q_encoder
        self.q_tokenizer = q_tokenizer
        self.device = device

        # Move model to the appropriate device
        self.q_encoder.to(device)
        self.q_encoder.eval()

    def encode_query(self, query):
        """Encode the query using the fine-tuned question encoder."""
        with torch.no_grad():
            inputs = self.q_tokenizer(
                query,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=64
            ).to(self.device)

            embeddings = self.q_encoder(**inputs).pooler_output.cpu().numpy()

            # Ensure embeddings are C-contiguous for FAISS
            if not embeddings.flags['C_CONTIGUOUS']:
                embeddings = np.ascontiguousarray(embeddings)

            # Normalize embeddings for cosine similarity
            faiss.normalize_L2(embeddings)

            return embeddings

    def retrieve(self, query):
        """Retrieve top-k relevant documents using FAISS."""
        if self.use_fine_tuned:
            query_embedding = self.encode_query([query])
        else:
            query_embedding = self.indexer.model.encode([query], convert_to_numpy=True)

        D, I = self.indexer.index.search(query_embedding, self.top_k)
        print(D, I)
        results = [
            {
                "rank": rank + 1,
                "score": float(D[0][rank]),
                "id": self.indexer.article_ids[idx],
                "title": self.indexer.titles[idx],
                "body": self.indexer.documents[idx],
            }
            for rank, idx in enumerate(I[0]) if idx < len(self.indexer.documents)
        ]
        return results

    def retrieve_batch(self, queries):

        print(f"processing {len(queries)}")
        """Retrieve top-k relevant documents for multiple queries."""
        questions = [query["question"] for query in queries]
        if self.use_fine_tuned:
            query_embeddings = self.encode_query([query])
        else:
            query_embeddings = self.indexer.model.encode(questions, convert_to_numpy=True)
        D, I = self.indexer.index.search(query_embeddings, self.top_k)
        print(f"done {len(D)}")
        results = []
        for query_idx, query in enumerate(queries):
            print(f"query idx: {query_idx}")
            retrieved_docs = [
                {
                    "rank": rank + 1,
                    "score": float(D[query_idx][rank]),
                    "text": self.indexer.documents[idx]
                }
                for rank, idx in enumerate(I[query_idx]) if idx < len(self.indexer.documents)
            ]
            results.append({"query": query, "top_docs": retrieved_docs})
        return results


In [None]:
# # Initialize the FAISS indexer
# indexer = FAISSIndexer(index_file="faiss_index.idx")

# Index documents from a CSV file
# # Save the FAISS index for later use
# indexer.save_index()

# indexer = FAISSIndexer(index_file="faiss_index.idx")
# indexer.load_index()

In [None]:
# Initialize the retriever with the loaded indexer
# retriever = FAISSRetriever(index_file="faiss_index.idx", top_k=K)

# # Retrieve relevant documents for a single query
# query = "kanus-a ang palarong pambansa??"
# results = retriever.retrieve(query)

# # Print retrieved documents
# print(results)


# Retrieve relevant documents for a single query
# query = [{"question": "kanus-a ang palarong pambansa?"}, {"question":"Kinsa ang hepe sa Cebu Police?"}]
# results = retriever.retrieve_batch(query)

# # Print retrieved documents
# for res in results:
#     print(res)


# Reader

In [None]:
class Reader:
    def __init__(
        self,
        model_path = CURRENT_MODEL,
        tokenizer_path = CURRENT_TOKENIZER
      ):
        print(f"Initiating reader with model: {model_path}")
        model_best = AutoModelForQuestionAnswering.from_pretrained(model_path)
        tokenizer_best = AutoTokenizer.from_pretrained(tokenizer_path)

        device = torch.device("mps")
        self.qa_pipeline = pipeline(
            "question-answering",
            model=model_best,
            tokenizer=tokenizer_best,
            device=device
            )

    def extract_answer_batch(self, queries_list, top_docs):
        print(f"Extracting batch answer for {len(queries_list)} queries")
        qa_dataset = Dataset.from_dict({
          "question": [queries_list["question"] for doc in top_docs['top_docs']] ,
          "context": [doc['body'] for doc in top_docs['top_docs']]
        })

        return self.qa_pipeline(qa_dataset)

    def extract_answer(self, question, documents, num_chunks = 1, overlap = 0.3):
        print(f"extracting answer for {question}")
        """ Find the best answer from retrieved documents while keeping metadata """
        best_result = None
        best_score = 0

        for doc in documents:
            if num_chunks == 1:
                contexts = [doc["body"]]
            else:
                contexts = self.chunk_text(doc["body"],  num_chunks, overlap)

            for context in contexts:
            #   print(question)
            #   print(context)
              result = self.qa_pipeline(question=question, context = context)
              if result["score"] > best_score:
                  best_result = {
                      "article_id": doc["id"],
                      "title": doc["title"],
                      "body": doc["body"],
                      "answer": result["answer"],
                      "score": result["score"]
                  }
                  best_score = result["score"]

        return best_result

    def chunk_text(self, text, chunk_size=3, overlap=0.5):
        sentences = sent_tokenize(text)  # Tokenize text into sentences
        step = int(chunk_size * (1 - overlap))  # Overlapping step

        chunks = []
        for i in range(0, len(sentences), step):
            chunk = sentences[i:i + chunk_size]
            if not chunk:
                continue
            chunks.append(" ".join(chunk))

        return chunks

# QA Pipeline

In [None]:
class QA:
    def __init__(
        self,
        model_path = CURRENT_MODEL,
        tokenizer_path = CURRENT_TOKENIZER,
        dataset = CEBQA_DATASET,
        indexer_type = BM25,
        index_name = INDEX_NAME,
        k = K,
        sample = None,
        isRandom = False,
        overlap = 0.0,
        num_chunks = 1,
        use_fine_tuned = False
      ):
        reader = Reader(model_path=model_path, tokenizer_path=tokenizer_path)

        self.model_path = model_path
        self.tokenizer_path = tokenizer_path
        self.reader = reader
        self.tokenizer = XLMRobertaTokenizerFast.from_pretrained("xlm-roberta-base")
        test_dataset = load_dataset(dataset)["test"]
        self.dataset = test_dataset.filter(self.filter_incomplete_examples) \
            .map(self.normalize_row, batched=True) \
            .map(self.tokenize_train_function, batched=True)\
            .filter(self.decode_error)
        self.sentence_transformer = SentenceTransformer("all-MiniLM-L6-v2")
        self.k = k
        self.overlap = overlap
        self.num_chunks = num_chunks
        self.sample = sample
        self.isRandom = isRandom
        self.index_name = index_name
        self.indexer_type = indexer_type

        if sample is not None and isRandom:
            indices = random.sample(range(len(self.dataset)), sample)
            self.dataset = self.dataset.select(indices)
        elif sample is not None and not isRandom:
            self.dataset = self.dataset.select(range(sample))

        print(f"Initiating QA Pipeline.")
        print(f"QA model {self.model_path}")
        print(f"QA tokenizer {self.tokenizer_path}")
        print(f"QA reader {self.reader}")
        print(f"QA dataset {len(self.dataset)}")
        print(f"QA k {self.k}")
        print(f"QA overlap {self.overlap}")
        print(f"QA num_chunks {self.num_chunks}")
        print(f"QA sample {self.sample}")
        print(f"QA isRandom {self.isRandom}")
        print(f"QA index_name {self.index_name}")
        print(f"QA indexer {self.index_name}")
        self.queries = [
            {
                "id": item['id'],
                "article_id": item['article_id'],
                "question": item['question'],
                "context": {
                    "text": item['context'],
                    "start": item['context_start']
                },
                "answer": {
                    "text": item['answer'],
                    "start": item['answer_start']
                }
            }
            for item in self.dataset
        ]

        if indexer_type == BM25:
            self.retriever = BM25Retriever(index_name=index_name)
            self.run_top_docs_batch_bm25()
        else:
            self.retriever = FAISSRetriever(top_k=self.k, use_fine_tuned = use_fine_tuned)
            self.run_top_docs_batch_faiss()
        print(f"QA retriever {self.retriever}")


    def run_top_docs_batch_bm25(self):
        self.top_docs = self.retriever.retrieve_batch_query_dict(
            queries_list = self.queries,
            top_k=self.k
        )

        return self.top_docs

    def run_top_docs_batch_faiss(self):
        docs = []
        for item in self.dataset:
            result = self.retriever.retrieve(item["question"])
            doc = {
                "query_id": item["id"],
                "top_docs": result
            }
            docs.append(doc)

        self.top_docs = docs
        return self.top_docs

    def run(self):
        start_time = time.time()
        date_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
        print(f"QA run for {self.model_path} on {date_now}")

        results = []

        for index, query in enumerate(self.queries):
          print(f"{index} processing {query['id']}")
          docs = self.top_docs[index]['top_docs']
        #   print([f"{doc}\n" for doc in docs])
          answer = self.reader.extract_answer(
              question = query["question"],
              documents = docs,
              num_chunks = self.num_chunks,
              overlap= self.overlap
          )
        #   print(query["answer"]["text"])
        #   print(answer["answer"])
          result = query
          result["pred"] = answer
          result["top_docs"] = docs
          results.append(result)

        self.results = results

        end_time = time.time()
        self.stats ={
            'run_time': end_time - start_time
        }
        return self.results

    def normalize_row(self, examples):
        examples["context"] = [unicodedata.normalize("NFKC", context) for context in examples["context"]]

        examples["article_body"] = [unicodedata.normalize("NFKC", body) for body in examples["article_body"]]

        examples["answer"] =  [unicodedata.normalize("NFKC", answer) for answer in examples["answer"]]

        examples["question"] = [unicodedata.normalize("NFKC", q) for q in examples["question"]]

        return examples

    def normalize_text(self, text):
        """Lowercase and remove punctuation, articles, and extra whitespace."""
        text = text.lower()
        text = re.sub(r'\W+', ' ', text)  # Remove punctuation and special characters
        text = re.sub(r'\s+', ' ', text).strip()  # Remove extra whitespace
        return text

    def compute_similarity(self, text1, text2):
        """Compute cosine similarity between two texts using Sentence Transformers."""
        emb1 = self.sentence_transformer.encode(text1, convert_to_tensor=True)
        emb2 = self.sentence_transformer.encode(text2, convert_to_tensor=True)
        similarity = util.pytorch_cos_sim(emb1, emb2).item()  # Convert tensor to float
        return similarity

    def evaluate_batch(self):
        pass

    def evaluate_retriever(self):
        wrong_doc = []
        for index, query in enumerate(self.queries):
            top_doc = self.top_docs[index]["top_docs"]
            if not any(doc["id"] == query["article_id"] for doc in top_doc):
                wrong_doc.append(query["article_id"])

        return wrong_doc

    def compute_retrieval_metrics(self):
        metrics = {
            "hits@1": 0,
            "hits@3": 0,
            "hits@5": 0,
            "hits@10": 0,
            "hits@50": 0,
            "hits@100": 0,
            "mrr": 0.0
        }

        total = len(self.queries)

        for index, query in enumerate(self.queries):
            correct_id = query["article_id"]
            docs = self.top_docs[index]["top_docs"]  # Ranked list of dicts with 'id'

            found = False
            for rank, doc in enumerate(docs):
                                    # rank is 0-based, so add 1
                print(rank, doc)

                r = rank + 1
                if doc["id"] == correct_id:

                    if r <= 1: metrics["hits@1"] += 1
                    if r <= 3: metrics["hits@3"] += 1
                    if r <= 5: metrics["hits@5"] += 1
                    if r <= 10: metrics["hits@10"] += 1
                    if r <= 50: metrics["hits@50"] += 1
                    if r <= 100: metrics["hits@100"] += 1

                    metrics["mrr"] += 1 / r
                    found = True
                    break  # stop checking once found

            if not found:
                metrics["mrr"] += 0.0  # optional, for clarity

        # Average over total queries
        for k in ["hits@1", "hits@3", "hits@5", "hits@10", "hits@50", "hits@100"]:
            metrics[k] = metrics[k] / total
        metrics["mrr"] = metrics["mrr"] / total

        return metrics
    def evaluate(self):
        print(f"QA evaluate for {len(self.results)} results on {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
        pred = [
          {
              'id': result['id'],  # Convert ID to string
              'prediction_text': self.normalize_text(result['pred']['answer'])
          }
          for result in self.results
        ]

        ref = [
            {
                'id': item['id'],  # Convert ID to string
                'answers': {
                    'text': [self.normalize_text(item['answer']['text'])],
                    'answer_start': [item['answer']['start']]
                }
            }
            for item in self.results
        ]

        # Load SQuAD metric
        metric = load("squad")

        # Compute metric
        res = metric.compute(predictions=pred, references=ref)
        sentence_match_scores = [
            p['prediction_text'] in r['answers']['text'][0] for p, r in zip(pred, ref)
        ]

        # Compute average sentence match score
        avg_sentence_match = np.mean(sentence_match_scores)

        # Combine results
        res["sentence_match"] = float(avg_sentence_match ) * 100
        print(res)

        self.config = {
            'model_path': self.model_path,
            'tokenizer_path': self.tokenizer_path,
            'k': self.k,
            'sample': self.sample,
            'isRandom': self.isRandom,
            'overlap': self.overlap,
            'num_chunks': self.num_chunks,
            'indexer_type': self.indexer_type
        }
        self.eval_res = res

        return self.eval_res, self.config, self.stats

    def filter_incomplete_examples(self, example):
        # Ensure both "question" and "context" exist and are non-empty
        return "question" in example and example["question"] and \
            "article_body" in example and example["answer"]

    def filter_by_token_length(self, example):
        # Tokenize the concatenated question + article_body
        tokens = self.tokenizer(example["question"], example["article_body"], truncation=False)
        return len(tokens["input_ids"]) <= 512

    def decode_error(self, example):
        input_ids = example["input_ids"]
        start_positions = example["start_positions"]
        end_positions = example["end_positions"]
        predict_answer_tokens = input_ids[start_positions : end_positions+1]
        return self.tokenizer.decode(predict_answer_tokens) == example["answer"]

    def tokenize_train_function(self, examples):
        article_text = [article for article in examples.get("article_body", [""])]
        context_text = [context for context in examples.get("context", ["{}"])]
        answer_text = examples.get("answer", [""])
        answer_start = examples.get("answer_start", [0])
        context_start_list = examples.get("context_start", [0])
        question_text = [q for q in examples.get("question", [""])]
        start_positions = []
        end_positions = []

        inputs = self.tokenizer(
            question_text,
            article_text,
            truncation="only_second",  # Truncate only the context
            max_length=512,            # Limit input length
            stride=128,                # Add a sliding window
            return_overflowing_tokens=False,  # Handle long contexts
            return_offsets_mapping=True,
            padding="max_length"
        )

        offset_mapping = inputs.pop("offset_mapping")
        # sample_map = inputs.pop("overflow_to_sample_mapping")

        for i, offset in enumerate(offset_mapping):
            answer = answer_text[i]
            context = context_text[i]
            article = article_text[i]
            start_char = int(context_start_list[i]) + int(answer_start[i])
            end_char = start_char + len(answer)


            sequence_ids = inputs.sequence_ids(i)

            # Find the start and end of the context
            idx = 0
            while sequence_ids[idx] != 1:
                idx += 1
            context_start = idx
            while sequence_ids[idx] == 1:
                idx += 1
            context_end = idx - 1

            # If the answer is not fully inside the context, label is (0, 0)
            if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
                start_positions.append(0)
                end_positions.append(0)
            else:
                # Otherwise it's the start and end token positions
                idx = context_start
                while idx <= context_end and offset[idx][0] <= start_char:
                    idx += 1
                start_positions.append(idx - 1)

                idx = context_end
                while idx >= context_start and offset[idx][1] >= end_char:
                    idx -= 1
                end_positions.append(idx + 1)

        inputs["start_positions"] = start_positions
        inputs["end_positions"] = end_positions


        return inputs


# QA - BM25

In [None]:
print(np.__version__)  # Check if NumPy is available
print(torch.randn(1).numpy())


1.26.4
[-1.4586302]


In [None]:
qa_bm25 = QA(
    model_path=CURRENT_MODEL,
    k = 100, overlap=0.0, num_chunks=1)
wrong = qa_bm25.evaluate_retriever()
metrics = qa_bm25.compute_retrieval_metrics()
print(len(wrong))
print(metrics)

In [None]:
qa_bm25.run()
qa_bm25.evaluate()


# QA - FAISS

In [None]:
qa_faiss = QA(
    sample=100,
    model_path=CURRENT_MODEL,
    k = 100, overlap=0.0, num_chunks=1,
    indexer_type=FAISS,
    use_fine_tuned=True)
wrong_faiss = qa_faiss.evaluate_retriever()
metrics = qa_faiss.compute_retrieval_metrics()
print(len(wrong_faiss))
print(metrics)

In [None]:
qa_faiss.run()
qa_faiss.evaluate()