diff --git a/trax/layers/base.py b/trax/layers/base.py index e61e43761..4e0d122b6 100644 --- a/trax/layers/base.py +++ b/trax/layers/base.py @@ -111,14 +111,16 @@ def __init__(self, n_in=1, n_out=1, name=None): self._jit_cache = {} def __repr__(self): - class_str = self._name - fields_str = 'in={},out={}'.format(self.n_in, self.n_out) + name_str = self._name + n_in, n_out = self.n_in, self.n_out + if n_in != 1: name_str += f'_in{n_in}' + if n_out != 1: name_str += f'_out{n_out}' objs = self.sublayers if objs: - objs_str = ', '.join(str(x) for x in objs) - return '{}{{{},sublayers=[{}]}}'.format(class_str, fields_str, objs_str) + objs_str = ' '.join(str(x) for x in objs) + return f'{name_str}[ {objs_str} ]' else: - return '{}{{{}}}'.format(class_str, fields_str) + return name_str def __call__(self, x, weights=None, state=None, rng=None, n_accelerators=0): """Makes Layer instances callable; for use in tests or interactive settings. diff --git a/trax/layers/combinators.py b/trax/layers/combinators.py index c50fe2392..5da95d99a 100644 --- a/trax/layers/combinators.py +++ b/trax/layers/combinators.py @@ -534,6 +534,8 @@ def Select(indices, n_in=None, name=None): """ if n_in is None: n_in = max(indices) + 1 + if name is None: + name = f'Select{indices}'.replace(' ', '') @base.layer(n_in=n_in, n_out=len(indices), name=name) def Selection(xs): # pylint: disable=invalid-name