Implement opt hparams per stream#2029
Conversation
| :return: List of param group dicts for torch.optim.AdamW. | ||
| """ | ||
| # unwrap DDP if necessary | ||
| raw_model = model.module if hasattr(model, "module") else model |
There was a problem hiding this comment.
Isn't there a better way to detect this. In trainer we should also retain a handle to the original model (don't think we do it now).
|
|
||
| default_wd = optimizer_cfg.weight_decay | ||
| stream_param_ids: set[int] = set() | ||
| groups: list[dict] = [] |
There was a problem hiding this comment.
Is there a reason this is not a dict of dicts with the outer one having the stream name as key.
| ) | ||
|
|
||
| # shared group: everything not assigned to a stream | ||
| shared_params = [ |
There was a problem hiding this comment.
It would be more natural to drop the stream parameters above rather than working with id()
| if is_root(): | ||
| for g in groups: | ||
| logger.info( | ||
| f"Param group '{g['name']}': {len(g['params'])} params, " |
There was a problem hiding this comment.
This should be prefaced with Optimizer parameters or something similar
| self.model, stream_optimizer_cfgs, self.training_cfg.optimizer | ||
| ) | ||
| lr_start = self.training_cfg.learning_rate_scheduling.lr_start | ||
| for g in param_groups: |
There was a problem hiding this comment.
This should be done in build_param_groups()
|
|
||
| import weathergen.common.config as config | ||
| from weathergen.train.utils import TRAIN | ||
| from weathergen.evaluate.plotting.plot_utils import create_filename |
There was a problem hiding this comment.
If we want to have the changes in this file, then let's please open a separate PR that also move create_filename to packages/common/src/weathergen/common/paths.py
| and col_split[3] == channel | ||
| and int(col_split[4]) in forecast_steps | ||
| ): | ||
| if col == stream_name.lower(): |
There was a problem hiding this comment.
It seems this removes some branches from the current code? This doesn't seem functionally equivalent
| nargs="+", | ||
| help="List of channels to plot", | ||
| ) | ||
| parser.add_argument( |
| nargs="+", | ||
| help="List of metrics (e.g. mse) to plot", | ||
| ) | ||
| parser.add_argument( |
| clean_plot_folder(out_dir) | ||
|
|
||
| # collect all physical streams from all run_ids if requested | ||
| if "all" in streams: |
Description
See issue and commit messages
Issue Number
Closes #2028
Is this PR a draft? Mark it as draft.
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60