In [None]:
from jax import grad, jit
import jax.numpy as jnp
import numpy as np

from smooth_binary_node import SmoothBinaryNode
from utils import mse

node = SmoothBinaryNode()

# feature 1 will be used s criteria
# The threshold is 50
params = {'weights': jnp.array([0, 10.0, 0]),
          'leaves': jnp.array([-1.0, 1.0]),
          'bias': 50.0}
features = jnp.array([[1, 43, 7], [1, 55, 7], [1, 46, 33],[1, 52, 16]])
pred = node.predict(params, features)
y_true = jnp.array([[-1, 1, -1, 1]])
err = mse(node)

# Compute gradient using jax
jgrad = jit(grad(err))
grads = jgrad(params, features, y_true)
# Parameters are optimal
# gradient is null
pred = node.predict(params, features)
print('pred (right pred)', pred)
print('err (right pred):', err(params, features, y_true))
print('grads (right pred)', grads)
#> grads {'bias': Array(0., dtype=float32, weak_type=True), 'leaves': Array([0., 0.], dtype=float32), 'weights': Array([0., 0., 0.], dtype=float32)}

# slightly perturbing leaves value
params = {'weights': jnp.array([0, 10.0, 0]),
          'leaves': jnp.array([-0.9, 1.0]),
          'bias': 50.0}
grads = jgrad(params, features, y_true)
# Parameters are no more optimal
# gradient is non null
pred = node.predict(params, features)
print('pred (leave err)', pred)
print('err: (leave err)', err(params, features, y_true))
print('grads (leave err)', grads)
#> grads {'bias': Array(-1.1920918e-09, dtype=float32, weak_type=True), 'leaves': Array([0.20000005, 0.        ], dtype=float32), 'weights': Array([0., 0., 0.], dtype=float32)}
# Error gradient indicates that the left leave value has to be increased


# Let's try to perturbate bias this time
params = {'weights': jnp.array([0, 10.0, 0]),
          'leaves': jnp.array([-1.0, 1.0]),
          'bias': 45.0}
grads = jgrad(params, features, y_true)
# Parameters are no more optimal
# gradient is non null
pred = node.predict(params, features)
print('pred (bias err)', pred)
print('err: (bias err)', err(params, features, y_true))
print('grads', grads)
#> grads {'bias': Array(-1.1920918e-09, dtype=float32, weak_type=True), 'leaves': Array([0.20000005, 0.        ], dtype=float32), 'weights': Array([0., 0., 0.], dtype=float32)}
# Error gradient indicates that the left leave value has to be increased

learning_rate = 0.1
for i in range(0, 1000):
    for param_name, param_grad in grads.items():
        params[param_name] -= learning_rate * param_grad
    grads = jgrad(params, features, y_true)

pred = node.predict(params, features)
print('params (trained)', params)
print('pred (trained)', pred)
print('err: (trained)', err(params, features, y_true))
print('grads', grads)