# Data Generation

In [1]:
import numpy as np

rng = np.random.default_rng()

In [2]:
def generate_dense_gamma(n, m, intercept=1, k=3):
    beta_star = rng.standard_normal(m)
    X = rng.standard_normal((n, m))
    mu = np.exp(intercept + X @ beta_star)
    y = rng.gamma(k, mu / k, n)
    return X, y, beta_star


def generate_sparse_gamma(n, m, intercept=1, k=3, sparse=0.20):
    beta_star = rng.standard_normal(m)
    beta_star[rng.integers(0, m, int(sparse * m))] = 0
    X = rng.standard_normal((n, m))
    mu = np.exp(intercept + X @ beta_star)
    y = rng.gamma(k, mu / k, n)
    return X, y, beta_star

In [650]:
n = 1000
m = 10
X, y, beta_star = generate_sparse_gamma(n, m)

# JAX Optimization

In [647]:
import jax.numpy as jnp
import jax
import optax
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.utils.validation import check_is_fitted

In [658]:
@jax.jit
def gamma_deviance(y_hat, y):
    return 2 * (jnp.log(y_hat / y) + y / y_hat - 1)


@jax.jit
def coef_gamma_deviance(params, X, y):
    intercept = params[0]
    beta = params[1:]
    y_hat = jnp.exp(intercept + X @ beta)
    return gamma_deviance(y_hat, y).mean()


@jax.jit
def lasso_penalty(beta):
    return jnp.abs(beta).sum()


@jax.jit
def ridge_penalty(beta):
    return 1 / 2 * jnp.sum(beta**2)


@jax.jit
def gamma_reg(params, X, y, lam, alpha):
    intercept = params[0]
    beta = params[1:]
    y_hat = jnp.exp(intercept + X @ beta)
    return gamma_deviance(y_hat, y).mean() + lam * (
        alpha * lasso_penalty(beta) + (1 - alpha) * ridge_penalty(beta)
    )


class GammaRegressor(RegressorMixin, BaseEstimator):
    def __init__(self, lam=0, alpha=0.5) -> None:
        self.lam = lam
        self.alpha = alpha

    def fit(self, X, y):
        X, y = self._validate_data(X, y)

        X_jax = jnp.asarray(X)
        y_jax = jnp.asarray(y)
        params = jnp.zeros(X.shape[1] + 1)

        def obj(params):
            return gamma_reg(params, X_jax, y_jax, self.lam, self.alpha)

        solver = optax.lbfgs()
        opt_state = solver.init(params)
        value_and_grad = optax.value_and_grad_from_state(obj)

        old_value = obj(params)
        no_prog_steps = 0
        while True:
            value, grad = value_and_grad(params, state=opt_state)
            print(f"Objective Value: {value}")

            if jnp.abs(grad).sum() < 1e-5:
                print("Gradient Norm Stop")
                break

            updates, opt_state = solver.update(
                grad / jnp.linalg.norm(grad),
                opt_state,
                params,
                value=value,
                grad=grad,
                value_fn=obj,
            )
            params = optax.apply_updates(params, updates)

            new_value = obj(params)
            if jnp.abs((new_value - old_value) / old_value) < 1e-3:
                no_prog_steps += 1
            if no_prog_steps > 5:
                print("No objective progress stop")
                break
            old_value = new_value

        self.intercept_ = params[0].item()
        self.coef_ = np.asarray(params[1:], dtype=np.float64)

    def predict(self, X):
        check_is_fitted(self)
        return np.exp(self.intercept_ + X @ self.coef_)

    def score(self, X, y):
        return (
            gamma_deviance(jnp.asarray(self.predict(X)), jnp.asarray(y)).mean().item()
        )

In [659]:
model = GammaRegressor(lam=0)
model.fit(X, y)

Objective Value: 39.902191162109375
Objective Value: 4.674535274505615
Objective Value: 1.949567198753357
Objective Value: 1.6746418476104736
Objective Value: 1.6747959852218628
Objective Value: 1.5127439498901367
Objective Value: 1.2911567687988281
Objective Value: 1.1280925273895264
Objective Value: 1.010317325592041
Objective Value: 0.9427123069763184
Objective Value: 0.8438864946365356
Objective Value: 0.7746723294258118
Objective Value: 0.7185803055763245
Objective Value: 0.6727139949798584
Objective Value: 0.6354771852493286
Objective Value: 0.5853992104530334
Objective Value: 0.5489557385444641
Objective Value: 0.49130696058273315
Objective Value: 0.4322171211242676
Objective Value: 0.4030422568321228
Objective Value: 0.3894992172718048
Objective Value: 0.38036978244781494
Objective Value: 0.3748360276222229
Objective Value: 0.37274113297462463
Objective Value: 0.37199628353118896
Objective Value: 0.3716709017753601
Objective Value: 0.37158092856407166
Objective Value: 0.3714780

In [660]:
model.intercept_

0.9785037040710449

In [663]:
model.coef_.round(2)

array([-1.14, -0.  , -0.97, -0.65, -0.34,  0.02,  0.18, -0.29,  0.94,
        0.81])

In [664]:
beta_star.round(2)

array([-1.12,  0.  , -0.99, -0.67, -0.31,  0.  ,  0.19, -0.3 ,  0.94,
        0.82])