In [1]:
!pip install transformers torch faiss-cpu scikit-learn rank_bm25 rank bm25 tqdm pyserini==0.22.1 sentence_transformers python-dotenv

Collecting faiss-cpu
  Downloading faiss_cpu-1.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.4 kB)
Collecting rank_bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl.metadata (3.2 kB)
Collecting rank
  Downloading rank-1.0.0.zip (1.1 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting bm25
  Downloading BM25-1.0.0.tar.gz (1.1 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pyserini==0.22.1
  Downloading pyserini-0.22.1-py3-none-any.whl.metadata (4.5 kB)
Collecting python-dotenv
  Downloading python_dotenv-1.0.1-py3-none-any.whl.metadata (23 kB)
Collecting pyjnius>=1.4.0 (from pyserini==0.22.1)
  Downloading pyjnius-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting nmslib>=2.1.1 (from pyserini==0.22.1)
  Downloading nmslib-2.1.1.tar.gz (188 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m188.7/188.7 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Prepari

In [2]:
import os
import re
import csv
import nltk
import tqdm
import json
import torch
import requests
import subprocess
import numpy as np
import networkx as nx
nltk.download('punkt')
nltk.download('stopwords')
from functools import cache
from dotenv import load_dotenv
from rank_bm25 import BM25Okapi
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
from typing import List, Dict,Tuple
from collections import defaultdict
from sentence_transformers import SentenceTransformer
from nltk.tokenize import word_tokenize, sent_tokenize
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer
from transformers import BertTokenizer, BertModel, AutoTokenizer, pipeline, AutoModelForSeq2SeqLM, RagTokenizer, RagRetriever, RagTokenForGeneration, AutoModelForSequenceClassification, LlamaTokenizer, LlamaForCausalLM

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
  from tqdm.autonotebook import tqdm, trange


In [3]:
!git clone https://github.com/RegNLP/ObliQADataset.git

Cloning into 'ObliQADataset'...
remote: Enumerating objects: 68, done.[K
remote: Counting objects: 100% (68/68), done.[K
remote: Compressing objects: 100% (65/65), done.[K
remote: Total 68 (delta 11), reused 47 (delta 1), pack-reused 0 (from 0)[K
Receiving objects: 100% (68/68), 11.83 MiB | 12.73 MiB/s, done.
Resolving deltas: 100% (11/11), done.


### Retrieves Individual Docs with RANK

In [4]:
documents = []
document_ids = []
doc_id_to_text = {}
doc_id_to_tokens = defaultdict(list)

def preprocess_text(text):
    return text.lower()

for i in range(1, 41):
    with open(os.path.join("ObliQADataset/StructuredRegulatoryDocuments", f"{i}.json")) as f:
        doc = json.load(f)

        for passage in doc:
            doc_id = str(passage["DocumentID"])
            passage_text = preprocess_text(passage["Passage"])

            if doc_id not in doc_id_to_text:
                doc_id_to_text[doc_id] = passage_text
            else:
                doc_id_to_text[doc_id] += " " + passage_text

            doc_id_to_tokens[doc_id].extend(word_tokenize(passage_text))

all_documents = list(doc_id_to_tokens.values())
all_document_ids = list(doc_id_to_tokens.keys())

bm25 = BM25Okapi(all_documents, k1=1.5, b=0.75)

def retrieve_passages(user_query: str, topk: int = 10) -> List[Dict]:
    preprocessed_query = preprocess_text(user_query)
    tokenized_query = word_tokenize(preprocessed_query)

    scores = bm25.get_scores(tokenized_query)

    top_results = sorted(
        [(all_document_ids[idx], score) for idx, score in enumerate(scores)],
        key=lambda x: x[1],
        reverse=True
    )[:topk]

    results = [{"doc_id": doc_id, "rank": rank + 1, "score": score, "text": doc_id_to_text[doc_id]}
               for rank, (doc_id, score) in enumerate(top_results)]

    return results

results = retrieve_passages("Can the ADGM provide examples of legal risks associated with securitisation that Authorised Persons should particularly be aware of and manage?")
for result in results:
    print(f"Rank: {result['rank']}, Doc ID: {result['doc_id']}, Score: {result['score']}")
    print(f"Text: {result['text'][:500]}...")

Rank: 1, Doc ID: 13, Score: 32.030300265205206
Text: application, interpretation and categorisation application  subject to (2), these rules apply to every authorised person where its financial services permission authorises it to carry on one or more of the regulated activities listed in 1.3.1(a), 1.3.2(a), 1.3.3(1)(a), 1.3.4(a), 1.3.5(a), 1.3.6(a) or 1.3.7(a). in respect of a fund manager that:
(a) 	manages only venture capital funds; or
(b)	(i)	manages only venture capital funds; and
(ii)	undertakes one or both of the regulated activities of ad...
Rank: 2, Doc ID: 11, Score: 29.807693012956896
Text: introduction application  the rules in this rulebook ("mkt") are made for the purposes of the financial services and markets regulations 2015 ("fsmr") and apply to every person to whom that legislation applies. for the purposes of these rules the regulator may refer to itself as the listing authority. without limiting the generality of (1), this rulebook applies to a:
(a)	person making a

### Answer Generation

## RePaSs Evaluator

In [5]:
class RePASSEvaluator:
    def __init__(self,
                 nli_model_name="cross-encoder/nli-deberta-v3-xsmall",
                 obligation_model_name="nlpaueb/legal-bert-base-uncased"):
        # NLI Model for entailment and contradiction
        self.nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_name)
        self.nli_model = AutoModelForSequenceClassification.from_pretrained(nli_model_name)

        # Obligation Detection Model (placeholder - would require fine-tuning)
        self.obligation_tokenizer = AutoTokenizer.from_pretrained(obligation_model_name)
        self.obligation_model = AutoModelForSequenceClassification.from_pretrained(obligation_model_name)

    def calculate_entailment_contradiction_scores(self, source_passages: List[str], answer_sentences: List[str]) -> Tuple[float, float]:
        """
        Calculate entailment and contradiction scores between source passages and answer sentences

        Args:
            source_passages (List[str]): List of source passage sentences
            answer_sentences (List[str]): List of answer sentences

        Returns:
            Tuple[float, float]: Entailment score and Contradiction score
        """
        entailment_matrix = []
        contradiction_matrix = []

        for answer_sent in answer_sentences:
            answer_entail_probs = []
            answer_contradict_probs = []

            for source_sent in source_passages:
                # Prepare inputs for NLI model
                inputs = self.nli_tokenizer(source_sent, answer_sent, return_tensors="pt", truncation=True)

                with torch.no_grad():
                    outputs = self.nli_model(**inputs)
                    probs = torch.softmax(outputs.logits, dim=1)

                # Probabilities: [entailment, neutral, contradiction]
                entail_prob = probs[0, 0].item()
                contradict_prob = probs[0, 2].item()

                answer_entail_probs.append(entail_prob)
                answer_contradict_probs.append(contradict_prob)

            # Take max probability for each answer sentence
            entailment_matrix.append(max(answer_entail_probs))
            contradiction_matrix.append(max(answer_contradict_probs))

        # Calculate average scores
        entailment_score = np.mean(entailment_matrix)
        contradiction_score = np.mean(contradiction_matrix)

        return entailment_score, contradiction_score

    def detect_obligations(self, source_passages: List[str]) -> List[str]:
        """
        Detect obligation sentences in source passages

        Note: This is a placeholder method that would require a trained obligation classifier

        Args:
            source_passages (List[str]): List of source passage sentences

        Returns:
            List[str]: List of detected obligation sentences
        """
        # Placeholder implementation
        return source_passages

    def calculate_obligation_coverage(self, source_obligations: List[str], answer_sentences: List[str]) -> float:
        """
        Calculate obligation coverage score

        Args:
            source_obligations (List[str]): List of obligation sentences from source
            answer_sentences (List[str]): List of answer sentences

        Returns:
            float: Obligation coverage score
        """
        covered_obligations = 0

        for obligation in source_obligations:
            for answer_sent in answer_sentences:
                # Prepare inputs for NLI model
                inputs = self.nli_tokenizer(obligation, answer_sent, return_tensors="pt", truncation=True)

                with torch.no_grad():
                    outputs = self.nli_model(**inputs)
                    probs = torch.softmax(outputs.logits, dim=1)

                # Check if entailment probability exceeds 0.7
                if probs[0, 0].item() > 0.7:
                    covered_obligations += 1
                    break

        return covered_obligations / len(source_obligations) if source_obligations else 0

    def calculate_repass_score(self, source_passages: List[str], answer: str) -> Dict[str, float]:
        """
        Calculate the Regulatory Passage Answer Stability Score (RePASs)

        Args:
            source_passages (List[str]): List of source passage sentences
            answer (str): Generated answer

        Returns:
            Dict[str, float]: Detailed RePASs score components and final score
        """
        # Tokenize answer into sentences
        answer_sentences = [sent.strip() for sent in answer.split('.') if sent.strip()]

        # Calculate Entailment and Contradiction Scores
        entailment_score, contradiction_score = self.calculate_entailment_contradiction_scores(
            source_passages, answer_sentences
        )

        # Detect obligations in source passages
        source_obligations = self.detect_obligations(source_passages)

        # Calculate Obligation Coverage Score
        obligation_coverage_score = self.calculate_obligation_coverage(
            source_obligations, answer_sentences
        )

        # Calculate final RePASs score
        # RePASs = Es - Cs + OCs + 1/3
        repass_score = entailment_score - contradiction_score + obligation_coverage_score + (1/3)

        # Clip score between 0 and 1
        repass_score = max(0, min(1, repass_score))

        return {
            "entailment_score": entailment_score,
            "contradiction_score": contradiction_score,
            "obligation_coverage_score": obligation_coverage_score,
            "repass_score": repass_score
        }

# QA System with LLama

In [8]:
from google.colab import userdata

GROQ_API_KEY = userdata.get('GROQ_API_KEY')

class EnhancedQASystem:
    def __init__(self, max_token_length=512):
        self.max_token_length = max_token_length
        self.doc_id_to_passages = defaultdict(list)
        self.all_passages = []
        self.passage_to_doc_map = []
        self.bm25 = None
        self.repass_evaluator = RePASSEvaluator()

    def preprocess_text(self, text: str) -> str:
        """Enhanced text preprocessing."""
        text = re.sub(r'[^\w\s.,!?-]', '', text)
        text = ' '.join(text.split())
        return text.lower()

    def load_documents(self, data_dir: str):
        """Load and preprocess documents with better error handling."""
        for i in range(1, 41):
            try:
                with open(os.path.join(data_dir, f"{i}.json")) as f:
                    doc = json.load(f)
                    for passage in doc:
                        doc_id = str(passage["DocumentID"])
                        passage_text = self.preprocess_text(passage["Passage"])
                        self.doc_id_to_passages[doc_id].append(passage_text)
            except Exception as e:
                print(f"Error loading document {i}: {str(e)}")

        self._prepare_bm25_index()

    def _prepare_bm25_index(self):
        """Prepare BM25 index with smart chunking."""
        for doc_id, passages in self.doc_id_to_passages.items():
            for passage in passages:
                sentences = sent_tokenize(passage)
                current_chunk = []
                current_length = 0

                for sentence in sentences:
                    tokens = word_tokenize(sentence)
                    if current_length + len(tokens) > self.max_token_length:
                        if current_chunk:
                            self.all_passages.append(current_chunk)
                            self.passage_to_doc_map.append((doc_id, ' '.join(current_chunk)))
                        current_chunk = tokens
                        current_length = len(tokens)
                    else:
                        current_chunk.extend(tokens)
                        current_length += len(tokens)

                if current_chunk:
                    self.all_passages.append(current_chunk)
                    self.passage_to_doc_map.append((doc_id, ' '.join(current_chunk)))

        self.bm25 = BM25Okapi(self.all_passages, k1=1.5, b=0.75)

    def retrieve_passages(self, query: str, topk: int = 5) -> list:
        """Enhanced passage retrieval with relevance scoring."""
        preprocessed_query = self.preprocess_text(query)
        tokenized_query = word_tokenize(preprocessed_query)

        scores = self.bm25.get_scores(tokenized_query)
        query_terms = set(tokenized_query)

        def calculate_relevance_score(passage_text: str, bm25_score: float) -> float:
            passage_terms = set(word_tokenize(passage_text.lower()))
            term_overlap = len(query_terms.intersection(passage_terms))
            overlap_score = term_overlap / len(query_terms) if query_terms else 0
            return bm25_score * 0.7 + overlap_score * 0.3

        scored_passages = [
            (calculate_relevance_score(passage, score), doc_id, passage)
            for score, (doc_id, passage) in zip(scores, self.passage_to_doc_map)
        ]

        top_results = sorted(scored_passages, key=lambda x: x[0], reverse=True)[:topk]

        return [
            {
                "rank": rank + 1,
                "doc_id": doc_id,
                "score": score,
                "text": passage
            }
            for rank, (score, doc_id, passage) in enumerate(top_results)
        ]

    def generate_answer(self, query: str, retrieved_passages: list) -> str:
        """Generate answer using Groq API."""
        sorted_passages = sorted(retrieved_passages, key=lambda x: x['score'], reverse=True)
        context = " ".join([result['text'] for result in sorted_passages])

        # Prepare prompt
        prompt = (
            "Based on the following context, provide a detailed and structured answer to the question. "
            "Focus on specific requirements and procedures mentioned in the regulatory documents.\n\n"
            f"Question: {query}\n"
            f"Context: {context}\n"
            "Answer: Please provide a comprehensive response that directly addresses the question."
        )

        # Groq API request
        response = requests.post(
            "https://api.groq.com/openai/v1/chat/completions",
            headers={"Authorization": f"Bearer {GROQ_API_KEY}", "Content-Type": "application/json"},
            json={
                "model": "llama3-70b-8192",
                "messages": [{"role": "user", "content": prompt}]
            }
        )

        if response.status_code != 200:
            raise Exception(f"API request failed with status code {response.status_code}: {response.text}")

        answer = response.json()["choices"][0]["message"]["content"].strip()
        return self._post_process_answer(answer)

    def _post_process_answer(self, answer: str) -> str:
        """Enhanced answer post-processing with better structure and formatting."""
        answer = re.sub(r'Question:|Context:|Answer:', '', answer)
        answer = re.sub(r'\s+', ' ', answer).strip()
        answer = re.sub(r'\.+', '.', answer)
        answer = re.sub(r'\s+\.', '.', answer)

        sentences = sent_tokenize(answer)
        processed_sentences = []

        for sentence in sentences:
            sentence = sentence.capitalize()
            sentence = re.sub(r'\.+$', '.', sentence)
            if not sentence.endswith(('.', '?', '!')):
                sentence += '.'
            processed_sentences.append(sentence)

        answer = ' '.join(processed_sentences)

        if re.search(r'\d+\.|\-|\•', answer):
            lines = answer.split('. ')
            formatted_lines = []
            for line in lines:
                if re.match(r'^\d+\.|\-|\•', line.strip()):
                    formatted_lines.append('\n' + line)
                else:
                    formatted_lines.append(line)
            answer = '. '.join(formatted_lines)

        return answer

    def answer_query(self, query: str, topk: int = 5) -> Dict:
        """Comprehensive method to process query and return detailed results with RePASs evaluation."""
        retrieved_passages = self.retrieve_passages(query, topk)
        answer = self.generate_answer(query, retrieved_passages)

        source_passage_texts = [result['text'] for result in retrieved_passages]

        # Calculate RePASs score
        repass_evaluation = self.repass_evaluator.calculate_repass_score(source_passage_texts, answer)

        return {
            'answer': answer,
            'source_passages': retrieved_passages,
            'source_documents': list(set(p['doc_id'] for p in retrieved_passages)),
            'repass_evaluation': repass_evaluation
        }

def main():
    qa_system = EnhancedQASystem()
    qa_system.load_documents("ObliQADataset/StructuredRegulatoryDocuments")

    query = "What kind of documentation and verification does the FSRA require from a Mining Reporting Entity to prove adherence to the appropriate Mining Reporting Standard when disclosing Exploration Targets and Production Targets?"

    result = qa_system.answer_query(query)

    print("\nGenerated Answer:", result['answer'])
    print("\nRePASs Score:", result['repass_evaluation'])
    print("\nRetrieved Passages:")
    for passage in result['source_passages']:
        print(f"Rank: {passage['rank']}, Doc ID: {passage['doc_id']}, Score: {passage['score']:.4f}")
        print(f"Context: {passage['text'][:200]}...")

if __name__ == "__main__":
    main()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/419 [00:00<?, ?B/s]

spm.model:   0%|          | 0.00/2.46M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/18.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/156 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/1.05k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/283M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/222k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Generated Answer: Here is a detailed and structured answer to the question: **documentation and verification requirements for mining reporting entities** to prove adherence to the appropriate mining reporting standard when disclosing exploration targets and production targets, the financial services regulatory authority (fsra) requires the following documentation and verification from a mining reporting entity: **1. Disclosure statement**: the mining reporting entity must prepare a disclosure statement that includes a statement about exploration targets, exploration results, mineral resources, ore reserves, or production targets, in accordance with a mining reporting standard and the requirements of mkt chapter 11 (rule 11.2.1). **2. Compliance with mining reporting standard**: the disclosure statement must be prepared in accordance with a mining reporting standard, which is defined as a standard accepted by the fsra as a basis for reporting mineral resources and ore reserves, such as

## writing answers to json

In [37]:
from google.colab import userdata

GROQ_API_KEY = userdata.get('GROQ_API_KEY')

class EnhancedQASystem:
    def __init__(self, max_token_length=512):
        self.max_token_length = max_token_length
        self.doc_id_to_passages = defaultdict(list)
        self.all_passages = []
        self.passage_to_doc_map = []
        self.bm25 = None
        self.repass_evaluator = RePASSEvaluator()

    def preprocess_text(self, text: str) -> str:
        """Enhanced text preprocessing."""
        text = re.sub(r'[^\w\s.,!?-]', '', text)
        text = ' '.join(text.split())
        return text.lower()

    def load_documents(self, data_dir: str):
        """Load and preprocess documents with better error handling."""
        for i in range(1, 41):
            try:
                with open(os.path.join(data_dir, f"{i}.json")) as f:
                    doc = json.load(f)
                    for passage in doc:
                        doc_id = str(passage["DocumentID"])
                        passage_text = self.preprocess_text(passage["Passage"])
                        self.doc_id_to_passages[doc_id].append(passage_text)
            except Exception as e:
                print(f"Error loading document {i}: {str(e)}")

        self._prepare_bm25_index()

    def _prepare_bm25_index(self):
        """Prepare BM25 index with smart chunking."""
        for doc_id, passages in self.doc_id_to_passages.items():
            for passage in passages:
                sentences = sent_tokenize(passage)
                current_chunk = []
                current_length = 0

                for sentence in sentences:
                    tokens = word_tokenize(sentence)
                    if current_length + len(tokens) > self.max_token_length:
                        if current_chunk:
                            self.all_passages.append(current_chunk)
                            self.passage_to_doc_map.append((doc_id, ' '.join(current_chunk)))
                        current_chunk = tokens
                        current_length = len(tokens)
                    else:
                        current_chunk.extend(tokens)
                        current_length += len(tokens)

                if current_chunk:
                    self.all_passages.append(current_chunk)
                    self.passage_to_doc_map.append((doc_id, ' '.join(current_chunk)))

        self.bm25 = BM25Okapi(self.all_passages, k1=1.5, b=0.75)

    def retrieve_passages(self, query: str, topk: int = 5) -> list:
        """Enhanced passage retrieval with relevance scoring."""
        preprocessed_query = self.preprocess_text(query)
        tokenized_query = word_tokenize(preprocessed_query)

        scores = self.bm25.get_scores(tokenized_query)
        query_terms = set(tokenized_query)

        def calculate_relevance_score(passage_text: str, bm25_score: float) -> float:
            passage_terms = set(word_tokenize(passage_text.lower()))
            term_overlap = len(query_terms.intersection(passage_terms))
            overlap_score = term_overlap / len(query_terms) if query_terms else 0
            return bm25_score * 0.7 + overlap_score * 0.3

        scored_passages = [
            (calculate_relevance_score(passage, score), doc_id, passage)
            for score, (doc_id, passage) in zip(scores, self.passage_to_doc_map)
        ]

        top_results = sorted(scored_passages, key=lambda x: x[0], reverse=True)[:topk]

        return [
            {
                "rank": rank + 1,
                "doc_id": doc_id,
                "score": score,
                "text": passage
            }
            for rank, (score, doc_id, passage) in enumerate(top_results)
        ]

    def generate_answer(self, query: str, retrieved_passages: list) -> str:
        """Generate answer using Groq API."""
        sorted_passages = sorted(retrieved_passages, key=lambda x: x['score'], reverse=True)
        context = " ".join([result['text'] for result in sorted_passages])

        # Prepare prompt
        prompt = (
            "Based on the following context, provide a detailed and structured answer to the question. "
            "Focus on specific requirements and procedures mentioned in the regulatory documents.\n\n"
            f"Question: {query}\n"
            f"Context: {context}\n"
            "Answer: Please provide a comprehensive response that directly addresses the question."
        )

        # Groq API request
        response = requests.post(
            "https://api.groq.com/openai/v1/chat/completions",
            headers={"Authorization": f"Bearer {GROQ_API_KEY}", "Content-Type": "application/json"},
            json={
                "model": "llama3-70b-8192",
                "messages": [{"role": "user", "content": prompt}]
            }
        )

        if response.status_code != 200:
            raise Exception(f"API request failed with status code {response.status_code}: {response.text}")

        answer = response.json()["choices"][0]["message"]["content"].strip()
        return self._post_process_answer(answer)

    def _post_process_answer(self, answer: str) -> str:
        """Enhanced answer post-processing with better structure and formatting."""
        answer = re.sub(r'Question:|Context:|Answer:', '', answer)
        answer = re.sub(r'\s+', ' ', answer).strip()
        answer = re.sub(r'\.+', '.', answer)
        answer = re.sub(r'\s+\.', '.', answer)

        sentences = sent_tokenize(answer)
        processed_sentences = []

        for sentence in sentences:
            sentence = sentence.capitalize()
            sentence = re.sub(r'\.+$', '.', sentence)
            if not sentence.endswith(('.', '?', '!')):
                sentence += '.'
            processed_sentences.append(sentence)

        answer = ' '.join(processed_sentences)

        if re.search(r'\d+\.|\-|\•', answer):
            lines = answer.split('. ')
            formatted_lines = []
            for line in lines:
                if re.match(r'^\d+\.|\-|\•', line.strip()):
                    formatted_lines.append('\n' + line)
                else:
                    formatted_lines.append(line)
            answer = '. '.join(formatted_lines)

        return answer

    def answer_query(self, query: str, topk: int = 5) -> Dict:
        """Comprehensive method to process query and return detailed results with RePASs evaluation."""
        retrieved_passages = self.retrieve_passages(query, topk)
        answer = self.generate_answer(query, retrieved_passages)

        source_passage_texts = [result['text'] for result in retrieved_passages]

        # Calculate RePASs score
        repass_evaluation = self.repass_evaluator.calculate_repass_score(source_passage_texts, answer)

        output_data = [{
        "QuestionID": "Q1",
        "Question": query,
        "RetrievedPassages": source_passage_texts,
        "Answer": answer,
        "RepassEvaluation": repass_evaluation
        }]

        # Save the output to input.json
        with open('input.json', 'w') as f:
            json.dump(output_data, f, indent=4)

        return output_data

def main():
    qa_system = EnhancedQASystem()
    qa_system.load_documents("ObliQADataset/StructuredRegulatoryDocuments")

    query = "What kind of documentation and verification does the FSRA require from a Mining Reporting Entity to prove adherence to the appropriate Mining Reporting Standard when disclosing Exploration Targets and Production Targets?"

    result = qa_system.answer_query(query)

    print("\nGenerated Answer:", result[0]['Answer'])
    print("\nRePASs Score:", result[0]['RepassEvaluation'])

if __name__ == "__main__":
    main()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Generated Answer: To prove adherence to the appropriate mining reporting standard when disclosing exploration targets and production targets, the fsra requires a mining reporting entity to provide specific documentation and verification as outlined in the regulatory documents. **documentation requirements:** 1. Disclosure prepared in accordance with a mining reporting standard (rule 11.2.1): the mining reporting entity must prepare disclosures related to exploration targets, exploration results, mineral resources, ore reserves, or production targets in accordance with a mining reporting standard, such as the australasian code for reporting of exploration results, mineral resources and ore reserves (jorc code) or the canadian institute of mining, metallurgy and petroleum (cim) definition standards. 2. Compliance with mkt chapter 11 requirements (rule 11.2.1): the disclosure must also comply with the requirements of mkt chapter 11, which outlines the disclosure obligations for mining re

In [38]:
!pip freeze > requirements.txt

from google.colab import files

files.download("requirements.txt")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [42]:
from sentence_transformers import SentenceTransformer, util
import numpy as np

class EnhancedQASystem:
    def __init__(self, max_token_length=512):
        self.max_token_length = max_token_length
        self.doc_id_to_passages = defaultdict(list)
        self.all_passages = []
        self.passage_to_doc_map = []
        self.passage_embeddings = None
        self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
        self.repass_evaluator = RePASSEvaluator()

    def preprocess_text(self, text: str) -> str:
        text = re.sub(r'[^\w\s.,!?-]', '', text)
        text = ' '.join(text.split())
        return text.lower()

    def load_documents(self, data_dir: str):
        for i in range(1, 41):
            try:
                with open(os.path.join(data_dir, f"{i}.json")) as f:
                    doc = json.load(f)
                    for passage in doc:
                        doc_id = str(passage["DocumentID"])
                        passage_text = self.preprocess_text(passage["Passage"])
                        self.doc_id_to_passages[doc_id].append(passage_text)
            except Exception as e:
                print(f"Error loading document {i}: {str(e)}")

        self._prepare_vector_index()

    def _prepare_vector_index(self):
        """Generate embeddings for all passages."""
        passages_text = []
        for doc_id, passages in self.doc_id_to_passages.items():
            for passage in passages:
                passages_text.append(passage)
                self.passage_to_doc_map.append((doc_id, passage))

        # Embed all passages and store them
        self.passage_embeddings = self.model.encode(passages_text, convert_to_tensor=True)

    def retrieve_passages(self, query: str, topk: int = 5) -> list:
        """Retrieve passages using vector-based similarity."""
        preprocessed_query = self.preprocess_text(query)
        query_embedding = self.model.encode(preprocessed_query, convert_to_tensor=True)

        # Calculate cosine similarity
        cos_scores = util.pytorch_cos_sim(query_embedding, self.passage_embeddings)[0]
        top_results = torch.topk(cos_scores, k=topk)

        return [
            {
                "rank": rank + 1,
                "doc_id": self.passage_to_doc_map[idx][0],
                "score": cos_scores[idx].item(),
                "text": self.passage_to_doc_map[idx][1]
            }
            for rank, idx in enumerate(top_results.indices)
        ]

    def generate_answer(self, query: str, retrieved_passages: list) -> str:
        """Generate answer using Groq API."""
        sorted_passages = sorted(retrieved_passages, key=lambda x: x['score'], reverse=True)
        context = " ".join([result['text'] for result in sorted_passages])

        # Prepare prompt
        prompt = (
            "Based on the following context, provide a detailed and structured answer to the question. "
            "Focus on specific requirements and procedures mentioned in the regulatory documents.\n\n"
            f"Question: {query}\n"
            f"Context: {context}\n"
            "Answer: Please provide a comprehensive response that directly addresses the question."
        )

        # Groq API request
        response = requests.post(
            "https://api.groq.com/openai/v1/chat/completions",
            headers={"Authorization": f"Bearer {GROQ_API_KEY}", "Content-Type": "application/json"},
            json={
                "model": "llama3-70b-8192",
                "messages": [{"role": "user", "content": prompt}]
            }
        )

        if response.status_code != 200:
            raise Exception(f"API request failed with status code {response.status_code}: {response.text}")

        answer = response.json()["choices"][0]["message"]["content"].strip()
        return self._post_process_answer(answer)

    def _post_process_answer(self, answer: str) -> str:
        """Enhanced answer post-processing with better structure and formatting."""
        answer = re.sub(r'Question:|Context:|Answer:', '', answer)
        answer = re.sub(r'\s+', ' ', answer).strip()
        answer = re.sub(r'\.+', '.', answer)
        answer = re.sub(r'\s+\.', '.', answer)

        sentences = sent_tokenize(answer)
        processed_sentences = []

        for sentence in sentences:
            sentence = sentence.capitalize()
            sentence = re.sub(r'\.+$', '.', sentence)
            if not sentence.endswith(('.', '?', '!')):
                sentence += '.'
            processed_sentences.append(sentence)

        answer = ' '.join(processed_sentences)

        if re.search(r'\d+\.|\-|\•', answer):
            lines = answer.split('. ')
            formatted_lines = []
            for line in lines:
                if re.match(r'^\d+\.|\-|\•', line.strip()):
                    formatted_lines.append('\n' + line)
                else:
                    formatted_lines.append(line)
            answer = '. '.join(formatted_lines)

        return answer

    def answer_query(self, query: str, topk: int = 5) -> Dict:
        """Comprehensive method to process query and return detailed results with RePASs evaluation."""
        retrieved_passages = self.retrieve_passages(query, topk)
        answer = self.generate_answer(query, retrieved_passages)

        source_passage_texts = [result['text'] for result in retrieved_passages]

        # Calculate RePASs score
        repass_evaluation = self.repass_evaluator.calculate_repass_score(source_passage_texts, answer)

        return {
            'answer': answer,
            'source_passages': retrieved_passages,
            'source_documents': list(set(p['doc_id'] for p in retrieved_passages)),
            'repass_evaluation': repass_evaluation
        }

def main():
    qa_system = EnhancedQASystem()
    qa_system.load_documents("ObliQADataset/StructuredRegulatoryDocuments")

    query = "What kind of documentation and verification does the FSRA require from a Mining Reporting Entity to prove adherence to the appropriate Mining Reporting Standard when disclosing Exploration Targets and Production Targets?"

    result = qa_system.answer_query(query)

    print("\nGenerated Answer:", result['answer'])
    print("\nRePASs Score:", result['repass_evaluation'])

if __name__ == "__main__":
    main()

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Generated Answer: According to the fsra's regulatory documents, when disclosing exploration targets and production targets, a mining reporting entity is required to provide specific documentation and verification to prove adherence to the appropriate mining reporting standard. The following requirements and procedures must be fulfilled: 1. **competent person statement**: the fsra expects to see disclosure of the appropriate competent person statement in relation to the original disclosure of the estimates of ore reserves and/or mineral resources. This statement must be prepared by a competent person in accordance with a mining reporting standard. 2. **disclosure of exploration targets and production targets**: any disclosure by a mining reporting entity that includes a statement about exploration targets, exploration results, mineral resources, ore reserves, or production targets must be prepared in accordance with a mining reporting standard and in accordance with the requirements of