Skip to content

Commit

Permalink
mean average precision meter (#265)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #265

Average Precision (AP) summarizes a precision-recall curve as the weighted mean of precisions achieved at each threshold, with the increase in recall from the previous threshold used as the weight.

See AP formal definition (https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html).

Mean Average Precision is an average of AP over different classes.

In this diff, we implement Mean Average Precision meter that is useful for multi-label classification task.

We implement a simple `SparseBinaryMatrix` to store multi-hot groundtruth label to save memory. It also supports `max_capacity` argument to limit the memory footprint of the meter because model predictions, which is stored as a dense matrix of size N x K (  N is number of samples, and K number of classes), can be quite large specially for training set. For example, SSVP has a training set of size 2.6M, and has 7K+ classes.

Reviewed By: aadcock

Differential Revision: D18715190

fbshipit-source-id: c4777b179b64394adece4d8c5975f97b0909753e
  • Loading branch information
stephenyan1231 authored and facebook-github-bot committed Dec 5, 2019
1 parent 34a0055 commit 2d33c03
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 35 deletions.
27 changes: 21 additions & 6 deletions classy_vision/generic/distributed_util.py
Expand Up @@ -82,7 +82,7 @@ def all_reduce_sum(tensor):
return tensor


def gather_from_all(tensor):
def gather_tensors_from_all(tensor):
"""
Wrapper over torch.distributed.all_gather for performing
'gather' of 'tensor' over all processes in both distributed /
Expand All @@ -91,11 +91,26 @@ def gather_from_all(tensor):
if tensor.ndim == 0:
# 0 dim tensors cannot be gathered. so unsqueeze
tensor = tensor.unsqueeze(0)
gathered_tensor = [
torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(gathered_tensor, tensor)
gathered_tensor = torch.cat(gathered_tensor, 0)

if is_distributed_training_run():
tensor, orig_device = convert_to_distributed_tensor(tensor)
gathered_tensors = [
torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(gathered_tensors, tensor)
gathered_tensors = [
convert_to_normal_tensor(_tensor, orig_device)
for _tensor in gathered_tensors
]
else:
gathered_tensors = [tensor]

return gathered_tensors


def gather_from_all(tensor):
gathered_tensors = gather_tensors_from_all(tensor)
gathered_tensor = torch.cat(gathered_tensors, 0)
return gathered_tensor


Expand Down
12 changes: 10 additions & 2 deletions classy_vision/hooks/tensorboard_plot_hook.py
Expand Up @@ -132,7 +132,15 @@ def on_phase_end(
log.warn(f"Skipping meter {meter.name} with value: {meter.value}")
continue
for name, value in meter.value.items():
meter_key = f"{phase_type}_{meter.name}_{name}"
self.tb_writer.add_scalar(meter_key, value, global_step=phase_type_idx)
if isinstance(value, float):
meter_key = f"{phase_type}_{meter.name}_{name}"
self.tb_writer.add_scalar(
meter_key, value, global_step=phase_type_idx
)
else:
log.warn(
f"Skipping meter name {meter.name}_{name} with value: {value}"
)
continue

logging.info(f"Done plotting to Tensorboard")
94 changes: 67 additions & 27 deletions test/generic/meter_test_utils.py
Expand Up @@ -106,23 +106,45 @@ def _apply_updates_and_test_meter(
self.assertTrue(
key in meter_value, msg="{0} not in meter value!".format(key)
)
self.assertAlmostEqual(
meter_value[key],
val,
places=4,
msg="{0} meter value mismatch!".format(key),
)
if torch.is_tensor(meter_value[key]):
self.assertTrue(
torch.all(torch.eq(meter_value[key], val)),
msg="{0} meter value mismatch!".format(key),
)
else:
self.assertAlmostEqual(
meter_value[key],
val,
places=4,
msg="{0} meter value mismatch!".format(key),
)

def _values_match_expected_value(self, value0, value1, expected_value):
for key, val in expected_value.items():
self.assertTrue(key in value0, msg="{0} not in meter value!".format(key))
self.assertAlmostEqual(
value0[key], val, places=4, msg="{0} meter value mismatch!".format(key)
)
self.assertTrue(key in value1, msg="{0} not in meter value!".format(key))
self.assertAlmostEqual(
value1[key], val, places=4, msg="{0} meter value mismatch!".format(key)
)
if torch.is_tensor(val):
self.assertTrue(
torch.all(torch.eq(value0[key], val)),
"{0} meter value mismatch!".format(key),
)
self.assertTrue(
torch.all(torch.eq(value1[key], val)),
"{0} meter value mismatch!".format(key),
)
else:
self.assertAlmostEqual(
value0[key],
val,
places=4,
msg="{0} meter value mismatch!".format(key),
)
self.assertAlmostEqual(
value1[key],
val,
places=4,
msg="{0} meter value mismatch!".format(key),
)

def meter_update_and_reset_test(
self, meter, model_outputs, targets, expected_value, **kwargs
Expand Down Expand Up @@ -192,25 +214,43 @@ def meter_get_set_classy_state_test(
meter1.sync_state()
value1 = meter1.value
for key, val in value0.items():
self.assertNotEqual(
value1[key], val, msg="{0} meter values should not be same!".format(key)
)
if torch.is_tensor(value1[key]):
self.assertFalse(
torch.all(torch.eq(value1[key], val)),
msg="{0} meter values should not be same!".format(key),
)
else:
self.assertNotEqual(
value1[key],
val,
msg="{0} meter values should not be same!".format(key),
)

meter0.set_classy_state(meter1.get_classy_state())
value0 = meter0.value
for key, val in value0.items():
self.assertAlmostEqual(
value1[key],
val,
places=4,
msg="{0} meter value mismatch after state transfer!".format(key),
)
self.assertAlmostEqual(
value1[key],
expected_value[key],
places=4,
msg="{0} meter value mismatch from ground truth!".format(key),
)
if torch.is_tensor(value1[key]):
self.assertTrue(
torch.all(torch.eq(value1[key], val)),
msg="{0} meter value mismatch after state transfer!".format(key),
)
self.assertTrue(
torch.all(torch.eq(value1[key], expected_value[key])),
msg="{0} meter value mismatch from ground truth!".format(key),
)
else:
self.assertAlmostEqual(
value1[key],
val,
places=4,
msg="{0} meter value mismatch after state transfer!".format(key),
)
self.assertAlmostEqual(
value1[key],
expected_value[key],
places=4,
msg="{0} meter value mismatch from ground truth!".format(key),
)

def _spawn_all_meter_workers(self, world_size, meters, is_train):
filename = tempfile.NamedTemporaryFile(delete=True).name
Expand Down

0 comments on commit 2d33c03

Please sign in to comment.