In [39]:
from dataclasses import dataclass
from pathlib import Path

import altair as alt
from altair import datum
import json
import pandas as pd


RUNS = Path("./Runs")

@dataclass
class Run:
    patient: int
    temperature: int
    screenability_precision: float
    screenability_recall: float
    trial_precision: float
    trial_recall: float
        
        
def parse_data(file):
    with open(file) as f:
        patient = int(file.stem[2:5])
        temperature = 1 if "TEMP1" in file.stem else 0
        _ = json.load(f)
        return Run(
            patient=patient,
            temperature=temperature,
            screenability_precision=_.get("screenability_precision"),
            screenability_recall=_.get("screenability_recall"),
            trial_precision=_.get("trial_precision"),
            trial_recall=_.get("trial_precision"),
        )

In [32]:
data = []
for file in RUNS.glob("**/*.json"):
    if file.stem.startswith("fp") and ("GPT4" not in file.stem):
        data.append(parse_data(file))
df = pd.DataFrame.from_dict(data)
df

Unnamed: 0,patient,temperature,screenability_precision,screenability_recall,trial_precision,trial_recall
0,3,1,0.42,0.28,0.50,0.50
1,10,1,0.59,0.17,0.29,0.29
2,7,1,0.54,0.23,0.43,0.43
3,7,0,0.58,0.27,0.43,0.43
4,8,1,0.25,0.29,1.00,1.00
...,...,...,...,...,...,...
195,2,0,0.73,0.29,0.50,0.50
196,4,0,0.66,0.16,0.67,0.67
197,1,1,0.50,0.26,1.00,1.00
198,3,0,0.62,0.28,0.57,0.57


In [52]:
def make_stripplot(df=df, temperature=0):
    return (
        alt.Chart(df)
        .mark_circle(size=8)
        .encode(
            y="patient:N",
            x=alt.X(
                alt.repeat("row"), type="quantitative", scale=alt.Scale(domain=[0, 1])
            ),
            yOffset="jitter:Q",
            color=alt.Color("patient:N").legend(None),
        )
        .transform_filter(datum.temperature == temperature)
        .transform_calculate(
            # Generate Gaussian jitter with a Box-Muller transform
            jitter="sqrt(-2*log(random()))*cos(2*PI*random())"
        )
        .properties(width=300, height=200)
        .repeat(
            row=[
                "screenability_precision",
                "screenability_recall",
                "trial_precision",
                "trial_recall",
            ]
        )
    )
make_stripplot(temperature=0)

In [53]:
make_stripplot(temperature=1)

In [60]:
df.groupby(["patient", "temperature",]).agg(
    screenability_precision_mean=("screenability_precision", "mean"),
    screenability_precision_std=("screenability_precision", "std"),
    screenability_recall_mean=("screenability_recall", "mean"),
    screenability_recall_std=("screenability_recall", "std"),
    trial_precision_mean=("trial_precision", "mean"),
    trial_precision_std=("trial_precision", "std"),
    trial_recall_mean=("trial_recall", "mean"),
    trial_recall_std=("trial_recall", "std"),
).round(3)

Unnamed: 0_level_0,Unnamed: 1_level_0,screenability_precision_mean,screenability_precision_std,screenability_recall_mean,screenability_recall_std,trial_precision_mean,trial_precision_std,trial_recall_mean,trial_recall_std
patient,temperature,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
1,0,0.585,0.056,0.244,0.008,1.0,0.0,1.0,0.0
1,1,0.616,0.101,0.281,0.04,0.98,0.063,0.98,0.063
2,0,0.835,0.127,0.301,0.047,0.45,0.158,0.45,0.158
2,1,0.647,0.216,0.324,0.047,0.626,0.116,0.626,0.116
3,0,0.576,0.034,0.257,0.049,0.596,0.072,0.596,0.072
3,1,0.434,0.037,0.252,0.029,0.585,0.137,0.585,0.137
4,0,0.642,0.024,0.153,0.008,0.667,0.037,0.667,0.037
4,1,0.626,0.074,0.177,0.014,0.699,0.038,0.699,0.038
5,0,0.524,0.056,0.153,0.026,0.67,0.0,0.67,0.0
5,1,0.506,0.102,0.173,0.069,0.737,0.201,0.737,0.201
