# Note

We expereienced logifold augmented by 'largely perturbed sample' with entropy method works well. But for the original and for adversarial samples perturbed little we got negative bar (in the bar graph). 

Therefore we can try to match the ratio of testing dataset with the validation dataset. (ratio of low entropy to high entropy).

Here, we'll going to use PGD largely perturbed.

Also, for the standard perturbation, we experienced over-fitting issue. Perhaps high entropy dataset may be skewed little. So, try to make them balanced.

Here, we'll going to use 

DeepFool, CWL2, PGD std. (total three samples).


In [None]:
from __future__ import annotations
import glob
from pathlib import Path
from dataclasses import dataclass
from typing import List, Tuple

import numpy as np
import tensorflow as tf
from keras.models import load_model
from keras.utils import to_categorical
from keras.datasets import cifar10
from sklearn.model_selection import train_test_split

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
# Define paths

ROOT = Path(".").resolve()
DATA = ROOT / "data"
MODELS_DIR = DATA / "models"
ADV_MODELS_DIR = DATA / "adversarial_models"
ADV_SAMPLES = DATA / "samples"
EXPERTS_DIR = DATA / "specialized_models"
LOGIFOLD_MODS = (ROOT / "logifold_modules") 


CACHE = DATA / "cache"
CACHE_PREDS = CACHE / "preds"
CACHE_METRICS = CACHE / "metrics"
CACHE_INDEX = CACHE / "index"
ANALYSIS = DATA / "analysis"
ANALYSIS.mkdir(parents=True, exist_ok=True)
FIGURES = ANALYSIS / "figures"
FIGURES.mkdir(parents=True, exist_ok=True)
REPORTS = ANALYSIS / "reports"
REPORTS.mkdir(parents=True, exist_ok=True)
LGFD_PATH = DATA / "logifold/"
LGFD_PATH.mkdir(parents=True, exist_ok=True)
# Define Judge
JUDGES_DIR = sorted(glob.glob(str(MODELS_DIR / 'resnet*original_tuned-once-on_original*')))


In [None]:
from logifold_modules.logifoldv1_4_modified import Logifold, _stem_all, int_from_model_path
from logifold_modules.resnet_modified import ResNet
import logifold_modules.custom_specialization as specialization
from adv_logifold import AdvLogifold, get_statistics, plot_disagreements
import cache_store



In [None]:
def load_adv_samples(pattern: str, _print_ : bool = False) -> np.ndarray:
    files = sorted(glob.glob(str(ADV_SAMPLES / pattern)))
    if not files:
        raise FileNotFoundError(f"No samples for pattern: {pattern}")
    if _print_:
        print(f"Loading {len(files)} files matching pattern: {pattern}")
        for f in files:
            
            print(f" - {f}")
    samples = [np.load(f) for f in files]
    if len(samples) == 1:
        samples = samples[0]
    return samples

In [None]:
(x, y), (x_test, y_test) = cifar10.load_data()
x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=42)
x_train = x_train.astype('float32') / 255.0
x_val = x_val.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

y_train_categorical_10 = to_categorical(y_train,10)
y_val_categorical_10 = to_categorical(y_val,10)
y_test_categorical_10 = to_categorical(y_test,10)



In [None]:
def train_union_and_specialize(
    
    x_adv_tr: np.ndarray, x_adv_val: np.ndarray, adv_label: str,
    
) -> Tuple[tf.keras.Model, tf.keras.Model]:
    """
    Returns (baseline_adv_model, tuned_baseline_adv_model, tuned_history_dict_or_None)
    """
    size = x_adv_tr.shape[0] # CWL2 example training size is not 40000 but 10001.
    train_union = np.concatenate([x_train, x_adv_tr], axis=0)
    val_union = np.concatenate([x_val, x_adv_val], axis=0)

    training_y_long=np.concatenate([y_train,y_train[:size]],axis=0)
    validating_y_long=np.concatenate([y_val,y_val],axis=0)
    if training_y_long.ndim == 1 or training_y_long.shape[1] != 10:
        training_y_long = to_categorical(training_y_long, 10)
    if validating_y_long.ndim == 1 or validating_y_long.shape[1] != 10:
        validating_y_long = to_categorical(validating_y_long, 10)

    path = ADV_MODELS_DIR / f"ResNet56v1_union-of-original-and-{adv_label}_ver0.keras"
    if path.exists():
        base_model = load_model(path)
        print(f'load {path} to specialize once')
    else:
        raise FileNotFoundError(f'no model at {path}')
    baseline_before_tuning = base_model
    path = ADV_MODELS_DIR / f"ResNet56v1_union-of-original-and-{adv_label}_tuned-once-on_stratified_union-of-original-and-{adv_label}_ver0.keras"
    if path.exists():
        baseline_after_tuning = load_model(path)
        print(f'{path} already exists. try to get history of the training procedure')
        hist_baseline = specialization.load_history(path) # it could be none.
        if hist_baseline is None:
            print(f"[WARN] No history found for {path}")
    else:
        print(f'{path} training...')
        baseline_after_tuning,hist_baseline = specialization.turn_specialist(base_model, path = path,
                                                x_tr=train_union, y_tr=training_y_long,
                                                  x_v=val_union,   y_v=validating_y_long,
                                                  epochs=21, learning_rate=1e-3, batch_size=128, verbose=1, name=f"tuned_once")
        hist_baseline = {"history": hist_baseline.history, "params": hist_baseline.params, "epoch": hist_baseline.epoch}

    return baseline_before_tuning, baseline_after_tuning, hist_baseline

