<a href="https://colab.research.google.com/github/baicheto/AML_Bitcoin/blob/Kri/AML_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Deep Neural Decision Forest

`Deep Neural Decision Forest` takes the same input as `DeepWalk` which is `feat_tensor`.

## DNDF Machinery

In [None]:
class NeuralDecisionForest(nn.Module):
    def __init__(self, input_dim, num_trees, depth, num_classes):
        super().__init__()
        self.num_trees    = num_trees
        self.depth        = depth
        self.num_leaves   = 2 ** depth
        self.num_dec_nodes= self.num_leaves - 1

        self.router = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, num_trees * self.num_dec_nodes)
        )

        self.leaf_probs = nn.Parameter(
            torch.randn(num_trees, self.num_leaves, num_classes) * 0.1
        )

        mask = torch.zeros(self.num_leaves, self.num_dec_nodes)
        for l in range(self.num_leaves):
            idx = 0
            for d in range(depth):
                bit = (l >> (depth-1-d)) & 1
                mask[l, idx] = float(1 - bit)
                idx = 2*idx + 1 + bit
                if idx >= self.num_dec_nodes: break
        self.register_buffer('leaf_mask', mask)

    def forward(self, x):
        B = x.size(0)
        logits = self.router(x)
        probs  = torch.sigmoid(logits).view(B, self.num_trees, self.num_dec_nodes)

        p = probs.unsqueeze(-2)
        m = self.leaf_mask.unsqueeze(0).unsqueeze(0)
        eps = 1e-6
        p = p.clamp(eps, 1-eps)
        leaf_reach = (p**m * (1-p)**(1-m)).prod(dim=-1)

        leaf_dist = torch.softmax(self.leaf_probs, dim=-1).unsqueeze(0)
        out = (leaf_reach.unsqueeze(-1) * leaf_dist).sum(dim=[1,2])
        return out

In [None]:
dnf = NeuralDecisionForest(
    input_dim   = feat_tensor.size(1),
    num_trees   = 10,
    depth       = 4,
    num_classes = 2
).to(device)

## Training Setup

In [None]:
opt = torch.optim.Adam(dnf.parameters(), lr=5e-4, weight_decay=1e-5)
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,
            mode='max', factor=0.5, patience=5, min_lr=1e-6)

In [None]:
best_val, wait = -1, 0
patience = 15

for epoch in range(1, 201):
    dnf.train()
    idxs = train_idx.cpu().numpy()
    np.random.shuffle(idxs)
    losses = []
    for b in range(0, len(idxs), 1024):
        batch = torch.tensor(idxs[b:b+1024], device=device)
        logits= dnf(feat_tensor[batch])
        loss  = criterion(logits, labels[batch])
        opt.zero_grad(); loss.backward(); opt.step()
        losses.append(loss.item())

    dnf.eval()
    with torch.no_grad():
        val_logits = dnf(feat_tensor[val_idx])
        val_probs  = F.softmax(val_logits, dim=1)[:,1].cpu().numpy()
        val_true   = labels[val_idx].cpu().numpy()
    val_pr = average_precision_score(val_true, val_probs)
    sched.step(val_pr)

    print(f"Epoch {epoch:03d}  loss={np.mean(losses):.4f}  Val PR={val_pr:.4f}")
    if val_pr > best_val:
        best_val, wait = val_pr, 0
        torch.save(dnf.state_dict(), "dnf_best.pt")
    else:
        wait += 1
        if wait >= patience:
            print("Early stopping."); break

dnf.load_state_dict(torch.load("dnf_best.pt"))

