# 1. Examine generated jets

In [1]:
import numpy as np
import json
import torch
import matplotlib.pyplot as plt

from tensorclass import TensorMultiModal
from datamodules.utils import JetFeatures, ParticleClouds
from datamodules.aoj import AspenOpenJets

dir_path = "/home/df630/Multimodal-Bridges/experiments/results/comet/multimodal-jets/e5812472ffd44fa38ad3915f85806da2"

with open(dir_path + "/metadata.json", "r") as f:
    metadata = json.load(f)

test_path = dir_path + "/data/test_sample.h5"
paths_path = dir_path + "/data/paths_sample.h5"

paths = TensorMultiModal.load_from(paths_path)
test = TensorMultiModal.load_from(test_path)

In [2]:
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from tqdm.auto import tqdm

class MMDMetricMultiGPU:
    def __init__(self, sigma=1.0, batch_size=5000, device_ids=[0, 1, 2, 3]):
        """
        Multi-GPU MMD metric optimized for batch processing and distributed computing.

        Parameters:
        - sigma: Bandwidth parameter for the RBF kernel.
        - batch_size: Number of particles per batch for kernel computation.
        - device_ids: List of GPU device IDs to use (default: [0, 1, 2, 3]).
        """
        self.sigma = sigma
        self.batch_size = batch_size
        self.device_ids = device_ids
        self.world_size = len(device_ids)  # Number of GPUs

    def _split_data(self, data, rank):
        """Split data into chunks based on rank."""
        num_samples = data.shape[0]
        chunk_size = (num_samples + self.world_size - 1) // self.world_size  # Ensure all samples are assigned
        start_idx = rank * chunk_size
        end_idx = min(start_idx + chunk_size, num_samples)
        return data[start_idx:end_idx]

    def _rbf_kernel_batchwise(self, rank, X, Y, results_dict):
        """
        Compute the RBF kernel between two datasets in batches using PyTorch GPU operations across multiple GPUs.
        Each GPU only processes a subset of X and Y.
        """
        device = f"cuda:{rank}"
        X = self._split_data(X, rank).to(device)
        Y = self._split_data(Y, rank).to(device)

        num_X, num_Y = X.shape[0], Y.shape[0]
        num_batches_X = (num_X + self.batch_size - 1) // self.batch_size
        num_batches_Y = (num_Y + self.batch_size - 1) // self.batch_size

        mean_K = 0.0
        count = 0

        for i in tqdm(range(num_batches_X), desc=f"GPU {rank}: RBF Kernel Batches", position=rank):
            start_i = i * self.batch_size
            end_i = min((i + 1) * self.batch_size, num_X)
            X_batch = X[start_i:end_i]

            for j in range(num_batches_Y):
                start_j = j * self.batch_size
                end_j = min((j + 1) * self.batch_size, num_Y)
                Y_batch = Y[start_j:end_j]

                # Compute squared Euclidean distances and RBF kernel
                dist_sq = torch.cdist(X_batch, Y_batch, p=2) ** 2
                K_batch = torch.exp(-dist_sq / (2 * self.sigma ** 2))

                mean_K += K_batch.sum().item()
                count += K_batch.numel()

        results_dict[rank] = mean_K / count

    def _discrete_kernel_batchwise(self, rank, X, Y, results_dict):
        """
        Compute the Kronecker delta kernel for discrete features in batches using GPU tensors across multiple GPUs.
        Each GPU only processes a subset of X and Y.
        """
        device = f"cuda:{rank}"
        X = self._split_data(X, rank).to(device)
        Y = self._split_data(Y, rank).to(device)

        num_X, num_Y = X.shape[0], Y.shape[0]
        num_batches_X = (num_X + self.batch_size - 1) // self.batch_size
        num_batches_Y = (num_Y + self.batch_size - 1) // self.batch_size

        mean_K = 0.0
        count = 0

        for i in tqdm(range(num_batches_X), desc=f"GPU {rank}: Discrete Kernel Batches", position=rank):
            start_i = i * self.batch_size
            end_i = min((i + 1) * self.batch_size, num_X)
            X_batch = X[start_i:end_i]

            for j in range(num_batches_Y):
                start_j = j * self.batch_size
                end_j = min((j + 1) * self.batch_size, num_Y)
                Y_batch = Y[start_j:end_j]

                K_batch = (X_batch[:, None] == Y_batch[None, :]).float()
                mean_K += K_batch.sum().item()
                count += K_batch.numel()

        results_dict[rank] = mean_K / count

    def compute_mmd(self, gen, test):
        """
        Compute the MMD metric between real and generated jet distributions using multi-GPU acceleration and batching.

        Parameters:
        - gen: TensorMultiModal (generated data).
        - test: TensorMultiModal (test data).

        Returns:
        - MMD^2 score.
        """
        # Extract valid (non-zero-padded) particles
        gen_kin = gen.continuous[gen.mask.squeeze(-1) > 0]
        test_kin = test.continuous[test.mask.squeeze(-1) > 0]
        gen_flavor = gen.discrete[gen.mask.squeeze(-1) > 0]
        test_flavor = test.discrete[test.mask.squeeze(-1) > 0]

        # Create dictionaries to store results
        rbf_results = mp.Manager().dict()
        discrete_results = mp.Manager().dict()

        # Launch multi-GPU processes for RBF kernel computation
        processes = []
        for rank in range(self.world_size):
            p = mp.Process(target=self._rbf_kernel_batchwise, args=(rank, test_kin, test_kin, rbf_results))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()

        mean_K_XX = sum(rbf_results.values()) / self.world_size

        # Launch multi-GPU processes for discrete kernel computation
        processes = []
        for rank in range(self.world_size):
            p = mp.Process(target=self._discrete_kernel_batchwise, args=(rank, test_flavor, test_flavor, discrete_results))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()

        mean_K_XX *= sum(discrete_results.values()) / self.world_size

        # Repeat for other kernel terms (mean_K_YY and mean_K_XY)
        processes = []
        for rank in range(self.world_size):
            p = mp.Process(target=self._rbf_kernel_batchwise, args=(rank, gen_kin, gen_kin, rbf_results))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()

        mean_K_YY = sum(rbf_results.values()) / self.world_size

        processes = []
        for rank in range(self.world_size):
            p = mp.Process(target=self._discrete_kernel_batchwise, args=(rank, gen_flavor, gen_flavor, discrete_results))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()

        mean_K_YY *= sum(discrete_results.values()) / self.world_size

        processes = []
        for rank in range(self.world_size):
            p = mp.Process(target=self._rbf_kernel_batchwise, args=(rank, test_kin, gen_kin, rbf_results))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()

        mean_K_XY = sum(rbf_results.values()) / self.world_size

        processes = []
        for rank in range(self.world_size):
            p = mp.Process(target=self._discrete_kernel_batchwise, args=(rank, test_flavor, gen_flavor, discrete_results))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()

        mean_K_XY *= sum(discrete_results.values()) / self.world_size

        # Compute unbiased MMD^2
        mmd2 = mean_K_XX + mean_K_YY - 2 * mean_K_XY
        return mmd2

