In [186]:
import polars as pl
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np
import jax as jax
import matplotlib.pyplot as plt

rng = np.random.default_rng()

In [2]:
df_freq = pl.read_csv("./data/insurance/freMTPL2freq.csv")

df_sev = (
    pl.read_csv("./data/insurance/freMTPL2sev.csv", infer_schema_length=25000)
    .group_by("IDpol")
    .sum()
)

df = df_freq.join(df_sev, on="IDpol", how="left", coalesce=True).with_columns(
    PurePremium=pl.col("ClaimAmount").fill_null(0) / pl.col("Exposure")
)

In [None]:
@jax.jit
def log_norm_pdf(x, mu, sigma):
    return (
        1
        / (x * sigma * jnp.sqrt(2 * jnp.pi))
        * jnp.exp(-jnp.square(jnp.log(x) - mu) / (2 * jnp.square(sigma)))
    )

@jax.jit
def hurdle_log_prob(x, params):
    p = jsp.special.expit(params[0])
    mu = params[1]
    sigma = jnp.exp(params[2])
    # Returning negative log-prob for minimization
    return -jnp.where(
        x == 0, jnp.log(1 - p), jnp.log(p) + jnp.log(log_norm_pdf(x, mu, sigma))
    )

hlp_jac_func = jax.jacfwd(hurdle_log_prob, 1)
hlp_hess_func = jax.jacfwd(hlp_jac_func, 1)

@jax.jit
def hlp_grad(x, params):
    # When observed value is zero it has no gradient information for the log-normal
    # part of the model, so fill these parts with 0
    return jnp.nan_to_num(hlp_jac_func(x, params), nan=0.0)

@jax.jit
def hlp_hess(x, params):
    return jnp.nan_to_num(hlp_hess_func(x, params), nan=0.0)

# Xgboost will output a vector of predictions for each input giving an nxm matrix
# so we need to map over each parameter vector for each datapoint
# Param vector is [n, m] and the log-prob function requires a param vector of shape [m,]
vector_grad = jax.vmap(hlp_grad, (0, 0), 0) # [n,] x [n, m] -> [n, m]
vector_hess = jax.vmap(hlp_hess, (0, 0), 0) # [n,] x [n, m] -> [n, m, m]

@jax.jit
def xg_obj(y_true, y_pred):
    grad = vector_grad(y_true, y_pred)
    hess = vector_hess(y_true, y_pred)
    # Xgboost needs hessian in [n, m] shape
    return grad, jnp.abs(hess).sum(axis=2)

In [196]:
y_true = df["PurePremium"].to_numpy()
# Test predictions in same form as xgboost produces
y_pred = rng.random((y_true.shape[0], 3))

In [198]:
test, test2 = xg_obj(y_true, y_pred)