In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

from sklearn.model_selection import train_test_split

import uproot
import awkward as ak

In [2]:
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# Function to calculate weights analytically

def weight_fn(thetas, angles):
    theta0, theta1, theta2 = thetas[:, 0], thetas[:, 1], thetas[:, 2]
    phi, costh = angles[:, 0], angles[:, 1]
    weight = 1. + theta0* costh * costh + 2.* theta1* costh * np.sqrt(1. - costh * costh) * np.cos(phi) + 0.5* theta2* (1. - costh * costh)* np.cos(2. * phi)
    return weight / (1. + costh * costh)

In [4]:
# Load E906 simulated data

nevents = 10**6

tree = uproot.open("BMFData.root:save")
events = tree.arrays(["mass", "pT", "xF", "phi", "costh", "true_phi", "true_costh"])

X = np.array([(mass, pT, xF, phi, costh) for mass, pT, xF, phi, costh in zip(events.mass, events.pT, events.xF, events.phi, events.costh)])
Y = np.array([(true_phi, true_costh) for true_phi, true_costh in zip(events.true_phi, events.true_costh)])

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=nevents, shuffle=True)
thetas = np.random.uniform(-1., 1, (nevents, 3))
W_train = weight_fn(thetas, Y_train).reshape(-1, 1)

In [5]:
batch_size = 1024
num_epochs = 100

# Convert to torch tensor
X_tensor = torch.from_numpy(X_train).float()
weight_tensor = torch.from_numpy(W_train).float()
thetas_tensor = torch.from_numpy(thetas).float()

dataset = TensorDataset(X_tensor, weight_tensor, thetas_tensor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Variational Autoencoders for Parameter Extraction

In [6]:
class ParamsEstimator(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, params_dim):
        super(ParamsEstimator, self).__init__()

        self.fc_encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim, bias=True),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim, bias=True),
            nn.ReLU(),
            # nn.Linear(hidden_dim, latent_dim, bias=True),
        )

        self.fc_mu = nn.Linear(hidden_dim, latent_dim, bias=True)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim, bias=True)

        self.fc_decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim, bias=True),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim, bias=True),
            nn.ReLU(),
            nn.Linear(hidden_dim, params_dim, bias=True),
        )

    def encode(self, x):
        h = self.fc_encoder(x)
        mu, logvar = self.fc_mu(h), self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h = self.fc_decoder(z)
        return h

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [7]:
class EstimatorLoss(nn.Module):
    def __init__(self):
        super(EstimatorLoss, self).__init__()

    def forward(self, reco_params, weight, params, mu, logvar):
        reco_loss = (weight* (reco_params - params)*(reco_params - params)).sum()
        kld_loss = torch.sum(1. + logvar - mu.pow(2) - logvar.exp())
        return (reco_loss + kld_loss)/params.size(0)

In [8]:
# Create the VAE
input_dim = 5
hidden_dim = 64
latent_dim = 32
params_dim = 3
estimator = ParamsEstimator(input_dim, hidden_dim, latent_dim, params_dim)

In [9]:
# Define the optimizer
optimizer = optim.Adam(estimator.parameters(), lr=0.001)
loss_fn = EstimatorLoss()

In [10]:
# Training loop
for epoch in range(num_epochs):
    estimator.train()
    for batch_x, batch_weight, batch_params in dataloader:
        optimizer.zero_grad()
        reco_params, mu, logvar = estimator(batch_x)
        loss_batch = loss_fn(reco_params, batch_weight, batch_params, mu, logvar)
        loss_batch.backward()
        optimizer.step()
    if epoch % 10 == 0:
        print(f"===> Epoch [{epoch}/{num_epochs}], Loss: {loss_batch.item()}")

===> Epoch [0/100], Loss: -1.3560248406020016e+22


KeyboardInterrupt: 