In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = "openai/gpt-oss-20b" 
# MODEL_ID = "Qwen/Qwen3-30B-A3B-Instruct-2507-FP8" 

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, dtype=torch.bfloat16 if torch.cuda.is_available() else None, device_map="auto"
)

MXFP4 quantization requires triton >= 3.4.0 and kernels installed, we will default to dequantizing the model to bf16


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.80G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.17G [00:00<?, ?B/s]

model-00000-of-00002.safetensors:   0%|          | 0.00/4.79G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/177 [00:00<?, ?B/s]

In [2]:
import os 
import re
import json
import os
import random

pattern = re.compile(r"\[\*\*\s*(.*?)\s*\*\*\]")

for filename in os.listdir("data/entities_txt"):
    with open(os.path.join("data/entities_txt", filename), 'r') as file:
        content = file.read()
        
        pattern_brat = re.compile(r'^[^\t]+\t[^\t]+\s+\d+\s+\d+\t(.+)$')

        original_entities = []
        fictive_entities = []
        with open(os.path.join("data/original_ann", filename), 'r') as file:
                lines = file.readlines()
                for line in lines:
                    match = pattern_brat.match(line.strip())
                    if match:
                        original_entities.append(match.group(1))
                        
        with open(os.path.join("data/fictive_ann", filename), 'r') as file:
                lines = file.readlines()
                for line in lines:
                    match = pattern_brat.match(line.strip())
                    if match:
                        fictive_entities.append(match.group(1))

        for m in pattern.finditer(content):
            start, end = m.span()
            
            LETTER = content[:start] + "[** ENTITY_SLOT **]" + content[end:]
            
            # Pull the next pair
            original_entity = original_entities.pop(0)
            fictive_entity   = fictive_entities.pop(0)

            # Randomly decide which option slot gets the real original
            if random.random() < 0.5:
                OPTION_A, OPTION_B = original_entity, fictive_entity
                truth_label, truth_text = "A", original_entity
            else:
                OPTION_A, OPTION_B = fictive_entity, original_entity
                truth_label, truth_text = "B", original_entity
            
            SYSTEM_PROMPT = """2AFC LINKAGE EVALUATOR
            You evaluate a two-alternative forced-choice (2AFC) linkage task on doctor letters.

            Input:
            • One letter with a single masked slot: [**ENTITY_SLOT**]
            • Two candidates labeled A) and B)

            Task: Choose the candidate that most likely belongs into masked slot given the context of the letter.

            Output (STRICT):
            You have to return exactly one capital letter — A or B — with nothing else (no spaces, punctuation, or newlines).
            You can not Answer with None.
            """

            # ---- Beispieltext zur Verarbeitung ----
            INPUT_TEXT = f""""
            DOCTOR LETTER:
                {LETTER}
            OPTIONS:
                A) {OPTION_A}
                B) {OPTION_B}
            """
                
            def build_input(tokenizer, INPUT_TEXT: str):
                messages = [
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user", "content": INPUT_TEXT},
                ]
                
                if hasattr(tokenizer, "apply_chat_template"):
                    return tokenizer.apply_chat_template(
                        messages, add_generation_prompt=True, return_tensors="pt", reasoning_effort="low"
                    )
                
            # ---- Hauptverarbeitung ----
            input_ids = build_input(tokenizer, INPUT_TEXT).to(model.device)


            gen_kwargs = dict(
                max_new_tokens=1000,
                do_sample=False,
                eos_token_id=tokenizer.eos_token_id,
            )

            with torch.no_grad():
                out = model.generate(input_ids, **gen_kwargs)

            # Slice off the prompt part
            gen_ids = out[0][input_ids.shape[-1]:]
            raw = tokenizer.decode(gen_ids, skip_special_tokens=False).strip()

            # Generic: capture ANY channel blocks
            def grab_block(text: str, kind: str):
                start_tag = f"<|channel|>{kind}<|message|>"
                i = text.rfind(start_tag)
                if i == -1:
                    return None
                i += len(start_tag)
                # find nearest terminator after i
                j_end = text.find("<|end|>", i)
                j_ret = text.find("<|return|>", i)
                j_candidates = [x for x in (j_end, j_ret) if x != -1]
                j = min(j_candidates) if j_candidates else len(text)
                return text[i:j].strip()

            analysis_text = grab_block(raw, "analysis")
            final_text = grab_block(raw, "final")

            output_data = {
                "letter": LETTER,
                "option_a": OPTION_A,
                "option_b": OPTION_B,
                "resoning": analysis_text,
                "choice": final_text,
                "truth":truth_label,
                "is_correct": (final_text == truth_label)
            }

            output_file = os.path.join("LLM_2AFC_output", f"result_{os.path.splitext(filename)[0]}_{m.start()}.json")

            with open(output_file, 'w') as f:
                json.dump(output_data, f, indent=4)

In [13]:
import os
import json

correct_count = 0
incorrect_count = 0
same = 0

directory = "LLM_2AFC_output"
for filename in os.listdir(directory):
    if filename.endswith(".json"):
        with open(os.path.join(directory, filename), 'r') as f:
            data = json.load(f)
            op_a = data["option_a"]
            ob_b = data["option_b"]
            same += op_a == ob_b
            if op_a == ob_b:
                pass
                #print(op_a, ob_b)
            if data["is_correct"]:
                correct_count += 1
            else:
                incorrect_count += 1

print(f"Correct predictions: {correct_count}")
print(f"Incorrect predictions: {incorrect_count}")
print(f"Accuracy: {correct_count / (correct_count + incorrect_count):.2%}")
print(f"Same option: {same:}")

Correct predictions: 764
Incorrect predictions: 672
Accuracy: 53.20%
Same option: 291


