In [None]:
#!/usr/bin/env python3
from __future__ import annotations

import pandas as pd
from typing import Literal, Dict, Any

Feature = Literal["ctg", "obj", "loc"]
InterOrder = Literal["ABAB", "ABBA"]
TaskKind = Literal["interdms", "1back"]


def _truth_to_resp_code(is_match: bool) -> int:
    """Boolean match -> expected response code (2=yes, 3=no)."""
    return 2 if is_match else 3


def _is_missing(x) -> bool:
    return pd.isna(x)


# =========================
# InterDMS scoring
# =========================
def score_interdms_row(
    row: pd.Series,
    feature: Feature,
    order: InterOrder,
    resp_cols: tuple[str, str] = ("response_2", "response_3"),
) -> Dict[str, Any]:
    """
    InterDMS has 4 stimuli features per row (feature1..feature4) and 2 scored responses.

    ABAB:
      ans1: compare 3 vs 1 (A2 vs A1)
      ans2: compare 4 vs 2 (B2 vs B1)

    ABBA:
      ans1: compare 3 vs 2 (B2 vs B1)
      ans2: compare 4 vs 1 (A2 vs A1)

    Returns flat dict:
      ans1, ans2, trial_correct,
      expected_ans1, expected_ans2,
      ans1_comp, ans2_comp
    """
    # Pull features
    f1 = row[f"{feature}1"]
    f2 = row[f"{feature}2"]
    f3 = row[f"{feature}3"]
    f4 = row[f"{feature}4"]

    # Which comparisons are performed depends on order
    if order == "ABAB":
        left1, right1 = 3, 1
        left2, right2 = 4, 2
    elif order == "ABBA":
        left1, right1 = 3, 2
        left2, right2 = 4, 1
    else:
        raise ValueError(f"Unknown InterDMS order: {order}")

    # Compute match truth values
    vals = {1: f1, 2: f2, 3: f3, 4: f4}
    match1 = bool(vals[left1] == vals[right1])
    match2 = bool(vals[left2] == vals[right2])

    exp1 = _truth_to_resp_code(match1)
    exp2 = _truth_to_resp_code(match2)

    # Pull responses
    r1 = row.get(resp_cols[0], pd.NA)
    r2 = row.get(resp_cols[1], pd.NA)

    # Score answers (False if missing)
    ans1 = (not _is_missing(r1)) and (int(r1) == exp1)
    ans2 = (not _is_missing(r2)) and (int(r2) == exp2)

    # trial_correct requires all answers present + correct
    trial_correct = bool(ans1 and ans2 and (not _is_missing(r1)) and (not _is_missing(r2)))

    ans1_comp = f"{feature}{left1}=={feature}{right1}"
    ans2_comp = f"{feature}{left2}=={feature}{right2}"

    return {
        "ans1": bool(ans1),
        "ans2": bool(ans2),
        "trial_correct": bool(trial_correct),
        "expected_ans1": int(exp1),
        "expected_ans2": int(exp2),
        "ans1_comp": ans1_comp,
        "ans2_comp": ans2_comp,
    }


def score_interdms_df(
    df: pd.DataFrame,
    feature: Feature,
    order: InterOrder,
    resp_cols: tuple[str, str] = ("response_2", "response_3"),
) -> pd.DataFrame:
    scored = df.apply(lambda r: score_interdms_row(r, feature=feature, order=order, resp_cols=resp_cols), axis=1)
    return pd.concat([df, pd.DataFrame(scored.tolist(), index=df.index)], axis=1)


# =========================
# 1-back scoring
# =========================
def score_1back_row(
    row: pd.Series,
    feature: Feature,
    resp_cols: tuple[str, str, str, str, str] = ("response_1", "response_2", "response_3", "response_4", "response_5"),
) -> Dict[str, Any]:
    """
    1-back has 6 stimuli per row: feature1..feature6 and 5 scored responses:
      ans1 compares 2 vs 1
      ans2 compares 3 vs 2
      ans3 compares 4 vs 3
      ans4 compares 5 vs 4
      ans5 compares 6 vs 5

    Returns flat dict:
      ans1..ans5, trial_correct,
      expected_ans1..expected_ans5,
      ans1_comp..ans5_comp
    """
    vals = {i: row[f"{feature}{i}"] for i in range(1, 7)}

    out: Dict[str, Any] = {}
    all_correct = True

    for k in range(1, 6):
        left = k + 1
        right = k
        match = bool(vals[left] == vals[right])
        exp = _truth_to_resp_code(match)

        r = row.get(resp_cols[k - 1], pd.NA)
        ans = (not _is_missing(r)) and (int(r) == exp)

        out[f"ans{k}"] = bool(ans)
        out[f"expected_ans{k}"] = int(exp)
        out[f"ans{k}_comp"] = f"{feature}{left}=={feature}{right}"

        # trial_correct requires all answers present and correct
        if _is_missing(r) or (not ans):
            all_correct = False

    out["trial_correct"] = bool(all_correct)
    return out


