Skip to content

Commit

Permalink
refactor metric aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Feb 24, 2023
1 parent 60b1ab3 commit 244f484
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import warnings
import logging

import torch

from rastervision.pytorch_learner.learner import Learner
from rastervision.pytorch_learner.utils import (compute_conf_mat_metrics,
compute_conf_mat)
from rastervision.pytorch_learner.utils import (
compute_conf_mat_metrics, compute_conf_mat, aggregate_metrics)
from rastervision.pytorch_learner.dataset.visualizer import (
ClassificationVisualizer)

Expand Down Expand Up @@ -34,16 +32,12 @@ def validate_step(self, batch, batch_ind):

return {'val_loss': val_loss, 'conf_mat': conf_mat}

def validate_end(self, outputs, num_samples):
def validate_end(self, outputs):
metrics = aggregate_metrics(outputs, exclude_keys={'conf_mat'})
conf_mat = sum([o['conf_mat'] for o in outputs])
val_loss = torch.stack([o['val_loss']
for o in outputs]).sum() / num_samples
conf_mat_metrics = compute_conf_mat_metrics(conf_mat,
self.cfg.data.class_names)

metrics = {'val_loss': val_loss.item()}
metrics.update(conf_mat_metrics)

return metrics

def prob_to_pred(self, x):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from rastervision.pipeline.utils import terminate_at_exit
from rastervision.pipeline.config import (build_config, upgrade_config,
save_pipeline_config)
from rastervision.pytorch_learner.utils import (get_hubconf_dir_from_cfg,
log_metrics_to_csv)
from rastervision.pytorch_learner.utils import (
get_hubconf_dir_from_cfg, aggregate_metrics, log_metrics_to_csv)
from rastervision.pytorch_learner.dataset.visualizer import Visualizer

