Skip to content

Implement opt hparams per stream#2029

Open
sophie-xhonneux wants to merge 3 commits intodevelopfrom
sophiex/dev/per-embed-lr
Open

Implement opt hparams per stream#2029
sophie-xhonneux wants to merge 3 commits intodevelopfrom
sophiex/dev/per-embed-lr

Conversation

@sophie-xhonneux
Copy link
Copy Markdown
Contributor

Description

See issue and commit messages

Issue Number

Closes #2028

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

@github-actions github-actions Bot added the model Related to model training or definition (not generic infra) label Mar 14, 2026
:return: List of param group dicts for torch.optim.AdamW.
"""
# unwrap DDP if necessary
raw_model = model.module if hasattr(model, "module") else model
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there a better way to detect this. In trainer we should also retain a handle to the original model (don't think we do it now).


default_wd = optimizer_cfg.weight_decay
stream_param_ids: set[int] = set()
groups: list[dict] = []
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this is not a dict of dicts with the outer one having the stream name as key.

)

# shared group: everything not assigned to a stream
shared_params = [
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be more natural to drop the stream parameters above rather than working with id()

if is_root():
for g in groups:
logger.info(
f"Param group '{g['name']}': {len(g['params'])} params, "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be prefaced with Optimizer parameters or something similar

self.model, stream_optimizer_cfgs, self.training_cfg.optimizer
)
lr_start = self.training_cfg.learning_rate_scheduling.lr_start
for g in param_groups:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be done in build_param_groups()

@github-actions github-actions Bot added the eval anything related to the model evaluation pipeline label Mar 16, 2026

import weathergen.common.config as config
from weathergen.train.utils import TRAIN
from weathergen.evaluate.plotting.plot_utils import create_filename
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to have the changes in this file, then let's please open a separate PR that also move create_filename to packages/common/src/weathergen/common/paths.py

and col_split[3] == channel
and int(col_split[4]) in forecast_steps
):
if col == stream_name.lower():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this removes some branches from the current code? This doesn't seem functionally equivalent

nargs="+",
help="List of channels to plot",
)
parser.add_argument(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this option?

nargs="+",
help="List of metrics (e.g. mse) to plot",
)
parser.add_argument(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this option?

clean_plot_folder(out_dir)

# collect all physical streams from all run_ids if requested
if "all" in streams:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this option?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

eval anything related to the model evaluation pipeline model Related to model training or definition (not generic infra)

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

Learning rate per stream

2 participants