# ‚öñÔ∏è Legal QA Domain ‚Äì Low-Resource Survey Notebook

This notebook accompanies the paper *"QA Analysis in Medical and Legal Domains: A Survey of Data Augmentation in Low-Resource Settings"* and focuses on the **legal domain**.

We analyze and visualize the characteristics of various legal QA datasets, compute semantic similarity using domain-specific embeddings.

In [None]:
from datasets import concatenate_datasets, load_dataset
import datasets

from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity 

import matplotlib.pyplot as plt
from matplotlib_venn import venn3
import seaborn as sns

from collections import Counter
import umap.umap_ as umap
import pandas as pd 
import numpy as np
import random
import spacy
import json
import umap
import os
import re

import nltk
from nltk.corpus import stopwords

import torch
print("CUDA available:", torch.cuda.is_available())
print("CUDA device count:", torch.cuda.device_count())

## üîß NLTK & spaCy Setup
# Download necessary NLTK resources
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('punkt_tab')  

# Load spaCy English model
nlp = spacy.load("en_core_web_sm")
stop_words = set(stopwords.words("english"))

## üß† Sentence Transformer Model Setup
model_name = "infly/inf-retriever-v1-1.5b"
use_gpu = torch.cuda.is_available()

model_legal = SentenceTransformer(
    model_name,
    trust_remote_code=True,
    device="cuda" if use_gpu else "cpu"
)

## üì• Load and Prepare QA Texts

In [None]:
# Load MMLU
def load_and_prepare_mmlu(*categories):
    texts = []
    for category in categories:
        dataset = load_dataset("cais/mmlu", category)
        combined_dataset = concatenate_datasets(
            [dataset[split] for split in ["test", "validation", "dev"] if split in dataset]
        )
        texts.extend(
            f"{row['question']}\n\n{row['choices'][row['answer']]}"
            for row in combined_dataset
            if row['choices'] and row['answer'] is not None
        )
    return texts

mmlu_categories = ["international_law", "jurisprudence", "logical_fallacies", 
                   "moral_disputes", "moral_scenarios", "professional_law", 
                   "public_relations", "us_foreign_policy"]

mmlu_texts = load_and_prepare_mmlu(*mmlu_categories)

