In [None]:
# Growing Neural Tree
class GrowingNeuralTree:
    def __init__(self, max_depth=6):
        self.max_depth = max_depth
        self.nodes = []

    def _sequential_predict_probs(self, X, y_true_int=None):
        """
        Routing inference:
        - Start at root
        - If a sample is wrong at a node (during evaluation where truth is known), send it to next node
        - Latest node prediction is retained for each sample
        """
        active = np.arange(len(X))
        final_probs = np.zeros((len(X), 10), dtype=float)

        for node in self.nodes:
            if len(active) == 0:
                break

            p = probs(node, X[active])
            final_probs[active] = p

            if y_true_int is not None:
                pred = pred_int_from_probs(p)
                wrong = (pred != y_true_int[active])
                active = active[wrong]
            else:
                break

        return final_probs

    def _misclassified_indices(self, X_train, y_train_int):
        p = self._sequential_predict_probs(X_train, y_true_int=y_train_int)
        yhat = pred_int_from_probs(p)
        return np.where(yhat != y_train_int)[0]

    def _log_metrics(self, logs, depth,
                     X_train, y_train_int, y_train_oh,
                     X_val, y_val_int, y_val_oh,
                     X_test, y_test_int, y_test_oh):

        p_tr = self._sequential_predict_probs(X_train, y_true_int=y_train_int)
        p_va = self._sequential_predict_probs(X_val, y_true_int=y_val_int)
        p_te = self._sequential_predict_probs(X_test, y_true_int=y_test_int)

        yhat_tr = pred_int_from_probs(p_tr)
        yhat_va = pred_int_from_probs(p_va)
        yhat_te = pred_int_from_probs(p_te)

        tr_acc = float(accuracy_score(y_train_int, yhat_tr))
        va_acc = float(accuracy_score(y_val_int, yhat_va))
        te_acc = float(accuracy_score(y_test_int, yhat_te))

        tr_loss = ce_loss_from_probs(p_tr, y_train_oh)
        va_loss = ce_loss_from_probs(p_va, y_val_oh)
        te_loss = ce_loss_from_probs(p_te, y_test_oh)

        mis = int(np.sum(yhat_tr != y_train_int))

        logs["depth"].append(depth)
        logs["train_loss"].append(tr_loss)
        logs["val_loss"].append(va_loss)
        logs["test_loss"].append(te_loss)
        logs["train_acc"].append(tr_acc)
        logs["val_acc"].append(va_acc)
        logs["test_acc"].append(te_acc)
        logs["train_misclassified"].append(mis)

        print(f"[Depth {depth}] mis={mis} | train_acc={tr_acc:.4f} | val_acc={va_acc:.4f} | test_acc={te_acc:.4f}")

    def fit(self,
            X_train, y_train_int, y_train_oh,
            X_val, y_val_int, y_val_oh,
            X_test, y_test_int, y_test_oh,
            root_epochs=10, root_bs=256, root_lr=1e-3,
            child_epochs=20, child_bs=128, child_lr=1e-4,
            stop_misclassified_threshold=0,
            stop_val_acc_threshold=None):

        logs = {
            "depth": [],
            "train_loss": [], "val_loss": [], "test_loss": [],
            "train_acc": [], "val_acc": [], "test_acc": [],
            "train_misclassified": []
        }

        # Root node (single parent node)
        root = create_softmax_node(lr=root_lr)
        root.fit(X_train, y_train_oh,
                 validation_data=(X_val, y_val_oh),
                 epochs=root_epochs, batch_size=root_bs, verbose=1)
        self.nodes = [root]

        self._log_metrics(logs, 0,
                          X_train, y_train_int, y_train_oh,
                          X_val, y_val_int, y_val_oh,
                          X_test, y_test_int, y_test_oh)

        # Grow children
        for depth in range(1, self.max_depth + 1):
            mis_idx = self._misclassified_indices(X_train, y_train_int)

            if len(mis_idx) <= stop_misclassified_threshold:
                print("Stopping: misclassified threshold reached.")
                break

            if stop_val_acc_threshold is not None and logs["val_acc"][-1] >= stop_val_acc_threshold:
                print("Stopping: validation accuracy threshold reached.")
                break

            X_mis = X_train[mis_idx]
            y_mis_oh = y_train_oh[mis_idx]

            child = create_softmax_node(lr=child_lr)
            child.fit(X_mis, y_mis_oh,
                      validation_data=(X_val, y_val_oh),
                      epochs=child_epochs, batch_size=child_bs, verbose=1)

            self.nodes.append(child)

            self._log_metrics(logs, depth,
                              X_train, y_train_int, y_train_oh,
                              X_val, y_val_int, y_val_oh,
                              X_test, y_test_int, y_test_oh)

        return logs