# Unraveling scalar mults and countermeasures

In [None]:
import pickle
import itertools
import glob
import random
import math

from collections import Counter

import numpy as np
import pandas as pd
from scipy.stats import binom, entropy
from scipy.spatial import distance
from tqdm.auto import tqdm, trange
from statsmodels.stats.proportion import proportion_confint
from anytree import PreOrderIter, Walker

from pyecsca.ec.mult import *
from pyecsca.misc.utils import TaskExecutor, silent
from pyecsca.sca.re.tree import Map, Tree

from common import *

In [None]:
def conf_interval(p: float, samples: int, alpha: float = 0.05) -> tuple[float, float]:
    return proportion_confint(round(p*samples), samples, alpha, method="wilson")

In [None]:
def powers_of(k, max_power=20):
    return [k**i for i in range(1, max_power)]

def prod_combine(one, other):
    return [a * b for a, b in itertools.product(one, other)]

small_primes = [3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199]
medium_primes = [211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397]
large_primes = [401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997]
all_integers = list(range(1, 400))
all_even = list(range(2, 400, 2))
all_odd = list(range(1, 400, 2))
all_primes = small_primes + medium_primes + large_primes

divisor_map = {
    "small_primes": small_primes,
    "medium_primes": medium_primes,
    "large_primes": large_primes,
    "all_primes": all_primes,
    "all_integers": all_integers,
    "all_even": all_even,
    "all_odd": all_odd,
    "powers_of_2": powers_of(2),
    "powers_of_2_large": powers_of(2, 256),
    "powers_of_2_large_3": [i * 3 for i in powers_of(2, 256)],
    "powers_of_2_large_p1": [i + 1 for i in powers_of(2, 256)],
    "powers_of_2_large_m1": [i - 1 for i in powers_of(2, 256)],
    "powers_of_2_large_pmautobus": sorted(set([i + j for i in powers_of(2, 256) for j in range(-5,5) if i+j > 0])),
    "powers_of_3": powers_of(3),
}
divisor_map["all"] = list(sorted(set().union(*[v for v in divisor_map.values()])))

## Prepare
Select *divisor name* to restrict the features. Select *kind* to pick the probmap source.

In [None]:
divisor_name = "all"
kind = "all"
allfeats = divisor_map[divisor_name]

In [None]:
# Load
try:
    with open(f"{divisor_name}_{kind}_distrs.pickle", "rb") as f:
        distributions_mults = pickle.load(f)
except FileNotFoundError:
    with open(f"all_{kind}_distrs.pickle", "rb") as f:
        distributions_mults = pickle.load(f)
    for probmap in distributions_mults.values():
        probmap.narrow(allfeats)

In [None]:
nmults = len(distributions_mults.keys())
nallfeats = len(allfeats)

## Build dmap and tree

Select the n for building the tree.

In [None]:
nbuild = 10000
alpha = 0.05

In [None]:
# Now go over all divisors, cluster based on overlapping CI for given n?
io_map = {mult:{} for mult in distributions_mults.keys()}
for divisor in allfeats:
    prev_ci_low = None
    prev_ci_high = None
    groups = {}
    pvals = {}
    group = 0
    for mult, probmap in sorted(distributions_mults.items(), key=lambda item: -item[1][divisor]):
        # We are going from high to low p.
        pval = probmap[divisor]
        pvals[mult] = pval
        ci_low, ci_high = conf_interval(pval, nbuild, alpha)
        ci_low = max(ci_low, 0.0)
        ci_high = min(ci_high, 1.0)
        if (prev_ci_low is None and prev_ci_high is None) or prev_ci_low >= ci_high:
            g = groups.setdefault(f"arbitrary{group}", set())
            g.add(mult)
            group += 1
        else:
            g = groups.setdefault(f"arbitrary{group}", set())
            g.add(mult)
        prev_ci_low = ci_low
        prev_ci_high = ci_high
    
    #print(f"Divisor: {divisor}, num groups: {group}", end="\n\t")
    #for g in groups.values():
    #    print(len(g), end=", ")
    #print()
    for group, mults in groups.items():
        mult_pvals = [pvals[mult] for mult in mults]
        group_pval_avg = np.mean(mult_pvals)
        group_pval_var = np.var(mult_pvals)
        group_pval_min = np.min(mult_pvals)
        group_pval_max = np.max(mult_pvals)
        for mult in mults:
            io_map[mult][divisor] = (group,  group_pval_avg, group_pval_var, group_pval_min, group_pval_max)

