diff --git a/classy_vision/generic/distributed_util.py b/classy_vision/generic/distributed_util.py index 3deab3df29..8395c1f96e 100644 --- a/classy_vision/generic/distributed_util.py +++ b/classy_vision/generic/distributed_util.py @@ -82,7 +82,7 @@ def all_reduce_sum(tensor): return tensor -def gather_from_all(tensor): +def gather_tensors_from_all(tensor): """ Wrapper over torch.distributed.all_gather for performing 'gather' of 'tensor' over all processes in both distributed / @@ -91,11 +91,21 @@ def gather_from_all(tensor): if tensor.ndim == 0: # 0 dim tensors cannot be gathered. so unsqueeze tensor = tensor.unsqueeze(0) - gathered_tensor = [ - torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size()) - ] - torch.distributed.all_gather(gathered_tensor, tensor) - gathered_tensor = torch.cat(gathered_tensor, 0) + + if is_distributed_training_run(): + gathered_tensors = [ + torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(gathered_tensors, tensor) + else: + gathered_tensors = [tensor] + + return gathered_tensors + + +def gather_from_all(tensor): + gathered_tensors = gather_tensors_from_all() + gathered_tensor = torch.cat(gathered_tensors, 0) return gathered_tensor diff --git a/classy_vision/hooks/tensorboard_plot_hook.py b/classy_vision/hooks/tensorboard_plot_hook.py index 35f55ac63b..367ed20c56 100644 --- a/classy_vision/hooks/tensorboard_plot_hook.py +++ b/classy_vision/hooks/tensorboard_plot_hook.py @@ -129,7 +129,15 @@ def on_phase_end( log.warn(f"Skipping meter {meter.name} with value: {meter.value}") continue for name, value in meter.value.items(): - meter_key = f"{phase_type}_{meter.name}_{name}" - self.tb_writer.add_scalar(meter_key, value, global_step=phase_type_idx) + if isinstance(value, int) or isinstance(value, float): + meter_key = f"{phase_type}_{meter.name}_{name}" + self.tb_writer.add_scalar( + meter_key, value, global_step=phase_type_idx + ) + else: + log.warn( + f"Skipping meter name {meter.name}_{name} with value: {value}" + ) + continue logging.info(f"Done plotting to Tensorboard") diff --git a/test/generic/meter_test_utils.py b/test/generic/meter_test_utils.py index 7d489f8df4..a7d9a83ba4 100644 --- a/test/generic/meter_test_utils.py +++ b/test/generic/meter_test_utils.py @@ -49,10 +49,13 @@ def _meter_worker(qin, qout, meter, is_train, world_size, rank, filename): continue if signal == UPDATE_SIGNAL: + print("signal UPDATE_SIGNAL") meter.update(val[0], val[1], is_train=is_train) elif signal == VALUE_SIGNAL: + print("signal VALUE_SIGNAL") meter.sync_state() + print("meter.value: %s" % str(meter.value)) qout.put(meter.value) elif signal == SHUTDOWN_SIGNAL: @@ -106,23 +109,45 @@ def _apply_updates_and_test_meter( self.assertTrue( key in meter_value, msg="{0} not in meter value!".format(key) ) - self.assertAlmostEqual( - meter_value[key], - val, - places=4, - msg="{0} meter value mismatch!".format(key), - ) + if torch.is_tensor(meter_value[key]): + self.assertTrue( + torch.all(torch.eq(meter_value[key], val)), + msg="{0} meter value mismatch!".format(key), + ) + else: + self.assertAlmostEqual( + meter_value[key], + val, + places=4, + msg="{0} meter value mismatch!".format(key), + ) def _values_match_expected_value(self, value0, value1, expected_value): for key, val in expected_value.items(): self.assertTrue(key in value0, msg="{0} not in meter value!".format(key)) - self.assertAlmostEqual( - value0[key], val, places=4, msg="{0} meter value mismatch!".format(key) - ) self.assertTrue(key in value1, msg="{0} not in meter value!".format(key)) - self.assertAlmostEqual( - value1[key], val, places=4, msg="{0} meter value mismatch!".format(key) - ) + if torch.is_tensor(val): + self.assertTrue( + torch.all(torch.eq(value0[key], val)), + "{0} meter value mismatch!".format(key), + ) + self.assertTrue( + torch.all(torch.eq(value1[key], val)), + "{0} meter value mismatch!".format(key), + ) + else: + self.assertAlmostEqual( + value0[key], + val, + places=4, + msg="{0} meter value mismatch!".format(key), + ) + self.assertAlmostEqual( + value1[key], + val, + places=4, + msg="{0} meter value mismatch!".format(key), + ) def meter_update_and_reset_test( self, meter, model_outputs, targets, expected_value, **kwargs @@ -192,25 +217,43 @@ def meter_get_set_classy_state_test( meter1.sync_state() value1 = meter1.value for key, val in value0.items(): - self.assertNotEqual( - value1[key], val, msg="{0} meter values should not be same!".format(key) - ) + if torch.is_tensor(value1[key]): + self.assertFalse( + torch.all(torch.eq(value1[key], val)), + msg="{0} meter values should not be same!".format(key), + ) + else: + self.assertNotEqual( + value1[key], + val, + msg="{0} meter values should not be same!".format(key), + ) meter0.set_classy_state(meter1.get_classy_state()) value0 = meter0.value for key, val in value0.items(): - self.assertAlmostEqual( - value1[key], - val, - places=4, - msg="{0} meter value mismatch after state transfer!".format(key), - ) - self.assertAlmostEqual( - value1[key], - expected_value[key], - places=4, - msg="{0} meter value mismatch from ground truth!".format(key), - ) + if torch.is_tensor(value1[key]): + self.assertTrue( + torch.all(torch.eq(value1[key], val)), + msg="{0} meter value mismatch after state transfer!".format(key), + ) + self.assertTrue( + torch.all(torch.eq(value1[key], expected_value[key])), + msg="{0} meter value mismatch from ground truth!".format(key), + ) + else: + self.assertAlmostEqual( + value1[key], + val, + places=4, + msg="{0} meter value mismatch after state transfer!".format(key), + ) + self.assertAlmostEqual( + value1[key], + expected_value[key], + places=4, + msg="{0} meter value mismatch from ground truth!".format(key), + ) def _spawn_all_meter_workers(self, world_size, meters, is_train): filename = tempfile.NamedTemporaryFile(delete=True).name