In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import xgboost as xgb
import matplotlib.pyplot as plt
import scipy

from sklearn.metrics import accuracy_score, mean_absolute_error
from functools import partial

from treeck import SplitTree
from treeck.xgb import addtree_from_xgb_model
from treeck.verifier import Verifier
from treeck.distributed import DistributedVerifier
from treeck.z3backend import Z3Backend as Backend
import z3

from dask.distributed import Client

In [None]:
mat = scipy.io.loadmat("../tests/data/mnist.mat")
X = mat["X"]
y = mat["y"].reshape((70000,))

num_examples = X.shape[0]
num_features = X.shape[1]

np.random.seed(111)
indices = np.random.permutation(num_examples)

m = int(num_examples*0.9)
Itrain = indices[0:m]
Itest = indices[m:]

In [None]:
label = y==5
dtrain = xgb.DMatrix(X[Itrain], label[Itrain], missing=0)
dtest = xgb.DMatrix(X[Itest], label[Itest], missing=0)

params = {
    "objective": "binary:logistic",
    "tree_method": "hist",
    "max_depth": 6,
    "learning_rate": 0.25,
    "eval_metric": "error"
}
#xgb.cv(params, dtrain,
#      num_boost_round = 10,
#      nfold=5, metrics={'error'}, seed=0,
#      callbacks=[xgb.callback.print_evaluation(show_stdv=True),
#                 xgb.callback.early_stop(3)])
model = xgb.train(params, dtrain, num_boost_round=200,
                  early_stopping_rounds=5,
                  evals=[(dtrain, "train"), (dtest, "test")])

In [None]:
at = addtree_from_xgb_model(num_features, model)
at.base_score = 0.0
pred = model.predict(xgb.DMatrix(X), output_margin=True)
acc = accuracy_score(pred > 0.0, label)
print(f"accuracy: {acc}")
mae = mean_absolute_error(pred[:10000], at.predict(X[:10000]))
print(f"mae model difference {mae}")

In [None]:
instance_index = np.argmax(y == 5)
instance = X[instance_index, :]

fig, ax = plt.subplots()
ax.imshow(instance.reshape((28,28)))
plt.show()

In [None]:
def vfactory(instance, num_features, at, leaf):
    offset = 10
    max_sum_offset = 500
    v = Verifier(at, leaf, Backend())
    
    sum_constraint = 0
    for j, pixel in zip(range(num_features), instance):
        x = v.xvar(j)
        v.add_constraint((x > max(0, pixel-offset)) & (x < min(255, pixel+offset)))
        sum_constraint += z3.If(x.get()-pixel <= 0, pixel-x.get(), x.get()-pixel)
    v.add_constraint(sum_constraint < max_sum_offset)
    v.add_constraint(v.fvar() < 0.0)
    return v

vfac = partial(vfactory, instance, num_features)

with Client("tcp://localhost:30333") as client:
    #client.restart()
    st = SplitTree(at, {})
    dv = DistributedVerifier(client, st, vfac,
            check_paths = False,
            saturate_workers_factor=1,
            stop_when_sat = False)

    dv.check()

In [None]:
for domtree_node_id, res in dv.results.items():
    status = res["status"]
    print(domtree_node_id, status)
    
    if status.is_sat():
        fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(11,4))
        inst = np.array(res["model"]["xs"]).reshape((28,28))
        im0 = ax0.imshow(inst)
        im1 = ax1.imshow(inst-instance.reshape((28,28)))
        fig.colorbar(im0, ax=ax0)
        fig.colorbar(im1, ax=ax1)
        ax0.set_title(f"f={at.predict_single(instance):.4f}")
        ax1.set_title(f"f={res['model']['f']:.4f}")
        plt.show()