Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify examples #456

Merged
merged 53 commits into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
bb78249
self contained output and target names
RasmusOrsoe Mar 18, 2023
35b7b08
polish
RasmusOrsoe Mar 18, 2023
045ca1a
added explicit, default target and output labels
RasmusOrsoe Mar 19, 2023
4e2aca4
polish
RasmusOrsoe Mar 19, 2023
682056d
shell script rename
RasmusOrsoe Mar 19, 2023
561ddb0
renamed files and updated readme
RasmusOrsoe Mar 19, 2023
4a089e4
further simplification
RasmusOrsoe Mar 19, 2023
8cc2426
removed print message
RasmusOrsoe Mar 19, 2023
f6652b1
removed changes to example scripts
RasmusOrsoe Mar 26, 2023
332c586
made wandb optional
RasmusOrsoe Mar 26, 2023
ffe7fb5
made wandb optional
RasmusOrsoe Mar 26, 2023
d9b25d7
Update examples/04_training/01_train_model.py
RasmusOrsoe Mar 28, 2023
23c162c
Update examples/04_training/01_train_model.py
RasmusOrsoe Mar 28, 2023
21618db
Update examples/04_training/01_train_model.py
RasmusOrsoe Mar 28, 2023
e902422
Update examples/04_training/02_train_model_without_configs.py
RasmusOrsoe Mar 28, 2023
711612c
Update examples/04_training/02_train_model_without_configs.py
RasmusOrsoe Mar 28, 2023
2daec4d
Update examples/04_training/01_train_model.py
RasmusOrsoe Mar 28, 2023
b18b012
Update examples/04_training/02_train_model_without_configs.py
RasmusOrsoe Mar 28, 2023
7c9011d
Update src/graphnet/models/task/task.py
RasmusOrsoe Mar 28, 2023
53b0abd
Update src/graphnet/models/task/task.py
RasmusOrsoe Mar 28, 2023
bd3e41f
Update src/graphnet/models/task/task.py
RasmusOrsoe Mar 28, 2023
a944453
Update src/graphnet/models/task/task.py
RasmusOrsoe Mar 28, 2023
f4a1650
Update src/graphnet/models/task/task.py
RasmusOrsoe Mar 28, 2023
2002d4a
Update src/graphnet/models/model.py
RasmusOrsoe Mar 28, 2023
bf6b062
Update src/graphnet/models/standard_model.py
RasmusOrsoe Mar 28, 2023
6f0cf5a
Update src/graphnet/models/task/reconstruction.py
RasmusOrsoe Mar 28, 2023
a8caefb
Update src/graphnet/models/task/reconstruction.py
RasmusOrsoe Mar 28, 2023
8701432
Update src/graphnet/models/task/reconstruction.py
RasmusOrsoe Mar 28, 2023
101f051
Update src/graphnet/models/task/reconstruction.py
RasmusOrsoe Mar 28, 2023
c609e76
Update src/graphnet/models/task/reconstruction.py
RasmusOrsoe Mar 28, 2023
396f039
Update src/graphnet/models/task/reconstruction.py
RasmusOrsoe Mar 28, 2023
d2a71ef
Update src/graphnet/models/task/reconstruction.py
RasmusOrsoe Mar 28, 2023
8415bea
Update src/graphnet/models/task/reconstruction.py
RasmusOrsoe Mar 28, 2023
a445c81
Update src/graphnet/models/task/reconstruction.py
RasmusOrsoe Mar 28, 2023
3f86536
Update src/graphnet/models/task/reconstruction.py
RasmusOrsoe Mar 28, 2023
b88fd78
Update src/graphnet/models/task/reconstruction.py
RasmusOrsoe Mar 28, 2023
43e8bbe
Update src/graphnet/models/task/reconstruction.py
RasmusOrsoe Mar 28, 2023
e0e7a52
refractor default callbacks
RasmusOrsoe Mar 28, 2023
8285c77
prediction columns
RasmusOrsoe Mar 28, 2023
dd5b195
le docce strings
RasmusOrsoe Mar 28, 2023
eb31074
mypy....
RasmusOrsoe Mar 28, 2023
2cb7dc7
mypy..
RasmusOrsoe Mar 28, 2023
4d446b7
added warning for patience & early stopping
RasmusOrsoe Apr 13, 2023
06e76bf
typo fix
RasmusOrsoe Apr 13, 2023
fe8b547
removed try except in predict_as_dataframe
RasmusOrsoe Apr 13, 2023
b952791
added predict_as_dataframe to standardmodel
RasmusOrsoe Apr 13, 2023
0b99a38
added predict_as_dataframe to standardmodel
RasmusOrsoe Apr 13, 2023
38d0e96
prediction labels in task
RasmusOrsoe Apr 13, 2023
6e3f0c4
Update src/graphnet/models/task/task.py
RasmusOrsoe Apr 18, 2023
9a6d3ea
Update src/graphnet/models/model.py
RasmusOrsoe Apr 18, 2023
0ef572f
black
RasmusOrsoe Apr 18, 2023
f45c527
flake8
RasmusOrsoe Apr 18, 2023
70a9e22
cmon
RasmusOrsoe Apr 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions examples/04_training/01_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@
from graphnet.utilities.logging import Logger


