In [1]:
import os
import random
import gc
import argparse
import numpy as np
import h5py as h5
import torch
import torch.nn.functional as F
import torch.nn as nn
import wandb

from models.autoencoder import Autoencoder 
#from losses.cyl_ptpz_mae import CylPtPzMAE

from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LogNorm
from scipy.stats import binned_statistic

In [2]:
#setting a seed like in ae_legacy
def set_seed(seed=123):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [3]:
def distance_corr(var_1,var_2,normedweight,power=1):
    """var_1: First variable to decorrelate (eg mass)
    var_2: Second variable to decorrelate (eg classifier output)
    normedweight: Per-example weight. Sum of weights should add up to N (where N is the number of examples)
    power: Exponent used in calculating the distance correlation
    
    va1_1, var_2 and normedweight should all be 1D torch tensors with the same number of entries
    
    Usage: Add to your loss function. total_loss = BCE_loss + lambda * distance_corr
    """ 
    
    xx = var_1.view(-1, 1).repeat(1, len(var_1)).view(len(var_1),len(var_1))
    yy = var_1.repeat(len(var_1),1).view(len(var_1),len(var_1))
    amat = (xx-yy).abs()

    xx = var_2.view(-1, 1).repeat(1, len(var_2)).view(len(var_2),len(var_2))
    yy = var_2.repeat(len(var_2),1).view(len(var_2),len(var_2))
    bmat = (xx-yy).abs()

    amatavg = torch.mean(amat*normedweight,dim=1)
    Amat=amat-amatavg.repeat(len(var_1),1).view(len(var_1),len(var_1))\
        -amatavg.view(-1, 1).repeat(1, len(var_1)).view(len(var_1),len(var_1))\
        +torch.mean(amatavg*normedweight)

    bmatavg = torch.mean(bmat*normedweight,dim=1)
    Bmat=bmat-bmatavg.repeat(len(var_2),1).view(len(var_2),len(var_2))\
        -bmatavg.view(-1, 1).repeat(1, len(var_2)).view(len(var_2),len(var_2))\
        +torch.mean(bmatavg*normedweight)

    ABavg = torch.mean(Amat*Bmat*normedweight,dim=1)
    AAavg = torch.mean(Amat*Amat*normedweight,dim=1)
    BBavg = torch.mean(Bmat*Bmat*normedweight,dim=1)

    if(power==1):
        dCorr=(torch.mean(ABavg*normedweight))/torch.sqrt((torch.mean(AAavg*normedweight)*torch.mean(BBavg*normedweight)))
    elif(power==2):
        dCorr=(torch.mean(ABavg*normedweight))**2/(torch.mean(AAavg*normedweight)*torch.mean(BBavg*normedweight))
    else:
        dCorr=((torch.mean(ABavg*normedweight))/torch.sqrt((torch.mean(AAavg*normedweight)*torch.mean(BBavg*normedweight))))**power
    
    return dCorr

In [4]:
#compute differentiable ABCD region counts (just like discotec paper)
def sigmoid_counts(var1, var2, cut1, cut2, weights, scale=100.0):
    #flatten 1D
    v1 = var1.view(-1)
    v2 = var2.view(-1)
    w  = weights.view(-1)

    #use sigmoids instead of hard cuts so ABCD is differentiable
    s1_high = torch.sigmoid(scale * (v1 - cut1))      
    s1_low  = torch.sigmoid(scale * (cut1 - v1)) 
    s2_high = torch.sigmoid(scale * (v2 - cut2))
    s2_low  = torch.sigmoid(scale * (cut2 - v2))

    #counts for each region
    NA = torch.sum(s1_high * s2_high * w)
    NB = torch.sum(s1_high * s2_low  * w)
    NC = torch.sum(s1_low  * s2_high * w)
    ND = torch.sum(s1_low  * s2_low  * w)
    return NA, NB, NC, ND

#compute the ABCD closure loss for a batch
def closure_loss_batch(var1, var2, weights, symmetrize=True,
                       n_events_min=10, max_tries=20):

    #flatten losses
    v1 = var1.view(-1)
    v2 = var2.view(-1)
    w = weights.view(-1)

    #pick random cuts within 1-99% of the batch like in discotec paper
    for _ in range(max_tries):
        with torch.no_grad():
            x_min = torch.quantile(v1, 0.01).item()
            x_max = torch.quantile(v1, 0.99).item()
            y_min = torch.quantile(v2, 0.01).item()
            y_max = torch.quantile(v2, 0.99).item()
            cut1 = np.random.uniform(x_min, x_max)
            cut2 = np.random.uniform(y_min, y_max)

        #compute counts for ABCD
        NA, NB, NC, ND = sigmoid_counts(v1, v2, cut1, cut2, w)

        #check if these random cuts give good statistics, if not skip and try new cuts
        if (NA.item() > n_events_min and NB.item() > n_events_min and
            NC.item() > n_events_min and ND.item() > n_events_min):
            break
    else:
        #if we somehow never found good cuts just return 0 to avoid NaNs
        return torch.tensor(0.0, device=var1.device, dtype=var1.dtype)

    #compute closure loss like discotec paper
    if symmetrize:
        num = torch.abs(NA * ND - NB * NC)
        den = NA * ND + NB * NC + 1e-8
    else:
        num = torch.abs(NA * ND - NB * NC)
        den = NB * NC + 1e-8

    return num / den


