Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Robustify configurable, more options for testrun and contrib changes #138

Merged
merged 7 commits into from
Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]
python-version: [3.7, 3.8]

steps:
- uses: actions/checkout@v2
Expand Down
3 changes: 0 additions & 3 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@ trigger:

strategy:
matrix:
Python36:
IMAGE_NAME: 'ubuntu-18.04'
python.version: '3.6'
Python37:
IMAGE_NAME: 'ubuntu-18.04'
python.version: '3.7'
Expand Down
49 changes: 44 additions & 5 deletions padertorch/configurable.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,10 +434,21 @@ def from_config(
"""Produce a Configurable instance from a valid config."""
# TODO: assert do not use defaults


if isinstance(config, _DogmaticConfig):
config = config.to_dict() # if called in finalize_dogmatic dict
assert 'factory' in config, (cls, config)
if cls is not Configurable:

if cls.__module__ == '__main__':
# When a class is defined in the main script, it will be
# __main__.<ModelName>, but it should be <script>.<ModelName>.
# This fix is active when the script is called with
# "python -m <script> ..."
# but not when it is called with "python <script>.py ..."
# pylint: disable=self-cls-assignment
cls = import_class(class_to_str(cls))

assert issubclass(import_class(config['factory']), cls), \
(config['factory'], cls)

Expand Down Expand Up @@ -1321,6 +1332,10 @@ def config_to_instance(config, strict=False):
>>> config_to_instance(config)
<class 'torch.nn.modules.linear.Linear'>
"""

if isinstance(config, _DogmaticConfig):
config = config.to_dict() # if called in finalize_dogmatic dict

if isinstance(config, dict):
special_key = _get_special_key(config)
if special_key:
Expand Down Expand Up @@ -1534,7 +1549,7 @@ def _sacred_dogmatic_to_dict(config):
return config


def _get_signature(cls, drop_positional_only=False):
def _get_signature(cls, drop_positional_only=False, drop_type_annotations=False):
"""

>>> _get_signature(dict)
Expand All @@ -1555,6 +1570,12 @@ def _get_signature(cls, drop_positional_only=False):
...
ValueError: no signature found for builtin type <class 'set'>


>>> _get_signature(Configurable.from_file)
<Signature (config_path: pathlib.Path, in_config_path: str = '', consider_mpi=False)>
>>> _get_signature(Configurable.from_file, drop_type_annotations=True)
<Signature (config_path, in_config_path='', consider_mpi=False)>

"""
if cls in [
set, # py38: set missing signature
Expand All @@ -1567,7 +1588,7 @@ def _get_signature(cls, drop_positional_only=False):
default=(),
)]
)
elif cls in [dict]:
elif cls.__init__ in [dict.__init__]:
# Dict has no correct signature, hence return the signature, that is
# needed here.
sig = inspect.Signature(
Expand All @@ -1586,6 +1607,15 @@ def _get_signature(cls, drop_positional_only=False):
if p.kind != inspect.Parameter.POSITIONAL_ONLY
]
)
if drop_type_annotations:
sig = sig.replace(
parameters=[
p.replace(annotation=p.empty)
for p in sig.parameters.values()
],
return_annotation=sig.empty
)

return sig


Expand All @@ -1609,7 +1639,16 @@ def get_signature(factory):
"""
if factory in [tuple, list, set, dict]:
return {}
sig = inspect.signature(factory)
try:
sig = inspect.signature(factory)
except ValueError:
if factory.__init__ in [tuple.__init__, list.__init__, set.__init__, dict.__init__]:
# Buildin type is in MRO and __init__ is not overwritten. e.g.
# ValueError: no signature found for builtin type <class 'paderbox.utils.mapping.Dispatcher'>
return {}
else:
raise

defaults = {}
param: inspect.Parameter
for name, param in sig.parameters.items():
Expand Down Expand Up @@ -1800,7 +1839,7 @@ def _check_redundant_keys(self, msg):
f'{msg}\n'
f'Too many keywords for the factory {imported}.\n'
f'Redundant keys: {redundant_keys}\n'
f'Signature: {_get_signature(imported)}\n'
f'Signature: {_get_signature(imported, drop_type_annotations=True)}\n'
f'Current config with fallbacks:\n{pretty(self.data)}'
)

Expand Down Expand Up @@ -1946,7 +1985,7 @@ def to_dict(self):
except KeyError as ex:
from IPython.lib.pretty import pretty
if self.special_key == 'factory' \
and self.special_key in self._key_candidates() and \
and self.special_key in self._key_candidates() and \
k != self.special_key:
# KeyError has a bad __repr__, use Exception
missing_keys = set(self._key_candidates()) - set(self.data.keys())
Expand Down
42 changes: 38 additions & 4 deletions padertorch/contrib/cb/tensorboard_symlink_tree.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Create a symlink tree for all specified files in the current folder.

python -m padertorch.contrib.cb.tensorboard_symlink_tree ../*/*tfevents*
python -m padertorch.contrib.cb.tensorboard_symlink_tree ../*/*tfevents* --max_age=1days

Usecase:

Expand All @@ -28,8 +28,13 @@
Because of this I created a Makefile in that folder:

