From c793214540c1a9d506e4e13f530471a204e7dcdf Mon Sep 17 00:00:00 2001 From: yivona08 <1017201439@qq.com> Date: Thu, 8 Dec 2022 16:20:47 +0800 Subject: [PATCH] fix fpn distill --- .../algorithms/distill/configurable/fpn_teacher_distill.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mmrazor/models/algorithms/distill/configurable/fpn_teacher_distill.py b/mmrazor/models/algorithms/distill/configurable/fpn_teacher_distill.py index db995d107..9d87d9139 100644 --- a/mmrazor/models/algorithms/distill/configurable/fpn_teacher_distill.py +++ b/mmrazor/models/algorithms/distill/configurable/fpn_teacher_distill.py @@ -30,15 +30,16 @@ def loss( # If the `override_data` of a delivery is False, the delivery will # record the origin data. self.distiller.set_deliveries_override(False) + + # Unlike ``SingleTeacherDistill``, teacher will only execute + # back + neck, not head, so there will be no loss. if self.teacher_trainable: - # Unlike ``SingleTeacherDistill``, teacher will only execute - # back + neck, not head, so there will be no loss. with self.distiller.teacher_recorders, self.distiller.deliveries: _ = self.teacher.extract_feat(batch_inputs) else: with self.distiller.teacher_recorders, self.distiller.deliveries: with torch.no_grad(): - _ = self.teacher(batch_inputs, data_samples, mode='loss') + _ = self.teacher.extract_feat(batch_inputs) # If the `override_data` of a delivery is True, the delivery will # override the origin data with the recorded data.