Skip to content

Commit

Permalink
Merge pull request #730 from OscarBarreraGithub/main
Browse files Browse the repository at this point in the history
Updated Examples to GraphNetDataModule
  • Loading branch information
OscarBarreraGithub authored Aug 6, 2024
2 parents 652f194 + 85a81ed commit 6e78d88
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 88 deletions.
42 changes: 28 additions & 14 deletions examples/04_training/01_train_dynedge.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from graphnet.models.task.reconstruction import EnergyReconstruction
from graphnet.training.callbacks import PiecewiseLinearLR
from graphnet.training.loss_functions import LogCoshLoss
from graphnet.training.utils import make_train_validation_dataloader
from graphnet.utilities.argparse import ArgumentParser
from graphnet.utilities.logging import Logger
from graphnet.data import GraphNeTDataModule
from graphnet.data.dataset import SQLiteDataset
from graphnet.data.dataset import ParquetDataset

# Constants
features = FEATURES.PROMETHEUS
Expand Down Expand Up @@ -68,6 +70,9 @@ def main(
"gpus": gpus,
"max_epochs": max_epochs,
},
"dataset_reference": SQLiteDataset
if path.endswith(".db")
else ParquetDataset,
}

archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model_without_configs")
Expand All @@ -79,21 +84,30 @@ def main(
# Define graph representation
graph_definition = KNNGraph(detector=Prometheus())

(
training_dataloader,
validation_dataloader,
) = make_train_validation_dataloader(
db=config["path"],
graph_definition=graph_definition,
pulsemaps=config["pulsemap"],
features=features,
truth=truth,
batch_size=config["batch_size"],
num_workers=config["num_workers"],
truth_table=truth_table,
selection=None,
# Use GraphNetDataModule to load in data
dm = GraphNeTDataModule(
dataset_reference=config["dataset_reference"],
dataset_args={
"truth": truth,
"truth_table": truth_table,
"features": features,
"graph_definition": graph_definition,
"pulsemaps": [config["pulsemap"]],
"path": config["path"],
},
train_dataloader_kwargs={
"batch_size": config["batch_size"],
"num_workers": config["num_workers"],
},
test_dataloader_kwargs={
"batch_size": config["batch_size"],
"num_workers": config["num_workers"],
},
)

training_dataloader = dm.train_dataloader
validation_dataloader = dm.val_dataloader

# Building model

backbone = DynEdge(
Expand Down
53 changes: 34 additions & 19 deletions examples/04_training/02_train_tito_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
)
from graphnet.training.labels import Direction
from graphnet.training.loss_functions import VonMisesFisher3DLoss
from graphnet.training.utils import make_train_validation_dataloader
from graphnet.utilities.argparse import ArgumentParser
from graphnet.utilities.logging import Logger
from graphnet.data import GraphNeTDataModule
from graphnet.data.dataset import SQLiteDataset
from graphnet.data.dataset import ParquetDataset

# Constants
features = FEATURES.PROMETHEUS
Expand Down Expand Up @@ -70,6 +72,9 @@ def main(
"gpus": gpus,
"max_epochs": max_epochs,
},
"dataset_reference": SQLiteDataset
if path.endswith(".db")
else ParquetDataset,
}

graph_definition = KNNGraph(detector=Prometheus())
Expand All @@ -79,27 +84,37 @@ def main(
# Log configuration to W&B
wandb_logger.experiment.config.update(config)

(
training_dataloader,
validation_dataloader,
) = make_train_validation_dataloader(
db=config["path"],
graph_definition=graph_definition,
selection=None,
pulsemaps=config["pulsemap"],
features=features,
truth=truth,
batch_size=config["batch_size"],
num_workers=config["num_workers"],
truth_table=truth_table,
index_column="event_no",
labels={
"direction": Direction(
azimuth_key="injection_azimuth", zenith_key="injection_zenith"
)
# Use GraphNetDataModule to load in data
dm = GraphNeTDataModule(
dataset_reference=config["dataset_reference"],
dataset_args={
"truth": truth,
"truth_table": truth_table,
"features": features,
"graph_definition": graph_definition,
"pulsemaps": [config["pulsemap"]],
"path": config["path"],
"index_column": "event_no",
"labels": {
"direction": Direction(
azimuth_key="injection_azimuth",
zenith_key="injection_zenith",
)
},
},
train_dataloader_kwargs={
"batch_size": config["batch_size"],
"num_workers": config["num_workers"],
},
test_dataloader_kwargs={
"batch_size": config["batch_size"],
"num_workers": config["num_workers"],
},
)

training_dataloader = dm.train_dataloader
validation_dataloader = dm.val_dataloader

# Building model
backbone = DynEdgeTITO(
nb_inputs=graph_definition.nb_outputs,
Expand Down
53 changes: 34 additions & 19 deletions examples/04_training/05_train_RNN_TITO.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
)
from graphnet.training.labels import Direction
from graphnet.training.loss_functions import VonMisesFisher3DLoss
from graphnet.training.utils import make_train_validation_dataloader
from graphnet.utilities.argparse import ArgumentParser
from graphnet.utilities.logging import Logger
from graphnet.data import GraphNeTDataModule
from graphnet.data.dataset import SQLiteDataset
from graphnet.data.dataset import ParquetDataset

# Constants
features = FEATURES.PROMETHEUS
Expand Down Expand Up @@ -74,6 +76,9 @@ def main(
"gpus": gpus,
"max_epochs": max_epochs,
},
"dataset_reference": SQLiteDataset
if path.endswith(".db")
else ParquetDataset,
}

graph_definition = KNNGraph(
Expand All @@ -91,27 +96,37 @@ def main(
# Log configuration to W&B
wandb_logger.experiment.config.update(config)

(
training_dataloader,
validation_dataloader,
) = make_train_validation_dataloader(
db=config["path"],
graph_definition=graph_definition,
selection=None,
pulsemaps=config["pulsemap"],
features=features,
truth=truth,
batch_size=config["batch_size"],
num_workers=config["num_workers"],
truth_table=truth_table,
index_column="event_no",
labels={
"direction": Direction(
azimuth_key="injection_azimuth", zenith_key="injection_zenith"
)
# Use GraphNetDataModule to load in data
dm = GraphNeTDataModule(
dataset_reference=config["dataset_reference"],
dataset_args={
"truth": truth,
"truth_table": truth_table,
"features": features,
"graph_definition": graph_definition,
"pulsemaps": [config["pulsemap"]],
"path": config["path"],
"index_column": "event_no",
"labels": {
"direction": Direction(
azimuth_key="injection_azimuth",
zenith_key="injection_zenith",
)
},
},
train_dataloader_kwargs={
"batch_size": config["batch_size"],
"num_workers": config["num_workers"],
},
test_dataloader_kwargs={
"batch_size": config["batch_size"],
"num_workers": config["num_workers"],
},
)

training_dataloader = dm.train_dataloader
validation_dataloader = dm.val_dataloader

# Building model
backbone = RNN_TITO(
nb_inputs=graph_definition.nb_outputs,
Expand Down
53 changes: 34 additions & 19 deletions examples/04_training/06_train_icemix_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
)
from graphnet.training.labels import Direction
from graphnet.training.loss_functions import VonMisesFisher3DLoss
from graphnet.training.utils import make_train_validation_dataloader
from graphnet.utilities.argparse import ArgumentParser
from graphnet.utilities.logging import Logger
from graphnet.data import GraphNeTDataModule
from graphnet.data.dataset import SQLiteDataset
from graphnet.data.dataset import ParquetDataset

# Constants
features = FEATURES.PROMETHEUS
Expand Down Expand Up @@ -76,6 +78,9 @@ def main(
"max_epochs": max_epochs,
"distribution_strategy": "ddp_find_unused_parameters_true",
},
"dataset_reference": SQLiteDataset
if path.endswith(".db")
else ParquetDataset,
}

graph_definition = KNNGraph(
Expand All @@ -96,27 +101,37 @@ def main(
# Log configuration to W&B
wandb_logger.experiment.config.update(config)

(
training_dataloader,
validation_dataloader,
) = make_train_validation_dataloader(
db=config["path"],
graph_definition=graph_definition,
selection=None,
pulsemaps=config["pulsemap"],
features=features,
truth=truth,
batch_size=config["batch_size"],
num_workers=config["num_workers"],
truth_table=truth_table,
index_column="event_no",
labels={
"direction": Direction(
azimuth_key="injection_azimuth", zenith_key="injection_zenith"
)
# Use GraphNetDataModule to load in data
dm = GraphNeTDataModule(
dataset_reference=config["dataset_reference"],
dataset_args={
"truth": truth,
"truth_table": truth_table,
"features": features,
"graph_definition": graph_definition,
"pulsemaps": [config["pulsemap"]],
"path": config["path"],
"index_column": "event_no",
"labels": {
"direction": Direction(
azimuth_key="injection_azimuth",
zenith_key="injection_zenith",
)
},
},
train_dataloader_kwargs={
"batch_size": config["batch_size"],
"num_workers": config["num_workers"],
},
test_dataloader_kwargs={
"batch_size": config["batch_size"],
"num_workers": config["num_workers"],
},
)

training_dataloader = dm.train_dataloader
validation_dataloader = dm.val_dataloader

# Building model
backbone = DeepIce(
hidden_dim=768,
Expand Down
34 changes: 17 additions & 17 deletions examples/04_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

This subfolder contains two main training scripts:

**`01_train_dynedge.py`** ** Shows how to train a GNN on neutrino telescope data **without configuration files,** i.e., by programatically constructing the dataset and model used. This is good for debugging and experimenting with different dataset settings and model configurations, as it is easier to build the model using the API than by writing configuration files from scratch. **This is our recommended way of getting started with the library**. For instance, try running:
**`01_train_dynedge.py`** Shows how to train a GNN on neutrino telescope data **without configuration files,** i.e., by programatically constructing the dataset and model used. This is good for debugging and experimenting with different dataset settings and model configurations, as it is easier to build the model using the API than by writing configuration files from scratch. **This is our recommended way of getting started with the library**. For instance, try running:

```bash
# Show the CLI
Expand All @@ -11,35 +11,35 @@ This subfolder contains two main training scripts:
# Train energy regression model
(graphnet) $ python examples/04_training/01_train_dynedge.py

# Same as above, as this is the default model config.
(graphnet) $ python examples/04_training/01_train_model.py \
--model-config configs/models/example_energy_reconstruction_model.yml

# Train using a single GPU
(graphnet) $ python examples/04_training/01_train_dynedge.py --gpus 0

# Train using multiple GPUs
(graphnet) $ python examples/04_training/01_train_dynedge.py --gpus 0 1
```

**`03_train_model_dynedge_from_config.py`** Shows how to train a GNN on neutrino telescope data **using configuration files** to construct the dataset that loads the data and the model that is trained. This is the recommended way to configure standard dataset and models, as it is easier to ready and share than doing so in pure code. This example can be run using a few different models targeting different physics use cases. For instance, you can try running:

```bash
# Show the CLI
(graphnet) $ python examples/04_training/03_train_dynedge_from_config.py --help

# Train energy regression model
(graphnet) $ python examples/04_training/03_train_dynedge_from_config.py

# Same as above, as this is the default model config.
(graphnet) $ python examples/04_training/03_train_dynedge_from_config.py \
--model-config configs/models/example_energy_reconstruction_model.yml

# Train a vertex position reconstruction model
(graphnet) $ python examples/04_training/01_train_dynedge.py \
(graphnet) $ python examples/04_training/03_train_dynedge_from_config.py \
--model-config configs/models/example_vertex_position_reconstruction_model.yml

# Trains a direction (zenith, azimuth) reconstruction model. Note that the
# chosen `Task` in the model config file also returns estimated "kappa" values,
# i.e. inverse variance, for each predicted feature, meaning that we need to
# manually specify the names of these.
(graphnet) $ python examples/04_training/01_train_model_dynedge.py --gpus 0 \
(graphnet) $ python examples/04_training/03_train_dynedge_from_config.py --gpus 0 \
--model-config configs/models/example_direction_reconstruction_model.yml \
--prediction-names zenith_pred zenith_kappa_pred azimuth_pred azimuth_kappa_pred
```

**`03_train_model_dynedge_from_config.py** Shows how to train a GNN on neutrino telescope data **using configuration files** to construct the dataset that loads the data and the model that is trained. This is the recommended way to configure standard dataset and models, as it is easier to ready and share than doing so in pure code. This example can be run using a few different models targeting different physics use cases. For instance, you can try running:

```bash
# Show the CLI
(graphnet) $ python examples/04_training/02_train_dynedge_from_config.py --help

# Train energy regression model
(graphnet) $ python examples/04_training/02_train_dynedge_from_config.py
```
3 changes: 3 additions & 0 deletions src/graphnet/data/dataset/parquet/parquet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
loss_weight_default_value: Optional[float] = None,
seed: Optional[int] = None,
cache_size: int = 1,
labels: Optional[Dict[str, Any]] = None,
):
"""Construct Dataset.
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(
graph_definition: Method that defines the graph representation.
cache_size: Number of batches to cache in memory.
Must be at least 1. Defaults to 1.
labels: Dictionary of labels to be added to the dataset.
"""
self._validate_selection(selection)
# Base class constructor
Expand All @@ -122,6 +124,7 @@ def __init__(
loss_weight_default_value=loss_weight_default_value,
seed=seed,
graph_definition=graph_definition,
labels=labels,
)

# mypy..
Expand Down

0 comments on commit 6e78d88

Please sign in to comment.