Skip to content

JAX autodiff for XGBoost/LightGBM objectives — write a loss function, get gradients and Hessians automatically.

License

Notifications You must be signed in to change notification settings

jxucoder/jaxboost

Repository files navigation

JAXBoost

Tests Lint Python 3.10+ License

JAX autodiff for XGBoost/LightGBM objectives.

Write a loss function, get gradients and Hessians automatically. No manual derivation needed.

Works with XGBoost and LightGBM.

Install

pip install jaxboost

Quick Start

XGBoost

import xgboost as xgb
import jax.numpy as jnp
from jaxboost import auto_objective, focal_loss, huber, quantile

# Prepare your data
dtrain = xgb.DMatrix(X_train, label=y_train)
params = {"max_depth": 4, "eta": 0.1}

# Built-in objectives - just use them
model = xgb.train(params, dtrain, num_boost_round=100, obj=focal_loss.xgb_objective)
model = xgb.train(params, dtrain, num_boost_round=100, obj=huber.xgb_objective)
model = xgb.train(params, dtrain, num_boost_round=100, obj=quantile(0.9).xgb_objective)

# Custom objective - write the loss, autodiff handles the rest
@auto_objective
def asymmetric_mse(y_pred, y_true, alpha=0.7):
    error = y_true - y_pred
    return jnp.where(error > 0, alpha * error**2, (1 - alpha) * error**2)

model = xgb.train(params, dtrain, num_boost_round=100, obj=asymmetric_mse.xgb_objective)

LightGBM

import lightgbm as lgb
from jaxboost import huber

train_data = lgb.Dataset(X_train, label=y_train)
params = {"max_depth": 4, "learning_rate": 0.1}

model = lgb.train(params, train_data, num_boost_round=100, fobj=huber.lgb_objective)

Available Objectives

Regression

Objective Description
mse Mean squared error
huber Huber loss (robust to outliers)
pseudo_huber Smooth approximation of Huber loss
log_cosh Log-cosh loss
mae_smooth Smooth approximation of MAE
quantile(q) Quantile regression
asymmetric(alpha) Asymmetric squared error
poisson Poisson deviance (count data)
gamma Gamma deviance (positive continuous)
tweedie(p) Tweedie deviance

Binary Classification

Objective Description
focal_loss Focal loss for imbalanced data
binary_crossentropy Standard log loss
weighted_binary_crossentropy Weighted binary cross-entropy
hinge_loss SVM-style hinge loss

Multi-class Classification

Objective Description
softmax_cross_entropy Standard multi-class
focal_multiclass Focal loss for multi-class
label_smoothing(eps) Label smoothing regularization
class_balanced Class-balanced loss

Survival Analysis

Objective Description
aft Accelerated failure time (log-normal)
weibull_aft Weibull AFT model

Ordinal Regression

Objective Description
ordinal_logit Cumulative Link Model (logit link)
ordinal_probit Cumulative Link Model (probit link)
qwk_ordinal QWK-aligned Expected Quadratic Error
squared_cdf_ordinal CRPS / Ranked Probability Score
hybrid_ordinal NLL + EQE hybrid
sord_objective SORD (Soft Ordinal) from SLACE paper
oll_objective OLL (Ordinal Log-Loss) from SLACE paper
slace_objective SLACE (AAAI 2025)

Multi-task Learning

Objective Description
multi_task_regression Multiple regression targets
multi_task_classification Multiple classification targets
multi_task_huber Multi-task Huber loss
multi_task_quantile Multi-task quantile loss
MaskedMultiTaskObjective Handle missing labels

Uncertainty Estimation

Objective Description
gaussian_nll Predict mean + variance
laplace_nll Predict median + scale

Ordinal Regression

XGBoost/LightGBM have no native ordinal objective. JAXBoost implements proper Cumulative Link Models:

from jaxboost import ordinal_logit, qwk_ordinal

# Wine quality: 6 ordered classes (3-8 mapped to 0-5)
ordinal = ordinal_logit(n_classes=6)
ordinal.init_thresholds_from_data(y_train)

# Works with XGBoost
model = xgb.train(params, dtrain, obj=ordinal.xgb_objective)

# Or LightGBM
model = lgb.train(params, train_data, fobj=ordinal.lgb_objective)

# Get class probabilities
probs = ordinal.predict_proba(model.predict(dtest))
classes = ordinal.predict(model.predict(dtest))

Evaluation Metrics

When using custom objectives, use matching evaluation metrics:

from jaxboost import ordinal_logit
from jaxboost.metric import qwk_metric, mae_metric

ordinal = ordinal_logit(n_classes=6)
ordinal.init_thresholds_from_data(y_train)