In [None]:
class AttackEntry:
    short_tag: str                    # short_tag
    glob_pattern: str            # pattern in data/samples
    adv_label: str                     # label for cache

ATTACKS: List[AttackEntry] = [
    AttackEntry("CWL2",            "*cwl2*untargeted_train_by_resnet56v1_ver0.npy", "cwl2-untargeted-gen-by-resnet56v1-ver0"),
    AttackEntry("PGD_standard",    "*pgd*eps8*untargeted_train_by_resnet56v1_ver0.npy","pgd-eps8-iter2-10steps-untargeted-gen-by-resnet56v1-ver0"),
    AttackEntry("DeepFool",     "*deepfool_untargeted_train_by_resnet56v1_ver0.npy", "deepfool-untargeted-gen-by-resnet56v1-ver0")
]

In [None]:
def specialize_Committee(adversarial_lgfd : AdvLogifold, Comm_keys : List[Tuple],  adv_short_tag: str):
    # Get adversarial sample corresponding to the adv_short_tag
    adv_type = adv_short_tag
    for atk in ATTACKS:
        if atk.short_tag == adv_type:
            adv_sample_name = atk.adv_label
            
            adv_sample_train = load_adv_samples(atk.glob_pattern)
            pattern = atk.glob_pattern.replace("train", "val")
            adv_sample_val = load_adv_samples(pattern)
            break

    # Compute entropy of adversarial sample by JUDGE models
    ent_original_train =adversarial_lgfd.get_entropy_array(Comm_keys, sample_name = 'original_train', sample = x_train)
    ent_adv_train = adversarial_lgfd.get_entropy_array(Comm_keys, sample_name = adv_sample_name + '_train', sample = adv_sample_train)
    fp = FIGURES / f"entropy-disagreements-on-original_train.png"
    if fp.exists():
        pass
    else:
        plot_disagreements(ent_original_train, title = f"Entropy Disagreements on original_train", save_path = FIGURES / f"entropy-disagreements-on-original_train.png")

    plot_disagreements(ent_adv_train, title = f"Entropy Disagreements on {adv_sample_name}_train", save_path = FIGURES / f"entropy-disagreements-on-{adv_sample_name}_train.png")

    ent_original_val = adversarial_lgfd.get_entropy_array(Comm_keys, sample_name = 'original_val', sample = x_val)
    fp = FIGURES / f"entropy-disagreements-on-original_val.png"
    if fp.exists():
        pass
    else:
        plot_disagreements(ent_original_val, title = f"Entropy Disagreements on original_train", save_path = FIGURES / f"entropy-disagreements-on-original_val.png")

    ent_adv_val = adversarial_lgfd.get_entropy_array(Comm_keys, sample_name = adv_sample_name + '_val', sample = adv_sample_val)
    plot_disagreements(ent_adv_val, title = f"Entropy Disagreements on {adv_sample_name}_val", save_path = FIGURES / f"entropy-disagreements-on-{adv_sample_name}_val.png")
    
    # Including original sample, compute average of entropy
    stats = {}
    stats[('original','train')] = get_statistics(ent_original_train)
    stats[('original','val')] = get_statistics(ent_original_val)
    stats[('adv','train')] = get_statistics(ent_adv_train)
    stats[('adv','val')] = get_statistics(ent_adv_val)
    train_alpha_union = (stats[('original','train')]['average'] + stats[('adv','train')]['average'])/2
    val_alpha_union = (stats[('original','val')]['average'] + stats[('adv','val')]['average'])/2

    # separate union of original and adversarial samples into high entropy and low entropy samples
    loc_1_original_train = ent_original_train>=train_alpha_union
    loc_1_adv_train = ent_adv_train>=train_alpha_union
    loc_1_original_val = ent_original_val>=val_alpha_union
    loc_1_adv_val = ent_adv_val>=val_alpha_union
    print('alpha for train: {}, for val: {}'.format(train_alpha_union, val_alpha_union))
    print('the number of data greater than alpha:')
    print(f'Training set original + {adv_type}:', np.sum(loc_1_original_train), '+',np.sum(loc_1_adv_train), '=', np.sum(loc_1_original_train) + np.sum(loc_1_adv_train))
    print(f'Validation set original + {adv_type}:', np.sum(loc_1_original_val), '+', np.sum(loc_1_adv_val), '=', np.sum(loc_1_original_val) + np.sum(loc_1_adv_val))
    
    DATASETS = {"Experts_union":dict(train = (np.concatenate([x_train[loc_1_original_train], adv_sample_train[loc_1_adv_train]]), 
                                            to_categorical(
                                                np.concatenate(
                                                [y_train[loc_1_original_train], y_train[:adv_sample_train.shape[0]][loc_1_adv_train]]
                                                ), 10)
                                            ),
                                    val=(np.concatenate([x_val[loc_1_original_val], adv_sample_val[loc_1_adv_val]]), 
                                        to_categorical(
                                            np.concatenate(
                                                [y_val[loc_1_original_val], y_val[loc_1_adv_val]]
                                                ),10)))}
    
    # specialize Judge models on the high entropy samples
    EXPERTS_KEYS = []
    experts_paths = []
    
    for a_judge_key in Comm_keys: 
        a_judge = adversarial_lgfd.getModel(a_judge_key)
        a_judge_name = adversarial_lgfd.model_source_name(a_judge_key)
        
        path = EXPERTS_DIR / f"{a_judge_name}_specialized-once-on_high-entropy-union-of-original-and-{adv_sample_name}_ver0.keras"
        
        if path.exists():
            print(f"There is specialized Judge {a_judge_name} on union of original and {adv_type} samples.")
        
            specialist = load_model(str(path))
        else:
            print(f"Specializing Judge {a_judge_name} on union of original and {adv_type} samples...")
        
            specialist, _ = specialization.turn_specialist(model = a_judge, path = path,
                                           x_tr = DATASETS["Experts_union"]["train"][0], y_tr = DATASETS["Experts_union"]["train"][1],
                                           x_v = DATASETS["Experts_union"]["val"][0], y_v = DATASETS["Experts_union"]["val"][1],
                                           epochs = 21, learning_rate = 1e-3, batch_size = 128, verbose = 0, 
                                           name = f"specialized_once_on_high-entropy_union_of_original_and_{adv_sample_name}")
            # Add them to Advlogifold
        key = (a_judge_key[0],int_from_model_path(f"{a_judge_name}_specialized-once-on_high-entropy-union-of-original-and-{adv_sample_name}_ver0.keras"))
        print('prepared key:', key)
        if key in adversarial_lgfd.keys():
            print(f'specialized model is already a member of logifold')
        else:
            print(f'Adding specialized model...')
            adversarial_lgfd.add(specialist,
                             key = key,
                             model_path = _stem_all(path),
                             description = f'specialized on high entropy union of original and {adv_sample_name}', 
                             fuzDom = {})
        # compute fuzdom
        adversarial_lgfd.getFuzDoms(keys = [key],
                            x = DATASETS["Experts_union"]["val"][0], y = DATASETS["Experts_union"]["val"][1], sample_name = f'union_of_original_and_{adv_sample_name}_val',
                            update = False, autosave = False, verbose = 0)
        EXPERTS_KEYS.append(key)
        experts_paths.append(path)
        
        alpha = val_alpha_union
    return EXPERTS_KEYS, experts_paths, alpha

