# 🩺 Medical QA Domain – Low-Resource Survey Notebook

This notebook accompanies the paper *"QA Analysis in Medical and Legal Domains: A Survey of Data Augmentation in Low-Resource Settings"* and focuses on the **medical domain**.

We analyze and visualize the characteristics of various medical QA datasets, compute semantic similarity using domain-specific embeddings.

In [None]:
from datasets import concatenate_datasets, load_dataset
import datasets

from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity 

import matplotlib.pyplot as plt
from matplotlib_venn import venn3
import seaborn as sns

from collections import Counter
import umap.umap_ as umap
import pandas as pd 
import numpy as np
import random
import spacy
import json
import umap
import os
import re

import nltk
from nltk.corpus import stopwords

import torch
print("CUDA available:", torch.cuda.is_available())
print("CUDA device count:", torch.cuda.device_count())

## 🔧 NLTK & spaCy Setup
# Download necessary NLTK resources
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('punkt_tab')  

# Load spaCy English model
nlp = spacy.load("en_core_web_sm")
stop_words = set(stopwords.words("english"))

## 🧠 Sentence Transformer Model Setup
model_name = "NovaSearch/jasper_en_vision_language_v1"
use_gpu = torch.cuda.is_available()

model_medical = SentenceTransformer(
    model_name,
    trust_remote_code=True,
    device="cuda" if use_gpu else "cpu",
    model_kwargs={
        "torch_dtype": torch.bfloat16 if use_gpu else torch.float32,
        "attn_implementation": "sdpa"
    },
    config_kwargs={"is_text_encoder": True},
)

## 📥 Load and Prepare QA Texts

In [None]:
# Load MedMCQA
medmcqa = datasets.load_dataset("openlifescienceai/medmcqa")
medmcqa_train = concatenate_datasets([medmcqa["train"]])
medmcqa_texts = []
for ex in medmcqa_train:
    q, e = ex.get("question"), ex.get("exp")
    if q is None or e is None:
        continue
    
    if e.startswith("Ans."):
        e = e[4:].strip()  
    text = f"{q}\n\n{e}"
    medmcqa_texts.append(text)

# Load CareQA
careqa = load_dataset("HPAI-BSC/CareQA", "CareQA_en_open")
available_splits = careqa.keys()
merged_ds = {
    split: concatenate_datasets([careqa[split]]) for split in available_splits 
}

careqa_texts = [
    f"{example['question']}\n\n{example['answer']}"
    for split in merged_ds.values()
    for example in split
]

# Load COVID-QA
with open("data/COVID-QA.json", "r", encoding="utf-8") as file:
    json_data = json.load(file)

covidqa_texts = [
    f"{qa['question']}\n\n{qa['answers'][0]['text']}"
    for entry in json_data['data']
    for paragraph in entry['paragraphs']
    for qa in paragraph['qas']
    if qa['answers']  
]

# Load ReDis-QA 
redis_data = load_dataset("guan-wang/ReDis-QA")
option_map = {1: "opa", 2: "opb", 3: "opc", 4: "opd"}
redisqa_texts = []
for entry in redis_data['test']:
    try:
        if entry['cop'] == 0:
            continue
        correct_option = option_map.get(entry['cop'])  
        if correct_option and correct_option in entry:   
            redisqa_texts.append(f"{entry['question']}\n\n{entry[correct_option]}")
        else:
            print(f"Skipping entry due to unexpected 'cop' value: {entry}")
    except KeyError as e:
        print(f"Skipping entry due to missing key: {e}")

# Load MedBullets
df1 = pd.read_csv("data/medbullets_op4.csv")
df2 = pd.read_csv("data/medbullets_op5.csv")
combined_df = pd.concat([df1, df2], ignore_index=True)

medbullets_texts = [
    f"{row['question']}\n\n{row['answer']}"
    for _, row in combined_df.iterrows()
    if pd.notna(row['question']) and pd.notna(row['answer'])  
]

big_data = medmcqa_texts 

len(big_data), len(covidqa_texts), len(careqa_texts), len(redisqa_texts), len(medbullets_texts)

## 🔍 Generate Embeddings for Each QA Dataset

In [None]:
# Encode each dataset using the SentenceTransformer model
covidqa_embeddings = model_medical.encode(covidqa_texts, show_progress_bar=True, device="cuda")
careqa_embeddings = model_medical.encode(careqa_texts, show_progress_bar=True, device="cuda")
redisqa_embeddings = model_medical.encode(redisqa_texts, show_progress_bar=True, device="cuda")
medbullets_embeddings = model_medical.encode(medbullets_texts, show_progress_bar=True, device="cuda")
parent_embeddings = model_medical.encode(big_data, show_progress_bar=True, device="cuda")

parent_embeddings.shape, covidqa_embeddings.shape, careqa_embeddings.shape, redisqa_embeddings.shape, medbullets_embeddings.shape

In [None]:
# Compare each low-resource QA dataset to the parent corpus
covidqa_similarity = cosine_similarity(parent_embeddings, covidqa_embeddings).flatten()
careqa_similarity = cosine_similarity(parent_embeddings, careqa_embeddings).flatten()
redisqa_similarity = cosine_similarity(parent_embeddings, redisqa_embeddings).flatten()
medbullets_similarity = cosine_similarity(parent_embeddings, medbullets_embeddings).flatten()

## 📊 Cosine Similarity Distribution Between Parent and Target Datasets

