Warm start and frozen teachers#1876
Conversation
…iex/dev/warm-and-frozen-teachers
…iex/dev/warm-and-frozen-teachers
…to be (at least here) identical
clessig
left a comment
There was a problem hiding this comment.
Overall looks fine. I pushed some minor changes. config_jepa.yml has 2D rope param but it's not in here. This should be removed (it was also one of the things that caused problems for me).
| class FrozenTeacher(EncoderTeacher): | ||
| """SSL teacher using a frozen pre-trained encoder. | ||
|
|
||
| The encoder is loaded from a checkpoint and never updated. Non-encoder |
There was a problem hiding this comment.
The teacher_model is assumed to have non-encoder parts discarded, not?
There was a problem hiding this comment.
The code should do the discarding the original model as specified in its config associated to its run id may have an encoder
| self.teacher_model.eval() | ||
|
|
||
| @classmethod | ||
| def from_pretrained(cls, cf: Config, dataset, device, params: dict) -> FrozenTeacher: |
There was a problem hiding this comment.
This function is inconsistent with what is done for EMATeacher in model_interface. Either we have from_pretrained() for both classes or we have the functionality in model_inferface.py
There was a problem hiding this comment.
But they conceptually and functionally do different things, so I don't follow
There was a problem hiding this comment.
Ok, can you then maybe briefly explain what the difference is for you between this here and load_encoder_from_checkpoint()
There was a problem hiding this comment.
Copied from a different reply
because one applies to the meta teacher model, e.g. for EMA from scratch; in this case we have no params, and we must to do something like prepare_encoder_teacher does, vs when we load pre-trained models, we can select the params we want
| 3. Creates fresh latent_heads based on the student's SSL loss config | ||
| """ | ||
| # Strip non-encoder components | ||
| model.forecast_engine = None |
There was a problem hiding this comment.
Can we formulate it as is not encoder so that we are robust to changes in the model design, e.g. we discussed to have a decoder-type model for the stream-specific prediction heads and we will most likely forget this hidden dependency here. Otherwise, we might have a function in model that reduces it to the encoder which is called here.
There was a problem hiding this comment.
Something similar to
encoder_params = {
k: v for k, v in params.items() if k.startswith(("encoder.", "latent_pre_norm"))
}
There was a problem hiding this comment.
okay, will change this
There was a problem hiding this comment.
Actually this is tricky to implement because you need to know what the non-existent state is and be aware of the hierarchy, I can explain more in a call
| logger.warning(f"Unknown SSL loss type {name!r} in teacher setup, skipping.") | ||
|
|
||
|
|
||
| def load_encoder_from_checkpoint( |
There was a problem hiding this comment.
Why do we need this as well as the first part of prepare_encoder_teacher(); it seems to be the same functionality
There was a problem hiding this comment.
because one applies to the meta teacher model, e.g. for EMA from scratch; in this case we have no params, and we must to do something like prepare_encoder_teacher does, vs when we load pre-trained models, we can select the params we want
| @@ -0,0 +1,16 @@ | |||
| training_config: | |||
There was a problem hiding this comment.
How is this config to use used? Maybe we can given an example at the top what pretraining can be used. Copyright is also missing
There was a problem hiding this comment.
it is for testing purposes will remove at the end
| @@ -0,0 +1,7 @@ | |||
| training_config: | |||
There was a problem hiding this comment.
How is this config to use used? Maybe we can given an example at the top what pretraining can be used. Copyright is also missing
…nto sophiex/dev/warm-and-frozen-teachers
when batch is empty for latent loss
|
@clessig pinging this |
…iex/dev/warm-and-frozen-teachers
| @@ -11,7 +11,7 @@ embed_orientation: "channels" | |||
| embed_unembed_mode: "block" | |||
There was a problem hiding this comment.
Revert changes to default config
| from weathergen.model.utils import apply_fct_to_blocks, freeze_weights | ||
| from weathergen.train.target_and_aux_module_base import PhysicalTargetAndAux | ||
| from weathergen.train.target_and_aux_ssl_teacher import EMATeacher | ||
| from weathergen.train.target_and_aux_ssl_teacher import EMATeacher, FrozenTeacher |
| device: Target device | ||
| params: Dict with 'teacher_run_id' and optional 'teacher_mini_epoch' | ||
| """ | ||
| from weathergen.model.model import ModelParams |
| # Load only encoder weights | ||
| load_encoder_from_checkpoint(teacher_model, cf, teacher_run_id, teacher_mini_epoch, device) | ||
|
|
||
| # Strip to encoder + create fresh heads |
There was a problem hiding this comment.
Can you just explain create fresh heads here please? Not sure what that refers to in the context of the teacher? As in create a fresh e.g. identity predictor head?
There was a problem hiding this comment.
The teacher may have had a predictor head before, if it did we strip it, if it needs one as in DINOv2, we remove it
| def compute(self, bidx, batch, model_params, model) -> TargetAuxOutput: | ||
| with torch.no_grad(): | ||
| outputs = self.ema_model.forward_eval(model_params, batch).get_latent_prediction(0) | ||
| outputs = self.forward_teacher(model_params, batch).get_latent_prediction(0) |
There was a problem hiding this comment.
Just curious what is the (0) here?
There was a problem hiding this comment.
fstep, so might be relevant to your latent forecasting!
* Write first solution with Claude * Add test configs, works on santis * Disabling rope; removing model config from finetuning since it needs to be (at least here) identical * Add new JEPA config * Address comments on PR * Linting * Linting * Fixed some corner cases in handling of when batch samples are NaN and when batch is empty for latent loss * Fixed handling of when batch valid is * Fixed path handling * Fixed problems with loading of teacher model * Revert incorrect changes to default_config * Fix problem with missing run_id as dir in path for loading teacher model * Updated logging * Updated config * Address PR review * Lint --------- Co-authored-by: Sophie Xhonneux <sophiex@Sophies-MacBook-Pro.local> Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>
Description
Allow for the warm start with EMA and Frozen Teachers
Issue Number
Closes #1881
Is this PR a draft? Mark it as draft.
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60