# then build dmap
dmap = Map.from_io_maps(set(distributions_mults.keys()), io_map)

In [None]:
print(dmap.describe())

In [None]:
# deduplicate dmap
dmap.deduplicate()

In [None]:
print(dmap.describe())

In [None]:
# build a tree
with silent():
    tree = Tree.build(set(distributions_mults.keys()), dmap)

In [None]:
print(tree.describe())

In [None]:
print(tree.render_basic())

## Simulate distinguishing using a tree
We can now simulate distinguishing using the tree and how it behaves with increasing the number of samples per divisor collected.

In [None]:
simulations = 1000

for nattack in trange(100, 10000, 100):
    successes = 0
    pathiness = 0
    for i in range(simulations):
        true_mult = random.choice(list(distributions_mults.keys()))
        probmap = distributions_mults[true_mult]
        node = tree.root
        while True:
            if node.is_leaf:
                break
            divisor = node.dmap_input
            prob = probmap[divisor]
            sampled_prob = binom(nattack, prob).rvs() / nattack
            best_child = None
            true_child = None
            best_group_distance = None
            #print(f"Divisor: {divisor}, prob: {prob}, sampled: {sampled_prob}")
            for child in node.children:
                if true_mult in child.cfgs:
                    true_child = child
                group, group_pval_avg, group_pval_var, group_pval_min, group_pval_max = child.response
                group_distance = min(abs(sampled_prob - group_pval_min), abs(sampled_prob - group_pval_max))
                #print(f"Child {group}, {group_pval_avg}")
                if best_child is None or \
                    (group_distance < best_group_distance):
                    best_child = child
                    best_group_distance = group_distance
                if sampled_prob > group_pval_min and sampled_prob < group_pval_max:
                    best_child = child
                    break
            #print(f"Best {best_child.response}")
            if true_child is not None and true_child != best_child:
                pass
                #print(f"Mistake! {prob}, {sampled_prob} true:{true_child.response}, chosen:{best_child.response}")
            node = best_child
            if true_mult in node.cfgs:
                pathiness += 1
        #print(f"Arrived: {true_mult in node.cfgs}")
        if true_mult in node.cfgs:
            successes += 1
    print(f"{nattack}: success rate {successes/simulations}, pathiness {pathiness/simulations}")

## Simulate distinguishing using a distance metric

We need to first select some features (divisors) from the set of all divisors that we will query
the target with. This set should be the smallest (to not do a lot of queries) yet allow us to distinguish as
much as possible.

### Feature selection using trees + classification error

We can reuse the clustering + tree building approach above and just take the inputs that the greedy tree building choses as the features. However, we can also use more conventional feature selection approaches.

In [None]:
feats_in_tree = Counter()
for node in PreOrderIter(tree.root):
    if node.is_leaf:
        continue
    feats_in_tree[node.dmap_input] += 1
feats_in_tree = set(feats_in_tree.keys())

In [None]:
def bayes(nattack: int, feat_vector: list[int], feats, probmap):
    bayes.reverse = True
    log_likelihood = 0.0
    for sampled, divisor in zip(feat_vector, feats):
        other_p = probmap[divisor]
        log_prob = binom(nattack, other_p).logpmf(sampled)
        log_likelihood += log_prob
    return log_likelihood

def euclid(nattack: int, feat_vector: list[int], feats, probmap):
    euclid.reverse = False
    other_vector = np.zeros(nfeats)
    for i, divisor in enumerate(feats):
        other_vector[i] = probmap[divisor]
    return distance.euclidean(feat_vector, other_vector)

# TODO: Adjust scorers to penalize/reject when sampled prob of a feature is != 1.0 but the mult has that feature at 1.0 proba.

def one_simulation(nattack, true_mult, mults, feats, scorer,):
    probmap = mults[true_mult]
    feat_vector = []
    for divisor in feats:
        prob = probmap[divisor]
        sampled = binom(nattack, prob).rvs()
        feat_vector.append(sampled)
    scoring = []
    for other_mult, other_probmap in mults.items():
        score = scorer(nattack, feat_vector, feats, other_probmap)
        scoring.append((score, other_mult))
    scoring.sort(key=lambda item: item[0], reverse=scorer.reverse)
    for i, (sim, other) in enumerate(scoring):
        if other == true_mult:
            return i

