## Simpson's paradox

Goal is a simpson's paradox data generator

In [None]:
import os
from pathlib import Path
from functools import partial
from typing import Tuple

import pytest

import numpy as np
import pandas as pd
import xarray as xr

from fake_data_for_learning.simpson import compute_margin, transform_data_array_component

In [None]:
# Table 1.1 of Pearl, Primer

init_data = np.empty((2,2,2))
init_data[:] = np.nan
data = xr.DataArray(
    init_data, 
    dims=('gender', 'drug_taken', 'recovery'),
    coords = dict(gender=['male', 'female'], drug_taken=[0,1], recovery=[0,1])
)
data.loc[:, :, 1] = [[234, 81], [55, 192]]  # Order switched from book, as no recovery <-> 0

## Spot check
assert(data.loc['male', 1, 1].values == 81)  # Male, drug, recovered
assert(data.loc['female', 0, 1].values == 55)  # Female, no drug, recovered

data.loc[:, :, 0] = [[36, 6], [25, 71]]

## Check sums
np.testing.assert_array_equal(
    (data.sum(dim='recovery')).values,
    [[270, 87], [80, 263]]
)

## Recovery rates

In [None]:
# By gender
recovery_by_gender = data.sel(recovery=1) / (data.sum(dim='recovery'))
recovery_by_gender.values

In [None]:
print("Recovery higher if taking drug by gender:")
(recovery_by_gender.sel(drug_taken=1) > recovery_by_gender.sel(drug_taken=0)).values

In [None]:
# Summing along gender
recovery = data.sel(recovery=1).sum(dim=['gender']) / data.sum(dim=['recovery', 'gender'])
recovery.values

In [None]:
print("Recovery higher if taking drug project out gender:")
(recovery.sel(drug_taken=1) > recovery.sel(drug_taken=0)).values

## Generating examples

Starting from a contingency table on the Simpson's boundary, generate examples that more further from the boundary.

In [None]:
contingency_table = xr.DataArray(
    [[
        [6, 8],
        [3, 1]
    ], [
        [4, 22],
        [27, 9]
    ]],
    dims=("recovered", "gender", "treated"),
    coords=dict(recovered=[0, 1], gender=[0, 1], treated=[0, 1])
)


# Convenience functions
def treated_gender_recovery_relative_risk(
    contingency_table: xr.DataArray, gender: int
) -> float:
    u_101 = float(contingency_table[dict(recovered=1, gender=gender, treated=1)])
    u_p00 = compute_margin(contingency_table, dict(gender=gender, treated=0))
    u_100 = float(contingency_table[dict(recovered=1, gender=gender, treated=0)])
    u_p01 = compute_margin(contingency_table, dict(gender=gender, treated=1))
    res = u_101 * u_p00 / (u_100 * u_p01)
    return res

def treated_recovery_relative_risk(contingency_table: xr.DataArray) -> float:

    u_1p1 = compute_margin(contingency_table, dict(recovered=1, treated=1))
    u_pp0 = compute_margin(contingency_table, dict(treated=0))
    u_1p0 = compute_margin(contingency_table, dict(recovered=1, treated=0))
    u_pp1 = compute_margin(contingency_table, dict(treated=1))
    res = u_1p1 * u_pp0 / (u_1p0 * u_pp1)
    return res

rr_0 = treated_gender_recovery_relative_risk(contingency_table, gender=0)
rr_1 = treated_gender_recovery_relative_risk(contingency_table, gender=1)
rr = treated_recovery_relative_risk(contingency_table)

assert rr_0 == pytest.approx(22 * 10 / (4 * 30))
assert rr_1 == pytest.approx(9 * 30 / (27 * 10))
assert rr == pytest.approx(31 * 40 / (31 * 40))

print(f'Treated female recovery relative risk: {rr_0}')
print(f'Treated male recovery relative risk: {rr_1}')
print(f'Treated recovery relative risk: {rr}')


### Transform counts to make Simpson's paradox more extreme

Note: we aren't strictly in a case of Simpson's paradox, as a trend has not reversed by going from population to sub-population.

Now we transform the counts to increase the female recovery relative risk, while keeping the male and total population relative risks the same.

