Skip to content

Reduce peak memory in mid_cnn fusion and load checkpoints to CPU first#2

Merged
sezginerr merged 1 commit into
forithmus:mainfrom
dongyang0122:fix/contrastive-pretraining-oom
Apr 17, 2026
Merged

Reduce peak memory in mid_cnn fusion and load checkpoints to CPU first#2
sezginerr merged 1 commit into
forithmus:mainfrom
dongyang0122:fix/contrastive-pretraining-oom

Conversation

@dongyang0122
Copy link
Copy Markdown
Collaborator

@dongyang0122 dongyang0122 commented Apr 17, 2026

Summary

Two small, independent memory fixes in contrastive-pretraining/mr_rate/mr_rate/mr_rate.py:

  1. mid_cnn fusion mode (_encode_visual_tokens)

    • Before: all r volumes are batched as (b*r, c, d, h, w) and passed through forward_cnn in a single checkpointed call.
    • After: iterate over the r volumes, calling forward_cnn per volume with run_checkpoint, then torch.stack the per-volume features.
    • Effect: peak CNN activation memory scales with one volume instead of r, mirroring the existing per-volume pattern in late / late_attn. No change to numerical result — same features, just materialized sequentially.
  2. MRRATE.load()

    • torch.load(str(path))torch.load(str(path), map_location="cpu").
    • Prevents OOM when loading a checkpoint whose tensors were saved on a different / larger GPU layout (e.g., multi-GPU training → fewer-GPU finetune or eval).

Notes

  • Only contrastive-pretraining/mr_rate/mr_rate/mr_rate.py is touched; no changes to public API, CLI, or checkpoint format.
  • Matches the repo's existing style of sequential per-volume processing with gradient checkpointing.

- mid_cnn fusion: process volumes sequentially through forward_cnn with
  gradient checkpointing instead of batching all r volumes as (b*r, c, d, h, w).
  Peak activation memory now scales with one volume rather than r, matching
  the late/late_attn pattern.

- MRRATE.load: pass map_location="cpu" to torch.load so checkpoints are
  deserialized onto CPU first, avoiding OOM when loading multi-GPU checkpoints
  onto smaller/fewer GPUs.

Made-with: Cursor
@sezginerr sezginerr merged commit f609ca7 into forithmus:main Apr 17, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants