
# Positron Angle Predictor

Training utilities for predicting the positron emission direction from time-group graphs, including per-epoch angular-error histograms.


In [None]:

import sys
import os

PROJECT_ROOT = '/mnt/c/Users/obbee/research/notebooks/ML'
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)
    os.environ['PYTHONPATH'] = PROJECT_ROOT + os.pathsep + os.environ.get('PYTHONPATH', '')


In [None]:

import math
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from torch_geometric.loader import DataLoader

from graph_data.utils import PositronAngleGraphDataset, PositronAngleRecord
from graph_data.models import PositronAngleModel


In [None]:

def build_angle_records(
    groups: Sequence[Iterable],
    angle_targets: Sequence[Sequence[float]],
    *,
    event_ids: Optional[Sequence[int]] = None,
    pion_stops: Optional[Sequence[Sequence[float]]] = None,
) -> List[PositronAngleRecord]:
    """Normalize raw group arrays plus per-group angle labels into PositronAngleRecord objects."""
    if len(groups) != len(angle_targets):
        raise ValueError("angle_targets length must match groups length")

    records: List[PositronAngleRecord] = []
    for idx, (group, angle) in enumerate(zip(groups, angle_targets)):
        arr = np.asarray(group)
        if arr.ndim != 2 or arr.shape[0] < 2 or arr.shape[1] < 4:
            raise ValueError(f"Group at index {idx} must be [N, >=4], got shape {arr.shape}")

        record_event = int(event_ids[idx]) if event_ids is not None else None
        pion_stop = None if pion_stops is None else np.asarray(pion_stops[idx], dtype=np.float32)

        records.append(PositronAngleRecord(
            coord=arr[:, 0].astype(np.float32),
            z=arr[:, 1].astype(np.float32),
            energy=arr[:, 2].astype(np.float32),
            view=arr[:, 3].astype(np.float32),
            angle=angle,
            event_id=record_event,
            group_id=idx,
            pion_stop=None if pion_stop is None else pion_stop.astype(np.float32),
        ))
    return records


In [None]:

def _split_records(
    records: Sequence[PositronAngleRecord | Dict[str, Any]],
    *,
    train_fraction: float = 0.85,
    seed: int = 13,
) -> Tuple[List[PositronAngleRecord | Dict[str, Any]], List[PositronAngleRecord | Dict[str, Any]]]:
    if not records:
        return [], []
    rng = np.random.default_rng(seed)
    indices = np.arange(len(records))
    rng.shuffle(indices)
    if len(indices) == 1:
        return [records[int(indices[0])]], []
    split = int(len(indices) * train_fraction)
    split = min(max(split, 1), len(indices) - 1)
    train_idx = indices[:split]
    val_idx = indices[split:]
    return [records[i] for i in train_idx], [records[i] for i in val_idx]


def _make_loaders(
    records: Sequence[PositronAngleRecord | Dict[str, Any]],
    *,
    batch_size: int,
    train_fraction: float,
    seed: int,
):
    train_records, val_records = _split_records(records, train_fraction=train_fraction, seed=seed)
    if not train_records:
        raise ValueError("Training set is empty. Check filtering parameters.")

    train_dataset = PositronAngleGraphDataset(train_records)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    val_loader = None
    val_dataset = None
    if val_records:
        val_dataset = PositronAngleGraphDataset(val_records)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, train_dataset, val_dataset


