Requirements

Download dataset


---



In [None]:
!pip install datasets

from datasets import load_dataset
import json
import os
from pathlib import Path


In [None]:
dataset = load_dataset(
    "ccdv/pubmed-summarization",
    split="train"
)

OUT_DIR = "/content/"

Path(OUT_DIR).mkdir(parents=True, exist_ok=True)

OUT_FILE = os.path.join(OUT_DIR, "pubmed_500.json")
final_data = []

N = 512
for i, ex in enumerate(dataset.select(range(N))):
    article = ex["article"].strip()
    abstract = ex["abstract"].strip()

    if not article or not abstract:
        continue

    entry = {
        "stringID": f"pubmed_{i:06d}",
        "source": article,
        "summary": abstract,
    }

    final_data.append(entry)

print("Total examples:", len(final_data))


In [None]:
with open(OUT_FILE, "w", encoding="utf-8") as f:
    json.dump(final_data, f, ensure_ascii=False, indent=2)

print("Saved to:", OUT_FILE)

Saved to: /content/askqe_pubmed/pubmed_500.json


FACT EXTRACTION


---



In [None]:
!pip install -U transformers bitsandbytes accelerate

In [None]:
import json
import os
import ast
import re
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig


In [None]:
PUBMED_FILE = "/content/pubmed_500.json"

with open(PUBMED_FILE, "r", encoding="utf-8") as f:
    data_pubmed = json.load(f)

print("Loaded:", len(data_pubmed))
print(data_pubmed[0].keys())

Loaded: 500
dict_keys(['stringID', 'source', 'summary'])


In [None]:
MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token

quant_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_enable_fp32_cpu_offload=True
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=quant_config,
    device_map="auto",
    torch_dtype=torch.float16
)

model.eval()
print("Model loaded")


Downloading (incomplete total...): 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/291 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

Model loaded


In [None]:
def run_llm(prompt, max_new_tokens=256):
    messages = [{"role": "user", "content": prompt}]

    enc = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        return_tensors="pt"
    )

    input_ids = enc["input_ids"].to(model.device)
    attention_mask = enc["attention_mask"].to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )

    return tokenizer.decode(
        outputs[0][input_ids.shape[-1]:],
        skip_special_tokens=True
    ).strip()


In [None]:
FACT_PROMPT = """You are extracting facts from a biomedical research article.

Rules:
- Extract ONLY facts that are explicitly stated.
- Each fact must be a single, atomic statement.
- Include numerical values when present.
- Do NOT explain, summarize, or infer.
- If no facts are found, output an empty Python list [].

Output format:
A valid Python list of strings.

Text:
{sentence}

Facts:
"""


In [None]:
def extract_facts(text):
    raw = run_llm(FACT_PROMPT.format(sentence=text))
    try:
        parsed = ast.literal_eval(raw)
        if isinstance(parsed, list):
            return [f.strip() for f in parsed if isinstance(f, str)]
    except:
        pass

    facts = []
    for line in raw.split("\n"):
        line = line.strip()
        line = re.sub(r'^["\']|["\']$', '', line)
        if len(line) > 20 and line.count(" ") > 4:
            facts.append(line)
    return facts

def deduplicate_facts(facts):
    seen = set()
    out = []
    for f in facts:
        f_norm = f.lower().strip()
        if f_norm not in seen:
            seen.add(f_norm)
            out.append(f)
    return out

def chunk_text_uniform(text, max_chars=1500, n_chunks=5):
    L = len(text)
    positions = [
        int(L * p) for p in [0.0, 0.25, 0.5, 0.75, 0.9]
    ]
    chunks = []
    for pos in positions:
        chunk = text[pos:pos + max_chars]
        if len(chunk) > 300:
            chunks.append(chunk)
    return chunks[:n_chunks]

def extract_facts_from_document(text, facts_per_chunk=5, max_facts=15):
    chunks = chunk_text_uniform(text)
    all_facts = []

    for ch in chunks:
        facts = extract_facts(ch)
        all_facts.extend(facts[:facts_per_chunk])

    facts_unique = deduplicate_facts(all_facts)
    return facts_unique[:max_facts]



In [None]:
START_IDX = 400
END_IDX = 500

def get_idx(stringID):
    return int(stringID.split("_")[1])


In [None]:
OUT_FILE = "/content/pubmed_atomic_facts.jsonl"

with open(OUT_FILE, "a", encoding="utf-8") as f:
    for ex in data_pubmed:
        sid = ex["stringID"]
        idx = get_idx(sid)

        if idx < START_IDX or idx >= END_IDX:
            continue

        src = ex["source"]
        facts = extract_facts_from_document(src)

        record = {
            "stringID": sid,
            "facts_raw": facts
        }

        f.write(json.dumps(record, ensure_ascii=False) + "\n")
        f.flush()

        print(sid, "→", len(facts), "facts")


