**Tutorial: how to wrap a custom loss function with CustomGBM**  

Let's import all relevant packages and append the path to the code used for this example

In [9]:
# general packages
import sys
sys.path.append("../code")
from custom_gbm import CustomGBM
from loss_function import loss_function
import pickle as pkl
from sklearn.metrics import *
import time
import lightgbm as lgb
import numpy as np

# jax dependencies
from jax import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import jit
from functools import partial

Let's import the data we will use for this tutorial.  

We are going to train a Quantitative Structure-Activity Relationship model for identifying CYP2C9 substrates. The raw data was downloaded from Therapeutic Data Commons (https://tdcommons.ai/single_pred_tasks/adme/#cyp2c9-substrate-carbon-mangels-et-al).  

Compounds have been converted already to 208 2D molecular descriptors from RDKIT. The training set contains 468 compounds (90 are active), while the test set has 135 (38 are active). The train and test sets were obtained via scaffold split using the TCD API.

In [2]:
with open('../data/train.pkl', 'rb') as handle:
    train = pkl.load(handle)
with open('../data/test.pkl', 'rb') as handle:
    test = pkl.load(handle) 

x_train, y_train = train
x_test, y_test = test

Let's make a new custom loss function. For this example, we will use the binary cross-entropy.  

There are four rules to follow to successfully implement a custom loss function with CustomGBM:  
1. The new loss function must be subclassed from the "loss_function" class, and the "init" method of the parent class must be called during the "init" method of the new loss function.
2. The new loss function can have any number of arguments, but it must include the y_true vector of the training set. This then needs to be passed to the "init" method of the parent class.  
3. It must include a "call" method, which computes the loss from the raw LightGBM output (e.g. logits).  
4. The "call" method must use JAX for speeding up computation. As such, replace numpy with jnp and make sure to wrap it with the decorator shown below.  

For additional examples, take a look at `focal.py` and other pre-implemented loss functions in "../code".

The parent class takes care of computing the gradients, hessians, optimal init_score and class weights. Check `loss_function.py` in "../code" for further information.

In [3]:
class cross_entropy(loss_function):
    
    def __init__(self,
                 y_true
                 ): 
        
        super(cross_entropy, self).__init__(y_true)
    
    @partial(jit, static_argnums=(0,))
    def __call__(self, y_true, y_pred):

        p = 1/(1+jnp.exp(-y_pred))
        q = 1-p
        pos_loss = jnp.log(p)
        neg_loss = jnp.log(q)
        
        return y_true * pos_loss + (1 - y_true) * neg_loss

Now that we have implemented the new custom class, we can make a new LightGBM model with it using the CustomGBM API. We need to pass the class as the first argument, followed by the additional arguments of the loss (except y_true) and the arguments for the LightGBM model.  

We then train the model, measure the training time and evaluate PR-AUC on the test set.

In [6]:
booster_params = {"num_boost_round":100, "verbose":-100}
loss_fn = cross_entropy
loss_params = {}
gbm = CustomGBM(loss_fn, loss_params, booster_params)

t_start = time.time()
gbm.fit(x_train, y_train)
t_end = time.time()
t_custom = t_end - t_start

predictions_custom = gbm.predict(x_test)
pr_auc_custom = average_precision_score(y_test, predictions_custom)

print(f"CustomGBM PR-AUC: {pr_auc_custom}")
print(f"CustomGBM training time (s): {t_custom}")

CustomGBM PR-AUC: 0.35089992803682485
CustomGBM training time (s): 0.18747258186340332




To ensure that the implementation is correct, let's replicate the training procedure using a standard LightGBM classifier with the default cross entropy loss.

In [7]:
gbm = lgb.LGBMClassifier(n_estimators=100)

t_start = time.time()
gbm.fit(x_train, y_train)
t_end = time.time()
t_default = t_end - t_start

predictions_default = gbm.predict_proba(x_test)[:,1]
pr_auc_default = average_precision_score(y_test, predictions_default)

print(f"Default PR-AUC: {pr_auc_default}")
print(f"Default training time (s): {t_default}")

Default PR-AUC: 0.35089992803682485
Default training time (s): 0.1328582763671875


The performance is exactly identical and the training time is in the same order of magnitude, albeit a bit slower. This is expected given that LightGBM uses C for the calculations and has analytical formulas for the gradients and the hessians, while we use numerical approximations.  

As a last check, let's verify that the numerical difference between the predictions is sufficiently small:

In [10]:
print(f"Average prediction delta: {np.mean(np.abs(predictions_custom - predictions_default))})")

Average prediction delta: 1.2600912305828415e-08)
