In [4]:
import torch
from torch.utils.data import Dataset
import pandas as pd
import h5py
from pathlib import Path
from typing import Union, Optional, Callable
from tqdm import tqdm

In [1]:
from jepa.modules import JEA
import yaml

In [5]:
class TrackMLDataset(Dataset):
    r"""
    A Dataset subclass for the TrackML dataset in HDF5 format.

    Args:
        file (str or Path): path to the HDF5 file holding the data.
        scaling_factor (float, optional): a multiplicative scaling factor applied to the hit positions (default: ``1.0``).
            Note that, by default, positions are specified in millimeters.
        transform (Callable, optional): a function used to further process the output.
        float_dtype (torch.dtype, optional): the dtype of the returned tensors for floating-point features (default: ``torch.float32``).
    """

    def __init__(
        self,
        file: Union[str,Path],
        scaling_factor: float=1.0,
        transform: Optional[Callable]=None, #TODO transforms
        float_dtype=torch.float32,
    ):
        super(TrackMLDataset).__init__()
        self.file = h5py.File(file, 'r')
        self.number_of_events = self.file.attrs['number_of_events']
        self.hits = self.file['hits']
        self.truth = self.file['truth']
        self.float_dtype = float_dtype
        self.scaling_factor = torch.tensor(scaling_factor, dtype=float_dtype)
        self.transform = transform #TODO transforms

    def __del__(self):
        self.file.close()

    def __len__(self):
        return self.number_of_events

    def __getitem__(self, idx: int):
        x, hit_id = self._get_hits(idx)
        pids = self._get_particle_ids(idx, hit_id)
        output = {
            'x': x,
            'mask': torch.ones(x.shape[0], dtype=bool),
            'pids': pids,
            'event': None,
        }
        if self.transform:
            output = self.transform(output)
        return output

    def _get_hits(self, idx: int):
        offset = self.hits['event_offset'][idx]
        length = self.hits['event_length'][idx]
        event_slice = slice(offset, offset+length)
        hit_id = pd.DataFrame({'hit_id': self.hits['hit_id'][event_slice]}, copy=False).set_index('hit_id')
        x = torch.zeros((length, 3), dtype=self.float_dtype)
        x[:,0] = torch.from_numpy(self.hits['x'][event_slice]) * self.scaling_factor
        x[:,1] = torch.from_numpy(self.hits['y'][event_slice]) * self.scaling_factor
        x[:,2] = torch.from_numpy(self.hits['z'][event_slice]) * self.scaling_factor
        return x, hit_id

    def _get_particle_ids(self, idx: int, detected_hits: pd.DataFrame):
        # Note: not all hits in "hits" are also in "truth", and reciprocally
        # Note: the weight is ignored for now
        offset = self.truth['event_offset'][idx]
        length = self.truth['event_length'][idx]
        event_slice = slice(offset, offset+length)
        truth = pd.DataFrame({
            'hit_id': self.truth['hit_id'][event_slice],
            'particle_id': self.truth['particle_id'][event_slice],
        }, copy=False).set_index('hit_id')
        # Let’s find the true particle_id corresponding to each detected hit_id
        joined = detected_hits.join(truth, on='hit_id', how='inner')
        assert joined['particle_id'].dtype == truth['particle_id'].dtype
        matched_particle_id = torch.from_numpy(joined['particle_id'].values)
        return matched_particle_id


In [6]:
from typing import Dict

