-
Notifications
You must be signed in to change notification settings - Fork 275
/
map_list.py
36 lines (26 loc) · 1 KB
/
map_list.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from typing import Callable, List, Optional, Tuple, TypeVar
from ..model import Model
InT = TypeVar("InT")
OutT = TypeVar("OutT")
def map_list(layer: Model[InT, OutT]) -> Model[List[InT], List[OutT]]:
"""Create a model that maps a child layer across list inputs."""
return Model("map_list", forward, layers=[layer], init=init)
def forward(
model: Model[List[InT], List[OutT]], Xs: List[InT], is_train: bool
) -> Tuple[List[OutT], Callable[[List[OutT]], List[InT]]]:
layer = model.layers[0]
Ys = []
callbacks = []
for X in Xs:
Y, get_dX = layer(X, is_train)
Ys.append(Y)
callbacks.append(get_dX)
def backprop_map_list(dYs: List[OutT]) -> List[InT]:
return [callback(dY) for callback, dY in zip(callbacks, dYs)]
return Ys, backprop_map_list
def init(
model: Model[List[InT], List[OutT]],
X: Optional[List[InT]] = None,
Y: Optional[List[OutT]] = None,
) -> None:
model.layers[0].initialize(X=X[0] if X else None, Y=Y[0] if Y else None)