In [None]:
# --- imports & globals ---
import os, io
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display, Image

OUTDIR = "ratio_summary"
os.makedirs(OUTDIR, exist_ok=True)

# --- helpers ---
def class_stats(y, mask, uniform=0.1):
    """Per-class counts/ratios/deviation for y[mask]."""
    yy = y[mask].reshape(-1)
    classes, counts = np.unique(yy, return_counts=True)
    total = int(counts.sum())
    ratios = counts / total if total > 0 else np.zeros_like(counts, dtype=float)
    dev = ratios - uniform
    out = pd.DataFrame(
        {"count": counts, "ratio": np.round(ratios, 4), "dev_from_0.1": np.round(dev, 4)},
        index=classes,
    )
    out.index.name = "class"
    return out, total

def summarize_split(name, y, high_mask):
    """
    Build LOW/HIGH summaries for a single vector y with a boolean mask high_mask.
    Convention: LOW = ~high_mask, HIGH = high_mask.
    """
    assert y.shape[0] == high_mask.shape[0], f"Length mismatch: y={y.shape} vs mask={high_mask.shape}"
    low_df, low_n   = class_stats(y, ~high_mask)
    high_df, high_n = class_stats(y,  high_mask)
    return {"name": name, "low": low_df, "low_total": low_n, "high": high_df, "high_total": high_n}

def pretty_table(title, df, total, save_basename=None):
    df2 = df.copy()
    df2["%"] = (df2["ratio"] * 100).round(2)
    df2 = df2[["count", "ratio", "%", "dev_from_0.1"]]
    print(f"\n=== {title} (n={total}) ===")
    display(df2)
    if save_basename:
        csv_path = os.path.join(OUTDIR, f"{save_basename}.csv")
        df2.to_csv(csv_path)
        print(f"[Saved table] {csv_path}")

def barplot(title, df, save_basename=None, dpi=160):
    df = df.sort_index()
    fig, ax = plt.subplots(figsize=(8, 3))
    ax.bar(df.index, df["ratio"].values)
    ax.axhline(0.1, linestyle="--")
    ax.set_title(title); ax.set_xlabel("class"); ax.set_ylabel("ratio")

    if save_basename:
        png_path = os.path.join(OUTDIR, f"{save_basename}.png")
        fig.savefig(png_path, bbox_inches="tight", dpi=dpi)
        print(f"[Saved plot] {png_path}")

    # Force inline render as an image (works even when fig repr shows)
    buf = io.BytesIO()
    fig.savefig(buf, format="png", bbox_inches="tight", dpi=dpi)
    buf.seek(0)
    display(Image(data=buf.getvalue()))
    plt.close(fig)

def summarize_union_two(name, y_A, high_A, y_B, high_B):
    """Union (concat) of two splits (e.g., original & adversarial)."""
    y_union    = np.concatenate([y_A.reshape(-1), y_B.reshape(-1)])
    high_union = np.concatenate([high_A.reshape(-1), high_B.reshape(-1)]).astype(bool)
    return summarize_split(name, y_union, high_union)

def equalize_by_min_per_class(indices, labels, classes=None, seed=42, shuffle=True):
    """
    Downsample `indices` so each class has exactly the minimum available count across classes.

    Args:
        indices : 1D array[int] -> indices into `labels` you want to stratify (e.g., LOW bucket indices)
        labels  : 1D array[int] -> label vector for the *full* dataset (labels[indices] used here)
        classes : optional sequence of class ids to enforce presence/order (default: infer from labels[indices])
        seed    : RNG seed for reproducibility
        shuffle : if True, shuffle the final concatenated result

    Returns:
        chosen_indices : 1D array[int] of selected indices (balanced, no replacement)
        per_class_k    : target count per class (the min count)
        counts_before  : dict[class -> original available count in `indices`]
    """
    rng = np.random.default_rng(seed)
    idx = np.asarray(indices)
    y   = np.asarray(labels)[idx]

    # classes present
    if classes is None:
        cls = np.unique(y)
    else:
        cls = np.array(classes)

    # counts per class (only within `indices`)
    counts_before = {c: int(np.sum(y == c)) for c in cls}
    per_class_k = min(counts_before.values()) if len(counts_before) else 0

    chosen = []
    for c in cls:
        c_idx = idx[y == c]
        if len(c_idx) < per_class_k:
            # If any class is too small (shouldn't happen with per_class_k=min),
            # we still bail out safely by sampling all that exist.
            take = len(c_idx)
        else:
            take = per_class_k
        chosen.append(rng.choice(c_idx, size=take, replace=False))

    if len(chosen):
        chosen = np.concatenate(chosen)
        if shuffle:
            rng.shuffle(chosen)
    else:
        chosen = np.array([], dtype=int)

    return chosen, per_class_k, counts_before