def many_simulations(nattack, mults, feats, scorer, simulations):
    successes = {k:0 for k in range(1, 11)}
    mean_pos = 0
    mults_l = list(mults)
    for i in range(simulations):
        if len(mults) <= simulations:
            true_mult = mults_l[i]
        else:
            true_mult = random.choice(mults_l)
        pos = one_simulation(nattack, true_mult, mults, feats, scorer)
        mean_pos += pos
        for k in range(1, 11):
            if pos + 1 <= k:
                successes[k] += 1
    mean_pos /= simulations
    for i in successes.keys():
        successes[i] /= simulations
    return mean_pos, successes

def find_features_random(feat_subset, nfeat_range, nattack_range, num_workers, feat_retries, simulations, scorer):
    for nfeats in nfeat_range:
        for nattack in nattack_range:
            best_feats = None
            best_feats_mean_pos = None
            best_successes = None
            with TaskExecutor(max_workers=num_workers) as pool:
                for retry in range(feat_retries):
                    feats = random.sample(sorted(feat_subset), nfeats)
                    pool.submit_task(retry,
                                     many_simulations,
                                     nattack, distributions_mults, feats, scorer, simulations)
                for i, future in tqdm(pool.as_completed(), leave=False, desc="Retries", total=feat_retries, smoothing=0):
                    mean_pos, successes = future.result()
                    #print(f"{nfeats} {nattack}({i}): mean pos {mean_pos:.2f} top1: {successes[1]:.2f}, top5: {successes[5]:.2f}, top10: {successes[10]:.2f}")
                    if best_feats is None or best_feats_mean_pos > mean_pos:
                        best_feats = feats
                        best_feats_mean_pos = mean_pos
                        best_successes = successes
            
            print(f"Best results for {nfeats} feats at {nattack} samples out of {retries} random feat subsets.")
            print(f"Features: {best_feats}")
            print(f"mean_pos: {best_feats_mean_pos:.2f}")
            print(f"top1: {best_successes[1]:.2f}, top2: {best_successes[2]:.2f}, top5: {best_successes[5]:.2f}, top10: {best_successes[10]:.2f}")

In [None]:
simulations = 500
retries = 200
nfeats = trange(1, 11, leave=False, desc="nfeats")
nattack = trange(50, 350, 50, leave=False, desc="nattack")
num_workers = 30

selected_random_euclid = find_features_random(feats_in_tree, nfeats, nattack, num_workers, retries, simulations, euclid)

## Simulate distinguishing using a Bayes classifier

We need to first select some features (divisors) from the set of all divisors that we will query
the target with. This set should be the smallest (to not do a lot of queries) yet allow us to distinguish as
much as possible.

Then, we can build a true Bayes classifier. Since our features are conditionally independent (when conditioned on the class label) in our case naive Bayes == non-naive Bayes. We examine four feature selection algorithms:
 - Feature selection by pre-selection using tree-building and final selection by random subsets + classification error.
 - Feature selection via greedy classification error.
 - Feature selection via mRMR (maximal relevance, minimal redundancy) using mutual information.
 - Feature selection via JMI (Joint Mutual Information).

### Feature selection using trees + classification error

We can reuse the clustering + tree building approach above and just take the inputs that the greedy tree building choses as the features. However, we can also use more conventional feature selection approaches.

In [None]:
feats_in_tree = Counter()
for node in PreOrderIter(tree.root):
    if node.is_leaf:
        continue
    feats_in_tree[node.dmap_input] += 1
feats_in_tree = set(feats_in_tree.keys())

In [None]:
simulations = 500
retries = 200
nfeats = trange(1, 11, leave=False, desc="nfeats")
nattack = trange(50, 350, 50, leave=False, desc="nattack")
num_workers = 30

selected_random_bayes = find_features_random(feats_in_tree, nfeats, nattack, num_workers, retries, simulations, bayes)

### Feature selection via greedy classification
We can also use the classifier itself for feature selection. We iterate over all the divisors to pick the first feature with the best classifier results in simulation. Then we iteratively add features to it.

