Skip to content

Commit

Permalink
[Feature] Support calculating loss during validation (#1503)
Browse files Browse the repository at this point in the history
  • Loading branch information
fanqiNO1 committed May 17, 2024
1 parent 66fb81f commit d1f1aab
Showing 1 changed file with 81 additions and 1 deletion.
82 changes: 81 additions & 1 deletion mmengine/runner/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from torch.utils.data import DataLoader

from mmengine.evaluator import Evaluator
from mmengine.logging import print_log
from mmengine.logging import HistoryBuffer, print_log
from mmengine.registry import LOOPS
from mmengine.structures import BaseDataElement
from mmengine.utils import is_list_of
from .amp import autocast
from .base_loop import BaseLoop
from .utils import calc_dynamic_intervals
Expand Down Expand Up @@ -363,17 +365,26 @@ def __init__(self,
logger='current',
level=logging.WARNING)
self.fp16 = fp16
self.val_loss: Dict[str, HistoryBuffer] = dict()

def run(self) -> dict:
"""Launch validation."""
self.runner.call_hook('before_val')
self.runner.call_hook('before_val_epoch')
self.runner.model.eval()

# clear val loss
self.val_loss.clear()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)

# compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))

if self.val_loss:
loss_dict = _parse_losses(self.val_loss, 'val')
metrics.update(loss_dict)

self.runner.call_hook('after_val_epoch', metrics=metrics)
self.runner.call_hook('after_val')
return metrics
Expand All @@ -391,6 +402,9 @@ def run_iter(self, idx, data_batch: Sequence[dict]):
# outputs should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
outputs = self.runner.model.val_step(data_batch)

outputs, self.val_loss = _update_losses(outputs, self.val_loss)

self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
'after_val_iter',
Expand Down Expand Up @@ -435,17 +449,26 @@ def __init__(self,
logger='current',
level=logging.WARNING)
self.fp16 = fp16
self.test_loss: Dict[str, HistoryBuffer] = dict()

def run(self) -> dict:
"""Launch test."""
self.runner.call_hook('before_test')
self.runner.call_hook('before_test_epoch')
self.runner.model.eval()

# clear test loss
self.test_loss.clear()
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)

# compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))

if self.test_loss:
loss_dict = _parse_losses(self.test_loss, 'test')
metrics.update(loss_dict)

self.runner.call_hook('after_test_epoch', metrics=metrics)
self.runner.call_hook('after_test')
return metrics
Expand All @@ -462,9 +485,66 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
# predictions should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
outputs = self.runner.model.test_step(data_batch)

outputs, self.test_loss = _update_losses(outputs, self.test_loss)

self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
'after_test_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=outputs)


def _parse_losses(losses: Dict[str, HistoryBuffer],
stage: str) -> Dict[str, float]:
"""Parses the raw losses of the network.
Args:
losses (dict): raw losses of the network.
stage (str): The stage of loss, e.g., 'val' or 'test'.
Returns:
dict[str, float]: The key is the loss name, and the value is the
average loss.
"""
all_loss = 0
loss_dict: Dict[str, float] = dict()

for loss_name, loss_value in losses.items():
avg_loss = loss_value.mean()
loss_dict[loss_name] = avg_loss
if 'loss' in loss_name:
all_loss += avg_loss

loss_dict[f'{stage}_loss'] = all_loss
return loss_dict


def _update_losses(outputs: list, losses: dict) -> Tuple[list, dict]:
"""Update and record the losses of the network.
Args:
outputs (list): The outputs of the network.
losses (dict): The losses of the network.
Returns:
list: The updated outputs of the network.
dict: The updated losses of the network.
"""
if isinstance(outputs[-1],
BaseDataElement) and outputs[-1].keys() == ['loss']:
loss = outputs[-1].loss # type: ignore
outputs = outputs[:-1]
else:
loss = dict()

for loss_name, loss_value in loss.items():
if loss_name not in losses:
losses[loss_name] = HistoryBuffer()
if isinstance(loss_value, torch.Tensor):
losses[loss_name].update(loss_value.item())
elif is_list_of(loss_value, torch.Tensor):
for loss_value_i in loss_value:
losses[loss_name].update(loss_value_i.item())
return outputs, losses

0 comments on commit d1f1aab

Please sign in to comment.