# 2AFC stats from n (trials) and k (correct)
- Accuracy
- Standard error
- Wilson 95% CI
- 90% Wilson CI (for equivalence checks)
- Binomial test vs 50%:

In [5]:

import math

def accuracy(n, k):
    return k / n

def se_proportion(n, k):
    p = k / n
    return math.sqrt(p * (1 - p) / n)

def wilson_ci(n, k, alpha=0.05):
    # Wilson score interval
    z = {0.10: 1.2815515655446004, 0.05: 1.6448536269514722, 0.025: 1.959963984540054}[alpha/2]
    p = k / n
    denom = 1 + (z*z)/n
    center = (p + (z*z)/(2*n)) / denom
    half = (z * math.sqrt((p*(1-p)/n) + (z*z)/(4*n*n))) / denom
    return (center - half, center + half)

def normal_binom_test_vs_half(n, k):
    """
    Continuity-corrected normal approx for H0: p = 0.5.
    Returns (one_sided_p_gt, two_sided_p).
    """
    mu = n / 2
    sigma = math.sqrt(n / 4)
    # For upper-tail (k >= observed), use k - 0.5 continuity correction
    z = ((k - 0.5) - mu) / sigma
    # Survival function of standard normal using erf
    def norm_sf(z):
        cdf = 0.5 * (1 + math.erf(z / math.sqrt(2)))
        return 1 - cdf
    one_sided = norm_sf(z)              # P(K >= k | H0)
    # two-sided: double the smaller tail
    cdf_lower = 1 - one_sided           # P(K <= k-1 | H0, with CC)
    two_sided = 2 * min(cdf_lower, one_sided)
    return one_sided, two_sided

def try_exact_binom(n, k):
    """
    Try to compute exact binomial test & exact (Clopper–Pearson) CI via SciPy.
    """
    from scipy.stats import binomtest, beta
    # Exact test vs 0.5
    exact_test_two = binomtest(k, n, p=0.5, alternative='two-sided').pvalue
    exact_test_one  = binomtest(k, n, p=0.5, alternative='greater').pvalue
    # Exact 95% CI (Clopper–Pearson)
    alpha = 0.05
    if k == 0:
        cp_lo = 0.0
    else:
        cp_lo = beta.ppf(alpha/2, k, n - k + 1)
    if k == n:
        cp_hi = 1.0
    else:
        cp_hi = beta.ppf(1 - alpha/2, k + 1, n - k)
    # Exact 90% CI
    alpha90 = 0.10
    if k == 0:
        cp_lo_90 = 0.0
    else:
        cp_lo_90 = beta.ppf(alpha90/2, k, n - k + 1)
    if k == n:
        cp_hi_90 = 1.0
    else:
        cp_hi_90 = beta.ppf(1 - alpha90/2, k + 1, n - k)
    return {
        "p_two_sided": exact_test_two,
        "p_one_sided": exact_test_one,
        "cp95": (cp_lo, cp_hi),
        "cp90": (cp_lo_90, cp_hi_90),
    }

def summarize(n, k, equivalence_margin=0.05):
    p_hat = accuracy(n, k)
    se = se_proportion(n, k)

    # Wilson 95% and 90%
    wilson95 = wilson_ci(n, k, alpha=0.05)
    wilson90 = wilson_ci(n, k, alpha=0.10)

    # Tests vs 50%
    exact = try_exact_binom(n, k)
    if exact:
        p_two = exact["p_two_sided"]
        p_one = exact["p_one_sided"]
        ci95 = exact["cp95"]
        ci90 = exact["cp90"]
        ci_label = "Exact Clopper–Pearson"
        ci90_label = "Exact 90% (Clopper–Pearson)"
    else:
        p_one, p_two = normal_binom_test_vs_half(n, k)
        ci95 = wilson95
        ci90 = wilson90
        ci_label = "Wilson (approx.)"
        ci90_label = "Wilson 90% (approx.)"

    # Equivalence check: is the entire 90% CI inside [0.5 ± margin]?
    eq_low, eq_high = 0.5 - equivalence_margin, 0.5 + equivalence_margin
    equivalent_to_chance = (ci90[0] >= eq_low) and (ci90[1] <= eq_high)

    # Pretty print
    def pct(x): return f"{100*x:.2f}%"
    print(f"n = {n}, k = {k}")
    print(f"Accuracy (p̂)      : {pct(p_hat)}")
    print(f"Std. error         : {se:.6f}  ({100*se:.3f} percentage points)")
    print(f"95% CI ({ci_label}): [{pct(ci95[0])}, {pct(ci95[1])}]")
    print(f"Binomial test vs 50%: two-sided p = {p_two:.3f} | one-sided p(>0.5) = {p_one:.3f}")
    print(f"90% CI ({ci90_label}): [{pct(ci90[0])}, {pct(ci90[1])}]")
    print(f"Equivalence margin  : ±{100*equivalence_margin:.1f} pp → "
          f"window [{pct(eq_low)}, {pct(eq_high)}]")
    print(f"Equivalence result  : "
          f"{'PASS (inside window)' if equivalent_to_chance else 'FAIL (touches/exceeds window)'}")

summarize(n=1436, k=746, equivalence_margin=0.05)


n = 1436, k = 746
Accuracy (p̂)      : 51.95%
Std. error         : 0.013184  (1.318 percentage points)
95% CI (Exact Clopper–Pearson): [49.33%, 54.56%]
Binomial test vs 50%: two-sided p = 0.147 | one-sided p(>0.5) = 0.073
90% CI (Exact 90% (Clopper–Pearson)): [49.74%, 54.15%]
Equivalence margin  : ±5.0 pp → window [45.00%, 55.00%]
Equivalence result  : PASS (inside window)