In [5]:
#calculates anomaly score like in ae legacy
def distance_pt(model_ae, data_np, device):
    x = torch.tensor(data_np, dtype=torch.float32, device=device)
    z_mean, z_logvar, _ = model_ae.encoder(x)
    score = torch.sum(z_mean**2, dim=1)
    return score.detach().cpu().numpy()

In [6]:
#function to make 2D histograms
def make_2D_hist(x, y, xlabel, ylabel, title, wandb_key, bins=40):
    fig = plt.figure(figsize=(5,4))
    plt.hist2d(x, y, bins=bins)
    plt.colorbar(label='Counts')
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    wandb.log({wandb_key: wandb.Image(fig)})
    plt.close(fig)

In [7]:
class PerSampleMSE(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss(reduction='none')
    def forward(self, recon, target):
        per_feat = self.mse(recon, target)
        return per_feat.mean(dim=1)

In [8]:
def fit_standard_scaler(X, eps=1e-8):
    mu  = X.mean(axis=0).astype(np.float32)
    std = X.std(axis=0).astype(np.float32)
    std = np.where(std < eps, 1.0, std)
    return mu, std

def transform_standard(X, mu, std):
    return (X - mu) / (std + 1e-8)

In [9]:
#prints h5 tree to examine
def print_h5_tree(h, prefix=""):
    for k in h.keys():
        item = h[k]
        if hasattr(item, 'keys'):
            print(prefix + f"[GROUP] {k}")
            print_h5_tree(item, prefix + "  ")
        else:
            try:
                print(prefix + f"{k}: shape={item.shape}, dtype={item.dtype}")
            except Exception:
                print(prefix + f"{k}: <dataset>")

In [10]:
def profile_plot(ax, x, y, nbins=30, logx=False, min_per_bin=20, label="mean ± SE"):
    x = np.asarray(x)
    y = np.asarray(y)
    m = np.isfinite(x) & np.isfinite(y)
    if logx:
        m &= (x > 0)

    x = x[m]
    y = y[m]

    # bin along x (linear or log space)
    if logx:
        xu = np.log10(x)
    else:
        xu = x

    # uniform bins over the chosen coordinate
    lo = float(xu.min())
    hi = float(xu.max())
    if lo == hi:
        hi = np.nextafter(hi, np.inf)

    edges = np.linspace(lo, hi, nbins + 1)
    centers = 0.5 * (edges[:-1] + edges[1:])

    # stats per bin
    mean, _, _ = binned_statistic(xu, y, statistic="mean", bins=edges)
    std,  _, _ = binned_statistic(xu, y, statistic="std",  bins=edges)
    cnt,  _, _ = binned_statistic(xu, y, statistic="count", bins=edges)

    sem = std / np.sqrt(np.maximum(cnt, 1))

    # keep well-populated bins
    good = cnt >= min_per_bin
    xc = centers[good]
    ym   = mean[good]
    ye   = sem[good]

    # convert x-axis back from log if needed
    if logx:
        xplot = 10.0 ** xc
        ax.set_xscale("log")
    else:
        xplot = xc

    ax.errorbar(xplot, ym, yerr=ye, fmt="o", ms=3, lw=1, capsize=2, label=label)
    ax.grid(alpha=0.3)
    return {"x": xplot, "mean": ym, "sem": ye, "count": cnt[good]}

In [11]:
def inference_loss_plots(
    ae_1,
    ae_2,
    dataset_test,                # single array OR (X1_test, X2_test)
    reco_loss_fn1=None,
    reco_loss_fn2=None,
    device=None,
    batch_size=2048,
    outdir="plots",
    bins=200,
    logy=False,
    ae1_input_slicer=None,
    ae2_input_slicer=None,
    ae1_loss_slicer=None,
    ae2_loss_slicer=None
):
    import os, numpy as np, torch, matplotlib.pyplot as plt
    from matplotlib.colors import LogNorm

    os.makedirs(outdir, exist_ok=True)
    assert device is not None

    do_ae1 = ae_1 is not None
    do_ae2 = ae_2 is not None
    assert do_ae1 or do_ae2

    if do_ae1:
        ae_1.eval(); assert reco_loss_fn1 is not None
    if do_ae2:
        ae_2.eval(); assert reco_loss_fn2 is not None

    # Allow (X1_test, X2_test) OR a single dataset for both
    if isinstance(dataset_test, (tuple, list)) and len(dataset_test) == 2:
        ds1, ds2 = dataset_test
    else:
        ds1 = dataset_test
        ds2 = dataset_test

    losses_1, losses_2 = [], []

    with torch.no_grad():
        if do_ae1 and ds1 is not None:
            X1_full = torch.tensor(ds1, dtype=torch.float32, device=device)
            for i in range(0, X1_full.size(0), batch_size):
                xb_full = X1_full[i:i + batch_size]
                xb1 = ae1_input_slicer(xb_full) if ae1_input_slicer else xb_full
                recon1, _ = ae_1(xb1)
                y1_pred, y1_true = recon1, xb1
                if ae1_loss_slicer:
                    y1_pred = ae1_loss_slicer(y1_pred); y1_true = ae1_loss_slicer(y1_true)
                loss_1 = reco_loss_fn1(y1_pred, y1_true).detach().cpu().numpy().reshape(-1)
                loss_1 = np.nan_to_num(loss_1, nan=0.0, posinf=0.0, neginf=0.0)
                losses_1.append(loss_1)

        if do_ae2 and ds2 is not None:
            X2_full = torch.tensor(ds2, dtype=torch.float32, device=device)
            for i in range(0, X2_full.size(0), batch_size):
                xb_full = X2_full[i:i + batch_size]
                xb2 = ae2_input_slicer(xb_full) if ae2_input_slicer else xb_full
                recon2, _ = ae_2(xb2)
                y2_pred, y2_true = recon2, xb2
                if ae2_loss_slicer:
                    y2_pred = ae2_loss_slicer(y2_pred); y2_true = ae2_loss_slicer(y2_true)
                loss_2 = reco_loss_fn2(y2_pred, y2_true).detach().cpu().numpy().reshape(-1)
                loss_2 = np.nan_to_num(loss_2, nan=0.0, posinf=0.0, neginf=0.0)
                losses_2.append(loss_2)

    losses_1_all = np.concatenate(losses_1, axis=0) if (do_ae1 and losses_1) else None
    losses_2_all = np.concatenate(losses_2, axis=0) if (do_ae2 and losses_2) else None

    # Save raw arrays
    if losses_1_all is not None:
        np.save(os.path.join(outdir, "test_reco_loss_ae1.npy"), losses_1_all)
    if losses_2_all is not None:
        np.save(os.path.join(outdir, "test_reco_loss_ae2.npy"), losses_2_all)

    # Helper to plot one histogram
    def _one_hist(arr, tag):
        if arr is None or arr.size == 0:
            print(f"[WARN] No losses for {tag}; skipping histogram.")
            return None
        mask = arr > 0
        x = arr[mask]
        if x.size == 0:
            print(f"[WARN] No positive losses for {tag}; skipping histogram.")
            return None
        x_max = np.nextafter(float(x.max()), np.inf)
        x_min = max(min(x.min(), x_max/1e6), 1e-12)
        edges = np.logspace(np.log10(x_min), np.log10(x_max), bins + 1)
        plt.figure(figsize=(7, 4))
        plt.hist(x, bins=edges)
        plt.xscale("log")
        if logy: plt.yscale("log")
        plt.xlabel(f"{tag} reconstruction loss"); plt.ylabel("Counts")
        plt.title(f"{tag} reconstruction loss (Test)")
        plt.xlim(edges[0], edges[-1])
        plt.tight_layout()
        out_png = os.path.join(outdir, f"hist_{tag}.png")
        plt.savefig(out_png, dpi=150); plt.close()
        try:
            import wandb; wandb.log({f"Eval/hist_{tag}": wandb.Image(out_png)})
        except Exception:
            pass
        return edges

    # Per-AE histograms
    edges1 = _one_hist(losses_1_all, "AE1") if do_ae1 else None
    edges2 = _one_hist(losses_2_all, "AE2") if do_ae2 else None

    # Cross-AE plots only if BOTH exist and are aligned in length
    if (losses_1_all is not None) and (losses_2_all is not None) and \
       (losses_1_all.shape[0] == losses_2_all.shape[0]):

        mask = (losses_1_all > 0) & (losses_2_all > 0)
        x1 = losses_1_all[mask]; x2 = losses_2_all[mask]

        # Choose edges robustly if None (fallback)
        if edges1 is None and x1.size:
            x1_max = np.nextafter(float(x1.max()), np.inf)
            x1_min = max(min(x1.min(), x1_max/1e6), 1e-12)
            edges1 = np.logspace(np.log10(x1_min), np.log10(x1_max), bins + 1)
        if edges2 is None and x2.size:
            x2_max = np.nextafter(float(x2.max()), np.inf)
            x2_min = max(min(x2.min(), x2_max/1e6), 1e-12)
            edges2 = np.logspace(np.log10(x2_min), np.log10(x2_max), bins + 1)

        # 2D hist
        if x1.size and x2.size:
            plt.figure(figsize=(6, 5))
            plt.hist2d(x1, x2, bins=[edges1, edges2], norm=LogNorm(vmin=1), cmin=1)
            plt.xscale("log"); plt.yscale("log")
            plt.xlabel("AE1 reconstruction loss"); plt.ylabel("AE2 reconstruction loss")
            plt.title("AE1 vs AE2 reconstruction loss (Test) — log-log")
            plt.xlim(edges1[0], edges1[-1]); plt.ylim(edges2[0], edges2[-1])
            plt.colorbar(label="Counts (log)")
            plt.tight_layout()
            out_png = os.path.join(outdir, "hist2d_AE1_vs_AE2.png")
            plt.savefig(out_png, dpi=150); plt.close()
            try:
                import wandb; wandb.log({"Eval/hist2d_AE1_vs_AE2": wandb.Image(out_png)})
            except Exception:
                pass

    # Profile plots
    if (losses_1_all is not None) and (losses_2_all is not None):
        n = min(len(losses_1_all), len(losses_2_all))
        x1 = losses_1_all[:n]
        x2 = losses_2_all[:n]
    
        fig, ax = plt.subplots(figsize=(6.2, 4.6))
        profile_plot(ax, x=x2, y=x1, nbins=100, logx=True, min_per_bin=50, label="mean ± SE")
        ax.set_title("Profile: ⟨AE1 loss⟩ vs AE2 loss (Test)")
        ax.set_xlabel("AE2 reconstruction loss")
        ax.set_ylabel("⟨AE1 reconstruction loss⟩")
        ax.set_yscale("log")
        ax.legend()
        fig.tight_layout()
        out_png = os.path.join(outdir, "profile_AE1_vs_AE2.png")
        fig.savefig(out_png, dpi=150)
        plt.close(fig)
        try:
            import wandb; wandb.log({"Profiles/AE1_vs_AE2": wandb.Image(out_png)})
        except Exception:
            pass
    
        fig, ax = plt.subplots(figsize=(6.2, 4.6))
        profile_plot(ax, x=x1, y=x2, nbins=100, logx=True, min_per_bin=50, label="mean ± SE")
        ax.set_title("Profile: ⟨AE2 loss⟩ vs AE1 loss (Test)")
        ax.set_xlabel("AE1 reconstruction loss")
        ax.set_ylabel("⟨AE2 reconstruction loss⟩")
        ax.set_yscale("log")
        ax.legend()
        fig.tight_layout()
        out_png = os.path.join(outdir, "profile_AE2_vs_AE1.png")
        fig.savefig(out_png, dpi=150)
        plt.close(fig)
        try:
            import wandb; wandb.log({"Profiles/AE2_vs_AE1": wandb.Image(out_png)})
        except Exception:
            pass
                
    return losses_1_all, losses_2_all




In [12]:
def run(config):
    seed = 123
    set_seed(seed)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(f"Using: {device}", flush=True)

    print("Logging in to wandb...", flush=True)
    wandb.login(key="24d1d60ce26563c74d290d7b487cb104fc251271")
    wandb.init(project="Double Disco Axo Training",
               settings=wandb.Settings(_disable_stats=True),
               config=config)
    run_name = wandb.run.name
    print(f"Run name: {run_name}", flush=True)

    alpha = float(config['alpha'])
    ae_lr = float(config['ae_lr'])
    latent_dim = int(config['ae_latent'])
    enc_nodes = list(config['ae_nodes'])
    dec_nodes_template = [24, 32, 64, 128]

    lambda_disco_max = float(config.get("lambda_disco", 1.0))
    lambda_closure_max = float(config.get("lambda_closure", 0.0))
    disco_warmup_epochs = int(config.get("disco_warmup_epochs", 0))

    #load data
    data_path = '//axovol/HLT_data_nov_12_2024I.h5'
    print("Loading dataset...", flush=True)
    with h5.File(data_path, 'r') as f:
        r = f['data'] if 'data' in f else f

        print("H5 tree:")
        print_h5_tree(r)

        x_train = r['Background_data']['Train']['DATA'][:]
        x_test = r['Background_data']['Test']['DATA'][:]

        print(f"Train shape: {x_train.shape}, Test shape: {x_test.shape}", flush=True)

        Xtr_raw = x_train
        Xte_raw = x_test

        # zero out padding objects
        pad_tr = (x_train == 0.0).all(axis=-1)
        pad_te = (x_test == 0.0).all(axis=-1)
        Xtr_raw[pad_tr] = 0.0
        Xte_raw[pad_te] = 0.0

        #slots for objects
        SLOTS = {
            "ELECTRONS": (0, 4),
            "MUONS": (4, 8),
            "PHOTONS": (8, 12),
            "JETS": (12, 22),
            "FATJETS": (22, 32),
            "MET": (32, 33),
        }

        def _slice_slots(x, start, end):
            n, _, fdim = x.shape
            return x[:, start:end, :].reshape(n, (end - start) * fdim)

        def _take_groups(x, groups):
            parts = [_slice_slots(x, *SLOTS[g]) for g in groups]
            return np.concatenate(parts, axis=1) if len(parts) > 1 else parts[0]

        # AE-1 = jets + fatjets + MET, AE-2 = leptons + photons
        X1_train = _take_groups(Xtr_raw, ["JETS", "FATJETS", "MET"])
        X1_test = _take_groups(Xte_raw, ["JETS", "FATJETS", "MET"])
        X2_train = _take_groups(Xtr_raw, ["ELECTRONS", "MUONS", "PHOTONS"])
        X2_test = _take_groups(Xte_raw, ["ELECTRONS", "MUONS", "PHOTONS"])

        # keep copies for later plots/inference
        X1_train_raw, X1_test_raw = X1_train, X1_test
        X2_train_raw, X2_test_raw = X2_train, X2_test

        # feature-wise standardization
        mu1, std1 = fit_standard_scaler(X1_train_raw)
        mu2, std2 = fit_standard_scaler(X2_train_raw)

        X1_train_z = transform_standard(X1_train_raw, mu1, std1)
        X1_test_z = transform_standard(X1_test_raw,  mu1, std1)
        X2_train_z = transform_standard(X2_train_raw, mu2, std2)
        X2_test_z = transform_standard(X2_test_raw,  mu2, std2)

    feat1 = X1_train.shape[1]
    feat2 = X2_train.shape[1]

    reco_loss_fn1 = PerSampleMSE().to(device)
    reco_loss_fn2 = PerSampleMSE().to(device)
    print("Loss functions ready.", flush=True)

    ae1_cfg = {
        "features": feat1,
        "latent_dim": latent_dim,
        "encoder_config": {"nodes": enc_nodes},
        "decoder_config": {"nodes": dec_nodes_template + [feat1]},
        "alpha": alpha
    }
    ae2_cfg = {
        "features": feat2,
        "latent_dim": latent_dim,
        "encoder_config": {"nodes": enc_nodes},
        "decoder_config": {"nodes": dec_nodes_template + [feat2]},
        "alpha": alpha
    }

    ae_1 = Autoencoder(ae1_cfg).to(device)
    ae_2 = Autoencoder(ae2_cfg).to(device)
    print("Autoencoders are ready.", flush=True)

    optimizer = torch.optim.Adam(
        list(ae_1.parameters()) + list(ae_2.parameters()),
        lr=ae_lr
    )

    warmup_epochs = 10  
    cos = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=32, T_mult=2, eta_min=0.0
    )

    def set_lr(lr):
        for g in optimizer.param_groups:
            g['lr'] = lr

    Epochs_AE = 100
    Batch_size = 2048

    print("Moving data to device...", flush=True)
    X1 = torch.tensor(X1_train_z, dtype=torch.float32, device=device)
    X2 = torch.tensor(X2_train_z, dtype=torch.float32, device=device)
    print("Data on device.", flush=True)

    #training loop
    print("Starting the training loop!", flush=True)
    N1 = X1.size(0)
    N2 = X2.size(0)
    Nmin = min(N1, N2)

    for epoch in range(Epochs_AE):
        ae1_reco_loss = []
        ae2_reco_loss = []

        # LR warmup
        if epoch < warmup_epochs:
            lr = ae_lr * (epoch + 1) / warmup_epochs
            set_lr(lr)
        else:
            cos.step(epoch - warmup_epochs)

        #disco warmup
        if disco_warmup_epochs > 0 and epoch < disco_warmup_epochs:
            ramp = (epoch + 1) / disco_warmup_epochs
            lambda_disco = lambda_disco_max * ramp
            lambda_closure = lambda_closure_max * ramp
        else:
            lambda_disco = lambda_disco_max
            lambda_closure = lambda_closure_max

        #perm1 = torch.randperm(N1, device=device)
        perm = torch.randperm(Nmin, device=device)

        num_batches = 0
        total_loss = total_reco1 = total_reco2 = total_disco  = total_closure = 0.0

        for i0 in range(0, Nmin, Batch_size):
            i1 = min(i0 + Batch_size, Nmin)
            bsz = i1 - i0

            idx = perm[i0:i1]

            xb1 = X1[idx]
            xb2 = X2[idx]

            recon1, z1 = ae_1(xb1)
            recon2, z2 = ae_2(xb2)

            # per-event reconstruction losses
            reco1_per = reco_loss_fn1(recon1, xb1)
            reco2_per = reco_loss_fn2(recon2, xb2)

            w = torch.ones(len(idx), device=device, dtype=reco1_per.dtype)

            #disco
            disco = distance_corr(reco1_per, reco2_per, w, power=1)

            #ABCD closure loss on this batch (AE1_loss vs AE2_loss)
            closure = closure_loss_batch(
                reco1_per,
                reco2_per,
                w,
                symmetrize=True,
                n_events_min=10,
            )

            #mean reconstruction losses
            reco1 = ae_1.alpha * reco1_per.mean()
            reco2 = ae_2.alpha * reco2_per.mean()

            #total loss with warmup scaled lambdas
            loss = reco1 + reco2 + lambda_disco * disco +lambda_closure*closure

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                list(ae_1.parameters()) + list(ae_2.parameters()), max_norm=5.0
            )
            optimizer.step()

            total_loss  += loss.item()
            total_reco1 += reco1.item()
            total_reco2 += reco2.item()
            total_disco += disco.item()
            total_closure += closure.item()
            num_batches += 1

            ae1_reco_loss.append(reco1.item())
            ae2_reco_loss.append(reco2.item())

        avg_loss = total_loss/max(1, num_batches)
        avg_reco1 = total_reco1/max(1, num_batches)
        avg_reco2 = total_reco2/max(1, num_batches)
        avg_disco = total_disco/max(1, num_batches)
        avg_closure = total_closure/max(1, num_batches)

        print(f"[EPOCH {epoch}/{Epochs_AE}] "
              f"Loss={avg_loss:.4f} "
              f"Reco1(jets)={avg_reco1:.4f} Reco2(lep/photon/MET)={avg_reco2:.4f} "
              f"DisCo={avg_disco:.4f} "
              f"Closure={avg_closure:4f}"
              f"(lambda_dis={lambda_disco:.3g},flush=True)")

        wandb.log({
            "epoch": epoch,
            "TotalLoss": avg_loss,
            "RecoLoss_AE1_jets": avg_reco1,
            "RecoLoss_AE2_lepPhotMET": avg_reco2,
            "DisCoLoss": avg_disco,
            "ClosureLoss": avg_closure,
            "lr": optimizer.param_groups[0]["lr"],
            "lambda_disco": lambda_disco,
            "lambda_closure": lambda_closure,
        })

        ae1_reco_np = np.array(ae1_reco_loss)
        ae2_reco_np = np.array(ae2_reco_loss)

        make_2D_hist(ae1_reco_np, ae2_reco_np,
                     "Reco Loss (jets)", "Reco Loss (lep/photon/MET)",
                     f"Epoch {epoch}: Reco jets vs Reco lep/photon/MET",
                     wandb_key="Hists2D/Reco_jets_vs_Reco_lepPhotMET")

    print("Finished training.", flush=True)

    # save models
    torch.save(ae_1.state_dict(), "ae1_trained_jets.pth")
    torch.save(ae_2.state_dict(), "ae2_trained_lepPhotMET.pth")

    # tensors for inference normalisation
    mu1_t = torch.tensor(mu1, dtype=torch.float32, device=device)
    std1_t = torch.tensor(std1, dtype=torch.float32, device=device)
    mu2_t = torch.tensor(mu2, dtype=torch.float32, device=device)
    std2_t = torch.tensor(std2, dtype=torch.float32, device=device)
    
    def ae1_input_from_raw(t_flat: torch.Tensor):
        return (t_flat - mu1_t) / (std1_t + 1e-8)
    
    def ae2_input_from_raw(t_flat: torch.Tensor):
        return (t_flat - mu2_t) / (std2_t + 1e-8)

    _ae1_losses, _ae2_losses = inference_loss_plots(
        ae_1=ae_1,
        ae_2=ae_2,
        dataset_test=(X1_test_raw, X2_test_raw),
        reco_loss_fn1=reco_loss_fn1,
        reco_loss_fn2=reco_loss_fn2,
        device=device,
        batch_size=2048,
        outdir="plots_profiles_same_events",
        bins=200,
        logy=False,
        ae1_input_slicer=ae1_input_from_raw,
        ae2_input_slicer=ae2_input_from_raw,
        ae1_loss_slicer=None,
        ae2_loss_slicer=None
    )
        
    return {
        "ae_1": ae_1,
        "ae_2": ae_2,
        "X1_test_z": X1_test_z,
        "X2_test_z": X2_test_z,
        "reco_loss_fn1": reco_loss_fn1,
        "reco_loss_fn2": reco_loss_fn2,
        "device": device,
    }


