In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

sns.set_theme(style="whitegrid")

In [None]:
path = Path('../data/ab_test_experiment.csv')
if not path.exists():
    raise FileNotFoundError(f"Missing {path}. Run: python ../scripts/generate_demo_datasets.py")
df = pd.read_csv(path)
df.head()

## SRM check
A sample ratio mismatch can invalidate results if assignment is broken.

In [None]:
counts = df['variant'].value_counts()
counts

In [None]:
# Chi-square goodness-of-fit against 50/50
obs = counts.reindex(['control','treatment']).values
exp = np.array([obs.sum()/2, obs.sum()/2])
chi2, p_srm = stats.chisquare(f_obs=obs, f_exp=exp)
print('SRM p-value:', p_srm)

## Conversion uplift
We estimate the difference in conversion rates and compute a (Wald) confidence interval.

In [None]:
summary = df.groupby('variant').agg(n=('user_id','count'), conv=('converted','mean'), rpu=('revenue','mean'))
summary

In [None]:
c = summary.loc['control']
t = summary.loc['treatment']

p1, n1 = c['conv'], c['n']
p2, n2 = t['conv'], t['n']

diff = p2 - p1
se = np.sqrt(p1*(1-p1)/n1 + p2*(1-p2)/n2)
z = 1.96
ci = (diff - z*se, diff + z*se)

# Two-proportion z-test
p_pool = (p1*n1 + p2*n2)/(n1+n2)
se_pool = np.sqrt(p_pool*(1-p_pool)*(1/n1 + 1/n2))
z_stat = diff / se_pool
p_val = 2*(1 - stats.norm.cdf(abs(z_stat)))

print(f'Control conv: {p1:.3%}  Treatment conv: {p2:.3%}')
print(f'Uplift (pp): {diff*100:.2f}  95% CI: [{ci[0]*100:.2f}, {ci[1]*100:.2f}]')
print('p-value:', p_val)

## Revenue per user (secondary metric)
A quick comparison of average revenue per user (including zeros for non-converters).

In [None]:
plt.figure(figsize=(6,4))
sns.barplot(data=summary.reset_index(), x='variant', y='rpu')
plt.title('Revenue per user')
plt.ylabel('RPU')
plt.show()

## Segment check (device)
Useful for communicating *where* impact is concentrated.

In [None]:
seg = (df.groupby(['device','variant'])['converted'].mean().reset_index())
plt.figure(figsize=(7,4))
sns.barplot(data=seg, x='device', y='converted', hue='variant')
plt.title('Conversion rate by device')
plt.ylabel('Conversion rate')
plt.show()