Reduce peak memory in mid_cnn fusion and load checkpoints to CPU first#2
Merged
sezginerr merged 1 commit intoApr 17, 2026
Conversation
- mid_cnn fusion: process volumes sequentially through forward_cnn with gradient checkpointing instead of batching all r volumes as (b*r, c, d, h, w). Peak activation memory now scales with one volume rather than r, matching the late/late_attn pattern. - MRRATE.load: pass map_location="cpu" to torch.load so checkpoints are deserialized onto CPU first, avoiding OOM when loading multi-GPU checkpoints onto smaller/fewer GPUs. Made-with: Cursor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Two small, independent memory fixes in
contrastive-pretraining/mr_rate/mr_rate/mr_rate.py:mid_cnnfusion mode (_encode_visual_tokens)rvolumes are batched as(b*r, c, d, h, w)and passed throughforward_cnnin a single checkpointed call.rvolumes, callingforward_cnnper volume withrun_checkpoint, thentorch.stackthe per-volume features.r, mirroring the existing per-volume pattern inlate/late_attn. No change to numerical result — same features, just materialized sequentially.MRRATE.load()torch.load(str(path))→torch.load(str(path), map_location="cpu").Notes
contrastive-pretraining/mr_rate/mr_rate/mr_rate.pyis touched; no changes to public API, CLI, or checkpoint format.