In [286]:
import h5py
import numpy as np
import torch
from sklearn.preprocessing import StandardScaler
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt

In [287]:
#load the data
with h5py.File('data/background_for_training.h5', 'r') as f:
    X = f['Particles'][:100000]  #will load more if this works

In [288]:
#must flatten data so it is 1D per event
X = X.reshape(X.shape[0], -1)

#take out any NaN
X = np.nan_to_num(X)

#scale it
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X).astype(np.float32)

#make torch tensor
X_tensor = torch.tensor(X_scaled)

In [289]:
#standard vae model class
class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim=16):
        super().__init__()

        #the encoder (maps input to hidden layer, then to mu and logvar)
        self.encoder = nn.Sequential(

            #input to hidden layer
            nn.Linear(input_dim, 64),

            #activation function
            nn.ReLU(),

            #outputs mu and logvar
            nn.Linear(64, 2 * latent_dim)
        )

        #decoder (maps latent vector z back to recon. input)
        self.decoder = nn.Sequential(

            #latent to hidden layer
            nn.Linear(latent_dim, 64),

            #act. function
            nn.ReLU(),

            #hidden to recon. inputp
            nn.Linear(64, input_dim)
        )

    #this encodes input z into latent params (mu and logvar)
    def encode(self, x):

        #shape is [batch_size, 2*latent_dim]
        h = self.encoder(x)

        #splits into mu and logvar on last dimension
        mu, logvar = h.chunk(2, dim=-1)
        return mu, logvar

    #samples latent var z with reparam trick (z = mu + std*(~N(0,1))
    #this allows gradient propagation
    def reparameterize(self, mu, logvar):
        std = (0.5 * logvar).exp()
        eps = torch.randn_like(std)

        #reparam latent sample
        return mu + eps * std

    #decodes latent var z into input space
    def decode(self, z):
        return self.decoder(z)

    #VAE forward pass
    def forward(self, x):

        #encode input to mu and logvar
        mu, logvar = self.encode(x)

        #sample z via reparameterize
        z = self.reparameterize(mu, logvar)

        #decode z to recon input
        x_hat = self.decode(z)
        return x_hat, mu, logvar, z


In [290]:
#get the distance correlation (DisCo) given and x and y var
#this is from the paper
def disco_loss(x, y):
    x = x - x.mean(0)
    y = y - y.mean(0)
    a = torch.cdist(x, x)
    b = torch.cdist(y, y)
    A = a - a.mean(0) - a.mean(1, keepdim=True) + a.mean()
    B = b - b.mean(0) - b.mean(1, keepdim=True) + b.mean()
    dcov = (A * B).mean()
    dvar_x = (A * A).mean().sqrt()
    dvar_y = (B * B).mean().sqrt()
    return dcov / (dvar_x.detach() * dvar_y.detach() + 1e-10) #in case of zero in denom

In [291]:
#calculate loss for vae (mse and kl div)
def vae_loss(x, x_hat, mu, logvar):
    
    #reco loss
    recon = F.mse_loss(x_hat, x, reduction='mean')
    
    #kl divergence
    kl_div = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return recon + kl_div

In [292]:
#combining vae loss and disco loss
def total_loss(x, out_1, out_2, lambda_disco=0.01):

    #output from VAE 1
    x1_hat, mu_1, logvar_1, z_1 = out_1

    #output from VAE 2
    x2_hat, mu_2, logvar_2, z_2 = out_2

    #calc vae loss
    loss_1 = vae_loss(x, x1_hat, mu_1, logvar_1)
    loss_2 = vae_loss(x, x2_hat, mu_2, logvar_2)


    #calc distance corr between latent vectors z_1 and z_2
    disco = disco_loss(z_1, z_2)

    #sum loss (keep in mind disco term penalizes corr between latent spaces
    #to force decorrelation)
    return loss_1 + loss_2 + lambda_disco * disco

