## Imports and notebook setup

In [None]:
import os
import shutil
import re
import json
from datetime import datetime, timedelta
from collections import Counter, defaultdict
from typing import Literal
from typing_extensions import TypeIs
import textwrap

import numpy as np
import pandas as pd
from pandas._libs.missing import NAType
from scipy.stats import mannwhitneyu, fisher_exact, binomtest
from scipy.stats.contingency import relative_risk
from IPython.display import Image

import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.stats.inter_rater import fleiss_kappa
from statsmodels.stats.proportion import proportion_confint

import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.ticker as mtick
from graphviz import Digraph
import seaborn as sns

from pydantic import BaseModel, ValidationError

import blobfile as bf

import types_os as types

In [None]:
parquet_fp = 'REMOVED'
clinical_study_results_fp = 'REMOVED'
gpt41_results_fp = 'REMOVED'
o3_results_fp = 'REMOVED'
plots_folder = 'ENTER_FOLDER_NAME_HERE'
shutil.rmtree(plots_folder)
os.makedirs(plots_folder)

In [None]:
if bf.exists(plots_folder):
    print(f"Removing existing plots folder: {plots_folder}")
    bf.rmtree(plots_folder)

print(f"Creating plots folder: {plots_folder}")
bf.makedirs(plots_folder)

def save_fig(plot_name: str, fig: plt.Figure):
    fp = bf.join(plots_folder, f"{plot_name}.png")
    with bf.BlobFile(fp, "wb") as f:
        fig.savefig(f, format="png", dpi=300, bbox_inches="tight")

def save_latex(name: str, df: pd.DataFrame):
    fp = bf.join(plots_folder, f"{name}.tex")
    s = df.to_latex().replace('%', '\\%')
    with bf.BlobFile(fp, "w") as f:
        f.write(s)

def save_csv(name: str, df: pd.DataFrame):
    fp = bf.join(plots_folder, "data_for_web/", f"{name}.csv")
    print(f"Saving {name} to {fp}")
    with bf.BlobFile(fp, "w") as f:
        df.to_csv(f, index=False)

latexvars = {}

In [None]:
sns.set_theme(
    style="dark",
    palette="muted",
    font="serif",
    rc={
        "figure.dpi": 120, # TODO: modify for paper figures
        "axes.titleweight":"normal",
        "axes.labelweight":"normal",
        "axes.spines.top":False,
        "axes.spines.right": False,
        "legend.frameon": False,
        "figure.autolayout": True,
        "legend.fontsize": "small",
        "legend.title_fontsize": "medium",
        "xtick.labelsize": "small",
        "ytick.labelsize": "small",
    }
)

## Load and process data

In [None]:
parquet_dir = os.path.expanduser(parquet_fp)

dfs = {}
for file in bf.listdir(parquet_dir):
    if not file.endswith('.parquet'):
        continue

    file_path = bf.join(parquet_dir, file)
    with bf.BlobFile(file_path, "rb") as f:
        df = pd.read_parquet(f)
    filename = os.path.basename(file_path).replace('.parquet', '')
    dfs[filename] = df

In [None]:
def concat_old_new_dfs(old_df: pd.DataFrame, new_df: pd.DataFrame, assert_no_dupes: bool = True, index_col: str = 'VisitCode', columns_should_be_subset: tuple[str, ...] | None = ('UserId', 'CallDate')) -> pd.DataFrame:
    assert (old_df.columns == new_df.columns).all()
    for col in old_df.columns:
        assert old_df[col].dtype == new_df[col].dtype, f"Column {col} has different dtypes in the old and new dfs"

    concat_df = pd.concat([old_df, new_df], ignore_index=True)

    for subset in [None, index_col]:
        n_dupes_old = len(old_df) - len(old_df.drop_duplicates(subset=subset))
        n_dupes_new = len(new_df) - len(new_df.drop_duplicates(subset=subset))
        if assert_no_dupes:
            assert n_dupes_old == 0, f"Old df has {n_dupes_old} duplicates" + ("looking at all columns" if subset is None else f"looking at {subset} column")
            assert n_dupes_new == 0, f"New df has {n_dupes_new} duplicates" + ("looking at all columns" if subset is None else f"looking at {subset} column")

        assert len(concat_df) == len(concat_df.drop_duplicates(subset=subset)) + n_dupes_old + n_dupes_new, f"Duplicates exist across the old and new dfs" + ("looking at all columns" if subset is None else f"looking at {subset} column")

    if columns_should_be_subset is not None:
        for col in columns_should_be_subset:
            diff = set(new_df[col].dropna()) - set(old_df[col].dropna())
            assert len(diff) == 0, f"New df has {col} that are not in the old df: {diff}"

    return concat_df

### Interrogate outcomes data

In [None]:
outcomes = dfs['outcomes']
outcomes_new = dfs['outcomes_new']

outcomes = concat_old_new_dfs(outcomes, outcomes_new, assert_no_dupes=False, columns_should_be_subset=('Q1',))

In [None]:
print(f'Number of outcomes: {len(outcomes)}')

In [None]:
# ensure no entirely duplicated rows
assert not any(outcomes.duplicated())

#### Normalize the outcomes results

In [None]:
outcomes.Q1.value_counts(dropna = False)

In [None]:
def outcomes_q1_remapper(s: str | NAType) -> int | NAType:
    if pd.isna(s):
        return pd.NA

    # many of these have | in them, indicating multiple answers; drop
    if '|' in s:
        return pd.NA

    return int(s[0])

outcomes.Q1 = outcomes.Q1.map(outcomes_q1_remapper)
outcomes.Q1.value_counts(dropna=False, normalize=True) * 100

In [None]:
outcomes.Q2.value_counts(dropna = False)

In [None]:
outcomes_map = {
    'I got all my treatment and medicines at Penda': 'all_at_penda',
    'I went myself to another hospital or specialist': 'self_referred',
    'Penda referred me to another hospital or specialist': 'penda_referred',
    'I visited another chemist': 'another_chemist',
}
def map_outcomes(s: str) -> str:
    mapped_outcomes = []
    for k, v in outcomes_map.items():
        if k in s:
            mapped_outcomes.append(v)
    assert len(mapped_outcomes) == 1, f"Multiple outcomes in {s} ({mapped_outcomes})"
    return mapped_outcomes[0]

def outcomes_q2_remapper(s: str | None) -> str | None:
    if s is None:
        return s

    outcomes_present = s.split('|')
    outcomes_mapped = {map_outcomes(o) for o in outcomes_present}

    match len(outcomes_mapped):
        case 0:
            raise ValueError(f"No outcomes were found in non-None {s}")
        case 1:
            return outcomes_mapped.pop()
        case 2:
            if 'all_at_penda' in outcomes_mapped:
                outcomes_mapped.remove('all_at_penda')
                return outcomes_mapped.pop()
            else:
                return None
        case _:
            raise ValueError(f"Unexpected: three outcomes were returned for {s}")

outcomes.Q2 = outcomes.Q2.map(outcomes_q2_remapper)
outcomes.Q2.value_counts(dropna=False, normalize=True) * 100


#### VisitCodes

In [None]:
# visit codes and call dates must be non-null
assert not any(outcomes.VisitCode.isna())
assert not any(outcomes.CallDate.isna())

In [None]:
counts_visit_codes = outcomes.VisitCode.value_counts()
repeat_visit_code_counts = counts_visit_codes[counts_visit_codes > 1]
repeat_visit_codes = repeat_visit_code_counts.index
print(f'Number of visit codes that occur more than once: {len(repeat_visit_code_counts)}')

In [None]:
print("Example of a visit code that occurs more than once:")
outcomes[outcomes.VisitCode == repeat_visit_codes[0]]

In [None]:
def keep_most_recent_call_date(group):
    # first, keep only the rows with the most recent call date
    max_call_date = group.CallDate.max()
    max_call_date_rows = group[group.CallDate == max_call_date]

    if len(max_call_date_rows) == 1:
        return max_call_date_rows

    # Combine rows with the same max call date for a visit code
    # For Q1 and Q2, if there's a unique non-NA value, use it; if multiple, set to NA
    row = max_call_date_rows.iloc[0].copy()
    q1_values = max_call_date_rows.Q1.dropna().unique()
    row['Q1'] = q1_values[0] if len(q1_values) == 1 else None
    q2_values = max_call_date_rows.Q2.dropna().unique()
    row['Q2'] = q2_values[0] if len(q2_values) == 1 else None
    return row.to_frame().T

outcomes = outcomes.groupby('VisitCode', group_keys=False).apply(keep_most_recent_call_date)
assert outcomes.VisitCode.is_unique, "VisitCode is not unique in outcomes_deduplicated"

### Allocations

Allocations is subsumed by visits_with_users and is not necessary to use. Not all of the allocated providers saw patients - `visits_with_users` contains the canonical record of those who actually saw patients and should be used as the source of truth for what providers saw patients within the scope of the study.

In [None]:
allocations = dfs['allocations']
allocations

In [None]:
assert allocations['UserId'].is_unique, "UserId is not unique in allocations"
assert allocations['UserId'].notna().all(), "UserId has nulls in allocations"
assert allocations['StaffId'].is_unique, "StaffId is not unique in allocations"
assert allocations['StaffId'].notna().all(), "StaffId has nulls in allocations"
assert allocations['Group'].notna().all(), "Group has nulls in allocations"

In [None]:
# confirm that the 'Group' in visits_with_users matches the 'Group' from allocations
# Left join visits_with_users with allocations on 'UserId'
visits_with_users = dfs['visits_with_users']
vw_alloc = visits_with_users.merge(
    allocations[['UserId', 'Group']],
    on = 'UserId',
    how='left',
    suffixes=('', '_alloc')
)

# Confirm that the 'Group' in visits_with_users matches the 'Group' from allocations
# (i.e., 'Group' == 'Group_alloc' for all rows where 'Group_alloc' is not null)
mismatch = vw_alloc[~vw_alloc['Group'].eq(vw_alloc['Group_alloc']) & vw_alloc['Group_alloc'].notna()]
assert len(mismatch) == 0, "There are mismatches between visits_with_users.Group and allocations.Group"

In [None]:
# figure out how many are doctors and nurses
unique_users_with_visits = visits_with_users[['UserId', 'RoleName', 'Group']].drop_duplicates()
assert unique_users_with_visits['UserId'].is_unique, "UserId is not unique in unique_users_with_visits"
len(unique_users_with_visits), len(allocations)

In [None]:
uuv_alloc = unique_users_with_visits.merge(
    allocations[['UserId', 'Group']],
    on='UserId',
    how='left',
    suffixes=('', '_alloc')
)

In [None]:
uuv_alloc_providers = uuv_alloc[uuv_alloc['RoleName'] == 'Provider']
uuv_alloc_providers['Group'].value_counts()

In [None]:
assert all(uuv_alloc_providers['Group'] == uuv_alloc_providers['Group_alloc']), "Group does not match Group_alloc in uuv_alloc"

### Visits

In [None]:
visits = dfs['visits']
visits_new = dfs['visits_new']

visits = concat_old_new_dfs(visits, visits_new, columns_should_be_subset = ('VisitCategory', 'Gender', 'LocationName', 'VisitType', 'VisitDate', 'HadUnplannedVisit'))

visits.head()

In [None]:
len(visits)

In [None]:
assert visits.VisitCode.is_unique, "VisitCode is not unique in visits"

### Visits with users

In [None]:
visits_with_users

In [None]:
visits_with_users = dfs['visits_with_users']
visits_with_users_new = dfs['visits_with_users_new']

visits_with_users = concat_old_new_dfs(visits_with_users, visits_with_users_new, assert_no_dupes=False, columns_should_be_subset=('UserId', 'RoleName', 'Group'))

visits_with_users.head()

In [None]:
nurse_provider_mapping_df = visits_with_users[['UserId', 'RoleName', 'Group']].drop_duplicates()
nurse_provider_mapping_df

In [None]:
assert not any(visits_with_users.duplicated())
assert visits_with_users.VisitCode.notna().all(), "VisitCode has nulls in visits_with_users"
assert visits_with_users.UserId.notna().all(), "UserId has nulls in visits_with_users"
assert visits_with_users.RoleName.notna().all(), "RoleName has nulls in visits_with_users"
assert visits_with_users.Group.notna().all(), "Group has nulls in visits_with_users"

In [None]:
visits_with_users.RoleName.value_counts(dropna=False)

In [None]:
visits_with_users.Group.value_counts(dropna=False)

In [None]:
def summarize_duplicated_visits_with_users(visits_with_users_curr):
    # Find duplicated VisitCodes
    dup_mask = visits_with_users_curr.VisitCode.duplicated(keep=False)
    dups = visits_with_users_curr[dup_mask]

    # What % of the overall dataset are these rows?
    percent_dups = len(dups) / len(visits_with_users_curr) * 100
    print(f"Rows with duplicated VisitCode: {len(dups)} ({percent_dups:.2f}% of all rows)")

    # number of unique visit codes with duplicates
    print(f'Number of unique visit codes with duplicates: {len(dups.VisitCode.unique())} ({len(dups.VisitCode.unique()) / len(visits_with_users_curr.VisitCode.unique()) * 100:.2f}% of all unique visit codes)')

    # For each duplicated VisitCode, check the unique groups present
    group_counts = dups.groupby('VisitCode')['Group'].unique().reset_index()

    all_silent = group_counts['Group'].apply(lambda x: set(x) == {'Silent AI'}).sum()
    all_active = group_counts['Group'].apply(lambda x: set(x) == {'Active AI'}).sum()
    mixed = group_counts['Group'].apply(lambda x: set(x) == {'Silent AI', 'Active AI'} or set(x) == {'Active AI', 'Silent AI'}).sum()

    print(f"Number of duplicated VisitCodes where all groups are Silent AI: {all_silent} ({all_silent / len(group_counts) * 100:.2f}% of all duplicated visit codes)")
    print(f"Number of duplicated VisitCodes where all groups are Active AI: {all_active} ({all_active / len(group_counts) * 100:.2f}% of all duplicated visit codes)")
    print(f"Number of duplicated VisitCodes where there is a mix of Silent AI and Active AI: {mixed} ({mixed / len(group_counts) * 100:.2f}% of all duplicated visit codes)")

In [None]:
print("Duplicated visits with users (all clinicians)")
summarize_duplicated_visits_with_users(visits_with_users)
print("\nDuplicated visits with users (providers only)")
summarize_duplicated_visits_with_users(visits_with_users[visits_with_users.RoleName == "Provider"])

In [None]:
n_total_visits_with_users = len(visits_with_users)
visits_with_users_providers_only = visits_with_users[visits_with_users.RoleName == "Provider"]
n_provider_visits_with_users = len(visits_with_users_providers_only)
print(f"Percentage of visits with users that are providers: {n_provider_visits_with_users / n_total_visits_with_users * 100:.2f}% - remaining visits were dropped")

In [None]:
def summarize_visits_with_users(group: pd.DataFrame) -> pd.Series:
    # AI group composition for all users
    group_set_all = set(group["Group"].unique())
    if group_set_all == {"Active AI"}:
        all_user_role = "AI"
    elif group_set_all == {"Silent AI"}:
        all_user_role = "Non-AI"
    elif group_set_all == {"Active AI", "Silent AI"}:
        all_user_role = "Crossover"
    else:
        raise ValueError(f"Unexpected group composition: {group_set_all}")

    user_ids_all = set(group["UserId"].unique())

    return pd.Series({
        "ClinicianGroup": all_user_role,
        "UserIDs": group["UserId"].unique().tolist(),
    })

summarized_visits_with_users = visits_with_users_providers_only.groupby("VisitCode", group_keys=False).apply(summarize_visits_with_users).reset_index()

In [None]:
summarized_visits_with_users.head()

### ai_interaction_scrubbed

In [None]:
ai_interaction_scrubbed = dfs['ai_interaction_scrubbed']
ai_interaction_scrubbed_new = dfs['scrubbed_ai_interaction_new']
ai_interaction_scrubbed = concat_old_new_dfs(ai_interaction_scrubbed, ai_interaction_scrubbed_new, assert_no_dupes=False, index_col='VisitCode', columns_should_be_subset=('Version', 'SystemRolePrompt', 'AiLike', 'ClinicalDecisionRule', 'FocusOutEmrComponent', 'Silent', 'Acknowledged'))

In [None]:
ai_interaction_scrubbed.columns

In [None]:
(ai_interaction_scrubbed.AiResponse.str.startswith("```")).sum()

In [None]:
ai_interaction_scrubbed.CreatedOn = pd.to_datetime(ai_interaction_scrubbed.CreatedOn)

In [None]:
DATE_ACKNOWLEDGEMENT_STARTED = ai_interaction_scrubbed[ai_interaction_scrubbed.Acknowledged == 1.0].CreatedOn.min()

In [None]:
# replace thumbs up data with 'Up', 'Down', 'None'
ai_interaction_scrubbed.AiLike = ai_interaction_scrubbed.AiLike.map({1.0: 'Up', 0.0: 'Down'}, na_action='ignore')
ai_interaction_scrubbed.AiLike = ai_interaction_scrubbed.AiLike.fillna(value = 'None')

# replace acknowledgement data with True/False (and NA if before acknowledgement data started to be collected)
ai_interaction_scrubbed.Acknowledged = ai_interaction_scrubbed.Acknowledged.map({1.0: True}, na_action='ignore')
ai_interaction_scrubbed.Acknowledged = ai_interaction_scrubbed.apply(
    lambda row: False if pd.isna(row.Acknowledged) and row.CreatedOn > DATE_ACKNOWLEDGEMENT_STARTED else row.Acknowledged,
    axis=1
)
ai_interaction_scrubbed.Acknowledged = ai_interaction_scrubbed.Acknowledged.where(pd.notna(ai_interaction_scrubbed.Acknowledged), None)

# replace silent data with 'Silent'/'Active'
ai_interaction_scrubbed['Silent'] = ai_interaction_scrubbed['Silent'].map({True: 'Silent', False: 'Active'}, na_action='ignore')

In [None]:
pd.crosstab(
    ai_interaction_scrubbed.FocusOutEmrComponent,
    ai_interaction_scrubbed.ClinicalDecisionRule,
    dropna=False
)

In [None]:
failed_parses = []
def parse_ai_response(ai_response):
    """
    Parse an AiResponse string into a dict using parse_json.
    Handles:
      - raw JSON
      - JSON wrapped in ```
      - JSON wrapped in ```json ... ```
    Returns None if parsing fails or input is null/empty.
    """
    if not isinstance(ai_response, str) or not ai_response.strip():
        return None

    s = ai_response.strip()
    if s.startswith("```"): # handle markdown-wrapped JSON
        s = s[3:]
        s = s.lstrip()
        if s.lower().startswith("json"):
            s = s[4:]
            s = s.lstrip()
        if s.endswith("```"):
            s = s[:-3]
        s = s.strip()
    try:
        return json.loads(s)
    except Exception:
        failed_parses.append(s)
        return None

AIResponseParsed = ai_interaction_scrubbed.AiResponse.map(parse_ai_response)
sum(AIResponseParsed.isna()), sum(AIResponseParsed.isna()) / len(ai_interaction_scrubbed), sum(ai_interaction_scrubbed.AiResponse.isna()) / len(ai_interaction_scrubbed)

In [None]:
print(f'Percentage of rows that failed to parse: {len(failed_parses) / len(ai_interaction_scrubbed) * 100:.2f}%')
n_failed_parse_curly_brace = sum(['{' in s or '}' in s for s in failed_parses])
print(f'Percentage of failed parses that contain a curly brace: {n_failed_parse_curly_brace / len(failed_parses) * 100:.2f}%')
n_failed_parse_error = sum(['error' in s.lower() and 'http' in s.lower() for s in failed_parses])
print(f'Percentage of failed parses that contain an error and http: {n_failed_parse_error / len(failed_parses) * 100:.2f}%')
n_failed_parse_unknown = len(failed_parses) - n_failed_parse_curly_brace - n_failed_parse_error
print(f'Percentage of failed parses that are of other types (typically model not comforming to JSON format): {n_failed_parse_unknown / len(failed_parses) * 100:.2f}%')

In [None]:
ai_interaction_scrubbed['AIResponseParsed'] = AIResponseParsed
pre_drop_n = len(ai_interaction_scrubbed)
ai_interaction_scrubbed.dropna(subset=['AIResponseParsed'], inplace=True)
post_drop_n = len(ai_interaction_scrubbed)
print(f'Percentage of rows that were dropped: {(pre_drop_n - post_drop_n) / pre_drop_n * 100:.2f}%')

In [None]:
ai_interaction_scrubbed.ClinicalDecisionRule.unique()

In [None]:
def parse_ai_response(ai_response: dict):
    try:
        return types.AIResponse.model_validate(ai_response)
    except ValidationError:
        return None

In [None]:
ai_interaction_scrubbed['AIResponseValidated'] = ai_interaction_scrubbed.AIResponseParsed.apply(parse_ai_response)
n_none = sum(ai_interaction_scrubbed.AIResponseValidated.isna())
print(f'Percentage of rows that failed to validate: {n_none / len(ai_interaction_scrubbed) * 100:.4f}% ({n_none} rows / {len(ai_interaction_scrubbed)} total)')