# --- main ---
def ratio_of_samples(adversarial_lgfd, Comm_keys, adv_short_tag):
    """
    Compute entropy on original & a chosen adversarial set, threshold by the
    union-average alpha, and report LOW/HIGH class ratios for the TRAIN union
    and VAL union (original U adversarial).

    Expects globals available:
        - ATTACKS: iterable with .short_tag, .adv_label, .glob_pattern
        - x_train, y_train, x_val, y_val
        - load_adv_samples, plot_disagreements, FIGURES
    """
    # 1) resolve attack & load samples
    atk = next((a for a in ATTACKS if a.short_tag == adv_short_tag), None)
    if atk is None:
        raise ValueError(f"Unknown adv_short_tag: {adv_short_tag}")

    adv_sample_name = atk.adv_label
    adv_train = load_adv_samples(atk.glob_pattern)
    adv_val   = load_adv_samples(atk.glob_pattern.replace("train", "val"))
    
    # 2) entropy arrays (by the JUDGE committee)
    ent_orig_tr = adversarial_lgfd.get_entropy_array(Comm_keys, sample_name="original_train", sample=x_train)
    ent_orig_tr = ent_orig_tr[:adv_train.shape[0]]
    ent_adv_tr  = adversarial_lgfd.get_entropy_array(Comm_keys, sample_name=f"{adv_sample_name}_train", sample=adv_train)
    ent_orig_va = adversarial_lgfd.get_entropy_array(Comm_keys, sample_name="original_val", sample=x_val)
    ent_adv_va  = adversarial_lgfd.get_entropy_array(Comm_keys, sample_name=f"{adv_sample_name}_val",   sample=adv_val)


    # 3) union alphas (average of means of orig & adv)
    def mean_of(a): return float(np.mean(a)) if a.size else np.nan
    alpha_tr = 0.5 * (mean_of(ent_orig_tr) + mean_of(ent_adv_tr))
    alpha_va = 0.5 * (mean_of(ent_orig_va) + mean_of(ent_adv_va))
    print(f"alpha — train: {alpha_tr}, val: {alpha_va}")

    # 4) HIGH/LOW masks using union alpha (True => HIGH)
    high_orig_tr = ent_orig_tr >= alpha_tr
    high_adv_tr  = ent_adv_tr  >= alpha_tr
    high_orig_va = ent_orig_va >= alpha_va
    high_adv_va  = ent_adv_va  >= alpha_va
    
    # 5) report counts above alpha
    print("count ≥ alpha (HIGH):")
    print(f"  TRAIN original + {adv_short_tag}: {high_orig_tr.sum()} + {high_adv_tr.sum()} = {int(high_orig_tr.sum()+high_adv_tr.sum())}")
    print(f"  VAL   original + {adv_short_tag}: {high_orig_va.sum()} + {high_adv_va.sum()} = {int(high_orig_va.sum()+high_adv_va.sum())}")
    
    
    # 6) UNION summaries (train & val)
    n_adv_tr  = adv_train.shape[0]
    x_tr_union = np.concatenate([x_train[:n_adv_tr], adv_train])
    y_tr_union = np.concatenate([y_train[:n_adv_tr], y_train[:n_adv_tr]])
    y_tr_union = y_tr_union.reshape(-1)
    high_tr_union = np.concatenate([high_orig_tr, high_adv_tr])
    
    x_va_union = np.concatenate([x_val, adv_val])
    y_va_union = np.concatenate([y_val, y_val])
    y_va_union = y_va_union.reshape(-1)
    high_va_union = np.concatenate([high_orig_va, high_adv_va])

    train_union = summarize_split(f"TRAIN union (original U {adv_short_tag})", y_tr_union, high_tr_union)
    val_union   = summarize_split(f"VAL union (original U {adv_short_tag})",   y_va_union, high_va_union)


    # 7) display tables + plots + save artifacts
    for s in (train_union, val_union):
        base = s["name"].replace(" ", "_")
