# Quickstart: Train and Evaluate a Graph Classifier
Train a GroupClassifier on synthetic data, then evaluate with standardized metrics and plots.

In [None]:

import torch
from torch_geometric.data import Data

from pioneerml.models import GroupClassifier
from pioneerml.training import GraphLightningModule, GraphDataModule, plot_loss_curves
from pioneerml.pipelines import Pipeline, Context, StageConfig
from pioneerml.pipelines.stages import LightningTrainStage, EvaluateStage, CollectPredsStage
from pioneerml.evaluation import (
    plot_multilabel_confusion_matrix,
    plot_precision_recall_curves,
    resolve_preds_targets,
)


In [None]:

def make_synthetic_group(num_nodes=16, num_classes=3):
    # Class-specific offsets with some overlap so the task is non-trivial
    class_offsets = torch.tensor([
        [1.0, 0.0, 0.0, 0.0, 0.0],  # pi-ish
        [0.0, 1.0, 0.0, 0.0, 0.0],  # mu-ish
        [0.0, 0.0, 1.0, 0.0, 0.0],  # e+ ish
    ])
    label = torch.randint(0, num_classes, (1,)).item()
    x = torch.randn(num_nodes, 5) * 1.2 + class_offsets[label]
    edge_index = torch.randint(0, num_nodes, (2, num_nodes * 2))
    edge_attr = torch.randn(edge_index.shape[1], 4)
    y = torch.zeros(num_classes)
    y[label] = 1.0  # one-hot label
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)

records = [make_synthetic_group() for _ in range(256)]
datamodule = GraphDataModule(dataset=records, val_split=0.25, batch_size=16)


In [None]:

model = GroupClassifier(num_classes=3)
lightning_module = GraphLightningModule(model, task='classification', lr=5e-4)

train_stage = LightningTrainStage(
    config=StageConfig(
        name='train',
        params={
            'module': lightning_module,
            'datamodule': datamodule,
            'trainer_params': {
                'max_epochs': 10,
                'limit_train_batches': 5,
                'limit_val_batches': 1,
                'logger': False,
                'enable_checkpointing': False,
            },
        },
    )
)

collect_stage = CollectPredsStage(
    config=StageConfig(
        name='collect_preds',
        inputs=['lightning_module', 'datamodule'],
        outputs=['preds', 'targets'],
        params={
            'dataloader': 'val',
            'preds_key': 'preds',
            'targets_key': 'targets',
        },
    )
)

eval_stage = EvaluateStage(
    config=StageConfig(
        name='evaluate',
        inputs=['preds', 'targets'],
        outputs=['metrics'],
        params={
            'task': 'multilabel',
            'plots': ['multilabel_confusion', 'precision_recall'],
            'save_dir': 'outputs/tutorial_quickstart',
            'metric_params': {'class_names': ['pi', 'mu', 'e+']},
        },
    )
)

pipeline = Pipeline([train_stage, collect_stage, eval_stage], name='quickstart')
ctx = pipeline.run(Context())
ctx['metrics']


## Loss curves
Training/validation loss should trend down; divergence can signal overfitting or learning-rate issues.

In [None]:

_ = plot_loss_curves(
    ctx['lightning_module'],
    title='Training/Validation Loss (quickstart)',
    xlabel='Epoch',
    show=True,
)


## Metrics and plots
Precisionâ€“recall curves come from sweeping the decision threshold; confusion matrices show class-wise true/false positives/negatives (normalized to sum to 1).

In [None]:

preds, targets = resolve_preds_targets(ctx)

plot_multilabel_confusion_matrix(
    predictions=preds,
    targets=targets,
    class_names=['pi', 'mu', 'e+'],
    threshold=0.5,
    normalize=True,
    show=True,
)

plot_precision_recall_curves(
    predictions=preds,
    targets=targets,
    class_names=['pi', 'mu', 'e+'],
    show=True,
)