In [None]:
original_length = len(ai_interaction_scrubbed)
ai_interaction_scrubbed.dropna(subset=['AIResponseValidated'], inplace=True)
post_drop_length = len(ai_interaction_scrubbed)
print(f'Percentage of rows that were dropped: {(original_length - post_drop_length) / original_length * 100:.4f}% ({original_length - post_drop_length} rows / {original_length} total)')

In [None]:
ai_interaction_scrubbed = ai_interaction_scrubbed.merge(nurse_provider_mapping_df, on = 'UserId', how = 'left')
ai_interaction_scrubbed = ai_interaction_scrubbed[ai_interaction_scrubbed.RoleName == 'Provider']

In [None]:
# Helper to convert a single dataframe row into an AICall object
def _row_to_aicall(row) -> types.AICall:
    return types.AICall(
        rule=row.ClinicalDecisionRule,
        response=row.AIResponseValidated,
        user_id=row.UserId,
        time=row.CreatedOn,
        thumbs_up_down=row.AiLike,
        silent=row.Silent,
        acknowledged=row.Acknowledged,
        user_role_prompt=row.UserRolePrompt,
    )

def _rows_to_aicalls(rows: pd.DataFrame) -> types.AICalls:
    return types.AICalls(calls=[_row_to_aicall(row) for _, row in rows.iterrows()])

visit_aicalls_df = (
    ai_interaction_scrubbed
    .groupby("VisitCode", sort=False)
    .apply(_rows_to_aicalls)
    .reset_index()
    .rename(columns={0: "AICalls"})
)

In [None]:
all_aicalls = visit_aicalls_df.AICalls.apply(lambda x: x.calls).explode()
active_aicalls = all_aicalls[all_aicalls.map(lambda x: x.silent == 'Active')]
active_red_aicalls = active_aicalls[active_aicalls.map(lambda x: x.color == types.Color.Red)]
active_red_aicalls.map(lambda x: x.acknowledged).value_counts(normalize=True)

### clinical_documentation

In [None]:
clinical_documentation = dfs['clinical_documentation']
clinical_documentation = clinical_documentation.drop(columns = ['VisitId_1'])
clinical_documentation_new = dfs['scrubbed_clinical_documentation_cleaned']
clinical_documentation = concat_old_new_dfs(clinical_documentation, clinical_documentation_new, columns_should_be_subset = ('LocationName', 'VisitCategory', 'VisitType', 'VisitDate'))

clinical_documentation.head()

In [None]:
assert clinical_documentation.VisitCode.is_unique, "VisitCode is not unique in clinical_documentation"

In [None]:
clinical_documentation.columns

In [None]:
categories_not_na = ['VisitCode', 'LocationName', 'VisitCategory', 'VisitType', 'VisitDate', 'Age']
categories_na = [c for c in clinical_documentation.columns if c not in categories_not_na]
for c in categories_not_na:
    assert not clinical_documentation[c].isna().any(), f"{c} has nulls"
for c in categories_na:
    assert clinical_documentation[c].isna().any(), f"{c} has no nulls"

In [None]:
len(clinical_documentation)

In [None]:
# Calculate NA rate in 'CC' per VisitDate
cc_na_rate_by_date = (
    clinical_documentation
    .groupby('VisitDate')['CC']
    .apply(lambda x: x.isna().mean())
    .sort_index()
)

plt.figure(figsize=(12, 5))
cc_na_rate_by_date.plot(marker='o')
plt.title('Rate of NAs in clinical_documentation.CC vs VisitDate')
plt.ylabel('NA Rate in CC')
plt.xlabel('VisitDate')
plt.tight_layout()
plt.show()

In [None]:
last_visitdate_with_cc = clinical_documentation.loc[clinical_documentation['CC'].notna(), 'VisitDate'].max()
last_visitdate_with_cc

### one_day_outcomes

In [None]:
one_day_outcomes = dfs['one_day_outcomes']
one_day_outcomes_new = dfs['one_day_outcomes_new']

one_day_outcomes = concat_old_new_dfs(one_day_outcomes, one_day_outcomes_new, assert_no_dupes=False, columns_should_be_subset=('Q1',))

one_day_outcomes.head()

one_day_outcomes['one_day_call_outcome'] = one_day_outcomes['Q1']
one_day_outcomes.drop(columns=['Q1'], inplace=True)

In [None]:
# Group by VisitCode and aggregate one_day_call_outcome according to the rules:
# - If both values are the same (including both None), keep that value.
# - If one is 'No' and the other is None, keep 'No'.
# - If one is 'Yes' and the other is None, keep 'Yes'.
# - If one is 'No' and the other is 'Yes', drop the row.

def resolve_outcome(s: pd.Series):
    values = s.tolist()           # ← the Series is already the column we need
    if len(values) > 2:
        raise ValueError(f"Expected at most 2 values, got {len(values)}")

    # treat NaNs as “None” for the logic below
    both_same = len({v for v in values if pd.notnull(v)}) == 1 and \
                (all(pd.notnull(v) for v in values) or all(pd.isnull(v) for v in values))
    if both_same:
        return values[0]

    if 'No' in values and any(pd.isnull(v) for v in values):
        return 'No'
    if 'Yes' in values and any(pd.isnull(v) for v in values):
        return 'Yes'
    if 'No' in values and 'Yes' in values:
        return None
    # all NaN or other unexpected pattern
    return None


one_day_outcomes_grouped = (
    one_day_outcomes
      .groupby('VisitCode', as_index=False)
      .agg({'one_day_call_outcome': resolve_outcome})   # no inner lambda needed
)

one_day_outcomes = one_day_outcomes_grouped

### Visit durations

In [None]:
durations = dfs['visit_provider_durations']
durations['duration_minutes'] = (durations['EndDate'] - durations['StartDate']).dt.total_seconds() / 60
assert durations.duration_minutes.notna().all()
assert durations.duration_minutes.min() > 0

total_duration_by_visit = durations.groupby('VisitCode').duration_minutes.sum()

## Combine data

Process to merge the data:
- Use "visits" as the core data frame 
- Join on "visits with users" (should be a completely 1:1 join) to get what users saw that patient and what group those users were in (i.e., the allocation)
- Join on "outcomes" (left join) to add in the outcomes nicely where those are present for a given patient
- Join on "clinical documentation" (left join)
- Join on "ai_interaction_scrubbed" left join to get the final red rate
- Join on "one_day_outcomes" (left join) to get the deeper outcome data here 

skip:
- Allocations - redundant with visits with users 

So in sum:
- Every visit has basic information about the patient
- Every visit has information about the provider, including what group they were in
- The vast majority of visits, but not every single visit, has clinical documentation
- A fraction of visits have outcomes

In [None]:
def deduplicate_columns_modulo_nans(df, allowed_exceptions: list[tuple[str, str, str]] = []):
    for x_col in df.columns:
        if x_col.endswith('_x'):
            y_col = x_col.replace('_x', '_y')
            # Check equality only where y_col is notna
            mask = df[y_col].notna()
            unequal = df.loc[mask, [x_col, y_col]][df.loc[mask, x_col] != df.loc[mask, y_col]]
            if not unequal.empty:
                for colname, xval, yval in allowed_exceptions:
                    if colname in x_col and colname in y_col:
                        unequal = unequal[~((unequal[x_col] == xval) & (unequal[y_col] == yval))]

                if not unequal.empty:
                    print(f"Cells where {x_col} and {y_col} differ (showing index and values):")
                    print(unequal)
                    raise AssertionError(f"{x_col} and {y_col} are not the same modulo NaNs in {y_col}")
            df = df.drop(columns=[y_col])
            df.rename(columns={x_col: x_col.replace('_x', '')}, inplace=True)
    return df

In [None]:
df = visits

assert set(df.VisitCode) >= set(clinical_documentation.VisitCode), "visits.VisitCode and clinical_documentation.VisitCode are not the same"
df_clinical_documentation = df.merge(clinical_documentation, on='VisitCode', how='left')
df_clinical_documentation = deduplicate_columns_modulo_nans(df_clinical_documentation, allowed_exceptions=[('Age', '57y 9m', '56y 2m')]) # one case of a mismatch, allowing it as spurious / a one-off change
df_clinical_documentation = df_clinical_documentation.replace(np.nan, None)
clinical_documentation_objects = {row['VisitCode']: types.ClinicalDocumentation.model_validate(row.to_dict()) for _, row in df_clinical_documentation.iterrows()}
df['ClinicalDocumentation'] = df.VisitCode.map(clinical_documentation_objects)

assert set(df.VisitCode) >= set(outcomes.VisitCode), "visits.VisitCode is not a superset of outcomes.VisitCode"
df = df.merge(outcomes, on='VisitCode', how='left')

assert set(df.VisitCode) >= set(visit_aicalls_df.VisitCode), "visits.VisitCode is not a superset of visit_aicalls_df.VisitCode"
df = df.merge(visit_aicalls_df, on='VisitCode', how='left')

assert set(df.VisitCode) >= set(summarized_visits_with_users.VisitCode), "Expected visits.VisitCode to be a superset of summarized_visits_with_users.VisitCode"
df = df.merge(summarized_visits_with_users, on='VisitCode', how='right')
assert len(df) == len(summarized_visits_with_users)

assert set(df.VisitCode) <= set(total_duration_by_visit.index), "Expected visits.VisitCode to be a subset of visit_provider_durations.VisitCode"
df = df.merge(total_duration_by_visit, on='VisitCode', how='left')

len_pre = len(df)
df = df.merge(one_day_outcomes, on='VisitCode', how='left')
len_post = len(df)
assert len_pre == len_post

In [None]:
# drop g45 visits
df.LocationName.value_counts()
g45_dropped_length = len(df[df.LocationName == 'Githurai 45'])
print(f'Dropping {g45_dropped_length} visits from Githurai 45')

g45_initial_len = len(df)
df = df[df.LocationName != 'Githurai 45']
g45_final_len = len(df)
assert g45_initial_len - g45_dropped_length == g45_final_len
all_included_visits_len = g45_final_len

In [None]:
# drop crossover visits
crossover_dropped_len = len(df[df.ClinicianGroup == 'Crossover'])
print(f"Dropping {crossover_dropped_len} visits with mixed clinician groups")
crossover_initial_len = len(df)
df = df[df.ClinicianGroup.isin(['Non-AI', 'AI'])]
crossover_final_len = len(df)
assert crossover_final_len == crossover_initial_len - crossover_dropped_len, "Expected final length to be initial length minus dropped length"

### Add additional computed variables

In [None]:
latexvars['rolloutStartDate'] = df.VisitDate.min().strftime("%B %-d %Y")
latexvars['rolloutEndDate'] = df.VisitDate.max().strftime("%B %-d %Y")
cut_point = datetime.strptime('2025-03-01', '%Y-%m-%d')
cut_point_str = cut_point.strftime("%B %-d %Y")
latexvars['inductionEndDate'] = (cut_point - timedelta(days=1)).strftime("%B %-d %Y")
latexvars['restStudyStartDate'] = cut_point.strftime("%B %-d %Y")

df['during_main_study'] = df['VisitDate'] >= cut_point

In [None]:
def treatment_group(row: pd.Series) -> str:
    if row.ClinicianGroup == 'Non-AI':
        return 'Non-AI'
    elif row.ClinicianGroup == 'AI':
        if row.VisitDate >= cut_point:
            return 'AI - main study'
        else:
            return 'AI - induction period'
    else:
        raise ValueError(f"Invalid ClinicianGroup: {row.ClinicianGroup}")

df['treatment_group'] = df.apply(treatment_group, axis=1)
df.treatment_group.value_counts()

In [None]:
def is_na(x: object) -> TypeIs[NAType]:
    return pd.isna(x)

def n_relevant_calls(
    call: types.AICalls | NAType,
    rules: list[types.ClinicalDecisionRule] | None = None,
    colors: list[types.Color] | None = None,
) -> int | NAType:
    if is_na(call):
        return pd.NA

    if rules is None:
        rules = list(types.ClinicalDecisionRule)

    if colors is None:
        colors = list(types.Color)

    relevant_calls = [
        x for x in call.calls
        if x.rule in rules
        and x.color in colors
    ]
    return len(relevant_calls)

df['n_aicalls'] = df.AICalls.map(lambda x: n_relevant_calls(x))

In [None]:
df["week"] = df["VisitDate"].dt.to_period("W").apply(
    lambda r: r.start_time
)

In [None]:
def risky_cases(
    call: types.AICalls | NAType,
    rules: list[types.ClinicalDecisionRule],
    colors: list[types.Color] | None = None,
) -> bool | NAType:
    """Determine if a visit is a risky case, defined as having at least one red call for the rules provided."""
    if is_na(call):
        return False

    if colors is None:
        colors = [types.Color.Red]

    relevant_calls = [
        x for x in call.calls
        if x.rule in rules
    ]
    return any(x.color in colors for x in relevant_calls)

df['risky_cases_for_history'] = df.AICalls.map(lambda x: risky_cases(x, rules=[types.ClinicalDecisionRule.VitalsChiefComplaintEvaluation, types.ClinicalDecisionRule.ClinicalNotes]))
df['risky_cases_for_investigations'] = df.AICalls.map(lambda x: risky_cases(x, rules=[types.ClinicalDecisionRule.InvestigationRecommendations]))
df['risky_cases_for_diagnosis'] = df.AICalls.map(lambda x: risky_cases(x, rules=[types.ClinicalDecisionRule.DiagnosisEvaluation]))
df['risky_cases_for_treatment'] = df.AICalls.map(lambda x: risky_cases(x, rules=[types.ClinicalDecisionRule.TreatmentRecommendation]))

In [None]:
durations_with_min_start_date = durations.loc[durations.groupby('VisitCode')['StartDate'].idxmin()]
assert durations_with_min_start_date.VisitCode.is_unique
earliest_user_id = durations_with_min_start_date[['VisitCode', 'UserId']]
earliest_user_id_dict = dict(zip(earliest_user_id.VisitCode, earliest_user_id.UserId))

df['user_id'] = df.VisitCode.map(earliest_user_id_dict)

df.loc[:, 'weeks_since_start'] = ((df.VisitDate - df.VisitDate.min()).dt.days / 7)

## Descriptive statistics

In [None]:
study_start   = min(df.VisitDate)
induction_start = study_start          # induction starts same day
main_start      = cut_point
main_end        = max(df.VisitDate)

d2n = mdates.date2num

col_induction = '#c7e9c0'  # light green tint
col_main      = '#74c476'  # medium green tint
col_silent    = '#b3cde3'  # pastel blue

# Font sizes
label_fs  = 12   # bar labels
marker_fs = 12   # milestone labels
axis_fs   = 10    # axis tick / group label font

# Figure
fig, ax = plt.subplots(figsize=(9, 3))

y_active, y_silent, h = 1, 0.5, 0.3

# Active bars
ax.barh(
    y_active,
    float(d2n(main_start)) - float(d2n(induction_start)),
    left=float(d2n(induction_start)),
    height=h,
    color=col_induction, edgecolor='none'
)
ax.barh(
    y_active,
    float(d2n(main_end)) - float(d2n(main_start)),
    left=float(d2n(main_start)),
    height=h,
    color=col_main, edgecolor='none'
)

# Silent bar
ax.barh(
    y_silent,
    float(d2n(main_end)) - float(d2n(study_start)),
    left=float(d2n(study_start)),
    height=h,
    color=col_silent, edgecolor='none'
)

# Bar labels (just phase names)
ax.text(
    float(d2n(induction_start)) + (float(d2n(main_start)) - float(d2n(induction_start))) / 2,
    y_active,
    "Induction",
    ha='center', va='center', fontsize=label_fs, color='black'
)
ax.text(
    float(d2n(main_start)) + (float(d2n(main_end)) - float(d2n(main_start))) / 2,
    y_active,
    "Main study",
    ha='center', va='center', fontsize=label_fs, color='black'
)
ax.text(
    float(d2n(study_start)) + (float(d2n(main_end)) - float(d2n(study_start))) / 2 + 1,
    y_silent,
    "No CDSS",
    ha='center', va='center', fontsize=label_fs, color='black'
)

# ---------------- Time Point Markers ----------------
marker_y = y_active + 0.28  # closer but slightly up

# Horizontal timeline bar
ax.hlines(marker_y - 0.03, float(d2n(study_start)), float(d2n(main_end)), color='black', linewidth=1)

# Milestones
offset = 0.07  # vertical offset for milestone labels
ax.axvline(float(d2n(study_start)), color='black', linewidth=1.2, linestyle=':')
ax.text(
    float(d2n(study_start)),
    marker_y + offset,
    f"Start ({study_start.strftime('%b %-d')})",
    ha='center', va='bottom', fontsize=marker_fs
)

ax.axvline(float(d2n(main_start)), color='black', linewidth=1.2, linestyle=':')
ax.text(
    float(d2n(main_start)),
    marker_y + offset,
    f"Main study begins ({main_start.strftime('%b %-d')})",
    ha='center', va='bottom', fontsize=marker_fs, weight='bold'
)

ax.axvline(float(d2n(main_end)), color='black', linewidth=1.2, linestyle=':')
ax.text(
    float(d2n(main_end)),
    marker_y + offset,
    f"End ({main_end.strftime('%b %-d')})",
    ha='center', va='bottom', fontsize=marker_fs
)

# ---------------- Axis & Styling ----------------
ax.set_yticks([y_active, y_silent])
ax.set_yticklabels(["AI group", "Non-AI group"], fontsize=axis_fs)

# Add xticks every two weeks and set xlabels
tick_dates = []
tick_labels = []
current = study_start
while current <= main_end:
    tick_dates.append(float(d2n(current)))
    tick_labels.append(current.strftime('%b %-d'))
    current += timedelta(days=14)

ax.set_xticks(tick_dates)
ax.set_xticklabels(tick_labels, fontsize=axis_fs)

# Add little vertical lines at each xtick
for x in tick_dates:
    ax.axvline(x, color='gray', linestyle='-', linewidth=0.7, ymin=0, ymax=1, alpha=0.2)

ax.set_xlim(float(d2n(study_start)) - 2, float(d2n(main_end)) + 2)

# Remove all grid & spines
ax.grid(False)
for spine in ['top', 'right', 'bottom', 'left']:
    ax.spines[spine].set_visible(False)
ax.xaxis.set_ticks_position('none')
ax.yaxis.set_ticks_position('none')

plt.tight_layout()
plt.show()

In [None]:
save_fig('study_timeline', fig)

In [None]:
latexvars['nTotalVisits'] = 87931
latexvars['nEligibleVisits'] = 52409
latexvars['nIneligibleVisits'] = 35522
assert latexvars['nTotalVisits'] == latexvars['nEligibleVisits'] + latexvars['nIneligibleVisits']

latexvars['nOutsideNairobiCounty'] = 3878
latexvars['nIneligibleVisitType'] = 29210
latexvars['nNotSeenByProvider'] = 2434
assert latexvars['nIneligibleVisits'] == latexvars['nOutsideNairobiCounty'] + latexvars['nIneligibleVisitType'] + latexvars['nNotSeenByProvider']

latexvars['nConsentedVisits'] = 40745
assert latexvars['nConsentedVisits'] == all_included_visits_len
latexvars['nSeenByProviderInBothGroups'] = crossover_dropped_len
latexvars['nNotConsentedVisits'] = latexvars['nEligibleVisits'] - latexvars['nConsentedVisits']

latexvars['nSilentVisits'] = df[df.ClinicianGroup == 'Non-AI'].shape[0]
latexvars['nActiveVisits'] = df[df.ClinicianGroup == 'AI'].shape[0]
assert len(df) == latexvars['nSilentVisits'] + latexvars['nActiveVisits']

def has_outcome_data(s: pd.Series) -> bool:
    if pd.notna(s.Q1) or pd.notna(s.Q2):
        return True
    else:
        return False

latexvars['nDocsSilentVisits'] = df[df.ClinicianGroup == 'Non-AI'].ClinicalDocumentation.notna().sum()
latexvars['nOutcomesSilentVisits'] = df[df.ClinicianGroup == 'Non-AI'].apply(has_outcome_data, axis=1).sum()

latexvars['nDocsActiveVisits'] = df[df.ClinicianGroup == 'AI'].ClinicalDocumentation.notna().sum()
latexvars['nOutcomesActiveVisits'] = df[df.ClinicianGroup == 'AI'].apply(has_outcome_data, axis=1).sum()

dot = Digraph(comment="AI CDSS Study Flow Diagram (Visits) – v2", format="png")
dot.attr(rankdir="TB", splines="ortho")
dot.attr("node",
            shape="box",
            style="filled",
            fontname="Times",
            fontsize="10",
            color="black",
            width="2.4",
            fixedsize="width")  # lock width; height auto

dot.attr("edge", arrowhead="normal", arrowsize="0.7")

# --- Backbone column (col 1) ------------------------------------------
dot.node("total", f"Total visits during study period\n(n = {latexvars['nTotalVisits']})", fillcolor="white")
latexvars['pctEligible'] = f'{latexvars["nEligibleVisits"] / latexvars["nTotalVisits"] * 100:.1f}%'
latexvars['pctConsented'] = f'{latexvars["nConsentedVisits"] / latexvars["nEligibleVisits"] * 100:.1f}%'
dot.node("eligible", f"Eligible visits\n(n = {latexvars['nEligibleVisits']}, {latexvars['pctEligible']})", fillcolor="white")
dot.node("consent", f"Visits where patients consented to\nPenda's general consent form\n(n = {latexvars['nConsentedVisits']}, {latexvars['pctConsented']})", fillcolor="white")

