From d138e8910b0fac7979b76ef5d274a1b2c1303598 Mon Sep 17 00:00:00 2001 From: vsey Date: Tue, 3 Mar 2026 16:43:08 +0100 Subject: [PATCH] use inplace foreach op to update model ema --- lightly/models/utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/lightly/models/utils.py b/lightly/models/utils.py index b1911ce16..fb8726b2b 100644 --- a/lightly/models/utils.py +++ b/lightly/models/utils.py @@ -283,8 +283,12 @@ def update_momentum(model: nn.Module, model_ema: nn.Module, m: float): >>> update_momentum(moco, moco_momentum, m=0.999) >>> update_momentum(projection_head, projection_head_momentum, m=0.999) """ - for model_ema, model in zip(model_ema.parameters(), model.parameters()): - model_ema.data = model_ema.data * m + model.data * (1.0 - m) + + model_params = list(model.parameters()) + model_ema_params = list(model_ema.parameters()) + + torch._foreach_mul_(model_ema_params, m) + torch._foreach_add_(model_ema_params, model_params, alpha=1.0 - m) @torch.no_grad()