In [1]:
import os
import random
import gc
import argparse
import numpy as np
import h5py as h5
import torch
import torch.nn.functional as F
import torch.nn as nn
import wandb

from models.autoencoder import Autoencoder 
#from losses.cyl_ptpz_mae import CylPtPzMAE

from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LogNorm
from scipy.stats import binned_statistic



In [2]:
#setting a seed like in ae_legacy
def set_seed(seed=123):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [3]:
def distance_corr(var_1,var_2,normedweight,power=1):
    """var_1: First variable to decorrelate (eg mass)
    var_2: Second variable to decorrelate (eg classifier output)
    normedweight: Per-example weight. Sum of weights should add up to N (where N is the number of examples)
    power: Exponent used in calculating the distance correlation
    
    va1_1, var_2 and normedweight should all be 1D torch tensors with the same number of entries
    
    Usage: Add to your loss function. total_loss = BCE_loss + lambda * distance_corr
    """ 
    
    xx = var_1.view(-1, 1).repeat(1, len(var_1)).view(len(var_1),len(var_1))
    yy = var_1.repeat(len(var_1),1).view(len(var_1),len(var_1))
    amat = (xx-yy).abs()

    xx = var_2.view(-1, 1).repeat(1, len(var_2)).view(len(var_2),len(var_2))
    yy = var_2.repeat(len(var_2),1).view(len(var_2),len(var_2))
    bmat = (xx-yy).abs()

    amatavg = torch.mean(amat*normedweight,dim=1)
    Amat=amat-amatavg.repeat(len(var_1),1).view(len(var_1),len(var_1))\
        -amatavg.view(-1, 1).repeat(1, len(var_1)).view(len(var_1),len(var_1))\
        +torch.mean(amatavg*normedweight)

    bmatavg = torch.mean(bmat*normedweight,dim=1)
    Bmat=bmat-bmatavg.repeat(len(var_2),1).view(len(var_2),len(var_2))\
        -bmatavg.view(-1, 1).repeat(1, len(var_2)).view(len(var_2),len(var_2))\
        +torch.mean(bmatavg*normedweight)

    ABavg = torch.mean(Amat*Bmat*normedweight,dim=1)
    AAavg = torch.mean(Amat*Amat*normedweight,dim=1)
    BBavg = torch.mean(Bmat*Bmat*normedweight,dim=1)

    if(power==1):
        dCorr=(torch.mean(ABavg*normedweight))/torch.sqrt((torch.mean(AAavg*normedweight)*torch.mean(BBavg*normedweight)))
    elif(power==2):
        dCorr=(torch.mean(ABavg*normedweight))**2/(torch.mean(AAavg*normedweight)*torch.mean(BBavg*normedweight))
    else:
        dCorr=((torch.mean(ABavg*normedweight))/torch.sqrt((torch.mean(AAavg*normedweight)*torch.mean(BBavg*normedweight))))**power
    
    return dCorr

In [4]:
#calculates anomaly score like in ae legacy
def distance_pt(model_ae, data_np, device):
    x = torch.tensor(data_np, dtype=torch.float32, device=device)
    z_mean, z_logvar, _ = model_ae.encoder(x)
    score = torch.sum(z_mean**2, dim=1)
    return score.detach().cpu().numpy()

In [5]:
#function to make 2D histograms
def make_2D_hist(x, y, xlabel, ylabel, title, wandb_key, bins=40):
    fig = plt.figure(figsize=(5,4))
    plt.hist2d(x, y, bins=bins)
    plt.colorbar(label='Counts')
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    wandb.log({wandb_key: wandb.Image(fig)})
    plt.close(fig)

