Skip to content

Commit

Permalink
Improve layer debug strings (increase signal-to-noise ratio):
Browse files Browse the repository at this point in the history
  - Include n_in/n_out info only when non-default values (!= 1).
  - Change from "MyLayer{in=2,out=2}" to "MyLayer_in2_out2".
  - Remove "sublayers=" from repr string.
  - Add spaces inside square brackets, remove commas.
  - Combined example:
      "Foo{in=1,out=2,sublayers=[Bar{in=1,out=1}, Baz{in=1,out=1}]}"
          --> "Foo_out2[ Bar Baz ]"
  - Add selection indices to Select layer name.

Samples using new __repr__:

1. Mnist model:  Serial[ Flatten Dense Relu Dense Relu Dense LogSoftmax ]

2. Atari CNN model, raw debug string:

    Serial[ F Branch_out4[ Select[0,0,0,0]_out4 Parallel_in4_out4[ Serial Serial[ ShiftRight ] Serial[ ShiftRight ShiftRight ] Serial[ ShiftRight ShiftRight ShiftRight ] ] ] Concatenate_in4 Conv Relu Conv Relu Flatten Dense Relu ]

2'. Atari CNN model, debug string + hand-done white space additions

    Serial[
      F
      Branch_out4[
        Select[0,0,0,0]_out4
        Parallel_in4_out4[
          Serial
          Serial[
            ShiftRight
          ]
          Serial[
            ShiftRight
            ShiftRight
          ]
          Serial[
            ShiftRight
            ShiftRight
            ShiftRight
          ]
        ]
      ]
      Concatenate_in4
      Conv
      Relu
      Conv
      Relu
      Flatten
      Dense
      Relu
    ]

PiperOrigin-RevId: 307533870
  • Loading branch information
j2i2 authored and Copybara-Service committed Apr 21, 2020
1 parent 172819b commit 0294404
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
12 changes: 7 additions & 5 deletions trax/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,16 @@ def __init__(self, n_in=1, n_out=1, name=None):
self._jit_cache = {}

def __repr__(self):
class_str = self._name
fields_str = 'in={},out={}'.format(self.n_in, self.n_out)
name_str = self._name
n_in, n_out = self.n_in, self.n_out
if n_in != 1: name_str += f'_in{n_in}'
if n_out != 1: name_str += f'_out{n_out}'
objs = self.sublayers
if objs:
objs_str = ', '.join(str(x) for x in objs)
return '{}{{{},sublayers=[{}]}}'.format(class_str, fields_str, objs_str)
objs_str = ' '.join(str(x) for x in objs)
return f'{name_str}[ {objs_str} ]'
else:
return '{}{{{}}}'.format(class_str, fields_str)
return name_str

def __call__(self, x, weights=None, state=None, rng=None, n_accelerators=0):
"""Makes Layer instances callable; for use in tests or interactive settings.
Expand Down
2 changes: 2 additions & 0 deletions trax/layers/combinators.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,8 @@ def Select(indices, n_in=None, name=None):
"""
if n_in is None:
n_in = max(indices) + 1
if name is None:
name = f'Select{indices}'.replace(' ', '')

@base.layer(n_in=n_in, n_out=len(indices), name=name)
def Selection(xs): # pylint: disable=invalid-name
Expand Down

0 comments on commit 0294404

Please sign in to comment.