class WedgePatchify3d:
    """
    A class to transform hitwise data into wedges of annuli.

    Give an event of shape [num_hits, 3] (x, y, z), return two mask tensors (context, target) of shape [num_hits]
    """

    def __init__(self, phi_range: float, eta_range: float, radius_midpoint: float, random_context: bool = True):
        self.phi_range = phi_range
        self.eta_range = eta_range
        self.radius_midpoint = radius_midpoint
        self.random_context = random_context

    def __call__(self, sample: Dict) -> Dict:
        """
        Apply the WedgePatchify transform to the input sample.

        Args:
            sample (Dict): A dictionary containing hitwise data. Must include an 'x' key
                           with a tensor of shape (num_hits, 3).

        Returns:
            Dict: The transformed sample with context, target, and mask tensors.
        """
        x, y, z = self._extract_coordinates(sample)
        radius, phi, eta = self._calculate_radius_phi_and_eta(x, y, z)
        selected_phi = self._select_random_phi(phi)
        phi_mask = self._create_phi_mask(phi, selected_phi)
        selected_eta = self._select_random_eta(eta)
        eta_mask = self._create_eta_mask(eta, selected_eta)
        inner_mask, outer_mask = self._create_radius_masks(radius)
        context_mask, target_mask = self._assign_masks(inner_mask, outer_mask, phi_mask, eta_mask)

        sample["context_mask"] = context_mask
        sample["target_mask"] = target_mask

        return sample

    def _extract_coordinates(self, sample: Dict) -> (torch.Tensor, torch.Tensor, torch.Tensor):
        x = sample["x"][:, 0]
        y = sample["x"][:, 1]
        z = sample["x"][:, 2]
        return x, y, z

    def _calculate_radius_phi_and_eta(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor):
        radius = torch.sqrt(x**2 + y**2)
        norm = torch.sqrt(x**2 + y**2 + z**2)
        phi = torch.atan2(y, x)  # Returns values between -pi and pi
        eta = torch.atanh(z / norm)
        return radius, phi, eta

    def _select_random_phi(self, phi: torch.Tensor) -> torch.Tensor:
        hit_idx = torch.randint(0, phi.shape[0], (1,))
        return phi[hit_idx]

    def _select_random_eta(self, eta: torch.Tensor) -> torch.Tensor:
        hit_idx = torch.randint(0, eta.shape[0], (1,))
        return eta[hit_idx]

    def _create_phi_mask(self, phi: torch.Tensor, selected_phi: torch.Tensor) -> torch.Tensor:
        phi_min = selected_phi - self.phi_range / 2
        phi_max = selected_phi + self.phi_range / 2
        return torch.logical_or(
            torch.logical_and(phi >= phi_min, phi <= phi_max),
            torch.logical_and(phi + 2*torch.pi >= phi_min, phi + 2*torch.pi <= phi_max)
        )

    def _create_eta_mask(self, eta: torch.Tensor, selected_eta: torch.Tensor) -> torch.Tensor:
        eta_min = selected_eta - self.eta_range / 2
        eta_max = selected_eta + self.eta_range / 2
        return torch.logical_and(eta >= eta_min, eta <= eta_max)

    def _create_radius_masks(self, radius: torch.Tensor) -> (torch.Tensor, torch.Tensor):
        inner_mask = radius <= self.radius_midpoint
        outer_mask = radius > self.radius_midpoint
        return inner_mask, outer_mask

    def _assign_masks(self, inner_mask: torch.Tensor, outer_mask: torch.Tensor, phi_mask: torch.Tensor, eta_mask: torch.Tensor) -> (torch.Tensor, torch.Tensor):
        if self.random_context and torch.rand(1).item() > 0.5:
            context_mask = inner_mask & phi_mask & eta_mask
            target_mask = outer_mask & phi_mask & eta_mask
        else:
            context_mask = outer_mask & phi_mask & eta_mask
            target_mask = inner_mask & phi_mask & eta_mask
        return context_mask, target_mask

