Skip to content

Commit

Permalink
mean average precision meter (#265)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #265

Average Precision (AP) summarizes a precision-recall curve as the weighted mean of precisions achieved at each threshold, with the increase in recall from the previous threshold used as the weight.

See AP formal definition (https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html).

Mean Average Precision is an average of AP over different classes.

In this diff, we implement Mean Average Precision meter that is useful for multi-label classification task.

Differential Revision: D18715190

fbshipit-source-id: 83cffde189d4e2f100523c1d33353d85fd4815e7
  • Loading branch information
stephenyan1231 authored and facebook-github-bot committed Nov 30, 2019
1 parent 790b048 commit cdb6415
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 35 deletions.
22 changes: 16 additions & 6 deletions classy_vision/generic/distributed_util.py
Expand Up @@ -82,7 +82,7 @@ def all_reduce_sum(tensor):
return tensor


def gather_from_all(tensor):
def gather_tensors_from_all(tensor):
"""
Wrapper over torch.distributed.all_gather for performing
'gather' of 'tensor' over all processes in both distributed /
Expand All @@ -91,11 +91,21 @@ def gather_from_all(tensor):
if tensor.ndim == 0:
# 0 dim tensors cannot be gathered. so unsqueeze
tensor = tensor.unsqueeze(0)
gathered_tensor = [
torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(gathered_tensor, tensor)
gathered_tensor = torch.cat(gathered_tensor, 0)

if is_distributed_training_run():
gathered_tensors = [
torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(gathered_tensors, tensor)
else:
gathered_tensors = [tensor]

return gathered_tensors


def gather_from_all(tensor):
gathered_tensors = gather_tensors_from_all()
gathered_tensor = torch.cat(gathered_tensors, 0)
return gathered_tensor


Expand Down
12 changes: 10 additions & 2 deletions classy_vision/hooks/tensorboard_plot_hook.py
Expand Up @@ -129,7 +129,15 @@ def on_phase_end(
log.warn(f"Skipping meter {meter.name} with value: {meter.value}")
continue
for name, value in meter.value.items():
meter_key = f"{phase_type}_{meter.name}_{name}"
self.tb_writer.add_scalar(meter_key, value, global_step=phase_type_idx)
if isinstance(value, int) or isinstance(value, float):
meter_key = f"{phase_type}_{meter.name}_{name}"
self.tb_writer.add_scalar(
meter_key, value, global_step=phase_type_idx
)
else:
log.warn(
f"Skipping meter name {meter.name}_{name} with value: {value}"
)
continue

logging.info(f"Done plotting to Tensorboard")
97 changes: 70 additions & 27 deletions test/generic/meter_test_utils.py
Expand Up @@ -49,10 +49,13 @@ def _meter_worker(qin, qout, meter, is_train, world_size, rank, filename):
continue

if signal == UPDATE_SIGNAL:
print("signal UPDATE_SIGNAL")
meter.update(val[0], val[1], is_train=is_train)

elif signal == VALUE_SIGNAL:
print("signal VALUE_SIGNAL")
meter.sync_state()
print("meter.value: %s" % str(meter.value))
qout.put(meter.value)

elif signal == SHUTDOWN_SIGNAL:
Expand Down Expand Up @@ -106,23 +109,45 @@ def _apply_updates_and_test_meter(
self.assertTrue(
key in meter_value, msg="{0} not in meter value!".format(key)
)
self.assertAlmostEqual(
meter_value[key],
val,
places=4,
msg="{0} meter value mismatch!".format(key),
)
if torch.is_tensor(meter_value[key]):
self.assertTrue(
torch.all(torch.eq(meter_value[key], val)),
msg="{0} meter value mismatch!".format(key),
)
else:
self.assertAlmostEqual(
meter_value[key],
val,
places=4,
msg="{0} meter value mismatch!".format(key),
)

def _values_match_expected_value(self, value0, value1, expected_value):
for key, val in expected_value.items():
self.assertTrue(key in value0, msg="{0} not in meter value!".format(key))
self.assertAlmostEqual(
value0[key], val, places=4, msg="{0} meter value mismatch!".format(key)
)
self.assertTrue(key in value1, msg="{0} not in meter value!".format(key))
self.assertAlmostEqual(
value1[key], val, places=4, msg="{0} meter value mismatch!".format(key)
)
if torch.is_tensor(val):
self.assertTrue(
torch.all(torch.eq(value0[key], val)),
"{0} meter value mismatch!".format(key),
)
self.assertTrue(
torch.all(torch.eq(value1[key], val)),
"{0} meter value mismatch!".format(key),
)
else:
self.assertAlmostEqual(
value0[key],
val,
places=4,
msg="{0} meter value mismatch!".format(key),
)
self.assertAlmostEqual(
value1[key],
val,
places=4,
msg="{0} meter value mismatch!".format(key),
)

def meter_update_and_reset_test(
self, meter, model_outputs, targets, expected_value, **kwargs
Expand Down Expand Up @@ -192,25 +217,43 @@ def meter_get_set_classy_state_test(
meter1.sync_state()
value1 = meter1.value
for key, val in value0.items():
self.assertNotEqual(
value1[key], val, msg="{0} meter values should not be same!".format(key)
)
if torch.is_tensor(value1[key]):
self.assertFalse(
torch.all(torch.eq(value1[key], val)),
msg="{0} meter values should not be same!".format(key),
)
else:
self.assertNotEqual(
value1[key],
val,
msg="{0} meter values should not be same!".format(key),
)

meter0.set_classy_state(meter1.get_classy_state())
value0 = meter0.value
for key, val in value0.items():
self.assertAlmostEqual(
value1[key],
val,
places=4,
msg="{0} meter value mismatch after state transfer!".format(key),
)
self.assertAlmostEqual(
value1[key],
expected_value[key],
places=4,
msg="{0} meter value mismatch from ground truth!".format(key),
)
if torch.is_tensor(value1[key]):
self.assertTrue(
torch.all(torch.eq(value1[key], val)),
msg="{0} meter value mismatch after state transfer!".format(key),
)
self.assertTrue(
torch.all(torch.eq(value1[key], expected_value[key])),
msg="{0} meter value mismatch from ground truth!".format(key),
)
else:
self.assertAlmostEqual(
value1[key],
val,
places=4,
msg="{0} meter value mismatch after state transfer!".format(key),
)
self.assertAlmostEqual(
value1[key],
expected_value[key],
places=4,
msg="{0} meter value mismatch from ground truth!".format(key),
)

def _spawn_all_meter_workers(self, world_size, meters, is_train):
filename = tempfile.NamedTemporaryFile(delete=True).name
Expand Down

0 comments on commit cdb6415

Please sign in to comment.