Skip to content

Warm start and frozen teachers#1876

Merged
sophie-xhonneux merged 25 commits intodevelopfrom
sophiex/dev/warm-and-frozen-teachers
Mar 27, 2026
Merged

Warm start and frozen teachers#1876
sophie-xhonneux merged 25 commits intodevelopfrom
sophiex/dev/warm-and-frozen-teachers

Conversation

@sophie-xhonneux
Copy link
Copy Markdown
Contributor

@sophie-xhonneux sophie-xhonneux commented Feb 18, 2026

Description

Allow for the warm start with EMA and Frozen Teachers

Issue Number

Closes #1881

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

@sophie-xhonneux sophie-xhonneux changed the title Write first solution with Claude Warm start and frozen teachers Feb 19, 2026
@github-actions github-actions Bot added the model Related to model training or definition (not generic infra) label Feb 19, 2026
Copy link
Copy Markdown
Collaborator

@clessig clessig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks fine. I pushed some minor changes. config_jepa.yml has 2D rope param but it's not in here. This should be removed (it was also one of the things that caused problems for me).

Comment thread src/weathergen/model/model_interface.py Outdated
Comment thread src/weathergen/train/target_and_aux_ssl_teacher.py Outdated
class FrozenTeacher(EncoderTeacher):
"""SSL teacher using a frozen pre-trained encoder.

The encoder is loaded from a checkpoint and never updated. Non-encoder
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The teacher_model is assumed to have non-encoder parts discarded, not?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code should do the discarding the original model as specified in its config associated to its run id may have an encoder

self.teacher_model.eval()

@classmethod
def from_pretrained(cls, cf: Config, dataset, device, params: dict) -> FrozenTeacher:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is inconsistent with what is done for EMATeacher in model_interface. Either we have from_pretrained() for both classes or we have the functionality in model_inferface.py

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But they conceptually and functionally do different things, so I don't follow

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, can you then maybe briefly explain what the difference is for you between this here and load_encoder_from_checkpoint()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied from a different reply

because one applies to the meta teacher model, e.g. for EMA from scratch; in this case we have no params, and we must to do something like prepare_encoder_teacher does, vs when we load pre-trained models, we can select the params we want

Comment thread src/weathergen/train/teacher_utils.py Outdated
Comment thread src/weathergen/train/teacher_utils.py
3. Creates fresh latent_heads based on the student's SSL loss config
"""
# Strip non-encoder components
model.forecast_engine = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we formulate it as is not encoder so that we are robust to changes in the model design, e.g. we discussed to have a decoder-type model for the stream-specific prediction heads and we will most likely forget this hidden dependency here. Otherwise, we might have a function in model that reduces it to the encoder which is called here.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something similar to

    encoder_params = {
        k: v for k, v in params.items() if k.startswith(("encoder.", "latent_pre_norm"))
    }

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, will change this

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is tricky to implement because you need to know what the non-existent state is and be aware of the hierarchy, I can explain more in a call

logger.warning(f"Unknown SSL loss type {name!r} in teacher setup, skipping.")


def load_encoder_from_checkpoint(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this as well as the first part of prepare_encoder_teacher(); it seems to be the same functionality

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because one applies to the meta teacher model, e.g. for EMA from scratch; in this case we have no params, and we must to do something like prepare_encoder_teacher does, vs when we load pre-trained models, we can select the params we want

Comment thread config/config_ema_warm_start.yml Outdated
@@ -0,0 +1,16 @@
training_config:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this config to use used? Maybe we can given an example at the top what pretraining can be used. Copyright is also missing

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is for testing purposes will remove at the end

Comment thread config/config_frozen_teacher.yml Outdated
@@ -0,0 +1,7 @@
training_config:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this config to use used? Maybe we can given an example at the top what pretraining can be used. Copyright is also missing

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above

@sophie-xhonneux
Copy link
Copy Markdown
Contributor Author

@clessig pinging this

Comment thread config/config_ema_warm_start.yml Outdated
Comment thread config/config_frozen_teacher.yml Outdated
Comment thread config/config_jepa_frozen_mtm_sweep.yml Outdated
Comment thread config/default_config.yml
@@ -11,7 +11,7 @@ embed_orientation: "channels"
embed_unembed_mode: "block"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert changes to default config

from weathergen.model.utils import apply_fct_to_blocks, freeze_weights
from weathergen.train.target_and_aux_module_base import PhysicalTargetAndAux
from weathergen.train.target_and_aux_ssl_teacher import EMATeacher
from weathergen.train.target_and_aux_ssl_teacher import EMATeacher, FrozenTeacher
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

device: Target device
params: Dict with 'teacher_run_id' and optional 'teacher_mini_epoch'
"""
from weathergen.model.model import ModelParams
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imports here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# Load only encoder weights
load_encoder_from_checkpoint(teacher_model, cf, teacher_run_id, teacher_mini_epoch, device)

# Strip to encoder + create fresh heads
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just explain create fresh heads here please? Not sure what that refers to in the context of the teacher? As in create a fresh e.g. identity predictor head?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The teacher may have had a predictor head before, if it did we strip it, if it needs one as in DINOv2, we remove it

def compute(self, bidx, batch, model_params, model) -> TargetAuxOutput:
with torch.no_grad():
outputs = self.ema_model.forward_eval(model_params, batch).get_latent_prediction(0)
outputs = self.forward_teacher(model_params, batch).get_latent_prediction(0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious what is the (0) here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fstep, so might be relevant to your latent forecasting!

Copy link
Copy Markdown
Contributor

@shmh40 shmh40 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor comments. Tested that frozen and warm works pretty extensively. Would also be interesting if you run it through the copilot review.

@sophie-xhonneux sophie-xhonneux merged commit 9b47c27 into develop Mar 27, 2026
5 checks passed
@sophie-xhonneux sophie-xhonneux self-assigned this Mar 27, 2026
wael-mika pushed a commit to wael-mika/WeatherGenerator that referenced this pull request Apr 13, 2026
* Write first solution with Claude

* Add test configs, works on santis

* Disabling rope; removing model config from finetuning since it needs to be (at least here) identical

* Add new JEPA config

* Address comments on PR

* Linting

* Linting

* Fixed some corner cases in handling of when batch samples are NaN and
when batch is empty for latent loss

* Fixed handling of when batch valid is

* Fixed path handling

* Fixed problems with loading of teacher model

* Revert incorrect changes to default_config

* Fix problem with missing run_id as dir in path for loading teacher model

* Updated logging

* Updated config

* Address PR review

* Lint

---------

Co-authored-by: Sophie Xhonneux <sophiex@Sophies-MacBook-Pro.local>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model:pretrain model Related to model training or definition (not generic infra)

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

Allow for frozen teachers and warm starts

3 participants