In [2]:
import pandas as pd
import numpy as np
import altair as alt
alt.data_transformers.disable_max_rows()
from joblib import Parallel, delayed
from numpy.random import Generator, PCG64
rng = Generator(PCG64())
import random
import math

In [3]:
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import accuracy_score

# Trial Parition Data

In [4]:
responses = pd.read_csv("https://data-visualization-benchmark.s3.us-west-2.amazonaws.com/vt-fusion/participant_responses.csv")
performance_df = responses.copy()
print("Number of participants: ", len(responses[['participant_id']].value_counts()))

Number of participants:  426


In [5]:
# responses.columns
responses['image_file'].nunique()

202

In [6]:
performance_df.iloc[-1][['mc1', 'mc2', 'mc3', 'mc4', 'question']].to_list()

["More than 60% of residents have a Bachelor's degree or above.",
 'The majority of residents have a graduate degree.',
 'There are fewer residents whose education level is Associate degree than those with a high school degree.',
 'None of the above.',
 'Which of the following is true about the education level of residents in city Z?']

### Substitute task categories

In [20]:
vlat_format_map = {
    'find_extremum': 'value identification', 
    'determine_Range': 'value identification', 
    'retrieve_value': 'value identification',
    'make_comparisons': 'arithmetic computation',
    'find_correlations_trends': 'statistical inference', 
    'characterize_distribution': 'statistical inference',
    'find_anomolies': 'statistical inference', 
    'find_clusters': 'statistical inference', 
}

calvi_format_map = {
    'Make Comparisons': 'arithmetic computation',
    'Find Correlations/Trends': 'statistical inference', 
    'Find Extremum': 'value identification',
    'Retrieve Value': 'value identification', 
    'Make Predictions': 'statistical inference', 
    'Aggregate Values': 'statistical inference'
}

brbf_task_map = {
    'max': "value identification", 
    'min': 'value identification', 
    'trend': 'statistical inference',
    'trendComp': 'arithmetic computation', 
    'average': 'statistical inference', 
    'intersection': 'value identification',
}

wainer_question_map = {
    "How many months have more rain than the average month?": "statistical inference",
    "Which season has more rain, the summer or the spring?": "arithmetic computation",
    "Which season has the most rain?": "value identification",
    "In San Francisco, as the weather gets warmer, there is generally more rain.": "statistical inference",
    "In which season does each month have less rain than the month before?": "arithmetic computation",
    "How many months have less than 40mm of rain?": "arithmetic computation",
    "How much does it rain in March?": "value identification",
    "Which month has 25 mm of rain?": "value identification"
}

ggr_question_map = {
    # from level 1
    "Approximately what percentage of people had Adeolitis in the year 2000?": "value identification",
    "What percentage of patients recovered after chemotherapy?": "value identification",
    "Of 100 patients with disease X, how many are women?": "value identification",
    "Of all the people who die from cancer, approximately what percentage dies from lung cancer?": "value identification",

    # from level 2
    "Approximately what percentage of people who die from cancer die from colon cancer, breast cancer, and prostate cancer taken together?": "arithmetic computation",
    "How many more men than women are there among 100 patients with disease X?": "arithmetic computation",
    "What is the difference between the percentage of patients who recovered after a surgery and the percentage of patients who recovered after radiation therapy?": "arithmetic computation",
    "When was the increase in the percentage of people with Adeolitis higher?": "arithmetic computation",

    # from level 3
    "What percentage of patients recovered after chemotherapy?": "value identification",
    "According to your best guess, what will the percentage of people with Adeolitis be in the year 2010?": "statistical inference",
    "Between 1980 and 1990, which disease had a higher increase in the percentage of people affected?": "statistical inference",
    "Compared to the placebo, which treatment leads to a larger decrease in the percentage of patients who die?": "statistical inference",
    "What is the percentage of cancer patients who die after chemotherapy?": "value identification",
    "Which of the treatments contributes to a larger decrease in the percentage of sick patients?": "statistical inference"
}

task_category_map = ggr_question_map | wainer_question_map | brbf_task_map | vlat_format_map | calvi_format_map

### Substitute chart categories

In [21]:
chart_categories = {
    'Dot Plot': 'Dot Plot',
    'Line Chart': 'Line',
    'Pie Chart': 'Pie',
    'chart': 'Table',
    'radial': 'Radial',
    'line': 'Line',
    'bar': 'Bar',
    'Scatterplot': 'Scatter',
    'Bar chart': 'Bar',
    'Stacked area chart': 'Stacked Area',
    '100% stacked bar chart': '100% Stacked Bar',
    'Line chart': 'Line',
    'Choropleth map': 'Map',
    'Area chart': 'Area',
    'Choropleth Map': 'Map',
    'Pie chart': 'Pie',
    'Stacked bar chart': 'Stacked Bar',
    'bg-table': 'Table',
    'bg': 'Bar',
    'sp-table': 'Table',
    'sp': 'Scatter',
    'lg2': 'Line',
    'lg1-table': 'Table',
    'lg1': 'Line',
    'lg2-table': 'Table',
    'Bar Chart': 'Bar',
    'Bubble Chart': 'Scatter',
    'Area Chart': 'Area',
    '100 % Stacked Bar Chart': '100% Stacked Bar',
    'Stacked Area Chart': 'Stacked Area',
    'Stacked Bar Chart': 'Stacked Bar',
    'Treemap': 'Treemap',
    'table': 'Table'
}

### Create item dataframe

In [22]:
def create_item_df():
    tests = ['wainer', 'brbf', 'ggr-mc', 'vlat', 'calvi']
    item_df = []
    for test in tests:
        idf = pd.read_csv(f'https://data-visualization-benchmark.s3.us-west-2.amazonaws.com/{test}/questions.csv')
        idf['test_type'] = test
        item_df.append(idf)

    item_df = pd.concat(item_df)
    item_df['question_image'] = item_df['question'] + " + " + item_df['image_file']
    return item_df
    
def add_brbf_categories(r):
    if (r['test_type'] == 'brbf'):
        return '-'.join(r['image_file'].split("-")[:-1])
    return r['chart_type']


item_df = create_item_df()
item_df['chart_type_filled'] = item_df.apply(add_brbf_categories, axis=1)
item_df['chart_type'] = item_df['chart_type_filled'].replace(chart_categories)
# item_df = item_df[item_df['chart_type'] != 'Table']

item_df['task_category'] = item_df['task_category'].apply(
    lambda t : task_category_map[t] if t in task_category_map.keys() else t
)
item_df['task_category'] = item_df.apply(
    lambda r : task_category_map[r['question']] if r['question'] in task_category_map.keys() else r['task_category'],
    axis=1
)
item_df['task_category'] = item_df['task_category'].apply(lambda x : x.replace(" ", "-")).to_numpy()
item_df['chart_type'] = item_df['chart_type'].apply(lambda x : x.replace(" ", "-")).to_numpy()

item_df['item_id'] = item_df['question_image']

In [23]:
def bootstrap_ci(
        data: pd.DataFrame,
        measure,
        id_col,
        n_iterations=1000,
        statistic=np.mean
    ):
    """
    
    """
    
    items = list(data[id_col].unique())
    n_size = len(items)
    df = data.copy()

    def bootstrap_iteration(data, chosen_items):
        # filter_df = data[data[id_col].isin(chosen_items)] # Filter based on chosen questions
        filter_df = pd.concat(
            [data[data[id_col].isin([item])] for item in chosen_items]
        )
        
        bs_mean = statistic(filter_df[measure]) 
        return (bs_mean, list(chosen_items))

    qset_means = Parallel(n_jobs=-1)(
        delayed(bootstrap_iteration)(
            df.copy(),
            rng.choice(items, n_size,  replace=True)
        ) for _ in range(n_iterations)
    )
    
    means = []
    qs_used = []
    means = [bs_mean for bs_mean, chosen_qs in qset_means]
 
    # 95% confidence interval
    lower = np.percentile(means, 2.5)
    upper = np.percentile(means, 97.5)
    
    return lower, upper

def create_confidence_interval_df(
    data: pd.DataFrame,
    measure, 
    id_col,
    condition_col,
    statistic=np.mean
):
    """ create the dataframe for 95% bootstrapped confidence interval
    data: dataset
    measure: dependent variable
    id_col: units to bootstrap along (e.g. participant_d or item_id)
    condition_col: the different conditions of the experiments
    """
    
    data_list = []

    for condition in data[condition_col].unique():
        condition_data = data[data[condition_col] == condition]

        lower, upper = bootstrap_ci(condition_data, measure=measure, statistic=statistic, id_col=id_col)
        mean = condition_data[measure].mean()
        
        data_list.append({
            "category": condition,
            "mean": mean,
            "ci_upper": upper, 
            "ci_lower": lower,
        })

    ci_df = pd.DataFrame(data_list)

    item_level_data = data.rename(columns={
        condition_col: 'category', 
    }).groupby([id_col, 'category'])[measure].mean().reset_index()

    return item_level_data, ci_df

In [24]:
assessment_color_map = {
    'wainer': '#3a0ca3',
    'ggr-mc': '#e26d5c',
    'brbf': '#f4a261',
    'vlat': '#247ba0',
    'calvi': '#70c1b3'
}


def box_muller():
    '''
    apparently controls how much jitter but who knows :shrug:
    '''
    # jitter_level * math.sqrt(-2 * math.log(random.random())) #* math.cos(2 * math.pi * random.random())
    return rng.normal() 


# Performance

## How does performance vary across items?

In [25]:
def find_amount_at_chance(r):
    option_count = 0
    for i in range(1, 11):
        if not pd.isna(r[f'mc{i}']):
            option_count += 1
    return 1 / option_count

item_df['chance_selection'] = item_df.apply(find_amount_at_chance, axis=1)

In [26]:
try:
    print("Reloading item-performance dataframes")
    item_ci_df = pd.read_csv("./dataframes/item_ci_df.csv")
    item_participant_level_data = pd.read_csv("./dataframes/item_participant_level_data.csv")
    
except:
    print("Regenerating item-performance dataframes")
    item_participant_level_data, item_ci_df = create_confidence_interval_df(
        performance_df, 
        measure='is_correct',
        id_col='participant_id',
        condition_col='question_image',
    )
    item_ci_df = pd.merge(item_ci_df, item_df[['question_image', 'test_type']].rename(columns={'question_image': 'category'}))
    item_participant_level_data.to_csv("./dataframes/item_participant_level_data.csv")
    item_ci_df.to_csv("./dataframes/item_ci_df.csv")

Reloading item-performance dataframes


In [27]:
assessment_color_map = {
    'wainer': '#8947BF', # '#3a0ca3',
    'ggr-mc': '#e26d5c',
    'brbf': '#f4a261',
    'vlat': '#247ba0',
    'calvi': '#70c1b3'
}

item_ci_df_sorted = item_ci_df.sort_values(by='mean', ascending=False)
item_domain_order = item_ci_df_sorted['category'].to_list()

test_domain = list(assessment_color_map.keys())
test_domain_color = list(assessment_color_map.values())

def create_item_accuracy_chart(ci_df, domain_order):
    """ Generate chart for point estimate and confidence intervals overlayed with item level scatteplots
    """
    # Create dot+error bar plot
    chart = alt.Chart(ci_df).mark_bar(opacity=0.7).encode(
        x=alt.X('category:N', title='Category', scale=alt.Scale(domain=domain_order), axis=None),
        y=alt.Y('mean:Q', title='Prop. Correct'),
        color=alt.Color('test_type:N', scale=alt.Scale(domain=test_domain, range=test_domain_color), legend=None),
    ) + alt.Chart(ci_df).mark_rule(strokeWidth=0.4, opacity=1).encode(
        x=alt.X('category:N'),
        y=alt.Y('ci_lower:Q'),
        y2='ci_upper:Q',
        color=alt.value("#717d7e")
        # color=alt.Color('test_type:N', scale=alt.Scale(domain=test_domain, range=test_domain_color), legend=None),
    )
    
    return chart.properties(width=800, height=200)

item_performance_reliability_plot = create_item_accuracy_chart(
    item_ci_df, 
    domain_order=item_domain_order,
)
# item_performance_reliability_plot.save("./figures/item_performance_reliability.pdf")
item_performance_reliability_plot

In [28]:
item_ci_df = pd.merge(item_df[['item_id', 'chance_selection']], item_ci_df, left_on='item_id', right_on='category')
item_ci_df['mean'] = item_ci_df['mean'] - item_ci_df['chance_selection']
item_ci_df['ci_upper'] = item_ci_df['ci_upper'] - item_ci_df['chance_selection']
item_ci_df['ci_lower'] = item_ci_df['ci_lower'] - item_ci_df['chance_selection']