NLI filtering


---



In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn.functional as F
import json
import os


In [None]:
OUT_DIR = "/content/"

OUT_FACTS_FILE = os.path.join(
    OUT_DIR,
    "pubmed_atomic_facts.jsonl"
)

In [None]:
FACTS_ENTAILED_FILE = os.path.join(
    OUT_DIR,
    "pubmed_facts_entailed.jsonl"
)

In [None]:
NLI_MODEL = "roberta-large-mnli"

nli_tokenizer = AutoTokenizer.from_pretrained(NLI_MODEL)
nli_model = AutoModelForSequenceClassification.from_pretrained(NLI_MODEL)
nli_model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nli_model.to(device)

print("NLI model loaded on", device)


In [None]:
def is_entailed_sliding(source, fact, window_chars=1000, step=500, threshold=0.4):
    for i in range(0, len(source), step):
        window = source[i:i+window_chars]

        inputs = nli_tokenizer(
            window,
            fact,
            return_tensors="pt",
            truncation=True,
            max_length=512
        ).to(device)

        with torch.no_grad():
            logits = nli_model(**inputs).logits

        probs = F.softmax(logits, dim=-1)
        if probs[0][2].item() >= threshold:
            return True

    return False


In [None]:
PUBMED_FILE = os.path.join(OUT_DIR, "pubmed_500.json")

with open(PUBMED_FILE, "r", encoding="utf-8") as f:
    data_pubmed = json.load(f)

id_to_source = {
    ex["stringID"]: ex["source"]
    for ex in data_pubmed
}

In [None]:
with open(OUT_FACTS_FILE, "r", encoding="utf-8") as fin, \
     open(FACTS_ENTAILED_FILE, "w", encoding="utf-8") as fout:

    for line in fin:
        obj = json.loads(line)

        sid = obj["stringID"]
        facts = obj["facts_raw"]
        source = id_to_source[sid]

        entailed = []

        for fact in facts:
            try:
                if is_entailed_sliding(source, fact):
                    entailed.append(fact)
            except Exception as e:
                print("NLI error on", sid, fact[:50], e)

        out_record = {
            "stringID": sid,
            "facts_entailed": entailed
        }

        fout.write(json.dumps(out_record, ensure_ascii=False) + "\n")
        fout.flush()

        print(sid, "→", len(entailed), "/", len(facts), "entailed")


Load model for Question Generation and Answering


---



In [None]:
import torch


In [None]:
!pip install -U transformers bitsandbytes accelerate


In [None]:
import json
import os
import ast
import re
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig


In [None]:
MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token

quant_config = BitsAndBytesConfig(
    load_in_8bit=True,
    llm_int8_enable_fp32_cpu_offload=True
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=quant_config,
    device_map="auto",
    torch_dtype=torch.float16
)

model.eval()
print("Model loaded")


In [None]:
def run_llm(prompt, max_new_tokens=256):
    messages = [{"role": "user", "content": prompt}]

    enc = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        return_tensors="pt"
    )

    input_ids = enc["input_ids"].to(model.device)
    attention_mask = enc["attention_mask"].to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )

    return tokenizer.decode(
        outputs[0][input_ids.shape[-1]:],
        skip_special_tokens=True
    ).strip()


Questions Generation


---



In [None]:
QG_PROMPT = """You are generating factual questions for evaluation.

Given a list of factual statements extracted from a biomedical article,
generate one clear, specific question for each fact.

Rules:
- Do NOT include the answer.
- Avoid yes/no questions.
- Use precise biomedical terminology.
- If a fact cannot be turned into a good question, skip it.

Output format:
A Python list of questions.

Facts:
{facts}

Questions:
"""


In [None]:
import ast

def generate_questions_from_facts(facts, max_new_tokens=256):
    prompt = QG_PROMPT.format(facts="\n".join(facts))
    raw = run_llm(prompt, max_new_tokens=max_new_tokens)

    try:
        parsed = ast.literal_eval(raw)
        if isinstance(parsed, list):
            return [q.strip() for q in parsed if isinstance(q, str) and "?" in q]
    except:
        pass

    questions = []
    for line in raw.split("\n"):
        line = line.strip()
        if "?" in line:
            questions.append(line)

    return questions


In [None]:
OUT_DIR = "/content/"

In [None]:
FACTS_ENTAILED_FILE = os.path.join(
    OUT_DIR, "pubmed_facts_entailed.jsonl"
)
QUESTIONS_FILE = os.path.join(
    OUT_DIR, "pubmed_questions.jsonl"
)