def _angle_errors_deg(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    preds_n = F.normalize(preds, dim=1)
    targets_n = F.normalize(targets, dim=1)
    cos_sim = torch.clamp((preds_n * targets_n).sum(dim=1), -0.999999, 0.999999)
    return torch.arccos(cos_sim) * (180.0 / math.pi)


def _angle_histogram(errors: np.ndarray | Sequence[float], bins: int = 50, span: Tuple[float, float] = (0.0, 180.0)):
    values = np.asarray(errors, dtype=np.float32)
    if values.size == 0:
        return None
    counts, edges = np.histogram(values, bins=bins, range=span)
    return {
        'counts': counts.tolist(),
        'bin_edges': edges.tolist(),
    }


def _plot_histogram(title: str, hist_data: Optional[Dict[str, Any]]) -> None:
    if hist_data is None:
        return
    counts = np.asarray(hist_data['counts'], dtype=np.float32)
    edges = np.asarray(hist_data['bin_edges'], dtype=np.float32)
    if counts.size == 0 or edges.size == 0:
        return
    centers = 0.5 * (edges[:-1] + edges[1:])
    widths = np.diff(edges)
    fig, axes = plt.subplots(1, 2, figsize=(10, 3))
    for ax, log_scale in zip(axes, [False, True]):
        ax.bar(centers, counts, width=widths, align='center', alpha=0.75, color='crimson')
        ax.set_title(f"{title} angle error histogram" + (' (log)' if log_scale else ''))
        ax.set_xlabel('Angular error [deg]')
        ax.set_ylabel('Counts')
        ax.set_xlim(0.0, 180.0)
        if log_scale:
            ax.set_yscale('log')
        ax.grid(True, linestyle='--', alpha=0.3)
    plt.tight_layout()
    plt.show()


def _run_train_epoch(model, loader, optimizer, loss_fn, device, grad_clip=None):
    model.train()
    total_loss = 0.0
    total_error = 0.0
    total_samples = 0
    angle_values: List[torch.Tensor] = []

    for data in loader:
        data = data.to(device)
        target = data.y.view(data.num_graphs, -1).float()
        preds = model(data)
        loss = loss_fn(preds, target)

        optimizer.zero_grad()
        loss.backward()
        if grad_clip is not None and grad_clip > 0:
            clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()

        batch_errors = _angle_errors_deg(preds.detach(), target)
        angle_values.append(batch_errors.cpu())

        batch_size = target.size(0)
        total_loss += loss.item() * batch_size
        total_error += batch_errors.sum().item()
        total_samples += batch_size

    mean_loss = total_loss / max(total_samples, 1)
    mean_error = total_error / max(total_samples, 1)
    all_errors = torch.cat(angle_values).numpy() if angle_values else np.zeros(0, dtype=np.float32)
    return mean_loss, mean_error, all_errors


def _run_eval_epoch(model, loader, loss_fn, device):
    if loader is None:
        return math.nan, math.nan, np.zeros(0, dtype=np.float32)

    model.eval()
    total_loss = 0.0
    total_error = 0.0
    total_samples = 0
    angle_values: List[torch.Tensor] = []

    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            target = data.y.view(data.num_graphs, -1).float()
            preds = model(data)
            loss = loss_fn(preds, target)

            batch_errors = _angle_errors_deg(preds, target)
            angle_values.append(batch_errors.cpu())

            batch_size = target.size(0)
            total_loss += loss.item() * batch_size
            total_error += batch_errors.sum().item()
            total_samples += batch_size

    mean_loss = total_loss / max(total_samples, 1)
    mean_error = total_error / max(total_samples, 1)
    all_errors = torch.cat(angle_values).numpy() if angle_values else np.zeros(0, dtype=np.float32)
    return mean_loss, mean_error, all_errors


def train_positron_angle_predictor(
    records: Sequence[PositronAngleRecord | Dict[str, Any]],
    *,
    model: Optional[torch.nn.Module] = None,
    batch_size: int = 128,
    epochs: int = 20,
    lr: float = 5e-4,
    weight_decay: float = 1e-5,
    train_fraction: float = 0.85,
    seed: int = 13,
    grad_clip: Optional[float] = 2.0,
    scheduler_step_size: Optional[int] = None,
    scheduler_gamma: float = 0.7,
    device: Optional[torch.device | str] = None,
):
    if model is None:
        model = PositronAngleModel()

    device_obj = torch.device(device) if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device_obj)

    train_loader, val_loader, train_dataset, val_dataset = _make_loaders(
        records,
        batch_size=batch_size,
        train_fraction=train_fraction,
        seed=seed,
    )

    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    scheduler = None
    if scheduler_step_size is not None and scheduler_step_size > 0:
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step_size, gamma=scheduler_gamma)

    history = []
    for epoch in range(1, epochs + 1):
        current_lr = optimizer.param_groups[0]['lr']
        train_loss, train_err, train_errors = _run_train_epoch(model, train_loader, optimizer, loss_fn, device_obj, grad_clip)
        val_loss, val_err, val_errors = _run_eval_epoch(model, val_loader, loss_fn, device_obj)

        if scheduler is not None:
            scheduler.step()

        train_hist = _angle_histogram(train_errors)
        val_hist = _angle_histogram(val_errors) if not math.isnan(val_loss) else None

        _plot_histogram('Train', train_hist)
        if val_hist is not None:
            _plot_histogram('Validation', val_hist)

        if not math.isnan(val_loss):
            print(
                f"Epoch {epoch:02d} | lr={current_lr:.5f} | train_loss={train_loss:.5f} err={train_err:.3f}° | "
                f"val_loss={val_loss:.5f} err={val_err:.3f}°"
            )
        else:
            print(
                f"Epoch {epoch:02d} | lr={current_lr:.5f} | train_loss={train_loss:.5f} err={train_err:.3f}°"
            )

        history.append({
            'epoch': epoch,
            'lr': float(current_lr),
            'train_loss': float(train_loss),
            'train_error_deg': float(train_err),
            'train_histogram': train_hist,
            'val_loss': None if math.isnan(val_loss) else float(val_loss),
            'val_error_deg': None if math.isnan(val_err) else float(val_err),
            'val_histogram': val_hist,
        })

    return {
        'model': model,
        'history': history,
        'train_size': len(train_dataset),
        'val_size': 0 if val_dataset is None else len(val_dataset),
    }


In [None]:

# Example usage
# raw_groups = [...]  # sequence of per-group arrays
# angle_targets = [...]  # list of [theta, phi] (radians) or unit vectors [x, y, z]
# records = build_angle_records(raw_groups, angle_targets)
#
# model = PositronAngleModel(in_channels=5, hidden=150, heads=4, layers=3, dropout=0.1)
#
# results = train_positron_angle_predictor(
#     records,
#     model=model,
#     batch_size=64,
#     epochs=20,
#     lr=5e-4,
#     weight_decay=1e-5,
#     train_fraction=0.85,
#     seed=42,
#     grad_clip=2.0,
#     scheduler_step_size=10,
#     scheduler_gamma=0.6,
# )
#
# print(f"Train groups: {results['train_size']} | Val groups: {results['val_size']}")
# for entry in results['history']:
#     val_part = ''
#     if entry['val_loss'] is not None:
#         val_part = f" | val_loss={entry['val_loss']:.5f} err={entry['val_error_deg']:.3f}°"
#     print(
#         f"Epoch {entry['epoch']:02d}: train_loss={entry['train_loss']:.5f} err={entry['train_error_deg']:.3f}°{val_part}"
#     )
