### High-dimensional synthetic data

- C: Vector of {c_dim} {dice_size}-sided dice rolls.
- A: Flip 1 + {dice_size} - median(C) coins. A is 1 if at least one flip comes up heads.
- Y: Flip f(C) + A coins and write down the number of heads.

In [49]:
import numpy as np
import pandas as pd
import statsmodels.formula.api as smf
import time

In [2]:
def get_smf_model_a_param(ols, df):
    """
    Fit a model with statsmodels
    Return the parameter corresponding to the treatment
    """
    return smf.ols(ols, data=df).fit().params['a']

In [43]:
def observed(n=100, c_dim=6, dice_size=2, power=1):
    """
    The observed data distribution
      C: roll {c_dim} {dice_size}-sided dice and record the results
      A: flip `1 + dice_size - np.median(C)` fair coins
          and record 1 if at least one flip lands heads
      
      Y: flip `C + A` fair coins, and record the number of heads
    """

    c_powers = np.ones([c_dim * power])
    c_coefs = np.array([(-1) ** i * j * 1.3 ** i
                        for i in range(1, 1 + c_dim)
                        for j in range(1, 1 + power)])

    # what's the smallest c ** pow @ c_coefs could be? Subtract that off.
    worst_roll = np.where(c_coefs > 0, 1, dice_size)
    y_min_dice = np.power(worst_roll, c_powers) @ c_coefs

    c = np.random.randint(1, 1 + dice_size, (n, c_dim))
    c_median = np.median(c, axis=1).astype(int)
    a = np.random.binomial(n=1 + dice_size - c_median, p=0.5, size=n)
    a = (a > 0).astype(np.int32)

    c = np.tile(c, power)
    y_n_dice = np.ceil(-y_min_dice + a + c @ c_coefs).astype(int)
    y = np.random.binomial(n=y_n_dice, p=0.5)

    columns = {"a": a, "y": y}
    c_col_names = [f"c{i}_{j}" for i in range(1, 1 + c_dim) for j in range(1, 1 + power)]
    c_cols = [col.reshape(-1) for col in np.array_split(c, c_dim * power, axis=1)]
    columns.update(dict(zip(c_col_names, c_cols)))
    df = pd.DataFrame(data=columns)

    return df

In [45]:
def experiment(estimator="ols", n=100, c_dim=6,
               repeats=1, power=1, dice_size=2,
               ground_truth=None, prec=(3,0)):

    # c_col_names = [f"c{i}_{j}" for i in range(1, 1 + c_dim) for j in range(1, 1 + power)]
    c_col_names = [f"c{i}_1" for i in range(1, 1 + c_dim)]
    results = []
    np.random.seed(42)
    for i in range(repeats):
        df = observed(n=n, c_dim=c_dim, power=power, dice_size=dice_size)

        if estimator == "ols":
            ols = "y ~ a + " + " + ".join(c_col_names)
            results.append(get_smf_model_a_param(ols, df))

        elif estimator == "count":
            total = 0
            denominator = 0
            unique_c, counts = np.unique(
                df[c_col_names], axis=0, return_counts=True)

            for uniq, count in zip(unique_c, counts):
                if count == 1: continue
                subdf = df[(df[c_col_names] == uniq).all(axis=1)]
                if np.unique(subdf["a"]).shape[0] == 1: continue
                e_y_a1 = subdf[subdf["a"] == 1]["y"].mean()
                e_y_a0 = subdf[subdf["a"] == 0]["y"].mean()
                total += count * (e_y_a1 - e_y_a0)
                denominator += count

            if denominator == 0:
                results.append(np.nan)
            else:
                results.append(total / denominator)

    if ground_truth is not None:
        results = [np.abs(r - ground_truth) for r in results]
    err = ""
    prec_mean, prec_std = prec
    if repeats > 1 and prec_std >= 0:
        err = f" ± {np.std(results):.{prec_std}f}"
    return f"{np.mean(results):.{prec_mean}f}{err}"

In [41]:
default_kwargs = dict(
  n=100,
  c_dim=2,
  dice_size=8,
  repeats=10,
  power=1,
  prec=(2, -1),
  ground_truth=0.5,
)

In [46]:
kwargs = default_kwargs.copy()

for est in ["ols", "count"]:
    kwargs["estimator"] = est
    print(est, end=" ")
    print(experiment(**kwargs))

ols 0.30
count 0.65


In [47]:
kwargs = default_kwargs.copy()
kwargs["power"] = 3

for est in ["ols", "count"]:
    kwargs["estimator"] = est
    print(est, end=" ")
    print(experiment(**kwargs))

ols 0.93
count 1.43


In [48]:
kwargs = default_kwargs.copy()
kwargs["prec"] = (2, -1)

