Skip to content

Commit

Permalink
[Enhance] Extract parse losses
Browse files Browse the repository at this point in the history
  • Loading branch information
fanqiNO1 committed Mar 6, 2024
1 parent e0b4cea commit 4d76d6c
Showing 1 changed file with 42 additions and 26 deletions.
68 changes: 42 additions & 26 deletions mmengine/runner/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.utils.data import DataLoader

from mmengine.evaluator import Evaluator
from mmengine.logging import print_log
from mmengine.logging import HistoryBuffer, print_log
from mmengine.registry import LOOPS
from mmengine.structures import BaseDataElement
from mmengine.utils import is_list_of
Expand Down Expand Up @@ -363,7 +363,7 @@ def __init__(self,
logger='current',
level=logging.WARNING)
self.fp16 = fp16
self.val_loss: Dict[str, list] = dict()
self.val_loss: Dict[str, HistoryBuffer] = dict()

def run(self) -> dict:
"""Launch validation."""
Expand All @@ -380,14 +380,8 @@ def run(self) -> dict:
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))

if self.val_loss:
# get val loss and save to metrics
val_loss = 0
for loss_name, loss_value in self.val_loss.items():
avg_loss = sum(loss_value) / len(loss_value)
metrics[loss_name] = avg_loss
if 'loss' in loss_name:
val_loss += avg_loss # type: ignore
metrics['val_loss'] = val_loss
loss_dict = _parse_losses(self.val_loss, 'val')
metrics.update(loss_dict)

self.runner.call_hook('after_val_epoch', metrics=metrics)
self.runner.call_hook('after_val')
Expand All @@ -406,6 +400,7 @@ def run_iter(self, idx, data_batch: Sequence[dict]):
# outputs should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
outputs = self.runner.model.val_step(data_batch)

if isinstance(outputs[-1],
BaseDataElement) and outputs[-1].keys() == ['loss']:
loss = outputs[-1].loss # type: ignore
Expand All @@ -415,11 +410,12 @@ def run_iter(self, idx, data_batch: Sequence[dict]):
# get val loss and avoid breaking change
for loss_name, loss_value in loss.items():
if loss_name not in self.val_loss:
self.val_loss[loss_name] = []
self.val_loss[loss_name] = HistoryBuffer()
if isinstance(loss_value, torch.Tensor):
self.val_loss[loss_name].append(loss_value.item())
self.val_loss[loss_name].update(loss_value.item())
elif is_list_of(loss_value, torch.Tensor):
self.val_loss[loss_name].extend([v.item() for v in loss_value])
for loss_value_i in loss_value:
self.val_loss[loss_name].update(loss_value_i.item())

self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
Expand Down Expand Up @@ -465,7 +461,7 @@ def __init__(self,
logger='current',
level=logging.WARNING)
self.fp16 = fp16
self.test_loss: Dict[str, list] = dict()
self.test_loss: Dict[str, HistoryBuffer] = dict()

def run(self) -> dict:
"""Launch test."""
Expand All @@ -482,14 +478,8 @@ def run(self) -> dict:
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))

if self.test_loss:
# get test loss and save to metrics
test_loss = 0
for loss_name, loss_value in self.test_loss.items():
avg_loss = sum(loss_value) / len(loss_value)
metrics[loss_name] = avg_loss
if 'loss' in loss_name:
test_loss += avg_loss # type: ignore
metrics['test_loss'] = test_loss
loss_dict = _parse_losses(self.test_loss, 'test')
metrics.update(loss_dict)

self.runner.call_hook('after_test_epoch', metrics=metrics)
self.runner.call_hook('after_test')
Expand All @@ -507,6 +497,7 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
# predictions should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
outputs = self.runner.model.test_step(data_batch)

if isinstance(outputs[-1],
BaseDataElement) and outputs[-1].keys() == ['loss']:
loss = outputs[-1].loss # type: ignore
Expand All @@ -516,16 +507,41 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
# get val loss and avoid breaking change
for loss_name, loss_value in loss.items():
if loss_name not in self.test_loss:
self.test_loss[loss_name] = []
self.test_loss[loss_name] = HistoryBuffer()
if isinstance(loss_value, torch.Tensor):
self.test_loss[loss_name].append(loss_value.item())
self.test_loss[loss_name].update(loss_value.item())
elif is_list_of(loss_value, torch.Tensor):
self.test_loss[loss_name].extend(
[v.item() for v in loss_value])
for loss_value_i in loss_value:
self.test_loss[loss_name].update(loss_value_i.item())

self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
'after_test_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=outputs)


def _parse_losses(losses: Dict[str, HistoryBuffer],
stage: str) -> Dict[str, float]:
"""Parses the raw losses of the network.
Args:
losses (dict): raw losses of the network.
stage (str): The stage of loss, e.g., 'val' or 'test'.
Returns:
dict[str, float]: The key is the loss name, and the value is the
average loss.
"""
all_loss = 0
loss_dict: Dict[str, float] = dict()

for loss_name, loss_value in losses.items():
avg_loss = loss_value.mean()
loss_dict[loss_name] = avg_loss
if 'loss' in loss_name:
all_loss += avg_loss

loss_dict[f'{stage}_loss'] = all_loss
return loss_dict

0 comments on commit 4d76d6c

Please sign in to comment.