From 15264c3e50c64c080f2169d72fd8dd8dcc8f29f1 Mon Sep 17 00:00:00 2001 From: Zhicheng Yan Date: Wed, 4 Dec 2019 18:39:25 -0800 Subject: [PATCH] mean average precision meter (#265) Summary: Pull Request resolved: https://github.com/facebookresearch/ClassyVision/pull/265 Average Precision (AP) summarizes a precision-recall curve as the weighted mean of precisions achieved at each threshold, with the increase in recall from the previous threshold used as the weight. See AP formal definition (https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html). Mean Average Precision is an average of AP over different classes. In this diff, we implement Mean Average Precision meter that is useful for multi-label classification task. We implement a simple `SparseBinaryMatrix` to store multi-hot groundtruth label to save memory. It also supports `max_capacity` argument to limit the memory footprint of the meter because model predictions, which is stored as a dense matrix of size N x K ( N is number of samples, and K number of classes), can be quite large specially for training set. For example, SSVP has a training set of size 2.6M, and has 7K+ classes. Differential Revision: D18715190 fbshipit-source-id: b144cf84ff3cf589d9729e925a0356a9bc9a5710 --- classy_vision/generic/distributed_util.py | 27 ++++-- classy_vision/hooks/tensorboard_plot_hook.py | 12 ++- test/generic/meter_test_utils.py | 94 ++++++++++++++------ 3 files changed, 98 insertions(+), 35 deletions(-) diff --git a/classy_vision/generic/distributed_util.py b/classy_vision/generic/distributed_util.py index 3deab3df29..ec5211f412 100644 --- a/classy_vision/generic/distributed_util.py +++ b/classy_vision/generic/distributed_util.py @@ -82,7 +82,7 @@ def all_reduce_sum(tensor): return tensor -def gather_from_all(tensor): +def gather_tensors_from_all(tensor): """ Wrapper over torch.distributed.all_gather for performing 'gather' of 'tensor' over all processes in both distributed / @@ -91,11 +91,26 @@ def gather_from_all(tensor): if tensor.ndim == 0: # 0 dim tensors cannot be gathered. so unsqueeze tensor = tensor.unsqueeze(0) - gathered_tensor = [ - torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size()) - ] - torch.distributed.all_gather(gathered_tensor, tensor) - gathered_tensor = torch.cat(gathered_tensor, 0) + + if is_distributed_training_run(): + tensor, orig_device = convert_to_distributed_tensor(tensor) + gathered_tensors = [ + torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(gathered_tensors, tensor) + gathered_tensors = [ + convert_to_normal_tensor(_tensor, orig_device) + for _tensor in gathered_tensors + ] + else: + gathered_tensors = [tensor] + + return gathered_tensors + + +def gather_from_all(tensor): + gathered_tensors = gather_tensors_from_all(tensor) + gathered_tensor = torch.cat(gathered_tensors, 0) return gathered_tensor diff --git a/classy_vision/hooks/tensorboard_plot_hook.py b/classy_vision/hooks/tensorboard_plot_hook.py index 35f55ac63b..f43ee77d9c 100644 --- a/classy_vision/hooks/tensorboard_plot_hook.py +++ b/classy_vision/hooks/tensorboard_plot_hook.py @@ -129,7 +129,15 @@ def on_phase_end( log.warn(f"Skipping meter {meter.name} with value: {meter.value}") continue for name, value in meter.value.items(): - meter_key = f"{phase_type}_{meter.name}_{name}" - self.tb_writer.add_scalar(meter_key, value, global_step=phase_type_idx) + if isinstance(value, float): + meter_key = f"{phase_type}_{meter.name}_{name}" + self.tb_writer.add_scalar( + meter_key, value, global_step=phase_type_idx + ) + else: + log.warn( + f"Skipping meter name {meter.name}_{name} with value: {value}" + ) + continue logging.info(f"Done plotting to Tensorboard") diff --git a/test/generic/meter_test_utils.py b/test/generic/meter_test_utils.py index 7d489f8df4..1ec4bea6a6 100644 --- a/test/generic/meter_test_utils.py +++ b/test/generic/meter_test_utils.py @@ -106,23 +106,45 @@ def _apply_updates_and_test_meter( self.assertTrue( key in meter_value, msg="{0} not in meter value!".format(key) ) - self.assertAlmostEqual( - meter_value[key], - val, - places=4, - msg="{0} meter value mismatch!".format(key), - ) + if torch.is_tensor(meter_value[key]): + self.assertTrue( + torch.all(torch.eq(meter_value[key], val)), + msg="{0} meter value mismatch!".format(key), + ) + else: + self.assertAlmostEqual( + meter_value[key], + val, + places=4, + msg="{0} meter value mismatch!".format(key), + ) def _values_match_expected_value(self, value0, value1, expected_value): for key, val in expected_value.items(): self.assertTrue(key in value0, msg="{0} not in meter value!".format(key)) - self.assertAlmostEqual( - value0[key], val, places=4, msg="{0} meter value mismatch!".format(key) - ) self.assertTrue(key in value1, msg="{0} not in meter value!".format(key)) - self.assertAlmostEqual( - value1[key], val, places=4, msg="{0} meter value mismatch!".format(key) - ) + if torch.is_tensor(val): + self.assertTrue( + torch.all(torch.eq(value0[key], val)), + "{0} meter value mismatch!".format(key), + ) + self.assertTrue( + torch.all(torch.eq(value1[key], val)), + "{0} meter value mismatch!".format(key), + ) + else: + self.assertAlmostEqual( + value0[key], + val, + places=4, + msg="{0} meter value mismatch!".format(key), + ) + self.assertAlmostEqual( + value1[key], + val, + places=4, + msg="{0} meter value mismatch!".format(key), + ) def meter_update_and_reset_test( self, meter, model_outputs, targets, expected_value, **kwargs @@ -192,25 +214,43 @@ def meter_get_set_classy_state_test( meter1.sync_state() value1 = meter1.value for key, val in value0.items(): - self.assertNotEqual( - value1[key], val, msg="{0} meter values should not be same!".format(key) - ) + if torch.is_tensor(value1[key]): + self.assertFalse( + torch.all(torch.eq(value1[key], val)), + msg="{0} meter values should not be same!".format(key), + ) + else: + self.assertNotEqual( + value1[key], + val, + msg="{0} meter values should not be same!".format(key), + ) meter0.set_classy_state(meter1.get_classy_state()) value0 = meter0.value for key, val in value0.items(): - self.assertAlmostEqual( - value1[key], - val, - places=4, - msg="{0} meter value mismatch after state transfer!".format(key), - ) - self.assertAlmostEqual( - value1[key], - expected_value[key], - places=4, - msg="{0} meter value mismatch from ground truth!".format(key), - ) + if torch.is_tensor(value1[key]): + self.assertTrue( + torch.all(torch.eq(value1[key], val)), + msg="{0} meter value mismatch after state transfer!".format(key), + ) + self.assertTrue( + torch.all(torch.eq(value1[key], expected_value[key])), + msg="{0} meter value mismatch from ground truth!".format(key), + ) + else: + self.assertAlmostEqual( + value1[key], + val, + places=4, + msg="{0} meter value mismatch after state transfer!".format(key), + ) + self.assertAlmostEqual( + value1[key], + expected_value[key], + places=4, + msg="{0} meter value mismatch from ground truth!".format(key), + ) def _spawn_all_meter_workers(self, world_size, meters, is_train): filename = tempfile.NamedTemporaryFile(delete=True).name