In [13]:
config = {
    'ae_lr': 1e-4,
    'alpha': 0.5,
    'ae_latent': 8,
    'ae_nodes': [28, 14],
    'lambda_disco': 100.0,     
    'lambda_closure': 1,     
    'disco_warmup_epochs': 0,  
}


In [14]:
# run(config)

In [15]:
#run and store the returned stuff here
training_vars = run(config)

#store things we need
ae_1 = training_vars["ae_1"]
ae_2 = training_vars["ae_2"]
X1_test_z = training_vars["X1_test_z"]
X2_test_z = training_vars["X2_test_z"]
reco_loss_fn1 = training_vars["reco_loss_fn1"]
reco_loss_fn2 = training_vars["reco_loss_fn2"]
device = training_vars["device"]

#inference of test set
def inference(ae, Xz, loss_fn, device, batch_size=4096):
    ae.eval()
    n = Xz.shape[0]
    out = np.empty(n, dtype=np.float32)
    with torch.no_grad():
        for i0 in range(0, n, batch_size):
            i1 = min(i0 + batch_size, n)
            xb = torch.tensor(Xz[i0:i1], dtype=torch.float32, device=device)
            # forward pass
            recon, _ = ae(xb)                         
            loss_b = loss_fn(recon, xb)
            out[i0:i1] = loss_b.detach().cpu().numpy()
    return out