# Train with custom metric monitoring
model = xgb.train(
    {'disable_default_eval_metric': 1, 'max_depth': 4},  # Disable default metrics!
    dtrain,
    obj=ordinal.xgb_objective,
    custom_metric=ordinal.qwk_metric.xgb_metric,  # Built-in QWK metric
    evals=[(dtest, 'test')]
)

Available Metrics

Category Metrics
Ordinal qwk_metric, ordinal_mae_metric, ordinal_accuracy_metric, adjacent_accuracy_metric
Classification auc_metric, f1_metric, accuracy_metric, precision_metric, recall_metric
Regression mse_metric, rmse_metric, mae_metric, r2_metric
Bounded bounded_mse_metric, out_of_bounds_metric

Custom Objectives

The @auto_objective decorator turns any loss function into an XGBoost/LightGBM objective:

import xgboost as xgb
import lightgbm as lgb
import jax.numpy as jnp
from jaxboost import auto_objective

@auto_objective
def my_custom_loss(y_pred, y_true, **kwargs):
    # Write your loss here - JAX computes grad/hess automatically
    return (y_pred - y_true) ** 2

# Use with XGBoost
dtrain = xgb.DMatrix(X_train, label=y_train)
params = {"max_depth": 4, "eta": 0.1}
model = xgb.train(params, dtrain, num_boost_round=100, obj=my_custom_loss.xgb_objective)

# Use with LightGBM
train_data = lgb.Dataset(X_train, label=y_train)
params = {"max_depth": 4, "learning_rate": 0.1}
model = lgb.train(params, train_data, num_boost_round=100, fobj=my_custom_loss.lgb_objective)

# Pass parameters
model = xgb.train(
    params, dtrain, num_boost_round=100,
    obj=my_custom_loss.get_xgb_objective(alpha=0.5)
)

Multi-class Example

import xgboost as xgb
import jax
import jax.numpy as jnp
from jaxboost import multiclass_objective

@multiclass_objective(num_classes=3)
def custom_multiclass(logits, label):
    # logits: (num_classes,), label: scalar
    probs = jax.nn.softmax(logits)
    return -jnp.log(probs[label] + 1e-7)

dtrain = xgb.DMatrix(X_train, label=y_train)
model = xgb.train(
    {"num_class": 3, "max_depth": 4, "eta": 0.1},
    dtrain,
    num_boost_round=100,
    obj=custom_multiclass.xgb_objective
)

Sklearn Interface

Use custom objectives with XGBClassifier and XGBRegressor:

from xgboost import XGBClassifier
from jaxboost import focal_loss

clf = XGBClassifier(
    objective=focal_loss.sklearn_objective,
    n_estimators=100,
    max_depth=4
)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

Why jaxboost?

Traditional Approach jaxboost
Derive gradients by hand Write loss, get gradients free
Derive Hessians by hand Write loss, get Hessians free
Error-prone math JAX autodiff is correct by construction
One loss = hours of work One loss = 5 lines of code

Benchmark Results

JAXBoost shines when XGBoost/LightGBM have no native solution:

Bounded Regression (Proportions in [0, 1])

Predicting proportions where standard MSE can predict outside valid range.

Model MSE Out-of-Bounds Code
JAXBoost Soft CE 0.0181 0% 5 lines
Native MSE + Clip 0.0201 0% post-hoc fix
Native MSE 0.0201 4.9% -

9.5% improvement + guaranteed valid outputs.

@auto_objective
def soft_crossentropy(y_pred, y_true):
    mu = sigmoid(y_pred)
    return -(y_true * jnp.log(mu) + (1 - y_true) * jnp.log(1 - mu))

Ordinal Regression (Wine Quality)

Predicting ordered categories (ratings 3-8) with Quadratic Weighted Kappa.

Model QWK Probabilistic
Regression + OptRounder 0.55 No
JAXBoost Squared CDF 0.54 Yes
Native Multi-class 0.51 Yes
Native Regression 0.48 No

JAXBoost ordinal objectives provide proper probability distributions over classes.

When to Use JAXBoost

Problem XGBoost/LightGBM Native? JAXBoost Advantage
Bounded regression [0,1] ❌ No ✅ 9.5% better MSE
Ordinal regression ❌ No ✅ Probabilistic outputs
Multi-task + missing labels ❌ No ✅ Proper masking
Custom business metrics ❌ No ✅ 5 lines of code

📊 Full benchmark details →

Requirements

  • Python >= 3.10
  • JAX >= 0.4.20

Documentation

Full documentation available at: https://jxucoder.github.io/jaxboost/

License

Apache 2.0

About

JAX autodiff for XGBoost/LightGBM objectives — write a loss function, get gradients and Hessians automatically.

Resources

License

Stars

Watchers

Forks

Packages

No packages published