Skip to content

Commit

Permalink
used all gather
Browse files Browse the repository at this point in the history
  • Loading branch information
justanhduc committed Jun 17, 2021
1 parent bc5c7df commit 1991a36
Showing 1 changed file with 51 additions and 20 deletions.
71 changes: 51 additions & 20 deletions neural_monitor/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,8 @@ def distributed_collect(f):
def func(self, name: str, value: T.Tensor, *args, **kwargs):
if self.distributed:
assert isinstance(value, T.Tensor), 'value must be a Tensor in distributed mode'

T.distributed.all_reduce(value, op=T.distributed.ReduceOp.SUM)
value = value / T.distributed.get_world_size()
tensor_list = [torch.zeros_like(value, dtype=torch.int64) for _ in range(self.world_size)]
T.distributed.all_gather(tensor_list, value)
return f(self, name, value, *args, **kwargs)

return func
Expand All @@ -229,6 +228,21 @@ def func(self, *args, **kwargs):
return func


def standardize_image(img):
if isinstance(img, T.Tensor):
img = utils.to_numpy(img)

if img.dtype != 'uint8':
img = (255.99 * img).astype('uint8')

if len(img.shape) == 3:
img = np.transpose(img, (2, 0, 1))[None]
elif len(img.shape) == 2:
img = img[None, None]

return img


class Monitor:
"""
Collects statistics and displays the results using various backends.
Expand Down Expand Up @@ -341,6 +355,7 @@ def __init__(self):
self._thread = threading.Thread(target=self._flush, daemon=True)
self.rank = None
self.distributed = None
self.world_size = None

# schedule to flush when the program finishes
atexit.register(self._atexit)
Expand Down Expand Up @@ -413,6 +428,7 @@ def initialize(self, model_name: Optional[str] = None, root: Optional[str] = Non
self.with_git = with_git
self.distributed = T.distributed.is_initialized()
self.rank = T.distributed.get_rank() if self.distributed else 0
self.world_size = T.distributed.get_world_size() if self.distributed else 1
if self.distributed and self.rank != 0:
return

Expand Down Expand Up @@ -931,17 +947,23 @@ def filter_files(file_tuples):
except FileNotFoundError:
root_logger.warning('No such file or directory: %s' % src)

@distributed_collect
@standardize_name
def add_hparam(self, name: str, value: Union[T.Tensor, np.ndarray, float]):
if name not in self._options[self._hparams].keys():
if isinstance(value, (list, tuple)): # in distributed mode
value = value[-1]
if isinstance(value, T.Tensor):
value = utils.to_numpy(value)

self._options[self._hparams][name] = value

@distributed_collect
@standardize_name
def add_metric(self, name: str, value: Union[T.Tensor, np.ndarray, float]):
if name not in self._options[self._hparam_metrics].keys():
if isinstance(value, (list, tuple)): # in distributed mode
value = value[-1]
if isinstance(value, T.Tensor):
value = utils.to_numpy(value)

Expand Down Expand Up @@ -978,6 +1000,9 @@ def plot(self, name: str, value: Union[T.Tensor, np.ndarray, float], smooth: Opt
self._options[name]['smooth'] = smooth
self._options[name]['filter_outliers'] = filter_outliers
self._options[name]['precision'] = precision
if isinstance(value, (list, tuple)):
value = sum(value) / len(value)

if isinstance(value, T.Tensor):
value = utils.to_numpy(value)

Expand Down Expand Up @@ -1015,14 +1040,20 @@ def plot_matrix(self, name: str, value: Union[T.Tensor, np.ndarray, float],

self._options[name]['labels'] = labels
self._options[name]['show_values'] = show_values
if isinstance(value, (list, tuple)):
raise ValueError('Plotting a list of matrices is not supported')

if isinstance(value, T.Tensor):
value = utils.to_numpy(value)

self._mat_since_last_flush[name] = np.array(value)

@distributed_collect
@standardize_name
def scatter(self, name: str, value: Union[T.Tensor, np.ndarray], latest_only: bool = False, **kwargs):
def scatter(self, name: str,
value: Union[T.Tensor, np.ndarray, List[T.Tensor], List[np.ndarray]],
latest_only: bool = False,
**kwargs):
"""
schedules a scattor plot of (a batch of) points.
A 3D :mod:`matplotlib` figure will be rendered and saved every :attr:`~print_freq` iterations.
Expand All @@ -1039,15 +1070,22 @@ def scatter(self, name: str, value: Union[T.Tensor, np.ndarray], latest_only: bo
"""

self._options[name]['latest_only'] = latest_only
if isinstance(value, T.Tensor):
value = utils.to_numpy(value)
if isinstance(value, (list, tuple)):
value = [utils.to_numpy(v[None] if len(v.shape) == 2 else v) for v in value]
else:
if isinstance(value, T.Tensor):
value = utils.to_numpy(value)

if len(value.shape) == 2:
value = value[None]
if len(value.shape) == 2:
value = value[None]

self._points_since_last_flush[name][self.iter] = value
if self.writer is not None:
self.writer.add_mesh(name, value, global_step=self.iter, **kwargs)
if isinstance(value, list):
for i, v in enumerate(value):
self.writer.add_mesh(f'{name}-i', v, global_step=self.iter, **kwargs)
else:
self.writer.add_mesh(name, value, global_step=self.iter, **kwargs)

@distributed_collect
@standardize_name
Expand Down Expand Up @@ -1078,17 +1116,8 @@ def imwrite(self, name: str, value: Union[T.Tensor, np.ndarray], latest_only: Op
"""

self._options[name]['latest_only'] = latest_only
if isinstance(value, T.Tensor):
value = utils.to_numpy(value)

if value.dtype != 'uint8':
value = (255.99 * value).astype('uint8')

if len(value.shape) == 3:
value = np.transpose(value, (2, 0, 1))[None]
elif len(value.shape) == 2:
value = value[None, None]

value = np.concatenate([standardize_image(v) for v in value], 0) \
if isinstance(value, (list, tuple)) else standardize_image(value) # handler for distributed training
self._img_since_last_flush[name][self.iter] = value
if self.writer is not None:
prefix = kwargs.pop('prefix', 'image/')
Expand Down Expand Up @@ -1117,6 +1146,8 @@ def hist(self, name, value: Union[T.Tensor, np.ndarray], n_bins: int = 20, lates

self._options[name]['latest_only'] = latest_only
self._options[name]['n_bins'] = n_bins
if isinstance(value, (list, tuple)): # in distributed training
value = T.stack(value)
if isinstance(value, T.Tensor):
value = utils.to_numpy(value)

Expand Down

0 comments on commit 1991a36

Please sign in to comment.