with open(FACTS_ENTAILED_FILE, "r", encoding="utf-8") as fin, \
     open(QUESTIONS_FILE, "w", encoding="utf-8") as fout:

    for line in fin:
        obj = json.loads(line)

        sid = obj["stringID"]
        facts = obj["facts_entailed"]

        questions = generate_questions_from_facts(facts)

        out = {
            "stringID": sid,
            "questions": questions
        }

        fout.write(json.dumps(out, ensure_ascii=False) + "\n")
        fout.flush()

        print(sid, "→", len(questions), "questions")

Questions Answering


---



In [None]:
QA_PROMPT = """You are answering factual questions using ONLY the provided text.

Rules:
- Answer concisely.
- Use exact information from the text.
- If the answer is not explicitly stated, output "NOT_FOUND".
- Do NOT explain.

Text:
{text}

Question:
{question}

Answer:
"""

In [None]:
def answer_question(text, question, max_new_tokens=128):
    prompt = QA_PROMPT.format(text=text, question=question)
    raw = run_llm(prompt, max_new_tokens=max_new_tokens)

    answer = raw.strip()
    if not answer:
        return "NOT_FOUND"

    return answer.replace("\n", " ").strip()

In [None]:
PUBMED_FILE = "/content/pubmed_500.json"

with open(PUBMED_FILE, "r", encoding="utf-8") as f:
    data_pubmed = json.load(f)

print("Loaded:", len(data_pubmed))
print(data_pubmed[0].keys())


Loaded: 500
dict_keys(['stringID', 'source', 'summary'])


In [None]:
def answer_question_sliding(
    text,
    question,
    window_chars=1200,
    step=600,
    max_new_tokens=64
):
    for i in range(0, len(text), step):
        window = text[i:i+window_chars]

        prompt = QA_PROMPT.format(text=window, question=question)
        raw = run_llm(prompt, max_new_tokens=max_new_tokens)

        answer = raw.strip()
        if answer and answer != "NOT_FOUND":
            return answer.replace("\n", " ").strip()

    return "NOT_FOUND"


In [None]:
id_to_doc = {
    ex["stringID"]: {
        "source": ex["source"],
        "summary": ex["summary"]
    }
    for ex in data_pubmed
}

In [None]:
OUT_DIR = "/content"

QUESTIONS_FILE = os.path.join(OUT_DIR, "pubmed_questions.jsonl")
ASKQE_QA_FILE = "askqe_qa_pairs.jsonl"
ASKQE_SCORES_FILE = "askqe_scores.jsonl"



In [None]:
questions_data = []
with open(QUESTIONS_FILE, "r", encoding="utf-8") as f:
    for line in f:
        questions_data.append(json.loads(line))

print("Loaded questions for", len(questions_data), "documents")

# File 1: id -> [questions, answers_source, answers_summary]
with open(ASKQE_QA_FILE, "w", encoding="utf-8") as f_qa:
    for obj in questions_data:
        sid = obj["stringID"]
        questions = obj["questions"]

        if sid not in id_to_doc:
            continue

        source = id_to_doc[sid]["source"]
        summary = id_to_doc[sid]["summary"]

        questions_list = []
        answers_source_list = []
        answers_summary_list = []

        for q in questions:
            a_src = answer_question_sliding(source, q)
            a_sum = answer_question_sliding(summary, q)

            questions_list.append(q)
            answers_source_list.append(a_src)
            answers_summary_list.append(a_sum)

            torch.cuda.empty_cache()

        qa_line = json.dumps({
            "stringID": sid,
            "questions": questions_list,
            "answers_source": answers_source_list,
            "answers_summary": answers_summary_list
        }, ensure_ascii=False) + "\n"

        f_qa.write(qa_line)
        print(qa_line, end='')


Loaded questions for 16 documents
{"stringID": "pubmed_000196", "questions": ["1. Which primer sets were used for amplifying DNA fragments for the construction of the anea disruption cassette?", "2. How were the DNA fragments for the anea disruption cassette amplified using the primer sets anea - a1/-a2, anea - b1/-b2, and argb - for/-rev?", "3. Which nested PCR primer set was used for amplifying the complete anea disruption cassette?", "4. Is cop essential for the viability of Aspergillus nidulans?", "5. What roles do copi proteins play in responses to endoplasmic reticulum stress and thermal stress in yeast?", "6. Were any detectable changes observed in the cop deletion strain of Aspergillus nidulans upon treatment with calcofluor white, congo red, caspofungin, tunicamycin, terbinafine, fludioxonil, farnesol, and EGTA?"], "answers_source": ["The primer sets used for amplifying DNA fragments for the construction of the anea disruption cassette were anea - a1/-a2, anea - b1/-b2, and ar