In [49]:
class TrackMLDataset(Dataset):
    def __init__(
        self,
        file: Union[str, Path],
        scaling_factor: float = 1.0,
        phi_range: float = 0.5,
        eta_range: float = 0.5,
        radius_midpoint: float = 500.0,
        random_context: bool = True,
        float_dtype=torch.float32,
    ):
        super(TrackMLDataset).__init__()
        self.file = h5py.File(file, 'r')
        self.number_of_events = self.file.attrs['number_of_events']
        self.hits = self.file['hits']
        self.truth = self.file['truth']
        self.float_dtype = float_dtype
        self.scaling_factor = torch.tensor(scaling_factor, dtype=float_dtype)
        
        # Store wedge parameters directly in the dataset
        self.phi_range = phi_range
        self.eta_range = eta_range
        self.radius_midpoint = radius_midpoint
        self.random_context = random_context

    def __del__(self):
        self.file.close()

    def __len__(self):
        return self.number_of_events

    def __getitem__(self, idx: int):
        x, hit_id = self._get_hits(idx)
        # pids = self._get_particle_ids(idx, hit_id) #TODO Do we need it?
        
        # Calculate splits directly in the dataset
        x_context, x_target, context_mask, target_mask = self._split_data(x)
        
        return {
            'x_context': x_context,
            'x_target': x_target,
            'x_context_mask': torch.ones(x_context.shape[0], dtype=bool),
            'x_target_mask': torch.ones(x_target.shape[0], dtype=bool),
            # 'pids': pids,
            'event': None,
        }

    def _get_hits(self, idx: int):
        offset = self.hits['event_offset'][idx]
        length = self.hits['event_length'][idx]
        event_slice = slice(offset, offset+length)
        hit_id = pd.DataFrame({'hit_id': self.hits['hit_id'][event_slice]}, copy=False).set_index('hit_id')
        x = torch.zeros((length, 3), dtype=self.float_dtype)
        x[:,0] = torch.from_numpy(self.hits['x'][event_slice]) * self.scaling_factor
        x[:,1] = torch.from_numpy(self.hits['y'][event_slice]) * self.scaling_factor
        x[:,2] = torch.from_numpy(self.hits['z'][event_slice]) * self.scaling_factor
        return x, hit_id

    def _get_particle_ids(self, idx: int, detected_hits: pd.DataFrame):
        offset = self.truth['event_offset'][idx]
        length = self.truth['event_length'][idx]
        event_slice = slice(offset, offset+length)
        truth = pd.DataFrame({
            'hit_id': self.truth['hit_id'][event_slice],
            'particle_id': self.truth['particle_id'][event_slice],
        }, copy=False).set_index('hit_id')
        joined = detected_hits.join(truth, on='hit_id', how='inner')
        assert joined['particle_id'].dtype == truth['particle_id'].dtype
        matched_particle_id = torch.from_numpy(joined['particle_id'].values)
        return matched_particle_id

    def _split_data(self, x: torch.Tensor):
        """Split the data into context and target based on geometric criteria"""
        # Extract coordinates
        x_coord, y_coord, z_coord = x[:, 0], x[:, 1], x[:, 2]
        
        # Calculate geometric quantities
        radius = torch.sqrt(x_coord**2 + y_coord**2)
        norm = torch.sqrt(x_coord**2 + y_coord**2 + z_coord**2)
        phi = torch.atan2(y_coord, x_coord)
        eta = torch.atanh(z_coord / norm)
        
        # Select random reference points
        selected_phi = phi[torch.randint(0, phi.shape[0], (1,))]
        selected_eta = eta[torch.randint(0, eta.shape[0], (1,))]
        
        # Create masks
        phi_mask = self._create_phi_mask(phi, selected_phi)
        eta_mask = self._create_eta_mask(eta, selected_eta)
        inner_mask = radius <= self.radius_midpoint
        outer_mask = radius > self.radius_midpoint
        
        # Assign context and target based on random selection
        if self.random_context and torch.rand(1).item() > 0.5:
            context_mask = inner_mask & phi_mask & eta_mask
            target_mask = outer_mask & phi_mask & eta_mask
        else:
            context_mask = outer_mask & phi_mask & eta_mask
            target_mask = inner_mask & phi_mask & eta_mask
        
        # Split the data based on masks
        x_context = x[context_mask]
        x_target = x[target_mask]
        
        return x_context, x_target, context_mask, target_mask

    def _create_phi_mask(self, phi: torch.Tensor, selected_phi: torch.Tensor) -> torch.Tensor:
        phi_min = selected_phi - self.phi_range / 2
        phi_max = selected_phi + self.phi_range / 2
        return torch.logical_or(
            torch.logical_and(phi >= phi_min, phi <= phi_max),
            torch.logical_and(phi + 2*torch.pi >= phi_min, phi + 2*torch.pi <= phi_max)
        )

    def _create_eta_mask(self, eta: torch.Tensor, selected_eta: torch.Tensor) -> torch.Tensor:
        eta_min = selected_eta - self.eta_range / 2
        eta_max = selected_eta + self.eta_range / 2
        return torch.logical_and(eta >= eta_min, eta <= eta_max)

In [51]:
# dset = TrackMLDataset('/home/ucloud/Particle-JEPA/data/TrackML/training-small.hdf5', transform=patchify, scaling_factor=1e-3)

In [52]:
# torch.random.manual_seed(42)
# patchify = WedgePatchify3d(phi_range=torch.pi/2, eta_range=0.5, radius_midpoint=0.5) # Very approximate midpoint

