In [1]:
import os
import random
import gc
import argparse
import numpy as np
import h5py as h5
import torch
import torch.nn.functional as F
import wandb

from models.vae import VAE 
from losses.cyl_ptpz_mae import CylPtPzMAE

from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict

In [2]:
#setting a seed like in vae_legacy
def set_seed(seed=123):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [3]:
def distance_corr(var_1,var_2,normedweight,power=1):
    """var_1: First variable to decorrelate (eg mass)
    var_2: Second variable to decorrelate (eg classifier output)
    normedweight: Per-example weight. Sum of weights should add up to N (where N is the number of examples)
    power: Exponent used in calculating the distance correlation
    
    va1_1, var_2 and normedweight should all be 1D torch tensors with the same number of entries
    
    Usage: Add to your loss function. total_loss = BCE_loss + lambda * distance_corr
    """ 
    
    xx = var_1.view(-1, 1).repeat(1, len(var_1)).view(len(var_1),len(var_1))
    yy = var_1.repeat(len(var_1),1).view(len(var_1),len(var_1))
    amat = (xx-yy).abs()

    xx = var_2.view(-1, 1).repeat(1, len(var_2)).view(len(var_2),len(var_2))
    yy = var_2.repeat(len(var_2),1).view(len(var_2),len(var_2))
    bmat = (xx-yy).abs()

    amatavg = torch.mean(amat*normedweight,dim=1)
    Amat=amat-amatavg.repeat(len(var_1),1).view(len(var_1),len(var_1))\
        -amatavg.view(-1, 1).repeat(1, len(var_1)).view(len(var_1),len(var_1))\
        +torch.mean(amatavg*normedweight)

    bmatavg = torch.mean(bmat*normedweight,dim=1)
    Bmat=bmat-bmatavg.repeat(len(var_2),1).view(len(var_2),len(var_2))\
        -bmatavg.view(-1, 1).repeat(1, len(var_2)).view(len(var_2),len(var_2))\
        +torch.mean(bmatavg*normedweight)

    ABavg = torch.mean(Amat*Bmat*normedweight,dim=1)
    AAavg = torch.mean(Amat*Amat*normedweight,dim=1)
    BBavg = torch.mean(Bmat*Bmat*normedweight,dim=1)

    if(power==1):
        dCorr=(torch.mean(ABavg*normedweight))/torch.sqrt((torch.mean(AAavg*normedweight)*torch.mean(BBavg*normedweight)))
    elif(power==2):
        dCorr=(torch.mean(ABavg*normedweight))**2/(torch.mean(AAavg*normedweight)*torch.mean(BBavg*normedweight))
    else:
        dCorr=((torch.mean(ABavg*normedweight))/torch.sqrt((torch.mean(AAavg*normedweight)*torch.mean(BBavg*normedweight))))**power
    
    return dCorr

In [4]:
#calculates anomaly score like in vae legacy
def distance_pt(model_vae, data_np, device):
    x = torch.tensor(data_np, dtype=torch.float32, device=device)
    z_mean, z_logvar, _ = model_vae.encoder(x)
    score = torch.sum(z_mean**2, dim=1)
    return score.detach().cpu().numpy()

In [5]:
#print h5 tree
def print_h5_tree(h, prefix=""):
    for k in h.keys():
        item = h[k]
        if hasattr(item, 'keys'):
            print(prefix + f"[GROUP] {k}")
            print_h5_tree(item, prefix + "  ")
        else:
            try:
                print(prefix + f"{k}: shape={item.shape}, dtype={item.dtype}")
            except Exception:
                print(prefix + f"{k}: <dataset>")

In [6]:
#function to make 2D histograms
def make_2D_hist(x, y, xlabel, ylabel, title, wandb_key, bins=40):
    fig = plt.figure(figsize=(5,4))
    plt.hist2d(x, y, bins=bins)
    plt.colorbar(label='Counts')
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    wandb.log({wandb_key: wandb.Image(fig)})
    plt.close(fig)

