# Classify Groups (Standardized)

Decisions:
- Single-group samples; upstream code prefilters groups.
- Node features: `[coord, z, energy, view, group_energy]` (5 dims).
- Views use a single flag `view` (0=x-strip, 1=y-strip).
- Fully connected graphs per sample, built in the dataset (not in the model).
- Edge features always used: `[dx, dz, dE, same_view]` (4 dims).
- Models read `data.x`, `data.edge_index`, `data.edge_attr` only (no graph construction in forward).

In [None]:
import math
from typing import List, Dict, Any, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import TransformerConv, JumpingKnowledge
from torch_geometric.nn.aggr import AttentionalAggregation


In [None]:
def fully_connected_edge_index(num_nodes: int, device=None) -> torch.Tensor:
    """
    Build a fully connected directed edge_index (excluding self-loops)
    for a single graph with `num_nodes` nodes.

    Returns a LongTensor of shape [2, num_edges] with rows [src, dst].
    """
    if num_nodes <= 1:
        return torch.empty((2, 0), dtype=torch.long, device=device)

    # Create all ordered pairs (i, j) for i != j.
    src_index = torch.arange(num_nodes, device=device).repeat_interleave(num_nodes - 1)
    dst_index = torch.stack([
        torch.cat([torch.arange(0, i, device=device), torch.arange(i + 1, num_nodes, device=device)])
        for i in range(num_nodes)
    ], dim=0).reshape(-1)
    return torch.stack([src_index, dst_index], dim=0)

def build_edge_attr(node_features: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
    """
    Compute edge features for each edge defined by `edge_index` using
    the standardized node feature layout: [coord, z, energy, view, group_energy].

    Returns a FloatTensor of shape [num_edges, 4] with columns:
      [dx, dz, dE, same_view].
    """
    if edge_index.numel() == 0:
        return torch.zeros((0, 4), dtype=torch.float, device=node_features.device)

    src_idx, dst_idx = edge_index[0], edge_index[1]
    coord = node_features[:, 0]
    z_pos = node_features[:, 1]
    energy = node_features[:, 2]
    view_flag = node_features[:, 3]

    dx = (coord[dst_idx] - coord[src_idx]).unsqueeze(1)
    dz = (z_pos[dst_idx] - z_pos[src_idx]).unsqueeze(1)
    dE = (energy[dst_idx] - energy[src_idx]).unsqueeze(1)
    same_view = (view_flag[dst_idx] == view_flag[src_idx]).float().unsqueeze(1)

    return torch.cat([dx, dz, dE, same_view], dim=1)


In [None]:
class GroupClassificationDataset(Dataset):
    """
    Dataset for single-group samples. Assumes groups are prefiltered upstream.

    Each item is a dict with keys:
      - 'coord': 1D array-like (strip coordinate in mm; x or y depending on view)
      - 'z': 1D array-like (z coordinate in mm)
      - 'energy': 1D array-like (deposited energy per hit)
      - 'view': 1D array-like (0 for x-strip, 1 for y-strip)
      - 'label': int class label (optional)
      - 'event_id': int identifier (optional)

    Outputs torch_geometric.data.Data with:
      - x: [num_hits, 5] node features [coord, z, energy, view, group_energy]
      - edge_index: [2, num_edges] fully-connected directed edges (no self loops)
      - edge_attr: [num_edges, 4] edge features [dx, dz, dE, same_view]
      - y: optional class label
      - event_id: optional graph-level identifier
    """
    def __init__(self, groups: List[Dict[str, Any]]):
        self.items = groups

    def __len__(self) -> int:
        return len(self.items)

    def __getitem__(self, index: int) -> Data:
        item = self.items[index]

        # Convert inputs to numpy arrays with explicit dtypes
        coord_mm = np.asarray(item['coord'], dtype=np.float32)
        z_mm = np.asarray(item['z'], dtype=np.float32)
        energy_dep = np.asarray(item['energy'], dtype=np.float32)
        view_flag = np.asarray(item['view'], dtype=np.float32)  # 0=x-strip, 1=y-strip
        num_hits = coord_mm.shape[0]

        # Broadcast total group energy across all hits
        total_group_energy = np.full(num_hits, energy_dep.sum(), dtype=np.float32)

        # Stack node features: [coord, z, energy, view, group_energy]
        node_features_np = np.stack([coord_mm, z_mm, energy_dep, view_flag, total_group_energy], axis=1)
        node_features = torch.tensor(node_features_np, dtype=torch.float)

        # Build fully-connected graph and corresponding edge features
        edge_index = fully_connected_edge_index(num_hits, device=None)
        edge_features = build_edge_attr(node_features, edge_index)

        # Optional targets/ids
        label_tensor = None
        if 'label' in item and item['label'] is not None:
            label_tensor = torch.tensor(item['label'], dtype=torch.long)

        event_id_tensor = None
        if 'event_id' in item and item['event_id'] is not None:
            event_id_tensor = torch.tensor(item['event_id'], dtype=torch.long)

        data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features, y=label_tensor)
        if event_id_tensor is not None:
            data.event_id = event_id_tensor
        return data


