In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import pandas as pd
import numpy as np
import xgboost as xgb

from sklearn.metrics import mean_absolute_error

from treeck import SplitTree
from treeck.xgb import addtree_from_xgb_model
from treeck.verifier import Verifier
from treeck.distributed import DistributedVerifier
from treeck.z3backend import Z3Backend as Backend
import z3

from dask.distributed import Client

In [None]:
DATA_PATH = "~/kuleuven/phd/data"
data = pd.read_hdf(os.path.join(DATA_PATH, "youtube", "youtube.h5"))

num_examples = data.shape[0]
num_features = data.shape[1]

np.random.seed(222)
indices = np.random.permutation(num_examples)

m = int(num_examples*0.9)
Itrain = indices[0:m]
Itest = indices[m:]

wordsonly_data = data.iloc[:, 0:373]

Xtrain = wordsonly_data.iloc[Itrain, :].to_numpy()
Xtest = wordsonly_data.iloc[Itest, :].to_numpy()
ytrain = data.iloc[Itrain, -1].to_numpy()
ytest = data.iloc[Itest, -1].to_numpy()

dtrain = xgb.DMatrix(Xtrain, label=ytrain, missing=None)
dtest = xgb.DMatrix(Xtest, label=ytest, missing=None)

In [None]:
params = {
    "objective": "reg:squarederror",
    "tree_method": "hist",
    "max_depth": 10,
    "learning_rate": 0.5,
    "eval_metric": "mae",
    "seed": 0,
}

model = xgb.train(params, dtrain, num_boost_round=50,
                  early_stopping_rounds=5,
                  evals=[(dtrain, "train"), (dtest, "test")])

In [None]:
at = addtree_from_xgb_model(num_features, model)
pred_m = model.predict(xgb.DMatrix(Xtest[:1000, :]))
pred_a = at.predict(Xtest[:1000, :])
mae_m = mean_absolute_error(ytest[:1000], pred_m)
mae_a = mean_absolute_error(ytest[:1000], pred_a)

print(f"xgb: {mae_m}, at: {mae_a} diff: {mean_absolute_error(pred_m, pred_a)}")

In [None]:
# (41, 'txt_black_panther')
# (204, 'txt_marvel'),
# (215, 'txt_movie'),
# (216, 'txt_movies'),

words2id = dict([(w[4:], i) for i, w in enumerate(wordsonly_data.columns)])

In [None]:
def vfactory(at, leaf):
    offset = 10
    max_sum_offset = 50
    v = Verifier(at, leaf, Backend(), num_instances=2)
    
    pbeq = []
    for w, i in words2id.items():
        xvar0, xvar1 = v.xvar(i, instance=0), v.xvar(i, instance=1)
        if w in ["marvel", "black_panther", "movie", "movies"]:
            v.add_constraint((xvar0 == 1.0) & (xvar1 == 1.0))
        else:
            bvar_name = f"flag{i}"
            v.add_bvar(bvar_name)
            bvar = v.bvar(bvar_name)
            v.add_constraint(z3.If(bvar.get(), xvar0.get() != xvar1.get(),
                                               xvar0.get() == xvar1.get()))
            pbeq.append((bvar.get(), 1))

    v.add_constraint(z3.PbEq(pbeq, 1)) # at most N variables differ
    v.add_constraint(v.fvar(instance=0) > v.fvar(instance=1))

    return v

with Client("tcp://localhost:30333") as client:
    client.restart()
    st = SplitTree(at, {})
    dv = DistributedVerifier(client, st, vfactory,
            check_paths = True,
            saturate_workers_factor=1,
            stop_when_sat = False)
    dv.check()

In [None]:
from treeck.plot import TreePlot

g = TreePlot()
g.add_tree(dv._st.domtree())
g.g