In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import prada
import xgboost as xgb
import veritas

In [2]:
d = prada.Spambase()
d.load_dataset()
d.robust_normalize()
dtrain, dtest = d.split(0.6)
dtest, dvalid = dtest.split(0.5)

loading cached /home/laurens/prada_data/Spambase.h5


# Train a model

In [72]:
clf = xgb.XGBClassifier(
    n_estimators=10,
    max_leaves=8
)
clf.fit(dtrain.X, dtrain.y)
at = veritas.get_addtree(clf, silent=True)

veritas.test_conversion(at, dtrain.X.to_numpy().astype(veritas.FloatT), clf.predict_proba(dtrain.X)[:,1])

mtrain = dtrain.metric(clf)
mvalid = dvalid.metric(clf)
mtest = dtest.metric(clf)

print(f"mtrain {mtrain:.3f} mvalid {mvalid:.3f} mtest {mtest:.3f}")

test_conversion: no problems detected (rel_tol 0.0001)
mtrain 0.945 mvalid 0.933 mtest 0.934


In [73]:
def transformx_lookup(t):
    K = t.num_nodes()
    lookup = np.zeros((K, K-1), dtype=bool) # exclude root

    for l in t.get_leaf_ids():
        n = l
        while not t.is_root(n):
            lookup[l, n-1] = 1
            n = t.parent(n)

    return lookup

def transformx_for_tree(X, t):
    lookup = transformx_lookup(t)
    N = X.shape[0]
    K = t.num_nodes()
    xx = np.zeros((N, K-1), dtype=np.float32)

    for i, l in enumerate(t.eval_node(X)):
        xx[i, :] = lookup[l, :]

    return xx

def transformx(X, at):
    xx = np.hstack([
        transformx_for_tree(X, t) for t in at
    ])
    return xx

In [74]:
xxtrain = torch.from_numpy(transformx(dtrain.X, at))
yytrain = torch.from_numpy(dtrain.y.to_numpy().reshape(-1, 1).astype(np.float32))
xxvalid = torch.from_numpy(transformx(dvalid.X, at))
yyvalid = torch.from_numpy(dvalid.y.to_numpy().reshape(-1, 1).astype(np.float32))

In [75]:
class Net(nn.Module):
    def __init__(self, at):
        super().__init__()
        self.width = sum(t.num_nodes()-1 for t in at)
        self.lin = nn.Linear(self.width, at.num_leaf_values())

    def forward(self, x):
        return F.sigmoid(self.lin(x))

In [76]:
model = Net(at)

In [90]:
learning_rate = 1e-3
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=0.04)
for t in range(2000):
    # Forward pass: compute predicted y by passing x to the model.
    ypred = model(xxtrain)

    # Compute and print loss.
    loss = F.binary_cross_entropy(ypred, yytrain)
    if t % 100 == 99:
        print(t, loss.item())

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

99 0.1699610948562622
199 0.16998697817325592
299 0.16999202966690063
399 0.16999799013137817
499 0.17000199854373932
599 0.16999542713165283
699 0.16999460756778717
799 0.16999605298042297
899 0.1699967384338379
999 0.16999608278274536
1099 0.16999590396881104
1199 0.16999614238739014
1299 0.1699962019920349
1399 0.16999611258506775
1499 0.16999609768390656
1599 0.16999612748622894
1699 0.16999612748622894
1799 0.16999612748622894
1899 0.16999611258506775
1999 0.16999611258506775


In [91]:
with torch.no_grad():
    ypred_tr = model(xxtrain)
    ypred_va = model(xxvalid)

acc_tr = np.mean((ypred_tr > 0.5).numpy().flatten() == dtrain.y.to_numpy())
acc_va = np.mean((ypred_va > 0.5).numpy().flatten() == dvalid.y.to_numpy())

print(f"mtrain {acc_tr:.5f}, mvalid {acc_va:.5f}")

mtrain 0.94402, mvalid 0.92464


# Getting the weights back to the trees