In [53]:
dataset = TrackMLDataset(
    file='/home/ucloud/Particle-JEPA/data/TrackML/training-small.hdf5',
    scaling_factor=1e-3,
    phi_range=torch.pi/2,
    eta_range=0.5,
    radius_midpoint=0.5,
    random_context=True,
)

In [56]:
batch = dataset[0]
batch.keys()

dict_keys(['x_context', 'x_target', 'x_context_mask', 'x_target_mask', 'event'])

In [58]:
batch['x_context'].shape, batch['x_target'].shape, batch['x_context_mask'].shape, batch['x_target_mask'].shape, batch['event']

(torch.Size([949, 3]),
 torch.Size([1035, 3]),
 torch.Size([949]),
 torch.Size([1035]),
 None)

In [62]:
sample = dataset[0]

In [18]:
from typing import Optional, Union, List, Dict

In [19]:
def collate_fn(batch: List[Dict]) -> Dict:
    collated = {}
    for key in batch[0].keys():
        if key == 'event':
            collated[key] = [item[key] for item in batch]
        else:
            # Get first item to check dimensionality
            first_item = batch[0][key]
            if isinstance(first_item, torch.Tensor) and first_item.dim() == 0:
                # For scalar tensors, stack them
                collated[key] = torch.stack([item[key] for item in batch])
            else:
                # For tensors with dimensions, pad them
                collated[key] = torch.nn.utils.rnn.pad_sequence([item[key] for item in batch], batch_first=True)
    return collated

In [63]:
def collate_fn_manual(batch: List[Dict]) -> Dict:
    """Current manual padding approach"""
    batch_size = len(batch)
    collated = {}
    
    # Handle x_context and x_target
    for key in ['x_context', 'x_target']:
        tensors = [item[key] for item in batch]
        max_length = max(tensor.size(0) for tensor in tensors)
        padded = torch.zeros(batch_size, max_length, 3, dtype=tensors[0].dtype)
        for i, tensor in enumerate(tensors):
            padded[i, :tensor.size(0)] = tensor
        collated[key] = padded

    # Handle masks
    for key in ['x_context_mask', 'x_target_mask']:
        tensors = [item[key] for item in batch]
        max_length = max(tensor.size(0) for tensor in tensors)
        padded = torch.zeros(batch_size, max_length, dtype=torch.bool)
        for i, tensor in enumerate(tensors):
            padded[i, :tensor.size(0)] = tensor
        collated[key] = padded

    collated['num_context'] = torch.tensor([item['x_context'].size(0) for item in batch])
    collated['num_target'] = torch.tensor([item['x_target'].size(0) for item in batch])
    
    return collated

def collate_fn_pad_sequence(batch: List[Dict]) -> Dict:
    """Alternative approach using pad_sequence"""
    collated = {}
    
    # Handle x_context and x_target
    for key in ['x_context', 'x_target']:
        # Need to handle 3D tensors carefully
        tensors = [item[key] for item in batch]
        # pad_sequence expects sequence first, so we need to handle the 3D nature carefully
        padded = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True)
        collated[key] = padded

    # Handle masks
    for key in ['x_context_mask', 'x_target_mask']:
        tensors = [item[key] for item in batch]
        padded = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True)
        collated[key] = padded

    collated['num_context'] = torch.tensor([item['x_context'].size(0) for item in batch])
    collated['num_target'] = torch.tensor([item['x_target'].size(0) for item in batch])
    
    return collated


In [66]:

# Let's add a timing test
import time

def compare_collate_methods(batch_size=32, num_points=100, num_trials=100):
    # Create some dummy data
    batch = []
    for _ in range(batch_size):
        # Random number of points between 1 and num_points
        n_context = torch.randint(1, num_points, (1,)).item()
        n_target = torch.randint(1, num_points, (1,)).item()
        
        batch.append({
            'x_context': torch.randn(n_context, 3),
            'x_target': torch.randn(n_target, 3),
            'x_context_mask': torch.ones(n_context, dtype=torch.bool),
            'x_target_mask': torch.ones(n_target, dtype=torch.bool),
        })
    
    # Time manual padding
    start = time.time()
    for _ in range(num_trials):
        _ = collate_fn_manual(batch)
    manual_time = (time.time() - start) / num_trials
    
    # Time pad_sequence
    start = time.time()
    for _ in range(num_trials):
        _ = collate_fn_pad_sequence(batch)
    pad_sequence_time = (time.time() - start) / num_trials
    
    print(f"Manual padding: {manual_time*1000:.2f}ms per batch")
    print(f"pad_sequence:   {pad_sequence_time*1000:.2f}ms per batch")
    print(f"Ratio (pad_sequence/manual): {pad_sequence_time/manual_time:.2f}x")
    
    # Verify outputs are the same
    out1 = collate_fn_manual(batch)
    out2 = collate_fn_pad_sequence(batch)
    
    all_equal = all(torch.allclose(out1[k], out2[k]) for k in out1.keys())
    print(f"\nOutputs are {'equal' if all_equal else 'different'}")
    
    return manual_time, pad_sequence_time