.../tensorboard$ cat Makefile
symlink_tree1day:
find . -xtype l -delete # Remove broken symlinks: https://unix.stackexchange.com/a/314975/283777
python -m padertorch.contrib.cb.tensorboard_symlink_tree --prefix=.. ../*/*tfevents* --max_age=1days

symlink_tree:
python -m padertorch.contrib.cb.tensorboard_symlink_tree ../*/*tfevents*
find . -xtype l -delete # Remove broken symlinks: https://unix.stackexchange.com/a/314975/283777
python -m padertorch.contrib.cb.tensorboard_symlink_tree --prefix=.. ../*/*tfevents*

tensorboard:
date && $(cd .../tensorboard && ulimit -v 10000000 && tensorboard --bind_all -v 1 --logdir=. --port=...) && date || date
Expand All @@ -38,22 +43,51 @@

import os
from pathlib import Path
import datetime

import paderbox as pb


def main(*files, prefix=None):
def main(*files, prefix=None, max_age=None):
if prefix is None:
prefix = os.path.commonpath(files)
print('Common Prefix', prefix)
print('Create')

files = [Path(f) for f in files]

if max_age is not None:
# Panda import is slow, but pd.Timedelta
# accepts many styles for time
# (e.g. '1day')
import pandas as pd
max_age = pd.Timedelta(max_age)
now = pd.Timestamp('now')

files = sorted(files, key=lambda file: file.stat().st_mtime)

for file in files:
file = Path(file)
link_name = file.relative_to(prefix)
if max_age is not None:
last_modified = file.stat().st_mtime
last_modified = datetime.datetime.fromtimestamp(last_modified)

if max_age > now - last_modified:
# Create symlink if it doesn't exist.
pass
else:
if not link_name.is_symlink():
print(f'Skip {file}, it is {now - last_modified} > {max_age} old.')
continue

link_name.parent.mkdir(exist_ok=True)
source = os.path.relpath(file, link_name.parent)
if not link_name.exists():
print(f'\t{link_name} -> {source}')

# Create symlink if it does not exist,
# or check that the symlink point to the
# same file.
pb.io.symlink(source, link_name)
print('Finish')

Expand Down
6 changes: 3 additions & 3 deletions padertorch/contrib/cb/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def get_shape(self, obj):
try:
return list(obj.shape)
except AttributeError:
return 'unknown'
return '?'

def pre(self, module, input):
self.input_shape = self.get_shape(input)
Expand Down Expand Up @@ -655,7 +655,7 @@ def post(self, module, input, output):
self.maybe_add(t, 'tensors_learnable', 'tensors_fixed')

def _to_str(self, value):
return f'{value:6}'
return f'{value:6_}'

@property
def data(self):
Expand Down Expand Up @@ -706,7 +706,7 @@ def get_size(self, tensor):
return tensor.nelement() * tensor.element_size()

def _to_str(self, value):
return f'{value:6} B'
return f'{value:6_} B'

@property
def data(self):
Expand Down
14 changes: 12 additions & 2 deletions padertorch/train/runtime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def test_run(
*,
deterministic_atol=1e-5,
deterministic_rtol=1e-5,
loss_atol=1e-6,
loss_rtol=1e-6,
virtual_minibatch_size=None,
):
"""

Expand Down Expand Up @@ -139,6 +142,13 @@ def backup_state_dict(trainer: pt.Trainer):
'epoch',
new=-1,
))
if virtual_minibatch_size is not None:
assert virtual_minibatch_size > 0, virtual_minibatch_size
exit_stack.enter_context(mock.patch.object(
trainer,
'virtual_minibatch_size',
new=virtual_minibatch_size,
))

class SpyMagicMock(mock.MagicMock):
def __init__(self, *args, **kw):
Expand Down Expand Up @@ -316,8 +326,8 @@ def trainer_step_mock_to_inputs_output_review(review_mock):
# nested_test_assert_allclose(dt4['review'], dt8['review'])

# Expect that the initial loss is equal for two runs
nested_test_assert_allclose(dt1['loss'], dt5['loss'], rtol=1e-6, atol=1e-6)
nested_test_assert_allclose(dt2['loss'], dt6['loss'], rtol=1e-6, atol=1e-6)
nested_test_assert_allclose(dt1['loss'], dt5['loss'], rtol=loss_rtol, atol=loss_atol)
nested_test_assert_allclose(dt2['loss'], dt6['loss'], rtol=loss_rtol, atol=loss_atol)
try:
with np.testing.assert_raises(AssertionError):
# Expect that the loss changes after training.
Expand Down
6 changes: 6 additions & 0 deletions padertorch/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ def test_run(
temporary_directory=None,
deterministic_atol=1e-5,
deterministic_rtol=1e-5,
loss_atol=1e-6,
loss_rtol=1e-6,
virtual_minibatch_size=None,
):
"""
Run a test on the trainer instance (i.e. model test).
Expand Down Expand Up @@ -193,6 +196,9 @@ def test_run(
temporary_directory=temporary_directory,
deterministic_atol=deterministic_atol,
deterministic_rtol=deterministic_rtol,
loss_atol=loss_atol,
loss_rtol=loss_rtol,
virtual_minibatch_size=virtual_minibatch_size,
)

def train(
Expand Down