# Invisible point to fan out edges
dot.node("split", "", shape="point", width="0.01", height="0.01", color="black")

dot.edge("total", "eligible")
dot.edge("eligible", "consent")
dot.edge("consent", "split", arrowhead="none")

# --- Column 4: Excluded / ineligible ----------------------------------
latexvars['pctIneligible'] = f'{latexvars["nIneligibleVisits"] / latexvars["nTotalVisits"] * 100:.1f}%'
latexvars['pctOutsideNairobiCounty'] = f'{latexvars["nOutsideNairobiCounty"] / latexvars["nTotalVisits"] * 100:.1f}%'
latexvars['pctIneligibleVisitType'] = f'{latexvars["nIneligibleVisitType"] / latexvars["nTotalVisits"] * 100:.1f}%'
latexvars['pctNotSeenByProvider'] = f'{latexvars["nNotSeenByProvider"] / latexvars["nTotalVisits"] * 100:.1f}%'
dot.node("ineligible",
            f"Ineligible visits (n = {latexvars['nIneligibleVisits']}, {latexvars['pctIneligible']})\n• Outside Nairobi County (n = {latexvars['nOutsideNairobiCounty']}, {latexvars['pctOutsideNairobiCounty']})\n• Ineligible visit type (n = {latexvars['nIneligibleVisitType']}, {latexvars['pctIneligibleVisitType']})\n• Not seen by a provider (n = {latexvars['nNotSeenByProvider']}, {latexvars['pctNotSeenByProvider']})",
            fillcolor="#eeeeee")
latexvars['pctNotConsented'] = f'{latexvars["nNotConsentedVisits"] / latexvars["nEligibleVisits"] * 100:.1f}%'
dot.node("nonconsent",
            f"Did not consent\n(n = {latexvars['nNotConsentedVisits']}, {latexvars['pctNotConsented']})",
            fillcolor="#eeeeee")
latexvars['pctSeenByProviderInBothGroups'] = f'{latexvars["nSeenByProviderInBothGroups"] / latexvars["nEligibleVisits"] * 100:.1f}%'
dot.node("both",
            f"Seen by providers in\nboth groups (excluded)\n(n = {latexvars['nSeenByProviderInBothGroups']}, {latexvars['pctSeenByProviderInBothGroups']})",
            fillcolor="#eeeeee")

# Horizontal arrows from backbone to exclusions
with dot.subgraph() as s:
    s.attr(rank="same")
    s.node("eligible")
    s.node("ineligible")
dot.edge("eligible:e", "ineligible:w", constraint="false")

with dot.subgraph() as s:
    s.attr(rank="same")
    s.node("consent")
    s.node("nonconsent")
dot.edge("consent:e", "nonconsent:w", constraint="false")

# Edge from split to 'both' (excluded third branch) - make it go to the top of the 'both' box
dot.edge("split", "both")

# --- Column 2: Silent‑AI branch ---------------------------------------
latexvars['pctSilentVisits'] = f'{latexvars["nSilentVisits"] / latexvars["nConsentedVisits"] * 100:.1f}%'
latexvars['pctDocsSilentVisits'] = f'{latexvars["nDocsSilentVisits"] / latexvars["nSilentVisits"] * 100:.1f}%'
latexvars['pctOutcomesSilentVisits'] = f'{latexvars["nOutcomesSilentVisits"] / latexvars["nSilentVisits"] * 100:.1f}%'
dot.node("silent",
            f"Non-AI group visits\n(n = {latexvars['nSilentVisits']}, {latexvars['pctSilentVisits']})", fillcolor="#d0e7ff")
dot.node("docs_silent",
            f"Documentation available – non-AI group\n(n = {latexvars['nDocsSilentVisits']}, {latexvars['pctDocsSilentVisits']})", fillcolor="#d0e7ff")
dot.node("out_silent",
            f"Outcome data available – non-AI group\n(n = {latexvars['nOutcomesSilentVisits']}, {latexvars['pctOutcomesSilentVisits']})", fillcolor="#d0e7ff")

dot.edge("split", "silent")
dot.edge("silent", "docs_silent")
dot.edge("docs_silent", "out_silent")

# --- Column 3: Active‑AI branch ---------------------------------------
latexvars['pctActiveVisits'] = f'{latexvars["nActiveVisits"] / latexvars["nConsentedVisits"] * 100:.1f}%'
latexvars['pctDocsActiveVisits'] = f'{latexvars["nDocsActiveVisits"] / latexvars["nActiveVisits"] * 100:.1f}%'
latexvars['pctOutcomesActiveVisits'] = f'{latexvars["nOutcomesActiveVisits"] / latexvars["nActiveVisits"] * 100:.1f}%'
dot.node("active",
            f"AI group visits\n(n = {latexvars['nActiveVisits']}, {latexvars['pctActiveVisits']})", fillcolor="#ffd6d6")
dot.node("docs_active",
            f"Documentation available – AI group\n(n = {latexvars['nDocsActiveVisits']}, {latexvars['pctDocsActiveVisits']})", fillcolor="#ffd6d6")
dot.node("out_active",
            f"Outcome data available – AI group\n(n = {latexvars['nOutcomesActiveVisits']}, {latexvars['pctOutcomesActiveVisits']})", fillcolor="#ffd6d6")

dot.edge("split", "active")
dot.edge("active", "docs_active")
dot.edge("docs_active", "out_active")

# --- Align columns with invisible edges (to keep grid) ----------------
# Ensure columns line vertical by connecting invisible edges downwards in ineligible column
dot.node("excl_dummy1", "", shape="plaintext", width="0", height="0")
dot.node("excl_dummy2", "", shape="plaintext", width="0", height="0")
dot.edge("ineligible", "nonconsent", style="invis")
dot.edge("nonconsent", "both", style="invis")

# Render PNG to bytes
png_bytes = dot.pipe(format='png')

# Display from in-memory bytes
display(Image(data=png_bytes))

bf.write_bytes(plots_folder + "flow_diagram.png", png_bytes)

In [None]:
df.LocationName.unique()

location_name_map = {
    'Tassia': 'Eastlands',
    'Umoja 1': 'Eastlands',
    'Umoja 2': 'Eastlands',
    'Embakasi': 'Eastlands',
    'Pipeline': 'Eastlands',
    'Kangemi': 'Southwest',
    'Kawangware': 'Southwest',
    'Kimathi Street': 'Southwest',
    "Lang'ata": 'Southwest',
    'Mathare North': 'Thika Road Corridor',
    'Kasarani': 'Thika Road Corridor',
    'Sunton': 'Thika Road Corridor',
    'Lucky Summer': 'Thika Road Corridor',
    'Zimmerman': 'Thika Road Corridor',
    'Kahawa West': 'Thika Road Corridor',
}

df['LocationGroup'] = df['LocationName'].map(location_name_map)

In [None]:
# table 1: demographics

def _age_to_years(age_str: str | float | int) -> float:
    """
    Convert strings like “29y 10m” / “3y” / “8m” into decimal years.
    If already numeric (float / int) or NaN, return as-is.
    """
    if pd.isna(age_str):
        return np.nan
    if isinstance(age_str, (int, float)):
        return float(age_str)

    s = str(age_str).strip().lower()
    yrs, mths = 0, 0
    m = re.match(r'(?:(\d+)\s*y)?\s*(?:(\d+)\s*m)?', s)
    if m:
        if m.group(1):
            yrs = int(m.group(1))
        if m.group(2):
            mths = int(m.group(2))
    return yrs + mths / 12


def _fmt_n_pct(n: int, denom: int) -> str:
    pct = 100 * n / denom if denom else 0
    return f"{n:,} ({pct:0.1f}%)"


def _fmt_mean_sd(series: pd.Series) -> str:
    return f"{series.mean():0.1f} ± {series.std(ddof=0):0.1f}"


# --------------------------------------------------------------------------- #
# Derive / clean variables
# --------------------------------------------------------------------------- #
df['VisitDate'] = pd.to_datetime(df['VisitDate'])
df['AgeYears'] = df['Age'].apply(_age_to_years)

df["Period"] = np.where(
    df["VisitDate"] < cut_point,
    f"Visits before {cut_point_str}",
    f"Visits {cut_point_str} or later",
)

df["Payor"] = df["VisitType"].replace({"Insurance": "Insurance", "Cash": "Cash"}).fillna(
    "Other"
)

df["RespondedTo8DayCall"] = df["CallDate"].notna().map({True: "Yes", False: "No"})

# --------------------------------------------------------------------------- #
# Sub-cohort to display and helper containers
# --------------------------------------------------------------------------- #
cohort_df = df[df["ClinicianGroup"].isin(["AI", "Non-AI"])].copy()
groups = ["AI", "Non-AI"]
group_counts = (
    cohort_df["ClinicianGroup"]
    .value_counts()
    .reindex(groups)
    .fillna(0)
    .astype(int)
)

rows: list[dict] = []
section_breaks: list[int] = []  # index after which to insert a blank line

# --------------------------------------------------------------------------- #
# Helper functions for building the table
# --------------------------------------------------------------------------- #
def _add_total_n_row() -> None:
    row = {"Variable": "n"}
    for grp in groups:
        row[grp] = f"{group_counts[grp]:,}"
    rows.append(row)


def _add_categorical(
    var: str,
    *,
    levels: list[str] | None = None,
    custom_labels: dict[str, str] | None = None,
) -> None:
    """
    Add one row per level of the categorical variable `var`.

    Parameters
    ----------
    var : str
        Column name.
    levels : list[str] | None
        Display order.  If None, inferred from value_counts().
    custom_labels : dict[str, str] | None
        Mapping ``raw_value -> label`` to override default “<var>: <value>” text.
    """
    if levels is None:
        levels = cohort_df[var].dropna().value_counts().index.tolist()

    for lvl in levels:
        label = custom_labels.get(lvl) if custom_labels else None
        if label is None:
            label = f"{var}: {lvl}"

        row = {"Variable": label}
        for grp in groups:
            n = cohort_df[
                (cohort_df["ClinicianGroup"] == grp) & (cohort_df[var] == lvl)
            ].shape[0]
            row[grp] = _fmt_n_pct(n, group_counts[grp])
        rows.append(row)


def _fmt_median_iqr(series: pd.Series) -> str:
    median = series.median()
    q1 = series.quantile(0.25)
    q3 = series.quantile(0.75)
    return f"{median:0.1f} [{q1:0.1f}, {q3:0.1f}]"


def _add_continuous(series_name: str, display_name: str) -> None:
    row = {"Variable": display_name}
    for grp in groups:
        series = cohort_df[cohort_df["ClinicianGroup"] == grp][series_name]
        row[grp] = _fmt_median_iqr(series)
    rows.append(row)


# --------------------------------------------------------------------------- #
# Build the descriptive table
# --------------------------------------------------------------------------- #
# 1. Study period
_add_total_n_row()
section_breaks.append(len(rows))

_add_categorical(
    "Period",
    levels=[
        f"Visits before {cut_point_str}",
        f"Visits {cut_point_str} or later",
    ],
    custom_labels={
        f"Visits before {cut_point_str}": f"Induction period (before {cut_point_str})",
        f"Visits {cut_point_str} or later": f"Main study period ({cut_point_str} or later)",
    },
)
section_breaks.append(len(rows))

# 2. Location
_add_categorical(
    "LocationGroup",
    levels=["Eastlands", "Southwest", "Thika Road Corridor"],
    custom_labels={"Eastlands": "Visit location: Eastlands clinics", "Southwest": "Visit location: Southwest clinics", "Thika Road Corridor": "Visit location: Thika Road Corridor clinics"},
)
section_breaks.append(len(rows))

# 2. Age
_add_continuous("AgeYears", "Age (years), median [q25, q75]")
section_breaks.append(len(rows))

# 3. Gender
_add_categorical(
    "Gender",
    levels=["Female", "Male"],
    custom_labels={"Female": "Female", "Male": "Male"},
)
section_breaks.append(len(rows))

# 4. Payor
_add_categorical(
    "Payor",
    levels=["Insurance", "Cash"],
    custom_labels={
        "Insurance": "Insurance visit",
        "Cash": "Cash visit",
    },
)
section_breaks.append(len(rows))

# 5. Responded to 8-day follow-up call
_add_categorical(
    "RespondedTo8DayCall",
    levels=["Yes", "No"],
    custom_labels={
        "Yes": "Did respond to 8-day follow-up call",
        "No": "Did not respond to 8-day follow-up call",
    },
)
section_breaks.append(len(rows))

# --------------------------------------------------------------------------- #
# Insert blank rows between thematic sections
# --------------------------------------------------------------------------- #
rows_with_blanks: list[dict] = []
last_break = 0
for i, break_idx in enumerate(section_breaks):
    rows_with_blanks.extend(rows[last_break:break_idx])
    if i < len(section_breaks) - 1:
        rows_with_blanks.append({"Variable": "", "AI": "", "Non-AI": ""})
    last_break = break_idx

# --------------------------------------------------------------------------- #
# Final DataFrame
# --------------------------------------------------------------------------- #
table1_df = (
    pd.DataFrame(rows_with_blanks)
    .set_index("Variable")[
        ["Non-AI", "AI"]
    ]
)

display(table1_df)
save_latex('table1', table1_df)

# Age: median and IQR for each group
age_row = table1_df.loc["Age (years), median [q25, q75]"]
def _parse_median_iqr(val):
    m = re.match(r"([0-9.]+) \[([0-9.]+), ([0-9.]+)\]", val)
    if m:
        return m.group(1), m.group(2), m.group(3)
    return None, None, None

medianAgeSilent, ageQFirstSilent, ageQThirdSilent = _parse_median_iqr(age_row["Non-AI"])
medianAgeActive, ageQFirstActive, ageQThirdActive = _parse_median_iqr(age_row["AI"])
latexvars["medianAgeSilent"] = medianAgeSilent
latexvars["ageQFirstSilent"] = ageQFirstSilent
latexvars["ageQThirdSilent"] = ageQThirdSilent
latexvars["medianAgeActive"] = medianAgeActive
latexvars["ageQFirstActive"] = ageQFirstActive
latexvars["ageQThirdActive"] = ageQThirdActive

# Gender: percent female for each group
female_row = table1_df.loc["Female"]
def _extract_pct(val):
    m = re.search(r"\(([\d.]+)%\)", val)
    return f"{float(m.group(1)):0.1f}\\%" if m else ""
latexvars["pctFemaleSilent"] = _extract_pct(female_row["Non-AI"])
latexvars["pctFemaleActive"] = _extract_pct(female_row["AI"])

# Payor: percent insurance for each group
insurance_row = table1_df.loc["Insurance visit"]
latexvars["pctInsuranceSilent"] = _extract_pct(insurance_row["Non-AI"])
latexvars["pctInsuranceActive"] = _extract_pct(insurance_row["AI"])

# 8-day follow-up: percent responded for each group
followup_row = table1_df.loc["Did respond to 8-day follow-up call"]
latexvars["pctFollowupSilent"] = _extract_pct(followup_row["Non-AI"])
latexvars["pctFollowupActive"] = _extract_pct(followup_row["AI"])

# Location: percent for each region and group
eastlands_row = table1_df.loc["Visit location: Eastlands clinics"]
southwest_row = table1_df.loc["Visit location: Southwest clinics"]
thika_row = table1_df.loc["Visit location: Thika Road Corridor clinics"]
latexvars["pctEastlandsSilent"] = _extract_pct(eastlands_row["Non-AI"])
latexvars["pctEastlandsActive"] = _extract_pct(eastlands_row["AI"])
latexvars["pctSouthwestSilent"] = _extract_pct(southwest_row["Non-AI"])
latexvars["pctSouthwestActive"] = _extract_pct(southwest_row["AI"])
latexvars["pctThikaSilent"] = _extract_pct(thika_row["Non-AI"])
latexvars["pctThikaActive"] = _extract_pct(thika_row["AI"])

In [None]:
# Calculate the median and IQR (q25, q75) number of rows per UserID, and compare between the two groups
user_group_counts = (
    df[['UserIDs', 'ClinicianGroup']]
    .explode('UserIDs')
    .dropna(subset=['UserIDs', 'ClinicianGroup'])
    .groupby(['ClinicianGroup', 'UserIDs'])
    .size()
    .rename('n_rows')
    .reset_index()
)

# Median and IQR (q25, q75) number of rows per UserID in each ClinicianGroup, plus number of unique UserIDs
agg_rows_per_userid = (
    user_group_counts.groupby('ClinicianGroup')['n_rows']
    .agg(
        median='median',
        q25=lambda x: x.quantile(0.25),
        q75=lambda x: x.quantile(0.75)
    )
)

median_iqr_series = agg_rows_per_userid.apply(
    lambda row: f"{int(row['median'])} [{int(row['q25'])}-{int(row['q75'])}]", axis=1
)

# Number of unique UserIDs in each ClinicianGroup
unique_userids_per_group = (
    user_group_counts.groupby('ClinicianGroup')['UserIDs'].nunique()
)

# Concatenate unique_userids_per_group and median_iqr_series into a single DataFrame for summary
userids_and_median_iqr = pd.concat(
    [unique_userids_per_group.rename("n providers"), median_iqr_series.rename("median [q25-q75]")],
    axis=1
)
userids_and_median_iqr = userids_and_median_iqr.rename_axis("Group")

latexvars['nSilentAI'] = unique_userids_per_group['Non-AI']
nsilentai = unique_userids_per_group['Non-AI']
latexvars['nActiveAI'] = unique_userids_per_group['AI']
nactiveai = unique_userids_per_group['AI']
latexvars['nProviders'] = unique_userids_per_group['Non-AI'] + unique_userids_per_group['AI']

latexvars["visitsPerProviderMedSilent"] = f"{agg_rows_per_userid['median']['Non-AI']:.0f}"
latexvars["visitsPerProviderQFirstSilent"] = f"{agg_rows_per_userid['q25']['Non-AI']:.0f}"
latexvars["visitsPerProviderQThirdSilent"] = f"{agg_rows_per_userid['q75']['Non-AI']:.0f}"
latexvars["visitsPerProviderMedActive"] = f"{agg_rows_per_userid['median']['AI']:.0f}"
latexvars["visitsPerProviderQFirstActive"] = f"{agg_rows_per_userid['q25']['AI']:.0f}"
latexvars["visitsPerProviderQThirdActive"] = f"{agg_rows_per_userid['q75']['AI']:.0f}"


save_latex('table2', userids_and_median_iqr)
userids_and_median_iqr.to_latex()
userids_and_median_iqr

## Clinician rater study analysis

### Load data

In [None]:
def jsonl_load(file_path: str) -> list[dict]:
    records: list[dict] = []

    with open(file_path, "r") as f:
        lines = f.readlines()

    for line in lines:
        line = line.strip()
        if line:
            try:
                records.append(json.loads(line))
            except json.JSONDecodeError as e:
                raise ValueError(f"Invalid JSON on line {len(records)+1} of {file_path!s}") from e

    return records

results_clinical_study = jsonl_load(clinical_study_results_fp)
results_clinical_study_df = pd.DataFrame(results_clinical_study)
results_clinical_study_df.rename(columns={'form_a_likert': 'history_likert', 'form_b_likert': 'investigations_likert', 'form_c_likert': 'diagnosis_likert', 'form_d_likert': 'treatment_likert'}, inplace=True)

results_clinical_study_df = results_clinical_study_df.merge(df, left_on='visit_code', right_on='VisitCode', how='left')
results_clinical_study_df.dropna(subset = ['ClinicianGroup'], inplace=True)

In [None]:
all_trainer_emails = results_clinical_study_df.trainer_email.unique()

In [None]:
# get stats for the paper
n_total_ratings = len(results_clinical_study)
n_unique_visits_rated = len({state['visit_code'] for state in results_clinical_study})
visit_code_counts = Counter(state['visit_code'] for state in results_clinical_study)
visit_code_count_counts = Counter(visit_code_counts.values())
n_visits_single_rated = visit_code_count_counts[1]
n_visits_double_rated = visit_code_count_counts[2]
assert n_unique_visits_rated == n_visits_single_rated + n_visits_double_rated

latexvars['nTotalHumanRatings'] = n_total_ratings
latexvars['nUniqueVisitsHumanRated'] = n_unique_visits_rated
latexvars['nVisitsSingleRated'] = n_visits_single_rated
latexvars['nVisitsDoubleRated'] = n_visits_double_rated

In [None]:
def intify_likert(likert: str | int | float | None) -> int:
    """
    Convert a Likert value stored either as an int (1–5) or as a string whose
    first character is the integer code.  Raises ValueError otherwise.

    NOTE: `None` values are interpreted as NaN and will raise.
    """
    if isinstance(likert, str):
        try:
            return_val = int(likert[0])
        except (TypeError, ValueError):
            raise ValueError(
                f"Str value for Likert that could not be converted to int by taking first character: {likert=}"
            )
    elif isinstance(likert, int):
        return_val = likert
    elif isinstance(likert, float):
        assert likert.is_integer(), f"Invalid Likert value: {likert}"
        return_val = int(likert)
    else:  # includes `None`
        raise ValueError(f"Invalid value type: {type(likert)}, {likert=}")

    assert 1 <= return_val <= 5, f"Invalid Likert value: {return_val}"
    return return_val