#reco losses after running on test set
reco_test_1 = inference(ae_1, X1_test_z, reco_loss_fn1, device)
reco_test_2 = inference(ae_2, X2_test_z, reco_loss_fn2, device)

#divides events into ABCD based on their losses compared to threshold
def abcd_counts(loss_1, loss_2, percent_1, percent_2):
    thresh_1 = np.quantile(loss_1, percent_1)
    thresh_2 = np.quantile(loss_2, percent_2)
    A = int(((loss_1 > thresh_1) & (loss_2 > thresh_2)).sum())     
    B = int(((loss_1 > thresh_1) & (loss_2 <= thresh_2)).sum())    
    C = int(((loss_1 <= thresh_1) & (loss_2 > thresh_2)).sum())
    D = int(((loss_1 <= thresh_1) & (loss_2 <= thresh_2)).sum()) 
    return thresh_1, thresh_2, A, B, C, D

#calculate nonclosure in region A
def nonclosure_A(A, B, C, D, eps=1e-8):
    #predicted background
    A_hat = (B * C) / max(D, eps)
    if A_hat <= 0:
        return np.inf, A_hat
    return (A - A_hat) / A_hat, A_hat

#scan thresholds to minimize nonclosure
percent_1s = np.linspace(0.75, 0.98, 24)
percent_2s = np.linspace(0.75, 0.98, 24)

