In [1]:
from utils import prepare_jupyter
prepare_jupyter()

In [64]:
import numpy as np
import lightgbm as lgb

from scipy.special import softmax
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

from structure.Dataset import Dataset
from parser.lightgbm import parse_lightgbm
from extract import get_lgb_trees

In [65]:
iris = load_iris()
X_train, X_val, y_train, y_val = train_test_split(iris.data, iris.target, test_size=0.2)
iris_data = Dataset(X_train, y_train, iris.feature_names, iris.target_names)

n_estimators = 5
clf = lgb.LGBMClassifier(n_estimators=n_estimators, objective='softmax')
clf.fit(X_train, y_train)

trees = parse_lightgbm(clf, iris_data)

In [66]:
trees_preds = np.array([tree.predict(X_val) for tree in trees])
print(f'Shape = {trees_preds.shape}')

Shape = (15, 30)


## Jak LightGBM działa

* Softmax na wartościach `leaf_value` jeśli `objective='multiclass'`
* Suma leaf_value: Podzielone w taki sposób dla klas A, B oraz C: [A, B, C, A, B, C, ...]

In [67]:
clf_preds = clf.predict(X_val)

which = 0
preds = trees_preds[:, which]
y_true = y_val[which]
y_clf = clf_preds[which]

print(preds.shape)
preds = preds.reshape((n_estimators, len(iris.target_names)))
print(preds)
probs = softmax(np.sum(preds, axis=0))

(15,)
[[-1.1042137  -0.99861272 -1.19800417]
 [-0.0747298   0.13038377 -0.07151707]
 [-0.07195114  0.11231368 -0.06931541]
 [-0.06950033  0.10145332 -0.06742889]
 [-0.06738723  0.09276975 -0.06573405]]


In [68]:
print(f'Is {np.argmax(probs)}, clf says {y_clf}, should be {y_true}')

Is 1, clf says 1, should be 1


In [69]:
clf.predict_proba([X_val[which]]), probs

(array([[0.23789129, 0.54343157, 0.21867714]]),
 array([0.23789129, 0.54343157, 0.21867714]))

In [76]:
n_classes = len(iris.target_names)
preds = np.rollaxis(trees_preds, axis=1).reshape(len(X_val), n_estimators, n_classes)
probs = softmax(np.sum(preds, axis=1), axis=1)

probs

array([[0.23789129, 0.54343157, 0.21867714],
       [0.24899423, 0.25013052, 0.50087525],
       [0.24499309, 0.52964702, 0.22535989],
       [0.57759978, 0.20873435, 0.21366587],
       [0.57759978, 0.20873435, 0.21366587],
       [0.58205075, 0.20658375, 0.2113655 ],
       [0.57759978, 0.20873435, 0.21366587],
       [0.23789129, 0.54343157, 0.21867714],
       [0.24913077, 0.25008504, 0.50078418],
       [0.57076505, 0.21207781, 0.21715714],
       [0.2378794 , 0.54340442, 0.21871618],
       [0.23776124, 0.54353117, 0.21870759],
       [0.30110203, 0.42205227, 0.27684569],
       [0.23250995, 0.21275609, 0.55473396],
       [0.24040522, 0.5384551 , 0.22113968],
       [0.24899423, 0.25013052, 0.50087525],
       [0.24499309, 0.52964702, 0.22535989],
       [0.23250995, 0.21275609, 0.55473396],
       [0.57520548, 0.20990813, 0.21488639],
       [0.23250995, 0.21275609, 0.55473396],
       [0.23250995, 0.21275609, 0.55473396],
       [0.23789129, 0.54343157, 0.21867714],
       [0.

(30,)