Skip to content

Commit

Permalink
pair not nested in lm_type
Browse files Browse the repository at this point in the history
  • Loading branch information
kddubey committed Aug 2, 2024
1 parent c917588 commit abf7c35
Show file tree
Hide file tree
Showing 5 changed files with 3,242 additions and 3,348 deletions.
3,248 changes: 1,598 additions & 1,650 deletions analysis/fit_posteriors/m50/m50_n200.ipynb

Large diffs are not rendered by default.

3,256 changes: 1,602 additions & 1,654 deletions analysis/fit_posteriors/m50/m50_n500.ipynb

Large diffs are not rendered by default.

16 changes: 9 additions & 7 deletions analysis/test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
"metadata": {},
"outputs": [],
"source": [
"equation = \"p(num_correct, num_test) ~ method + lm_type + (1|dataset/pair)\"\n",
"equation = \"p(num_correct, num_test) ~ method + lm_type + (1|dataset/method) + (1|dataset/pair)\"\n",
"id_vars = [\"num_test\", \"pair\", \"lm_type\", \"dataset\"]"
]
},
Expand All @@ -129,13 +129,15 @@
" treatment_effect_prior_mean_std: tuple[float, float],\n",
" random_seed: int = 123,\n",
"):\n",
" # Picking treatment and control columns is arbitrary. Just need to tell bambi what\n",
" # the data looks like, not what's in it. The actual data is unused\n",
" treatment = \"test\"\n",
" control = \"extra\"\n",
" # Build model to set up sampling from priors\n",
" model = utils.create_model(\n",
" # Picking treatment and control columns is arbitrary. Just need to tell bambi\n",
" # what the data looks like, not what's in it. The actual data is unused\n",
" utils.melt_num_correct(\n",
" num_correct_df, treatment=\"test\", control=\"extra\", id_vars=id_vars\n",
" ),\n",
" num_correct_df, treatment, control, id_vars=id_vars\n",
" ).to_pandas(),\n",
" equation,\n",
" treatment_effect_prior_mean_std=treatment_effect_prior_mean_std,\n",
" )\n",
Expand All @@ -148,8 +150,8 @@
" \"p(num_correct, num_test)\"\n",
" ]\n",
" return utils.num_correct_df_from_predicions(\n",
" num_correct_df, predictions_prior[:, 0].to_numpy()\n",
" )"
" num_correct_df, predictions_prior[:, 0].to_numpy(), treatment, control\n",
" ).rename({treatment: \"treatment\", control: \"control\"})"
]
},
{
Expand Down
68 changes: 32 additions & 36 deletions analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from IPython.display import display
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
import pandas as pd
import polars as pl
import seaborn as sns
Expand Down Expand Up @@ -254,8 +253,8 @@ def violin_plot_multiple_lms(accuracy_df: pl.DataFrame, num_test: int, num_train
)
xlim = (-0.5, 0.5)
axes: list[plt.Axes] = [axes] if isinstance(axes, plt.Axes) else axes
for subplot_idx, (lm_type, accuracy_df_lm) in enumerate(
accuracy_df.group_by("lm_type", maintain_order=True)
for subplot_idx, ((lm_type,), accuracy_df_lm) in enumerate(
accuracy_df.partition_by("lm_type", maintain_order=True, as_dict=True).items()
):
ax = axes[subplot_idx]
ax.set_xlim(xlim)
Expand Down Expand Up @@ -349,7 +348,7 @@ def melt_num_correct(
treatment: str,
control: str,
id_vars: Sequence[str] = ("num_test", "pair", "dataset"),
) -> pd.DataFrame:
) -> pl.DataFrame:
"""
Melt `num_correct_df` into a format which is compatible with Wilkinson notation.
A pandas (not polars) dataframe is returned because `bambi` (and others) assume a
Expand All @@ -359,12 +358,13 @@ def melt_num_correct(
if "num_test" not in id_vars:
id_vars = ["num_test"] + id_vars
return (
num_correct_df.with_columns(pl.Series("pair", range(len(num_correct_df))))
# pair indexes the subsample
num_correct_df.group_by("lm_type", maintain_order=True)
.map_groups(lambda df: df.with_columns(pl.Series("pair", range(len(df)))))
# pair indexes the subsample. It's common across LM types b/c seeds were set
# when subsampling.
.select(id_vars + [control, treatment])
.melt(id_vars=id_vars, variable_name="method", value_name="num_correct")
.sort("pair")
.to_pandas()
.sort("lm_type", "pair")
)


Expand Down Expand Up @@ -422,7 +422,7 @@ def stat_model(
"""
num_correct_df_melted = melt_num_correct(
num_correct_df, treatment, control, id_vars=id_vars
)
).to_pandas()

