In [None]:
# Conditioning on the particle type, not masking the velocity field

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from torchdyn.core import NeuralODE
from torchcfm.conditional_flow_matching import ExactOptimalTransportConditionalFlowMatcher
from matplotlib import animation
import numpy as np
import load_and_preprocess as lap
import train_and_eval_functions as taef

In [None]:
datasets = lap.load_and_preprocess(plots_path='/eos/home-m/mmcohen/ctfm_development/trained_models/trial_1/plots')

In [None]:
class TransformerVectorField(nn.Module):
    def __init__(self, input_dim=3, model_dim=128, num_heads=8, num_layers=4, 
                 n_mask_vals=5, ff_dim=512):
        super().__init__()
        self.model_dim = model_dim

        # project the 3D features
        self.input_proj = nn.Linear(input_dim, model_dim)

        # embed mask (0=pad, 1-4=particle types)
        self.mask_emb = nn.Embedding(n_mask_vals, model_dim)

        # time embedding MLP
        self.time_mlp = nn.Sequential(
            nn.Linear(1, model_dim),
            nn.SiLU(),
            nn.Linear(model_dim, model_dim)
        )

        # transformer stack
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=model_dim,
            nhead=num_heads,
            dim_feedforward=ff_dim,
            activation='gelu',
            batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # back to velocity
        self.output_proj = nn.Linear(model_dim, input_dim)

In [None]:
def train(model, dataloader, optimizer, device, num_epochs=50, sigma=0.0):
    model.to(device)
    # flow matcher does not support conditioning directly; we embed and mask slots instead
    FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)

    for epoch in range(num_epochs):
        total_loss = 0.
        for i, (x1, mask) in enumerate(tqdm(dataloader)):
            x1 = x1.to(device)
            mask = mask.to(device)
            optimizer.zero_grad()

            # sample noise
            x0 = torch.randn_like(x1)
            # sample flow-matching tuples with condition
            t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)

            # model prediction
            vt = model(t, xt, mask=mask)
            
            # Compute loss over all dimensions
            sq_err = (vt - ut)**2                      # [B, N, D]
            loss = torch.mean(sq_err)                  # Mean over all dimensions
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1} - Loss: {avg_loss:.6f}")

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

dataset = EvtDataset(sel_train[:ntrain].astype(np.float32))
dataloader = DataLoader(dataset, batch_size=batchSize, shuffle=True, drop_last=True)

model = TransformerVectorField(input_dim=input_dim, model_dim=model_dim, num_heads=num_heads, num_layers=num_layers)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

if not trained:
    train(model, dataloader, optimizer, device=device, num_epochs=numberOfEpochs, sigma=flowSigma)
    
    # Assuming `node` is your trained NeuralODE (e.g., torchcfm.NeuralODE)
    torch.save({
        'state_dict': model.state_dict(),
        'model_kwargs': {
            'input_dim': input_dim,
            'model_dim': model_dim,
            'num_heads': num_heads,
            'num_layers': num_layers
        }
    }, "cond_no_mask_flow_model_%s.pt"%label)

else:
    # Rebuild the vector field
    ckpt = torch.load("cond_no_mask_flow_model_%s.pt"%label, map_location=device)
    
    # Recreate the architecture
    model = TransformerVectorField(**ckpt['model_kwargs'])
    model.load_state_dict(ckpt['state_dict'])
    model.to(device).eval()

node = NeuralODE(model, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)

In [None]:
# animate flows

In [None]:
def map_to_base_distribution(node, x1, mask, batch_size=1024):
    """
    Maps x1 (data space) to x0 (base space) using the inverse flow in batches.
    Applies mask to ignore padded particles.
    
    Args:
        node: A trained NeuralODE.
        x1: [B, N, D] torch.Tensor - data space samples.
        mask: [B, N, 1] torch.Tensor - binary mask for valid particles.
        batch_size: int - batch size for processing.

    Returns:
        x0: [B, N, D] torch.Tensor - mapped base samples with mask applied.
    """
    t_span = torch.linspace(1., 0., 2, device=x1.device)
    outputs = []

    for i in tqdm(range(0, x1.size(0), batch_size)):
        xb = x1[i:i+batch_size]
        mb = mask[i:i+batch_size]

        with torch.no_grad():
            traj = node.trajectory(xb, t_span=t_span)  # [2, bsz, N, D]
            x0_b = traj[-1]  # No mask applied here

        outputs.append(x0_b)

    return torch.cat(outputs, dim=0)

test_evts_tensor = torch.tensor(parts_test, dtype=torch.float32).to(device)

ntest = 2000000

# Extract features and mask
tp = np.random.permutation(test_evts_tensor.shape[0])
x_test = test_evts_tensor[tp[:ntest], :, :3]       # [eta, phi, pt]
mask_test = test_evts_tensor[tp[:ntest], :, 3:]     # [mask], shape [B, N, 1]

x0_latent = map_to_base_distribution(node, x_test, mask_test)

print("Mapped shape:", x0_latent.shape)  # Should be [B, maxparts, 3]
print("Latent mean (should be ~0):", x0_latent.mean().item())
print("Latent std (should be ~1):", x0_latent.std().item())

In [None]:
# plot features post flow

In [None]:
def neg_log_prob_gaussian(x0, mask=None):
    """
    Compute -log p(x0) under a standard N-dimensional Gaussian,
    over all dimensions including padded particles.

    Args:
        x0: [B, N, D] tensor
        mask: [B, N, 1] binary mask tensor (optional, not used)

    Returns:
        nll: [B] tensor, negative log-likelihood per jet
    """
    # Standard Gaussian: log p(x) = -0.5 * (x^2 + log(2π))
    log_probs = -0.5 * (x0 ** 2 + torch.log(torch.tensor(2 * torch.pi, device=x0.device)))

    # Sum over all particles and features without masking
    log_likelihood = log_probs.sum(dim=(1, 2))  # [B]

    # Return negative log-likelihood
    return -log_likelihood  # [B]

nll = neg_log_prob_gaussian(x0_latent, mask_test).cpu().numpy()  # [B]
print("Mean NLL per jet:", np.mean(nll))

labelnames = ['bkg', 'a4l', 'htaunu', 'htautau', 'lq']
for i in range(5):
    _,bins,_ = plt.hist(nll[labels_test[tp[:ntest],i]==1],bins=50 if i==0 else bins, histtype='step',label=labelnames[i],density=True)
plt.legend()
plt.yscale('log')

In [None]:
# ROC curve