In [1]:
import sys
from pathlib import Path

# --- Auto-detect project root by walking upward ---
cwd = Path().resolve()
ROOT = None

for parent in [cwd] + list(cwd.parents):
    if (parent / "src" / "pioneerml").exists():
        ROOT = parent
        break

if ROOT is None:
    raise RuntimeError("Could not find project root containing src/pioneerml")

sys.path.append(str(ROOT / "src"))
print("Using project root:", ROOT)

# --- Normal imports ---
import torch
import numpy as np
import matplotlib.pyplot as plt

from pioneerml.data.datasets.graph_group import GraphRecord
from pioneerml.training.datamodules.group import GroupClassificationDataModule
from pioneerml.models.classifiers.group_classifier import GroupClassifier
from pioneerml.training.lightning import GraphLightningModule
from pioneerml.training.utils import default_precision_for_accelerator, set_tensor_core_precision
from pioneerml.training.visualization import plot_loss_curves
from pioneerml.training.utils import default_precision_for_accelerator
from pioneerml.pipelines.stage import StageConfig
from pioneerml.pipelines.stages.model import LightningTrainStage
from pioneerml.pipelines.pipeline import Pipeline
from pioneerml.pipelines.context import Context
from pioneerml.training.progress import CleanProgressBar


import pytorch_lightning as pl


Using project root: /home/jack/python_projects/pioneerML


In [2]:
if torch.cuda.is_available():
    accelerator = "cuda"
    device = torch.device("cuda")
    precision = default_precision_for_accelerator("cuda")
    set_tensor_core_precision("medium")
else:
    accelerator = "cpu"
    device = torch.device("cpu")
    precision = "32-true"

accelerator, precision


('cuda', '16-mixed')

In [3]:
def make_synthetic_record(num_hits: int, event_id: int):
    coord = np.random.randn(num_hits).astype(np.float32)
    z = np.random.randn(num_hits).astype(np.float32)
    energy = np.abs(np.random.randn(num_hits)).astype(np.float32)
    view = np.random.randint(0, 2, size=num_hits).astype(np.float32)

    energy_mean = energy.mean()
    spatial_spread = coord.std()

    labels = []

    # Energy: class 0 (high) or 1 (low)
    if energy_mean > 1.0:
        labels.append(0)
    else:
        labels.append(1)

    # Hits: class 2 (high) or 3 (low)
    if num_hits > 20:
        labels.append(2)
    else:
        labels.append(3)

    # Spread: class 4 (high) or 5 (low)
    if spatial_spread > 1.0:
        labels.append(4)
    else:
        labels.append(5)

    return GraphRecord(
        coord=coord,
        z=z,
        energy=energy,
        view=view,
        labels=labels,
        event_id=event_id,
        group_id=event_id,
    )


In [4]:
records = [make_synthetic_record(num_hits=np.random.randint(5, 40), event_id=i)
           for i in range(600)]

len(records)


600

In [5]:
dm = GroupClassificationDataModule(
    records,
    num_classes=3,
    batch_size=32,
    val_split=0.2,
    test_split=0.0,
    num_workers=0,
)

dm.setup()


In [6]:
# 1. Build model for 6 classes
model = GroupClassifier(
    hidden=128,
    num_blocks=2,
    num_classes=3,
)

# 2. Wrap it in the Lightning module
lightning_module = GraphLightningModule(
    model=model,
    task="classification",
    lr=1e-3,
)

# 3. Build DataModule (you already have records)
datamodule = GroupClassificationDataModule(
    records,
    num_classes=6,             # IMPORTANT for your new 6-class setup
    batch_size=32,
    val_split=0.2,
    num_workers=4,             # or os.cpu_count()
    pin_memory=True,
)


In [7]:
# building training stage
train_stage = LightningTrainStage(
    config=StageConfig(
        name="train_synthetic_classifier",
        params={
            "module": lightning_module,
            "datamodule": datamodule,
            "trainer_params": {
                "accelerator": "auto",            # GPU if available
                "devices": 1,
                "max_epochs": 10,
                "logger": False,
                "enable_checkpointing": False,
                "precision": default_precision_for_accelerator("auto"),
                "enable_model_summary": True,
                "enable_progress_bar": False,
                "callbacks": [CleanProgressBar(bar_width=30)],
            },
        },
    )
)



# Build pipeline with stages, this pipeline will just train so it will have one stage
pipeline = Pipeline(
    stages=[train_stage],
    name="synthetic_classification_pipeline",
)

#Run pipeline
ctx = pipeline.run(Context())


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/jack/miniconda3/envs/pioneerml-uv/lib/python3.11/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision 16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name    | Type              | Params | Mode 
------------------------------------------------------
0 | model   | GroupClassifier   | 465 K  | train
1 | loss_fn | BCEWithLogitsLoss | 0      | train
------------------------------------------------------
465 K     Trainable params
0         Non-trainable params
465 K     Total params
1.862     Total estimated model params size (MB)
47        Modules in train mode
0         Modules in eval mode
[1/1] Stage 'train_synthetic_classifier' failed: Target size (torch.Size([64, 3])) must be the same as input size (torc

ValueError: Target size (torch.Size([64, 3])) must be the same as input size (torch.Size([32, 3]))

In [None]:
#Get info from pipeline as context got filled out after training
print("\n=== PIPELINE CONTEXT SUMMARY ===")
print(ctx.summary())

print("\n=== TRAINING METRICS ===")
print(ctx.get("metrics", {}))

In [None]:
plot_loss_curves(
    lightning_module.train_epoch_loss_history,
    lightning_module.val_epoch_loss_history,
    title="Training vs Validation Loss",
    xlabel="Epoch",
    show=True,
)


In [None]:
preds = []
truths = []

trainer.model.eval()

for batch in dm.val_dataloader():
    batch = batch.to(device)
    with torch.no_grad():
        out = trainer.model(batch)
    preds.append(torch.sigmoid(out).cpu().numpy())
    truths.append(batch.y.cpu().numpy())

preds = np.vstack(preds)
truths = np.vstack(truths)

preds_binary = (preds > 0.5).astype(np.float32)


In [None]:
def extract_parameters(record: GraphRecord):
    coord = np.asarray(record.coord)
    energy = np.asarray(record.energy)
    return (
        energy.mean(),
        len(coord),
        coord.std(),
    )

params = np.array([extract_parameters(r) for r in dm.val_dataset.indices])
# Properly map val dataset indices back to original records
val_records = [records[i] for i in dm.val_dataset.indices]
params = np.array([extract_parameters(r) for r in val_records])


In [None]:
class_names = ["energy_high", "nhits_high", "spread_high"]
thresholds = [1.0, 20, 1.0]

fig, axes = plt.subplots(3, 1, figsize=(12, 12), sharex=True)

for i, ax in enumerate(axes):
    y_param = params[:, i]
    pred = preds_binary[:, i]
    truth = truths[:, i]

    colors = ["red" if p == 0 else "green" for p in pred]
    ax.scatter(np.arange(len(y_param)), y_param, c=colors, alpha=0.7)

    ax.axhline(thresholds[i], color="black", linestyle="--", label="truth threshold")

    wrong = pred != truth
    ax.scatter(np.where(wrong)[0], y_param[wrong], marker="x", s=80, c="black")

    ax.set_ylabel(class_names[i])

axes[-1].set_xlabel("sample index")
plt.tight_layout()
plt.show()