# Fit model
model = create_model(
Expand Down Expand Up @@ -463,44 +463,34 @@ def stat_model(
def num_correct_df_from_predicions(
num_correct_df: pl.DataFrame,
predictions: Sequence[float],
treatment: str,
control: str,
id_vars: Sequence[str] = ("num_test", "pair", "lm_type", "dataset"),
) -> pl.DataFrame:
"""
Returns a dataframe which looks like `num_correct_df` where observations are filled
by `predictions`.
`predictions` is assumed to come from an `InferenceData` xarray. So `predictions` is
melted while `num_correct_df` is not.
from the melted version of `num_correct_df`.
"""
# Some data we'll need to populated the simulated DF
num_test: int = num_correct_df.select("num_test")[0].item()
datasets = num_correct_df["dataset"].unique(maintain_order=True)
num_subsamples = _num_subsamples(num_correct_df)
lm_types = num_correct_df["lm_type"].unique(maintain_order=True)
# Inverse of melt is pivot
# The number 2 in the code below refers to treatment and control
return (
pl.DataFrame(
{
"pair": np.repeat(np.arange(len(num_correct_df)), 2),
"lm_type": (
lm_types.to_numpy().repeat(datasets.len() * num_subsamples * 2)
),
"dataset": np.tile(
datasets.to_numpy().repeat(num_subsamples * 2), reps=lm_types.len()
),
"method": np.tile(["control", "treatment"], reps=len(num_correct_df)),
"num_correct": predictions,
"num_test": num_test,
}
melt_num_correct(num_correct_df, treatment, control, id_vars=id_vars)
.with_columns(pl.Series(predictions).alias("num_correct"))
.pivot(
values="num_correct",
index=id_vars,
on="method",
)
.pivot(values="num_correct", index=id_vars, columns="method")
.drop("pair")
)


def _marginal_mean_diffs(
num_correct_df: pl.DataFrame, predictions: xr.DataArray
num_correct_df: pl.DataFrame,
predictions: xr.DataArray,
treatment: str,
control: str,
) -> list[float]:
"""
Distribution for the expected accuracy difference between the treatment and control.
Expand All @@ -512,11 +502,11 @@ def _marginal_mean_diffs(
range(predictions.shape[1]), desc=f"Marginalizing each draw (n = {num_test})"
):
num_correct_df_simulated = num_correct_df_from_predicions(
num_correct_df, predictions[:, draw].to_numpy()
num_correct_df, predictions[:, draw].to_numpy(), treatment, control
)
mean_diffs.append(
num_correct_df_simulated.select(
(pl.col("treatment") - pl.col("control")) / num_test
(pl.col(treatment) - pl.col(control)) / num_test
)
.mean()
.item()
Expand All @@ -525,10 +515,16 @@ def _marginal_mean_diffs(


def posterior_marginal_mean_diffs(
summary: az.InferenceData, num_correct_df: pl.DataFrame
summary: az.InferenceData,
num_correct_df: pl.DataFrame,
treatment: str,
control: str,
) -> list[float]:
posterior_predictive = az.extract(summary, group="posterior_predictive")
mean_diffs = _marginal_mean_diffs(
num_correct_df, posterior_predictive["p(num_correct, num_test)"]
num_correct_df,
posterior_predictive["p(num_correct, num_test)"],
treatment,
control,
)
return mean_diffs
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ gcp = [
stat = [
"bambi>=0.13.0",
"jupyter>=1.0.0",
"polars>=0.18.0",
"polars>=1.0.0",
"seaborn>=0.13.0",
"statsmodels>=0.14.0",
]
Expand Down

0 comments on commit abf7c35

Please sign in to comment.