In [1]:
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM
)
import pandas as pd
import os
from medgemma_wrapper import MedGemmaQAWrapper

from dotenv import load_dotenv
load_dotenv()

os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(f"Total VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print(f"Allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
print(f"Cached: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Total VRAM: 34.19 GB
Allocated: 0.00 GB
Cached: 0.00 GB
Using device: cuda


In [3]:
yn_df = pd.read_parquet("yn_df.parquet", engine='fastparquet').reset_index(drop=True)
yn_df.head(3)

Unnamed: 0,dataset_name,id_in_dataset,question,options,answer_label,question_type,prompt_text
0,pubmedqa,0,Do mitochondria play a role in remodelling lac...,Answer Choices:\nA. Yes\nB. No,yes,Y/N,Question:\nDo mitochondria play a role in remo...
1,pubmedqa,1,Landolt C and snellen e acuity: differences in...,Answer Choices:\nA. Yes\nB. No,no,Y/N,Question:\nLandolt C and snellen e acuity: dif...
2,pubmedqa,4,Can tailored interventions increase mammograph...,Answer Choices:\nA. Yes\nB. No,yes,Y/N,Question:\nCan tailored interventions increase...


In [4]:
mcq_df = pd.read_parquet("mcq_df.parquet", engine='fastparquet').reset_index(drop=True)
mcq_df.head(3)

Unnamed: 0,dataset_name,id_in_dataset,question,options,answer_label,question_type,prompt_text
0,medmcqa,7131,Urogenital Diaphragm is made up of the followi...,Answer Choices:\nA. Deep transverse Perineus\n...,C,MCQ,Question:\nUrogenital Diaphragm is made up of ...
1,medmcqa,7133,Child with Type I Diabetes. What is the advise...,Answer Choices:\nA. After 5 years\nB. After 2 ...,A,MCQ,Question:\nChild with Type I Diabetes. What is...
2,medmcqa,7134,Most sensitive test for H pylori is-,Answer Choices:\nA. Fecal antigen test\nB. Bio...,B,MCQ,Question:\nMost sensitive test for H pylori is...


In [5]:
sample_prompt = yn_df.loc[0, "prompt_text"]
print(sample_prompt)

Question:
Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?

Answer Choices:
A. Yes
B. No


In [6]:
model = MedGemmaQAWrapper("google/medgemma-4b-it", device="cuda")

Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.14s/it]


In [7]:
print("=" * 80)
print("YES/NO SAMPLE (iloc=0)")
print("=" * 80)

model.set_task("yn")
yn_prompt = yn_df.iloc[0]["prompt_text"]
yn_gt = yn_df.iloc[0]["answer_label"]

# Get full output with rationale
yn_full = model.generate(yn_prompt)

# Get confidence scores
yn_answer, yn_conf, yn_probs = model.generate_with_confidence(yn_prompt)

print(f"PROMPT:\n{yn_prompt}\n")
print(f"GROUND TRUTH: {yn_gt}")
print("-" * 80)
print(f"FULL OUTPUT:\n{yn_full}")
print("-" * 80)
print(f"ANSWER: {yn_answer}")
print(f"CONFIDENCE: {yn_conf:.3f} ({yn_conf*100:.1f}%)")
print(f"PROBABILITIES: P(A)={yn_probs['A']:.3f}, P(B)={yn_probs['B']:.3f}")
print(f"CORRECT: {'✓' if yn_answer == yn_gt else '✗'}")
print("=" * 80)
print()

print("=" * 80)
print("MCQ SAMPLE (iloc=0)")
print("=" * 80)

model.set_task("mcq")
mcq_prompt = mcq_df.iloc[0]["prompt_text"]
mcq_gt = mcq_df.iloc[0]["answer_label"]

# Get full output with rationale
mcq_full = model.generate(mcq_prompt)

# Get confidence scores
mcq_answer, mcq_conf, mcq_probs = model.generate_with_confidence(mcq_prompt)

print(f"PROMPT:\n{mcq_prompt}\n")
print(f"GROUND TRUTH: {mcq_gt}")
print("-" * 80)
print(f"FULL OUTPUT:\n{mcq_full}")
print("-" * 80)
print(f"ANSWER: {mcq_answer}")
print(f"CONFIDENCE: {mcq_conf:.3f} ({mcq_conf*100:.1f}%)")
print(f"PROBABILITIES:")
for option, prob in mcq_probs.items():
    marker = "←" if option == mcq_answer else " "
    print(f"  {option}: {prob:.3f} ({prob*100:.1f}%) {marker}")
print(f"CORRECT: {'✓' if mcq_answer == mcq_gt else '✗'}")
print("=" * 80)

YES/NO SAMPLE (iloc=0)
PROMPT:
Question:
Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?

Answer Choices:
A. Yes
B. No

GROUND TRUTH: yes
--------------------------------------------------------------------------------
FULL OUTPUT:
Answer: A
Rationale: Mitochondria are essential organelles involved in cellular respiration and energy production. During programmed cell death (PCD), also known as apoptosis, mitochondria play a crucial role in initiating and executing the process. They release cytochrome c, a protein that activates caspases, a family of proteases that trigger the dismantling of the cell. Furthermore, mitochondria can also directly contribute to cell death by releasing reactive oxygen species (ROS) and other damaging molecules. Therefore, mitochondria are definitely involved in the remodelling of lace plant leaves during programmed cell death.
--------------------------------------------------------------------------------
ANSWER: 

In [8]:
model.set_task("yn")
for i in range(10):
    prompt = yn_df.loc[i, "prompt_text"]
    ground_truth = yn_df.loc[i, "answer_label"]
    out = model.generate(prompt)
    print(f"{prompt}")
    print("*" * 80)
    print(out)
    print("*" * 80)
    print(f"Ground truth: {ground_truth}")
    print("="*80)

Question:
Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?

Answer Choices:
A. Yes
B. No
********************************************************************************
Answer: A
Rationale: Mitochondria are essential organelles involved in cellular respiration and energy production. During programmed cell death (PCD), also known as apoptosis, mitochondria play a crucial role in initiating and executing the process. They release cytochrome c, a protein that activates caspases, a family of proteases that trigger the dismantling of the cell. Furthermore, mitochondria can also directly contribute to cell death by releasing reactive oxygen species (ROS) and other damaging molecules. Therefore, mitochondria are definitely involved in the remodelling of lace plant leaves during programmed cell death.
********************************************************************************
Ground truth: yes
Question:
Landolt C and snellen e acuity: differences

In [9]:
model.set_task("mcq")
for i in range(10):
    prompt = mcq_df.loc[i, "prompt_text"]
    ground_truth = mcq_df.loc[i, "answer_label"]
    out = model.generate(prompt)
    print(f"{prompt}")
    print("*" * 80)
    print(out)
    print("*" * 80)
    print(f"Ground truth: {ground_truth}")
    print("="*80)

Question:
Urogenital Diaphragm is made up of the following, except:

Answer Choices:
A. Deep transverse Perineus
B. Perinial membrane
C. Colle's fascia
D. Sphincter Urethrae
********************************************************************************
Answer: A
Rationale: The urogenital diaphragm is a muscular structure that forms the floor of the pelvis. It is composed of the deep transverse perineal muscle, the perineal membrane, and the superficial transverse perineal muscle. The sphincter urethrae is a muscle that surrounds the urethra and is part of the pelvic floor, but it is not considered part of the urogenital diaphragm.
********************************************************************************
Ground truth: C
Question:
Child with Type I Diabetes. What is the advised time for fundus examinations from the time of diagnosis?

Answer Choices:
A. After 5 years
B. After 2 years
C. After 10 years
D. At the time of diagnosis
**************************************************

In [29]:
model.set_task("yn")

# Define mapping
answer_map = {'A': 'yes', 'B': 'no'}
results = []

for i in range(100):
    prompt = yn_df.loc[i, "prompt_text"]
    ground_truth = yn_df.loc[i, "answer_label"].lower()  # normalize to lowercase
    
    answer, conf, probs = model.generate_with_confidence(prompt)
    answer_text = answer_map[answer]  # Convert A/B to yes/no
    correct = (answer_text == ground_truth)
    
    results.append({
        'answer': answer,
        'answer_text': answer_text,
        'confidence': conf,
        'correct': correct,
        'probs': probs
    })
    
    # Only show wrong answers or low confidence
    if not correct:
        print(f"Q{i+1:2d}: {answer_text} → {ground_truth} ✗  \t[conf: {conf:.1%}]")
    elif conf < 0.6:
        print(f"Q{i+1:2d}: {answer_text} ✓ (LOW) \t[conf: {conf:.1%}]")

print(f"\n{'='*50}")
avg_conf = sum(r['confidence'] for r in results) / len(results)
accuracy = sum(r['correct'] for r in results) / len(results)
print(f"Results: {sum(r['correct'] for r in results)}/{len(results)} correct ({accuracy:.1%})")
print(f"Avg confidence: {avg_conf:.1%}")

Q 2: no ✓ (LOW) 	[conf: 50.1%]
Q 5: yes → no ✗  	[conf: 87.5%]
Q11: no → yes ✗  	[conf: 96.2%]
Q17: no → yes ✗  	[conf: 99.8%]
Q24: no → yes ✗  	[conf: 85.1%]
Q25: yes ✓ (LOW) 	[conf: 58.6%]
Q27: no → yes ✗  	[conf: 93.5%]
Q33: yes → no ✗  	[conf: 60.9%]
Q35: no → yes ✗  	[conf: 97.9%]
Q51: no → yes ✗  	[conf: 79.4%]
Q53: yes → no ✗  	[conf: 87.0%]
Q60: no → yes ✗  	[conf: 70.0%]
Q61: no → yes ✗  	[conf: 98.7%]
Q66: no → yes ✗  	[conf: 96.4%]
Q70: yes → no ✗  	[conf: 88.1%]
Q74: no → yes ✗  	[conf: 91.5%]
Q78: no → yes ✗  	[conf: 72.1%]
Q79: yes → no ✗  	[conf: 90.7%]
Q81: yes → no ✗  	[conf: 98.8%]
Q82: yes → no ✗  	[conf: 52.1%]
Q88: yes → no ✗  	[conf: 95.6%]
Q89: no → yes ✗  	[conf: 99.3%]
Q90: no → yes ✗  	[conf: 98.2%]
Q91: yes → no ✗  	[conf: 99.5%]
Q93: no → yes ✗  	[conf: 99.2%]
Q99: no → yes ✗  	[conf: 73.1%]

Results: 76/100 correct (76.0%)
Avg confidence: 93.1%


In [28]:
model.set_task("mcq")
results = []

for i in range(100):
    prompt = mcq_df.loc[i, "prompt_text"]
    ground_truth = mcq_df.loc[i, "answer_label"]
    
    answer, conf, probs = model.generate_with_confidence(prompt)
    correct = (answer == ground_truth)
    
    results.append({
        'answer': answer,
        'confidence': conf,
        'correct': correct,
        'probs': probs
    })
    
    # Only show wrong answers or low confidence
    if not correct:
        print(f"Q{i+1:2d}: {answer} → {ground_truth} ✗  \t[conf: {conf:.1%}]")
    elif conf < 0.6:
        print(f"Q{i+1:2d}: {answer} ✓ (LOW) \t[conf: {conf:.1%}]")

avg_conf = sum(r['confidence'] for r in results) / len(results)
accuracy = sum(r['correct'] for r in results) / len(results)
print(f"\n{'='*50}")
print(f"Results: {sum(r['correct'] for r in results)}/{len(results)} correct ({accuracy:.1%})")
print(f"Avg confidence: {avg_conf:.1%}")

Q 1: A → C ✗  	[conf: 91.0%]
Q 2: B → A ✗  	[conf: 99.0%]
Q 3: D → B ✗  	[conf: 96.2%]
Q 4: B → D ✗  	[conf: 85.0%]
Q 6: C → B ✗  	[conf: 95.9%]
Q10: B → D ✗  	[conf: 98.5%]
Q12: B → C ✗  	[conf: 98.8%]
Q13: B ✓ (LOW) 	[conf: 48.0%]
Q15: A → B ✗  	[conf: 85.4%]
Q17: B → D ✗  	[conf: 79.9%]
Q18: B → D ✗  	[conf: 67.3%]
Q19: A ✓ (LOW) 	[conf: 51.0%]
Q20: A → D ✗  	[conf: 97.1%]
Q24: B → C ✗  	[conf: 70.2%]
Q30: C → D ✗  	[conf: 80.1%]
Q32: A → B ✗  	[conf: 69.4%]
Q37: C → A ✗  	[conf: 69.7%]
Q47: D ✓ (LOW) 	[conf: 47.8%]
Q53: D → C ✗  	[conf: 58.2%]
Q60: D → B ✗  	[conf: 99.8%]
Q63: B → A ✗  	[conf: 99.6%]
Q64: A → D ✗  	[conf: 83.0%]
Q66: A → B ✗  	[conf: 98.0%]
Q68: C → A ✗  	[conf: 86.8%]
Q75: D ✓ (LOW) 	[conf: 46.5%]
Q76: C → B ✗  	[conf: 99.0%]
Q79: B → A ✗  	[conf: 61.6%]
Q98: B ✓ (LOW) 	[conf: 51.3%]
Q99: B → A ✗  	[conf: 96.0%]
Q100: B → C ✗  	[conf: 99.8%]

Results: 75/100 correct (75.0%)
Avg confidence: 91.4%