item_performance_reliability_below_chance_plot = create_item_accuracy_chart(
    item_ci_df, 
    domain_order=item_domain_order,
)
item_performance_reliability_below_chance_plot

In [29]:
# item_ci_df['err_diff'] =  item_ci_df['ci_upper'] - item_ci_df['ci_lower']
# item_ci_df.sort_values(by='err_diff')

In [30]:
len(item_ci_df[item_ci_df['mean'] < 0])

36

## How does performance vary across tests?

In [31]:
try:
    print("Reloading test-performance dataframe")
    test_ci_df = pd.read_csv("./dataframes/test_ci_df.csv")
    test_item_level_data = pd.read_csv("./dataframes/test_item_level_data.csv")
    # raise Error
except:
    print("Regenerating test-performance dataframe")
    test_item_level_data, test_ci_df = create_confidence_interval_df(
        performance_df, 
        measure='is_correct',
        id_col='question_image',
        condition_col='test_type',
    )
    test_item_level_data.to_csv("./dataframes/test_item_level_data.csv", index=False)
    test_ci_df.to_csv("./dataframes/test_ci_df.csv", index=False)

Reloading test-performance dataframe


In [32]:
# item_df

In [33]:
def transform_ci_df(original_ci_df, item_df):
    ci_df = pd.merge(item_df[['item_id', 'chance_selection']], original_ci_df, left_on='item_id', right_on='category').copy()
    ci_df['mean'] = ci_df['mean'] - ci_df['chance_selection']
    ci_df['ci_upper'] = ci_df['ci_upper'] - ci_df['chance_selection']
    ci_df['ci_lower'] = ci_df['ci_lower'] - ci_df['chance_selection']
    return ci_df

transform_ci_df(test_ci_df, item_df)

Unnamed: 0,item_id,chance_selection,category,mean,ci_upper,ci_lower


In [34]:
test_ci_df

Unnamed: 0,category,mean,ci_upper,ci_lower
0,wainer,0.777124,0.840152,0.711434
1,brbf,0.687439,0.733594,0.633982
2,ggr-mc,0.616352,0.752746,0.478788
3,vlat,0.64327,0.711412,0.576438
4,calvi,0.414917,0.483861,0.347289


In [35]:
test_order = list(assessment_color_map.keys())
test_domain_colors = list(assessment_color_map.values())

def create_test_accuracy_chart(ci_df, item_level_data, domain_order, test_domain_colors=[]):
    """ Generate chart for point estimate and confidence intervals overlayed with item level scatteplots
    """
    # Create dot+error bar plot
    chart = alt.Chart(ci_df).mark_point(filled=True, size=75, opacity=1).encode(
        x=alt.X('category:N', title='Category', scale=alt.Scale(domain=domain_order)),
        y=alt.Y('mean:Q', title='Prop. Correct'),
        color=alt.Color('category:N', scale=alt.Scale(domain=domain_order, range=test_domain_colors), legend=None),
    ) + alt.Chart(ci_df).mark_rule(strokeWidth=2).encode(
        x=alt.X('category:N'),
        y=alt.Y('ci_lower:Q'),
        y2='ci_upper:Q',
        color=alt.Color('category:N', legend=None),
        opacity=alt.value(1)
    )
    
    scatter_plot = alt.Chart(item_level_data).mark_point(filled=True).encode(
        x=alt.X('category:N', title='Category'),
        y=alt.Y("is_correct:Q",),
        xOffset="jitter:Q",
        color=alt.Color('category:N', legend=None),
        size=alt.value(16),
        opacity=alt.value(0.3)
    )
    return (chart + scatter_plot)

test_item_level_data['jitter'] = test_item_level_data.apply(
    lambda _ : box_muller(),
    axis=1
)

test_type_performance = create_test_accuracy_chart(
    test_ci_df, test_item_level_data, test_order, test_domain_colors
).properties(
    width=40*5, height=200, title=f"Prop. correct across tests", 
)
test_type_performance.save("./figures/test_type_performance.pdf")
test_type_performance

In [36]:
test_ci_df.sort_values(by="mean")

Unnamed: 0,category,mean,ci_upper,ci_lower
4,calvi,0.414917,0.483861,0.347289
2,ggr-mc,0.616352,0.752746,0.478788
3,vlat,0.64327,0.711412,0.576438
1,brbf,0.687439,0.733594,0.633982
0,wainer,0.777124,0.840152,0.711434


In [37]:
test_item_level_data.groupby('category')['is_correct'].std()

category
brbf      0.222174
calvi     0.284198
ggr-mc    0.265484
vlat      0.261555
wainer    0.197308
Name: is_correct, dtype: float64

In [38]:
total_chances = []
for i, row in item_df[item_df['test_type'] == 'calvi'].iterrows():
    valid_mc_counter = 0
    for i in range(1, 11):
        
        if not pd.isna(row[f'mc{i}']):
            valid_mc_counter += 1

    total_chances.append(1 / valid_mc_counter)
np.mean(total_chances)

0.2833333333333333

In [39]:
import scipy.stats as stats

item_means = performance_df.groupby(["test_type", 'question_image'])['is_correct'].mean().reset_index()

groups = [performance_df["is_correct"].values for _, group in item_means.groupby('test_type')]
f_statistic, p_value = stats.f_oneway(*groups)

In [40]:
# item_means

## How does performance vary across different types of graphs?

In [41]:
try:
    print("Reloading test-performance dataframe")
    graph_item_level_data = pd.read_csv("./dataframes/graph_item_level_data.csv")
    graph_ci_df = pd.read_csv("./dataframes/graph_ci_df.csv")
except:
    print("Regenerating test-performance dataframe")
    
    def add_brbf_categories(r):
        if (r['test_type'] == 'brbf'):
            return '-'.join(r['image_file'].split("-")[:-1])
        return r['chart_type']

    chart_performance_df = performance_df.copy()
    chart_performance_df['chart_type_filled'] = chart_performance_df.apply(add_brbf_categories, axis=1)
    chart_performance_df['common_chart_type'] = chart_performance_df['chart_type_filled'].replace(chart_categories)
    chart_performance_df = chart_performance_df[chart_performance_df['common_chart_type'] != 'Table']
    
    graph_item_level_data, graph_ci_df = create_confidence_interval_df(
        chart_performance_df, 
        measure='is_correct',
        id_col='question_image',
        condition_col='common_chart_type',
    )
    graph_item_level_data = pd.merge(graph_item_level_data, item_df[['test_type', 'question_image']])
    graph_ci_df.to_csv("./dataframes/graph_ci_df.csv",index=False)
    graph_item_level_data.to_csv("./dataframes/graph_item_level_data.csv",index=False)

Reloading test-performance dataframe


In [42]:
test_order = list(assessment_color_map.keys())
test_domain_colors = list(assessment_color_map.values())
chart_order = graph_ci_df.sort_values(by="mean", ascending=False)['category'].to_list()

def format_hexcode_alpha(hexcode):
    samples = np.linspace(60, 99, num=len(chart_order))
    return [hexcode + str(int(s)) for s in samples][::-1]

# color_order = [hexcode for hexcode in format_hexcode_alpha("#17202a")]
color_order = ["#17202a" for _ in range(len(chart_order))]

def create_graph_performance_chart(ci_df, item_level_data, domain_order, domain_color=[]):
    """ Generate chart for point estimate and confidence intervals overlayed with item level scatteplots
    """
    # Create dot+error bar plot
    chart = alt.Chart(ci_df).mark_point(filled=True, size=75, opacity=1).encode(
        x=alt.X('category:N', title='Category', scale=alt.Scale(domain=domain_order)),
        y=alt.Y('mean:Q', title='Prop. Correct'),
        color=alt.Color('category:N', scale=alt.Scale(domain=domain_order, range=domain_color), legend=None),
    ) + alt.Chart(ci_df).mark_rule(strokeWidth=2).encode(
        x=alt.X('category:N'),
        y=alt.Y('ci_lower:Q'),
        y2='ci_upper:Q',
        color=alt.Color('category:N', scale=alt.Scale(domain=domain_order, range=domain_color), legend=None)
    )
    
    scatter_plot = alt.Chart(item_level_data).mark_point(opacity=0.35, filled=True).encode(
        x=alt.X('category:N', title='Category'),
        y=alt.Y("is_correct:Q",),
        xOffset="jitter:Q",
        color=alt.Color('test_type:N', scale=alt.Scale(domain=test_order, range=test_domain_colors), legend=None),
        size=alt.value(16),
    )

    return (chart + scatter_plot).resolve_scale(color='independent')


graph_item_level_data['jitter'] = graph_item_level_data.apply(
    lambda _ : box_muller(),
    axis=1
)
chart_performance = create_graph_performance_chart(graph_ci_df, graph_item_level_data, chart_order, color_order).properties(
    width=13*34, height=200, title=f"Prop. correct across tasks", 
)
chart_performance.save("./figures/graph_type_performance.pdf")
chart_performance

In [43]:
counted_df = graph_item_level_data[['test_type', 'category']].value_counts().reset_index()

# counted_df.pivot(values='count', index='category', columns='test_type')
counted_df.groupby('category')['count'].sum()

category
100% Stacked Bar     5
Area                 8
Bar                 38
Dot Plot             2
Histogram            3
Line                41
Map                 13
Pie                 10
Radial               8
Scatter             36
Stacked Area         8
Stacked Bar         11
Treemap              3
Name: count, dtype: int64

In [44]:
graph_ci_df.sort_values(by='mean')

Unnamed: 0,category,mean,ci_upper,ci_lower
6,Stacked Bar,0.395299,0.509522,0.27965
10,Stacked Area,0.397059,0.672735,0.207292
11,Map,0.465766,0.658661,0.278163
9,Area,0.510355,0.704065,0.316322
3,Scatter,0.532263,0.607401,0.449259
7,100% Stacked Bar,0.574419,0.711628,0.450116
0,Line,0.598045,0.67719,0.516103
4,Pie,0.637011,0.834148,0.414683
1,Bar,0.694978,0.773654,0.610742
5,Dot Plot,0.747126,0.896552,0.597701


In [45]:
graph_item_level_data.groupby('category')['is_correct'].std().sort_values()

category
Histogram           0.064656
Treemap             0.177028
100% Stacked Bar    0.177746
Radial              0.181634
Stacked Bar         0.204251
Dot Plot            0.211319
Scatter             0.244339
Line                0.263493
Bar                 0.267851
Area                0.300049
Map                 0.352601
Stacked Area        0.352747
Pie                 0.358117
Name: is_correct, dtype: float64

## How does performance vary across different kinds of tasks?

In [46]:
try:
    print("Reloading task-performance dataframe")
    task_item_level_data = pd.read_csv("./dataframes/task_item_level_data.csv")
    task_ci_df = pd.read_csv("./dataframes/task_ci_df.csv")
except:
    print("Regenerating task-performance dataframe")
    task_performance_df = performance_df.copy()
    task_performance_df['task_category'] = task_performance_df['task_category'].apply(
        lambda t : task_category_map[t] if t in task_category_map.keys() else t
    )
    task_performance_df['task_category'] = task_performance_df.apply(
        lambda r : task_category_map[r['question']] if r['question'] in task_category_map.keys() else r['task_category'],
        axis=1
    )
    # question_task_df = task_performance_df[['question_image', 'task_category']].value_counts().reset_index()
    # question_task_df[['task_category']].value_counts().reset_index()
    
    task_item_level_data, task_ci_df = create_confidence_interval_df(
        task_performance_df, 
        measure='is_correct',
        id_col='question_image',
        condition_col='task_category',
    )
    task_item_level_data = pd.merge(task_item_level_data, item_df[['test_type', 'question_image']])
    task_item_level_data = task_item_level_data.reset_index()

    task_item_level_data.to_csv("./dataframes/task_item_level_data.csv", index=False)
    task_ci_df.to_csv("./dataframes/task_ci_df.csv", index=False)

Reloading task-performance dataframe


In [47]:
task_performance_df = performance_df.copy()
task_performance_df['task_category'] = task_performance_df['task_category'].apply(
    lambda t : task_category_map[t] if t in task_category_map.keys() else t
)
task_performance_df['task_category'] = task_performance_df.apply(
    lambda r : task_category_map[r['question']] if r['question'] in task_category_map.keys() else r['task_category'],
    axis=1
)
task_performance_df[['test_type', 'task_category']].value_counts()