if TYPE_CHECKING:
Expand Down Expand Up @@ -524,33 +524,23 @@ def validate_step(self, batch: Any, batch_ind: int) -> MetricDict:
"""
pass

def train_end(self, outputs: List[MetricDict],
num_samples: int) -> MetricDict:
def train_end(self, outputs: List[Dict[str, Union[float, Tensor]]]
) -> MetricDict:
"""Aggregate the ouput of train_step at the end of the epoch.
Args:
outputs: a list of outputs of train_step
num_samples: total number of training samples processed in epoch
"""
metrics = {}
for k in outputs[0].keys():
metrics[k] = torch.stack([o[k] for o in outputs
]).sum().item() / num_samples
return metrics
return aggregate_metrics(outputs)

def validate_end(self, outputs: List[MetricDict],
num_samples: int) -> MetricDict:
def validate_end(self, outputs: List[Dict[str, Union[float, Tensor]]]
) -> MetricDict:
"""Aggregate the ouput of validate_step at the end of the epoch.
Args:
outputs: a list of outputs of validate_step
num_samples: total number of validation samples processed in epoch
"""
metrics = {}
for k in outputs[0].keys():
metrics[k] = torch.stack([o[k] for o in outputs
]).sum().item() / num_samples
return metrics
return aggregate_metrics(outputs)

def post_forward(self, x: Any) -> Any:
"""Post process output of call to model().
Expand Down Expand Up @@ -1162,7 +1152,6 @@ def train_epoch(
"""Train for a single epoch."""
start = time.time()
self.model.train()
num_samples = 0
outputs = []
with tqdm(self.train_dl, desc='Training') as bar:
for batch_ind, (x, y) in enumerate(bar):
Expand All @@ -1179,10 +1168,9 @@ def train_epoch(
outputs.append(output)
if step_scheduler is not None:
step_scheduler.step()
num_samples += x.shape[0]
if len(outputs) == 0:
raise ValueError('Training dataset did not return any batches')
metrics = self.train_end(outputs, num_samples)
metrics = self.train_end(outputs)
end = time.time()
train_time = datetime.timedelta(seconds=end - start)
metrics['train_time'] = str(train_time)
Expand All @@ -1192,7 +1180,6 @@ def validate_epoch(self, dl: DataLoader) -> MetricDict:
"""Validate for a single epoch."""
start = time.time()
self.model.eval()
num_samples = 0
outputs = []
with torch.inference_mode():
with tqdm(dl, desc='Validating') as bar:
Expand All @@ -1202,11 +1189,10 @@ def validate_epoch(self, dl: DataLoader) -> MetricDict:
batch = (x, y)
output = self.validate_step(batch, batch_ind)
outputs.append(output)
num_samples += x.shape[0]
end = time.time()
validate_time = datetime.timedelta(seconds=end - start)

metrics = self.validate_end(outputs, num_samples)
metrics = self.validate_end(outputs)
metrics['valid_time'] = str(validate_time)
return metrics

Expand Down Expand Up @@ -1253,7 +1239,7 @@ def train(self, epochs: Optional[int] = None):
self.epoch_scheduler.step()
valid_metrics = self.validate_epoch(self.valid_dl)
metrics = dict(epoch=epoch, **train_metrics, **valid_metrics)
log.info(f'metrics:\n{pformat(metrics)}')
log.info(f'metrics:\n{pformat(metrics, sort_dicts=False)}')

self.on_epoch_end(epoch, metrics)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def validate_step(self, batch, batch_ind):

return {'ys': ys, 'outs': outs}

def validate_end(self, outputs, num_samples):
def validate_end(self, outputs):
outs = []
ys = []
for o in outputs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from torch.nn import functional as F

from rastervision.pytorch_learner.learner import Learner
from rastervision.pytorch_learner.utils import (compute_conf_mat_metrics,
compute_conf_mat)
from rastervision.pytorch_learner.utils import (
compute_conf_mat_metrics, compute_conf_mat, aggregate_metrics)
from rastervision.pytorch_learner.dataset.visualizer import (
SemanticSegmentationVisualizer)

Expand Down Expand Up @@ -38,16 +38,12 @@ def validate_step(self, batch, batch_ind):

return {'val_loss': val_loss, 'conf_mat': conf_mat}

def validate_end(self, outputs, num_samples):
def validate_end(self, outputs):
metrics = aggregate_metrics(outputs, exclude_keys={'conf_mat'})
conf_mat = sum([o['conf_mat'] for o in outputs])
val_loss = torch.stack([o['val_loss']
for o in outputs]).sum() / num_samples
conf_mat_metrics = compute_conf_mat_metrics(conf_mat,
self.cfg.data.class_names)

metrics = {'val_loss': val_loss.item()}
metrics.update(conf_mat_metrics)

return metrics

def post_forward(self, x):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, Sequence, Tuple, Optional, Union, List, Iterable
from typing import (Any, Dict, Sequence, Tuple, Optional, Union, List,
Iterable, Container)
from os.path import basename, join, isfile
import logging

Expand Down Expand Up @@ -359,3 +360,35 @@ def log_metrics_to_csv(csv_path: str, metrics: Dict[str, Any]):
log_file_exists = isfile(csv_path)
metrics_df.to_csv(
csv_path, mode='a', header=(not log_file_exists), index=False)


def aggregate_metrics(
outputs: List[Dict[str, Union[float, torch.Tensor]]],
exclude_keys: Container[str] = set('conf_mat')) -> Dict[str, float]:
"""Aggregate the ouput of validate_step at the end of the epoch.
Args:
outputs: A list of outputs of Learner.validate_step().
exclude_keys: Keys to ignore. These will not be aggregated and will not
be included in the output. Defaults to {'conf_mat'}.
Returns:
Dict[str, float]: Dict with aggregated values.
"""
metrics = {}
metric_names = outputs[0].keys()
for metric_name in metric_names:
if metric_name in exclude_keys:
continue
metric_vals = [out[metric_name] for out in outputs]
elem = metric_vals[0]
if isinstance(elem, torch.Tensor):
if elem.ndim == 0:
metric_vals = torch.stack(metric_vals)
else:
metric_vals = torch.cat(metric_vals)
metric_avg = metric_vals.float().mean().item()
else:
metric_avg = sum(metric_vals) / len(metric_vals)
metrics[metric_name] = metric_avg
return metrics
45 changes: 44 additions & 1 deletion tests/pytorch_learner/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
validate_albumentation_transform, A, color_to_triple,
channel_groups_to_imgs, plot_channel_groups,
serialize_albumentation_transform, deserialize_albumentation_transform,
log_metrics_to_csv)
aggregate_metrics, log_metrics_to_csv)
from tests.data_files.lambda_transforms import lambda_transforms
from tests import data_file_path

Expand Down Expand Up @@ -227,6 +227,49 @@ def test_channel_expansion(self):
self._test_attribs_equal(old_conv, new_conv_2[1][1])


class TestAggregateMetrics(unittest.TestCase):
def assertNoError(self, fn: Callable, msg: str = ''):
try:
fn()
except Exception:
self.fail(msg)

def test_scalars(self):
outputs = [
dict(train_loss=0.),
dict(train_loss=1.),
]
metrics = aggregate_metrics(outputs)
self.assertIn('train_loss', metrics)
self.assertEqual(metrics['train_loss'], 0.5)

def test_tensors_zero_dim(self):
outputs = [
dict(key=torch.tensor(0)),
dict(key=torch.tensor(1)),
]
metrics = aggregate_metrics(outputs)
self.assertIn('key', metrics)
self.assertEqual(metrics['key'], 0.5)

def test_tensors(self):
outputs = [
dict(key=torch.zeros(8)),
dict(key=torch.ones(8)),
]
metrics = aggregate_metrics(outputs)
self.assertIn('key', metrics)
self.assertEqual(metrics['key'], 0.5)

def test_exclude(self):
outputs = [
dict(conf_mat=torch.randint(0, 100, (2, 2))),
dict(conf_mat=torch.randint(0, 100, (2, 2))),
]
metrics = aggregate_metrics(outputs, exclude_keys={'conf_mat'})
self.assertNotIn('conf_mat', metrics)


class TestOtherUtils(unittest.TestCase):
def assertNoError(self, fn: Callable, msg: str = ''):
try:
Expand Down

0 comments on commit 244f484

Please sign in to comment.