# Generate MCQs using MedAlpaca

In [1]:
import json
import re
from IPython.display import clear_output
from pathlib import Path
from tqdm import tqdm
from google.colab import files

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# # Upload JSONL file with question and answer fields
# uploaded = files.upload()
# file_path = list(uploaded.keys())[0]

# Get input from Google Drive
file_path = Path("/content/drive/MyDrive/medquad_sampled.jsonl")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# checkpoint = "microsoft/phi-2"
# checkpoint = "stanford-crfm/BioMedLM"
checkpoint = "medalpaca/medalpaca-7b"
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [3]:
# Load and parse JSONL file
with open(file_path, "r") as f:
    qa_data = [json.loads(line) for line in f]
    print(f"Loaded {len(qa_data)} QA pairs from {file_path}")

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.float16).to(device)
model.eval()
print('Tokenizer and model loaded.')

Loaded 3000 QA pairs from /content/drive/MyDrive/medquad_sampled.jsonl


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message
You are using the default legacy behav

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Tokenizer and model loaded.


In [4]:
torch.backends.cuda.matmul.allow_tf32 = True

# Check if output file already exists
# Set output path in Google Drive
drive_path = Path("/content/drive/MyDrive/generated_mcqs.jsonl")
# Check if output file already exists
output_file = drive_path
if output_file.exists():
    response = input(f"File {output_file} exists. Delete it? (y/n): ")
    if response.lower().startswith("y"):
        output_file.unlink()
    else:
        raise RuntimeError("Aborting to avoid overwriting the file.")

# Few-shot example to encourage structure
few_shot_example = """
Question: What is asthma?
Answer: Asthma is a lung condition that causes airway inflammation and breathing difficulty.
<choices>
A. A skin condition
B. A viral infection
C. A lung disease that causes airway inflammation
D. A type of heart failure
</choices>
<reason>Asthma is a lung condition that causes airway inflammation and breathing difficulty.</reason>
<answer>C</answer>
"""

# Format QA into prompt
def format_prompt(question, answer):
    return f"""
Create layperson multiple-choice answer choices and reasoning for the given medical Q&A pair.
Include 4 total choices: one correct and three plausible, but incorrect. Follow the XML-style structure below:
{few_shot_example}
Now it's your turn:
Question: {question}
Answer: {answer}
"""

tokenizer.padding_side = "right"

# Batch-generate MCQs
BATCH_SIZE = 16
NUM_RETURN_SEQUENCES = 3
num_prompts = 0
num_success = 0
# for i in tqdm(range(0, BATCH_SIZE*4+1, BATCH_SIZE),
#               desc=f"LLM Running on Micro Batches {BATCH_SIZE}"):  # DEBUG
with open(output_file, "a") as f:
    for i in tqdm(range(0, len(qa_data), BATCH_SIZE),
                        desc=f"LLM Running on Micro Batches {BATCH_SIZE}"):
        batch = qa_data[i:i + BATCH_SIZE]
        prompts = [format_prompt(q["question"], q["answer"]) for q in batch]

        inputs = tokenizer(prompts, return_tensors="pt",
                          truncation=True,
                          max_length=250,
                          padding=True).to(device)
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=250,
                temperature=0.7,
                do_sample=True,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.eos_token_id,
                num_return_sequences=NUM_RETURN_SEQUENCES,
            )

        prompt_lengths = [len(x) for x in inputs["input_ids"]]

        for j in range(len(batch)):
            num_prompts += 1
            for k in range(NUM_RETURN_SEQUENCES):
                idx = j * NUM_RETURN_SEQUENCES + k
                generated_tokens = outputs[idx][prompt_lengths[j]:]  # Strip prompt
                completion = tokenizer.decode(generated_tokens, clean_up_tokenization_spaces=False)

                mcq_choices = re.search(r"<choices>(.*?)</choices>", completion, re.DOTALL)
                mcq_reason = re.search(r"<reason>(.*?)</reason>", completion, re.DOTALL)
                mcq_answer = re.search(r"<answer>(.*?)</answer>", completion, re.DOTALL)
                four_choices = 'A. ' in completion and 'B. ' in completion and 'C. ' in completion and 'D. ' in completion and 'E. ' not in completion

                if not all([mcq_choices, mcq_reason, mcq_answer, four_choices]):
                    continue
                else:
                    # Add to output json
                    entry = {
                          "question": batch[j]["question"],
                          "mcq_choices": mcq_choices.group(1).strip(),
                          "mcq_reason": mcq_reason.group(1).strip(),
                          "mcq_answer": mcq_answer.group(1).strip(),
                    }
                    f.write(json.dumps(entry) + "\n")
                    f.flush()  # Save immediately
                    num_success += 1
                    break
        clear_output(wait=True)
        print(f"\n{num_success} / {num_prompts} ({num_success/num_prompts * 100:.2f}%) generated MCQs saved to {output_file}")
clear_output(wait=True)
print(f"\nDone! {num_success} / {num_prompts} ({num_success/num_prompts * 100:.2f}%) generated MCQs saved to {output_file}")


Done! 825 / 3000 (27.50%) generated MCQs saved to /content/drive/MyDrive/generated_mcqs.jsonl