for col in ['history_likert', 'investigations_likert', 'diagnosis_likert', 'treatment_likert']:
    results_clinical_study_df[col] = results_clinical_study_df[col].apply(intify_likert)

In [None]:
# zero out history Likert and MCQ for VisitDate > last_visitdate_with_cc
results_clinical_study_df.loc[results_clinical_study_df.VisitDate > last_visitdate_with_cc, ['history_likert']] = pd.NA
results_clinical_study_df.loc[results_clinical_study_df.VisitDate > last_visitdate_with_cc, ['history_mcq']] = pd.NA

In [None]:
def is_low_likert(likert: int | NAType) -> int | NAType:
    if is_na(likert):
        return pd.NA

    if isinstance(likert, float):
        if not likert.is_integer():
            raise ValueError(f"likert must be an int, got {type(likert)}. Likert: {likert}")
        likert = int(likert)

    if not isinstance(likert, int):
        raise ValueError(f"likert must be an int, got {type(likert)}. Likert: {likert}")

    return 1 if likert <= 2 else 0

results_clinical_study_df['history_likert_is_low'] = results_clinical_study_df['history_likert'].apply(is_low_likert)
results_clinical_study_df['investigations_likert_is_low'] = results_clinical_study_df['investigations_likert'].apply(is_low_likert)
results_clinical_study_df['diagnosis_likert_is_low'] = results_clinical_study_df['diagnosis_likert'].apply(is_low_likert)
results_clinical_study_df['treatment_likert_is_low'] = results_clinical_study_df['treatment_likert'].apply(is_low_likert)

In [None]:
# get acuity nice
def get_acuity(mcq_str: str) -> str:
    if mcq_str.startswith('Medium'):
        return 'Medium'
    elif mcq_str.startswith('High'):
        return 'High'
    elif mcq_str.startswith('Low'):
        return 'Low'
    else:
        raise ValueError(f'Unknown acuity: {mcq_str}')

results_clinical_study_df['acuity'] = results_clinical_study_df.form_e_mcq.apply(get_acuity)
results_clinical_study_df.drop(columns=['form_e_mcq'], inplace=True)

In [None]:
assert results_clinical_study_df['VisitCode'].value_counts().isin([1, 2]).all()

In [None]:
# add row weights for duplicated visits
results_clinical_study_df['row_weight'] = np.where(results_clinical_study_df.duplicated("VisitCode", keep=False), 0.5, 1.0)

### Plotting utils

In [None]:
class TwoByTwoTable(BaseModel):
    group_one_true_count: int
    group_one_false_count: int
    group_two_true_count: int
    group_two_false_count: int

    @property
    def group_one_n(self) -> int:
        return self.group_one_true_count + self.group_one_false_count

    @property
    def group_two_n(self) -> int:
        return self.group_two_true_count + self.group_two_false_count

def get_2x2_table(
    df: pd.DataFrame,
    allocation_col: str,
    allocation_group_one_values: list,
    allocation_group_two_values: list,
    outcome_col: str,
    outcome_true_values: list,
    outcome_false_values: list,
    unique_column: str | None = None
) -> TwoByTwoTable:
    """
    Gets a 2x2 table of counts for a given allocation and outcome.

    Ignores rows for which the allocation column is not in the allocation_group_one_values or allocation_group_two_values, and rows for which the outcome column is not in the outcome_true_values or outcome_false_values.

    If unique_column is provided, then if there are multiple observations for a given unique_column, then these observations are averaged, and the resulting counts rounded to the nearest integer.
    """
    relevant_df = df[
        (df[allocation_col].isin(allocation_group_one_values) | df[allocation_col].isin(allocation_group_two_values))
        & (df[outcome_col].isin(outcome_true_values) | df[outcome_col].isin(outcome_false_values))
    ]

    if unique_column is not None:
        unique_column_counts = Counter(relevant_df[unique_column])
    else:
        unique_column_counts = {}

    group_one_true_count = 0
    group_one_false_count = 0
    group_two_true_count = 0
    group_two_false_count = 0
    for _, row in relevant_df.iterrows():
        allocation_value = row[allocation_col]
        outcome_value = row[outcome_col]

        if unique_column is not None:
            unique_column_value = row[unique_column]
            weight = 1.0 / unique_column_counts[unique_column_value]
        else:
            weight = 1.0

        if allocation_value in allocation_group_one_values:
            if outcome_value in outcome_true_values:
                group_one_true_count += weight
            elif outcome_value in outcome_false_values:
                group_one_false_count += weight
            else:
                raise ValueError(f"Outcome value {outcome_value} not in outcome_true_values or outcome_false_values")
        elif allocation_value in allocation_group_two_values:
            if outcome_value in outcome_true_values:
                group_two_true_count += weight
            elif outcome_value in outcome_false_values:
                group_two_false_count += weight
            else:
                raise ValueError(f"Outcome value {outcome_value} not in outcome_true_values or outcome_false_values")
        else:
            raise ValueError(f"Allocation value {allocation_value} not in allocation_group_one_values or allocation_group_two_values")

    group_one_true_count = round(group_one_true_count)
    group_one_false_count = round(group_one_false_count)
    group_two_true_count = round(group_two_true_count)
    group_two_false_count = round(group_two_false_count)

    return TwoByTwoTable(
        group_one_true_count=group_one_true_count,
        group_one_false_count=group_one_false_count,
        group_two_true_count=group_two_true_count,
        group_two_false_count=group_two_false_count,
    )

def get_2x2_stats(
    table: TwoByTwoTable,
) -> dict[str, float]:
    """
    Get 2x2 stats for a given allocation and outcome.
    """

    group_one_binomial_test = binomtest(table.group_one_true_count, table.group_one_n)
    group_one_rate = group_one_binomial_test.statistic
    group_one_ci = group_one_binomial_test.proportion_ci(confidence_level=0.95, method='wilson')
    group_one_ci_lower = group_one_ci.low
    group_one_ci_upper = group_one_ci.high

    group_two_binomial_test = binomtest(table.group_two_true_count, table.group_two_n)
    group_two_rate = group_two_binomial_test.statistic
    group_two_ci = group_two_binomial_test.proportion_ci(confidence_level=0.95, method='wilson')
    group_two_ci_lower = group_two_ci.low
    group_two_ci_upper = group_two_ci.high

    fisher_exact_test = fisher_exact(np.array([[table.group_one_true_count, table.group_one_false_count], [table.group_two_true_count, table.group_two_false_count]]))
    fisher_exact_p_value = fisher_exact_test.pvalue

    relative_risk_test = relative_risk(
        exposed_cases=table.group_two_true_count,
        exposed_total=table.group_two_n,
        control_cases=table.group_one_true_count,
        control_total=table.group_one_n,
    )
    rr = relative_risk_test.relative_risk
    relative_risk_ci = relative_risk_test.confidence_interval(confidence_level=0.95)
    relative_risk_ci_lower = relative_risk_ci.low
    relative_risk_ci_upper = relative_risk_ci.high

    relative_risk_reduction = 1 - rr
    relative_risk_reduction_ci_lower = 1 - relative_risk_ci.high
    relative_risk_reduction_ci_upper = 1 - relative_risk_ci.low

    symmetric_difference = (group_one_rate - group_two_rate) / ((group_one_rate + group_two_rate) / 2) if (group_one_rate + group_two_rate) > 0 else pd.NA

    absolute_risk_reduction = group_one_rate - group_two_rate
    if absolute_risk_reduction > 0:
        nnt = 1 / absolute_risk_reduction
    else:
        nnt = np.inf

    return {
        'first_group_y': table.group_one_true_count,
        'first_group_n': table.group_one_n,
        'first_group_rate': group_one_rate,
        'first_group_lower_CI': group_one_ci_lower,
        'first_group_upper_CI': group_one_ci_upper,
        'second_group_y': table.group_two_true_count,
        'second_group_n': table.group_two_n,
        'second_group_rate': group_two_rate,
        'second_group_lower_CI': group_two_ci_lower,
        'second_group_upper_CI': group_two_ci_upper,
        'PVal': fisher_exact_p_value,
        'relative_risk': rr,
        'relative_risk_lower_CI': relative_risk_ci_lower,
        'relative_risk_upper_CI': relative_risk_ci_upper,
        'RRR': relative_risk_reduction,
        'RRR_low_CI': relative_risk_reduction_ci_lower,
        'RRR_high_CI': relative_risk_reduction_ci_upper,
        'ARR': absolute_risk_reduction,
        'NNT': nnt,
        'symmetric_difference': symmetric_difference,
    }

In [None]:
table = get_2x2_table(
    df = results_clinical_study_df[results_clinical_study_df.VisitDate >= cut_point],
    allocation_col = 'ClinicianGroup',
    allocation_group_one_values = ['Non-AI'],
    allocation_group_two_values = ['AI'],
    outcome_col = 'treatment_likert',
    outcome_true_values = [1, 2],
    outcome_false_values = [3, 4, 5],
    unique_column = 'VisitCode'
)
stats = get_2x2_stats(table)

def get_likert_bar_plot_data(
    df: pd.DataFrame | list[pd.DataFrame],
    allocation_col: str | list[str] = 'ClinicianGroup',
    allocation_group_one_values = ['Non-AI'],
    allocation_group_two_values = ['AI'],
    outcome_cols = ['history_likert', 'investigations_likert', 'diagnosis_likert', 'treatment_likert'],
    outcome_true_values = [1, 2],
    outcome_false_values = [3, 4, 5],
    unique_column: str | None = 'VisitCode'
) -> dict[str, dict[str, float]]:
    if isinstance(allocation_col, str):
        allocation_col = [allocation_col] * len(outcome_cols)

    if isinstance(df, pd.DataFrame):
        df = [df] * len(outcome_cols)

    return {
        oc.replace('_likert', '').capitalize():
        get_2x2_stats(get_2x2_table(
            df = d,
            allocation_col = ac,
            allocation_group_one_values = allocation_group_one_values,
            allocation_group_two_values = allocation_group_two_values,
            outcome_col = oc,
            outcome_true_values = outcome_true_values,
            outcome_false_values = outcome_false_values,
            unique_column = unique_column
        ))
        for d, ac, oc in zip(df, allocation_col, outcome_cols, strict=True)
    }

def get_rrr_bar_plot_data(
    dfs: list[pd.DataFrame] | None = None,
    likert_bar_plot_data_sets: list[dict[str, dict[str, float]]] | None = None,
    allocation_col = 'ClinicianGroup',
    allocation_group_one_values = ['Non-AI'],
    allocation_group_two_values = ['AI'],
    outcome_cols = ['history_likert', 'investigations_likert', 'diagnosis_likert', 'treatment_likert'],
    outcome_true_values = [1, 2],
    outcome_false_values = [3, 4, 5],
    unique_column = 'VisitCode'
) -> dict[str, dict[str, float]]:
    assert dfs is not None or likert_bar_plot_data_sets is not None, "Either dfs or likert_bar_plot_data_sets must be provided"
    assert dfs is None or likert_bar_plot_data_sets is None, "Only one of dfs or likert_bar_plot_data_sets can be provided"

    if dfs is not None:
        assert len(dfs) == 2, "Only two dataframes are supported"
        likert_bar_plot_data_sets = [
            get_likert_bar_plot_data(
                df = df,
                allocation_col = allocation_col,
                allocation_group_one_values = allocation_group_one_values,
                allocation_group_two_values = allocation_group_two_values,
                outcome_cols = outcome_cols,
                outcome_true_values = outcome_true_values,
                outcome_false_values = outcome_false_values,
                unique_column = unique_column
            )
            for df in dfs
        ]

    assert likert_bar_plot_data_sets is not None, "likert_bar_plot_data_sets is provided"

    categories = list(likert_bar_plot_data_sets[0].keys())

    return {
        category: {
            'first_group_rate': likert_bar_plot_data_sets[0][category]['RRR'],
            'first_group_lower_CI': likert_bar_plot_data_sets[0][category]['RRR_low_CI'],
            'first_group_upper_CI': likert_bar_plot_data_sets[0][category]['RRR_high_CI'],
            'second_group_rate': likert_bar_plot_data_sets[1][category]['RRR'],
            'second_group_lower_CI': likert_bar_plot_data_sets[1][category]['RRR_low_CI'],
            'second_group_upper_CI': likert_bar_plot_data_sets[1][category]['RRR_high_CI'],
        }
        for category in categories
    }

In [None]:
def _sig_star(p: float | NAType | None) -> str:
    if pd.isna(p):
        return ""
    if p < 0.001:
        return "***"
    if p < 0.01:
        return "**"
    if p < 0.05:
        return "*"
    return ""

def clustered_bar_plot(
    data: dict[str, dict[str, float]],
    group_one_name: str = "Group 1",
    group_two_name: str = "Group 2",
    x_axis_title: str = "",
    y_axis_title: str = "",
    fig_title: str = "",
    figsize: tuple[int, int] = (10, 6),
    bar_width: float = 0.35,
    colors: tuple[str, str] = ("C0", "C1"),
    y_axis_percent: bool = True,
    show_significance: bool = True,
):
    """
    Create a clustered (side-by-side) bar chart with error bars and
    return the underlying data in tabular form.

    Returns
    -------
    (matplotlib.figure.Figure, pandas.DataFrame)
        Figure object and a DataFrame with columns:
        ['cluster', 'group name', 'value', 'lower CI', 'upper CI']
    """
    # Pull the ordered categories and their statistics
    categories = list(data.keys())
    n = len(categories)
    idx = np.arange(n)

    g1_vals = [data[c]["first_group_rate"] for c in categories]
    g2_vals = [data[c]["second_group_rate"] for c in categories]

    # Asymmetric errors: shape (2, N) → [lower, upper]
    g1_err_lower = [data[c]["first_group_rate"] - data[c]["first_group_lower_CI"] for c in categories]
    g1_err_upper = [data[c]["first_group_upper_CI"] - data[c]["first_group_rate"] for c in categories]
    g2_err_lower = [data[c]["second_group_rate"] - data[c]["second_group_lower_CI"] for c in categories]
    g2_err_upper = [data[c]["second_group_upper_CI"] - data[c]["second_group_rate"] for c in categories]

    g1_yerr = np.array([g1_err_lower, g1_err_upper])
    g2_yerr = np.array([g2_err_lower, g2_err_upper])

    # Plot
    fig, ax = plt.subplots(figsize=figsize)
    ax.bar(
        idx - bar_width / 2,
        g1_vals,
        width=bar_width,
        label=group_one_name,
        color=colors[0],
        yerr=g1_yerr,
        capsize=5,
        align="center",
    )
    ax.bar(
        idx + bar_width / 2,
        g2_vals,
        width=bar_width,
        label=group_two_name,
        color=colors[1],
        yerr=g2_yerr,
        capsize=5,
        align="center",
    )

    # Significance stars
    if show_significance:
        stars: list[str] = []
        for c in categories:
            p = data[c].get("PVal")
            stars.append(_sig_star(p))

        # Place the stars just above the taller bar (including its CI)
        for i, star in enumerate(stars):
            if not star:
                continue
            top_height = max(
                g1_vals[i] + g1_err_upper[i],
                g2_vals[i] + g2_err_upper[i],
            )
            ax.text(
                x=i,
                y=top_height,
                s=star,
                ha="center",
                va="bottom",
                fontsize=14,
                fontweight="bold",
                color="black",
            )

    # Aesthetics
    ax.set_xticks(idx)
    ax.set_xticklabels(categories, rotation=0)
    if x_axis_title:
        ax.set_xlabel(x_axis_title)
    if y_axis_title:
        ax.set_ylabel(y_axis_title)
    if fig_title:
        ax.set_title(fig_title)
    ax.legend()
    ax.margins(x=0.05)

    if y_axis_percent:
        ax.yaxis.set_major_formatter(mtick.PercentFormatter(1.0))

    # Ensure y-axis starts at 0
    _, ymax = ax.get_ylim()
    ax.set_ylim(bottom=0, top=max(ymax, 0))

    fig.tight_layout()

    # Build the summary table
    table_rows = []
    for i, cat in enumerate(categories):
        # Group 1
        table_rows.append(
            {
                "cluster": cat,
                "group name": group_one_name,
                "value": g1_vals[i],
                "lower CI": data[cat]["first_group_lower_CI"],
                "upper CI": data[cat]["first_group_upper_CI"],
                "p_value": data[cat].get("PVal") if show_significance else None,
            }
        )
        # Group 2
        table_rows.append(
            {
                "cluster": cat,
                "group name": group_two_name,
                "value": g2_vals[i],
                "lower CI": data[cat]["second_group_lower_CI"],
                "upper CI": data[cat]["second_group_upper_CI"],
                "p_value": data[cat].get("PVal") if show_significance else None,
            }
        )

    table_df = pd.DataFrame(table_rows)

    return fig, table_df

def capitalize_first_letter(s: str) -> str:
    return s[0].upper() + s[1:]

def likert_bar_plot_data_to_latex_vars(
    data: dict[str, dict[str, float]],
    prefix: str,
    percent_variable: bool = True,
) -> dict[str, str]:
    # this function is this way for backwards compatibility with the old latex vars

    latex_vars = {}
    for metric_label, metric_data in data.items():
        for key, value in metric_data.items():
            latexified_key = ''.join([capitalize_first_letter(x) for x in key.split("_")])

            metric_label_camel = capitalize_first_letter(metric_label)

            # Determine where (or whether) to insert the metric label
            insertion_points = ["FirstGroup", "SecondGroup", "RRR", "NNT"]
            updated_key = None
            for substr in insertion_points:
                if substr in latexified_key:
                    # Insert the metric label immediately after the recognised substring
                    updated_key = latexified_key.replace(substr, f"{substr}{metric_label_camel}")
                    break

            # If none of the substrings were found, prepend the metric label
            if updated_key is None:
                updated_key = f"{metric_label_camel}{latexified_key}"

            # add on the prefix
            updated_key = f"{prefix}{updated_key}"

            if updated_key.endswith("N") or updated_key.endswith("Y"):
                value = f"{value:.0f}"

            elif "NNT" in updated_key:
                value = f"{value:.1f}"

            elif "PVal" in updated_key:
                value = f"{value:.3f}"

            else:
                if percent_variable:
                    value = f"{value * 100:.1f}%"
                else:
                    value = f"{value:.1f}"

            latex_vars[updated_key] = value

    return latex_vars

In [None]:
def lbp_data_to_group_one_series(lbp_data):
    return pd.Series({
        key: f"{value['first_group_rate'] * 100:.1f}% ({value['first_group_lower_CI'] * 100:.1f}%-{value['first_group_upper_CI'] * 100:.1f}%)"
        for key, value in lbp_data.items()
    })

def lbp_data_to_group_two_series(lbp_data):
    return pd.Series({
        key: f"{value['second_group_rate'] * 100:.1f}% ({value['second_group_lower_CI'] * 100:.1f}%-{value['second_group_upper_CI'] * 100:.1f}%)"
        for key, value in lbp_data.items()
    })

def lbp_data_to_pval_series(lbp_data):
    return pd.Series({
        key: f"{value['PVal']:.3f}"
        for key, value in lbp_data.items()
    })

def lbp_data_to_rrr_series(lbp_data):
    return pd.Series({
        key: f"{value['RRR'] * 100:.1f}% ({value['RRR_low_CI'] * 100:.1f}%-{value['RRR_high_CI'] * 100:.1f}%)"
        for key, value in lbp_data.items()
    })

def lbp_data_to_NNT_series(lbp_data):
    return pd.Series({
        key: f"{value['NNT']:.1f}" if value['PVal'] < 0.05 else "-"
        for key, value in lbp_data.items()
    })

PENDA_ANNUAL_PATIENT_VOLUME = 400000
def lbp_data_to_absolute_benefit_series(lbp_data):
    return pd.Series({
        key: f"{value['ARR'] * PENDA_ANNUAL_PATIENT_VOLUME:.0f}" if value['PVal'] < 0.05 else "-"
        for key, value in lbp_data.items()
    })

### Main study analysis and main period vs induction period

In [None]:
induction_period_df = results_clinical_study_df[results_clinical_study_df.VisitDate < cut_point]
main_study_df = results_clinical_study_df[results_clinical_study_df.VisitDate >= cut_point]

induction_period_lbp_data = get_likert_bar_plot_data(induction_period_df)
lv = likert_bar_plot_data_to_latex_vars(induction_period_lbp_data, "induction")
latexvars.update(lv)

main_study_lbp_data = get_likert_bar_plot_data(main_study_df)
lv = likert_bar_plot_data_to_latex_vars(main_study_lbp_data, "mainPeriodLikertErrorRates")
latexvars.update(lv)

In [None]:
fig, fig_df = clustered_bar_plot(
    main_study_lbp_data,
    group_one_name = 'Non-AI',
    group_two_name = 'AI',
    fig_title = 'Error rates in history-taking, investigations, diagnosis & treatment questions\nNon-AI vs AI in main study period',
    y_axis_title = '% of visits with clinical errors for category',
)
save_fig('main_period_likert_error_rates', fig)
save_csv('main_period_likert_error_rates', fig_df)

In [None]:
{k: f'{v["RRR"]:0.3f} ({v["PVal"]:0.3f}) - NNT {v["NNT"]:0.1f} (for 400k yearly visits, {400000 * v["ARR"]:.0f} errors reduced)' for k, v in main_study_lbp_data.items()}