#dict to store best combo of thresholds in ABCD scan
best = {
    "percent_1": None, "percent_2": None, "t1": None, "t2": None,
    "A": None, "B": None, "C": None, "D": None,
    "nonclosure": np.inf, "A_hat": None
}
#require at least 200 events in A
min_A = 200 
#at least 1000 in D
min_D = 1000 

#loop over thresholds (scan) and update best dict
for percent_1 in percent_1s:
    for percent_2 in percent_2s:
        thresh_1, thresh_2, A, B, C, D = abcd_counts(reco_test_1, reco_test_2, percent_1, percent_2)
        if (A < min_A) or (D < min_D):
            continue
        nc, A_hat = nonclosure_A(A, B, C, D)
        if np.isfinite(nc) and abs(nc) < abs(best["nonclosure"]):
            best.update({
                "percent_1": percent_1, "percent_2": percent_2, "t1": thresh_1, "t2": thresh_2,
                "A": A, "B": B, "C": C, "D": D,
                "nonclosure": nc, "A_hat": A_hat
            })

#take optimized parameters from scan
t1_opt = best["t1"]
t2_opt = best["t2"]
percent_1_opt = best["percent_1"]
percent_2_opt = best["percent_2"]
N_A = best["A"] 
N_B = best["B"] 
N_C = best["C"] 
N_D = best["D"]
N_A_hat = best["A_hat"] 
nonclosure_A = best["nonclosure"]

