Skip to content

Commit

Permalink
Merge pull request #816 from helmholtz-analytics/features/746-print0-…
Browse files Browse the repository at this point in the history
…print-toggle

Features/746 print0 print toggle
  • Loading branch information
coquelin77 committed Jan 24, 2022
2 parents 2867fe9 + e542860 commit 35f39d1
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 2 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
### Linear Algebra
- [#842](https://github.com/helmholtz-analytics/heat/pull/842) New feature: `vdot`

## Feature additions
### Communication
- [#868](https://github.com/helmholtz-analytics/heat/pull/868) New `MPICommunication` method `Split`

### DNDarray
- [#856](https://github.com/helmholtz-analytics/heat/pull/856) New `DNDarray` method `__torch_proxy__`
- [#885](https://github.com/helmholtz-analytics/heat/pull/885) New `DNDarray` method `conj`
Expand All @@ -35,6 +35,9 @@
- [#829](https://github.com/helmholtz-analytics/heat/pull/829) New feature: `roll`
- [#853](https://github.com/helmholtz-analytics/heat/pull/853) New Feature: `swapaxes`
- [#854](https://github.com/helmholtz-analytics/heat/pull/854) New Feature: `moveaxis`
### Printing
- [#816](https://github.com/helmholtz-analytics/heat/pull/816) New feature: Local printing (`ht.local_printing()`) and global printing options
- [#816](https://github.com/helmholtz-analytics/heat/pull/816) New feature: print only on process 0 with `print0(...)` and `ht.print0(...)`
### Random
- [#858](https://github.com/helmholtz-analytics/heat/pull/858) New Feature: `standard_normal`, `normal`
### Rounding
Expand Down
130 changes: 129 additions & 1 deletion heat/core/printing.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
"""Allows to output DNDarrays to stdout."""

import builtins
import copy
import torch
from .communication import MPI_WORLD

from .dndarray import DNDarray

__all__ = ["get_printoptions", "set_printoptions"]
__all__ = ["get_printoptions", "global_printing", "local_printing", "print0", "set_printoptions"]


# set the default printing width to a 120
_DEFAULT_LINEWIDTH = 120
torch.set_printoptions(profile="default", linewidth=_DEFAULT_LINEWIDTH)
LOCAL_PRINT = False

# printing
__PREFIX = "DNDarray"
Expand All @@ -24,6 +27,126 @@ def get_printoptions() -> dict:
return copy.copy(torch._tensor_str.PRINT_OPTS.__dict__)


def local_printing() -> None:
"""
The builtin `print` function will now print the local PyTorch Tensor values for
`DNDarrays` given as arguments.
Examples
--------
>>> x = ht.ht.arange(15 * 5, dtype=ht.float).reshape((15, 5)).resplit(0)
>>> ht.local_printing()
[0/2]Printing options set to LOCAL. DNDarrays will print the local PyTorch Tensors
>>> print(x)
[0/2] [[ 0., 1., 2., 3., 4.],
[0/2] [ 5., 6., 7., 8., 9.],
[0/2] [10., 11., 12., 13., 14.],
[0/2] [15., 16., 17., 18., 19.],
[0/2] [20., 21., 22., 23., 24.]]
[1/2] [[25., 26., 27., 28., 29.],
[1/2] [30., 31., 32., 33., 34.],
[1/2] [35., 36., 37., 38., 39.],
[1/2] [40., 41., 42., 43., 44.],
[1/2] [45., 46., 47., 48., 49.]]
[2/2] [[50., 51., 52., 53., 54.],
[2/2] [55., 56., 57., 58., 59.],
[2/2] [60., 61., 62., 63., 64.],
[2/2] [65., 66., 67., 68., 69.],
[2/2] [70., 71., 72., 73., 74.]]
"""
global LOCAL_PRINT
LOCAL_PRINT = True
print0("Printing options set to LOCAL. DNDarrays will print the local PyTorch Tensors")


def global_printing() -> None:
"""
For `DNDarray`s, the builtin `print` function will gather all of the data, format it
then print it on ONLY rank 0.
Returns
-------
None
Examples
--------
>>> x = ht.arange(15 * 5, dtype=ht.float).reshape((15, 5)).resplit(0)
>>> print(x)
[0] DNDarray([[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.],
[15., 16., 17., 18., 19.],
[20., 21., 22., 23., 24.],
[25., 26., 27., 28., 29.],
[30., 31., 32., 33., 34.],
[35., 36., 37., 38., 39.],
[40., 41., 42., 43., 44.],
[45., 46., 47., 48., 49.],
[50., 51., 52., 53., 54.],
[55., 56., 57., 58., 59.],
[60., 61., 62., 63., 64.],
[65., 66., 67., 68., 69.],
[70., 71., 72., 73., 74.]], dtype=ht.float32, device=cpu:0, split=0)
"""
global LOCAL_PRINT
if not LOCAL_PRINT:
return
LOCAL_PRINT = False
print0(
"Printing options set to GLOBAL. DNDarrays will be collected on process 0 before printing"
)


def print0(*args, **kwargs) -> None:
"""
Wraps the builtin `print` function in such a way that it will only run the command on
rank 0. If this is called with DNDarrays and local printing, only the data local to
process 0 is printed. For more information see the examples.
This function is also available as a builtin when importing heat.
Examples
--------
>>> x = ht.arange(15 * 5, dtype=ht.float).reshape((15, 5)).resplit(0)
>>> # GLOBAL PRINTING
>>> ht.print0(x)
[0] DNDarray([[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.],
[15., 16., 17., 18., 19.],
[20., 21., 22., 23., 24.],
[25., 26., 27., 28., 29.],
[30., 31., 32., 33., 34.],
[35., 36., 37., 38., 39.],
[40., 41., 42., 43., 44.],
[45., 46., 47., 48., 49.],
[50., 51., 52., 53., 54.],
[55., 56., 57., 58., 59.],
[60., 61., 62., 63., 64.],
[65., 66., 67., 68., 69.],
[70., 71., 72., 73., 74.]], dtype=ht.float32, device=cpu:0, split=0)
>>> ht.local_printing()
[0/2] Printing options set to LOCAL. DNDarrays will print the local PyTorch Tensors
>>> print0(x)
[0/2] [[ 0., 1., 2., 3., 4.],
[0/2] [ 5., 6., 7., 8., 9.],
[0/2] [10., 11., 12., 13., 14.],
[0/2] [15., 16., 17., 18., 19.],
[0/2] [20., 21., 22., 23., 24.]], device: cpu:0, split: 0
"""
if not LOCAL_PRINT:
args = list(args)
for i in range(len(args)):
if isinstance(args[i], DNDarray):
args[i] = __str__(args[i])
args = tuple(args)
if MPI_WORLD.rank == 0:
print(*args, **kwargs)


builtins.print0 = print0


def set_printoptions(
precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None, sci_mode=None
):
Expand Down Expand Up @@ -67,6 +190,11 @@ def __str__(dndarray) -> str:
dndarray : DNDarray
The array for which to obtain the corresponding string
"""
if LOCAL_PRINT:
return (
torch._tensor_str._tensor_str(dndarray.larray, 0)
+ f", device: {dndarray.device}, split: {dndarray.split}"
)
tensor_string = _tensor_str(dndarray, __INDENT + 1)
if dndarray.comm.rank != 0:
return ""
Expand Down
36 changes: 36 additions & 0 deletions heat/core/tests/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,49 @@ class TestPrinting(TestCase):
def setUp(self):
# move to CPU only for the testing printing, otherwise the compare string will become messy
ht.use_device("cpu")
ht.global_printing()

def tearDown(self):
# reset the print options back to default after each test run
ht.set_printoptions(profile="default")
# reset the default device
ht.use_device(self.device)

def test_local_printing(self):
x = ht.arange(15 * 5, dtype=ht.float).reshape((15, 5)).resplit(0)
global_comp = (
"DNDarray([[ 0., 1., 2., 3., 4.],\n"
" [ 5., 6., 7., 8., 9.],\n"
" [10., 11., 12., 13., 14.],\n"
" [15., 16., 17., 18., 19.],\n"
" [20., 21., 22., 23., 24.],\n"
" [25., 26., 27., 28., 29.],\n"
" [30., 31., 32., 33., 34.],\n"
" [35., 36., 37., 38., 39.],\n"
" [40., 41., 42., 43., 44.],\n"
" [45., 46., 47., 48., 49.],\n"
" [50., 51., 52., 53., 54.],\n"
" [55., 56., 57., 58., 59.],\n"
" [60., 61., 62., 63., 64.],\n"
" [65., 66., 67., 68., 69.],\n"
" [70., 71., 72., 73., 74.]], dtype=ht.float32, device=cpu:0, split=0)"
)
if x.comm.rank == 0:
self.assertEqual(str(x), global_comp)
else:
self.assertEqual(str(x), "")
ht.local_printing()
local_comp = (
"[[ 0., 1., 2., 3., 4.],\n"
" [ 5., 6., 7., 8., 9.],\n"
" [10., 11., 12., 13., 14.],\n"
" [15., 16., 17., 18., 19.],\n"
" [20., 21., 22., 23., 24.]], device: cpu:0, split: 0"
)
if x.comm.rank == 0 and x.comm.size == 3:
self.assertEqual(str(x), local_comp)
ht.global_printing() # needed to keep things correct for the other tests

def test_get_default_options(self):
print_options = ht.get_printoptions()
comparison = {
Expand Down

0 comments on commit 35f39d1

Please sign in to comment.