# Make sure W&B output directory exists
WANDB_DIR = "./wandb/"
os.makedirs(WANDB_DIR, exist_ok=True)


def main(
dataset_config_path: str,
model_config_path: str,
Expand All @@ -34,18 +29,23 @@ def main(
num_workers: int,
prediction_names: Optional[List[str]],
suffix: Optional[str] = None,
wandb: bool = False,
) -> None:
"""Run example."""
# Construct Logger
logger = Logger()

# Initialise Weights & Biases (W&B) run
wandb_logger = WandbLogger(
project="example-script",
entity="graphnet-team",
save_dir=WANDB_DIR,
log_model=True,
)
if wandb:
# Make sure W&B output directory exists
wandb_dir = "./wandb/"
os.makedirs(wandb_dir, exist_ok=True)
wandb_logger = WandbLogger(
project="example-script",
entity="graphnet-team",
save_dir=wandb_dir,
log_model=True,
)

# Build model
model_config = ModelConfig.load(model_config_path)
Expand Down Expand Up @@ -80,7 +80,7 @@ def main(
# Log configurations to W&B
# NB: Only log to W&B on the rank-zero process in case of multi-GPU
# training.
if rank_zero_only.rank == 0:
if wandb and rank_zero_only.rank == 0:
wandb_logger.experiment.config.update(config)
wandb_logger.experiment.config.update(model_config.as_dict())
wandb_logger.experiment.config.update(dataset_config.as_dict())
Expand All @@ -98,7 +98,7 @@ def main(
dataloaders["train"],
dataloaders["validation"],
callbacks=callbacks,
logger=wandb_logger,
logger=wandb_logger if wandb else None,
**config.fit,
)

Expand Down Expand Up @@ -166,6 +166,12 @@ def main(
default=None,
)

parser.add_argument(
"--wandb",
action="store_true",
help="If True, Weights & Biases are used to track the experiment.",
)

args = parser.parse_args()

main(
Expand All @@ -178,4 +184,5 @@ def main(
args.num_workers,
args.prediction_names,
args.suffix,
args.wandb,
)
41 changes: 24 additions & 17 deletions examples/04_training/02_train_model_without_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@
features = FEATURES.PROMETHEUS
truth = TRUTH.PROMETHEUS

# Make sure W&B output directory exists
WANDB_DIR = "./wandb/"
os.makedirs(WANDB_DIR, exist_ok=True)


def main(
path: str,
Expand All @@ -40,18 +36,23 @@ def main(
early_stopping_patience: int,
batch_size: int,
num_workers: int,
wandb: bool = False,
) -> None:
"""Run example."""
# Construct Logger
logger = Logger()

# Initialise Weights & Biases (W&B) run
wandb_logger = WandbLogger(
project="example-script",
entity="graphnet-team",
save_dir=WANDB_DIR,
log_model=True,
)
if wandb:
# Make sure W&B output directory exists
wandb_dir = "./wandb/"
os.makedirs(wandb_dir, exist_ok=True)
wandb_logger = WandbLogger(
project="example-script",
entity="graphnet-team",
save_dir=wandb_dir,
log_model=True,
)

logger.info(f"features: {features}")
logger.info(f"truth: {truth}")
Expand All @@ -72,9 +73,9 @@ def main(

archive = os.path.join(EXAMPLE_OUTPUT_DIR, "train_model_without_configs")
run_name = "dynedge_{}_example".format(config["target"])

# Log configuration to W&B
wandb_logger.experiment.config.update(config)
if wandb:
# Log configuration to W&B
wandb_logger.experiment.config.update(config)

(
training_dataloader,
Expand Down Expand Up @@ -137,17 +138,16 @@ def main(
training_dataloader,
validation_dataloader,
callbacks=callbacks,
logger=wandb_logger,
logger=wandb_logger if wandb else None,
**config["fit"],
)

# Get predictions
prediction_columns = [config["target"] + "_pred"]
additional_attributes = [config["target"]]
additional_attributes = model.target_labels
assert isinstance(additional_attributes, list) # mypy

results = model.predict_as_dataframe(
validation_dataloader,
prediction_columns=prediction_columns,
additional_attributes=additional_attributes + ["event_no"],
)

Expand Down Expand Up @@ -206,6 +206,12 @@ def main(
"num-workers",
)

parser.add_argument(
"--wandb",
action="store_true",
help="If True, Weights & Biases are used to track the experiment.",
)

args = parser.parse_args()

main(
Expand All @@ -218,4 +224,5 @@ def main(
args.early_stopping_patience,
args.batch_size,
args.num_workers,
args.wandb,
)
41 changes: 40 additions & 1 deletion src/graphnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pandas as pd
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers.logger import Logger as LightningLogger
import torch
from torch import Tensor
Expand All @@ -18,6 +19,7 @@

from graphnet.utilities.logging import Logger
from graphnet.utilities.config import Configurable, ModelConfig
from graphnet.training.callbacks import ProgressBar


class Model(Logger, Configurable, LightningModule, ABC):
Expand Down Expand Up @@ -85,9 +87,17 @@ def fit(
log_every_n_steps: int = 1,
gradient_clip_val: Optional[float] = None,
distribution_strategy: Optional[str] = "ddp",
early_stopping_patience: int = 5,
**trainer_kwargs: Any,
) -> None:
"""Fit `Model` using `pytorch_lightning.Trainer`."""
# Checks
if callbacks is None:
callbacks = self._create_default_callbacks(
val_dataloader=val_dataloader,
early_stopping_patience=early_stopping_patience,
)
asogaard marked this conversation as resolved.
Show resolved Hide resolved

self.train(mode=True)

self._construct_trainers(
Expand All @@ -110,6 +120,26 @@ def fit(
self.warning("[ctrl+c] Exiting gracefully.")
pass

def _create_default_callbacks(
self, val_dataloader: DataLoader, early_stopping_patience: int
) -> List:
callbacks = [ProgressBar()]
if val_dataloader is not None:
has_es = False
assert isinstance(callbacks, list)
for callback in callbacks:
if isinstance(callback, EarlyStopping):
has_es = True
if has_es is False:
callbacks.append(
EarlyStopping(
monitor="val_loss",
patience=early_stopping_patience,
)
)
self.info("EarlyStopping callback added automatically.")
return callbacks

def predict(
self,
dataloader: DataLoader,
Expand Down Expand Up @@ -146,7 +176,7 @@ def predict(
def predict_as_dataframe(
self,
dataloader: DataLoader,
prediction_columns: List[str],
prediction_columns: Optional[List[str]] = None,
*,
node_level: bool = False,
additional_attributes: Optional[List[str]] = None,
Expand All @@ -164,6 +194,14 @@ def predict_as_dataframe(
additional_attributes = []
assert isinstance(additional_attributes, list)

if prediction_columns is None:
try:
prediction_columns = self.prediction_columns
except AttributeError:
assert (
1 == 2
), "Could not infer prediction_columns from model. Please specify prediction_columns."
assert isinstance(prediction_columns, list)
asogaard marked this conversation as resolved.
Show resolved Hide resolved
if (
not isinstance(dataloader.sampler, SequentialSampler)
and additional_attributes
Expand All @@ -178,6 +216,7 @@ def predict_as_dataframe(
"doesn't resample batches; or do not request "
"`additional_attributes`."
)
self.info(f"Column names for predictions are: \n {prediction_columns}")
predictions_torch = self.predict(
dataloader=dataloader,
gpus=gpus,
Expand Down
10 changes: 10 additions & 0 deletions src/graphnet/models/standard_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ def __init__(
self._scheduler_kwargs = scheduler_kwargs or dict()
self._scheduler_config = scheduler_config or dict()

@property
def target_labels(self) -> List[str]:
"""Return target label."""
return [label for task in self._tasks for label in task._target_labels]

@property
def prediction_labels(self) -> List[str]:
"""Return prediction labels."""
return [label for task in self._tasks for label in task._output_labels]
asogaard marked this conversation as resolved.
Show resolved Hide resolved

def configure_optimizers(self) -> Dict[str, Any]:
"""Configure the model's optimizer(s)."""
optimizer = self._optimizer_class(
Expand Down
Loading