In [None]:
plt.figure(figsize=(10, 5))
plt.hist(covidqa_similarity, bins=30, alpha=0.5, label="CovidQA", density=True)
plt.hist(careqa_similarity, bins=30, alpha=0.5, label="CareQA", density=True)
plt.hist(redisqa_similarity, bins=30, alpha=0.5, label="ReDis-QA", density=True)
plt.hist(medbullets_similarity, bins=30, alpha=0.8, label="Medbullets", density=True)

plt.xlabel("Cosine Similarity")
plt.ylabel("Density (Probability Density)")

plt.legend()
plt.show()

## 🧹 Preprocessing & Vocabulary Extraction

In [None]:
def preprocess_and_tokenize(text): 
    text = text.lower()
    text = re.sub(r"[^a-z0-9\s]", " ", text)
    tokens = nltk.word_tokenize(text)
    tokens = [t for t in tokens if t not in stop_words and len(t) > 1]
    return tokens

def get_frequency_counter(texts):
    c = Counter()
    for txt in texts:
        tokens = preprocess_and_tokenize(txt)
        c.update(tokens)
    return c

# Frequency distributions
freq_big       = get_frequency_counter(big_data)
freq_covidqa   = get_frequency_counter(covidqa_texts)
freq_careqa    = get_frequency_counter(careqa_texts)
freq_redisqa   = get_frequency_counter(redisqa_texts)
freq_medbullets   = get_frequency_counter(medbullets_texts)

# Convert to vocab sets (for OOV / overlap analysis)
vocab_big       = set(freq_big.keys())
vocab_covidqa   = set(freq_covidqa.keys())
vocab_careqa    = set(freq_careqa.keys())
vocab_redisqa   = set(freq_redisqa.keys())
vocab_medbullets   = set(freq_medbullets.keys())

new_vocab_covidqa = vocab_covidqa - vocab_big
new_vocab_careqa  = vocab_careqa  - vocab_big
new_vocab_redisqa = vocab_redisqa - vocab_big
new_vocab_medbullets = vocab_medbullets - vocab_big

In [None]:
datasets = ["COVIDQA", "CareQA", "RedisQA", "MedBullets"]
vocab_sets = [new_vocab_covidqa, new_vocab_careqa, new_vocab_redisqa, new_vocab_medbullets]

overlap_matrix = np.zeros((4, 4))
for i in range(4):
    for j in range(4):
        overlap_matrix[i, j] = len(vocab_sets[i] & vocab_sets[j])   

# Plot heatmap
plt.figure(figsize=(6, 5))
sns.heatmap(overlap_matrix, annot=True, xticklabels=datasets, yticklabels=datasets, cmap="Blues", fmt=".0f")
plt.show()

In [None]:
all_embeddings = np.vstack([
    covidqa_embeddings,
    careqa_embeddings,
    redisqa_embeddings,
    medbullets_embeddings,
    parent_embeddings
])
labels = (
    ["COVIDQA"] * len(covidqa_embeddings) +
    ["CareQA"]  * len(careqa_embeddings)  +
    ["ReDisQA"] * len(redisqa_embeddings) +
    ["MedBullets"] * len(medbullets_embeddings) +
    ["ParentQA"] * len(parent_embeddings)
)

def tokenize_corpus(texts):
    return [token for doc in texts for token in preprocess_and_tokenize(doc)]

parent_vocab = set(tokenize_corpus(big_data))
for name, texts in [
    ("ParentQA",    big_data),
    ("COVIDQA",    covidqa_texts),
    ("CareQA",     careqa_texts),
    ("ReDisQA",    redisqa_texts),
    ("MedBullets", medbullets_texts),
]:
    vocab = set(tokenize_corpus(texts))
    oov = vocab - parent_vocab
    oov_rate = len(oov) / len(vocab)
    print(f"{name} — vocab size: {len(vocab):5d}, OOV size: {len(oov):5d}, OOV rate: {oov_rate:.2%}")
    print("  → exemples d'OOV:", list(oov)[:10])

In [None]:
# Calculate the Entropy
for name, texts in [
    ("ParentQA",    big_data),
    ("COVIDQA",     covidqa_texts),
    ("CareQA",      careqa_texts),
    ("ReDisQA",     redisqa_texts),
    ("MedBullets",  medbullets_texts),
]:
    tokens = tokenize_corpus(texts)
    freq   = Counter(tokens)
    ranks, counts = zip(*freq.most_common())
    p = np.array(list(freq.values()), dtype=float)
    p /= p.sum()
    H = -np.sum(p * np.log2(p))
    print(f"{name} entropie Shannon: {H:.2f} bits")

In [None]:
reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, metric="cosine", n_components=2, random_state=42)
proj = reducer.fit_transform(all_embeddings)

palette = {"ParentQA":"gray","COVIDQA":"C0","CareQA":"C1","ReDisQA":"C2","MedBullets":"C3"}
colors  = [palette[l] for l in labels]

plt.figure(figsize=(6,5))
idx_parent = [i for i, l in enumerate(labels) if l == "ParentQA"]
plt.scatter(proj[idx_parent, 0], proj[idx_parent, 1], c=palette["ParentQA"], s=5, alpha=0.3, label="ParentQA")

for corpus, color in palette.items():
    if corpus == "ParentQA":
        continue
    idx = [i for i, l in enumerate(labels) if l == corpus]
    plt.scatter(proj[idx, 0], proj[idx, 1], c=color, s=5, alpha=0.6, label=corpus)

plt.legend(markerscale=2)
plt.title("Projection UMAP des embeddings")
plt.xlabel("UMAP‐1")
plt.ylabel("UMAP‐2")
plt.tight_layout()
plt.show()