In [None]:
PENDA_ANNUAL_PATIENT_VOLUME = 400000
lv = {
    f"totalErrorReduction{k}": f"{PENDA_ANNUAL_PATIENT_VOLUME * v['ARR']:.0f}"
    for k, v in main_study_lbp_data.items()
}
latexvars.update(lv)
lv

In [None]:
fig, fig_df = clustered_bar_plot(
    induction_period_lbp_data,
    group_one_name = 'Non-AI',
    group_two_name = 'AI',
    fig_title = 'Error rates in history-taking, investigation, diagnosis & treatment questions\nNon-AI vs AI in induction period',
    y_axis_title = '% of visits with clinical errors for category',
)
save_fig('induction_likert_error_rates', fig)

In [None]:
fig = clustered_bar_plot(
    get_rrr_bar_plot_data([induction_period_df, main_study_df]),
    group_one_name = 'Non-AI',
    group_two_name = 'AI',
    y_axis_title = '% of visits with clinical errors for category',
    show_significance = False,
)

### Cases with at least one red response

In [None]:
risky_cases_for_history = main_study_df[main_study_df.risky_cases_for_history == True]
risky_cases_for_investigations = main_study_df[main_study_df.risky_cases_for_investigations == True]
risky_cases_for_diagnosis = main_study_df[main_study_df.risky_cases_for_diagnosis == True]
risky_cases_for_treatment = main_study_df[main_study_df.risky_cases_for_treatment == True]

risky_case_lbp_data = get_likert_bar_plot_data(
    [risky_cases_for_history, risky_cases_for_investigations, risky_cases_for_diagnosis, risky_cases_for_treatment]
)
lv = likert_bar_plot_data_to_latex_vars(risky_case_lbp_data, "mainPeriodRedOnly")
latexvars.update(lv)
lv

In [None]:
fig, fig_df = clustered_bar_plot(
    risky_case_lbp_data,
    group_one_name = 'Non-AI - risky cases for that category',
    group_two_name = 'AI - risky cases for that category',
    fig_title = 'Error rates in history-taking, investigation, diagnosis & treatment questions\nVisits with risky cases for that category',
    y_axis_title = '% of visits with clinical errors for category',
)
save_fig('main_period_red_only_likert_error_rates', fig)

In [None]:
{k: f'{v["RRR"]:0.3f} ({v["PVal"]:0.3f})' for k, v in risky_case_lbp_data.items()}

In [None]:
rrr_bar_plot_data_risky_vs_all = get_rrr_bar_plot_data(
    likert_bar_plot_data_sets = [main_study_lbp_data, risky_case_lbp_data]
)
_ = clustered_bar_plot(
    rrr_bar_plot_data_risky_vs_all,
    group_one_name = 'All visits',
    group_two_name = 'Only risky visits (i.e., one or more reds for the category in question)',
    fig_title = 'Relative risk reduction in history-taking, investigation, diagnosis & treatment errors from AI consult',
    y_axis_title = '% of visits with clinical errors for category',
    show_significance = False,
)

In [None]:
main_risky_nnt_df = pd.DataFrame({
    'RRR: all visits': lbp_data_to_rrr_series(main_study_lbp_data),
    'NNT': lbp_data_to_NNT_series(main_study_lbp_data),
    'Yearly errors averted at Penda': lbp_data_to_absolute_benefit_series(main_study_lbp_data),
})
save_latex('error_rate_reduction_main', main_risky_nnt_df)
main_risky_nnt_df

In [None]:
induction_risky_main_df = pd.DataFrame({
    'Main period, all visits': lbp_data_to_rrr_series(main_study_lbp_data),
    'Induction period': lbp_data_to_rrr_series(induction_period_lbp_data),
    'Main period, only visits with reds': lbp_data_to_rrr_series(risky_case_lbp_data),
})
save_latex("induction_vs_risky_vs_main", induction_risky_main_df)
induction_risky_main_df

### Final red/yellow vs final green

In [None]:
def any_final_red(x: types.AICalls | NAType) -> bool:
    if is_na(x):
        return False

    final_reds = [x.for_rule(rule).final_red for rule in types.ClinicalDecisionRule]
    return any(final_reds)


def final_color_for_rule(x: types.AICalls | NAType, rule: types.ClinicalDecisionRule, color: Literal['red', 'red_yellow', 'yellow', 'green']) -> bool:
    # returns True if the final color is the one we're looking for - red/yellow; red otherwise
    if is_na(x):
        return False

    calls_for_rule = x.for_rule(rule)

    if calls_for_rule._is_empty():
        return False

    if color == 'red':
        is_final_color = calls_for_rule.final_is_color(types.Color.Red)
    elif color == 'red_yellow':
        is_final_color = calls_for_rule.final_is_color(types.Color.Red) or calls_for_rule.final_is_color(types.Color.Yellow)
    elif color == 'yellow':
        is_final_color = calls_for_rule.final_is_color(types.Color.Yellow)
    elif color == 'green':
        is_final_color = calls_for_rule.final_is_color(types.Color.Green)
    else:
        raise ValueError(f"Invalid color: {color}")

    if is_final_color is None:
        raise ValueError(f"Final {color} is None for rule {rule}") # should only happen if the object is empty

    return is_final_color

def final_color_results_for_category(x: types.AICalls | NAType, category: str, color: Literal['red', 'red_yellow', 'yellow', 'green']) -> bool | NAType:
    if is_na(x):
        return pd.NA

    match category:
        case 'history':
            history_final_color = final_color_for_rule(x, types.ClinicalDecisionRule.VitalsChiefComplaintEvaluation, color)
            clinical_notes_final_color = final_color_for_rule(x, types.ClinicalDecisionRule.ClinicalNotes, color)
            return history_final_color | clinical_notes_final_color
        case 'investigations':
            return final_color_for_rule(x, types.ClinicalDecisionRule.InvestigationRecommendations, color)
        case 'diagnosis':
            return final_color_for_rule(x, types.ClinicalDecisionRule.DiagnosisEvaluation, color)
        case 'treatment':
            return final_color_for_rule(x, types.ClinicalDecisionRule.TreatmentRecommendation, color)
        case _:
            raise ValueError(f"Invalid category: {category}")

def color_x_vs_color_y(x: types.AICalls | NAType, category: str, true_color: Literal['red', 'red_yellow', 'yellow', 'green'], false_color: Literal['red', 'red_yellow', 'yellow', 'green']) -> bool | NAType:
    if is_na(x):
        return pd.NA

    final_color_for_rule_is_true_color = final_color_results_for_category(x, category, true_color)
    final_color_for_rule_is_false_color = final_color_results_for_category(x, category, false_color)

    if final_color_for_rule_is_true_color is True:
        return True
    elif final_color_for_rule_is_false_color is True:
        return False
    else:
        return pd.NA

In [None]:
results_clinical_study_df['final_red_vs_yellow_for_history'] = main_study_df.AICalls.map(lambda x: color_x_vs_color_y(x, 'history', 'red', 'yellow'))
results_clinical_study_df['final_red_vs_yellow_for_investigations'] = main_study_df.AICalls.map(lambda x: color_x_vs_color_y(x, 'investigations', 'red', 'yellow'))
results_clinical_study_df['final_red_vs_yellow_for_diagnosis'] = main_study_df.AICalls.map(lambda x: color_x_vs_color_y(x, 'diagnosis', 'red', 'yellow'))
results_clinical_study_df['final_red_vs_yellow_for_treatment'] = main_study_df.AICalls.map(lambda x: color_x_vs_color_y(x, 'treatment', 'red', 'yellow'))

final_red_yellow_lbp_data = get_likert_bar_plot_data(
    results_clinical_study_df,
    allocation_col = ['final_red_vs_yellow_for_history', 'final_red_vs_yellow_for_investigations', 'final_red_vs_yellow_for_diagnosis', 'final_red_vs_yellow_for_treatment'],
    allocation_group_one_values = [True],
    allocation_group_two_values = [False]
)
lv = likert_bar_plot_data_to_latex_vars(final_red_yellow_lbp_data, "redVYellow")
latexvars.update(lv)
lv

In [None]:
fig, fig_df = clustered_bar_plot(
    final_red_yellow_lbp_data,
    group_one_name = 'Final red for the category',
    group_two_name = 'Final yellow for the category',
    fig_title = 'Error rates in history-taking, investigation, diagnosis & treatment questions\nVisits with final red vs visits with final yellow',
    y_axis_title = '% of visits with clinical errors for category',
)
save_fig('red_v_yellow_likert_error_rates', fig)

In [None]:
results_clinical_study_df['final_yellow_vs_green_for_history'] = main_study_df.AICalls.map(lambda x: color_x_vs_color_y(x, 'history', 'yellow', 'green'))
results_clinical_study_df['final_yellow_vs_green_for_investigations'] = main_study_df.AICalls.map(lambda x: color_x_vs_color_y(x, 'investigations', 'yellow', 'green'))
results_clinical_study_df['final_yellow_vs_green_for_diagnosis'] = main_study_df.AICalls.map(lambda x: color_x_vs_color_y(x, 'diagnosis', 'yellow', 'green'))
results_clinical_study_df['final_yellow_vs_green_for_treatment'] = main_study_df.AICalls.map(lambda x: color_x_vs_color_y(x, 'treatment', 'yellow', 'green'))

final_yellow_green_lbp_data = get_likert_bar_plot_data(
    results_clinical_study_df,
    allocation_col = ['final_yellow_vs_green_for_history', 'final_yellow_vs_green_for_investigations', 'final_yellow_vs_green_for_diagnosis', 'final_yellow_vs_green_for_treatment'],
    allocation_group_one_values = [True],
    allocation_group_two_values = [False]
)
lv = likert_bar_plot_data_to_latex_vars(final_yellow_green_lbp_data, "yellowVGreen")
latexvars.update(lv)

In [None]:
fig, fig_df = clustered_bar_plot(
    final_yellow_green_lbp_data,
    group_one_name = 'Final yellow for the category',
    group_two_name = 'Final green for the category',
    fig_title = 'Error rates in history-taking, investigation, diagnosis & treatment questions\nVisits with final yellow vs visits with final green',
    y_axis_title = '% of visits with clinical errors for category',
)
save_fig('yellow_v_green_likert_error_rates', fig)

In [None]:
ryg_corr_df = pd.DataFrame({
    'Left in red': lbp_data_to_group_one_series(final_red_yellow_lbp_data),
    'Left inyellow': lbp_data_to_group_two_series(final_red_yellow_lbp_data),
    'Left in green': lbp_data_to_group_two_series(final_yellow_green_lbp_data),
    'p: R vs Y': lbp_data_to_pval_series(final_red_yellow_lbp_data),
    'p: Y vs G': lbp_data_to_pval_series(final_yellow_green_lbp_data),
})
save_latex("ryg_corr", ryg_corr_df)
ryg_corr_df

### Acuity analysis

In [None]:
low_acuity_df = results_clinical_study_df[(results_clinical_study_df.VisitDate >= cut_point) & (results_clinical_study_df.acuity == 'Low')]
medium_acuity_df = results_clinical_study_df[(results_clinical_study_df.VisitDate >= cut_point) & (results_clinical_study_df.acuity == 'Medium')]
high_acuity_df = results_clinical_study_df[(results_clinical_study_df.VisitDate >= cut_point) & (results_clinical_study_df.acuity == 'High')]


In [None]:
fig = clustered_bar_plot(
    get_rrr_bar_plot_data([low_acuity_df, medium_acuity_df]),
    group_one_name = 'Low acuity',
    group_two_name = 'Medium acuity',
    y_axis_title = '% of visits with clinical errors for category',
)

In [None]:
fig = clustered_bar_plot(
    get_rrr_bar_plot_data([low_acuity_df, high_acuity_df]),
    group_one_name = 'Low acuity',
    group_two_name = 'High acuity',
    y_axis_title = '% of visits with clinical errors for category',
)

In [None]:
low_acuity_lbp_data = get_likert_bar_plot_data(low_acuity_df)
medium_acuity_lbp_data = get_likert_bar_plot_data(medium_acuity_df)
high_acuity_lbp_data = get_likert_bar_plot_data(high_acuity_df)

low_acuity_error_rates = lbp_data_to_rrr_series(low_acuity_lbp_data)
medium_acuity_error_rates = lbp_data_to_rrr_series(medium_acuity_lbp_data)
high_acuity_error_rates = lbp_data_to_rrr_series(high_acuity_lbp_data)

acuity_rrr_df = pd.DataFrame({
    "Low-acuity cases": low_acuity_error_rates,
    "Medium-acuity cases": medium_acuity_error_rates,
    "High-acuity cases": high_acuity_error_rates,
})
save_latex("acuity_rrr", acuity_rrr_df)
acuity_rrr_df

### Inter-rater agreement 

In [None]:
likert_cols = ['history_likert', 'investigations_likert', 'diagnosis_likert', 'treatment_likert']

agreement_statistics = {}

for likert_col in likert_cols:
    error_col = f'{likert_col}_is_low'

    relevant_df = results_clinical_study_df.dropna(subset=[likert_col])
    grouped = relevant_df[["VisitCode", likert_col, error_col]].groupby("VisitCode")
    fleiss_kappa_results = []

    for _, group in grouped:
        if len(group) == 1:
            continue
        elif len(group) > 2:
            raise ValueError(f"Expected exactly one or two ratings per visit, got {len(group)}")

        is_error_1, is_error_2 = group[error_col]

        fleiss_kappa_results.append({
            0: (is_error_1 == 0) + (is_error_2 == 0),
            1: (is_error_1 == 1) + (is_error_2 == 1),
        })

    fleiss_kappa_df = pd.DataFrame(fleiss_kappa_results)
    fk = fleiss_kappa(fleiss_kappa_df.values, method="fleiss")

    agreement_statistics[likert_col] = {
        'fleiss_kappa': fk,
    }

agreement_statistics_df = pd.DataFrame(agreement_statistics)
latexvars['fleissKappaHistory'] = f"{agreement_statistics_df['history_likert']['fleiss_kappa']:.3f}"
latexvars['fleissKappaInvestigations'] = f"{agreement_statistics_df['investigations_likert']['fleiss_kappa']:.3f}"
latexvars['fleissKappaDiagnosis'] = f"{agreement_statistics_df['diagnosis_likert']['fleiss_kappa']:.3f}"
latexvars['fleissKappaTreatment'] = f"{agreement_statistics_df['treatment_likert']['fleiss_kappa']:.3f}"
agreement_statistics_df

In [None]:
def standard_confusion_matrix(
    df: pd.DataFrame,
    likert_col: str,
) -> pd.DataFrame:
    """
    Return a standard confusion matrix (row: rater 1, col: rater 2) as percentages of total pairs.
    """

    matrix = pd.DataFrame(0, index=range(1, 6), columns=range(1, 6), dtype=int)
    grouped = df[["VisitCode", likert_col]].dropna().groupby("VisitCode")
    total_pairs = 0

    for _, group in grouped:
        if len(group) == 1:
            continue
        elif len(group) > 2:
            raise ValueError(f"Expected exactly one or two ratings per visit, got {len(group)}")

        likert_1, likert_2 = group[likert_col].tolist()
        if is_na(likert_1) or is_na(likert_2):
            continue

        # Add both (rater1, rater2) and (rater2, rater1) to make symmetric
        matrix.at[likert_1, likert_2] += 1
        matrix.at[likert_2, likert_1] += 1
        total_pairs += 2  # Each visit adds two pairs

    return matrix

def agreement_summary_table(
    df: pd.DataFrame,
    likert_cols: list[str] | None = None,
    diffs: tuple[int, ...] = (0, 1, 2),
) -> pd.DataFrame:
    """
    Build a summary table of inter-rater agreement (proportion ± 95 % CI).
    """
    if likert_cols is None:
        likert_cols = [
            "history_likert",
            "investigations_likert",
            "diagnosis_likert",
            "treatment_likert",
        ]

    summary_rows: list[dict[str, float | str]] = []

    for question in likert_cols:
        confusion_matrix = standard_confusion_matrix(df, question)
        total_pairs = confusion_matrix.values.sum()

        row_vals = confusion_matrix.index.to_numpy()
        col_vals = confusion_matrix.columns.to_numpy()
        abs_diff_grid = np.abs(row_vals[:, None] - col_vals[None, :])
        confusion_matrix_values = confusion_matrix.to_numpy()

        for diff_threshold in diffs:
            matched_pairs = confusion_matrix_values[abs_diff_grid <= diff_threshold].sum()

            test_result = binomtest(matched_pairs, total_pairs)
            prop = test_result.statistic
            ci_low, ci_high = test_result.proportion_ci(method="wilson")

            summary_rows.append(
                {
                    "likert": question,
                    "diff": f"≤{diff_threshold}",
                    "prop": prop,
                    "low": ci_low,
                    "high": ci_high,
                }
            )

    # Convert the tidy records list into the desired wide table
    summary_table = (
        pd.DataFrame(summary_rows)
        .set_index(["likert", "diff"])
        .unstack("diff")            # columns -> (metric, diff)
        .swaplevel(0, 1, axis=1)    # order -> (diff, metric)
        .sort_index(axis=1, level=0)
        .reindex(likert_cols)       # preserve original order of questions
    )

    summary_table.columns.names = ["diff", "metric"]
    return summary_table

def plot_standard_confusion_matrix(
    conf_mat: pd.DataFrame,
    title: str = "Standard confusion matrix",
    cmap: str = "Oranges",
    ax: plt.Axes | None = None,
):
    if ax is None:
        _, ax = plt.subplots(figsize=(6, 5))

    data = conf_mat.fillna(0).astype(float).values
    data = data / data.sum() * 100

    annot = np.empty_like(data, dtype=object)
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            annot[i, j] = f"{data[i, j]:.1f}%"

    sns.heatmap(
        data,
        annot=annot,
        fmt="",
        cmap=cmap,
        cbar=False,
        ax=ax,
        linewidths=0.5,
        linecolor="white",
        xticklabels=list(conf_mat.columns),
        yticklabels=list(conf_mat.index),
    )
    ax.set_xlabel("Rater 2 rating")
    ax.set_ylabel("Rater 1 rating")
    ax.set_title(title)
    ax.invert_yaxis()  # So higher ratings are at the top

    # Rotate y tick labels to be horizontal
    for label in ax.get_yticklabels():
        label.set_rotation(0)
    plt.tight_layout()
    return ax


# display summary statistics and save as latex variables
inter_rater_agreement_table = agreement_summary_table(results_clinical_study_df)
print("Inter-rater agreement (proportion and 95 % CI):")
display(inter_rater_agreement_table)

for likert in inter_rater_agreement_table.index:
    likert_latexable = likert.replace('_likert', 'Likert')
    for diff in inter_rater_agreement_table.columns.get_level_values('diff'):
        diff_latexable = {
            "≤0": "Exact",
            "≤1": "OnePoint",
            "≤2": "TwoPoint",
        }[diff]
        for metric in inter_rater_agreement_table.columns.get_level_values('metric'):
            varstring = f'{likert_latexable}{diff_latexable}{metric.capitalize()}'
            val = f'{inter_rater_agreement_table.loc[likert, (diff, metric)] * 100:.1f}%'
            print(f'{varstring}: {val}')
            latexvars[varstring] = val

# Four-panel (2x2) plot of standard_confusion_matrix for each Likert type
likert_titles = [
    ("history_likert", "History Likert – trainer agreement"),
    ("investigations_likert", "Investigations Likert – trainer agreement"),
    ("diagnosis_likert", "Diagnosis Likert – trainer agreement"),
    ("treatment_likert", "Treatment Likert – trainer agreement"),
]

fig, axes = plt.subplots(2, 2, figsize=(16, 14))
axes = axes.flatten()
for ax, (likert, title) in zip(axes, likert_titles):
    conf_mat_std = standard_confusion_matrix(results_clinical_study_df, likert)
    plot_standard_confusion_matrix(
        conf_mat_std,
        title=title,
        ax=ax,
    )
plt.tight_layout()
plt.show()

save_fig('human_rater_study_agreement', fig)

## Multivariate models

In [None]:
def tidy_risk_ratios(result):
    """
    Convert a statsmodels fit object into a clean
    DataFrame of odds ratios with 95 % CIs and p-values.
    """
    return (
        pd.DataFrame({
            "RR":       np.exp(result.params),
            "CI_low":   np.exp(result.conf_int()[0]),
            "CI_high":  np.exp(result.conf_int()[1]),
            "p":        result.pvalues
        })
        .round(3)
    )

def clean_index_name(name):
    if name == "AgeYears":
        return "Age (years)"
    if name == "Intercept":
        return "Intercept"

    m = re.match(r"C\(([^,]+),\s*Treatment\(reference=['\"]([^'\"]+)['\"]\)\)\[T\.([^\]]+)\]", name)
    if m:
        var, ref, level = m.groups()

        match var:
            case "ClinicianGroup":
                var = "Group"
            case "VisitType":
                var = "Visit type"
            case "LocationName":
                var = "Clinic"
            case _:
                pass

        return f"{var}: {level} vs {ref}"

    m = re.match(r"C\(([^,]+),\s*Sum\)\[S\.([^\]]+)\]", name)
    if m:
        var, level = m.groups()

        assert var == "LocationName"
        return f"Clinic: {level} vs mean clinic"

    raise ValueError(f"Column name not recognized: {name}")