In [293]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#init both vaes
vae_1 = VAE(X_tensor.shape[1]).to(device)
vae_2 = VAE(X_tensor.shape[1]).to(device)

#optimizer with Adam for both models
optimizer = torch.optim.Adam(list(vae1.parameters()) + list(vae2.parameters()), lr=1e-3)

In [None]:
# training hyperparams
batch_size = 256
epochs = 20

# train the VAEs
for epoch in range(epochs):
    # shuffle data after each epoch
    perm = torch.randperm(X_tensor.size(0))

    # initialize loss tracker
    final_loss = 0
    total_disco = 0
    total_vae1 = 0
    total_vae2 = 0

    # iterate through the dataset in mini batches
    for i in range(0, X_tensor.size(0), batch_size):
        # shuffled mini batch
        x_batch = X_tensor[perm[i:i+batch_size]]

        # forward pass through both VAEs
        out_1 = vae_1(x_batch)
        out_2 = vae_2(x_batch)

        # unpack outputs
        x1_hat, mu1, logvar1, z1 = out_1
        x2_hat, mu2, logvar2, z2 = out_2

        # compute individual VAE losses
        loss1 = vae_loss(x_batch, x1_hat, mu1, logvar1)
        loss2 = vae_loss(x_batch, x2_hat, mu2, logvar2)

        # compute DisCo
        disco = distance_correlation(z1, z2)

        # compute total loss with weighting
        loss = loss1 + loss2 + 0.01 * disco

        # zero gradients
        optimizer.zero_grad()

        # backward pass (gradient propagation)
        loss.backward()

        # update parameters
        optimizer.step()

        # accumulate stats
        final_loss += loss.item()
        total_disco += disco.item()
        total_vae1 += loss1.item()
        total_vae2 += loss2.item()

    # print epoch summary
    print(f"Epoch {epoch} : Loss = {final_loss:.2f} | VAE1 = {total_vae1:.2f} | VAE2 = {total_vae2:.2f} | DisCo = {total_disco:.4f}")


Epoch 0 : Loss = 62411163151783.41 | VAE1 = 62391002048066.84 | VAE2 = 20159233893.31 | DisCo = 87.2577
Epoch 1 : Loss = 62511167459570.32 | VAE1 = 62486166590082.21 | VAE2 = 25000066695.47 | DisCo = 86.8345
Epoch 2 : Loss = 62790411183166.39 | VAE1 = 62771152763440.91 | VAE2 = 19256551981.71 | DisCo = 87.0986
Epoch 3 : Loss = 69854668997955.31 | VAE1 = 69835094694052.35 | VAE2 = 19572677824.01 | DisCo = 87.4907
Epoch 4 : Loss = 65042879660304.66 | VAE1 = 65022223150000.98 | VAE2 = 20657091922.99 | DisCo = 86.8738
Epoch 5 : Loss = 62462371716439.15 | VAE1 = 62440540943773.84 | VAE2 = 21829791690.74 | DisCo = 87.0557


In [None]:
#training hyperparams
batch_size = 256
epochs = 20

#train the vaes
for epoch in range(epochs):

    #shuffle data after each epoch
    perm = torch.randperm(X_tensor.size(0))

    #init loss
    final_loss = 0

    #itirates through the dataset in mini batches
    for i in range(0, X_tensor.size(0), batch_size):

        #shuffled mini batch
        x_batch = X_tensor[perm[i:i+batch_size]]

        #forward pass through baes
        out_1 = vae_1(x_batch)
        out_2 = vae_2(x_batch)

        #calc total loss
        loss = total_loss(x_batch, out_1, out_2)

        #zero gradients
        optimizer.zero_grad()

        #backward pass (gradient prop)
        loss.backward()

        #update params
        optimizer.step()

        #add to the final loss
        final_loss += loss.item()

        
    print(f"Epoch {epoch} : Loss = {final_loss}")