JAX autodiff for XGBoost/LightGBM objectives.
Write a loss function, get gradients and Hessians automatically. No manual derivation needed.
Works with XGBoost and LightGBM.
pip install jaxboostimport xgboost as xgb
import jax.numpy as jnp
from jaxboost import auto_objective, focal_loss, huber, quantile
# Prepare your data
dtrain = xgb.DMatrix(X_train, label=y_train)
params = {"max_depth": 4, "eta": 0.1}
# Built-in objectives - just use them
model = xgb.train(params, dtrain, num_boost_round=100, obj=focal_loss.xgb_objective)
model = xgb.train(params, dtrain, num_boost_round=100, obj=huber.xgb_objective)
model = xgb.train(params, dtrain, num_boost_round=100, obj=quantile(0.9).xgb_objective)
# Custom objective - write the loss, autodiff handles the rest
@auto_objective
def asymmetric_mse(y_pred, y_true, alpha=0.7):
error = y_true - y_pred
return jnp.where(error > 0, alpha * error**2, (1 - alpha) * error**2)
model = xgb.train(params, dtrain, num_boost_round=100, obj=asymmetric_mse.xgb_objective)import lightgbm as lgb
from jaxboost import huber
train_data = lgb.Dataset(X_train, label=y_train)
params = {"max_depth": 4, "learning_rate": 0.1}
model = lgb.train(params, train_data, num_boost_round=100, fobj=huber.lgb_objective)| Objective | Description |
|---|---|
mse |
Mean squared error |
huber |
Huber loss (robust to outliers) |
pseudo_huber |
Smooth approximation of Huber loss |
log_cosh |
Log-cosh loss |
mae_smooth |
Smooth approximation of MAE |
quantile(q) |
Quantile regression |
asymmetric(alpha) |
Asymmetric squared error |
poisson |
Poisson deviance (count data) |
gamma |
Gamma deviance (positive continuous) |
tweedie(p) |
Tweedie deviance |
| Objective | Description |
|---|---|
focal_loss |
Focal loss for imbalanced data |
binary_crossentropy |
Standard log loss |
weighted_binary_crossentropy |
Weighted binary cross-entropy |
hinge_loss |
SVM-style hinge loss |
| Objective | Description |
|---|---|
softmax_cross_entropy |
Standard multi-class |
focal_multiclass |
Focal loss for multi-class |
label_smoothing(eps) |
Label smoothing regularization |
class_balanced |
Class-balanced loss |
| Objective | Description |
|---|---|
aft |
Accelerated failure time (log-normal) |
weibull_aft |
Weibull AFT model |
| Objective | Description |
|---|---|
ordinal_logit |
Cumulative Link Model (logit link) |
ordinal_probit |
Cumulative Link Model (probit link) |
qwk_ordinal |
QWK-aligned Expected Quadratic Error |
squared_cdf_ordinal |
CRPS / Ranked Probability Score |
hybrid_ordinal |
NLL + EQE hybrid |
sord_objective |
SORD (Soft Ordinal) from SLACE paper |
oll_objective |
OLL (Ordinal Log-Loss) from SLACE paper |
slace_objective |
SLACE (AAAI 2025) |
| Objective | Description |
|---|---|
multi_task_regression |
Multiple regression targets |
multi_task_classification |
Multiple classification targets |
multi_task_huber |
Multi-task Huber loss |
multi_task_quantile |
Multi-task quantile loss |
MaskedMultiTaskObjective |
Handle missing labels |
| Objective | Description |
|---|---|
gaussian_nll |
Predict mean + variance |
laplace_nll |
Predict median + scale |
XGBoost/LightGBM have no native ordinal objective. JAXBoost implements proper Cumulative Link Models:
from jaxboost import ordinal_logit, qwk_ordinal
# Wine quality: 6 ordered classes (3-8 mapped to 0-5)
ordinal = ordinal_logit(n_classes=6)
ordinal.init_thresholds_from_data(y_train)
# Works with XGBoost
model = xgb.train(params, dtrain, obj=ordinal.xgb_objective)
# Or LightGBM
model = lgb.train(params, train_data, fobj=ordinal.lgb_objective)
# Get class probabilities
probs = ordinal.predict_proba(model.predict(dtest))
classes = ordinal.predict(model.predict(dtest))When using custom objectives, use matching evaluation metrics:
from jaxboost import ordinal_logit
from jaxboost.metric import qwk_metric, mae_metric
ordinal = ordinal_logit(n_classes=6)
ordinal.init_thresholds_from_data(y_train)
# Train with custom metric monitoring
model = xgb.train(
{'disable_default_eval_metric': 1, 'max_depth': 4}, # Disable default metrics!
dtrain,
obj=ordinal.xgb_objective,
custom_metric=ordinal.qwk_metric.xgb_metric, # Built-in QWK metric
evals=[(dtest, 'test')]
)| Category | Metrics |
|---|---|
| Ordinal | qwk_metric, ordinal_mae_metric, ordinal_accuracy_metric, adjacent_accuracy_metric |
| Classification | auc_metric, f1_metric, accuracy_metric, precision_metric, recall_metric |
| Regression | mse_metric, rmse_metric, mae_metric, r2_metric |
| Bounded | bounded_mse_metric, out_of_bounds_metric |
The @auto_objective decorator turns any loss function into an XGBoost/LightGBM objective:
import xgboost as xgb
import lightgbm as lgb
import jax.numpy as jnp
from jaxboost import auto_objective
@auto_objective
def my_custom_loss(y_pred, y_true, **kwargs):
# Write your loss here - JAX computes grad/hess automatically
return (y_pred - y_true) ** 2
# Use with XGBoost
dtrain = xgb.DMatrix(X_train, label=y_train)
params = {"max_depth": 4, "eta": 0.1}
model = xgb.train(params, dtrain, num_boost_round=100, obj=my_custom_loss.xgb_objective)
# Use with LightGBM
train_data = lgb.Dataset(X_train, label=y_train)
params = {"max_depth": 4, "learning_rate": 0.1}
model = lgb.train(params, train_data, num_boost_round=100, fobj=my_custom_loss.lgb_objective)
# Pass parameters
model = xgb.train(
params, dtrain, num_boost_round=100,
obj=my_custom_loss.get_xgb_objective(alpha=0.5)
)import xgboost as xgb
import jax
import jax.numpy as jnp
from jaxboost import multiclass_objective
@multiclass_objective(num_classes=3)
def custom_multiclass(logits, label):
# logits: (num_classes,), label: scalar
probs = jax.nn.softmax(logits)
return -jnp.log(probs[label] + 1e-7)
dtrain = xgb.DMatrix(X_train, label=y_train)
model = xgb.train(
{"num_class": 3, "max_depth": 4, "eta": 0.1},
dtrain,
num_boost_round=100,
obj=custom_multiclass.xgb_objective
)Use custom objectives with XGBClassifier and XGBRegressor:
from xgboost import XGBClassifier
from jaxboost import focal_loss
clf = XGBClassifier(
objective=focal_loss.sklearn_objective,
n_estimators=100,
max_depth=4
)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)| Traditional Approach | jaxboost |
|---|---|
| Derive gradients by hand | Write loss, get gradients free |
| Derive Hessians by hand | Write loss, get Hessians free |
| Error-prone math | JAX autodiff is correct by construction |
| One loss = hours of work | One loss = 5 lines of code |
JAXBoost shines when XGBoost/LightGBM have no native solution:
Predicting proportions where standard MSE can predict outside valid range.
| Model | MSE | Out-of-Bounds | Code |
|---|---|---|---|
| JAXBoost Soft CE | 0.0181 | 0% | 5 lines |
| Native MSE + Clip | 0.0201 | 0% | post-hoc fix |
| Native MSE | 0.0201 | 4.9% | - |
9.5% improvement + guaranteed valid outputs.
@auto_objective
def soft_crossentropy(y_pred, y_true):
mu = sigmoid(y_pred)
return -(y_true * jnp.log(mu) + (1 - y_true) * jnp.log(1 - mu))Predicting ordered categories (ratings 3-8) with Quadratic Weighted Kappa.
| Model | QWK | Probabilistic |
|---|---|---|
| Regression + OptRounder | 0.55 | No |
| JAXBoost Squared CDF | 0.54 | Yes |
| Native Multi-class | 0.51 | Yes |
| Native Regression | 0.48 | No |
JAXBoost ordinal objectives provide proper probability distributions over classes.
| Problem | XGBoost/LightGBM Native? | JAXBoost Advantage |
|---|---|---|
| Bounded regression [0,1] | ❌ No | ✅ 9.5% better MSE |
| Ordinal regression | ❌ No | ✅ Probabilistic outputs |
| Multi-task + missing labels | ❌ No | ✅ Proper masking |
| Custom business metrics | ❌ No | ✅ 5 lines of code |
- Python >= 3.10
- JAX >= 0.4.20
Full documentation available at: https://jxucoder.github.io/jaxboost/
Apache 2.0