def clean_regression_result(result):
    tidied_df = tidy_risk_ratios(result)
    tidied_df.index = tidied_df.index.map(clean_index_name)
    tidied_df.columns = ['Relative risk', '95% CI lower', '95% CI upper', 'p']

    main_effect = tidied_df.loc['Group: AI vs Non-AI', :]

    return main_effect, tidied_df

def main_effect_to_rrr(main_effect, latexname):
    rrr_point = 1 - main_effect['Relative risk']
    rrr_ci_lower = 1 - main_effect['95% CI upper']
    rrr_ci_upper = 1 - main_effect['95% CI lower']
    return {
        f'{latexname}RRRPoint': f'{rrr_point * 100:.1f}%',
        f'{latexname}RRRLower': f'{rrr_ci_lower * 100:.1f}%',
        f'{latexname}RRRUpper': f'{rrr_ci_upper * 100:.1f}%',
    }

In [None]:
for domain in ["history", "investigations", "diagnosis", "treatment"]:

    print("-" * 100)
    print(f"FOR DOMAIN: {domain.upper()}")
    print("-" * 100)

    # DATA PREP
    results_clinical_study_main_period_df = results_clinical_study_df[results_clinical_study_df.VisitDate >= cut_point]
    results_clinical_study_main_period_df = results_clinical_study_main_period_df.dropna(subset=[f'{domain}_likert_is_low'])
    results_clinical_study_main_period_df[f'{domain}_likert_is_low'] = results_clinical_study_main_period_df[f'{domain}_likert_is_low'].astype(int)

    # BASELINE MODEL
    print("-" * 100)
    print(f"{domain.upper()}: BASELINE MODEL")
    print("-" * 100)
    result = smf.glm(
        f"{domain}_likert_is_low ~ C(ClinicianGroup, Treatment(reference = 'Non-AI'))",
        data=results_clinical_study_main_period_df,
        family=sm.families.Binomial(link=sm.families.links.Log()),
        freq_weights=results_clinical_study_main_period_df["row_weight"]
    ).fit()

    display(result.summary().tables[0])
    display(tidy_risk_ratios(result))

    # BINOMIAL GEE MODEL
    print("-" * 100)
    print(f"{domain.upper()}: BINOMIAL GEE MODEL")
    print("-" * 100)

    formula = """{domain}_likert_is_low ~
    C(ClinicianGroup, Treatment(reference = 'Non-AI')) +
    AgeYears +
    C(Gender, Treatment(reference = 'Male')) +
    C(VisitType, Treatment(reference = 'Cash')) +
    C(LocationName, Sum)"""

    result = smf.gee(
        formula.format(domain=domain),
        groups = results_clinical_study_main_period_df.user_id,
        data=results_clinical_study_main_period_df,
        family=sm.families.Binomial(link=sm.families.links.Log()),
        weights=results_clinical_study_main_period_df["row_weight"],
        cov_struct = sm.genmod.cov_struct.Exchangeable(),
    ).fit()

    display(result.summary().tables[0])
    display(tidy_risk_ratios(result))

    main_effect, tidied_df = clean_regression_result(result)

    save_latex(f'main_study_gee_{domain}', tidied_df)
    lv = main_effect_to_rrr(main_effect, f'mainStudyGEE{domain}')
    latexvars.update(lv)

    # LOG-POISSON MODEL WITH ZOU-DONNER CLUSTERING FOR COVARIANCE
    print("-" * 100)
    print(f"{domain.upper()}: LOG-POISSON MODEL WITH ZOU-DONNER CLUSTERING FOR COVARIANCE")
    print("-" * 100)

    model = smf.glm(
        formula.format(domain=domain),
        data=results_clinical_study_main_period_df,
        family=sm.families.Poisson(link=sm.families.links.Log()),
        freq_weights=results_clinical_study_main_period_df["row_weight"]
    )

    result = model.fit(
        cov_type='cluster',
        cov_kwds={'groups': results_clinical_study_main_period_df['user_id']}
    )

    display(result.summary().tables[0])
    display(tidy_risk_ratios(result))

    main_effect, tidied_df = clean_regression_result(result)
    save_latex(f'main_study_poisson_{domain}', tidied_df)
    lv = main_effect_to_rrr(main_effect, f'mainStudyPoisson{domain}')
    latexvars.update(lv)

## Multiple-choice question analysis

In [None]:
def process_mcq_option(s: str) -> str:
    if '(' in s:
        return s.split('(')[0].strip()
    return s

def process_mcq_options(s: set[str]) -> set[str]:
    return {process_mcq_option(opt) for opt in s}

mcq_cols = ['form_a_mcq', 'form_b_mcq', 'form_c_mcq', 'form_d_mcq']
mcq_col_short_name_map = {
    'form_a_mcq': 'History',
    'form_b_mcq': 'Investigations',
    'form_c_mcq': 'Diagnosis',
    'form_d_mcq': 'Treatment'
}

for mcq_col in mcq_cols:
    results_clinical_study_df[mcq_col] = results_clinical_study_df[mcq_col].map(process_mcq_options)


In [None]:
mcq_options_by_question = defaultdict(set)

for col in mcq_cols:
    for row in results_clinical_study_df[col]:
        for option in row:
            if option == 'None of the above':
                continue

            mcq_options_by_question[col].add(option)

all_options = set.union(*mcq_options_by_question.values())
assert len(all_options) == sum(len(opts) for opts in mcq_options_by_question.values())

for col in mcq_cols:
    options = sorted(mcq_options_by_question[col])
    for opt in options:
        results_clinical_study_df[f"{opt}"] = results_clinical_study_df[col].apply(lambda opts: opt in opts if not is_na(opts) else False)

In [None]:
main_study_df = results_clinical_study_df[results_clinical_study_df['VisitDate'] >= cut_point]

all_bar_plot_data = {}
for mcq_col in mcq_cols:
    mcq_col_short_name = mcq_col_short_name_map[mcq_col]

    mcq_bar_plot_data = get_likert_bar_plot_data(
        main_study_df,
        allocation_col = 'ClinicianGroup',
        allocation_group_one_values = ['Non-AI'],
        allocation_group_two_values = ['AI'],
        outcome_cols = [f"{opt}" for opt in mcq_options_by_question[mcq_col]],
        outcome_true_values = [True],
        outcome_false_values = [False],
        unique_column = 'VisitCode'
    )

    for k, v in mcq_bar_plot_data.items():
        all_bar_plot_data[f"{mcq_col_short_name}: {k}"] = v

    wrap_length = 150 // len(mcq_bar_plot_data)
    mcq_bar_plot_data = {
        textwrap.fill(k, width = wrap_length): v
        for k, v in mcq_bar_plot_data.items()
    }

    clustered_bar_plot(
        mcq_bar_plot_data,
        group_one_name = 'Non-AI',
        group_two_name = 'AI',
        fig_title = f'Specific error category rates in AI vs Non-AI group: {mcq_col_short_name}',
        y_axis_title = "Proportion of visits with error",
        figsize = (15, 6),
        bar_width = 0.35,
        colors = ("C0", "C1"),
        y_axis_percent = True,
        show_significance = True
    )

In [None]:
for k in all_bar_plot_data.keys():
    print(k)

In [None]:
key_mcq_cols_mapping = {
    "History: Chief complaint is absent": None,
    "History: Documentation of relevant systems on physical exam are absent": None,
    "History: Pertinent vital signs are absent": None,
    "History: Key details in the history are missing": "History: Key details are missing",
    "Investigations: Key investigations are missing": "Investigations: Key investigations are missing",
    "Investigations: Unjustified investigations are ordered": None,
    "Diagnosis: Primary diagnosis is missing": None,
    "Diagnosis: Primary diagnosis is likely incorrect": "Diagnosis: Primary diagnosis is likely incorrect",
    "Diagnosis: Primary diagnosis is too specific to be supported based on current documentation or investigations": None,
    "Diagnosis: Additional diagnosis is likely incorrect": None,
    "Diagnosis: Primary diagnosis is too broad when a more specific diagnosis is supported by the clinical notes": None,
    "Diagnosis: Clinically relevant additional diagnosis is missing": None,
    "Treatment: Medications are present but inappropriate": "Treatment: Inappropriate medications used",
    "Treatment: Likely inappropriate class of antibiotics used": None,
    "Treatment: Missing patient advice, education or follow up plan": "Treatment: Missing patient education or follow up plan",
    "Treatment: Incorrect patient advice, education or follow up plan": None,
    "Treatment: Needed procedures are missing": None,
    "Treatment: Likely inappropriate use of antibiotics overall": None,
    "Treatment: Medications are appropriate but incorrect dosages listed": None,
    "Treatment: Escalations of care are present but inappropriate": None,
    "Treatment: Procedures are present but inappropriate": None,
    "Treatment: Needed escalations of care are missing": None,
    "Treatment: Referrals are missing": None,
    "Treatment: Referrals are present but inappropriate": None,
    "Treatment: Medications are missing": None,
}
summary_mcq_bar_plot_data = {
    short_name: all_bar_plot_data[key_col]
    for key_col, short_name in key_mcq_cols_mapping.items()
    if short_name is not None
}

In [None]:
mcq_df = pd.DataFrame({
    'Non-AI': lbp_data_to_group_one_series(all_bar_plot_data),
    'AI': lbp_data_to_group_two_series(all_bar_plot_data),
    'RRR': lbp_data_to_rrr_series(all_bar_plot_data),
    'p': lbp_data_to_pval_series(all_bar_plot_data),
    'NNT': lbp_data_to_NNT_series(all_bar_plot_data),
    'N errors reduced at Penda': lbp_data_to_absolute_benefit_series(all_bar_plot_data)
})

# Rename certain keys in the index for clarity
mcq_df = mcq_df.rename(index={
    "Diagnosis: Primary diagnosis is too broad when a more specific diagnosis is supported by the clinical notes": "Diagnosis: Primary diagnosis broad when more specific is supported",
    "Diagnosis: Primary diagnosis is too specific to be supported based on current documentation or investigations": "Diagnosis: Primary diagnosis too specific to be supported",
})

save_latex('mcq', mcq_df)
mcq_df

In [None]:
wrap_length = 150 // len(summary_mcq_bar_plot_data)
summary_mcq_bar_plot_data = {
    textwrap.fill(k, width = wrap_length): v
    for k, v in summary_mcq_bar_plot_data.items()
}

fig, fig_df = clustered_bar_plot(
    summary_mcq_bar_plot_data,
    group_one_name = 'Non-AI',
    group_two_name = 'AI',
    fig_title = f'Specific error category rates in AI vs Non-AI group\nOnly select categories are shown',
    y_axis_title = "Proportion of visits with error",
    figsize = (15, 6),
    bar_width = 0.35,
    colors = ("C0", "C1"),
    y_axis_percent = True,
    show_significance = True
)
save_fig('select_mcq', fig)
save_csv('mcq', fig_df)

In [None]:
def q2_is_opt(x, opt):
    if pd.isna(x):
        return pd.NA
    if opt in x:
        return True
    return False

a = results_clinical_study_df.form_d_mcq.apply(lambda x: 'Medications are present but inappropriate' in x)
a.name = "Were inappropriate medications given?"
b = results_clinical_study_df.Q2.map(lambda x: not q2_is_opt(x, "all_at_penda") if pd.notna(x) else pd.NA)
b.name = "Did patients seek care outside Penda?"

crosstab = pd.crosstab(a, b)
crosstab_normalized = pd.crosstab(a, b, normalize='index')
display(crosstab)
display(crosstab_normalized)
fisher_exact(crosstab.to_numpy())

seek_care_given_inappropriate = crosstab.at[True, True]
inappropriate_meds = crosstab.at[True, True] + crosstab.at[True, False]
bt_inappropriate = binomtest(seek_care_given_inappropriate, inappropriate_meds)
p_seek_care_given_inappropriate = bt_inappropriate.statistic
ci_seek_care_given_inappropriate = bt_inappropriate.proportion_ci(confidence_level=0.95, method='wilson')
latexvars['seekCareGivenInappropriateMedsRate'] = f"{p_seek_care_given_inappropriate:.1%}"
latexvars['seekCareGivenInappropriateMedsRateLowerCI'] = f"{ci_seek_care_given_inappropriate.low:.1%}"
latexvars['seekCareGivenInappropriateMedsRateUpperCI'] = f"{ci_seek_care_given_inappropriate.high:.1%}"

seek_care_given_not_inappropriate = crosstab.at[False, True]
not_inappropriate_meds = crosstab.at[False, True] + crosstab.at[False, False]
bt_not_inappropriate = binomtest(seek_care_given_not_inappropriate, not_inappropriate_meds)
p_seek_care_given_not_inappropriate = bt_not_inappropriate.statistic
ci_seek_care_given_not_inappropriate = bt_not_inappropriate.proportion_ci(confidence_level=0.95, method='wilson')
latexvars['seekCareGivenNotInappropriateMedsRate'] = f"{p_seek_care_given_not_inappropriate:.1%}"
latexvars['seekCareGivenNotInappropriateMedsRateLowerCI'] = f"{ci_seek_care_given_not_inappropriate.low:.1%}"
latexvars['seekCareGivenNotInappropriateMedsRateUpperCI'] = f"{ci_seek_care_given_not_inappropriate.high:.1%}"

latexvars['seekCareVsAppropriateTableP'] = f"{fisher_exact(crosstab.to_numpy()).pvalue:.3f}"

## AI Analysis

In [None]:
results_ai_gpt41 = jsonl_load(gpt41_results_fp)
results_ai_df_gpt41 = pd.DataFrame(results_ai_gpt41)
results_ai_df_gpt41 = results_ai_df_gpt41.merge(df, left_on='visit_code', right_on='VisitCode', how='left', suffixes=('_physician', ''))

results_ai_o3 = jsonl_load(o3_results_fp)
results_ai_df_o3 = pd.DataFrame(results_ai_o3)
results_ai_df_o3 = results_ai_df_o3.merge(df, left_on='visit_code', right_on='VisitCode', how='left', suffixes=('_physician', ''))

In [None]:
def process_ai_df(results_ai_df):
    for col in ['history_likert', 'investigations_likert', 'diagnosis_likert', 'treatment_likert']:
        results_ai_df[col] = results_ai_df[col].apply(intify_likert)

    # zero out history Likert and MCQ for VisitDate > last_visitdate_with_cc
    results_ai_df.loc[results_ai_df.VisitDate > last_visitdate_with_cc, ['history_likert']] = pd.NA

    results_ai_df['history_likert_is_low'] = results_ai_df['history_likert'].apply(is_low_likert)
    results_ai_df['investigations_likert_is_low'] = results_ai_df['investigations_likert'].apply(is_low_likert)
    results_ai_df['diagnosis_likert_is_low'] = results_ai_df['diagnosis_likert'].apply(is_low_likert)
    results_ai_df['treatment_likert_is_low'] = results_ai_df['treatment_likert'].apply(is_low_likert)

    return results_ai_df

results_ai_gpt41 = process_ai_df(results_ai_df_gpt41)
results_ai_o3 = process_ai_df(results_ai_df_o3)

In [None]:
lbp_data_ai_analysis = {}

for results_ai_df, model_name, latex_var_name in [
    (results_ai_df_gpt41, 'GPT-4.1', 'GPTFourOne'),
    (results_ai_df_o3, 'o3', 'OThree'),
]:
    print(f"AI ANALYSIS FOR MODEL: {model_name.upper()}...")
    main_study_ai_df = results_ai_df[results_ai_df.VisitDate >= cut_point]

    main_study_ai_lbp_data = get_likert_bar_plot_data(main_study_ai_df)
    lbp_data_ai_analysis[model_name] = main_study_ai_lbp_data
    lv = likert_bar_plot_data_to_latex_vars(main_study_ai_lbp_data, latex_var_name)
    latexvars.update(lv)

    fig, fig_df = clustered_bar_plot(
        main_study_ai_lbp_data,
        group_one_name = 'Non-AI',
        group_two_name = 'AI',
        fig_title = f'Error rates in history-taking, investigation, diagnosis & treatment questions\nNon-AI vs AI in main study period\nRatings provided by {model_name}',
        y_axis_title = '% of visits with clinical errors for category',
    )
    save_fig(f'main_period_ai_likert_error_rates_{model_name}', fig)

    fig = clustered_bar_plot(
        get_rrr_bar_plot_data([main_study_df, main_study_ai_df]),
        group_one_name = 'Ratings from human raters',
        group_two_name = 'Ratings from LLM-based grader',
        y_axis_title = '% of visits with clinical errors for category',
    )

    print(f"AGREEMENT STATISTICS FOR {model_name.upper()}")

    results_ai_df.VisitCode = results_ai_df.VisitCode.astype(str)
    results_clinical_study_df.VisitCode = results_clinical_study_df.VisitCode.astype(str)

    combined_df = results_clinical_study_df.merge(results_ai_df, on='VisitCode', how='inner', suffixes=('', '_ai'))

    within_one_agreement = {}
    likert_cols = ['history_likert', 'investigations_likert', 'diagnosis_likert', 'treatment_likert']
    for likert_col in likert_cols:
        relevant_df = combined_df.dropna(subset=[likert_col, f'{likert_col}_ai'])
        agreement = relevant_df.apply(lambda row: abs(row[likert_col] - row[f'{likert_col}_ai']) <= 1, axis=1)
        within_one_agreement[likert_col] = agreement.value_counts(normalize=True)[True]
        lv_string = f'{latex_var_name}{likert_col.replace("_likert", "").capitalize()}WithinOneAgreement'
        print(lv_string)
        latexvars[lv_string] = f"{within_one_agreement[likert_col]:.1%}"
        print(f"{likert_col}: {within_one_agreement[likert_col]:.1%}")

    agreement_statistics = {}
    for likert_col in likert_cols:
        error_col_physician = f'{likert_col}_is_low'
        error_col_ai = f'{likert_col}_is_low_ai'

        relevant_df = combined_df.dropna(subset=[error_col_physician, error_col_ai])
        fleiss_kappa_results = []

        for _, row in relevant_df.iterrows():
            is_error_ai = row[error_col_ai]
            is_error_physician = row[error_col_physician]

            fleiss_kappa_results.append({
                0: (is_error_ai == 0) + (is_error_physician == 0),
                1: (is_error_ai == 1) + (is_error_physician == 1),
            })

        fleiss_kappa_df = pd.DataFrame(fleiss_kappa_results)
        fk = fleiss_kappa(fleiss_kappa_df.values, method="fleiss")

        agreement_statistics[likert_col] = {
            'fleiss_kappa': fk,
        }
        lv_string = f'{latex_var_name}{likert_col.replace("_likert", "").capitalize()}FleissKappa'
        latexvars[lv_string] = f"{fk:.3f}"
        print(lv_string)

    display(pd.DataFrame(agreement_statistics))

    for domain in ['history', 'investigations', 'diagnosis', 'treatment']:

        main_study_ai_df_notna = main_study_ai_df.dropna(subset=[f'{domain}_likert_is_low'])
        main_study_ai_df_notna[f'{domain}_likert_is_low'] = main_study_ai_df_notna[f'{domain}_likert_is_low'].astype(int)

        # GEE LOG-BINOMIAL MODEL
        print("-" * 100)
        print(f"{model_name.upper()}: {domain.upper()}: GEE LOG-BINOMIAL MODEL")
        print("-" * 100)
        result = smf.gee(
            formula.format(domain=domain),
            groups = main_study_ai_df_notna.user_id,
            data=main_study_ai_df_notna,
            family=sm.families.Binomial(link=sm.families.links.Log()),
            cov_struct = sm.genmod.cov_struct.Exchangeable(),
        ).fit()

        display(result.summary().tables[0])
        display(tidy_risk_ratios(result))

        main_effect, tidied_df = clean_regression_result(result)
        save_latex(f'{model_name}_gee_{domain}', tidied_df)
        lv = main_effect_to_rrr(main_effect, f'{latex_var_name}GEE{domain}')
        latexvars.update(lv)
        print(lv)

        # LOG-POISSON MODEL WITH ZOU-DONNER CLUSTERING FOR COVARIANCE
        print("-" * 100)
        print(f"{model_name.upper()}: {domain.upper()}: LOG-POISSON MODEL WITH ZOU-DONNER CLUSTERING FOR COVARIANCE")
        print("-" * 100)

        model = smf.glm(
            formula.format(domain=domain),
            data=main_study_ai_df_notna,
            family=sm.families.Poisson(link=sm.families.links.Log()),
        )

        result = model.fit(
            cov_type='cluster',
            cov_kwds={'groups': main_study_ai_df_notna['user_id']}
        )

        display(result.summary().tables[0])
        display(tidy_risk_ratios(result))

        main_effect, tidied_df = clean_regression_result(result)
        save_latex(f'{model_name}_poisson_{domain}', tidied_df)
        lv = main_effect_to_rrr(main_effect, f'{latex_var_name}Poisson{domain}')
        latexvars.update(lv)
        print(lv)

