Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow checkpointing setup to be customized #26

Merged
merged 1 commit into from
Sep 1, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 14 additions & 5 deletions zoobot/pytorch/training/train_with_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ def train_default_zoobot_from_scratch(
mixed_precision=False,
# replication parameters
random_state=42,
wandb_logger=None
wandb_logger=None,
# checkpointing
checkpoint_file_template=None,
auto_insert_metric_name=True,
save_top_k=3
):

slurm_debugging_logs()
Expand Down Expand Up @@ -87,13 +91,13 @@ def train_default_zoobot_from_scratch(
if (gpus is not None) and (num_workers * gpus > os.cpu_count()):
logging.warning(
"""num_workers * gpu > num cpu.
You may be spawning more dataloader workers than you have cpus, causing bottlenecks.
You may be spawning more dataloader workers than you have cpus, causing bottlenecks.
Suggest reducing num_workers."""
)
if num_workers > os.cpu_count():
logging.warning(
"""num_workers > num cpu.
You may be spawning more dataloader workers than you have cpus, causing bottlenecks.
You may be spawning more dataloader workers than you have cpus, causing bottlenecks.
Suggest reducing num_workers."""
)

Expand Down Expand Up @@ -154,7 +158,12 @@ def loss_func(preds, labels): # pytorch convention is preds, labels
monitor="val/supervised_loss",
save_weights_only=True,
mode='min',
save_top_k=3
# custom filename for checkpointing due to / in metric
filename=checkpoint_file_template,
# https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint.params.auto_insert_metric_name
# avoid extra folders from the checkpoint name
auto_insert_metric_name=auto_insert_metric_name,
save_top_k=save_top_k
),
EarlyStopping(monitor='val/supervised_loss', patience=patience, check_finite=True)
]
Expand Down Expand Up @@ -202,7 +211,7 @@ def select_base_architecture_func_from_name(base_architecture):
else:
raise ValueError(
'Model architecture not recognised: got model={}, expected one of [efficientnet, resnet_detectron, resnet_torchvision]'.format(base_architecture))

return get_architecture,representation_dim


Expand Down