In [67]:
compare_collate_methods()

Manual padding: 0.74ms per batch
pad_sequence:   0.52ms per batch
Ratio (pad_sequence/manual): 0.71x

Outputs are equal


(0.0007410573959350585, 0.0005229401588439941)

In [68]:
collate_fn = collate_fn_pad_sequence

In [69]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

In [70]:
batch = next(iter(dataloader))

In [72]:
batch

{'x_context': tensor([[[ 0.0076,  0.0311,  0.0587],
          [ 0.0038,  0.0313,  0.0680],
          [ 0.0036,  0.0313,  0.0480],
          ...,
          [-0.1734,  0.4673,  1.0042],
          [-0.3437,  0.3612,  1.0078],
          [-0.4036,  0.2946,  1.0438]],
 
         [[-0.1410, -0.0819, -0.8180],
          [-0.1510, -0.0818, -0.8180],
          [-0.1409, -0.0878, -0.8180],
          ...,
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]]]),
 'x_target': tensor([[[-0.1893,  0.4632,  0.6666],
          [-0.3867,  0.3188,  0.6678],
          [ 0.1439,  0.4800,  0.7072],
          ...,
          [-0.8520,  0.4263,  2.1555],
          [-0.8724,  0.4073,  2.1555],
          [-0.8580,  0.3647,  2.1555]],
 
         [[-0.6131, -0.3374, -2.9485],
          [-0.5623, -0.3127, -2.9485],
          [-0.5434, -0.2653, -2.9485],
          ...,
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
         

In [73]:
def move_batch_to_device(batch, device):
    """
    Recursively moves all tensors in a batch to the specified device.
    
    Args:
        batch: A dictionary, list, tuple, or tensor
        device: The target device (e.g., 'cuda' or 'cpu')
    
    Returns:
        The batch with all tensors moved to the specified device
    """
    if isinstance(batch, dict):
        return {k: move_batch_to_device(v, device) for k, v in batch.items()}
    elif isinstance(batch, (list, tuple)):
        return type(batch)(move_batch_to_device(v, device) for v in batch)
    elif hasattr(batch, 'to'):
        return batch.to(device)
    return batch

# Usage example:
# batch = move_batch_to_device(batch, device='cuda')  # For GPU
# batch = move_batch_to_device(batch, device='cpu')   # For CPU


In [77]:
with open("configs/9_testing.yaml", "r") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

In [78]:
config['d_input'] = 3

In [79]:
model = JEA(**config)

In [80]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# batch = move_batch_to_device(batch, device)


JEA(
  (encoder): Encoder(
    (transformer): Transformer(
      (input_encoder): Sequential(
        (0): Linear(in_features=3, out_features=128, bias=True)
        (1): Dropout(p=0.0, inplace=False)
        (2): SiLU()
        (3): Linear(in_features=128, out_features=32, bias=True)
      )
      (encoder_layers): ModuleList(
        (0-2): 3 x AttentionBlock(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
          )
          (feed_forward): Sequential(
            (0): Linear(in_features=32, out_features=128, bias=True)
            (1): SiLU()
            (2): Dropout(p=0, inplace=False)
            (3): Linear(in_features=128, out_features=32, bias=True)
          )
          (dropout): Dropout(p=0, inplace=False)
          (norm_self_attn): SetNorm()
          (norm_ff): SetNorm()
          (activation): SiLU()
        )
      )
    )
    (aggregator): Aggregator(
      (encoder_layer

In [81]:
model.training_step(batch, 0)

Starting first training step...


KeyError: 'x'