test_type  task_category         
brbf       value identification      3062
calvi      arithmetic computation    2717
vlat       value identification      2529
calvi      statistical inference     2207
brbf       statistical inference     2026
vlat       arithmetic computation    1033
wainer     arithmetic computation    1024
           value identification      1011
brbf       arithmetic computation    1010
vlat       statistical inference      940
wainer     statistical inference      684
ggr-mc     value identification       436
           arithmetic computation     341
           statistical inference      336
calvi      value identification       171
Name: count, dtype: int64

In [48]:
task_order = task_ci_df.sort_values(by="mean", ascending=False)['category'].to_list()
test_order = list(assessment_color_map.keys())
test_domain_colors = list(assessment_color_map.values())
task_color_order = ["#17202a" for _ in range(len(task_order))]

def create_task_performance_chart(ci_df, item_level_data, domain_order, domain_color=[]):
    """ Generate chart for point estimate and confidence intervals overlayed with item level scatteplots
    """
    # Create dot+error bar plot
    chart = alt.Chart(ci_df).mark_point(filled=True, size=75).encode(
        x=alt.X('category:N', title='Category', scale=alt.Scale(domain=domain_order)),
        y=alt.Y('mean:Q', title='Prop. Correct'),
        color=alt.Color('category:N', scale=alt.Scale(domain=domain_order, range=domain_color), legend=None),
        opacity=alt.value(1)
    ) + alt.Chart(ci_df).mark_rule(strokeWidth=2).encode(
        x=alt.X('category:N'),
        y=alt.Y('ci_lower:Q'),
        y2='ci_upper:Q',
        color=alt.Color('category:N', scale=alt.Scale(domain=domain_order, range=domain_color), legend=None),
        opacity=alt.value(1)
    )

    scatter_plot = alt.Chart(item_level_data).mark_point(filled=True).encode(
        x=alt.X('category:N', title='Category'),
        y=alt.Y("is_correct:Q",),
        xOffset="jitter:Q",
        color=alt.Color('test_type:N', scale=alt.Scale(domain=test_order, range=test_domain_colors), legend=None),
        size=alt.value(16),
        opacity=alt.value(0.3)
    )    
    return (scatter_plot + chart).resolve_scale(color='independent')

task_item_level_data['jitter'] = task_item_level_data.apply(
    lambda i : box_muller(),
    axis=1
)
task_performance = create_task_performance_chart(
    task_ci_df, task_item_level_data, task_order, task_color_order).properties(
    width=40*3, height=200, title=f"Prop. correct across tasks", 
)
task_performance.save("./figures/task_type_performance.pdf")
task_performance

In [49]:
task_ci_df

Unnamed: 0,category,mean,ci_upper,ci_lower
0,value identification,0.710917,0.766009,0.650338
1,arithmetic computation,0.516408,0.577568,0.45282
2,statistical inference,0.599548,0.654147,0.539598


In [50]:
task_item_level_data.groupby('category')['is_correct'].std()

category
arithmetic computation    0.277484
statistical inference     0.253635
value identification      0.266221
Name: is_correct, dtype: float64

## How does performance vary across presentation modality (tables vs. plots)?

### Preprocess

In [51]:
modality_performance_df = performance_df.copy()
modality_performance_df = modality_performance_df[
    modality_performance_df['test_type'].isin(['brbf', 'wainer'])
]

modality_performance_df['chart_type'] = modality_performance_df.apply(
    lambda r : r['graph_type'] if pd.isna(r['chart_type']) else r['chart_type'],
    axis=1
)
modality_performance_df['chart_type'] = modality_performance_df['chart_type'].astype(str)
modality_performance_df['modality'] = modality_performance_df.apply(
    lambda r : "table" if "table" in r['chart_type'] else "graph",
    axis=1
)
modality_question_count = modality_performance_df[['question_image', 'modality']].value_counts().reset_index()
modality_count = modality_question_count[['modality']].value_counts().reset_index()
modality_count

Unnamed: 0,modality,count
0,graph,60
1,table,44


In [52]:
wainer_gt_performance_df = modality_performance_df[
    modality_performance_df['test_type'].isin(['wainer'])
].copy()

wainer_gt_performance_means = wainer_gt_performance_df.groupby(
    ['question_image', 'question', 'chart_type', 'test_type', 'modality']
)['is_correct'].mean().reset_index()
wainer_gt_performance_means['question_id'] = wainer_gt_performance_means['chart_type']


wainer_gt_performance_means_dfs = []
for chart in ['bar', 'line', 'radial']:
    wainer_gt_performance_chart_means = wainer_gt_performance_means[
        wainer_gt_performance_means['chart_type'].isin(['table', chart])
    ].copy()

    wainer_gt_performance_chart_means['question_id'] = (
        wainer_gt_performance_chart_means['question'] + chart
    )
    
    wainer_gt_performance_means_dfs.append(
        wainer_gt_performance_chart_means
    )
wainer_gt_performance_means = pd.concat(wainer_gt_performance_means_dfs)
wainer_gt_performance_pivot = wainer_gt_performance_means.pivot(
    index='question_id', values='is_correct', columns='modality'
).reset_index()


In [53]:
question_category = [
    ('bg-1', 'modified'),
    ('bg-2', 'modified'),
    ('bg-3', 'base'),
    ('bg-4', 'base'),
    ('bg-5', 'modified'),
    ('bg-6', 'modified'),
    ('bg-7', 'modified'),
    ('bg-8', 'base'),
    ('bg-9', 'modified'),
    ('bg-10', 'base'),
    ('bg-11', 'base'),
    ('bg-12', 'base'),
    ##
    ('sp-1', 'modified'),
    ('sp-2', 'base'),
    ('sp-3', 'base'),
    ('sp-4', 'base'),
    ('sp-5', 'modified'),
    ('sp-6', 'modified'),
    ('sp-7', 'modified'),
    ('sp-8', 'modified'),
    ('sp-9', 'modified'),
    ('sp-10', 'base'),
    ('sp-11', 'base'),
    ('sp-12', 'base'),
    ##
    ('lg1-1', 'base'),
    ('lg1-2', 'modified'),
    ('lg1-3', 'modified'),
    ('lg1-4', 'base'),
    ('lg1-5', 'modified'),
    ('lg1-6', 'modified'),
    ('lg1-7', 'base'),
    ('lg1-8', 'base'),
    ('lg1-9', 'base'),
    ('lg1-10', 'modified'),
    ('lg1-11', 'base'),
    ('lg1-12', 'modified'),
]

for i in range(1, 13):
    if i <= 6:
        question_category.append((f'bg-table-{i}', "base"))
        question_category.append((f'sp-table-{i}', "base"))
        question_category.append((f'lg1-table-{i}', "base"))
    else:
        question_category.append((f'bg-table-{i}', "modified"))
        question_category.append((f'sp-table-{i}', "modified"))
        question_category.append((f'lg1-table-{i}', "modified"))

In [54]:
modality_category_df = pd.DataFrame(question_category, columns=['chart_type', 'item_type'])
brbf_gt_performance_df = modality_performance_df[
    modality_performance_df['test_type'].isin(['brbf'])
].copy()

brbf_gt_performance_df['chart_type'] = brbf_gt_performance_df['chart_type'].apply(
    lambda c : c.replace('.png', '')
)

brbf_gt_performance_means = brbf_gt_performance_df.groupby(
    ['question_image', 'task_category', 'chart_type', 'test_type', 'modality']
)['is_correct'].mean().reset_index()

brbf_gt_performance_means = pd.merge(brbf_gt_performance_means, modality_category_df)
brbf_gt_performance_means['question_id'] = (
    brbf_gt_performance_means['task_category'] + '-' +
    brbf_gt_performance_means['item_type']
)
brbf_gt_performance_means['chart_type'] = brbf_gt_performance_means['chart_type'].apply(
    lambda c : '-'.join(c.split('-')[:-1])
)

brbf_gt_performance_means['question_id'] = (
    brbf_gt_performance_means['question_id'] + '-' 
    + brbf_gt_performance_means['chart_type'].apply(lambda c : c.replace("-table", ""))
)

brbf_gt_performance_means_pivot = brbf_gt_performance_means.pivot(
    index='question_id', columns='modality' , values='is_correct'
).reset_index()

wainer_gt_performance_pivot['test_type'] = 'wainer'
brbf_gt_performance_means_pivot['test_type'] = 'brbf'

table_graph_performance_means = pd.concat([
    wainer_gt_performance_pivot,
    brbf_gt_performance_means_pivot
])
table_graph_performance_means = table_graph_performance_means.melt(
    id_vars=['question_id', 'test_type'], 
    value_vars=['graph', 'table']
)

In [55]:

# item_level_data, ci_df = create_confidence_interval_df(
#     modality_performance_df, 
#     measure='is_correct',
#     id_col='question_image',
#     condition_col='modality',
# )


try:
    modality_ci_df = pd.read_csv("./dataframes/modality_item_level_data.csv")
    modality_item_level_data = pd.read_csv("./dataframes/modality_item_level_data.csv")
    table_graph_comparison_df = pd.read_csv("./dataframes/table_graph_comparison_df.csv")
    print("Reloading modality dataframes")
except:
    print("Regenerating modality dataframes")
    mod_df = modality_performance_df.copy()
    mod_df['task_category'] = mod_df['task_category'].apply(
        lambda t : task_category_map[t] if t in task_category_map.keys() else t
    )
    mod_df['task_category'] = mod_df.apply(
        lambda r : task_category_map[r['question']] if r['question'] in task_category_map.keys() else r['task_category'],
        axis=1
    )
    mod_df['test_modality'] = mod_df['modality']  + ' - ' + mod_df['test_type']

    
    modality_item_level_data, modality_ci_df = create_confidence_interval_df(
        mod_df,
        measure='is_correct',
        id_col='question_image',
        condition_col='test_modality',
    )
    
    modality_ci_df = modality_ci_df.rename(columns={'category': 'modality'})
    modality_ci_df['test_type'] = modality_ci_df['modality'].apply(lambda m : m.split(" - ")[1])
    modality_ci_df['modality'] = modality_ci_df['modality'].apply(lambda m : m.split(" - ")[0])

    table_graph_comparison_df = table_graph_performance_means.pivot(
        index='question_id', columns='modality' , values='value'
    ).reset_index()
    
    modality_ci_df.to_csv("./dataframes/modality_ci_df.csv")
    modality_item_level_data.to_csv("./dataframes/modality_item_level_data.csv")
    table_graph_comparison_df.to_csv("./dataframes/table_graph_comparison_df.csv")

Reloading modality dataframes


In [56]:
# modality_ci_df


In [57]:
# table_graph_comparison_df

### Figure

In [58]:
# table_graph_performance_means

In [59]:

domain = ['table', 'graph']
color_domain = [ assessment_color_map['wainer'], assessment_color_map['brbf']]
test_domain = ['wainer', 'brbf']


chart = alt.Chart(
    table_graph_performance_means
).mark_line(opacity=0.2).encode(
    x='modality:N',
    y='value',
    detail='question_id',
    color=alt.Color('test_type', scale=alt.Scale(domain=test_domain, range=color_domain),)
)

point_plot = alt.Chart(
    table_graph_performance_means
).mark_circle(opacity=0.1).encode(
    x='modality:N',
    y='value',
    color=alt.Color('test_type',legend=None, scale=alt.Scale(domain=test_domain, range=color_domain)),
    size=alt.value(30)
)

offset_domain = [30, 45]

ci_point_plot =  alt.Chart(modality_ci_df).mark_line(strokeWidth=3, opacity=1).encode(
    x=alt.X('modality:N', title='Category', scale=alt.Scale(domain=domain)),
    y=alt.Y('mean:Q', title='Prop. Correct'),
    opacity=alt.value(1),
    xOffset=alt.XOffset('test_type', scale=alt.Scale(domain=test_domain, range=offset_domain)),
    color=alt.Color(
        'test_type', 
        scale=alt.Scale(domain=test_domain, range=color_domain),
        legend=None
    ),
) + alt.Chart(modality_ci_df).mark_rule(strokeWidth=3, opacity=1).encode(
    x=alt.X('modality:N'),
    y=alt.Y('ci_lower:Q'),
    y2='ci_upper:Q',
    xOffset=alt.XOffset('test_type', scale=alt.Scale(domain=test_domain, range=offset_domain)),
    color=alt.Color(
        'test_type', 
        scale=alt.Scale(domain=test_domain, range=color_domain),
        legend=None
    ),
    detail='modality'
) + alt.Chart(modality_ci_df).mark_point(filled=True, size=75, fill='white', strokeWidth=2).encode(
    x=alt.X('modality:N', title='Category', scale=alt.Scale(domain=domain)),
    y=alt.Y('mean:Q', title='Prop. Correct'),
    opacity=alt.value(1),
    xOffset=alt.XOffset('test_type', scale=alt.Scale(domain=test_domain, range=offset_domain)),
    stroke=alt.Color(
        'test_type', 
        scale=alt.Scale(domain=test_domain, range=color_domain),
        legend=None
    ),
    # detail='modality'
)