In [None]:
within_one_agreement_table = {
    'Physician-physician agreement': {
        category.capitalize(): latexvars[f'{category}LikertOnePointProp']
        for category in ['history', 'investigations', 'diagnosis', 'treatment']
    },
    'GPT-4.1-physician agreement': {
        category.capitalize(): latexvars[f'GPTFourOne{category.capitalize()}WithinOneAgreement']
        for category in ['history', 'investigations', 'diagnosis', 'treatment']
    },
    'o3-physician agreement': {
        category.capitalize(): latexvars[f'OThree{category.capitalize()}WithinOneAgreement']
        for category in ['history', 'investigations', 'diagnosis', 'treatment']
    }
}
within_one_agreement_table_df = pd.DataFrame(within_one_agreement_table)
save_latex('within_one_agreement_table', within_one_agreement_table_df)
within_one_agreement_table_df

In [None]:
fleiss_kappa_table = {
    'Physician-physician $\kappa$': {
        category.capitalize(): latexvars[f'fleissKappa{category.capitalize()}']
        for category in ['history', 'investigations', 'diagnosis', 'treatment']
    },
    'GPT-4.1-physician $\kappa$': {
        category.capitalize(): latexvars[f'GPTFourOne{category.capitalize()}FleissKappa']
        for category in ['history', 'investigations', 'diagnosis', 'treatment']
    },
    'o3-physician $\kappa$': {
        category.capitalize(): latexvars[f'OThree{category.capitalize()}FleissKappa']
        for category in ['history', 'investigations', 'diagnosis', 'treatment']
    }
}
fleiss_kappa_table_df = pd.DataFrame(fleiss_kappa_table)
save_latex('fleiss_kappa_table', fleiss_kappa_table_df)
fleiss_kappa_table_df

In [None]:
gpt_lbp_data = lbp_data_ai_analysis['GPT-4.1']
o3_lbp_data = lbp_data_ai_analysis['o3']

humans_ai_df = pd.DataFrame({
    'Physician raters': lbp_data_to_rrr_series(main_study_lbp_data),
    'GPT-4.1': lbp_data_to_rrr_series(gpt_lbp_data),
    'o3': lbp_data_to_rrr_series(o3_lbp_data),
})
save_latex('physicians_ai_rrr_table', humans_ai_df)
humans_ai_df

## Outcomes analysis

In [None]:
main_study_all_results_df = df[df.VisitDate >= cut_point]
main_outcome_bar_plot_data = get_likert_bar_plot_data(
    main_study_all_results_df,
    allocation_col = 'ClinicianGroup',
    allocation_group_one_values = ['Non-AI'],
    allocation_group_two_values = ['AI'],
    outcome_cols = ['Q1'],
    outcome_true_values = [1, 2, 3],
    outcome_false_values = [4, 5],
    unique_column = None
)

lv = likert_bar_plot_data_to_latex_vars(main_outcome_bar_plot_data, prefix = 'notFeelingBetter')
lv = {k.replace('Q1', ''): v for k, v in lv.items()}
latexvars.update(lv)

main_outcome_bar_plot_data = {
    "Rate of patients not feeling better": v
    for k, v in main_outcome_bar_plot_data.items()
}
clustered_bar_plot(
    main_outcome_bar_plot_data,
    group_one_name = 'Non-AI',
    group_two_name = 'AI',
    fig_title = 'Rate of patients not feeling better',
    y_axis_title = "Proportion of visits",
    figsize = (5, 6),
)

In [None]:
main_study_all_results_df.loc[:, 'Saw a pharmacist'] = main_study_all_results_df.Q2.apply(lambda x: x == "another_chemist" if pd.notna(x) else pd.NA)
main_study_all_results_df.loc[:, 'Self-referred to another clinic or hospital'] = main_study_all_results_df.Q2.apply(lambda x: x == "self_referred" if pd.notna(x) else pd.NA)
main_study_all_results_df.loc[:, 'Unplanned visit at Penda'] = main_study_all_results_df.HadUnplannedVisit.apply(lambda x: x == 1 if pd.notna(x) else pd.NA)

other_patient_outcome_bar_plot_data = get_likert_bar_plot_data(
    main_study_all_results_df,
    allocation_col = 'ClinicianGroup',
    allocation_group_one_values = ['Non-AI'],
    allocation_group_two_values = ['AI'],
    outcome_cols = ['Saw a pharmacist', 'Self-referred to another clinic or hospital', 'Unplanned visit at Penda'],
    outcome_true_values = [True],
    outcome_false_values = [False],
    unique_column = None
)

clustered_bar_plot(
    other_patient_outcome_bar_plot_data,
    group_one_name = 'Non-AI',
    group_two_name = 'AI',
    fig_title = 'Additional patient outcomes',
    y_axis_title = "Proportion of visits",
    figsize = (10, 6),
)

In [None]:
main_study_all_results_df.loc[:, 'Feeling worse on one-day follow-up'] = main_study_all_results_df.one_day_call_outcome.apply(lambda x: x == 'Yes' if pd.notna(x) else pd.NA)

one_day_outcome_bar_plot_data = get_likert_bar_plot_data(
    main_study_all_results_df,
    allocation_col = 'ClinicianGroup',
    allocation_group_one_values = ['Non-AI'],
    allocation_group_two_values = ['AI'],
    outcome_cols = ['Feeling worse on one-day follow-up'],
    outcome_true_values = [True],
    outcome_false_values = [False],
    unique_column = None
)

clustered_bar_plot(
    one_day_outcome_bar_plot_data,
    group_one_name = 'Non-AI',
    group_two_name = 'AI',
    fig_title = 'One-day follow-up outcome',
    y_axis_title = "Proportion of visits",
    figsize = (10, 6),
)

In [None]:
non_ai_series = pd.concat([
    lbp_data_to_group_one_series(main_outcome_bar_plot_data),
    lbp_data_to_group_one_series(other_patient_outcome_bar_plot_data),
    lbp_data_to_group_one_series(one_day_outcome_bar_plot_data)
])

ai_series = pd.concat([
    lbp_data_to_group_two_series(main_outcome_bar_plot_data),
    lbp_data_to_group_two_series(other_patient_outcome_bar_plot_data),
    lbp_data_to_group_two_series(one_day_outcome_bar_plot_data)
])

pval_series = pd.concat([
    lbp_data_to_pval_series(main_outcome_bar_plot_data),
    lbp_data_to_pval_series(other_patient_outcome_bar_plot_data),
    lbp_data_to_pval_series(one_day_outcome_bar_plot_data)
])

patient_outcomes_df = pd.DataFrame({
    "Rate in non-AI group": non_ai_series,
    "Rate in AI group": ai_series,
    "p": pval_series
})

save_latex('patient_outcomes', patient_outcomes_df)
patient_outcomes_df

In [None]:
main_study_all_results_df = df[df.VisitDate >= cut_point]
main_outcome_bar_plot_data = get_likert_bar_plot_data(
    main_study_all_results_df,
    allocation_col = 'ClinicianGroup',
    allocation_group_one_values = ['Non-AI'],
    allocation_group_two_values = ['AI'],
    outcome_cols = ['Q1'],
    outcome_true_values = [1, 2, 3],
    outcome_false_values = [4, 5],
    unique_column = None
)

q1_non_ai_rate = main_outcome_bar_plot_data['Q1']['first_group_rate']
q1_ai_rate = main_outcome_bar_plot_data['Q1']['second_group_rate']


main_study_all_results_df.loc[:, 'Q1: Favorable imputation'] = main_study_all_results_df.apply(lambda row: row.Q1 if pd.notna(row.Q1) else (5 if row.ClinicianGroup == 'AI' else 1), axis=1)
main_outcome_bar_plot_data_favorable_imputation = get_likert_bar_plot_data(
    main_study_all_results_df,
    allocation_col = 'ClinicianGroup',
    allocation_group_one_values = ['Non-AI'],
    allocation_group_two_values = ['AI'],
    outcome_cols = ['Q1: Favorable imputation'],
    outcome_true_values = [1, 2, 3],
    outcome_false_values = [4, 5],
    unique_column = None
)

main_study_all_results_df.loc[:, 'Q1: Unfavorable imputation'] = main_study_all_results_df.apply(lambda row: row.Q1 if pd.notna(row.Q1) else (1 if row.ClinicianGroup == 'AI' else 5), axis=1)
main_outcome_bar_plot_data_unfavorable_imputation = get_likert_bar_plot_data(
    main_study_all_results_df,
    allocation_col = 'ClinicianGroup',
    allocation_group_one_values = ['Non-AI'],
    allocation_group_two_values = ['AI'],
    outcome_cols = ['Q1: Unfavorable imputation'],
    outcome_true_values = [1, 2, 3],
    outcome_false_values = [4, 5],
    unique_column = None
)

non_ai_series = pd.concat([
    lbp_data_to_group_one_series(main_outcome_bar_plot_data),
    lbp_data_to_group_one_series(main_outcome_bar_plot_data_favorable_imputation),
    lbp_data_to_group_one_series(main_outcome_bar_plot_data_unfavorable_imputation)
])

ai_series = pd.concat([
    lbp_data_to_group_two_series(main_outcome_bar_plot_data),
    lbp_data_to_group_two_series(main_outcome_bar_plot_data_favorable_imputation),
    lbp_data_to_group_two_series(main_outcome_bar_plot_data_unfavorable_imputation)
])

pval_series = pd.concat([
    lbp_data_to_pval_series(main_outcome_bar_plot_data),
    lbp_data_to_pval_series(main_outcome_bar_plot_data_favorable_imputation),
    lbp_data_to_pval_series(main_outcome_bar_plot_data_unfavorable_imputation)
])

imputation_df = pd.DataFrame({
    "non_ai": non_ai_series,
    "ai": ai_series,
    "pval": pval_series
})
save_latex("imputation_df", imputation_df)
imputation_df

## Survey analysis

In [None]:
n_silent_ai_survey_responses = 23
n_active_ai_survey_responses = 36
pct_silent_ai_survey_responses = rf'{n_silent_ai_survey_responses / nsilentai * 100:.0f}\%'
pct_active_ai_survey_responses = rf'{n_active_ai_survey_responses / nactiveai * 100:.0f}\%'
latexvars['nSilentAISurveyResponses'] = n_silent_ai_survey_responses
latexvars['nActiveAISurveyResponses'] = n_active_ai_survey_responses
latexvars['pctSilentAISurveyResponses'] = pct_silent_ai_survey_responses
latexvars['pctActiveAISurveyResponses'] = pct_active_ai_survey_responses
pct_silent_ai_survey_responses, pct_active_ai_survey_responses

In [None]:
# Original counts for each group
silent_quality_of_care = {
    3: 4,
    4: 6,
    5: 13
}
active_quality_of_care = {
    3: 0,
    4: 8,
    5: 28
}

# Expand counts into lists of scores
silent_scores = []
for score, count in silent_quality_of_care.items():
    silent_scores.extend([score] * count)

active_scores = []
for score, count in active_quality_of_care.items():
    active_scores.extend([score] * count)

# Mann-Whitney U test
u_stat, p_value = mannwhitneyu(active_scores, silent_scores, alternative='two-sided')
print(f"Mann-Whitney U test: U={u_stat}, p={p_value}")

latexvars['emrQualityUPVal'] = f'{p_value:.3f}'
latexvars['emrQualityActivePctFive'] = rf"{active_quality_of_care[5] / sum(active_quality_of_care.values()) * 100:.0f}\%"
latexvars['emrQualitySilentPctFive'] = rf"{silent_quality_of_care[5] / sum(silent_quality_of_care.values()) * 100:.0f}\%"

In [None]:
active_ai_specific_quality_of_care = {
    4: 9,
    5: 27
}

active_ai_specific_satisfaction = {
    4: 21,
    5: 15
}

active_ai_promoter_score = {
    8: 8,
    9: 16,
    10: 12
}

nps = sum(v for k, v in active_ai_promoter_score.items() if k >= 9) - sum(v for k, v in active_ai_promoter_score.items() if k <= 6)
nps_pct = nps / sum(active_ai_promoter_score.values())

latexvars['activeAISpecificNPS'] = rf'{nps_pct * 100:.0f}'
latexvars['activeAISatisfactionPctFour'] = rf'{active_ai_specific_satisfaction[4] / sum(active_ai_specific_satisfaction.values()) * 100:.0f}\%'
latexvars['activeAISatisfactionPctFive'] = rf'{active_ai_specific_satisfaction[5] / sum(active_ai_specific_satisfaction.values()) * 100:.0f}\%'
latexvars['activeAISpecificQualityPctFour'] = rf'{active_ai_specific_quality_of_care[4] / sum(active_ai_specific_quality_of_care.values()) * 100:.0f}\%'
latexvars['activeAISpecificQualityPctFive'] = rf'{active_ai_specific_quality_of_care[5] / sum(active_ai_specific_quality_of_care.values()) * 100:.0f}\%'


In [None]:
quality_of_care_long_format = []
for k, v in silent_quality_of_care.items():
    for _ in range(v):
        quality_of_care_long_format.append({
            'ClinicianGroup': 'Non-AI',
            'Substantially worsens': 1 if k == 1 else 0,
            'Somewhat worsens': 1 if k == 2 else 0,
            'No change': 1 if k == 3 else 0,
            'Somewhat improves': 1 if k == 4 else 0,
            'Substantially improves': 1 if k == 5 else 0,
        })

for k, v in active_quality_of_care.items():
    for _ in range(v):
        quality_of_care_long_format.append({
            'ClinicianGroup': 'AI',
            'Substantially worsens': 1 if k == 1 else 0,
            'Somewhat worsens': 1 if k == 2 else 0,
            'No change': 1 if k == 3 else 0,
            'Somewhat improves': 1 if k == 4 else 0,
            'Substantially improves': 1 if k == 5 else 0,
        })

quality_of_care_long_format_df = pd.DataFrame(quality_of_care_long_format)

quality_of_care_lbp_data = get_likert_bar_plot_data(
    quality_of_care_long_format_df,
    allocation_col = 'ClinicianGroup',
    allocation_group_one_values = ['Non-AI'],
    allocation_group_two_values = ['AI'],
    outcome_cols = ['Substantially worsens', 'Somewhat worsens', 'No change', 'Somewhat improves', 'Substantially improves'],
    outcome_true_values = [1],
    outcome_false_values = [0],
    unique_column = None
)
fig, fig_df = clustered_bar_plot(
    {f'{k}\n\n': v for k, v in quality_of_care_lbp_data.items()},
    group_one_name = 'Non-AI',
    group_two_name = 'AI',
    fig_title = 'Clinician-reported impact of the EMR (including AI consult, if present) on quality of care',
    y_axis_title = '% of responses',
    show_significance = False,
    figsize = (10, 5),
)
save_fig('quality_of_care_emr_impact', fig)
save_csv('quality_of_care_emr_impact', fig_df)

In [None]:
def _dict_to_percentages_and_errors(d, order):
    """
    Convert a {score: count} dictionary into two lists:
    (1) percentages (as floats, e.g. 0.23)
    (2) error bars as a (2, n) array: [[lower_errs], [upper_errs]]
    """
    total = sum(d.values())
    pct = []
    lower_errs = []
    upper_errs = []
    for k in order:
        y = d.get(k, 0)
        n = total
        bt = binomtest(y, n)
        ci = bt.proportion_ci(confidence_level=0.95, method='wilson')
        p = bt.statistic
        pct.append(p)
        lower_errs.append(p - ci.low)
        upper_errs.append(ci.high - p)
    # yerr must be shape (2, n)
    yerr = [lower_errs, upper_errs]
    return pct, yerr

score_order = [1, 2, 3, 4, 5]

active_pct, active_err = _dict_to_percentages_and_errors(
    active_quality_of_care, score_order
)
silent_pct, silent_err = _dict_to_percentages_and_errors(
    silent_quality_of_care, score_order
)
quality_of_care_pct, quality_of_care_err = _dict_to_percentages_and_errors(
    active_ai_specific_quality_of_care, score_order
)
satisfaction_pct, satisfaction_err = _dict_to_percentages_and_errors(
    active_ai_specific_satisfaction, score_order
)

nps_order = list(range(11))
nps_pct, nps_err = _dict_to_percentages_and_errors(
    active_ai_promoter_score, nps_order
)

def plot_pct_bar(pct, err, labels, title, color=None, ylim=(0, 0.9), ylabel="% of responses"):
    if color is None:
        # Use matplotlib's default color cycle, c1 is the second color
        color = plt.rcParams['axes.prop_cycle'].by_key()['color'][1]
    fig, ax = plt.subplots(figsize=(10, 5))
    x = np.arange(len(pct))
    ax.bar(x, pct, color=color, yerr=err, capsize=4)
    ax.set_xticks(x)
    ax.set_xticklabels(labels, rotation=0, ha="center")
    ax.set_ylim(*ylim)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    ax.yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
    fig.tight_layout()
    return fig

# Build wrapped labels for each plot
wrap = lambda s: textwrap.fill(s, width=13)

labels_qoc = list(map(wrap, [
    "Substantially worsens quality",
    "Somewhat worsens quality",
    "Does not change quality",
    "Somewhat improves quality",
    "Substantially improves quality",
]))

labels_sat = list(map(wrap, [
    "Very dissatisfied",
    "Somewhat dissatisfied",
    "Neither satisfied nor dissatisfied",
    "Somewhat satisfied",

    "Very satisfied",
]))

labels_nps = list(map(str, nps_order))

# Plot and save all figures
fig_qoc = plot_pct_bar(
    quality_of_care_pct, quality_of_care_err, labels_qoc,
    "AI group – impact of AI consult on quality of care"
)
plt.show()
save_fig('quality_of_care_ai_impact', fig_qoc)

fig_sat = plot_pct_bar(
    satisfaction_pct, satisfaction_err, labels_sat,
    "AI group – satisfaction with AI consult"
)
plt.show()
save_fig('satisfaction_ai_impact', fig_sat)

fig_nps = plot_pct_bar(
    nps_pct, nps_err, labels_nps,
    "AI group – Net Promoter Score (NPS) for AI consult", ylim=(0, 0.7), ylabel="% of responses"
)
plt.show()
save_fig('nps_ai_impact', fig_nps)

In [None]:
# Create a DataFrame for quality of care
df_quality_of_care = pd.DataFrame({
    "cluster": [
        "Substantially worsens quality",
        "Somewhat worsens quality",
        "Does not change quality",
        "Somewhat improves quality",
        "Substantially improves quality",
    ],
    "value": quality_of_care_pct,
    "lower CI": [k - v for k, v in zip(quality_of_care_pct, quality_of_care_err[0])],
    "upper CI": [k + v for k, v in zip(quality_of_care_pct, quality_of_care_err[1])],
})

save_csv('quality_of_care_ai_impact', df_quality_of_care)

## Red reduction rate

In [None]:
def compute_final_reds_over_time_df(
    df: pd.DataFrame,
    *,
    first_final_red: Literal["first", "final", "all"] = "final",
    filter_to_rules: list[types.ClinicalDecisionRule] | None = None,
):
    """
    Return a tidy DataFrame with (per-week, per-group) red-flag percentages and
    Wilson confidence intervals.

    This function *only* prepares data – it performs **no** plotting or other
    side-effects.
    """

    calls_df = df.dropna(subset=["AICalls"]).copy()

    if filter_to_rules is not None:
        calls_df["AICalls"] = calls_df["AICalls"].apply(
            lambda x: x.for_rules(filter_to_rules)
        )

    if first_final_red == "first":
        calls_df["red"] = calls_df.apply(
            lambda row: row["AICalls"].any_first_red, axis=1
        )
    elif first_final_red == "final":
        calls_df["red"] = calls_df.apply(
            lambda row: row["AICalls"].any_final_red, axis=1
        )
    else:  # "all"
        calls_df["calls"] = calls_df["AICalls"].apply(lambda x: x.calls)
        calls_df = calls_df.explode("calls", ignore_index=True)
        calls_df["red"] = calls_df["calls"].apply(lambda c: c.color == types.Color.Red)

    calls_df = calls_df.dropna(subset=["red"]).copy()

    calls_df["VisitDate"] = pd.to_datetime(calls_df["VisitDate"])
    calls_df["week"] = calls_df["VisitDate"].dt.to_period("W").apply(
        lambda r: r.start_time
    )

    summary = (
        calls_df.groupby(["week", "ClinicianGroup"])["red"]
        .agg(["sum", "count"])
        .reset_index()
        .rename(columns={"sum": "n_red", "count": "n"})
    )

    pct_red, err_lower, err_upper = [], [], []
    for _, row in summary.iterrows():
        bt = binomtest(int(row["n_red"]), int(row["n"]))
        ci = bt.proportion_ci(method="wilson")
        p = bt.statistic
        pct_red.append(p)
        err_lower.append(p - ci.low)
        err_upper.append(ci.high - p)

    summary["pct_final_red"] = pct_red
    summary["err_lower"] = err_lower
    summary["err_upper"] = err_upper

    return summary

