In [None]:
import jax.numpy as jnp
from entmax_jax import entmax, entmax15, sparsemax

class SmoothBinaryNode():
    def __init__(self, left_node=None, right_node=None):
        self.left_node = left_node
        self.right_node = right_node
        self.scale = 1.0

    def _get_params(self, params):
        return params['weights'], params['bias'], self.scale, params['leaves']

    def _choices(self, params, features):
        weights, bias, scale, leaves = self._get_params(params)
        feature = features.dot(entmax15(weights))
        choices = entmax15(jnp.array([-(feature - bias) / scale,
                                      jnp.zeros(feature.shape)]).T)
        return choices

    def left(self, params, features):
        return self._choices(params, features)[:, 0]

    def right(self, params, features):
        return self._choices(params, features)[:, 1]

    def _val(self, params, features):
        weights, bias, scale, leaves = self._get_params(params)
        feature = features.dot(entmax15(weights))
        choices = jnp.array([-(feature - bias) / scale, jnp.zeros(feature.shape)]).T
        pred = jnp.dot(entmax15(choices), leaves)
        return pred

    def left_val(self, params, features):
        if self.left_node:
            # ask child node
            print('---> ask left node')
            return self.left_node.left_val(params['left_params'], features)
        pred = self._val(params, features)
        return pred

    def right_val(self, params, features):
        if self.right_node:
            # ask child node
            print('---> ask right node')
            return self.right_node.right_val(params['right_params'], features)
        pred = self._val(params, features)
        return pred

    def  predict(self, params, features):
        return self.left(params, features) * self.left_val(params, features) + self.right(params, features) * self.right_val(params, features)