#         pretty_table(f"{s['name']} — LOW",  s["low"],  s["low_total"],  save_basename=f"{base}_LOW_table")
#         pretty_table(f"{s['name']} — HIGH", s["high"], s["high_total"], save_basename=f"{base}_HIGH_table")
        barplot(f"{s['name']} — LOW",  s["low"],  save_basename=f"{base}_LOW_plot")
        barplot(f"{s['name']} — HIGH", s["high"], save_basename=f"{base}_HIGH_plot")
    
    
    all_tr_idx = np.arange(len(y_tr_union))
    low_tr_idx  = all_tr_idx[~high_tr_union]
    high_tr_idx = all_tr_idx[ high_tr_union]

    all_va_idx = np.arange(len(y_va_union))
    low_va_idx  = all_va_idx[~high_va_union]
    high_va_idx = all_va_idx[ high_va_union]
    
    low_tr_balanced_idx,  k_low_tr,  counts_low_tr  = equalize_by_min_per_class(low_tr_idx,  y_tr_union, seed=42)
    high_tr_balanced_idx, k_high_tr, counts_high_tr = equalize_by_min_per_class(high_tr_idx, y_tr_union, seed=42)

    low_va_balanced_idx,  k_low_va,  counts_low_va  = equalize_by_min_per_class(low_va_idx,  y_va_union, seed=42)
    high_va_balanced_idx, k_high_va, counts_high_va = equalize_by_min_per_class(high_va_idx, y_va_union, seed=42)

    print("[TRAIN LOW]  per-class =", k_low_tr,  " total =", len(low_tr_balanced_idx), "before =", counts_low_tr)
    print("[TRAIN HIGH] per-class =", k_high_tr, " total =", len(high_tr_balanced_idx), "before =", counts_high_tr)
    print("[VAL   LOW]  per-class =", k_low_va,  " total =", len(low_va_balanced_idx), "before =", counts_low_va)
    print("[VAL   HIGH] per-class =", k_high_va, " total =", len(high_va_balanced_idx), "before =", counts_high_va)
    x_train_high_bal = x_tr_union[high_tr_balanced_idx]
    y_train_high_bal = y_tr_union[high_tr_balanced_idx]
    
    x_val_high_bal = x_va_union[high_va_balanced_idx]
    y_val_high_bal = y_va_union[high_va_balanced_idx]
    
    print('After balancing, the shape of high entropy training set:', x_train_high_bal.shape, y_train_high_bal.shape)
    print('After balancing, the shape of high entropy validation set:', x_val_high_bal.shape, y_val_high_bal.shape)
    
    # 8) return structured results if you want to program against them
    return {"train_union": train_union, "val_union": val_union, "alpha_train": alpha_tr, "alpha_val": alpha_va}