#print them
print(f"Optimized percentiles: percent_1={percent_1_opt:.3f}, percent_2={percent_2_opt:.3f}", flush=True)
print(f"Optimized thresholds: threshold 1={t1_opt:.6g}, threshold  2={t2_opt:.6g}", flush=True)
print(f"ABCD counts: A={N_A}, B={N_B}, C={N_C}, D={N_D}", flush=True)
print(f"Predicted A: N_A_hat={N_A_hat:.3f}", flush=True)
print(f"Non-closure: {(100.0*nonclosure_A):.2f}%  ((A - A_hat)/A_hat)", flush=True)

#make plot dir
plot_dir = "plots/"  
os.makedirs(plot_dir, exist_ok=True)

#scan over percentiles
p_list = np.linspace(0.75, 0.98, 24)
#init selection efficiency list
effs = []
#init predicted/true background list
bkg_ratio_ae = []
#init poisson error band list
bkg_ae_poisson_unc = []

#total number test background events
Ntot = float(len(reco_test_1))

#loop over percentiles
for p in p_list:
    #find ABCD counts at (p, p)
    thresh_1, thresh_2, A, B, C, D = abcd_counts(reco_test_1, reco_test_2, p, p)
    A_hat = (B * C) / max(D, 1e-8)
    ratio = A_hat / max(A, 1e-8)        

    #poisson error propagation
    invA = 0.0 if A == 0 else 1.0/A
    invB = 0.0 if B == 0 else 1.0/B
    invC = 0.0 if C == 0 else 1.0/C
    invD = 0.0 if D == 0 else 1.0/D
    rel_var = invA + invB + invC + invD
    sigma = abs(ratio) * np.sqrt(rel_var) if rel_var > 0 else 0.0

    #selection efficiency for region A
    effs.append(A / max(Ntot, 1.0))      
    bkg_ratio_ae.append(ratio)
    bkg_ae_poisson_unc.append(sigma)