# Example usage:
# mmd_metric_multi_gpu = MMDMetricMultiGPU(sigma=1.0, batch_size=5000, device_ids=[0, 1, 2, 3])
# mmd_score = mmd_metric_multi_gpu.compute_mmd(paths[-1], test)
# print("MMD Score:", mmd_score)



In [2]:
import torch
from tqdm import tqdm 

class MMDMetricGPU:
    def __init__(self, sigma=1.0, batch_size=5000, device="cuda"):
        """
        MMD metric optimized for GPU computation with batch processing.
        
        Parameters:
        - sigma: Bandwidth parameter for the RBF kernel.
        - batch_size: Number of particles per batch for kernel computation.
        - device: "cuda" for GPU, "cpu" otherwise.
        """
        self.sigma = sigma
        self.batch_size = batch_size
        self.device = device  # "cuda" for GPU, "cpu" for CPU fallback

    def _rbf_kernel_batchwise(self, X, Y):
        """
        Compute the RBF kernel between two datasets in batches using PyTorch GPU operations.
        Returns the mean kernel value.
        """
        num_X, num_Y = X.shape[0], Y.shape[0]
        num_batches_X = (num_X + self.batch_size - 1) // self.batch_size
        num_batches_Y = (num_Y + self.batch_size - 1) // self.batch_size

        mean_K = 0.0
        count = 0

        # Wrap outer loop with tqdm for progress
        for i in tqdm(range(num_batches_X), desc="RBF Kernel Batches (X)"):
            start_i = i * self.batch_size
            end_i = min((i + 1) * self.batch_size, num_X)
            X_batch = X[start_i:end_i].to(self.device)

            for j in range(num_batches_Y):
                start_j = j * self.batch_size
                end_j = min((j + 1) * self.batch_size, num_Y)
                Y_batch = Y[start_j:end_j].to(self.device)

                # Compute squared Euclidean distances and RBF kernel
                dist_sq = torch.cdist(X_batch, Y_batch, p=2) ** 2
                K_batch = torch.exp(-dist_sq / (2 * self.sigma ** 2))
                
                mean_K += K_batch.sum().item()
                count += K_batch.numel()

        return mean_K / count

    def _discrete_kernel_batchwise(self, X, Y):
        """Compute the Kronecker delta kernel for discrete features in batches using GPU tensors."""
        num_X, num_Y = X.shape[0], Y.shape[0]
        num_batches_X = (num_X + self.batch_size - 1) // self.batch_size
        num_batches_Y = (num_Y + self.batch_size - 1) // self.batch_size

        mean_K = 0
        count = 0

        for i in range(num_batches_X):
            start_i = i * self.batch_size
            end_i = min((i + 1) * self.batch_size, num_X)
            X_batch = X[start_i:end_i].to(self.device)

            for j in range(num_batches_Y):
                start_j = j * self.batch_size
                end_j = min((j + 1) * self.batch_size, num_Y)
                Y_batch = Y[start_j:end_j].to(self.device)

                K_batch = (X_batch[:, None] == Y_batch[None, :]).float()
                mean_K += K_batch.sum()
                count += K_batch.numel()

        return mean_K / count  # Return the mean discrete kernel value

    def compute_mmd(self, gen, test):
        """
        Compute the MMD metric between real and generated jet distributions using GPU acceleration and batching.

        Parameters:
        - gen: TensorMultiModal (generated data).
        - test: TensorMultiModal (test data).

        Returns:
        - MMD^2 score.
        """

        # Extract valid (non-zero-padded) particles and move to GPU
        gen_kin = gen.continuous[gen.mask.squeeze(-1) > 0].to(self.device)
        test_kin = test.continuous[test.mask.squeeze(-1) > 0].to(self.device)
        # gen_flavor = gen.discrete[gen.mask.squeeze(-1) > 0].to(self.device)
        # test_flavor = test.discrete[test.mask.squeeze(-1) > 0].to(self.device)

        # Compute kernel matrices in **batches** (to avoid memory issues)
        mean_K_XX = self._rbf_kernel_batchwise(test_kin, test_kin) #* self._discrete_kernel_batchwise(test_flavor, test_flavor)
        mean_K_YY = self._rbf_kernel_batchwise(gen_kin, gen_kin) #* self._discrete_kernel_batchwise(gen_flavor, gen_flavor)
        mean_K_XY = self._rbf_kernel_batchwise(test_kin, gen_kin) #* self._discrete_kernel_batchwise(test_flavor, gen_flavor)

        # Compute unbiased MMD^2
        mmd2 = mean_K_XX + mean_K_YY - 2 * mean_K_XY
        return mmd2.item()  # Convert to Python float



In [5]:

# Example usage
mmd_metric = MMDMetricMultiGPU(sigma=1.0, batch_size=5000, device_ids=[0,1,2,3])  # Reduce batch_size if memory is still an issue
mmd_metric.compute_mmd(paths[-1], test)