modality_plot = (chart + point_plot + ci_point_plot).resolve_scale(color='independent', xOffset='independent')
modality_plot = modality_plot.properties(width=150, height=200, title=f"Table vs Graph")
modality_plot.save("./figures/modality_performance.pdf")
modality_plot

ValueError: Unable to determine data type for the field "test_type"; verify that the field name is not misspelled. If you are referencing a field from a transform, also confirm that the data type is specified correctly.

In [689]:
modality_ci_df

Unnamed: 0,modality,mean,ci_upper,ci_lower,test_type
0,graph,0.754912,0.831047,0.663044,wainer
1,table,0.843338,0.928416,0.744828,wainer
2,graph,0.684159,0.754768,0.615986,brbf
3,table,0.690718,0.76315,0.611557,brbf


In [60]:
def create_pairwise_agent_heatmap(df, x, y, domain, units_of_measure, include_text=True, text_format=".3f"):

    base = alt.Chart(df).mark_rect().encode(
        x=alt.X(x, scale=alt.Scale(domain=domain)),  
        y=alt.Y(y, scale=alt.Scale(domain=domain)),
    )

    color_domain=[-1,1]
    color_reverse = False
    color_condition = alt.condition(
        alt.datum[units_of_measure] < 0.5,
        alt.value('black'),
        alt.value('white')
    )


    height=300
    width=400 
    title=f"Accuracy Vector Correlation between Participants"

    heatmap = base.mark_rect().encode(
        color=alt.Color(f'{units_of_measure}:Q', 
                        legend=None, 
                        scale=alt.Scale(
                            scheme="redgrey",
                            domain=color_domain,
                            reverse=color_reverse)
                       ),
    )

    text = base.mark_text(baseline='middle').encode(
        alt.Text(f'{units_of_measure}:Q', format=text_format),
        color=color_condition
    )

    chart = heatmap.properties(
        width=width,
        height=height,
        title=title
    ) 
    if include_text:
        chart = chart + text

    return chart

### Comparison

In [61]:
from scipy.stats import ttest_rel

In [62]:
table_graph_comparison_df = pd.read_csv("./dataframes/table_graph_comparison_df.csv")
ttest_rel(table_graph_comparison_df['graph'], table_graph_comparison_df['table'])

TtestResult(statistic=-1.4230257645359137, pvalue=0.1599950523321966, df=59)

In [63]:
# table_graph_comparison_df

### Difference of means

In [64]:
def bootstrap_diff_of_mean_ci(
        data: pd.DataFrame,
        measure,
        id_col,
        n_iterations=1000,
        statistic=np.mean
    ):
    """
    
    """
    
    items = list(data[id_col].unique())
    n_size = len(items)
    df = data.copy()

    def bootstrap_iteration(data, chosen_items):
        # filter_df = data[data[id_col].isin(chosen_items)] # Filter based on chosen questions
        filter_df = pd.concat(
            [data[data[id_col].isin([item])] for item in chosen_items]
        )
        
        bs_mean = statistic(filter_df[measure]) 
        return (bs_mean, list(chosen_items))

    
    qset_means1 = Parallel(n_jobs=-1)(
        delayed(bootstrap_iteration)(
            df[df['modality'] == 'table'].copy(),
            rng.choice(items, n_size,  replace=True)
        ) for _ in range(n_iterations)
    )

    qset_means2 = Parallel(n_jobs=-1)(
        delayed(bootstrap_iteration)(
            df[df['modality'] == 'graph'].copy(),
            rng.choice(items, n_size,  replace=True)
        ) for _ in range(n_iterations)
    )
    
    means = []
    qs_used = []
    means1 = [bs_mean for bs_mean, chosen_qs in qset_means1]
    means2 = [bs_mean for bs_mean, chosen_qs in qset_means2]
    means = np.subtract(means1, means2)
 
    # 95% confidence interval
    lower = np.percentile(means, 2.5)
    upper = np.percentile(means, 97.5)
    
    return lower, upper

def create_diff_of_mean_ci_df(
    data: pd.DataFrame,
    measure,
    id_col,
    statistic=np.mean
):
    """ create the dataframe for 95% bootstrapped confidence interval
    data: dataset
    measure: dependent variable
    id_col: units to bootstrap along (e.g. participant_d or item_id)
    condition_col: the different conditions of the experiments
    """
    
    data_list = []


    lower, upper = bootstrap_diff_of_mean_ci(data, measure=measure, statistic=statistic, id_col=id_col)
    mean1 = data[data['modality'] == 'table'][measure].mean()
    mean2 = data[data['modality'] == 'graph'][measure].mean()
    mean = mean1 - mean2
    
    data_list.append({
        "mean": mean,
        "ci_upper": upper, 
        "ci_lower": lower,
    })

    # item_level_data = data.rename(columns={
    #     condition_col: 'category', 
    # }).groupby([id_col, 'category'])[measure].mean().reset_index()

    return pd.DataFrame(data_list)

In [65]:
diff_of_mean_modality_ci_df = create_diff_of_mean_ci_df(
    mod_df[mod_df['test_type'] == 'brbf'],
    measure='is_correct',
    id_col='question_image',
)
diff_of_mean_modality_ci_df

NameError: name 'mod_df' is not defined

In [66]:
diff_of_mean_modality_ci_df

NameError: name 'diff_of_mean_modality_ci_df' is not defined

In [67]:
diff_of_mean_modality_ci_df = create_diff_of_mean_ci_df(
    mod_df[mod_df['test_type'] == 'wainer'],
    measure='is_correct',
    id_col='question_image',
)
diff_of_mean_modality_ci_df

NameError: name 'mod_df' is not defined

In [19]:
# mod_df['test_type'].unique()
# mod_df[mod_df['test_type'] == 'brbf']['is_correct'].mean()

# Model Fit

# Model comparison fit

## Load R + Libs

In [68]:
%load_ext rpy2.ipython

In [69]:
%%R
install.packages("lme4")
library(lme4)

--- Please select a CRAN mirror for use in this session ---
Secure CRAN mirrors 

 1: 0-Cloud [https]
 2: Australia (Canberra) [https]
 3: Australia (Melbourne 1) [https]
 4: Australia (Melbourne 2) [https]
 5: Austria (Wien 1) [https]
 6: Belgium (Brussels) [https]
 7: Brazil (PR) [https]
 8: Brazil (SP 1) [https]
 9: Brazil (SP 2) [https]
