-
Notifications
You must be signed in to change notification settings - Fork 0
/
container.py
48 lines (33 loc) · 994 Bytes
/
container.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
r"""Container modules"""
__all__ = [
'Sequential',
]
import jax
from textwrap import indent
from typing import *
from .module import Module
from ..tree_util import tree_repr
class Sequential(Module):
r"""Creates a composition of layers.
.. math:: y = f_n \circ \dots \circ f_2 \circ f_1(x)
Arguments:
layers: A sequence of layers :math:`f_1, f_2, \dots, f_n`.
"""
def __init__(self, *layers: Module):
self.layers = layers
def __call__(self, x: Any) -> Any:
r"""
Arguments:
x: The input :math:`x`.
Returns:
The output :math:`y`.
"""
for layer in self.layers:
x = layer(x)
return x
def tree_repr(self, **kwargs) -> str:
lines = (tree_repr(layer, **kwargs) for layer in self.layers)
lines = ',\n'.join(lines)
if lines:
lines = '\n' + indent(lines, ' ') + '\n'
return f'{self.__class__.__name__}({lines})'