In [6]:
class PerSampleMSE(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss(reduction='none')
    def forward(self, recon, target):
        per_feat = self.mse(recon, target)
        return per_feat.mean(dim=1)

In [7]:
def fit_standard_scaler(X, eps=1e-8):
    mu  = X.mean(axis=0).astype(np.float32)
    std = X.std(axis=0).astype(np.float32)
    std = np.where(std < eps, 1.0, std)
    return mu, std

def transform_standard(X, mu, std):
    return (X - mu) / (std + 1e-8)

In [8]:
#prints h5 tree to examine
def print_h5_tree(h, prefix=""):
    for k in h.keys():
        item = h[k]
        if hasattr(item, 'keys'):
            print(prefix + f"[GROUP] {k}")
            print_h5_tree(item, prefix + "  ")
        else:
            try:
                print(prefix + f"{k}: shape={item.shape}, dtype={item.dtype}")
            except Exception:
                print(prefix + f"{k}: <dataset>")

In [9]:
def profile_plot(ax, x, y, nbins=30, logx=False, min_per_bin=20, label="mean ± SE"):
    x = np.asarray(x)
    y = np.asarray(y)
    m = np.isfinite(x) & np.isfinite(y)
    if logx:
        m &= (x > 0)

    x = x[m]
    y = y[m]

    # bin along x (linear or log space)
    if logx:
        xu = np.log10(x)
    else:
        xu = x

    # uniform bins over the chosen coordinate
    lo = float(xu.min())
    hi = float(xu.max())
    if lo == hi:
        hi = np.nextafter(hi, np.inf)

    edges = np.linspace(lo, hi, nbins + 1)
    centers = 0.5 * (edges[:-1] + edges[1:])

    # stats per bin
    mean, _, _ = binned_statistic(xu, y, statistic="mean", bins=edges)
    std,  _, _ = binned_statistic(xu, y, statistic="std",  bins=edges)
    cnt,  _, _ = binned_statistic(xu, y, statistic="count", bins=edges)

    sem = std / np.sqrt(np.maximum(cnt, 1))

    # keep well-populated bins
    good = cnt >= min_per_bin
    xc = centers[good]
    ym   = mean[good]
    ye   = sem[good]

    # convert x-axis back from log if needed
    if logx:
        xplot = 10.0 ** xc
        ax.set_xscale("log")
    else:
        xplot = xc

    ax.errorbar(xplot, ym, yerr=ye, fmt="o", ms=3, lw=1, capsize=2, label=label)
    ax.grid(alpha=0.3)
    return {"x": xplot, "mean": ym, "sem": ye, "count": cnt[good]}

In [10]:
def inference_loss_plots(
    ae_1,
    ae_2,
    dataset_test,                # single array OR (X1_test, X2_test)
    reco_loss_fn1=None,
    reco_loss_fn2=None,
    device=None,
    batch_size=2048,
    outdir="plots",
    bins=200,
    logy=False,
    ae1_input_slicer=None,
    ae2_input_slicer=None,
    ae1_loss_slicer=None,
    ae2_loss_slicer=None
):
    import os, numpy as np, torch, matplotlib.pyplot as plt
    from matplotlib.colors import LogNorm

    os.makedirs(outdir, exist_ok=True)
    assert device is not None

    do_ae1 = ae_1 is not None
    do_ae2 = ae_2 is not None
    assert do_ae1 or do_ae2

    if do_ae1:
        ae_1.eval(); assert reco_loss_fn1 is not None
    if do_ae2:
        ae_2.eval(); assert reco_loss_fn2 is not None

    # Allow (X1_test, X2_test) OR a single dataset for both
    if isinstance(dataset_test, (tuple, list)) and len(dataset_test) == 2:
        ds1, ds2 = dataset_test
    else:
        ds1 = dataset_test
        ds2 = dataset_test

    losses_1, losses_2 = [], []

    with torch.no_grad():
        if do_ae1 and ds1 is not None:
            X1_full = torch.tensor(ds1, dtype=torch.float32, device=device)
            for i in range(0, X1_full.size(0), batch_size):
                xb_full = X1_full[i:i + batch_size]
                xb1 = ae1_input_slicer(xb_full) if ae1_input_slicer else xb_full
                recon1, _ = ae_1(xb1)
                y1_pred, y1_true = recon1, xb1
                if ae1_loss_slicer:
                    y1_pred = ae1_loss_slicer(y1_pred); y1_true = ae1_loss_slicer(y1_true)
                l1 = reco_loss_fn1(y1_pred, y1_true).detach().cpu().numpy().reshape(-1)
                l1 = np.nan_to_num(l1, nan=0.0, posinf=0.0, neginf=0.0)
                losses_1.append(l1)

        if do_ae2 and ds2 is not None:
            X2_full = torch.tensor(ds2, dtype=torch.float32, device=device)
            for i in range(0, X2_full.size(0), batch_size):
                xb_full = X2_full[i:i + batch_size]
                xb2 = ae2_input_slicer(xb_full) if ae2_input_slicer else xb_full
                recon2, _ = ae_2(xb2)
                y2_pred, y2_true = recon2, xb2
                if ae2_loss_slicer:
                    y2_pred = ae2_loss_slicer(y2_pred); y2_true = ae2_loss_slicer(y2_true)
                l2 = reco_loss_fn2(y2_pred, y2_true).detach().cpu().numpy().reshape(-1)
                l2 = np.nan_to_num(l2, nan=0.0, posinf=0.0, neginf=0.0)
                losses_2.append(l2)

    losses_1_all = np.concatenate(losses_1, axis=0) if (do_ae1 and losses_1) else None
    losses_2_all = np.concatenate(losses_2, axis=0) if (do_ae2 and losses_2) else None

    # Save raw arrays
    if losses_1_all is not None:
        np.save(os.path.join(outdir, "test_reco_loss_ae1.npy"), losses_1_all)
    if losses_2_all is not None:
        np.save(os.path.join(outdir, "test_reco_loss_ae2.npy"), losses_2_all)

    # Helper to plot one histogram
    def _one_hist(arr, tag):
        if arr is None or arr.size == 0:
            print(f"[WARN] No losses for {tag}; skipping histogram.")
            return None
        mask = arr > 0
        x = arr[mask]
        if x.size == 0:
            print(f"[WARN] No positive losses for {tag}; skipping histogram.")
            return None
        x_max = np.nextafter(float(x.max()), np.inf)
        x_min = max(min(x.min(), x_max/1e6), 1e-12)
        edges = np.logspace(np.log10(x_min), np.log10(x_max), bins + 1)
        plt.figure(figsize=(7, 4))
        plt.hist(x, bins=edges)
        plt.xscale("log")
        if logy: plt.yscale("log")
        plt.xlabel(f"{tag} reconstruction loss"); plt.ylabel("Counts")
        plt.title(f"{tag} reconstruction loss (Test)")
        plt.xlim(edges[0], edges[-1])
        plt.tight_layout()
        out_png = os.path.join(outdir, f"hist_{tag}.png")
        plt.savefig(out_png, dpi=150); plt.close()
        try:
            import wandb; wandb.log({f"Eval/hist_{tag}": wandb.Image(out_png)})
        except Exception:
            pass
        return edges

    # Per-AE histograms
    edges1 = _one_hist(losses_1_all, "AE1") if do_ae1 else None
    edges2 = _one_hist(losses_2_all, "AE2") if do_ae2 else None

    # Cross-AE plots only if BOTH exist and are aligned in length
    if (losses_1_all is not None) and (losses_2_all is not None) and \
       (losses_1_all.shape[0] == losses_2_all.shape[0]):

        mask = (losses_1_all > 0) & (losses_2_all > 0)
        x1 = losses_1_all[mask]; x2 = losses_2_all[mask]

        # Choose edges robustly if None (fallback)
        if edges1 is None and x1.size:
            x1_max = np.nextafter(float(x1.max()), np.inf)
            x1_min = max(min(x1.min(), x1_max/1e6), 1e-12)
            edges1 = np.logspace(np.log10(x1_min), np.log10(x1_max), bins + 1)
        if edges2 is None and x2.size:
            x2_max = np.nextafter(float(x2.max()), np.inf)
            x2_min = max(min(x2.min(), x2_max/1e6), 1e-12)
            edges2 = np.logspace(np.log10(x2_min), np.log10(x2_max), bins + 1)

        # 2D hist
        if x1.size and x2.size:
            plt.figure(figsize=(6, 5))
            plt.hist2d(x1, x2, bins=[edges1, edges2], norm=LogNorm(vmin=1), cmin=1)
            plt.xscale("log"); plt.yscale("log")
            plt.xlabel("AE1 reconstruction loss"); plt.ylabel("AE2 reconstruction loss")
            plt.title("AE1 vs AE2 reconstruction loss (Test) — log-log")
            plt.xlim(edges1[0], edges1[-1]); plt.ylim(edges2[0], edges2[-1])
            plt.colorbar(label="Counts (log)")
            plt.tight_layout()
            out_png = os.path.join(outdir, "hist2d_AE1_vs_AE2.png")
            plt.savefig(out_png, dpi=150); plt.close()
            try:
                import wandb; wandb.log({"Eval/hist2d_AE1_vs_AE2": wandb.Image(out_png)})
            except Exception:
                pass

    # Profile plots
    if (losses_1_all is not None) and (losses_2_all is not None):
        n = min(len(losses_1_all), len(losses_2_all))
        x1 = losses_1_all[:n]
        x2 = losses_2_all[:n]
    
        fig, ax = plt.subplots(figsize=(6.2, 4.6))
        profile_plot(ax, x=x2, y=x1, nbins=100, logx=True, min_per_bin=50, label="mean ± SE")
        ax.set_title("Profile: ⟨AE1 loss⟩ vs AE2 loss (Test)")
        ax.set_xlabel("AE2 reconstruction loss")
        ax.set_ylabel("⟨AE1 reconstruction loss⟩")
        ax.set_yscale("log")
        ax.legend()
        fig.tight_layout()
        out_png = os.path.join(outdir, "profile_AE1_vs_AE2.png")
        fig.savefig(out_png, dpi=150)
        plt.close(fig)
        try:
            import wandb; wandb.log({"Profiles/AE1_vs_AE2": wandb.Image(out_png)})
        except Exception:
            pass
    
        fig, ax = plt.subplots(figsize=(6.2, 4.6))
        profile_plot(ax, x=x1, y=x2, nbins=100, logx=True, min_per_bin=50, label="mean ± SE")
        ax.set_title("Profile: ⟨AE2 loss⟩ vs AE1 loss (Test)")
        ax.set_xlabel("AE1 reconstruction loss")
        ax.set_ylabel("⟨AE2 reconstruction loss⟩")
        ax.set_yscale("log")
        ax.legend()
        fig.tight_layout()
        out_png = os.path.join(outdir, "profile_AE2_vs_AE1.png")
        fig.savefig(out_png, dpi=150)
        plt.close(fig)
        try:
            import wandb; wandb.log({"Profiles/AE2_vs_AE1": wandb.Image(out_png)})
        except Exception:
            pass
                
    return losses_1_all, losses_2_all




In [11]:
def run(config):
    seed = 123
    set_seed(seed)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(f"Using: {device}", flush=True)

    print("Logging in to wandb...", flush=True)
    wandb.login(key="24d1d60ce26563c74d290d7b487cb104fc251271")
    wandb.init(project="Double Disco Axo Training",
               settings=wandb.Settings(_disable_stats=True),
               config=config)
    run_name = wandb.run.name
    print(f"Run name: {run_name}", flush=True)

    alpha = float(config['alpha'])
    ae_lr = float(config['ae_lr'])
    latent_dim = int(config['ae_latent'])
    enc_nodes = list(config['ae_nodes'])
    dec_nodes_template = [24, 32, 64, 128]

    #load data
    data_path = '/axovol/HLT_data_oct_14.h5'
    print("Loading dataset...", flush=True)
    with h5.File(data_path, 'r') as f:
        r = f['data'] if 'data' in f else f

        print("H5 tree:")
        print_h5_tree(r)

        x_train = r['Background_data']['Train']['DATA'][:]  
        x_test  = r['Background_data']['Test']['DATA'][:]   
        scale = r['Normalization']['norm_scale'][:]     
        bias = r['Normalization']['norm_bias'][:]        

        print(f"Train shape: {x_train.shape}, Test shape: {x_test.shape}", flush=True)

        #undo normalization (next time you do preprocessing just remove it)
        Xtr_raw = x_train * scale + bias
        Xte_raw = x_test  * scale + bias
        
        #padding
        pad_tr = (x_train == 0.0).all(axis=-1)
        pad_te = (x_test  == 0.0).all(axis=-1)
        Xtr_raw[pad_tr] = 0.0
        Xte_raw[pad_te] = 0.0
        
        #splitting AE datasets by object
        SLOTS = {
            "ELECTRONS": (0, 4),
            "MUONS":     (4, 8),
            "PHOTONS":   (8, 12),
            "JETS":      (12, 22),
            "FATJETS":   (22, 32),
            "MET":       (32, 33),
        }

        def _slice_slots(x, start, end):
            n, _, fdim = x.shape
            return x[:, start:end, :].reshape(n, (end - start) * fdim)

        def _take_groups(x, groups):
            parts = [_slice_slots(x, *SLOTS[g]) for g in groups]
            return np.concatenate(parts, axis=1) if len(parts) > 1 else parts[0]

        # AE-1 = jets (+ fatjets), AE-2 = leptons + photons + MET 
        X1_train = _take_groups(Xtr_raw, ["JETS", "FATJETS", "ELECTRONS", "MUONS", "PHOTONS", "MET"])
        X1_test = _take_groups(Xte_raw, ["JETS", "FATJETS", "ELECTRONS", "MUONS", "PHOTONS", "MET"])
        X2_train = X1_train
        X2_test = X1_test
        # X2_train = _take_groups(Xtr_raw, ["ELECTRONS", "MUONS", "PHOTONS", "MET"])
        # X2_test = _take_groups(Xte_raw, ["ELECTRONS", "MUONS", "PHOTONS", "MET"])


        # keep raw copies for plots/inference
        X1_train_raw, X1_test_raw = X1_train, X1_test
        X2_train_raw, X2_test_raw = X2_train, X2_test

        mu1, std1 = fit_standard_scaler(X1_train_raw)
        # mu2, std2 = fit_standard_scaler(X2_train_raw)
        mu2, std2 = mu1, std1
        
        X1_train_z = transform_standard(X1_train_raw, mu1, std1)
        X1_test_z = transform_standard(X1_test_raw,  mu1, std1)
        X2_train_z = transform_standard(X2_train_raw, mu2, std2)
        X2_test_z = transform_standard(X2_test_raw,  mu2, std2)

    feat1 = X1_train.shape[1]
    feat2 = X2_train.shape[1]

    reco_loss_fn1 = PerSampleMSE().to(device)
    reco_loss_fn2 = PerSampleMSE().to(device)
    print("Loss functions ready.", flush=True)

    ae1_cfg = {
        "features": feat1,
        "latent_dim": latent_dim,
        "encoder_config": {"nodes": enc_nodes},
        "decoder_config": {"nodes": dec_nodes_template + [feat1]},
        "alpha": alpha
    }
    ae2_cfg = {
        "features": feat2,
        "latent_dim": latent_dim,
        "encoder_config": {"nodes": enc_nodes},
        "decoder_config": {"nodes": dec_nodes_template + [feat2]},
        "alpha": alpha
    }

    ae_1 = Autoencoder(ae1_cfg).to(device)
    ae_2 = Autoencoder(ae2_cfg).to(device)
    print("Autoencoders are ready.", flush=True)

    optimizer = torch.optim.Adam(
        list(ae_1.parameters()) + list(ae_2.parameters()),
        lr=ae_lr
    )

    warmup_epochs = 10
    cos = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=32, T_mult=2, eta_min=0.0
    )

    def set_lr(lr):
        for g in optimizer.param_groups:
            g['lr'] = lr

    Epochs_AE = 100
    Batch_size = 1024
    lambda_disco = float(config.get("lambda_disco", 1.0))

    print("Moving data to device...", flush=True)
    X1 = torch.tensor(X1_train_z, dtype=torch.float32, device=device)
    X2 = torch.tensor(X2_train_z, dtype=torch.float32, device=device)
    print("Data on device.", flush=True)

    def _check_finite(name, t):
        ok = torch.isfinite(t).all().item()
        print(f"[finite-check] {name}: {ok} | min={t.min().item():.3g}, max={t.max().item():.3g}")
        if not ok:
            bad = (~torch.isfinite(t)).nonzero(as_tuple=False)[:5]
            print(f"[finite-check] {name} bad idx (first 5): {bad}")
            raise RuntimeError(f"{name} contains non-finite values")

    _check_finite("X1_train", torch.tensor(X1_train[:1000], device=device))
    _check_finite("X2_train", torch.tensor(X2_train[:1000], device=device))

    #training
    print("Starting the training loop!", flush=True)
    N1 = X1.size(0)
    N2 = X2.size(0)
    Nmin = min(N1, N2)

    for epoch in range(Epochs_AE):
        ae1_reco_loss = []
        ae2_reco_loss = []

        if epoch < warmup_epochs:
            lr = ae_lr * (epoch + 1) / warmup_epochs
            set_lr(lr)
        else:
            cos.step(epoch - warmup_epochs)

        perm1 = torch.randperm(N1, device=device)
        perm2 = torch.randperm(N2, device=device)

        num_batches = 0
        total_loss = total_reco1 = total_reco2 = total_disco = 0.0

        for i0 in range(0, Nmin, Batch_size):
            i1 = min(i0 + Batch_size, Nmin)
            bsz = i1 - i0

            idx1 = perm1[i0:i1]
            idx2 = perm2[i0:i1]

            xb1 = X1[idx1]
            xb2 = X2[idx2]

            recon1, z1 = ae_1(xb1)
            recon2, z2 = ae_2(xb2)

            reco1_per = reco_loss_fn1(recon1, xb1)
            reco2_per = reco_loss_fn2(recon2, xb2)

            w = torch.ones(bsz, device=device, dtype=reco1_per.dtype)
            disco = distance_corr(reco1_per, reco2_per, w, power=1)

            reco1 = ae_1.alpha * reco1_per.mean()
            reco2 = ae_2.alpha * reco2_per.mean()

            loss = reco1 + reco2 + lambda_disco * disco

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                list(ae_1.parameters()) + list(ae_2.parameters()), max_norm=5.0
            )
            optimizer.step()

            total_loss  += loss.item()
            total_reco1 += reco1.item()
            total_reco2 += reco2.item()
            total_disco += disco.item()
            num_batches += 1

            ae1_reco_loss.append(reco1.item())
            ae2_reco_loss.append(reco2.item())

            avg_loss  = total_loss  / max(1, num_batches)
            avg_reco1 = total_reco1 / max(1, num_batches)
            avg_reco2 = total_reco2 / max(1, num_batches)
            avg_disco = total_disco / max(1, num_batches)

        print(f"[EPOCH {epoch}/{Epochs_AE}] "
              f"Loss={avg_loss:.4f} "
              f"Reco1(jets)={avg_reco1:.4f} Reco2(lep/photon/MET)={avg_reco2:.4f} "
              f"DisCo={avg_disco:.4f}", flush=True)

        wandb.log({
            "epoch": epoch,
            "TotalLoss": avg_loss,
            "RecoLoss_AE1_jets": avg_reco1,
            "RecoLoss_AE2_lepPhotMET": avg_reco2,
            "DisCoLoss": avg_disco,
            "lr": optimizer.param_groups[0]["lr"],
        })

        ae1_reco_np = np.array(ae1_reco_loss)
        ae2_reco_np = np.array(ae2_reco_loss)

        make_2D_hist(ae1_reco_np, ae2_reco_np,
                     "Reco Loss (jets)", "Reco Loss (lep/photon/MET)",
                     f"Epoch {epoch}: Reco jets vs Reco lep/photon/MET",
                     wandb_key="Hists2D/Reco_jets_vs_Reco_lepPhotMET")

    print("Finished training.", flush=True)

    #save model
    torch.save(ae_1.state_dict(), "ae1_trained_jets.pth")
    torch.save(ae_2.state_dict(), "ae2_trained_lepPhotMET.pth")

    mu1_t = torch.tensor(mu1, dtype=torch.float32, device=device)
    std1_t = torch.tensor(std1, dtype=torch.float32, device=device)
    mu2_t = torch.tensor(mu2, dtype=torch.float32, device=device)
    std2_t = torch.tensor(std2, dtype=torch.float32, device=device)
    
    def ae1_input_from_raw(t_flat: torch.Tensor):
        return (t_flat - mu1_t) / (std1_t + 1e-8)
    
    def ae2_input_from_raw(t_flat: torch.Tensor):
        return (t_flat - mu2_t) / (std2_t + 1e-8)

    _ae1_losses, _ae2_losses = inference_loss_plots(
        ae_1=ae_1,
        ae_2=ae_2,
        dataset_test=(X1_test_raw, X2_test_raw),
        reco_loss_fn1=reco_loss_fn1,
        reco_loss_fn2=reco_loss_fn2,
        device=device,
        batch_size=2048,
        outdir="plots_profiles_same_events",
        bins=200,
        logy=False,
        ae1_input_slicer=ae1_input_from_raw,
        ae2_input_slicer=ae2_input_from_raw,
        ae1_loss_slicer=None,
        ae2_loss_slicer=None
    )