In [92]:
def update_addtree_weights(at, model):
    weights = model.lin.weight.detach().numpy().ravel()
    offset = 0

    at.set_base_score(0, model.lin.bias)

    for m, t in enumerate(at):
        lookup = transformx_lookup(t)
        K = t.num_nodes()-1
        ws = weights[offset:offset+K]
        offset += K
        for l in t.get_leaf_ids():
            v = ws.dot(lookup[l, :])
            #vo = t.get_leaf_value(l, 0)
            t.set_leaf_value(l, 0, v)
            #print(f"{m:4d}{l:4d}: {vo:+6.3f} -> {v:+6.3f} ({abs(vo-v):+6.3f})")

atc = at.copy()
update_addtree_weights(atc, model)

for m in [at, atc]:
    ypred_tr = m.eval(dtrain.X.to_numpy().astype(veritas.FloatT))
    ypred_va = m.eval(dvalid.X.to_numpy().astype(veritas.FloatT))
    acc_tr = np.mean((ypred_tr > 0.0).flatten() == dtrain.y.to_numpy())
    acc_va = np.mean((ypred_va > 0.0).flatten() == dvalid.y.to_numpy())

    print(f"mtrain {acc_tr:.5f}, mvalid {acc_va:.5f}")

mtrain 0.94511, mvalid 0.93261
mtrain 0.94402, mvalid 0.92464


In [93]:
print(atc[len(atc)-1])

Node(id=0, split=[F26 < 0.00862069], sz=15, left=1, right=2)
├─ Node(id=1, split=[F16 < 0.340426], sz=13, left=3, right=4)
│  ├─ Node(id=3, split=[F22 < 0.314286], sz=7, left=5, right=6)
│  │  ├─ Node(id=5, split=[F51 < 1.12688], sz=3, left=7, right=8)
│  │  │  ├─ Leaf(id=7, sz=1, value=[-0.168157])
│  │  │  └─ Leaf(id=8, sz=1, value=[0.146894])
│  │  ├─ Node(id=6, split=[F16 < 0.170213], sz=3, left=9, right=10)
│  │  │  ├─ Leaf(id=9, sz=1, value=[0.316405])
│  │  │  └─ Leaf(id=10, sz=1, value=[0.209658])
│  ├─ Node(id=4, split=[F24 < 0.0718232], sz=5, left=11, right=12)
│  │  ├─ Node(id=11, split=[F11 < 0.781457], sz=3, left=13, right=14)
│  │  │  ├─ Leaf(id=13, sz=1, value=[0.57572])
│  │  │  └─ Leaf(id=14, sz=1, value=[0.396323])
│  │  └─ Leaf(id=12, sz=1, value=[0.208413])
└─ Leaf(id=2, sz=1, value=[-0.166717])



In [94]:
print(at[len(at)-1])

Node(id=0, split=[F26 < 0.00862069], sz=15, left=1, right=2)
├─ Node(id=1, split=[F16 < 0.340426], sz=13, left=3, right=4)
│  ├─ Node(id=3, split=[F22 < 0.314286], sz=7, left=5, right=6)
│  │  ├─ Node(id=5, split=[F51 < 1.12688], sz=3, left=7, right=8)
│  │  │  ├─ Leaf(id=7, sz=1, value=[-0.0827581])
│  │  │  └─ Leaf(id=8, sz=1, value=[0.29238])
│  │  ├─ Node(id=6, split=[F16 < 0.170213], sz=3, left=9, right=10)
│  │  │  ├─ Leaf(id=9, sz=1, value=[0.35356])
│  │  │  └─ Leaf(id=10, sz=1, value=[-0.00801275])
│  ├─ Node(id=4, split=[F24 < 0.0718232], sz=5, left=11, right=12)
│  │  ├─ Node(id=11, split=[F11 < 0.781457], sz=3, left=13, right=14)
│  │  │  ├─ Leaf(id=13, sz=1, value=[0.37056])
│  │  │  └─ Leaf(id=14, sz=1, value=[0.0510146])
│  │  └─ Leaf(id=12, sz=1, value=[-0.256735])
└─ Leaf(id=2, sz=1, value=[-0.318531])