In [9]:
def run(config: Dict):
    # set seed
    seed = int(config.get('seed', 123))
    set_seed(seed)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(f"Using: {device}")

    #login to wandb
    print("Logging in to wandb...")
    wandb.login(key="24d1d60ce26563c74d290d7b487cb104fc251271")
    wandb.init(
        project="Double Disco Axo Training",
        settings=wandb.Settings(_disable_stats=True),
        config=config
    )
    run_name = wandb.run.name
    print(f"Run name: {run_name}")

    #scaling
    beta = float(config['beta'])
    alpha = float(config['alpha'])
    vae_lr = float(config['vae_lr'])

    #load data
    print("Loading dataset...")
    fpath = '/axovol/training/v5/conditionsupdate_apr25.h5'
    with h5.File(fpath, 'r') as f:
        root = f['data'] if 'data' in f else f

        # print h5 tree to view
        # print("Printing h5 tree...")
        # print_h5_tree(root)

        x_train = root['Background_data']['Train']['DATA'][:]
        x_test = root['Background_data']['Test']['DATA'][:]
        print(f"Train shape: {x_train.shape}, Test shape: {x_test.shape}")

        #flatten event for vae 1
        x_train_bkg = x_train.reshape(x_train.shape[0], -1).astype('float32')
        x_test_bkg = x_test.reshape(x_test.shape[0],  -1).astype('float32') 

        #scale and biases
        scale = root['Normalisation']['norm_scale'][:].astype('float32')
        bias = root['Normalisation']['norm_bias'][:].astype('float32')

        #HT and ET
        HT_train = root['Background_data']['Train']['HT'][:].astype('float32') 
        ET_train = root['Background_data']['Train']['ET'][:].astype('float32')
        HT_test = root['Background_data']['Test']['HT'][:].astype('float32')
        ET_test = root['Background_data']['Test']['ET'][:].astype('float32')

    print("Data finished loading.")

    #vae 1 is unchanged
    X1_train = x_train_bkg 
    features1 = X1_train.shape[1] 

    #vae 2 trained on HT, MET + jet kinematics
    #using 0 as MET, 1-4 and egamma, 5-8 mu, 9-18 jets (check with melissa)
    met_idx = 0
    n_eg = 4
    n_mu = 4
    n_jet = 10
    jet_start = 1 + n_eg + n_mu
    jet_stop = jet_start + n_jet

    #flatten jets
    jets_train = x_train[:, jet_start:jet_stop, :].reshape(x_train.shape[0], -1).astype('float32') 
    jets_test = x_test[:,  jet_start:jet_stop,  :].reshape(x_test.shape[0],  -1).astype('float32')

    #make vae 2 matrices
    X2_train = np.concatenate([HT_train[:, None], ET_train[:, None], jets_train], axis=1).astype('float32')
    X2_test = np.concatenate([HT_test[:,  None], ET_test[:,  None], jets_test],  axis=1).astype('float32')
    features2 = X2_train.shape[1]

    #scales/biases for jets only
    scale_jets = scale[jet_start:jet_stop, :] 
    bias_jets = bias[jet_start:jet_stop, :]

    #vae 1 reco loss
    reco1_loss_fn = CylPtPzMAE(scale, bias).to(device)

    #vae 2 reco loss is MSE on [HT, MET] + CylPtPzMAE on jets
    reco2_jet_loss_fn = CylPtPzMAE(scale_jets, bias_jets).to(device)
    def reco2_loss_fn(y_pred, y_true):
        mse_scalars = torch.mean((y_pred[:, :2] - y_true[:, :2])**2, dim=1)
        mae_jets = reco2_jet_loss_fn(y_pred[:, 2:], y_true[:, 2:])
        return mse_scalars + mae_jets

    #configs
    latent_dim = int(config['vae_latent'])
    enc_nodes = list(config['vae_nodes'])
    dec_nodes1 = [24, 32, 64, 128, features1]
    dec_nodes2 = [24, 32, 64, 128, features2]

    vae1_cfg = {
        "features": features1,
        "latent_dim": latent_dim,
        "encoder_config": {"nodes": enc_nodes},
        "decoder_config": {"nodes": dec_nodes1},
        "alpha": alpha,
        "beta":  beta,
    }
    vae2_cfg = {
        "features": features2,
        "latent_dim": latent_dim,
        "encoder_config": {"nodes": enc_nodes},
        "decoder_config": {"nodes": dec_nodes2},
        "alpha": alpha,
        "beta":  beta,
    }

    vae_1 = VAE(vae1_cfg).to(device)
    vae_2 = VAE(vae2_cfg).to(device)
    print("VAEs are ready.")

    #optimizer
    optimizer = torch.optim.Adam(
        list(vae_1.parameters()) + list(vae_2.parameters()),
        lr=vae_lr
    )

    #cosine restarts
    warmup_epochs = int(config.get('warmup_epochs', 10))
    cos = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=32, T_mult=2, eta_min=0.0
    )

    #sets learning rate
    def set_lr(lr):
        for g in optimizer.param_groups:
            g['lr'] = lr

    #hyperparameters
    Epochs_VAE = int(config.get('epochs', 50))
    Batch_size = int(config.get('batch_size', 16384))

    #disco scaling
    lambda_disco = float(config.get("lambda_disco", 1.0))

    #move data to device
    print("Moving data to device...")
    X1 = torch.tensor(X1_train, dtype=torch.float32, device=device)
    X2 = torch.tensor(X2_train, dtype=torch.float32, device=device) 
    print("Data on device.")

    #training loop
    print("Starting the training loop!")
    N = X1.size(0)
    for epoch in range(Epochs_VAE):

        #lists for 2D hists
        vae1_total_loss = []
        vae2_total_loss = []
        vae1_kl_loss = []
        vae2_kl_loss = []
        vae1_reco_loss = []
        vae2_reco_loss = []

        #cosine warmup step
        if epoch < warmup_epochs:
            lr = vae_lr * (epoch + 1) / warmup_epochs
            set_lr(lr)
        else:
            cos.step(epoch - warmup_epochs)

        #shuffle indices
        perm = torch.randperm(N, device=device)

        total_loss = total_reco1 = total_reco2 = total_kl1 = total_kl2 = total_disco = 0.0

        #training loop
        for i in range(0, N, Batch_size):
            idx = perm[i:i+Batch_size]
            xb1 = X1[idx]  # full input for vae 1
            xb2 = X2[idx]  # HT, MET, jets for vae 2

            #vae 1
            recon1, mu1, logvar1, z1 = vae_1(xb1)
            
            #vae 2
            recon2, mu2, logvar2, z2 = vae_2(xb2)

            #get reco loss from custom func
            reco1_per = reco1_loss_fn(recon1, xb1)
            reco2_per = reco2_loss_fn(recon2, xb2)

            #get kl div per sample
            kl1_per = VAE.kl_divergence(mu1, logvar1)
            kl2_per = VAE.kl_divergence(mu2, logvar2)
            
            tot1_per = reco1_per + kl1_per
            tot2_per = reco2_per + kl2_per

            #from paper code weight
            B = xb1.shape[0]
            w = torch.ones(B, device=tot1_per.device, dtype=tot1_per.dtype)

            #disco loss (ask Melissa about since using mu instead of z)
            #disco = disco_loss(mu1, mu2)
            disco = distance_corr((tot1_per), (tot2_per), w, power=1)

            reco1 = vae_1.reco_scale * reco1_per.mean()
            reco2 = vae_2.reco_scale * reco2_per.mean()
            kl1 = vae_1.kl_scale   * kl1_per.mean()
            kl2 = vae_2.kl_scale   * kl2_per.mean()

            #total loss
            loss = (reco1 + kl1) + (reco2 + kl2) + lambda_disco * disco

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                list(vae_1.parameters()) + list(vae_2.parameters()),
                max_norm=5.0
            )
            optimizer.step()

            total_loss+= float(loss.item())
            total_reco1+= float(reco1.item())
            total_reco2+= float(reco2.item())
            total_kl1+= float(kl1.item())
            total_kl2+= float(kl2.item())
            total_disco+= float(disco.item())

            #add loss to lists
            vae1_total_loss.append((reco1 + kl1).item())
            vae2_total_loss.append((reco2 + kl2).item())
            vae1_kl_loss.append(kl1.item())
            vae2_kl_loss.append(kl2.item())
            vae1_reco_loss.append(reco1.item())
            vae2_reco_loss.append(reco2.item())

        print(f"[EPOCH {epoch}/{Epochs_VAE}] "
              f"Loss={total_loss:.4f} "
              f"Reco1={total_reco1:.4f} Reco2={total_reco2:.4f} "
              f"KL1={total_kl1:.4f} KL2={total_kl2:.4f} "
              f"DisCo={total_disco:.4f}")

        wandb.log({
            "EpochVae": epoch,
            "TotalLossVae": total_loss,
            "RecoLossVae1": total_reco1,
            "RecoLossVae2": total_reco2,
            "KLLossVae1": total_kl1,
            "KLLossVae2": total_kl2,
            "DisCoLoss": total_disco,
        })

        #2d hist plotting
        vae1_total_np = np.array(vae1_total_loss)
        vae2_total_np = np.array(vae2_total_loss)
        vae1_kl_np = np.array(vae1_kl_loss)
        vae2_kl_np = np.array(vae2_kl_loss)
        vae1_reco_np = np.array(vae1_reco_loss)
        vae2_reco_np = np.array(vae2_reco_loss)

        make_2D_hist(vae1_total_np, vae2_total_np,
                     "Total Loss (VAE1)", "Total Loss (VAE2)",
                     f"Epoch {epoch}: Total VAE1 vs Total VAE2",
                     wandb_key="Hists2D/Total_VAE1_vs_Total_VAE2")

        make_2D_hist(vae1_kl_np, vae2_kl_np,
                     "KL Loss (VAE1)", "KL Loss (VAE2)",
                     f"Epoch {epoch}: KL VAE1 vs KL VAE2",
                     wandb_key="Hists2D/KL_VAE_1_vs_KL_VAE_2")

        make_2D_hist(vae1_reco_np, vae2_reco_np,
                     "Reco Loss (VAE1)", "Reco Loss (VAE2)",
                     f"Epoch {epoch}: Reco VAE1 vs Reco VAE2",
                     wandb_key="Hists2D/Reco_VAE1_vs_Reco_VAE2")

    print("Finished training.")
    torch.save(vae_1.state_dict(), "vae1_trained.pth")
    torch.save(vae_2.state_dict(), "vae2_trained.pth")
    print("Saved vae1_trained.pth and vae2_trained.pth")

In [10]:
config = {
    "vae_lr": 1e-4,
    "beta": 0.5,
    "alpha": 0.5,
    "vae_latent": 8,
    "vae_nodes": [28, 14],
    "lambda_disco": 1000.0,
    "epochs": 50,
    "batch_size": 16384,
    "warmup_epochs": 10,
}
run(config)



Using: cuda:0
Logging in to wandb...


Run name: ruby-vortex-120
Loading dataset...
Train shape: (1999965, 19, 3), Test shape: (4511092, 19, 3)
Data finished loading.
VAEs are ready.
Moving data to device...
Data on device.
Starting the training loop!


OutOfMemoryError: CUDA out of memory. Tried to allocate 1024.00 MiB. GPU 0 has a total capacity of 10.57 GiB of which 841.12 MiB is free. Including non-PyTorch memory, this process has 9.74 GiB memory in use. Of the allocated memory 9.55 GiB is allocated by PyTorch, and 15.65 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)