Epoch 001  loss=0.6916  Val PR=0.6931
Epoch 002  loss=0.6304  Val PR=0.6527
Epoch 003  loss=0.5198  Val PR=0.3922
Epoch 004  loss=0.4196  Val PR=0.3108
Epoch 005  loss=0.3364  Val PR=0.2693
Epoch 006  loss=0.2848  Val PR=0.2944
Epoch 007  loss=0.2583  Val PR=0.3605
Epoch 008  loss=0.2447  Val PR=0.3881
Epoch 009  loss=0.2370  Val PR=0.4454
Epoch 010  loss=0.2291  Val PR=0.5305
Epoch 011  loss=0.2216  Val PR=0.6476
Epoch 012  loss=0.2161  Val PR=0.7231
Epoch 013  loss=0.2101  Val PR=0.7512
Epoch 014  loss=0.2059  Val PR=0.7796
Epoch 015  loss=0.1997  Val PR=0.7918
Epoch 016  loss=0.1975  Val PR=0.8031
Epoch 017  loss=0.1924  Val PR=0.8109
Epoch 018  loss=0.1888  Val PR=0.8126
Epoch 019  loss=0.1872  Val PR=0.8153
Epoch 020  loss=0.1849  Val PR=0.8184
Epoch 021  loss=0.1824  Val PR=0.8192
Epoch 022  loss=0.1797  Val PR=0.8207
Epoch 023  loss=0.1788  Val PR=0.8225
Epoch 024  loss=0.1759  Val PR=0.8243
Epoch 025  loss=0.1760  Val PR=0.8248
Epoch 026  loss=0.1734  Val PR=0.8262
Epoch 027  l

<All keys matched successfully>

## Final Test Evaluation

In [None]:
dnf.eval()
with torch.no_grad():
    logits_all = dnf(feat_tensor)
    probs_all  = torch.softmax(logits_all, dim=1)[:,1].cpu().numpy()
    labels_all = labels.cpu().numpy()

y_test_true = labels_all[test_idx.cpu().numpy()]
y_test_prob = probs_all[test_idx.cpu().numpy()]
M_test   = y_test_true.shape[0]
prevalence = y_test_true.mean()

In [None]:
cutoffs = {
    "Top 0.1%": max(1, int(0.001 * M_test)),
    "Top 1%":   max(1, int(0.01  * M_test)),
    "Top 10%":  max(1, int(0.10  * M_test)),
    "Prevalence": max(1, int(prevalence * M_test)),
}

In [None]:
n_runs = 100
metrics = {
    "roc_auc": [],
    "pr_auc":  [],
    **{f"{name}_P": [] for name in cutoffs},
    **{f"{name}_R": [] for name in cutoffs},
    **{f"{name}_F1": [] for name in cutoffs},
}

rng       = np.random.RandomState(42)
half_size = M_test // 2

In [None]:
for run_i in range(n_runs):
    idxs = rng.choice(M_test, size=half_size, replace=False)
    y_bs = y_test_true[idxs]
    p_bs = y_test_prob[idxs]

    metrics["roc_auc"].append(roc_auc_score(y_bs, p_bs))
    metrics["pr_auc"].append(average_precision_score(y_bs, p_bs))

    order = np.argsort(p_bs)
    for name, k in cutoffs.items():
        topk = order[-k:]
        pred = np.zeros_like(p_bs, dtype=int)
        pred[topk] = 1

        metrics[f"{name}_P"].append(precision_score(y_bs, pred, zero_division=0))
        metrics[f"{name}_R"].append(recall_score(y_bs, pred, zero_division=0))
        metrics[f"{name}_F1"].append(f1_score(y_bs, pred, zero_division=0))

In [None]:
def fmt(arr):
    a = np.array(arr)
    return f"{a.mean():.3f} ± {a.std():.3f}"

print("\n=== DNDF Test Results (n=100 bootstraps) ===")
print(f"AUC-ROC  : {fmt(metrics['roc_auc'])}")
print(f"AUC-PR   : {fmt(metrics['pr_auc'])}")
for name in cutoffs:
    print(f"{name:12} Precision: {fmt(metrics[f'{name}_P'])}")
    print(f"{name:12} Recall   : {fmt(metrics[f'{name}_R'])}")
    print(f"{name:12} F1-score : {fmt(metrics[f'{name}_F1'])}")



=== DNDF Test Results (n=100 bootstraps) ===
AUC-ROC  : 0.759 ± 0.013
AUC-PR   : 0.584 ± 0.020
Top 0.1%     Precision: 1.000 ± 0.000
Top 0.1%     Recall   : 0.035 ± 0.002
Top 0.1%     F1-score : 0.067 ± 0.003
Top 1%       Precision: 0.995 ± 0.005
Top 1%       Recall   : 0.378 ± 0.018
Top 1%       F1-score : 0.548 ± 0.019
Top 10%      Precision: 0.159 ± 0.009
Top 10%      Recall   : 0.607 ± 0.021
Top 10%      F1-score : 0.252 ± 0.013
Prevalence   Precision: 0.290 ± 0.017
Prevalence   Recall   : 0.582 ± 0.021
Prevalence   F1-score : 0.387 ± 0.019