def score_1back_df(
    df: pd.DataFrame,
    feature: Feature,
    resp_cols: tuple[str, str, str, str, str] = ("response_1", "response_2", "response_3", "response_4", "response_5"),
) -> pd.DataFrame:
    scored = df.apply(lambda r: score_1back_row(r, feature=feature, resp_cols=resp_cols), axis=1)
    return pd.concat([df, pd.DataFrame(scored.tolist(), index=df.index)], axis=1)


# =========================
# Wrapper
# =========================
def score_block_df(
    df: pd.DataFrame,
    task_kind: TaskKind,
    feature: Feature,
    order: InterOrder | None = None,
) -> pd.DataFrame:
    if task_kind == "interdms":
        if order is None:
            raise ValueError("InterDMS scoring requires order='ABAB' or 'ABBA'.")
        return score_interdms_df(df, feature=feature, order=order)
    if task_kind == "1back":
        return score_1back_df(df, feature=feature)
    raise ValueError(f"Unknown task_kind: {task_kind}")

if __name__ == "__main__":
    df = pd.read_csv("/mnt/tempdata/lucas/fmri/recordings/TR/behav/sub-01/ses-01/sub-01_ses-1_20251106-120104_task-1back_loc_block_0_events_12-22-30.tsv", sep="\t")
    df_scored = score_1back_df(df, feature="loc")


[3, 3, 3, 3, 3]
[True, True, True, True, True]
[3, 3, 3, 3, 3]
[True, True, True, True, True]
[3, 3, 3, 3, 3]
[True, True, True, True, True]
[3, 3, 2, 2, 2]
[True, True, True, True, True]
[2, 2, 3, 3, 3]
[True, True, True, True, True]
[2, 3, 3, 3, 3]
[True, True, True, True, True]
[3, 3, 3, 3, 2]
[True, True, True, True, True]
[2, 2, 2, 2, 2]
[True, True, True, True, True]
[2, 2, 2, 2, 2]
[True, True, True, True, True]
[3, 3, 3, 2, 2]
[True, True, True, True, True]


In [30]:
df_scored[["TrialNumber", "loc1","loc2","loc3","loc4","loc5","ans1","expected_ans1","ans2","expected_ans2","ans3","expected_ans3","ans4","expected_ans4","ans5","expected_ans5","trial_correct","n_missing_responses"]].head()

Unnamed: 0,TrialNumber,loc1,loc2,loc3,loc4,loc5,ans1,expected_ans1,ans2,expected_ans2,ans3,expected_ans3,ans4,expected_ans4,ans5,expected_ans5,trial_correct,n_missing_responses
0,1,1,0,1,0,1,True,3,True,3,True,3,True,3,True,3,True,0
1,2,0,1,0,1,0,True,3,True,3,True,3,True,3,True,3,True,0
2,3,0,1,0,1,0,True,3,True,3,True,3,True,3,True,3,True,0
3,4,1,0,1,1,1,True,3,True,3,True,2,True,2,True,2,True,0
4,5,0,0,0,1,0,True,2,True,2,True,3,True,3,True,3,True,0


In [18]:
df_scored

Unnamed: 0,TrialNumber,loc1,ref1,obj1,ctg1,ang1,loc2,ref2,obj2,ctg2,...,response_3,response_3_time,all_responses_3,ans1,ans2,trial_correct,expected_ans1,expected_ans2,ans1_comp,ans2_comp
0,1,1,7,3,1,1,1,2,1,0,...,2,15.112075,"[['2', 15.112074399832636]]",False,False,False,3,3,B2_vs_B1,A2_vs_A1
1,2,0,7,3,1,1,0,4,2,1,...,2,32.69019,"[['2', 32.690190399996936]]",True,True,True,2,2,B2_vs_B1,A2_vs_A1
2,3,0,1,0,0,1,1,7,3,1,...,2,50.465,"[['2', 50.46499879984185]]",True,True,True,3,2,B2_vs_B1,A2_vs_A1
3,4,0,0,0,0,0,0,0,0,0,...,2,68.043144,"[['2', 68.04314699978568]]",False,False,False,2,3,B2_vs_B1,A2_vs_A1
4,5,0,1,0,0,1,1,5,2,1,...,2,85.46452,"[['2', 85.46452639997005]]",False,False,False,3,3,B2_vs_B1,A2_vs_A1
5,6,1,3,1,0,1,0,4,2,1,...,2,103.40326,"[['2', 103.40325739979744]]",False,False,False,3,3,B2_vs_B1,A2_vs_A1
6,7,0,7,3,1,1,1,2,1,0,...,3,121.49681,"[['3', 121.49680999992415]]",False,False,False,2,2,B2_vs_B1,A2_vs_A1