In [None]:
# Load PolicyQA
def load_policyqa_json(file_path):
    """
    Load a SQuAD-style JSON file (dev/test/train) and return a list of
    "<question>\\n\\n<first_answer_text>" strings.
    """
    with open(file_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    texts = []
    for article in data.get("data", []):
        for para in article.get("paragraphs", []):
            for qa in para.get("qas", []):
                question = qa.get("question")
                answers = qa.get("answers", [])
                if question and answers:
                    answer_text = answers[0].get("text")
                    if answer_text:
                        texts.append(f"{question}\n\n{answer_text}")
    return texts

policyqa = load_policyqa_json("data/policyqa.json")

In [None]:
# Load PolicyQA
def load_policyqa_csv(file_path):
    df = pd.read_csv(file_path, delimiter="\t")  
    df = df[df['Label'] == "Relevant"]
    texts = [
        f"{row['Query']}\n\n{row['Segment']}"
        for _, row in df.iterrows()
        if pd.notna(row['Segment'])   
    ]
    return texts

privacyqa_texts = load_policyqa_csv("data/privacyqa.csv")

In [None]:
# Load TruthfulQA
truthful_df = pd.read_csv("data/TruthfulQA.csv")

truthful_texts = [
    f"{row['Question']}\n\n{row['Best Answer']}"
    for _, row in truthful_df.iterrows()
    if pd.notnull(row['Question']) and pd.notnull(row['Best Answer'])
]

In [None]:
len(mmlu_texts), len(truthful_texts), len(policyqa), len(privacyqa_texts)

## üîç Generate Embeddings for Each QA Dataset

In [None]:
# Encode each dataset using the SentenceTransformer model
mmlu_embeddings = model_legal.encode(mmlu_texts, batch_size=4, device="cuda")
torch.cuda.empty_cache()

truthful_embeddings = model_legal.encode(truthful_texts, batch_size=4, device="cuda")
torch.cuda.empty_cache()

privacyqa_embeddings = model_legal.encode(privacyqa_texts, batch_size=4, device="cuda")
torch.cuda.empty_cache()

policyqa_embeddings = model_legal.encode(policyqa, batch_size=4, device="cuda")
torch.cuda.empty_cache()

In [None]:
# Compare each low-resource QA dataset to the parent corpus
policyqa_similarity = cosine_similarity(mmlu_embeddings, policyqa_embeddings).flatten()
truthful_similarity = cosine_similarity(mmlu_embeddings, truthful_embeddings).flatten()
privacyqa_similarity = cosine_similarity(mmlu_embeddings, privacyqa_embeddings).flatten()

## üìä Cosine Similarity Distribution Between Parent and Target Datasets

In [None]:
plt.figure(figsize=(10, 5))
plt.hist(policyqa_similarity, label="PolicyQA", color="orange", fill=True, bins=30, alpha=0.5)
plt.hist(truthful_similarity, label="TruthfulQA", color="green", fill=True, bins=30, alpha=0.5)
plt.hist(privacyqa_similarity, label="PrivacyQA", color="red", fill=True, bins=30, alpha=0.5)

plt.xlabel("Cosine Similarity")
plt.ylabel("Density (Probability Density)")

plt.legend()
plt.show()

## üßπ Preprocessing & Vocabulary Extraction

In [None]:
def preprocess_and_tokenize(text): 
    text = text.lower()
    text = re.sub(r"[^a-z0-9\s]", " ", text)
    tokens = nltk.word_tokenize(text)
    tokens = [t for t in tokens if t not in stop_words and len(t) > 1]
    return tokens

def get_frequency_counter(texts):
    c = Counter()
    for txt in texts:
        tokens = preprocess_and_tokenize(txt)
        c.update(tokens)
    return c

# Frequency distributions
freq_mmlu        = get_frequency_counter(mmlu_texts)
freq_truthful   = get_frequency_counter(truthful_texts)
freq_privacyqa   = get_frequency_counter(privacyqa_texts)
freq_policyqa   = get_frequency_counter(policyqa)

# Convert to vocab sets (for OOV / overlap analysis)
vocab_mmlu        = set(freq_mmlu.keys())
vocab_privacyqa  = set(freq_privacyqa.keys())
vocab_policyqa  = set(freq_policyqa.keys())
vocab_truthful   = set(freq_truthful.keys())

new_vocab_policyqa  = vocab_policyqa  - vocab_mmlu
new_vocab_truthful  = vocab_truthful  - vocab_mmlu
new_vocab_privacyqa  = vocab_privacyqa  - vocab_mmlu

In [None]:
datasets = ["TruthfulQA", "PolicyQA", "PrivacyQA"]
vocab_sets = [new_vocab_truthful, new_vocab_policyqa, new_vocab_privacyqa]

overlap_matrix = np.zeros((3,3))
for i in range(3):
    for j in range(3):
        overlap_matrix[i, j] = len(vocab_sets[i] & vocab_sets[j])   

# Plot heatmap
plt.figure(figsize=(6, 5))
sns.heatmap(overlap_matrix, annot=True, xticklabels=datasets, yticklabels=datasets, cmap="Blues", fmt=".0f")
plt.show()

In [None]:
all_embeddings = np.vstack([
    policyqa_embeddings,
    truthful_embeddings,
    privacyqa_embeddings
])
labels = (
    ["ParentQA"] * len(mmlu_embeddings) +
    ["PolicyQA"] * len(policyqa_embeddings) +
    ["TruthfulQA"] * len(truthful_embeddings) +
    ["PrivacyQA"] * len(privacyqa_embeddings)

)

def tokenize_corpus(texts):
    return [token for doc in texts for token in preprocess_and_tokenize(doc)]

parent_vocab = set(tokenize_corpus(mmlu_texts))
for name, texts in [
    ("ParentQA", mmlu_texts),
    ("PolicyQA", policyqa),
    ("TruthfulQA", truthful_texts),
    ("PrivacyQA", privacyqa_texts)
]:
    vocab = set(tokenize_corpus(texts))
    oov = vocab - parent_vocab
    oov_rate = len(oov) / len(vocab)
    print(f"{name} ‚Äî vocab size: {len(vocab):5d}, OOV size: {len(oov):5d}, OOV rate: {oov_rate:.2%}")
    print("  ‚Üí exemples d'OOV:", list(oov)[:10])

In [None]:
# Calculate the Entropy
for name, texts in [
    ("ParentQA", mmlu_texts),
    ("PolicyQA", policyqa),
    ("TruthfulQA", truthful_texts),
    ("PrivacyQA",  privacyqa_texts)
]:
    tokens = tokenize_corpus(texts)
    freq   = Counter(tokens)
    ranks, counts = zip(*freq.most_common())
    p = np.array(list(freq.values()), dtype=float)
    p /= p.sum()
    H = -np.sum(p * np.log2(p))
    print(f"{name} entropie Shannon: {H:.2f} bits")

In [None]:
reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, metric="cosine", n_components=2, random_state=42)
proj = reducer.fit_transform(all_embeddings)

palette = {"ParentQA":"gray","PolicyQA":"C0","TruthfulQA":"C1","PrivacyQA":"C2"}
colors  = [palette[l] for l in labels]

plt.figure(figsize=(6,5))
idx_parent = [i for i, l in enumerate(labels) if l == "ParentQA"]
plt.scatter(proj[idx_parent, 0], proj[idx_parent, 1], c=palette["ParentQA"], s=5, alpha=0.3, label="ParentQA")

for corpus, color in palette.items():
    if corpus == "ParentQA":
        continue
    idx = [i for i, l in enumerate(labels) if l == corpus]
    plt.scatter(proj[idx, 0], proj[idx, 1], c=color, s=5, alpha=0.6, label=corpus)

plt.legend(markerscale=2)
plt.title("Projection UMAP des embeddings")
plt.xlabel("UMAP‚Äê1")
plt.ylabel("UMAP‚Äê2")
plt.tight_layout()
plt.show()