Skip to content

Commit

Permalink
Contrib/sequential update (#301)
Browse files Browse the repository at this point in the history
* sequential-update

* version fix
  • Loading branch information
Scitator committed Aug 15, 2019
1 parent 5ab8354 commit a1d0963
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion catalyst/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "19.08.1"
__version__ = "19.08.4"
16 changes: 10 additions & 6 deletions catalyst/contrib/models/sequential.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Union
from collections import OrderedDict

import torch
Expand All @@ -22,22 +23,26 @@ def __init__(
self,
hiddens,
layer_fn=nn.Linear,
bias=True,
norm_fn=None,
activation_fn=nn.ReLU,
activation_fn=None,
bias=True,
dropout=None,
layer_order=None,
residual=False
residual: Union[bool, str] = False,
):

super().__init__()
assert len(hiddens) > 1, "No sequence found"

layer_fn = MODULES.get_if_str(layer_fn)
activation_fn = MODULES.get_if_str(activation_fn)
norm_fn = MODULES.get_if_str(norm_fn)
dropout = MODULES.get_if_str(dropout)
activation_fn = MODULES.get_if_str(activation_fn)
inner_init = create_optimal_inner_init(nonlinearity=activation_fn)

if isinstance(residual, bool) and residual:
residual = "hard"

layer_order = layer_order or ["layer", "norm", "drop", "act"]

if isinstance(dropout, float):
Expand Down Expand Up @@ -65,15 +70,14 @@ def _activation_fn(f_in, f_out, bias):
}

net = []

for i, (f_in, f_out) in enumerate(pairwise(hiddens)):
block = []
for key in layer_order:
fn = name2fn[key](f_in, f_out, bias)
if fn is not None:
block.append((f"{key}", fn))
block = torch.nn.Sequential(OrderedDict(block))
if residual:
if residual == "hard" or (residual == "soft" and f_in == f_out):
block = ResidualWrapper(net=block)
net.append((f"block_{i}", block))

Expand Down

0 comments on commit a1d0963

Please sign in to comment.