In [None]:
import functools

import polars as pl
import pyarrow.parquet as pq
import altair as alt

@functools.cache
def get_df(table_name: str) -> pl.DataFrame:
    table = pq.read_table(f'data/{table_name}.parquet')
    return pl.from_arrow(table)

In [None]:
(
    get_df('trials')
    .filter(
        # get random subset of sessions:
        pl.col('session_id') == pl.col('session_id').sample(1).first(),
    )
    .plot.scatter(
        x='trial_index_in_block',
        color='is_instruction',
        column='block_index',
        row='rewarded_modality',
        y='is_response',
    )
    .properties(width=200)
    .resolve_scale(x='independent')
)  

In [None]:
get_df('performance').schema

In [None]:
(
    get_df('performance')
    .plot.scatter(
        x='date:T',
        y='signed_cross_modality_dprime',
        color='rewarded_modality',
    )
)

In [None]:
(
    get_df('units')
    .group_by('structure')
    .agg(
        pl.col('ccf_ap').median(),
        pl.col('ccf_dv').median(),
        pl.col('unit_id').n_unique().alias('n_units'),
    )
    .plot.scatter(
        x=alt.X('ccf_ap').scale(zero=False), 
        y=alt.Y('ccf_dv').scale(reverse=True, zero=False),
        size='n_units',
        color=alt.Color('structure', legend=None),
    )
)