Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/746 print0 print toggle #816

Merged
merged 20 commits into from
Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
9b05d6f
basic outline of local printing flag and print0 function
coquelin77 Mar 16, 2021
d62c405
print0 now plays well with global printing
coquelin77 Mar 16, 2021
42407a0
added print0 to the global namespace
coquelin77 Mar 18, 2021
0a7fe8f
doc updates
coquelin77 Mar 18, 2021
7a30e00
Merge branch 'master' into features/746-print0-print-toggle
coquelin77 Apr 6, 2021
9222540
added printed statement to notify user about printing change
coquelin77 Apr 6, 2021
278f82b
Added tests for local printing
coquelin77 Jun 22, 2021
11b01d2
added global print toggle to all printing functions just in case. cha…
coquelin77 Jun 22, 2021
483d525
Merge branch 'master' into features/746-print0-print-toggle
coquelin77 Jul 6, 2021
365dc36
Merge branch 'master' into features/746-print0-print-toggle
coquelin77 Jul 13, 2021
aee0225
Merge branch 'master' into features/746-print0-print-toggle
mtar Jul 28, 2021
68fd0fb
update changelog
mtar Jul 28, 2021
7e403c7
Update heat/core/printing.py
coquelin77 Oct 8, 2021
6c131bd
Merge branch 'master' into features/746-print0-print-toggle
coquelin77 Oct 8, 2021
585c659
Merge branch 'master' into features/746-print0-print-toggle
coquelin77 Oct 13, 2021
8cdc7a7
Merge branch 'master' into features/746-print0-print-toggle
coquelin77 Jan 4, 2022
976f505
adjusted print0 to show device and split, dtype handled by torch, red…
coquelin77 Jan 4, 2022
bb10637
moved device and split info to str generation when local printint is …
coquelin77 Jan 4, 2022
912df19
Merge branch 'master' into features/746-print0-print-toggle
coquelin77 Jan 17, 2022
e542860
Merge branch 'master' into features/746-print0-print-toggle
mtar Jan 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
### Linear Algebra
- [#842](https://github.com/helmholtz-analytics/heat/pull/842) New feature: `vdot`

## Feature additions
### Communication
- [#868](https://github.com/helmholtz-analytics/heat/pull/868) New `MPICommunication` method `Split`

### DNDarray
- [#856](https://github.com/helmholtz-analytics/heat/pull/856) New `DNDarray` method `__torch_proxy__`
- [#885](https://github.com/helmholtz-analytics/heat/pull/885) New `DNDarray` method `conj`
Expand All @@ -31,6 +31,9 @@
- [#829](https://github.com/helmholtz-analytics/heat/pull/829) New feature: `roll`
- [#853](https://github.com/helmholtz-analytics/heat/pull/853) New Feature: `swapaxes`
- [#854](https://github.com/helmholtz-analytics/heat/pull/854) New Feature: `moveaxis`
### Printing
- [#816](https://github.com/helmholtz-analytics/heat/pull/816) New feature: Local printing (`ht.local_printing()`) and global printing options
- [#816](https://github.com/helmholtz-analytics/heat/pull/816) New feature: print only on process 0 with `print0(...)` and `ht.print0(...)`
### Random
- [#858](https://github.com/helmholtz-analytics/heat/pull/858) New Feature: `standard_normal`, `normal`
### Rounding
Expand Down
130 changes: 129 additions & 1 deletion heat/core/printing.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
"""Allows to output DNDarrays to stdout."""

import builtins
import copy
import torch
from .communication import MPI_WORLD

from .dndarray import DNDarray

__all__ = ["get_printoptions", "set_printoptions"]
__all__ = ["get_printoptions", "global_printing", "local_printing", "print0", "set_printoptions"]


# set the default printing width to a 120
_DEFAULT_LINEWIDTH = 120
torch.set_printoptions(profile="default", linewidth=_DEFAULT_LINEWIDTH)
LOCAL_PRINT = False

# printing
__PREFIX = "DNDarray"
Expand All @@ -24,6 +27,126 @@ def get_printoptions() -> dict:
return copy.copy(torch._tensor_str.PRINT_OPTS.__dict__)


def local_printing() -> None:
"""
The builtin `print` function will now print the local PyTorch Tensor values for
`DNDarrays` given as arguments.

Examples
--------
>>> x = ht.ht.arange(15 * 5, dtype=ht.float).reshape((15, 5)).resplit(0)
>>> ht.local_printing()
[0/2]Printing options set to LOCAL. DNDarrays will print the local PyTorch Tensors
>>> print(x)
[0/2] [[ 0., 1., 2., 3., 4.],
[0/2] [ 5., 6., 7., 8., 9.],
[0/2] [10., 11., 12., 13., 14.],
[0/2] [15., 16., 17., 18., 19.],
[0/2] [20., 21., 22., 23., 24.]]
[1/2] [[25., 26., 27., 28., 29.],
[1/2] [30., 31., 32., 33., 34.],
[1/2] [35., 36., 37., 38., 39.],
[1/2] [40., 41., 42., 43., 44.],
[1/2] [45., 46., 47., 48., 49.]]
[2/2] [[50., 51., 52., 53., 54.],
[2/2] [55., 56., 57., 58., 59.],
[2/2] [60., 61., 62., 63., 64.],
[2/2] [65., 66., 67., 68., 69.],
[2/2] [70., 71., 72., 73., 74.]]
"""
global LOCAL_PRINT
LOCAL_PRINT = True
print0("Printing options set to LOCAL. DNDarrays will print the local PyTorch Tensors")


def global_printing() -> None:
"""
For `DNDarray`s, the builtin `print` function will gather all of the data, format it
then print it on ONLY rank 0.

Returns
-------
None

Examples
--------
>>> x = ht.arange(15 * 5, dtype=ht.float).reshape((15, 5)).resplit(0)
>>> print(x)
[0] DNDarray([[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.],
[15., 16., 17., 18., 19.],
[20., 21., 22., 23., 24.],
[25., 26., 27., 28., 29.],
[30., 31., 32., 33., 34.],
[35., 36., 37., 38., 39.],
[40., 41., 42., 43., 44.],
[45., 46., 47., 48., 49.],
[50., 51., 52., 53., 54.],
[55., 56., 57., 58., 59.],
[60., 61., 62., 63., 64.],
[65., 66., 67., 68., 69.],
[70., 71., 72., 73., 74.]], dtype=ht.float32, device=cpu:0, split=0)
"""
global LOCAL_PRINT
if not LOCAL_PRINT:
return
LOCAL_PRINT = False
print0(
"Printing options set to GLOBAL. DNDarrays will be collected on process 0 before printing"
)


def print0(*args, **kwargs) -> None:
"""
Wraps the builtin `print` function in such a way that it will only run the command on
rank 0. If this is called with DNDarrays and local printing, only the data local to
process 0 is printed. For more information see the examples.

This function is also available as a builtin when importing heat.

Examples
--------
>>> x = ht.arange(15 * 5, dtype=ht.float).reshape((15, 5)).resplit(0)
>>> # GLOBAL PRINTING
>>> ht.print0(x)
[0] DNDarray([[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.],
[15., 16., 17., 18., 19.],
[20., 21., 22., 23., 24.],
[25., 26., 27., 28., 29.],
[30., 31., 32., 33., 34.],
[35., 36., 37., 38., 39.],
[40., 41., 42., 43., 44.],
[45., 46., 47., 48., 49.],
[50., 51., 52., 53., 54.],
[55., 56., 57., 58., 59.],
[60., 61., 62., 63., 64.],
[65., 66., 67., 68., 69.],
[70., 71., 72., 73., 74.]], dtype=ht.float32, device=cpu:0, split=0)
>>> ht.local_printing()
[0/2] Printing options set to LOCAL. DNDarrays will print the local PyTorch Tensors
>>> print0(x)
[0/2] [[ 0., 1., 2., 3., 4.],
[0/2] [ 5., 6., 7., 8., 9.],
[0/2] [10., 11., 12., 13., 14.],
[0/2] [15., 16., 17., 18., 19.],
[0/2] [20., 21., 22., 23., 24.]], device: cpu:0, split: 0
"""
if not LOCAL_PRINT:
args = list(args)
for i in range(len(args)):
if isinstance(args[i], DNDarray):
args[i] = __str__(args[i])
args = tuple(args)
if MPI_WORLD.rank == 0:
print(*args, **kwargs)


builtins.print0 = print0


def set_printoptions(
precision=None, threshold=None, edgeitems=None, linewidth=None, profile=None, sci_mode=None
):
Expand Down Expand Up @@ -67,6 +190,11 @@ def __str__(dndarray) -> str:
dndarray : DNDarray
The array for which to obtain the corresponding string
"""
if LOCAL_PRINT:
return (
torch._tensor_str._tensor_str(dndarray.larray, 0)
+ f", device: {dndarray.device}, split: {dndarray.split}"
)
tensor_string = _tensor_str(dndarray, __INDENT + 1)
if dndarray.comm.rank != 0:
return ""
Expand Down
36 changes: 36 additions & 0 deletions heat/core/tests/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,49 @@ class TestPrinting(TestCase):
def setUp(self):
# move to CPU only for the testing printing, otherwise the compare string will become messy
ht.use_device("cpu")
ht.global_printing()

def tearDown(self):
# reset the print options back to default after each test run
ht.set_printoptions(profile="default")
# reset the default device
ht.use_device(self.device)

def test_local_printing(self):
x = ht.arange(15 * 5, dtype=ht.float).reshape((15, 5)).resplit(0)
global_comp = (
"DNDarray([[ 0., 1., 2., 3., 4.],\n"
" [ 5., 6., 7., 8., 9.],\n"
" [10., 11., 12., 13., 14.],\n"
" [15., 16., 17., 18., 19.],\n"
" [20., 21., 22., 23., 24.],\n"
" [25., 26., 27., 28., 29.],\n"
" [30., 31., 32., 33., 34.],\n"
" [35., 36., 37., 38., 39.],\n"
" [40., 41., 42., 43., 44.],\n"
" [45., 46., 47., 48., 49.],\n"
" [50., 51., 52., 53., 54.],\n"
" [55., 56., 57., 58., 59.],\n"
" [60., 61., 62., 63., 64.],\n"
" [65., 66., 67., 68., 69.],\n"
" [70., 71., 72., 73., 74.]], dtype=ht.float32, device=cpu:0, split=0)"
)
if x.comm.rank == 0:
self.assertEqual(str(x), global_comp)
else:
self.assertEqual(str(x), "")
ht.local_printing()
local_comp = (
"[[ 0., 1., 2., 3., 4.],\n"
" [ 5., 6., 7., 8., 9.],\n"
" [10., 11., 12., 13., 14.],\n"
" [15., 16., 17., 18., 19.],\n"
" [20., 21., 22., 23., 24.]], device: cpu:0, split: 0"
)
if x.comm.rank == 0 and x.comm.size == 3:
self.assertEqual(str(x), local_comp)
ht.global_printing() # needed to keep things correct for the other tests

def test_get_default_options(self):
print_options = ht.get_printoptions()
comparison = {
Expand Down