In [None]:
def translate_component_by(
    a_data_array: xr.DataArray, component: dict, by: float
) -> float:
    return float(a_data_array[component] + by)


# u_101_translation = 0
# u_111_translation = 1

u_101_translation = 2
u_111_translation = 1


translations = [
    dict(component=dict(recovered=1, gender=0, treated=1), by=u_101_translation),
    dict(component=dict(recovered=1, gender=1, treated=1), by=u_111_translation),
]


for translation_dict in translations:
    component = translation_dict['component']
    by = translation_dict['by']
    res = transform_data_array_component(
        contingency_table, 
        component_function=partial(
            translate_component_by, 
            component=component,
            by=by
        )
    )
    translation_dict['res'] = res
    _ = translation_dict.pop('by')

def u110_substitution(a_data_array: xr.DataArray, offset: float) -> float:
    """A simpson's paradox inspired function"""
    first = (a_data_array[dict(recovered=1, gender=1, treated=1)] + offset)
    second = (
        compute_margin(a_data_array, dict(gender=1, treated=0))
        / compute_margin(a_data_array, dict(gender=1, treated=1))
    )

    return float(first * second)

def u100_substitution(
    a_data_array: xr.DataArray, offset: Tuple[float, float]
) -> float:
    """Another simpson's paradox inspired function"""
    first_left = -(a_data_array[dict(recovered=1, gender=1, treated=1)] + offset[1])
    first_right = (
        compute_margin(a_data_array, dict(gender=1, treated=0))
        / compute_margin(a_data_array, dict(gender=1, treated=1))
    )

    second_left = (
        a_data_array[dict(recovered=1, gender=0, treated=1)]
        + a_data_array[dict(recovered=1, gender=1, treated=1)]
        + offset[0] + offset[1]    )
    second_right = (
        compute_margin(a_data_array, dict(treated=0))
        / compute_margin(a_data_array, dict(treated=1))
    )

    return float(first_left * first_right + second_left * second_right)

substitutions = [
    dict(
        component=dict(recovered=1, gender=1, treated=0), 
        substitution_fn=u110_substitution, offset=u_111_translation
    ),
    dict(
        component=dict(recovered=1, gender=0, treated=0), 
        substitution_fn=u100_substitution, 
        offset=(u_101_translation, u_111_translation)
    )
]

for substitution_dict in substitutions:
    component = substitution_dict['component']
    substitution_fn = substitution_dict['substitution_fn']
    offset = substitution_dict['offset']

    res = substitution_fn(contingency_table, offset=offset)
    substitution_dict['res'] = res
    substitution_dict.pop('offset')


new_contingency_table = contingency_table.copy()

transformations = translations + substitutions

for transformation_dict in transformations:
    component = transformation_dict['component']
    new_contingency_table[component] = transformation_dict['res']

    # Adjust also non-recovered values to keep marginals u_+ij constant
    assert component['recovered'] == 1  # By choice of transformation
    nonrecovered_twin_component = component.copy()
    nonrecovered_twin_component['recovered'] = 0
    marginal = compute_margin(
        contingency_table, 
        non_margin_sel=dict(
            gender=nonrecovered_twin_component['gender'], 
            treated=nonrecovered_twin_component['treated'])
        )
    new_contingency_table[nonrecovered_twin_component] = (
        marginal - new_contingency_table[component]
    )

new_rr_0 = treated_gender_recovery_relative_risk(new_contingency_table, gender=0)
new_rr_1 = treated_gender_recovery_relative_risk(new_contingency_table, gender=1)
new_rr = treated_recovery_relative_risk(new_contingency_table)

# Tests

# Assert all counts still positive
assert np.all((new_contingency_table >= 0).values)

# Assert total population has not changed
assert contingency_table.sum() == new_contingency_table.sum()

# Assert female relative risk of recovery has increased
assert new_rr_0 > rr_0

# Assert other relative risks unchanged
assert new_rr_1 == rr_1 
assert new_rr == rr

print(f'Treated female recovery relative risk: {new_rr_0}')
print(f'Treated male recovery relative risk: {new_rr_1}')
print(f'Treated recovery relative risk: {new_rr}')