In [None]:
def find_features_greedy(nfeats, nattack, num_workers, simulations, scorer, start_features=None):
    available_feats = selected_divisors
    feats = []
    if start_features is not None:
        if nfeats <= len(start_features):
            raise ValueError("Features already picked.")
        feats.extend(start_features)
        for feat in start_features:
            available_feats.remove(feat)

    with TaskExecutor(max_workers=num_workers) as pool:
        while len(feats) < nfeats:
            for feat in available_feats:
                pool.submit_task(feat,
                                 many_simulations,
                                 nattack, distributions_mults, feats + [feat], scorer, simulations)
            best_feat = None
            best_feat_mean_pos = None
            best_successes = Noned
            for feat, future in tqdm(pool.as_completed(), total=len(available_feats), desc=f"Picking feature {len(feats)}", smoothing=0):
                mean_pos, successes = future.result()
                if best_feat is None or best_feat_mean_pos > mean_pos:
                    best_feat = feat
                    best_feat_mean_pos = mean_pos
                    best_successes = successes
            print(f"Picked {best_feat} with mean pos: {mean_pos:.2f}")
            print(f"top1: {best_successes[1]:.2f}, top2: {best_successes[2]:.2f}, top5: {best_successes[5]:.2f}, top10: {best_successes[10]:.2f}")
            feats.append(best_feat)
            available_feats.remove(best_feat)
    return feats

In [None]:
nfeats = 5
nattack = 100
num_workers = 30
simulations = 500
scorer = bayes

selected_greedy = find_features_greedy(nfeats, nattack, num_workers, simulations, scorer)

### Feature selection via mRMR using mutual information

In [None]:
def mutual_information(class_priors, p_ci_list, n):
    """
    Compute mutual information I(X; Y) for a binomial feature with given class parameters.
    
    Args:
        class_priors (np.array): P(Y=c), shape (num_classes,)
        p_ci_list (np.array): Binomial parameters [p_{c,i}] for each class c, shape (num_classes,)
        n (int): Number of trials in binomial distribution
    
    Returns:
        float: Mutual information I(X; Y)
    """
    num_classes = len(class_priors)
    
    # Precompute all PMFs across x and classes
    x_values = np.arange(0, n + 1)[:, None]  # (n+1, 1)
    pmfs = binom.pmf(x_values, n, p_ci_list[None, :])  # Shape: (n+1, num_classes)
    
    # Compute joint probabilities P(Y=c) * P(X=x | Y=c)
    # Multiply class_priors (shape C) with each row of pmfs (each x has shape (C,))
    # class_priors[None, :] becomes (1, C), so broadcasting works.
    joint_probs = pmfs * class_priors[None, :]
    
    # Compute P(X=x) for all x
    px = np.sum(joint_probs, axis=1)

    # Compute H(Y|X):
    h_ygx = 0.0

    for x_idx in range(n + 1):
        current_px = px[x_idx]
        
        if current_px < 1e-9:  # Skip negligible probabilities
            continue
        
        cond_probs = joint_probs[x_idx] / current_px  # P(Y=c | X=x)
        
        # Compute entropy H(Y|X=x) using scipy's entropy function
        h_x = entropy(cond_probs, base=2)
        
        h_ygx += current_px * h_x
    
    # Prior entropy H(Y)
    h_y = entropy(class_priors, base=2)

    return h_y - h_ygx


def mutual_information_between_features(class_priors, p_ci_i, p_ci_j, n):
    """
    Compute mutual information between two features X_i and X_j.
    
    Parameters:
        class_priors (array): Prior probabilities of each class. Shape: (num_classes,)
        p_ci_i (array): Binomial parameters for feature i across classes. Shape: (num_classes,)
        p_ci_j (array): Binomial parameters for feature j across classes. Shape: (num_classes,)
        n (int): Number of trials for the binomial distribution.
    
    Returns:
        float: Mutual information I(X_i; X_j)
    """
    num_classes = len(class_priors)
    x_vals = np.arange(0, n + 1)  # Possible values of features
    
    ### Compute marginal distributions P(Xi=x), P(Xj=y) ###
    # PMF for feature i across all classes
    pmf_i_per_class = binom.pmf(x_vals[:, None], n, p_ci_i[None, :])
    px_i = np.sum(pmf_i_per_class * class_priors[None, :], axis=1)
    entropy_xi = entropy(px_i, base=2) if not np.allclose(px_i, 0.0) else 0.0
    
    # PMF for feature j across all classes
    pmf_j_per_class = binom.pmf(x_vals[:, None], n, p_ci_j[None, :])
    px_j = np.sum(pmf_j_per_class * class_priors[None, :], axis=1)
    entropy_xj = entropy(px_j, base=2) if not np.allclose(px_j, 0.0) else 0.0
    
    ### Compute joint distribution P(Xi=x, Xj=y) ###
    joint_xy = np.zeros((n + 1, n + 1))
    
    for c in range(num_classes):
        pmf_i_c = binom.pmf(x_vals, n, p_ci_i[c])
        pmf_j_c = binom.pmf(x_vals, n, p_ci_j[c])
        
        # Outer product gives joint PMF for class c
        outer = np.outer(pmf_i_c, pmf_j_c)
        joint_xy += class_priors[c] * outer
    
    # Compute entropy of the joint distribution
    epsilon = 1e-10  # To avoid log(0) issues
    non_zero = (joint_xy > epsilon)
    entropy_joint = -np.sum(joint_xy[non_zero] * np.log2(joint_xy[non_zero]))
    
    ### Mutual Information ###
    mi = entropy_xi + entropy_xj - entropy_joint
    
    return mi