#sort by efficiency
effs = np.array(effs)
bkg_ratio_ae = np.array(bkg_ratio_ae)
bkg_ae_poisson_unc = np.array(bkg_ae_poisson_unc)
order = np.argsort(effs)
effs = effs[order] 
bkg_ratio_ae = bkg_ratio_ae[order]
bkg_ae_poisson_unc = bkg_ae_poisson_unc[order]

#colors and size like TNT code
colors = ['g', 'b']  
fig_size = (8, 6)
fs = 28
fs_leg = 24

plt.figure(figsize=fig_size)
fig, ax = plt.subplots(figsize=fig_size)

#main AE curve plot
ax.plot(effs, bkg_ratio_ae, c=colors[0], label="Autoencoders")

#error band plot
alpha = 0.5
ae_low = bkg_ratio_ae - bkg_ae_poisson_unc
ae_high = bkg_ratio_ae + bkg_ae_poisson_unc
ax.fill_between(effs, ae_low, ae_high, facecolor=colors[0], alpha=alpha, interpolate=True)

#reference lines at 1 (perfect closure) and +/-5%
one = np.ones_like(effs)
one_m = np.full_like(effs, 0.95)
one_p = np.full_like(effs, 1.05)
ax.plot(effs, one,   linestyle='-',  color='black')
ax.plot(effs, one_m, linestyle='--', color='black')
ax.plot(effs, one_p, linestyle='--', color='black')