10: Bulgaria [https]
11: Canada (MB) [https]
12: Canada (ON 1) [https]
13: Canada (ON 2) [https]
14: Chile (Santiago) [https]
15: China (Beijing 1) [https]
16: China (Beijing 2) [https]
17: China (Beijing 3) [https]
18: China (Hefei) [https]
19: China (Hong Kong) [https]
20: China (Jinan) [https]
21: China (Lanzhou) [https]
22: China (Nanjing) [https]
23: China (Shanghai 2) [https]
24: China (Shenzhen) [https]
25: China (Wuhan) [https]
26: Colombia (Cali) [https]
27: Cyprus [https]
28: Czech Republic [https]
29: Denmark [https]
30: East Asia [https]
31: Ecuador (Cuenca) [https]
32: Finland (Helsinki) [https]
33: France (Lyon 1) [https]
34: France (L

Selection:  68



The downloaded binary packages are in
	/var/folders/v8/3zpbxkws53b3x6m8509jyml80000gn/T//Rtmp63v1Wr/downloaded_packages


trying URL 'https://repo.miserver.it.umich.edu/cran/bin/macosx/big-sur-arm64/contrib/4.4/lme4_1.1-37.tgz'
Content type 'application/octet-stream' length 7092567 bytes (6.8 MB)
downloaded 6.8 MB

Loading required package: Matrix
In doTryCatch(return(expr), name, parentenv, handler) :
  unable to load shared object '/Library/Frameworks/R.framework/Resources/modules//R_X11.so':
  dlopen(/Library/Frameworks/R.framework/Resources/modules//R_X11.so, 0x0006): Library not loaded: /opt/X11/lib/libSM.6.dylib
  Referenced from: <34C5A480-1AC4-30DF-83C9-30A913FC042E> /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/modules/R_X11.so
  Reason: tried: '/opt/X11/lib/libSM.6.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/X11/lib/libSM.6.dylib' (no such file), '/opt/X11/lib/libSM.6.dylib' (no such file), '/usr/local/lib/libSM.6.dylib' (no such file), '/usr/lib/libSM.6.dylib' (no such file, not in dyld cache)


In [70]:
%%R
install.packages("MuMIn")
library(MuMIn)


The downloaded binary packages are in
	/var/folders/v8/3zpbxkws53b3x6m8509jyml80000gn/T//Rtmp63v1Wr/downloaded_packages


trying URL 'https://repo.miserver.it.umich.edu/cran/bin/macosx/big-sur-arm64/contrib/4.4/MuMIn_1.48.11.tgz'
Content type 'application/octet-stream' length 913416 bytes (892 KB)
downloaded 892 KB



In [71]:
%%R

install.packages("dplyr")
library(dplyr)


The downloaded binary packages are in
	/var/folders/v8/3zpbxkws53b3x6m8509jyml80000gn/T//Rtmp63v1Wr/downloaded_packages


trying URL 'https://repo.miserver.it.umich.edu/cran/bin/macosx/big-sur-arm64/contrib/4.4/dplyr_1.1.4.tgz'
Content type 'application/octet-stream' length 1599250 bytes (1.5 MB)
downloaded 1.5 MB


Attaching package: ‘dplyr’

The following objects are masked from ‘package:stats’:

    filter, lag

The following objects are masked from ‘package:base’:

    intersect, setdiff, setequal, union



### Model Comparison

## Create dataframe for models

In [72]:
def add_brbf_categories(r):
    if (r['test_type'] == 'brbf'):
        return '-'.join(r['image_file'].split("-")[:-1])
    return r['chart_type']


all_df = performance_df.copy()
all_df['chart_type_filled'] = all_df.apply(add_brbf_categories, axis=1)
all_df['chart_type'] = all_df['chart_type_filled'].replace(chart_categories)
all_df = all_df[all_df['chart_type'] != 'Table']

all_df['task_category'] = all_df['task_category'].apply(
    lambda t : task_category_map[t] if t in task_category_map.keys() else t
)
all_df['task_category'] = all_df.apply(
    lambda r : task_category_map[r['question']] if r['question'] in task_category_map.keys() else r['task_category'],
    axis=1
)

all_df['item_id'] = all_df['question_image']
all_df = all_df[['test_type', 'chart_type', 'task_category', 'item_id', 'is_correct']]
initial_df = all_df.copy()

print(
    all_df['chart_type'].unique(), 
    all_df['task_category'].unique(), 
    all_df['test_type'].unique()
)

['Line' 'Bar' 'Radial' 'Scatter' 'Pie' 'Dot Plot' 'Stacked Bar'
 '100% Stacked Bar' 'Histogram' 'Area' 'Stacked Area' 'Map' 'Treemap'] ['value identification' 'arithmetic computation' 'statistical inference'] ['wainer' 'brbf' 'ggr-mc' 'vlat' 'calvi']


## Model Comparisons

In [73]:
%%R

# define options for all models
control <- glmerControl(
    optimizer = "bobyqa",
    optCtrl = list(maxfun = 2e5)
)

### Null model

In [74]:
%%R -i all_df

null_model <- glmer(
    is_correct ~ 1 + (1 | item_id), 
    data = all_df, 
    family = binomial, 
    control = control
)

### Comparing test fixed effect to null model

In [75]:
%%R -i all_df

test_model <- glmer(
    is_correct ~ test_type + (1 | item_id), 
    data = all_df, 
    family = binomial,
    control = control
)

anova(null_model, test_model, test="LRT")

Data: all_df
Models:
null_model: is_correct ~ 1 + (1 | item_id)
test_model: is_correct ~ test_type + (1 | item_id)
           npar   AIC   BIC  logLik -2*log(L)  Chisq Df Pr(>Chisq)    
null_model    2 16396 16411 -8195.9     16392                         
test_model    6 16364 16410 -8176.0     16352 39.764  4  4.843e-08 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1


In [77]:
%%R

r.squaredGLMM(test_model, null_model)

                   R2m       R2c
theoretical 0.08449389 0.4247414
delta       0.07373037 0.3706344


### Comparing task type fixed effect to null model

In [447]:
%%R -i all_df

task_model <- glmer(
    is_correct ~ task_category + (1 | item_id), 
    data = all_df, 
    family = binomial, 
    control = control
)

anova(null_model, task_model, test="LRT")

Data: all_df
Models:
null_model: is_correct ~ 1 + (1 | item_id)
task_model: is_correct ~ task_category + (1 | item_id)
           npar   AIC   BIC  logLik deviance  Chisq Df Pr(>Chisq)    
null_model    2 16396 16411 -8195.9    16392                         
task_model    4 16386 16417 -8189.0    16378 13.847  2  0.0009846 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1


### Comparing graph type fixed effect to null model

In [451]:
%%R -i all_df

control <- glmerControl(
    optimizer = "bobyqa",
    optCtrl = list(maxfun = 2e5)
)

graph_model <- glmer(
    is_correct ~ chart_type + (1 | item_id), 
    data = all_df, 
    family = binomial,
    control = control
)

anova(null_model, graph_model, test="LRT")

Data: all_df
Models:
null_model: is_correct ~ 1 + (1 | item_id)
graph_model: is_correct ~ chart_type + (1 | item_id)
            npar   AIC   BIC  logLik deviance  Chisq Df Pr(>Chisq)   
null_model     2 16396 16411 -8195.9    16392                        
graph_model   14 16391 16498 -8181.4    16363 29.106 12     0.0038 **
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1


### Comparing graph/test type fixed effect to item model

In [526]:
%%R -i all_df

control <- glmerControl(
    optimizer = "bobyqa",
    optCtrl = list(maxfun = 2e5)
)

task_model <- glmer(
    is_correct ~ task_category + (1 | item_id), 
    data = all_df, 
    family = binomial,
    control = control
)

task_graph_model <- glmer(
    is_correct ~ task_category + chart_type + (1 | item_id), 
    data = all_df, 
    family = binomial,
    control = control
)

task_test_model <- glmer(
    is_correct ~ test_type + task_category + (1 | item_id), 
    data = all_df, 
    family = binomial,
    control = control
)

In [527]:
%%R
anova(task_model, task_graph_model, test="LRT")

Data: all_df
Models:
task_model: is_correct ~ task_category + (1 | item_id)
task_graph_model: is_correct ~ task_category + chart_type + (1 | item_id)
                 npar   AIC   BIC  logLik deviance  Chisq Df Pr(>Chisq)   
task_model          4 16386 16417 -8189.0    16378                        
task_graph_model   16 16383 16505 -8175.3    16351 27.387 12   0.006794 **
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1


In [528]:
%%R
anova(task_model, task_test_model, test="LRT")

Data: all_df
Models:
task_model: is_correct ~ task_category + (1 | item_id)
task_test_model: is_correct ~ test_type + task_category + (1 | item_id)
                npar   AIC   BIC  logLik deviance  Chisq Df Pr(>Chisq)    
task_model         4 16386 16417 -8189.0    16378                         
task_test_model    8 16365 16426 -8174.4    16349 29.233  4   7.01e-06 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1


In [554]:
# all_df

### Find model fit + variance

In [661]:
%%R -i all_df

control <- glmerControl(
    optimizer = "bobyqa",
    optCtrl = list(maxfun = 2e5)
)

# Logistic regression with item as a fixed effect
null_model <- glmer(is_correct ~ 1 + (1 | item_id), data = all_df, family = binomial, control = control)
test_model <- glmer(is_correct ~ test_type + (1 | item_id), data = all_df, family = binomial, control = control)
task_model <- glmer(is_correct ~ task_category + (1 | item_id), data = all_df, family = binomial, control = control)
graph_model <- glmer(is_correct ~ chart_type + (1 | item_id), data = all_df, family = binomial, control = control)
graph_test_model <- glmer(is_correct ~ test_type + chart_type + (1 | item_id), data = all_df, family = binomial, control = control)
graph_task_model <- glmer(is_correct ~ task_category + chart_type + (1 | item_id), data = all_df, family = binomial, control = control)
task_test_model <- glmer(is_correct ~ test_type + task_category + (1 | item_id), data = all_df, family = binomial, control = control)
graph_task_test_model <- glmer(is_correct ~ test_type + task_category + chart_type + (1 | item_id), data = all_df, family = binomial, control = control)

In [662]:
%%R -i all_df

control <- glmerControl(
    optimizer = "bobyqa",
    optCtrl = list(maxfun = 2e5)
)

interaction_model <- glmer(
    is_correct ~ test_type*task_category + chart_type + (1 | item_id), 
    data = all_df, family = binomial,
    control = control
)

In [663]:
%%R -i all_df

results <- data.frame(
  Model = character(),
  Marginal_R2 = numeric(),
  Conditional_R2 = numeric(),
  BIC = numeric(),
  AIC = numeric(),
  bootstap_iter = numeric(),
  stringsAsFactors = FALSE
)

# List of models with their formulas
model_formulas <- list(
null_model = is_correct ~ (1 | item_id),
test_model = is_correct ~ test_type + (1 | item_id),
task_model = is_correct ~ task_category + (1 | item_id),
graph_model = is_correct ~ chart_type + (1 | item_id),
graph_test_model = is_correct ~ test_type + chart_type + (1 | item_id),
graph_task_model = is_correct ~ task_category + chart_type + (1 | item_id),
task_test_model = is_correct ~ test_type + task_category + (1 | item_id),
graph_task_test_model = is_correct ~ test_type + task_category + chart_type + (1 | item_id),
interaction_model = is_correct ~ test_type * task_category + chart_type + (1 | item_id)
)

for (model_name in names(model_formulas)) {
formula <- model_formulas[[model_name]]
warning_occurred <- FALSE

model <- tryCatch({
  fit <- glmer(
    formula,
    data = all_df,
    family = binomial,
    control = control
  )
  fit
}, warning = function(w) {
  if (grepl("fixed-effect model matrix is rank deficient", w$message)) {
    warning_occurred <<- TRUE
  }
  invokeRestart("muffleWarning")
})

# Skip to the next iteration if model fitting failed
if (inherits(model, "error")) next

# Compute R-squared, BIC, and AIC
r_squared <- r.squaredGLMM(model, null_model)
bic_value <- BIC(model)
aic_value <- AIC(model)

# Append results to the data frame
results <- rbind(results, data.frame(
  Model = model_name,
  Marginal_R2 = r_squared[1],
  Conditional_R2 = r_squared[2],
  BIC = bic_value,
  AIC = aic_value,
  bootstap_iter = i,
  Warning = warning_occurred  # Store whether warning occurred
))
}

write.csv(results, "./dataframes/model_fit_fulldata.csv", row.names = FALSE)

In [665]:
# %%R
# results

In [25]:
initial_df = all_df.copy()

In [42]:
len(initial_df)

15795

### Bootstrap model fit

In [26]:
%%R -i initial_df

control <- glmerControl(
    optimizer = "bobyqa",
    optCtrl = list(maxfun = 2e5)
)

unique_items <- unique(initial_df$item_id)

# Initialize an empty data frame to store the results
results <- data.frame(
  Model = character(),
  Marginal_R2 = numeric(),
  Conditional_R2 = numeric(),
  BIC = numeric(),
  AIC = numeric(),
  bootstap_iter = numeric(),
  stringsAsFactors = FALSE
)

all_sampled_items <- list()

for (i in 1:100) {
  print(i)
  sampled_items <- sample(unique_items, size = length(unique_items), replace = TRUE)
  all_sampled_items <- c(all_sampled_items, list(sampled_items))

  all_df <- do.call(rbind, lapply(sampled_items, function(item) {
    initial_df[initial_df$item_id == item, ]
  }))

  null_model <- glmer(
    is_correct ~ 1 + (1 | item_id), 
    data = all_df, 
    family = binomial, 
    control = control
  )

  # List of models with their formulas
  model_formulas <- list(
    null_model = is_correct ~ (1 | item_id),
    test_model = is_correct ~ test_type + (1 | item_id),
    task_model = is_correct ~ task_category + (1 | item_id),
    graph_model = is_correct ~ chart_type + (1 | item_id),
    graph_test_model = is_correct ~ test_type + chart_type + (1 | item_id),
    graph_task_model = is_correct ~ task_category + chart_type + (1 | item_id),
    task_test_model = is_correct ~ test_type + task_category + (1 | item_id),
    graph_task_test_model = is_correct ~ test_type + task_category + chart_type + (1 | item_id),
    interaction_model = is_correct ~ test_type * task_category + chart_type + (1 | item_id)
  )

  for (model_name in names(model_formulas)) {
    formula <- model_formulas[[model_name]]
    warning_occurred <- FALSE

    model <- tryCatch({
      fit <- glmer(
        formula,
        data = all_df,
        family = binomial,
        control = control
      )
      fit
    }, warning = function(w) {
      if (grepl("fixed-effect model matrix is rank deficient", w$message)) {
        warning_occurred <<- TRUE
      }
      invokeRestart("muffleWarning")
    })

    # Skip to the next iteration if model fitting failed
    if (inherits(model, "error")) next

    # Compute R-squared, BIC, and AIC
    r_squared <- r.squaredGLMM(model, null_model)
    bic_value <- BIC(model)
    aic_value <- AIC(model)

    # Append results to the data frame
    results <- rbind(results, data.frame(
      Model = model_name,
      Marginal_R2 = r_squared[1],
      Conditional_R2 = r_squared[2],
      BIC = bic_value,
      AIC = aic_value,
      bootstap_iter = i,
      Warning = warning_occurred  # Store whether warning occurred
    ))
  }
}

write.csv(results, "./figures/model_statistics2.csv", row.names = FALSE)

[1] 1
[1] 2
[1] 3
[1] 4
[1] 5
[1] 6
[1] 7
[1] 8
[1] 9
[1] 10
[1] 11
[1] 12
[1] 13
[1] 14
[1] 15
[1] 16
[1] 17
[1] 18
[1] 19
[1] 20
[1] 21
[1] 22
[1] 23
[1] 24
[1] 25
[1] 26
[1] 27
[1] 28
[1] 29
[1] 30
[1] 31
[1] 32
[1] 33
[1] 34
[1] 35
[1] 36
[1] 37
[1] 38
[1] 39
[1] 40
[1] 41
[1] 42
[1] 43
[1] 44
[1] 45
[1] 46
[1] 47
[1] 48
[1] 49
[1] 50
[1] 51
[1] 52
[1] 53
[1] 54
[1] 55
[1] 56
[1] 57
[1] 58
[1] 59
[1] 60
[1] 61
[1] 62
[1] 63
[1] 64
[1] 65
[1] 66
[1] 67
[1] 68
[1] 69
[1] 70
[1] 71
[1] 72
[1] 73
[1] 74
[1] 75
[1] 76
[1] 77
[1] 78
[1] 79
[1] 80
[1] 81
[1] 82
[1] 83
[1] 84
[1] 85
[1] 86
[1] 87
[1] 88
[1] 89
[1] 90
[1] 91
[1] 92
[1] 93
[1] 94
[1] 95
[1] 96
[1] 97
[1] 98
[1] 99
[1] 100


fixed-effect model matrix is rank deficient so dropping 1 column / coefficient
fixed-effect model matrix is rank deficient so dropping 1 column / coefficient
fixed-effect model matrix is rank deficient so dropping 1 column / coefficient
fixed-effect model matrix is rank deficient so dropping 1 column / coefficient
fixed-effect model matrix is rank deficient so dropping 1 column / coefficient
fixed-effect model matrix is rank deficient so dropping 1 column / coefficient
fixed-effect model matrix is rank deficient so dropping 1 column / coefficient
fixed-effect model matrix is rank deficient so dropping 1 column / coefficient
fixed-effect model matrix is rank deficient so dropping 1 column / coefficient
fixed-effect model matrix is rank deficient so dropping 1 column / coefficient
fixed-effect model matrix is rank deficient so dropping 1 column / coefficient
fixed-effect model matrix is rank deficient so dropping 1 column / coefficient
fixed-effect model matrix is rank deficient so dropp

In [54]:
%%R

my_list <- all_sampled_items
num_columns <- max(sapply(my_list, length))
column_names <- paste0("Column_", seq_len(num_columns))

df <- do.call(rbind, lapply(my_list, function(x) {
  length(x) <- num_columns
  x
}))

# Assign column names to the dataframe
colnames(df) <- column_names

# Convert to a data.frame and ensure appropriate types
df <- as.data.frame(df, stringsAsFactors = FALSE)

# Save the dataframe as a CSV file
write.csv(df, "./model_bootstrapped_items.csv", row.names = FALSE)

In [61]:
model_bootstrapped_items = pd.read_csv("model_bootstrapped_items.csv")

### Plot Model Fit

In [274]:
valid_interaction_combos = all_df[['item_id', 'test_type', 'task_category']].drop_duplicates()
valid_interaction_combos

Unnamed: 0,item_id,test_type,task_category
0,How much does it rain in March? + line.png,wainer,value identification
82,Which month has 25 mm of rain? + line.png,wainer,value identification
162,How many months have less than 40mm of rain? +...,wainer,arithmetic computation
249,Which season has the most rain? + line.png,wainer,value identification
328,In which season does each month have less rain...,wainer,arithmetic computation
...,...,...,...
18767,Which of the following is true about the resid...,calvi,statistical inference
18850,Which of the following is true about the marke...,calvi,arithmetic computation
18929,Which of the following is true about the educa...,calvi,arithmetic computation
19011,Which of the following states has the highest ...,calvi,arithmetic computation


In [279]:
# items

In [305]:
invalid_interaction_runs = []
for i, row in pd.read_csv("./model_bootstrapped_items.csv").iterrows():
    item = pd.DataFrame({'item_id': row.to_list()}).drop_duplicates()
    test_task_interaction_counts = pd.merge(valid_interaction_combos, item)[['test_type', 'task_category']].value_counts()
    test_task_interaction_counts = test_task_interaction_counts.reset_index()
    test_task_interaction_counts = test_task_interaction_counts.pivot(index='test_type', columns='task_category', values='count')

    if (len(test_task_interaction_counts) != len(test_task_interaction_counts.dropna())):
        invalid_interaction_runs.append(i)
        print("Invalid bootstrap for interaction model found: Row ", i)
print("Total runs invalid: ", len(invalid_interaction_runs))

Invalid bootstrap for interaction model found: Row  5
Invalid bootstrap for interaction model found: Row  8
Invalid bootstrap for interaction model found: Row  11
Invalid bootstrap for interaction model found: Row  17
Invalid bootstrap for interaction model found: Row  30
Invalid bootstrap for interaction model found: Row  32
Invalid bootstrap for interaction model found: Row  36
Invalid bootstrap for interaction model found: Row  37
Invalid bootstrap for interaction model found: Row  43
Invalid bootstrap for interaction model found: Row  48
Invalid bootstrap for interaction model found: Row  52
Invalid bootstrap for interaction model found: Row  58
Invalid bootstrap for interaction model found: Row  73
Invalid bootstrap for interaction model found: Row  74
Invalid bootstrap for interaction model found: Row  85
Invalid bootstrap for interaction model found: Row  86
Invalid bootstrap for interaction model found: Row  97
Total runs invalid:  17


In [354]:
model_statistics_allbs = pd.read_csv("./figures/model_statistics2.csv")
model_statistics_allbs = model_statistics_allbs[~model_statistics_allbs['bootstap_iter'].isin(invalid_interaction_runs)]
len(model_statistics_allbs['bootstap_iter'].unique())

83

In [413]:
model_statistics_allbs

Unnamed: 0,Model,Marginal_R2,Conditional_R2,BIC,AIC,bootstap_iter,Warning,relative_aic
0,null_model,0.000000,0.000000,16546.932759,16531.596849,1,False,0.000000
1,test_model,0.078278,0.069070,16559.373119,16513.365388,1,False,-18.231461
2,task_model,0.032751,0.028921,16556.511238,16525.839417,1,False,-5.757432
3,graph_model,0.072501,0.064033,16641.274538,16533.923167,1,False,2.326319
4,graph_test_model,0.132326,0.116837,16657.822654,16519.799462,1,False,-11.797387
...,...,...,...,...,...,...,...,...
895,graph_test_model,0.154736,0.134961,16171.021067,16040.590695,100,False,-16.804302
896,graph_task_model,0.090404,0.078771,16171.735529,16056.649907,100,False,-0.745090
897,task_test_model,0.099425,0.086626,16101.675058,16040.296059,100,False,-17.098937
898,graph_task_test_model,0.155241,0.135383,16189.818756,16044.043634,100,False,-13.351362


In [397]:
null_model = model_statistics_allbs[model_statistics_allbs['Model'] == 'null_model']

def get_bootstrap_iter_aic(r):
    aic = r['AIC']
    null_aic = null_model[null_model['bootstap_iter'] == r['bootstap_iter']]['AIC'].iloc[0]
    return aic - null_aic

model_statistics_allbs['relative_aic'] = model_statistics_allbs.apply(
    get_bootstrap_iter_aic,
    axis=1
)

In [674]:
model_statistics = model_statistics_allbs.copy()
model_statistics = model_statistics[model_statistics['Model'] != 'null_model']
model_statistics_ci = model_statistics.groupby('Model').agg(
    bic_mean=('BIC', 'mean'),
    bic_std=('BIC', 'std'),
    relative_aic_mean=('relative_aic', 'mean'),
    relative_aic_std=('relative_aic', 'std'),
    mr2_mean=('Marginal_R2', 'mean'),
    mr2_std=('Marginal_R2', 'std'),
    mr2_ci_upper=('Marginal_R2', lambda m : np.percentile(m, 97.5)),
    mr2_ci_lower=('Marginal_R2', lambda m : np.percentile(m, 2.5 )),
    cr2_mean=('Conditional_R2', 'mean'),
    cr2_std=('Conditional_R2', 'std'),
).reset_index()

model_statistics_ci['mr2_ystd'] = model_statistics_ci['mr2_mean'] - model_statistics_ci['mr2_std']
model_statistics_ci['mr2_ystd2'] = model_statistics_ci['mr2_mean'] + model_statistics_ci['mr2_std']

model_statistics_ci['cr2_ystd'] = model_statistics_ci['cr2_mean'] - model_statistics_ci['cr2_std']
model_statistics_ci['cr2_ystd2'] = model_statistics_ci['cr2_mean'] + model_statistics_ci['cr2_std']

model_statistics_ci['relative_aic_ystd'] = model_statistics_ci['relative_aic_mean'] - model_statistics_ci['relative_aic_std']
model_statistics_ci['relative_aic_ystd2'] = model_statistics_ci['relative_aic_mean'] + model_statistics_ci['relative_aic_std']

# replace means with full data fit
model_fit_on_fulldata = pd.read_csv("./dataframes/model_fit_fulldata.csv")
model_fit_on_fulldata = model_fit_on_fulldata.rename(columns={'Marginal_R2': 'mr2_fulldata'})
model_statistics_ci = pd.merge(model_statistics_ci, model_fit_on_fulldata[['Model', 'mr2_fulldata']])

In [680]:
# model_fit_on_fulldata
model_statistics_ci

Unnamed: 0,Model,bic_mean,bic_std,relative_aic_mean,relative_aic_std,mr2_mean,mr2_std,mr2_ci_upper,mr2_ci_lower,cr2_mean,cr2_std,mr2_ystd,mr2_ystd2,cr2_ystd,cr2_ystd2,relative_aic_ystd,relative_aic_ystd2,mr2_fulldata
0,graph_model,16246.977604,375.28034,-0.487784,6.113237,0.079913,0.018858,0.117538,0.048798,0.069699,0.016427,0.061054,0.098771,0.053272,0.086126,-6.60102,5.625453,0.063993
1,graph_task_model,16257.822767,375.733654,-4.977314,7.354326,0.104436,0.021866,0.149406,0.060948,0.091092,0.019078,0.08257,0.126303,0.072014,0.11017,-12.33164,2.377013,0.088065
2,graph_task_test_model,16278.721705,376.165084,-14.747764,8.474622,0.150874,0.02366,0.190158,0.098778,0.131634,0.02099,0.127214,0.174534,0.110644,0.152624,-23.222387,-6.273142,0.131151
3,graph_test_model,16262.6691,375.974696,-15.465675,8.481975,0.142899,0.023356,0.184254,0.096595,0.124677,0.020709,0.119542,0.166255,0.103969,0.145386,-23.94765,-6.9837,0.123841
4,interaction_model,16340.614974,374.798605,-12.900075,9.315639,0.183741,0.023285,0.226307,0.140717,0.160325,0.020818,0.160456,0.207025,0.139507,0.181143,-22.215714,-3.584435,0.162607
5,task_model,16166.503043,375.538786,-5.951764,4.120332,0.035192,0.014322,0.061799,0.01153,0.030708,0.012523,0.02087,0.049514,0.018185,0.04323,-10.072097,-1.831432,0.031744
6,task_test_model,16184.828572,375.759524,-18.295623,7.421085,0.098251,0.023149,0.139012,0.058304,0.085771,0.020488,0.075103,0.1214,0.065283,0.106259,-25.716708,-10.874538,0.090818
7,test_model,16168.421287,375.726353,-19.368215,7.132049,0.08995,0.022403,0.13115,0.050946,0.078532,0.019826,0.067547,0.112353,0.058706,0.098358,-26.500264,-12.236165,0.084494


In [670]:
# model_fit_on_fulldata

In [677]:
# model_statistics_ci

#### Get Split-half R^2 

In [500]:
split_half_correlations = []
for _ in range(10000):
    correlation_df = all_df.sample(frac=1, replace=False)
    half_one = correlation_df.iloc[0:(len(all_df) // 2)].groupby('item_id')['is_correct']
    half_one = half_one.mean().reset_index().sort_values(by='item_id')
    
    half_two = correlation_df.iloc[(len(all_df) // 2):len(all_df)].groupby('item_id')['is_correct']
    half_two = half_two.mean().reset_index().sort_values(by='item_id')
    
    iter_corr = np.corrcoef(half_one['is_correct'], half_two['is_correct'])
    split_half_correlations.append(iter_corr[0][1]**2)

split_half_mean = np.mean(split_half_correlations)
split_half_ci_lower, split_half_ci_upper = np.percentile(split_half_correlations, [2.5, 97.5])
split_half_df = pd.DataFrame({
    'index': [0],
    'mean': split_half_mean,
    'ci_upper': split_half_ci_upper,
    'ci_lower': split_half_ci_lower,
})

In [660]:
split_half_df

Unnamed: 0,index,mean,ci_upper,ci_lower
0,0,0.906834,0.92521,0.886527


In [682]:
model_order = [
    'test_model',
    'task_model',
    'graph_model',
    'task_test_model',
    'graph_test_model',
    'graph_task_model',     
    'graph_task_test_model',
    'interaction_model'
]

line_path = "M -3.5,0 L 3.5,0"

def plot_model_fit(measure, yscale):
    mean_model_plot = alt.Chart(model_statistics_ci).mark_point(color='black', shape=line_path).encode(
        x=alt.X('Model:N', scale=alt.Scale(domain=model_order)),
        y=alt.Y(f'{measure}_fulldata:Q', scale=alt.Scale(domain=yscale), title=""),
        strokeWidth=alt.value(2)
    )
    
    std_model_plot = alt.Chart(model_statistics_ci).mark_bar(color='black', opacity=0.1).encode(
        x=alt.X('Model:N', scale=alt.Scale(domain=model_order)),
        y=alt.Y(f'{measure}_ci_upper:Q'),
        y2=(f'{measure}_ci_lower:Q'),
        size=alt.value(20)
    )

    split_half_band = alt.Chart(split_half_df).mark_rect(color='black', opacity=0.2).encode(
        y=alt.Y(f'ci_upper:Q'),
        y2=(f'ci_lower:Q'),
    )

    split_half_mean = alt.Chart(split_half_df).mark_rule(color='black', strokeDash=[5, 5]).encode(
        y=alt.Y(f'mean:Q'),
        strokeWidth=alt.value(1)
    )
    
    model_fit_figure = (mean_model_plot + std_model_plot + split_half_band + split_half_mean).properties(
        height=200,
        width=200
    )

    return model_fit_figure

# mr2_fulldata
model_fit_figure = plot_model_fit('mr2', yscale=[0, 1])
model_fit_figure.save("./figures/model_fit.pdf")
model_fit_figure

In [688]:
model_statistics_ci[['Model', 'mr2_fulldata', 'mr2_ci_upper', 'mr2_ci_lower']]

Unnamed: 0,Model,mr2_fulldata,mr2_ci_upper,mr2_ci_lower
0,graph_model,0.063993,0.117538,0.048798
1,graph_task_model,0.088065,0.149406,0.060948
2,graph_task_test_model,0.131151,0.190158,0.098778
3,graph_test_model,0.123841,0.184254,0.096595
4,interaction_model,0.162607,0.226307,0.140717
5,task_model,0.031744,0.061799,0.01153
6,task_test_model,0.090818,0.139012,0.058304
7,test_model,0.084494,0.13115,0.050946


In [54]:
temp = item_df[['test_type', 'task_category']].value_counts().reset_index()
temp.pivot(index='test_type', columns='task_category', values='count')

task_category,arithmetic-computation,statistical-inference,value-identification
test_type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
brbf,12,24,36
calvi,32,26,2
ggr-mc,4,4,5
vlat,12,11,30
wainer,12,8,12


In [2015]:
model_statistics = pd.read_csv("./figures/model_statistics.csv")

Unnamed: 0,Model,Marginal_R2,Conditional_R2,BIC,AIC
0,test_model,0.084495,0.073731,16410.062809,16364.058116
1,task_model,0.031745,0.027708,16416.645323,16385.975529
2,graph_model,0.063993,0.055838,16498.060456,16390.716174
3,graph_test_model,0.123848,0.10806,16504.394309,16366.380233
4,graph_task_model,0.088041,0.076828,16505.267872,16382.588693
5,task_test_model,0.090817,0.07925,16426.082211,16364.742621
6,graph_task_test_model,0.131133,0.114423,16519.452036,16366.103061
7,interaction_model,0.16328,0.142483,16692.847284,16386.149335


# Cross validation across top-down categories

### Preparing dataframe

In [1998]:
def add_brbf_categories(r):
    if (r['test_type'] == 'brbf'):
        return '-'.join(r['image_file'].split("-")[:-1])
    return r['chart_type']


all_df = performance_df.copy()
all_df['chart_type_filled'] = all_df.apply(add_brbf_categories, axis=1)
all_df['chart_type'] = all_df['chart_type_filled'].replace(chart_categories)
all_df = all_df[all_df['chart_type'] != 'Table']

all_df['task_category'] = all_df['task_category'].apply(
    lambda t : task_category_map[t] if t in task_category_map.keys() else t
)
all_df['task_category'] = all_df.apply(
    lambda r : task_category_map[r['question']] if r['question'] in task_category_map.keys() else r['task_category'],
    axis=1
)

all_df['item_id'] = all_df['question_image']
all_df = all_df[['test_type', 'chart_type', 'task_category', 'item_id', 'is_correct']]

### Load R

In [52]:
%load_ext rpy2.ipython

In [53]:
%%R
install.packages("lme4")
install.packages("caret")
install.packages("dplyr")

--- Please select a CRAN mirror for use in this session ---
Secure CRAN mirrors 

 1: 0-Cloud [https]
 2: Australia (Canberra) [https]
 3: Australia (Melbourne 1) [https]
 4: Australia (Melbourne 2) [https]
 5: Austria (Wien 1) [https]
 6: Belgium (Brussels) [https]
 7: Brazil (PR) [https]
 8: Brazil (SP 1) [https]
 9: Brazil (SP 2) [https]
10: Bulgaria [https]
11: Canada (MB) [https]
12: Canada (ON 1) [https]
13: Canada (ON 2) [https]
14: Chile (Santiago) [https]
15: China (Beijing 2) [https]
16: China (Beijing 3) [https]
17: China (Hefei) [https]
18: China (Hong Kong) [https]
19: China (Jinan) [https]
20: China (Lanzhou) [https]
21: China (Nanjing) [https]
22: China (Shanghai 2) [https]
23: China (Shenzhen) [https]
24: China (Wuhan) [https]
25: Colombia (Cali) [https]
26: Costa Rica [https]
27: Cyprus [https]
28: Czech Republic [https]
29: Denmark [https]
30: East Asia [https]
31: Ecuador (Cuenca) [https]
32: France (Lyon 1) [https]
33: France (Lyon 2) [https]
34: France (Marseille) 

Selection:  68



The downloaded binary packages are in
	/var/folders/v8/3zpbxkws53b3x6m8509jyml80000gn/T//RtmpcGPK4J/downloaded_packages

The downloaded binary packages are in
	/var/folders/v8/3zpbxkws53b3x6m8509jyml80000gn/T//RtmpcGPK4J/downloaded_packages

The downloaded binary packages are in
	/var/folders/v8/3zpbxkws53b3x6m8509jyml80000gn/T//RtmpcGPK4J/downloaded_packages


trying URL 'https://ftp.osuosl.org/pub/cran/bin/macosx/big-sur-arm64/contrib/4.4/lme4_1.1-36.tgz'
Content type 'application/x-gzip' length 7079300 bytes (6.8 MB)
downloaded 6.8 MB

trying URL 'https://ftp.osuosl.org/pub/cran/bin/macosx/big-sur-arm64/contrib/4.4/caret_7.0-1.tgz'
Content type 'application/x-gzip' length 3590247 bytes (3.4 MB)
downloaded 3.4 MB

trying URL 'https://ftp.osuosl.org/pub/cran/bin/macosx/big-sur-arm64/contrib/4.4/dplyr_1.1.4.tgz'
Content type 'application/x-gzip' length 1599250 bytes (1.5 MB)
downloaded 1.5 MB

In doTryCatch(return(expr), name, parentenv, handler) :
  unable to load shared object '/Library/Frameworks/R.framework/Resources/modules//R_X11.so':
  dlopen(/Library/Frameworks/R.framework/Resources/modules//R_X11.so, 0x0006): Library not loaded: /opt/X11/lib/libSM.6.dylib
  Referenced from: <34C5A480-1AC4-30DF-83C9-30A913FC042E> /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/modules/R_X11.so
  Reason: tried: '/opt/X11/lib/libSM.6.dylib'

In [None]:
library(lme4)

### Cross-fold eval

In [2016]:
def create_item_accuracy_chart(ci_df, domain_order):
    """ Generate chart for point estimate and confidence intervals overlayed with item level scatteplots
    """
    # Create dot+error bar plot
    chart = alt.Chart(ci_df).mark_bar(opacity=0.7).encode(
        x=alt.X('category:N', title='Category', scale=alt.Scale(domain=domain_order), axis=None),
        y=alt.Y('mean:Q', title='Prop. Correct'),
        color=alt.Color('test_type:N', scale=alt.Scale(domain=test_domain, range=test_domain_color), legend=None),
    ) 
    # + alt.Chart(ci_df).mark_rule(strokeWidth=0.4, opacity=1).encode(
    #     x=alt.X('category:N'),
    #     y=alt.Y('ci_lower:Q'),
    #     y2='ci_upper:Q',
    #     color=alt.value("#717d7e")
    #     # color=alt.Color('test_type:N', scale=alt.Scale(domain=test_domain, range=test_domain_color), legend=None),
    # )
    
    return chart.properties(width=1200, height=200)

In [None]:
%%R

library(lme4)    # For mixed-effects models
library(caret)   # For creating cross-validation folds
library(dplyr)   # For data manipulation

# Assume all_df is your dataset
# all_df <- read.csv("your_data.csv")

# Define the model formula
model_formula <- is_correct ~ test_type + task_category * chart_type + (1 | item_id)

# Set up 5-fold cross-validation
set.seed(42)  # For reproducibility
folds <- createFolds(all_df$is_correct, k = 5, list = TRUE, returnTrain = TRUE)

# Initialize vectors to store results
log_likelihoods <- numeric(length(folds))
accuracies <- numeric(length(folds))

# Perform 5-fold cross-validation
for (i in seq_along(folds)) {
  # Split the data into training and test sets
  train_indices <- folds[[i]]
  train_data <- all_df[train_indices, ]
  test_data <- all_df[-train_indices, ]
  
  # Fit the model on the training data
  model <- glmer(model_formula, data = train_data, family = binomial)
  
  # Predict probabilities on the test data
  test_data$predicted <- predict(model, newdata = test_data, type = "response")
  
  # Convert probabilities to binary predictions (threshold = 0.5)
  test_data$predicted_binary <- ifelse(test_data$predicted > 0.5, 1, 0)
  
  # Calculate log-likelihood for this fold
  log_likelihoods[i] <- logLik(model)
  
  # Calculate accuracy for this fold
  accuracies[i] <- mean(test_data$is_correct == test_data$predicted_binary)
}

# Summary of cross-validation results
mean_log_likelihood <- mean(log_likelihoods)
mean_accuracy <- mean(accuracies)

# Print results
cat("Mean Log-Likelihood:", mean_log_likelihood, "\n")
cat("Mean Accuracy:", mean_accuracy, "\n")


In [2006]:
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold
from pymer4.models import Lmer

# Load your data (replace this with your actual DataFrame)
# Example: all_df = pd.read_csv('your_data.csv')
all_df = ...  # Your data here

# Define the model formula
model_formula = "is_correct ~ test_type + task_category * chart_type + (1|item_id)"

# Initialize KFold cross-validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)

# To store results
log_likelihoods = []
accuracy_scores = []

# Perform 5-fold cross-validation
for train_index, test_index in kf.split(all_df):
    # Split the data into training and test sets
    train_data = all_df.iloc[train_index]
    test_data = all_df.iloc[test_index]
    
    # Fit the mixed-effects model on the training data
    model = Lmer(model_formula, data=train_data, family='binomial')
    model.fit(summarize=False)
    
    # Predict probabilities on the test data
    test_data['predicted'] = model.predict(test_data)
    
    # Convert probabilities to binary predictions (threshold = 0.5)
    test_data['predicted_binary'] = (test_data['predicted'] > 0.5).astype(int)
    
    # Calculate log-likelihood (or another scoring metric)
    log_likelihood = model.logLike
    log_likelihoods.append(log_likelihood)
    
    # Calculate accuracy for this fold
    accuracy = np.mean(test_data['is_correct'] == test_data['predicted_binary'])
    accuracy_scores.append(accuracy)

# Summary of cross-validation results
mean_log_likelihood = np.mean(log_likelihoods)
mean_accuracy = np.mean(accuracy_scores)

print(f"Mean Log-Likelihood: {mean_log_likelihood}")
print(f"Mean Accuracy: {mean_accuracy}")

ModuleNotFoundError: No module named 'pymer4'

In [None]:
# nosise ceiling 
# 

In [None]:
def cross_fold_performance_evaluation():
    regularization_param = 10**(-3)
    max_iter=1000
    classifier_type = 'multiclass_logistic'

    model = LogisticRegression(random_state=0, C=regularization_param, max_iter=max_iter)
    
    X = all_df[[]]
    values = []
    cross_val_values = []

    for prediction_variable in ['task_category', 'test_type', 'chart_type']:
        y = all_df[prediction_variable].to_numpy()
        scores = cross_val_score(estimator=model, X=X, y=y, cv=5)
        for i, score in enumerate(scores):
            row = [backbone, prediction_variable, classifier_type, regularization_param, score, i]
            cross_val_values.append(row)
        score_mean = np.mean(scores)
        score_std = np.std(scores)
        row = [backbone, prediction_variable, classifier_type, regularization_param, score_mean, score_std]
        values.append(row)
        # print(row)

        
    crossval_df = pd.DataFrame(
        cross_val_values, 
        columns=['backbone', 'prediction_variable', 'classifier_type', 'regularization_param', 'crossval_acc', 'crossval_fold']
    )
    val_df = pd.DataFrame(
        values, 
        columns=['backbone', 'prediction_variable', 'classifier_type', 'regularization_param', 'crossval_mean', 'crossval_std']
    )
    
    return val_df, crossval_df

# Calssification of category of embedding

In [24]:
# question_embeddings = torch.load("/Users/arnav/Desktop/vt-fusion/analysis/embedding_extraction/embeddings_pt/question_embeddings.pt", weights_only=False) 
# image_embeddings = torch.load("/Users/arnav/Desktop/vt-fusion/analysis/embedding_extraction/embeddings_pt/image_embeddings.pt", weights_only=False)
# combined_embeddings = torch.load("/Users/arnav/Desktop/vt-fusion/analysis/embedding_extraction/embeddings_pt/combined_embeddings.pt", weights_only=False) 

combined_embeddings = torch.load("/Users/arnav/Desktop/vt-fusion/data/embedding/multimodal_embeddings.pt", weights_only=False)
assert combined_embeddings.shape == (230, 4096)

In [25]:
# all_df[prediction_variable]

In [26]:
def create_item_df():
    tests = ['wainer', 'brbf', 'ggr-mc', 'vlat', 'calvi']
    item_df = []
    for test in tests:
        idf = pd.read_csv(f'https://data-visualization-benchmark.s3.us-west-2.amazonaws.com/{test}/questions.csv')
        idf['test_type'] = test
        item_df.append(idf)

    item_df = pd.concat(item_df)
    item_df['question_image'] = item_df['question'] + " + " + item_df['image_file']
    return item_df
    
def add_brbf_categories(r):
    if (r['test_type'] == 'brbf'):
        return '-'.join(r['image_file'].split("-")[:-1])
    return r['chart_type']


item_df = create_item_df()
item_df['chart_type_filled'] = item_df.apply(add_brbf_categories, axis=1)
item_df['chart_type'] = item_df['chart_type_filled'].replace(chart_categories)
# item_df = item_df[item_df['chart_type'] != 'Table']

item_df['task_category'] = item_df['task_category'].apply(
    lambda t : task_category_map[t] if t in task_category_map.keys() else t
)
item_df['task_category'] = item_df.apply(
    lambda r : task_category_map[r['question']] if r['question'] in task_category_map.keys() else r['task_category'],
    axis=1
)
item_df['task_category'] = item_df['task_category'].apply(lambda x : x.replace(" ", "-")).to_numpy()
item_df['chart_type'] = item_df['chart_type'].apply(lambda x : x.replace(" ", "-")).to_numpy()

item_df['item_id'] = item_df['question_image']

print(item_df['chart_type'].unique(), item_df['task_category'].unique(), item_df['test_type'].unique())

['Line' 'Bar' 'Radial' 'Table' 'Scatter' 'Pie' 'Dot-Plot' 'Stacked-Bar'
 '100%-Stacked-Bar' 'Histogram' 'Area' 'Stacked-Area' 'Map' 'Treemap'] ['value-identification' 'arithmetic-computation' 'statistical-inference'] ['wainer' 'brbf' 'ggr-mc' 'vlat' 'calvi']


In [27]:
item_df_indexed = item_df.reset_index().rename(columns={'index': 'embedding_index'})

In [33]:
item_df

Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,mc1,mc2,mc3,mc4,question,correct_answer,chart_type,task_category,...,Unnamed: 12,misleader_type,misleading_answer,misleading_answer_2,task_category_2,task_category_3,correct_answer_2,question_image,chart_type_filled,item_id
0,0.0,,30 mm,40 mm,50 mm,60 mm,How much does it rain in March?,60 mm,Line,value-identification,...,,,,,,,,How much does it rain in March? + line.png,line,How much does it rain in March? + line.png
1,1.0,,March,April,October,November,Which month has 25 mm of rain?,October,Line,value-identification,...,,,,,,,,Which month has 25 mm of rain? + line.png,line,Which month has 25 mm of rain? + line.png
2,2.0,,5,6,7,8,How many months have less than 40mm of rain?,6,Line,arithmetic-computation,...,,,,,,,,How many months have less than 40mm of rain? +...,line,How many months have less than 40mm of rain? +...
3,3.0,,Winter,Spring,Summer,Fall,Which season has the most rain?,Winter,Line,value-identification,...,,,,,,,,Which season has the most rain? + line.png,line,Which season has the most rain? + line.png
4,4.0,,Winter,Spring,Summer,Fall,In which season does each month have less rain...,Spring,Line,arithmetic-computation,...,,,,,,,,In which season does each month have less rain...,line,In which season does each month have less rain...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
55,,55.0,There is a negative relationship between sleep...,There is a positive relationship between sleep...,Most residents in town X spend more than 6 hou...,None of the above.,Which of the following is true about the resid...,There is a positive relationship between sleep...,Scatter,statistical-inference,...,,,,,Make Comparisons,,,Which of the following is true about the resid...,Scatterplot,Which of the following is true about the resid...
56,,56.0,Store A has less than half of the total market...,Store C has a higher market share than store B.,Store B has the second highest market share.,None of the above.,Which of the following is true about the marke...,Store B has the second highest market share.,Pie,arithmetic-computation,...,,,,,,,,Which of the following is true about the marke...,Pie chart,Which of the following is true about the marke...
57,,57.0,More than 60% of residents have a Bachelor's d...,The majority of residents have a graduate degree.,There are fewer residents whose education leve...,None of the above.,Which of the following is true about the educa...,More than 60% of residents have a Bachelor's d...,Pie,arithmetic-computation,...,,,,,,,,Which of the following is true about the educa...,Pie chart,Which of the following is true about the educa...
58,,58.0,AZ,NV,TN,Cannot be inferred / inadequate information,Which of the following states has the highest ...,AZ,Map,arithmetic-computation,...,,,,,,,,Which of the following states has the highest ...,Choropleth map,Which of the following states has the highest ...


In [555]:
item_response_classification_df = all_df.copy()
item_response_classification_df = pd.merge(item_response_classification_df, item_df_indexed[['item_id', 'embedding_index']])

array([ 0,  0,  0, ..., 56, 56, 57])

In [559]:
.shape

torch.Size([15795, 4096])

In [561]:
# for regularization_param in [10**(-3), 10**(0), 10**(3)]:

def cross_fold_evaluation(embeddings, item_df, backbone = 'all-mpnet-base-v2'):

    regularization_param = 10**(-3)
    run_crossfold_val = True
    model = LogisticRegression(random_state=0, C=regularization_param, max_iter=1000)
    classifier_type = 'multiclass_logistic'
    
    X = embeddings
    values = []
    cross_val_values = []

    for prediction_variable in ['task_category', 'test_type', 'chart_type']:
        y = item_df[prediction_variable].to_numpy()
        
        if run_crossfold_val:
            scores = cross_val_score(estimator=model, X=X, y=y, cv=5)
            for i, score in enumerate(scores):
                row = [backbone, prediction_variable, classifier_type, regularization_param, score, i]
                cross_val_values.append(row)
            score_mean = np.mean(scores)
            score_std = np.std(scores)
            row = [backbone, prediction_variable, classifier_type, regularization_param, score_mean, score_std]
            values.append(row)
            # print(row)
        else:
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=743)
            model.fit(X_train, y_train)
            y_pred = model.predict(X_test)
            accuracy = accuracy_score(y_test, y_pred)
            row = [backbone, prediction_variable, classifier_type, regularization_param, accuracy, 1]
            values.append(row)
            # print(row)
    
    
    crossval_df = pd.DataFrame(
        cross_val_values, 
        columns=['backbone', 'prediction_variable', 'classifier_type', 'regularization_param', 'crossval_acc', 'crossval_fold']
    )
    val_df = pd.DataFrame(
        values, 
        columns=['backbone', 'prediction_variable', 'classifier_type', 'regularization_param', 'crossval_mean', 'crossval_std']
    )
    
    return val_df

In [564]:
 mm_embedding_df = cross_fold_evaluation(
     combined_embeddings[item_response_classification_df['embedding_index'].to_numpy()], 
     item_response_classification_df, 
     backbone='meta-llama'
 )

In [565]:
mm_embedding_df

Unnamed: 0,backbone,prediction_variable,classifier_type,regularization_param,crossval_mean,crossval_std
0,meta-llama,task_category,multiclass_logistic,0.001,0.337575,0.027892
1,meta-llama,test_type,multiclass_logistic,0.001,0.113137,0.121473
2,meta-llama,chart_type,multiclass_logistic,0.001,0.254574,0.029522


In [1204]:
backbones = ['all-mpnet-base-v2', 'ViT/L', 'combined']
prediction_variables = ['test_type', 'task_category', 'chart_type']

error_plot = alt.Chart(embedding_df).mark_circle().encode(
    x=alt.X('prediction_variable:N', title=None, scale=alt.Scale(domain=prediction_variables)),
    y=alt.Y('crossval_mean:Q', scale=alt.Scale(domain=[0,1])),
    color='backbone:N',
    xOffset=alt.XOffset('backbone:N', scale=alt.Scale(domain=backbones) )
).properties(
    width=300,
    height=200,
)

error_bars = alt.Chart(embedding_df).mark_errorbar(extent='stdev').encode(
    x='prediction_variable:N',
    y='crossval_mean:Q',
    yError='crossval_std:Q',
    color='backbone:N',
    xOffset=alt.XOffset('backbone:N')
)


# Altair chart using the custom SVG shape
# dashed_line = alt.Chart(embedding_df).mark_point(shape=dashed_line_svg)

dashed_line = alt.Chart(embedding_df).mark_text(
    text='------------',  
    dx=5,  
    size=15, 
    color='black',
    opacity=0.1
).transform_calculate(
    calculated_y='datum.prediction_variable == "task_category" ? (1/3) : datum.prediction_variable == "test_type" ? (1/5) : (1/14)'
).encode(
    x=alt.X('prediction_variable:N'),
    y=alt.Y('calculated_y:Q')
)


plot = error_plot + error_bars + dashed_line
plot.properties(width=200)

In [1065]:
%%R -i all_df

# Model 1: Random intercepts model
# model1 <- lmer(accuracy ~ 1 + (1|target_word), data=corr_df)

# Model 2: Model with agentType as a fixed effect
# model2 <- lmer(is_correct ~ agentType + (1|target_word), data=corr_df)

# summary(model1)
# summary(model2)

# Logistic regression with item as a fixed effect

null_model <- glmer(is_correct ~ 1 + (1 | item_id), data = all_df, family = binomial)
test_model <- glmer(is_correct ~ test_type + (1 | item_id), data = all_df, family = binomial)
task_model <- glmer(is_correct ~ task_category + (1 | item_id), data = all_df, family = binomial)
graph_model <- glmer(is_correct ~ chart_type + (1 | item_id), data = all_df, family = binomial)
graph_test_model <- glmer(is_correct ~ test_type + chart_type + (1 | item_id), data = all_df, family = binomial)
graph_task_model <- glmer(is_correct ~ task_category + chart_type + (1 | item_id), data = all_df, family = binomial)
task_test_model <- glmer(is_correct ~ test_type + task_category + (1 | item_id), data = all_df, family = binomial)
graph_task_test_model <- glmer(is_correct ~ test_type + task_category + chart_type + (1 | item_id), data = all_df, family = binomial)

KeyError: 'index'

# Participant consistency across tests

Its hard to do this as each participant only does a partial version of the test and may receive easier / harder questions
- do we do this across paritions (e.g. people who answered the same item
- might not be ideal ... vastly different items on tests

In [1328]:
# len(data.dropna())

In [315]:
tests = ['ggr-mc', 'vlat', 'brbf', 'wainer', 'calvi']
data = performance_df.groupby(['participant_id', 'test_type']).mean('is_correct').reset_index()
data = data.pivot(index='participant_id', columns='test_type', values='is_correct').reset_index().dropna()

corr_df = []
for i, t1 in enumerate(tests):
    for j, t2 in enumerate(tests):
        if i <= j:  # only include the bottom triangle
            corr = np.corrcoef(data[t1], data[t2])[0][1]
            corr_df.append({
                'corr': corr,
                'test1': t1,
                'test2': t2
            })

corr_df = pd.DataFrame(corr_df)

create_pairwise_agent_heatmap(corr_df, x='test1', y='test2', domain=tests, units_of_measure='corr').properties(
    width=300,
    height=300,
)

Also do a correlation across 

In [316]:
def create_pairwise_agent_scatterplot(df, x, y, domain):

    charts = []
    for d1 in domain:
        chart_row = []
        for d2 in domain:
            scatter = alt.Chart(df).mark_circle(size=50).encode(
                y=alt.Y(d1, type='quantitative', title=d1, scale=alt.Scale(domain=[0,1])),
                x=alt.X(d2, type='quantitative', title=d2, scale=alt.Scale(domain=[0,1])),
                color=alt.Color('participant_id')
            ).properties(width=50, height=50)
            chart_row.append(scatter)
        charts.append(chart_row)

    return alt.vconcat(
        *[alt.hconcat(*chart_row) for chart_row in charts]
    )


create_pairwise_agent_scatterplot(data, x='test1', y='test2', domain=tests)