def conditional_mutual_info(class_priors, XJ_params, XK_params, n):
    """
    Compute I(XK; Y | XJ) using vectorization with broadcasting.
    
    Args:
        XJ_params (array): p_{c,J} for all classes c.
        XK_params (array): p_{c,K} for all classes c.
        class_priors (array): P(Y=c) for all classes c.
        n (int): Number of trials in the binomial distribution.

    Returns:
        float: Conditional mutual information I(XK; Y | XJ).
    """
    K = len(class_priors)
    x_values = np.arange(n + 1)

    # Precompute PMFs for each class
    P_XJ_giv_Y = binom.pmf(x_values[:, None], n, XJ_params)  
    P_XK_giv_Y = binom.pmf(x_values[:, None], n, XK_params)  

    P_XJ_T = P_XJ_giv_Y.T  # Shape: (K, n+1)
    P_XK_T = P_XK_giv_Y.T

    ######################################################################
    ### Compute H(Y | XJ) ###############################################
    ######################################################################

    # Calculate P(XJ=xj) for all xj
    P_XJ_total = np.dot(class_priors, P_XJ_T)

    # Numerators of posterior probabilities P(Y=c | XJ=xj)
    numerators_YgXJ = class_priors[:, None] * P_XJ_T  

    valid_mask = P_XJ_total > 1e-9
    posterior_YgXJ = np.zeros_like(numerators_YgXJ, dtype=float)
    posterior_YgXJ[:, valid_mask] = (
        numerators_YgXJ[:, valid_mask] / 
        P_XJ_total[valid_mask]
    )

    log_p = np.log2(posterior_YgXJ + 1e-9)  
    entropy_terms_HYgXJ = -np.sum(
        posterior_YgXJ * log_p, 
        axis=0,
        where=(posterior_YgXJ > 1e-9)
    )
    
    H_Y_given_XJ = np.dot(P_XJ_total, entropy_terms_HYgXJ)

    ######################################################################
    ### Compute H(Y | XJ, XK) ###########################################
    ######################################################################

    # Broadcast to compute joint PMF P(XJ=xj, XK=xk | Y=c)
    P_XJ_giv_Y_T = P_XJ_T[..., None]  # Shape: (K, n+1, 1)
    P_XK_giv_Y_T = P_XK_T[:, None, :]  # Shape: (K, 1, n+1)

    joint_pmf_conditional = (
        P_XJ_giv_Y_T * 
        P_XK_giv_Y_T
    )  # Shape: (K, n+1, n+1)

    numerators = class_priors[:, None, None] * joint_pmf_conditional  

    denominators = np.sum(numerators, axis=0)  # Shape: (n+1, n+1)

    valid_mask_3d = (denominators > 1e-9)[None, ...]  # Expand for class dimension

    # Compute posterior probabilities using broadcasting and where
    posterior_YgXJXK = numerators / denominators[None, ...]
    posterior_YgXJXK = np.where(valid_mask_3d, posterior_YgXJXK, 0.0)

    log_p_joint = np.log2(posterior_YgXJXK + 1e-9)  
    entropy_terms_HYgXJXK = -np.sum(
        posterior_YgXJXK * log_p_joint,
        axis=0,  # Sum over classes (axis 0 is K)
        where=(posterior_YgXJXK > 1e-9),
    )

    H_Y_given_XJXK = np.sum(denominators * entropy_terms_HYgXJXK)

    ######################################################################
    ### Compute CMI #####################################################
    ######################################################################

    cmi = H_Y_given_XJ - H_Y_given_XJXK

    return max(cmi, 0.0)  

#### Relevance and redundancy
First, lets pre-compute the relevance and redundancy metrics for mRMR (also used in JMI). We assume a uniform class prior.