#plot the optimized (percent_1_opt, percent_2_opt) operating point
eff_opt = N_A / max(Ntot, 1.0)
ax.plot([eff_opt], [N_A_hat / max(N_A, 1e-8)], marker='o', c='red', label='Optimized (bkg)')

#labels and style like TNT plot code
ax.set_xlabel('Selection Efficiency', fontsize=fs)
ax.set_ylabel('Predicted Bkg. / True Bkg.', fontsize=fs)
plt.ylim([0.0, 1.5])
plt.xscale('log')
plt.tick_params(axis='x', labelsize=fs_leg)
plt.tick_params(axis='y', labelsize=fs_leg)
plt.legend(loc="lower right", fontsize=fs_leg)

#save plot
out_path = os.path.join(plot_dir, "cut_and_count_bkg_check.png")
plt.savefig(out_path, dpi=200, bbox_inches='tight')
plt.close()

#log closure plot into wandb
wandb.log({"Closure/plot": wandb.Image(out_path)})


Using: cuda:0
Logging in to wandb...


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/jovyan/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mescheuller[0m ([33mescheuller-uc-san-diego[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Run name: distinctive-shadow-302
Loading dataset...
H5 tree:
[GROUP] Background_data
  [GROUP] Test
    DATA: shape=(284221, 33, 15), dtype=float32
  [GROUP] Train
    DATA: shape=(1136883, 33, 15), dtype=float32
Train shape: (1136883, 33, 15), Test shape: (284221, 33, 15)
Loss functions ready.
Autoencoders are ready.
Moving data to device...
Data on device.
Starting the training loop!
[EPOCH 0/100] Loss=4.1854 Reco1(jets)=0.3334 Reco2(lep/photon/MET)=0.1381 DisCo=0.0333 Closure=0.383814(lambda_dis=100,flush=True)
[EPOCH 2/100] Loss=0.8855 Reco1(jets)=0.3919 Reco2(lep/photon/MET)=0.1742 DisCo=0.0022 Closure=0.094937(lambda_dis=100,flush=True)
[EPOCH 3/100] Loss=0.8535 Reco1(jets)=0.3742 Reco2(lep/photon/MET)=0.1707 DisCo=0.0021 Closure=0.102570(lambda_dis=100,flush=True)
[EPOCH 4/100] Loss=0.8024 Reco1(jets)=0.3511 Reco2(lep/photon/MET)=0.1651 DisCo=0.0019 Closure=0.095530(lambda_dis=100,flush=True)
[EPOCH 5/100] Loss=0.7601 Reco1(jets)=0.3230 Reco2(lep/photon/MET)=0.1670 DisCo=0.0018 

<Figure size 800x600 with 0 Axes>