In [9]:
import numpy as np
import polars as pl
import tensorflow as tf
from scipy.stats import gamma
from nn_intreg import IntReg

In [10]:
# set seed for reproducibility of the synthetic data
np.random.seed(500)
# number of simulations for the synthetic data
NSIM = 10000
SHAPE_PAR = 1.6
N_BINS = 30
# define the covariates
intercept_labels = ["intercept"]
x1_labels = ["foo", "bar"]
x2_labels = ["apple", "orange", "lemon"]
# define the relativities
intercept_factors = [16]
x1_factors = [1.0, 1.07]
x2_factors = [1.0, 0.95, 1.02]

# function to collect covariates information
def _helper_pl(cov_name, labels, factors):
    return pl.DataFrame({cov_name: labels, f"factor_{cov_name}": factors})

Simulate covariates

In [11]:
sample = []
# sampling of the labels and the factors
for el in [
    ("intercept", intercept_labels, intercept_factors),
    ("X1", x1_labels, x1_factors),
    ("X2", x2_labels, x2_factors),
]:
    cov_name, labels, factors = el
    _df = _helper_pl(cov_name, labels, factors).sample(NSIM, with_replacement=True)
    sample.append(_df)
sample = pl.concat(sample, how="horizontal")
sample = (
    sample.with_columns(l=pl.concat_list((pl.col("^factor_.*$"))))
    .with_columns(
        factor_total=pl.col("l").list.eval(pl.col("*").product(), parallel=True)
    )
    .explode("factor_total")
    .select(["X1", "X2", "factor_total"])
)
sample.head()

X1,X2,factor_total
str,str,f64
"""bar""","""apple""",17.12
"""foo""","""lemon""",16.32
"""foo""","""orange""",15.2
"""bar""","""apple""",17.12
"""bar""","""lemon""",17.4624


Simulate the response variable

In [12]:

y = gamma.rvs(size=NSIM, a=SHAPE_PAR, scale=sample["factor_total"] / SHAPE_PAR)
sample = (
    sample.with_columns(y=pl.lit(y))
    .with_columns(breaks=pl.col("y").qcut(N_BINS, include_breaks=True))
    .unnest("breaks")
)

intervals = (
    sample["y_bin"]
    .cast(str)
    .str.strip_chars('( ] "')
    .str.split(", ")
    .list.to_struct()
    .struct.rename_fields(["left_break", "right_break"])
    .struct.unnest()
    .cast(pl.Float32)
)

sample = pl.concat([sample, intervals], how="horizontal").select(
    ["X1", "X2", "left_break", "right_break", "factor_total"]
)
sample.head()

X1,X2,left_break,right_break,factor_total
str,str,f32,f32,f64
"""bar""","""apple""",8.848387,9.615126,17.12
"""foo""","""lemon""",13.963438,14.953102,16.32
"""foo""","""orange""",9.615126,10.391552,15.2
"""bar""","""apple""",5.719008,6.461315,17.12
"""bar""","""lemon""",3.34258,4.18456,17.4624


Prepare the design matrix and the response matrix

In [13]:
X_y_train = sample.to_dummies(["X1", "X2"], drop_first=True)
X_train = X_y_train.select(
    pl.all().exclude(["left_break", "right_break", "factor_total"])
).to_numpy()
y_train = X_y_train.select(["left_break", "right_break"]).to_numpy()
y_train = np.clip(y_train, 0.000001, 99999999)
X_y_train.head()

X1_foo,X2_lemon,X2_orange,left_break,right_break,factor_total
u8,u8,u8,f32,f32,f64
0,0,0,8.848387,9.615126,17.12
1,1,0,13.963438,14.953102,16.32
1,0,1,9.615126,10.391552,15.2
0,0,0,5.719008,6.461315,17.12
0,1,0,3.34258,4.18456,17.4624


Build the model

In [14]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.4)
model = IntReg(optimizer)
history = model.fit(X_train, y_train)
scales_hat = model.predict(X_train)
results = (
    sample.with_columns(yhat=pl.lit(scales_hat.flatten()))
    .group_by(["X1", "X2"])
    .agg(truth=pl.col("factor_total").mean(), model=pl.col("yhat").mean())
)
print("Found shape:")
print(np.exp(model.layers[-1].get_vars().numpy()))
print("True shape:")
print(SHAPE_PAR)
print("E(Y|X) check:")
print(results)


Found shape:
1.5860531
True shape:
1.6
E(Y|X) check:
shape: (6, 4)
┌─────┬────────┬─────────┬───────────┐
│ X1  ┆ X2     ┆ truth   ┆ model     │
│ --- ┆ ---    ┆ ---     ┆ ---       │
│ str ┆ str    ┆ f64     ┆ f32       │
╞═════╪════════╪═════════╪═══════════╡
│ bar ┆ orange ┆ 16.264  ┆ 16.209957 │
│ bar ┆ lemon  ┆ 17.4624 ┆ 17.235085 │
│ foo ┆ orange ┆ 15.2    ┆ 15.264956 │
│ bar ┆ apple  ┆ 17.12   ┆ 16.896196 │
│ foo ┆ lemon  ┆ 16.32   ┆ 16.230455 │
│ foo ┆ apple  ┆ 16.0    ┆ 15.91115  │
└─────┴────────┴─────────┴───────────┘