col_header_width = 9
cell_width = 8
powers = [1, 2, 4, 6]
n_values = [100, 1000, 10000]
c_dims = [1, 2, 4, 8, 16]

for power in powers:
    for est in ["ols", "count"]:
    
        print(f"Power: {power}; Estimator: {est}")
        header = [" " * col_header_width]
        header += [f"{c:^{cell_width}}" for c in c_dims]
        print(" ".join(header))

        for n in n_values:
            runtime = time.time()
            row = [f"n={n:<6d}: "]
            for c_dim in c_dims:
                kwargs.update(dict(
                    n=n,
                    c_dim=c_dim,
                    estimator=est,
                    power=power,
                ))
                result = experiment(**kwargs)
                row.append(f"{result:{cell_width}s}")
            row.append(f"     in {time.time() - runtime:.1f}s")
            print(" ".join(row))

        print()

Power: 1; Estimator: ols
             1        2        4        8        16   
n=100   :  0.25     0.30     1.23     2.18     7.01          in 0.2s
n=1000  :  0.06     0.15     0.32     0.63     1.71          in 0.2s
n=10000 :  0.04     0.05     0.06     0.23     0.52          in 0.6s

Power: 1; Estimator: count
             1        2        4        8        16   
n=100   :  0.35     0.65     nan      nan      nan           in 0.2s
n=1000  :  0.15     0.16     0.48     nan      nan           in 0.6s
n=10000 :  0.10     0.07     0.09     nan      nan           in 11.6s

Power: 2; Estimator: ols
             1        2        4        8        16   
n=100   :  0.35     0.95     4.18     15.38    103.46        in 0.3s
n=1000  :  0.08     0.65     1.84     2.41     34.69         in 0.3s
n=10000 :  0.04     0.70     1.43     3.13     22.98         in 0.6s

Power: 2; Estimator: count
             1        2        4        8        16   
n=100   :  0.49     1.22     5.13     nan      nan 

Defaults:
```
default_kwargs = dict(
  n=100,
  repeats=10,
  prec=(2, -1),
  ground_truth=0.5,
)
```

## Power: 1; Estimator: ols
```
c_dim        1        2        4        8        16   
n=100   :  0.25     0.30     1.23     2.18     7.01          in 0.2s
n=1000  :  0.06     0.15     0.32     0.63     1.71          in 0.2s
n=10000 :  0.04     0.05     0.06     0.23     0.52          in 0.6s
```
## Power: 1; Estimator: count
```
c_dim        1        2        4        8        16   
n=100   :  0.35     0.65     nan      nan      nan           in 0.2s
n=1000  :  0.15     0.16     0.48     nan      nan           in 0.6s
n=10000 :  0.10     0.07     0.09     nan      nan           in 11.6s
```

## Power: 2; Estimator: ols
```
c_dim        1        2        4        8        16   
n=100   :  0.35     0.95     4.18     15.38    103.46        in 0.3s
n=1000  :  0.08     0.65     1.84     2.41     34.69         in 0.3s
n=10000 :  0.04     0.70     1.43     3.13     22.98         in 0.6s
```

## Power: 2; Estimator: count
```
c_dim        1        2        4        8        16   
n=100   :  0.49     1.22     5.13     nan      nan           in 0.3s
n=1000  :  0.27     0.58     2.02     13.87    nan           in 0.9s
n=10000 :  0.19     0.67     1.48     3.46     nan           in 12.9s
```

## Power: 4; Estimator: ols
```
c_dim        1        2        4        8        16   
n=100   :  0.45     3.73     3.49     58.05    292.20        in 0.3s
n=1000  :  0.17     2.35     5.75     17.23    158.10        in 0.3s
n=10000 :  0.07     2.13     5.30     10.35    76.04         in 0.6s
```

## Power: 4; Estimator: count
```
c_dim        1        2        4        8        16   
n=100   :  0.65     4.00     3.58     103.24   nan           in 0.3s
n=1000  :  0.55     2.20     5.91     16.74    367.28        in 1.1s
n=10000 :  0.43     2.01     5.59     9.46     86.98         in 13.6s
```

## Power: 6; Estimator: ols
```
c_dim        1        2        4        8        16   
n=100   :  0.93     4.52     13.14    23.70    284.06        in 0.2s
n=1000  :  0.26     4.47     7.86     16.48    72.39         in 0.3s
n=10000 :  0.08     4.31     6.77     17.07    102.84        in 0.7s
```

## Power: 6; Estimator: count
```
c_dim        1        2        4        8        16   
n=100   :  1.87     6.15     22.02    nan      nan           in 0.3s
n=1000  :  0.69     3.56     7.72     26.42    nan           in 1.1s
n=10000 :  0.57     4.28     6.39     14.88    nan           in 12.8s
```