In [None]:
def specialize_Committee_stratified_way(adversarial_lgfd, Comm_keys, adv_short_tag):
    """
    Compute entropy on original & a chosen adversarial set, threshold by the
    union-average alpha, and report LOW/HIGH class ratios for the TRAIN union
    and VAL union (original U adversarial).

    Expects globals available:
        - ATTACKS: iterable with .short_tag, .adv_label, .glob_pattern
        - x_train, y_train, x_val, y_val
        - load_adv_samples, plot_disagreements, FIGURES
    """
    # 1) resolve attack & load samples
    atk = next((a for a in ATTACKS if a.short_tag == adv_short_tag), None)
    if atk is None:
        raise ValueError(f"Unknown adv_short_tag: {adv_short_tag}")

    adv_sample_name = atk.adv_label
    adv_train = load_adv_samples(atk.glob_pattern)
    adv_val   = load_adv_samples(atk.glob_pattern.replace("train", "val"))
    
    # 2) entropy arrays (by the JUDGE committee)
    ent_orig_tr = adversarial_lgfd.get_entropy_array(Comm_keys, sample_name="original_train", sample=x_train)
    ent_orig_tr = ent_orig_tr[:adv_train.shape[0]]
    ent_adv_tr  = adversarial_lgfd.get_entropy_array(Comm_keys, sample_name=f"{adv_sample_name}_train", sample=adv_train)
    ent_orig_va = adversarial_lgfd.get_entropy_array(Comm_keys, sample_name="original_val", sample=x_val)
    ent_adv_va  = adversarial_lgfd.get_entropy_array(Comm_keys, sample_name=f"{adv_sample_name}_val",   sample=adv_val)


    # 3) union alphas (average of means of orig & adv)
    def mean_of(a): return float(np.mean(a)) if a.size else np.nan
    alpha_tr = 0.5 * (mean_of(ent_orig_tr) + mean_of(ent_adv_tr))
    alpha_va = 0.5 * (mean_of(ent_orig_va) + mean_of(ent_adv_va))
    print(f"alpha — train: {alpha_tr}, val: {alpha_va}")

    # 4) HIGH/LOW masks using union alpha (True => HIGH)
    high_orig_tr = ent_orig_tr >= alpha_tr
    high_adv_tr  = ent_adv_tr  >= alpha_tr
    high_orig_va = ent_orig_va >= alpha_va
    high_adv_va  = ent_adv_va  >= alpha_va
    
    # 5) report counts above alpha
    print("count ≥ alpha (HIGH):")
    print(f"  TRAIN original + {adv_short_tag}: {high_orig_tr.sum()} + {high_adv_tr.sum()} = {int(high_orig_tr.sum()+high_adv_tr.sum())}")
    print(f"  VAL   original + {adv_short_tag}: {high_orig_va.sum()} + {high_adv_va.sum()} = {int(high_orig_va.sum()+high_adv_va.sum())}")
    
    
    # 6) UNION summaries (train & val)
    n_adv_tr  = adv_train.shape[0]
    x_tr_union = np.concatenate([x_train[:n_adv_tr], adv_train])
    y_tr_union = np.concatenate([y_train[:n_adv_tr], y_train[:n_adv_tr]])
    y_tr_union = y_tr_union.reshape(-1)
    high_tr_union = np.concatenate([high_orig_tr, high_adv_tr])
    
    x_va_union = np.concatenate([x_val, adv_val])
    y_va_union = np.concatenate([y_val, y_val])
    y_va_union = y_va_union.reshape(-1)
    high_va_union = np.concatenate([high_orig_va, high_adv_va])

    
    all_tr_idx = np.arange(len(y_tr_union))
    high_tr_idx = all_tr_idx[ high_tr_union]

    all_va_idx = np.arange(len(y_va_union))
    high_va_idx = all_va_idx[ high_va_union]
    
    high_tr_balanced_idx, k_high_tr, counts_high_tr = equalize_by_min_per_class(high_tr_idx, y_tr_union, seed=42)

    high_va_balanced_idx, k_high_va, counts_high_va = equalize_by_min_per_class(high_va_idx, y_va_union, seed=42)

    print("[TRAIN HIGH] per-class =", k_high_tr, " total =", len(high_tr_balanced_idx), "before =", counts_high_tr)
    print("[VAL   HIGH] per-class =", k_high_va, " total =", len(high_va_balanced_idx), "before =", counts_high_va)
    x_train_high_bal = x_tr_union[high_tr_balanced_idx]
    y_train_high_bal = y_tr_union[high_tr_balanced_idx]
    
    x_val_high_bal = x_va_union[high_va_balanced_idx]
    y_val_high_bal = y_va_union[high_va_balanced_idx]
    
    print('After balancing, the shape of high entropy training set:', x_train_high_bal.shape, y_train_high_bal.shape)
    print('After balancing, the shape of high entropy validation set:', x_val_high_bal.shape, y_val_high_bal.shape)
    
    DATASETS = {"Experts_union":dict(train = (x_train_high_bal, to_categorical(
                                              y_train_high_bal, 10)),
                                    val=(x_val_high_bal, to_categorical(
                                              y_val_high_bal, 10)))}
    # specialize Judge models on the high entropy samples
    EXPERTS_KEYS = []
    experts_paths = []
    
    for a_judge_key in Comm_keys: 
        a_judge = adversarial_lgfd.getModel(a_judge_key)
        a_judge_name = adversarial_lgfd.model_source_name(a_judge_key)
        
        path = EXPERTS_DIR / f"{a_judge_name}_specialized-once-on_stratified_high-entropy-union-of-original-and-{adv_sample_name}_ver0.keras"
        
        if path.exists():
            print(f"There is specialized Judge {a_judge_name} on union of (stratified) original and {adv_short_tag} samples.")
        
            specialist = load_model(str(path))
        else:
            print(f"Specializing Judge {a_judge_name} on union of (stratified) original and {adv_short_tag} samples...")

            specialist, _ = specialization.turn_specialist(model = a_judge, path = path,
                                           x_tr = DATASETS["Experts_union"]["train"][0], y_tr = DATASETS["Experts_union"]["train"][1],
                                           x_v = DATASETS["Experts_union"]["val"][0], y_v = DATASETS["Experts_union"]["val"][1],
                                           epochs = 21, learning_rate = 1e-3, batch_size = 32, verbose = 0, 
                                           name = f"specialized_once_on_stratified_high-entropy_union_of_original_and_{adv_sample_name}")
            # Add them to Advlogifold
        key = (a_judge_key[0],int_from_model_path(f"{a_judge_name}_specialized-once-on_stratified_high-entropy-union-of-original-and-{adv_sample_name}_ver0.keras"))
        print('prepared key:', key)
        if key in adversarial_lgfd.keys():
            print(f'specialized model is already a member of logifold')
        else:
            print(f'Adding specialized model...')
            adversarial_lgfd.add(specialist,
                             key = key,
                             model_path = _stem_all(path),
                             description = f'specialized on stratified high entropy union of original and {adv_sample_name}', 
                             fuzDom = {})
        # compute fuzdom
        adversarial_lgfd.getFuzDoms(keys = [key],
                            x = DATASETS["Experts_union"]["val"][0], y = DATASETS["Experts_union"]["val"][1], sample_name = f'union_of_original_and_{adv_sample_name}_val',
                            update = False, autosave = False, verbose = 0)
        EXPERTS_KEYS.append(key)
        experts_paths.append(path)
        
        alpha = alpha_va
    # 8) return structured results if you want to program against them
    return EXPERTS_KEYS, experts_paths, alpha