In [None]:
priors = np.full(nmults, 1/nmults, dtype=np.float64)
probs = np.zeros((nallfeats, nmults), dtype=np.float64)
for i, divisor in enumerate(allfeats):
    for j, (mult, probmap) in enumerate(distributions_mults.items()):
        probs[i, j] = probmap[divisor]

nattack = 100
mis = []
relevance = np.zeros(nallfeats, dtype=np.float64)
for i, divisor in enumerate(allfeats):
    mi = mutual_information(priors, probs[i, ], nattack)
    relevance[i] = mi
    mis.append((mi, divisor))
mis.sort(key=lambda item: item[0], reverse=True)

print("Top 10 feats")
for mi, divisor in mis[:10]:
    print(f"{divisor} {mi:.3f}")

In [None]:
num_workers = 30

redundancy = np.zeros((nallfeats, nallfeats), dtype=np.float64)
with TaskExecutor(max_workers=num_workers) as pool:
    for i in trange(nallfeats):
        for j in range(nallfeats):
            if i < j:
                continue
            pool.submit_task((i, j),
                             mutual_information_between_features,
                             priors, probs[i, ], probs[j, ], nattack)
        for (i, j), future in pool.as_completed():
            mi = future.result()
            redundancy[i][j] = mi
            redundancy[j][i] = mi


Store the relevance and redundancy arrays.

In [None]:
with open("relevance.pickle", "wb") as f:
    pickle.dump(relevance, f)
with open("redundancy.pickle", "wb") as f:
    pickle.dump(redundancy, f)

In [None]:
def mrmr_selection(relevance, redundancy, nfeats):
    """
    Select top features using mRMR.
    
    Returns:
        indices of selected features.
    """
    selected_indices = []
    remaining_indices = list(range(nallfeats))
    
    # Initialize by selecting the most relevant feature
    first_feature_idx = np.argmax(relevance)
    selected_indices.append(first_feature_idx)
    remaining_indices.remove(first_feature_idx)
    
    while len(selected_indices) < nfeats:
        candidates_scores = []
        
        for candidate in remaining_indices:
            # Compute mRMR score: relevance - average redundancy with selected features
            current_relevance = relevance[candidate]
            
            avg_red = 0.0
            if len(selected_indices) > 0:
                sum_red = np.sum(redundancy[candidate][selected_indices])
                avg_red = sum_red / len(selected_indices)
            
            score = current_relevance - avg_red
            candidates_scores.append(score)
        
        # Select the candidate with highest score
        best_candidate_idx = remaining_indices[np.argmax(candidates_scores)]
        selected_indices.append(best_candidate_idx)
        remaining_indices.remove(best_candidate_idx)
    
    return selected_indices

In [None]:
selected_mrmr = [allfeats[i] for i in mrmr_selection(relevance, redundancy, nfeats=5)]

### Feature selection via JMI

In [None]:
def jmi_selection(features_params_list, class_priors, n_trials, relevance, nfeats):
    """
    Select top features using JMI.
    
    Returns:
        indices of selected features.
    """
    selected_indices = []
    remaining_indices = list(range(nallfeats))
    
    # Initialize by selecting the most relevant feature
    first_feature_idx = np.argmax(relevance)
    selected_indices.append(first_feature_idx)
    remaining_indices.remove(first_feature_idx)
    
    while len(selected_indices) < nfeats:
        candidates_scores = []
        
        for candidate in tqdm(remaining_indices):
            # Compute mRMR score: relevance - average redundancy with selected features
            current_relevance = relevance[candidate]
            
            sum_cmi = 0.0
            for selected in selected_indices:
                XJ_params = features_params_list[selected]
                XK_params = features_params_list[candidate]
                
                cmi_val = conditional_mutual_info(
                    class_priors=class_priors,
                    XJ_params=XJ_params,
                    XK_params=XK_params,
                    n=n_trials
                )
                sum_cmi += cmi_val
            avg_cmi = sum_cmi / len(selected_indices)
            score = current_relevance + avg_cmi
            candidates_scores.append(score)
        
        # Select the candidate with highest score
        best_candidate_idx = remaining_indices[np.argmax(candidates_scores)]
        selected_indices.append(best_candidate_idx)
        remaining_indices.remove(best_candidate_idx)
    
    return selected_indices

In [None]:
selected_jmi = [allfeats[i] for i in jmi_selection(probs, priors, nattack, relevance, nfeats=5)]