In [None]:
class FullTransformerBlock(nn.Module):
    """
    A transformer-style graph block:
      - Multi-head attention via TransformerConv (uses edge_attr)
      - Residual + LayerNorm
      - Position-wise FFN + Residual + LayerNorm
    """
    def __init__(self, hidden: int = 200, heads: int = 4, edge_dim: int = 4, dropout: float = 0.05):
        super().__init__()
        # With concat=True, out_channels * heads = hidden => per-head = hidden // heads
        self.attn = TransformerConv(hidden, hidden // heads, heads=heads, concat=True, edge_dim=edge_dim, dropout=dropout)
        self.norm1 = nn.LayerNorm(hidden)
        self.ff = nn.Sequential(
            nn.Linear(hidden, hidden * 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden * 4, hidden)
        )
        self.norm2 = nn.LayerNorm(hidden)
        self.drop = nn.Dropout(dropout)

    def forward(self, node_features: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
        # Attention over edges (uses edge_attr), then residual + LayerNorm
        attn_out = self.attn(node_features, edge_index, edge_attr)
        node_features = self.norm1(node_features + self.drop(attn_out))
        # Position-wise FFN, residual, and LayerNorm
        ff_out = self.ff(node_features)
        node_features = self.norm2(node_features + self.drop(ff_out))
        return node_features

class GroupClassifier(nn.Module):
    """
    Graph-level classifier using stacked FullTransformerBlocks, JumpingKnowledge
    (concatenation) across blocks, and AttentionalAggregation for pooling.
    """
    def __init__(self, in_dim: int = 5, edge_dim: int = 4, hidden: int = 200, num_blocks: int = 2, heads: int = 4, dropout: float = 0.0, num_classes: int = 3):
        super().__init__()
        self.input_embed = nn.Linear(in_dim, hidden)
        self.blocks = nn.ModuleList([
            FullTransformerBlock(hidden=hidden, heads=heads, edge_dim=edge_dim, dropout=dropout)
            for _ in range(num_blocks)
        ])
        self.jk_layer = JumpingKnowledge(mode='cat')
        concatenated_dim = hidden * num_blocks
        self.attn_pool = AttentionalAggregation(
            gate_nn=nn.Sequential(
                nn.Linear(concatenated_dim, concatenated_dim // 2),
                nn.ReLU(),
                nn.Linear(concatenated_dim // 2, 1)
            )
        )
        self.classifier_head = nn.Sequential(
            nn.Linear(concatenated_dim, concatenated_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(concatenated_dim // 2, num_classes) # Number of classes is just the number of possible particle types
        )

    def forward(self, data: Data) -> torch.Tensor:
        # Unpack batched graph inputs
        node_features, edge_index, edge_attr, batch_index = data.x, data.edge_index, data.edge_attr, data.batch

        # Initial node embedding
        node_features = self.input_embed(node_features)

        # Run stacked transformer blocks, collect outputs for JK concatenation
        block_outputs = []
        for block in self.blocks:
            node_features = block(node_features, edge_index, edge_attr)
            block_outputs.append(node_features)

        # Concatenate per-block outputs (JumpingKnowledge)
        concatenated = self.jk_layer(block_outputs)

        # Attention-based graph pooling to get one vector per graph in the batch
        pooled_graph_embeddings = self.attn_pool(concatenated, index=batch_index)

        # Final classifier head
        logits = self.classifier_head(pooled_graph_embeddings)
        return logits


In [None]:
# Example usage (replace with your real prefiltered groups):
# groups = [{
#     'coord': np.random.randn(32),
#     'z': np.random.randn(32),
#     'energy': np.abs(np.random.randn(32)),
#     'view': np.random.randint(0, 2, size=32),
#     'label': np.random.randint(0, 3),
#     'event_id': 123
# } for _ in range(100)]
# ds = GroupClassificationDataset(groups)
# loader = DataLoader(ds, batch_size=16, shuffle=True)
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# model = GroupClassifier(num_classes=3).to(device)
# opt = torch.optim.Adam(model.parameters(), lr=1e-3)
# for epoch in range(5):
#     model.train()
#     for batch in loader:
#         batch = batch.to(device)
#         logits = model(batch)
#         loss = F.cross_entropy(logits, batch.y)
#         opt.zero_grad(); loss.backward(); opt.step()
#     print('epoch', epoch)