In [None]:
specialized_model_paths = ['resnet20v1_original_tuned-once-on_original_ver0_specialized-once-on_stratified_high-entropy-union-of-original-and-pgd-eps8-iter2-10steps-untargeted-gen-by-resnet56v1-ver0_ver0.history.json',
                'resnet20v1_original_tuned-once-on_original_ver6_specialized-once-on_stratified_high-entropy-union-of-original-and-cwl2-untargeted-gen-by-resnet56v1-ver0_ver0.history.json',
'resnet20v2_original_tuned-once-on_original_ver0_specialized-once-on_stratified_high-entropy-union-of-original-and-deepfool-untargeted-gen-by-resnet56v1-ver0_ver0.history.json', 
'resnet20v1_original_tuned-once-on_original_ver6_specialized-once-on_stratified_high-entropy-union-of-original-and-deepfool-untargeted-gen-by-resnet56v1-ver0_ver0.history.json', 
'resnet20v1_original_tuned-once-on_original_ver1_specialized-once-on_stratified_high-entropy-union-of-original-and-deepfool-untargeted-gen-by-resnet56v1-ver0_ver0.history.json', 
'resnet56v2_original_tuned-once-on_original_ver0_specialized-once-on_stratified_high-entropy-union-of-original-and-deepfool-untargeted-gen-by-resnet56v1-ver0_ver0.history.json', 
'resnet56v1_original_tuned-once-on_original_ver1_specialized-once-on_stratified_high-entropy-union-of-original-and-deepfool-untargeted-gen-by-resnet56v1-ver0_ver0.history.json', 
'resnet20v1_original_tuned-once-on_original_ver0_specialized-once-on_stratified_high-entropy-union-of-original-and-deepfool-untargeted-gen-by-resnet56v1-ver0_ver0.history.json', 
'resnet20v1_original_tuned-once-on_original_ver7_specialized-once-on_stratified_high-entropy-union-of-original-and-deepfool-untargeted-gen-by-resnet56v1-ver0_ver0.history.json', 
'resnet20v2_original_tuned-once-on_original_ver1_specialized-once-on_stratified_high-entropy-union-of-original-and-deepfool-untargeted-gen-by-resnet56v1-ver0_ver0.history.json', 
'resnet56v1_original_tuned-once-on_original_ver0_specialized-once-on_stratified_high-entropy-union-of-original-and-deepfool-untargeted-gen-by-resnet56v1-ver0_ver0.history.json', 
'resnet56v2_original_tuned-once-on_original_ver2_specialized-once-on_stratified_high-entropy-union-of-original-and-deepfool-untargeted-gen-by-resnet56v1-ver0_ver0.history.json', 
'resnet56v1_original_tuned-once-on_original_ver3_specialized-once-on_stratified_high-entropy-union-of-original-and-deepfool-untargeted-gen-by-resnet56v1-ver0_ver0.history.json', 
'resnet20v1_original_tuned-once-on_original_ver4_specialized-once-on_stratified_high-entropy-union-of-original-and-deepfool-untargeted-gen-by-resnet56v1-ver0_ver0.history.json', 
'resnet20v2_original_tuned-once-on_original_ver2_specialized-once-on_stratified_high-entropy-union-of-original-and-deepfool-untargeted-gen-by-resnet56v1-ver0_ver0.history.json', 
'resnet20v1_original_tuned-once-on_original_ver3_specialized-once-on_stratified_high-entropy-union-of-original-and-deepfool-untargeted-gen-by-resnet56v1-ver0_ver0.history.json', 
'resnet56v1_original_tuned-once-on_original_ver2_specialized-once-on_stratified_high-entropy-union-of-original-and-deepfool-untargeted-gen-by-resnet56v1-ver0_ver0.history.json', 
'resnet56v2_original_tuned-once-on_original_ver3_specialized-once-on_stratified_high-entropy-union-of-original-and-deepfool-untargeted-gen-by-resnet56v1-ver0_ver0.history.json', 
'resnet20v1_original_tuned-once-on_original_ver2_specialized-once-on_stratified_high-entropy-union-of-original-and-deepfool-untargeted-gen-by-resnet56v1-ver0_ver0.history.json', 
'resnet20v2_original_tuned-once-on_original_ver3_specialized-once-on_stratified_high-entropy-union-of-original-and-deepfool-untargeted-gen-by-resnet56v1-ver0_ver0.history.json', 
'resnet20v1_original_tuned-once-on_original_ver5_specialized-once-on_stratified_high-entropy-union-of-original-and-deepfool-untargeted-gen-by-resnet56v1-ver0_ver0.history.json'
]           