def plot_no_ai_ai_over_time(
    red_by_week: pd.DataFrame,
    x_var: str = 'week',
    y_var: str = 'pct_final_red',
    x_lim: tuple[float, float] | None = None,
    y_lim: tuple[float, float] | None = None,
    x_label: str = '',
    y_label: str = '',
    title: str = '',
    grouping_var: str = 'ClinicianGroup',
    grouping_var_label: str = 'Group',
    group_order: list[str] = ["Non-AI", "AI"],
    y_axis_percent: bool = True,
    time_x_axis: bool = True,
):
    """
    Given the output of `compute_final_reds_over_time_df`, render the plot and
    return the Matplotlib `Figure`.
    """
    plt.close("all")
    fig, ax = plt.subplots(figsize=(8, 5))

    for group in group_order:
        gdf = red_by_week[red_by_week[grouping_var] == group]
        yerr = [gdf["err_lower"], gdf["err_upper"]]
        ax.errorbar(
            gdf[x_var],
            gdf[y_var],
            yerr=yerr,
            marker="o",
            capsize=4,
            linestyle="-",
            label=group,
        )

    if x_lim is not None:
        ax.set_xlim(x_lim)

    ymin, ymax = ax.get_ylim()
    if y_lim:
        ax.set_ylim(y_lim)
    else:
        ax.set_ylim(0, ymax * 1.05)

    if time_x_axis:
        cut_point_num: float = float(mdates.date2num(cut_point))
        ax.axvline(cut_point_num, color="gray", linestyle="--", linewidth=0.5)

        ax.xaxis.set_major_locator(mdates.WeekdayLocator())
        ax.xaxis.set_major_formatter(mdates.DateFormatter("%b %d\n%Y"))

        ax.annotate(
            "Induction period",
            xy=(cut_point_num, 0),
            xytext=(cut_point_num - 3, 0),
            ha="right",
            va="bottom",
            fontsize=10,
            backgroundcolor="white",
            bbox=dict(facecolor="white", edgecolor="none", alpha=0.7),
            arrowprops=dict(arrowstyle="-", color="gray", lw=0),
        )
        ax.annotate(
            "Active deployment",
            xy=(cut_point_num, 0),
            xytext=(cut_point_num + 3, 0),
            ha="left",
            va="bottom",
            fontsize=10,
            backgroundcolor="white",
            bbox=dict(facecolor="white", edgecolor="none", alpha=0.7),
            arrowprops=dict(arrowstyle="-", color="gray", lw=0),
        )

    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)
    ax.set_title(title)
    ax.legend(title=grouping_var_label)
    if y_axis_percent:
        ax.yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
    plt.tight_layout()

    return fig


# --------------------------------------------------------------------------- #
# 3. Convenience wrapper (preserves original API)                             #
# --------------------------------------------------------------------------- #
def plot_final_reds_over_time(
    df: pd.DataFrame,
    first_final_red: Literal["first", "final", "all"] = "final",
    filter_to_rules: list[types.ClinicalDecisionRule] | None = None,
    y_lim: tuple[float, float] | None = None,
):
    """
    Back-compatibility wrapper that:
    1. Generates the summary data frame.
    2. Creates the plot.
    3. Returns both, keeping the original return signature intact.
    """
    red_by_week = compute_final_reds_over_time_df(
        df,
        first_final_red=first_final_red,
        filter_to_rules=filter_to_rules,
    )
    rule_suffix = (
        "" if filter_to_rules is None else "\n" + ", ".join(r.value for r in filter_to_rules)
    )
    fig = plot_no_ai_ai_over_time(
        red_by_week,
        y_label=f"% of visits with {first_final_red.title()} red" + rule_suffix,
        title=f"% of {first_final_red.title()} Reds Over Time by Clinician Group" + rule_suffix,
        y_lim=y_lim,
    )
    return red_by_week, fig

In [None]:
red_by_week, fig = plot_final_reds_over_time(df, first_final_red = 'final', y_lim = (0, 0.6))
save_fig("final_reds_over_time", fig)
save_csv("final_reds_over_time", red_by_week)

In [None]:
red_by_week, fig = plot_final_reds_over_time(df, first_final_red = 'first', y_lim = (0, 0.6))
save_fig("first_reds_over_time", fig)

In [None]:
red_by_week, fig = plot_final_reds_over_time(df, first_final_red = 'first', filter_to_rules = [types.ClinicalDecisionRule.VitalsChiefComplaintEvaluation, types.ClinicalDecisionRule.ClinicalNotes])
save_fig("first_reds_over_time_vitals_history", fig)

In [None]:
red_by_week, fig = plot_final_reds_over_time(df, first_final_red = 'final', filter_to_rules = [types.ClinicalDecisionRule.TreatmentRecommendation])
save_fig("final_reds_over_time_treatment", fig)

In [None]:
red_by_week, fig = plot_final_reds_over_time(df, first_final_red = 'first', filter_to_rules = [types.ClinicalDecisionRule.TreatmentRecommendation])
save_fig("first_reds_over_time_treatment", fig)

In [None]:
def compute_provider_level_final_red_heterogeneity_df(
    df: pd.DataFrame,
    *,
    first_final_red: Literal["first", "final", "all"] = "final",
    filter_to_rules: list[types.ClinicalDecisionRule] | None = None,
):
    """
    Return a tidy DataFrame with (per-week, per-group) red-flag percentages and
    Wilson confidence intervals.

    This function *only* prepares data – it performs **no** plotting or other
    side-effects.
    """

    calls_df = df.dropna(subset=["AICalls"]).copy()

    if filter_to_rules is not None:
        calls_df["AICalls"] = calls_df["AICalls"].apply(
            lambda x: x.for_rules(filter_to_rules)
        )

    if first_final_red == "first":
        calls_df["red"] = calls_df.apply(
            lambda row: row["AICalls"].any_first_red, axis=1
        )
    elif first_final_red == "final":
        calls_df["red"] = calls_df.apply(
            lambda row: row["AICalls"].any_final_red, axis=1
        )
    else:  # "all"
        calls_df["calls"] = calls_df["AICalls"].apply(lambda x: x.calls)
        calls_df = calls_df.explode("calls", ignore_index=True)
        calls_df["red"] = calls_df["calls"].apply(lambda c: c.color == types.Color.Red)

    calls_df = calls_df.dropna(subset=["red"]).copy()

    calls_df['UserID'] = calls_df['UserIDs'].apply(lambda x: x[0] if len(x) == 1 else pd.NA)
    calls_df = calls_df.drop(columns=["UserIDs"])

    calls_df = calls_df[calls_df.ClinicianGroup == "AI"]

    # First, get pct_red for each provider per week
    provider_summary = (
        calls_df.groupby(["week", "UserID"])["red"]
        .agg(
            n_red=("sum"),
            n=("count"),
        )
        .reset_index()
    )
    provider_summary["pct_final_red"] = provider_summary["n_red"] / provider_summary["n"]

    # Now, for each week, get the p10/p50/p90 of pct_red across providers
    summary = (
        provider_summary.groupby("week")["pct_final_red"]
        .agg(
            min=lambda x: x.min(),
            p10=lambda x: x.quantile(0.10),
            p25=lambda x: x.quantile(0.25),
            p50="median",
            p75=lambda x: x.quantile(0.75),
            p90=lambda x: x.quantile(0.90),
            max=lambda x: x.max(),
        )
        .reset_index()
    )
    summary = summary.melt(id_vars="week", var_name="quantile", value_name="pct_final_red")

    pct_final_red, err_lower, err_upper = [], [], []
    for _, row in summary.iterrows():
        val = row["pct_final_red"]
        pct_final_red.append(val)
        err_lower.append(0)
        err_upper.append(0)

    summary["pct_final_red"] = pct_final_red
    summary["err_lower"] = err_lower
    summary["err_upper"] = err_upper

    group_name_map = {"p10": "p10", "p50": "p50", "p90": "p90"}
    summary["quantile"] = (
        summary["quantile"].map(group_name_map).fillna(summary["quantile"])
    )

    return summary

In [None]:
ai_df = df[df.ClinicianGroup == "AI"]
plot_df = compute_provider_level_final_red_heterogeneity_df(ai_df, first_final_red="final")
fig = plot_no_ai_ai_over_time(
    plot_df,
    x_var = "week",
    y_var = "pct_final_red",
    y_label = "Percentage of final red flags",
    title = "Percentage of final red flags by week",
    y_axis_percent = True,
    time_x_axis = True,
    grouping_var = "quantile",
    grouping_var_label = "Quantile",
    group_order = [ "p10", "p25", "p50", "p75", "p90"],
)
save_fig("final_red_quantiles", fig)

## Timing analysis

In [None]:
# Plot boxplot
plt.figure(figsize=(8, 5))
sns.boxplot(
    data=df,
    x='ClinicianGroup',
    y='duration_minutes',
    showfliers=True,  # show outliers
    palette='muted'
)
plt.title('Clinician attending time by clinician group')
plt.ylabel('Clinician attending time (minutes)')
plt.tight_layout()
plt.ylim(0, 60)

ax = plt.gca()
xtick_labels = [tick.get_text() for tick in ax.get_xticklabels()]
label_map = {
    'silent_group': 'Non-AI',
    'active_group': 'AI',
}
new_labels = [label_map.get(lbl, lbl) for lbl in xtick_labels]
ax.set_xticklabels(new_labels)

plt.show()

# Get the two groups
grouped = df.groupby('ClinicianGroup')['duration_minutes']
groups = list(grouped.groups.keys())
if len(groups) == 2:
    group1, group2 = groups
    durations1 = grouped.get_group(group1)
    durations2 = grouped.get_group(group2)

    # Print medians
    median1 = durations1.median()
    median2 = durations2.median()
    print(f"Median duration for {group1}: {median1:.2f} minutes")
    print(f"Median duration for {group2}: {median2:.2f} minutes")

    latexvars['nonAIMedianDuration'] = f"{median1:.2f}"
    latexvars['AIMedianDuration'] = f"{median2:.2f}"

    # U-test
    result = mannwhitneyu(durations1, durations2, alternative='two-sided')
    print(f"U-test between {group1} and {group2}: t={result.statistic:.3f}, p={result.pvalue:.3f}")
    latexvars['MedianDurationP'] = f"{result.pvalue:.3f}"
else:
    print("Expected exactly two clinician groups for t-test.")


In [None]:
def compute_duration_by_var(
    df: pd.DataFrame,
    var: str,
    stratify_by: str = "ClinicianGroup",
    n_boot: int = 1_000,
) -> pd.DataFrame:
    records: list[dict] = []
    for (var_val, group_val), durations in (
        df.groupby([var, stratify_by])["duration_minutes"]
    ):
        durations = durations.dropna().to_numpy()

        # Point estimate
        median = np.median(durations)

        # Bootstrap CI of the median
        boot_samples = np.random.choice(
            durations, size=(n_boot, len(durations)), replace=True
        )
        boot_medians = np.median(boot_samples, axis=1)
        ci_low, ci_high = np.percentile(boot_medians, [2.5, 97.5])

        records.append(
            {
                var: var_val,
                stratify_by: group_val,
                "median_duration": median,
                "err_lower": median - ci_low,
                "err_upper": ci_high - median,
            }
        )

    summary = pd.DataFrame.from_records(records)
    summary = summary[[var, stratify_by, "median_duration", "err_lower", "err_upper"]]

    return summary

In [None]:
data_df = compute_duration_by_var(df, var="n_aicalls", stratify_by="ClinicianGroup")
data_df = data_df[data_df['n_aicalls'] <= 12]
fig = plot_no_ai_ai_over_time(
    data_df,
    y_var = "median_duration",
    x_var = "n_aicalls",
    y_label="Median attending time (minutes)",
    title="Clinician attending time by number of AI consult triggers",
    y_axis_percent=False,
    time_x_axis=False,
    x_label = "Number of times AI Consult was or would have been triggered",
    y_lim=(0, 40),
)
save_fig("duration_by_n_ai_calls", fig)

In [None]:
# Filter to n_aicalls <= 12
filtered_df = results_ai_df_gpt41[results_ai_df_gpt41["n_aicalls"] <= 12]

# Compute mean and bootstrap CIs for treatment_likert by n_aicalls and ClinicianGroup
def wilson_ci(data, alpha=0.05):
    """
    Compute the mean (proportion) and Wilson 95% confidence interval for binary data.
    Returns (mean, lower, upper).
    """
    data = np.array(data.dropna())
    if len(data) == 0:
        return np.nan, np.nan, np.nan
    # If data is not binary, raise an error
    unique_vals = np.unique(data)
    if not np.all(np.isin(unique_vals, [0, 1])):
        raise ValueError("wilson_ci expects binary (0/1) data")
    count = np.sum(data)
    nobs = len(data)
    mean = count / nobs
    lower, upper = proportion_confint(count, nobs, alpha=alpha, method='wilson')
    return mean, lower, upper

summary_records = []
for (n_aicalls, group), subdf in filtered_df.groupby(["n_aicalls", "ClinicianGroup"]):
    mean, lower, upper = wilson_ci(subdf["treatment_likert_is_low"])
    summary_records.append({
        "n_aicalls": n_aicalls,
        "ClinicianGroup": group,
        "mean_treatment_likert": mean,
        "ci_lower": lower,
        "ci_upper": upper,
        "err_lower": mean - lower,
        "err_upper": upper - mean,
    })

summary_df = pd.DataFrame(summary_records)

# Use the same plotting style as in the previous cell (see @file_context_0)
fig = plot_no_ai_ai_over_time(
    summary_df,
    y_var="mean_treatment_likert",
    x_var="n_aicalls",
    y_label="Rate of treatment errors",
    title="Rate of treatment errors, assigned by GPT-4.1, vs number of AI calls",
    y_axis_percent=True,
    time_x_axis=False,
    x_label = "Number of times AI Consult was or would have been triggered",
    y_lim=None,
)
save_fig("treatment_errors_vs_n_aicalls", fig)


In [None]:
# Filter to visits with duration <= 30 minutes
filtered_df = results_ai_df_gpt41[results_ai_df_gpt41["duration_minutes"] <= 30].copy()

# Bin visit durations
bin_width = 5
filtered_df["VisitDurationBin"] = (filtered_df["duration_minutes"] // bin_width * bin_width).astype(int)

# Compute treatment error rate (proportion where treatment_likert_is_low == 1) and Wilson CI for each bin/group
summary_records = []
for (duration_bin, group), subdf in filtered_df.groupby(["VisitDurationBin", "ClinicianGroup"]):
    mean, lower, upper = wilson_ci(subdf["treatment_likert_is_low"])
    summary_records.append({
        "VisitDurationBin": duration_bin,
        "ClinicianGroup": group,
        "treatment_error_rate": mean,
        "ci_lower": lower,
        "ci_upper": upper,
        "err_lower": mean - lower,
        "err_upper": upper - mean,
        "n": len(subdf),
    })

summary_df = pd.DataFrame(summary_records)
summary_df = summary_df.sort_values("VisitDurationBin")

# Plot using the same style as plot_no_ai_ai_over_time
fig = plot_no_ai_ai_over_time(
    summary_df,
    y_var="treatment_error_rate",
    x_var="VisitDurationBin",
    y_label="Rate of treatment errors",
    title="Rate of treatment errors, assigned by GPT-4.1, vs clinician attending time",
    y_axis_percent=True,
    time_x_axis=False,
    x_label="Clinician attending time (minutes, binned)",
    y_lim=None,
)
save_fig("treatment_errors_vs_visit_duration", fig)


In [None]:
def get_note_length(row):
    doc = row['ClinicalDocumentation']
    if doc is None:
        return None
    note = getattr(doc, 'clinical_notes_clean', None)
    if note is None:
        return None
    return len(note)

date_col = 'VisitDate'

# Compute note length and week
df = results_ai_df_gpt41.copy()
df['note_length'] = df.apply(get_note_length, axis=1)
df['week'] = pd.to_datetime(df[date_col]).dt.to_period('W').dt.start_time

# Group by week and ClinicianGroup, compute median and CI
def bootstrap_median_ci_note(data, n_boot=1000, ci=95):
    data = np.array(data.dropna())
    if len(data) == 0:
        return np.nan, np.nan, np.nan
    medians = np.array([
        np.median(np.random.choice(data, size=len(data), replace=True))
        for _ in range(n_boot)
    ])
    median = np.median(data)
    lower = np.percentile(medians, (100 - ci) / 2)
    upper = np.percentile(medians, 100 - (100 - ci) / 2)
    return median, lower, upper

summary_records = []
for (week, group), subdf in df.groupby(['week', 'ClinicianGroup']):
    median, lower, upper = bootstrap_median_ci_note(subdf['note_length'])
    summary_records.append({
        'week': week,
        'ClinicianGroup': group,
        'median_note_length': median,
        'ci_lower': lower,
        'ci_upper': upper,
        'err_lower': median - lower,
        'err_upper': upper - median,
    })

summary_notes_df = pd.DataFrame(summary_records)
summary_notes_df = summary_notes_df.sort_values('week')

# Plot
fig = plot_no_ai_ai_over_time(
    summary_notes_df,
    y_var="median_note_length",
    x_var="week",
    y_label="Median clinical note length (chars)",
    title="Median clinical note length over time (by week), AI vs non-AI",
    y_axis_percent=False,
    time_x_axis=True,
    y_lim=None,
)
save_fig("median_clinical_note_length_vs_week", fig)

In [None]:
ai_value_counts = ai_interaction_scrubbed[ai_interaction_scrubbed.Silent == 'Active'].AiLike.value_counts()

In [None]:
latexvars['nAICalls'] = f"{ai_value_counts.sum():.0f}"

latexvars['nAICallsThumbsAny'] = f"{ai_value_counts['Up'] + ai_value_counts['Down']:.0f}"
latexvars['pctAICallsThumbsAny'] = f"{(ai_value_counts['Up'] + ai_value_counts['Down']) / ai_value_counts.sum():.1%}"

latexvars['nAICallsThumbsUp'] = f"{ai_value_counts['Up']:.0f}"
latexvars['nAICallsThumbsDown'] = f"{ai_value_counts['Down']:.0f}"

latexvars['pctAICallsThumbsUp'] = f"{ai_value_counts['Up'] / (ai_value_counts['Up'] + ai_value_counts['Down']):.1%}"
latexvars['pctAICallsThumbsDown'] = f"{ai_value_counts['Down'] / (ai_value_counts['Up'] + ai_value_counts['Down']):.1%}"

In [None]:
def proportion_ci(successes, n, ci=95):
    if n == 0:
        return np.nan, np.nan

    if float(successes).is_integer():
        successes = int(successes)
    else:
        raise ValueError("successes must be an integer")

    if float(n).is_integer():
        n = int(n)
    else:
        raise ValueError("n must be an integer")

    result = binomtest(successes, n)
    alpha = 1 - ci / 100
    ci_low, ci_upp = result.proportion_ci(confidence_level=1 - alpha, method="wilson")
    return ci_low, ci_upp

active_ai = ai_interaction_scrubbed[ai_interaction_scrubbed.Silent == 'Active'].copy()
active_ai['CreatedOn'] = pd.to_datetime(active_ai['CreatedOn'], errors='coerce')
active_ai['week'] = active_ai['CreatedOn'].dt.to_period('W').dt.start_time

active_ai = active_ai[active_ai['AiLike'].notna()]

valid_likes = active_ai[active_ai['AiLike'].isin(['Up', 'Down'])].copy()

ai_like_counts = (
    valid_likes.groupby(['week', 'AiLike'])
    .size()
    .unstack(fill_value=0)
    .sort_index()
)
ai_like_counts['ClinicianGroup'] = 'AI'

ai_like_counts['n'] = ai_like_counts.get('Up', 0) + ai_like_counts.get('Down', 0)
ai_like_counts['percent_up'] = ai_like_counts.get('Up', 0) / ai_like_counts['n']

ci_bounds = ai_like_counts.apply(
    lambda row: proportion_ci(row.get('Up', 0), row['n']) if row['n'] > 0 else (np.nan, np.nan),
    axis=1, result_type='expand'
)
ai_like_counts['ci_lower'] = ci_bounds[0]
ai_like_counts['ci_upper'] = ci_bounds[1]

ai_like_counts['err_lower'] = ai_like_counts['percent_up'] - ai_like_counts['ci_lower']
ai_like_counts['err_upper'] = ai_like_counts['ci_upper'] - ai_like_counts['percent_up']


# Use the shared plotting function for consistency
ai_like_counts = ai_like_counts.reset_index()
ai_like_counts = ai_like_counts.sort_values('week')

fig = plot_no_ai_ai_over_time(
    red_by_week=ai_like_counts,
    y_var="percent_up",
    x_var="week",
    y_label="Percent Up (%)",
    title="Percent of AI Consult interactions with a thumbs up rating among all interactions with ratings",
    grouping_var='ClinicianGroup',
    grouping_var_label="",
    group_order=['', 'AI'],
    y_axis_percent=True,
    time_x_axis=True,
    y_lim=None,
)
save_fig("thumbs_up_over_time", fig)

## Latexvars

In [None]:
# Process latexvars dict into LaTeX \newcommand strings
def process_value(v):
    if pd.isna(v):
        return "N/A"
    elif isinstance(v, float) and np.isinf(v):
        return "Inf"
    else:
        return v

latex_newcommands = "\n".join(
    [f"\\newcommand\\{k}{{{process_value(v)}\\xspace}}" for k, v in latexvars.items()]
)
def escape_percent(s):
    # Replace any % that is not already escaped (\%) with \%
    if isinstance(s, str):
        s = s.replace("%", r"\%")
        s = s.replace("\\\\", "\\")
    return s

latex_newcommands = "\n".join(
    [f"\\newcommand\\{k}{{{escape_percent(v)}\\xspace}}" for k, v in latexvars.items()]
)

bf.write_text(plots_folder + "latex_commands.tex", latex_newcommands)