In [12]:
config = {
    'ae_lr': 1e-4,
    'alpha': 0.5,
    'ae_latent': 8,
    'ae_nodes': [28, 14],
    'lambda_disco':1000.0
}

run(config)

Using: cuda:0
Logging in to wandb...


[34m[1mwandb[0m: Currently logged in as: [33mescheuller[0m ([33mescheuller-uc-san-diego[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Run name: winter-sun-259
Loading dataset...
H5 tree:
[GROUP] Background_data
  [GROUP] Test
    DATA: shape=(255856, 33, 15), dtype=float32
  [GROUP] Train
    DATA: shape=(1023420, 33, 15), dtype=float32
[GROUP] Normalization
  norm_bias: shape=(33, 15), dtype=float32
  norm_scale: shape=(33, 15), dtype=float32
Train shape: (1023420, 33, 15), Test shape: (255856, 33, 15)
Loss functions ready.
Moving data to device...
Data on device.
[finite-check] X1_train: True | min=-3.14, max=588
[finite-check] X2_train: True | min=-3.14, max=588
Starting the training loop!
[EPOCH 0/100] Loss=3.8588 Reco1(jets)=0.2482 Reco2(lep/photon/MET)=0.2473 DisCo=0.0034
[EPOCH 1/100] Loss=4.1258 Reco1(jets)=0.2454 Reco2(lep/photon/MET)=0.2416 DisCo=0.0036
[EPOCH 2/100] Loss=3.8061 Reco1(jets)=0.2490 Reco2(lep/photon/MET)=0.2416 DisCo=0.0033
[EPOCH 3/100] Loss=3.6073 Reco1(jets)=0.2654 Reco2(lep/photon/MET)=0.2534 DisCo=0.0031
[EPOCH 4/100] Loss=3.3571 Reco1(jets)=0.2809 Reco2(lep/photon/MET)=0.2701 DisCo=0.00