In [None]:
#default_exp mlp

In [None]:
#exports
from kesscore.base import *
import copy

In [None]:
#exports
class MultiActs(Module):
    '''Given acts=[a0,a1] and v=[v0,v1] returns [a0(v0),a1(v0),a0(v1),a1(v1)]'''
    def __init__(self, *acts): self.acts = nn.ModuleList(acts)
    def forward(self, x): 
        if len(self.acts)==1: return self.acts[0](x)
        return interleaved(*[f(x) for f in self.acts], dim=1)

In [None]:
_t,_r = [[-1,-2]],[[0,-1,0,-2]]
test_eq(MultiActs(nn.ReLU(),nn.Identity())(torch.tensor(_t)),torch.tensor(_r))

In [None]:
#exports
def _listify(x): return x if isinstance(x, list) else [x]
def _init_acts(acts, acts_params):
    '''Init act using act_params. and combine using MultiActs
    
    _init_acts(nn.ReLU, {'inplace':True})
    _init_acts([nn.ReLU,nn.ReLU], {'inplace':True})       #two relu with inplace
    _init_acts([nn.ReLU,nn.ReLU], [{'inplace':True}, {}]) #two relu, first with inplace
    _init_acts(nn.ReLU,           [{'inplace':True}, {}]) #two relu, first with inplace
    '''
    acts,acts_params = map(_listify, [acts, acts_params])

    if len(acts)==0: return nn.Identity(),1
    
    n1,n2 = map(len, [acts, acts_params])
    assert n1==n2 or 1 in [n1,n2], f'either equal or can be broadcast {[n1,n2]}'
    acts = [act(**params) for act,params in zip_cycle_longest(acts,acts_params)]
    
    return MultiActs(*acts),len(acts)

In [None]:
#exports
class Linear(Module):
    __repr__=basic_repr('in_channels,out_channels,bias,groups')
    def __init__(self, in_channels, out_channels, bias=True, groups=1):
        store_attr()
        if groups == 1: self._m = nn.Linear(in_channels, out_channels, bias)
        else:           self._m = nn.Conv2d(in_channels, out_channels, 1, bias=bias, groups=groups)
    def forward(self, x):
        test_eq(x.ndim, 2)
        def to2d(x): return x.view(*x.shape[:2])
        def to4d(x): return x.view(*x.shape[:2], 1, 1)
        fs = [to4d, self._m, to2d] if self.groups != 1 else [self._m]
        return compose(*fs)(x)

In [None]:
test_eq(repr(Linear(10,20,True)), 'Linear(in_channels=10, out_channels=20, bias=True, groups=1)')
test_eq(Linear(10,20,groups=2)(torch.zeros(5,10)).shape, [5,20])
test_eq(Linear(10,20)._m.weight.shape, [20,10])
test_eq(Linear(10,20,groups=2)._m.weight.shape, [20,5,1,1])

In [None]:
#exports
def _linear_act_norm(c_in, c_out, *, is_final, groups, bias=True, bn=nn.BatchNorm1d, bn_params={}, acts=[nn.LeakyReLU], acts_params={}):
    model = Linear(c_in, c_out, bias, groups)
    if is_final: return model, c_in, c_out
    acts, expansion = _init_acts(acts=acts, acts_params=acts_params)
    bn = bn(c_out * expansion, **bn_params)
    return nn.Sequential(model, acts, bn), c_in, c_out * expansion

In [None]:
#exports
def _build_channels(*, c_in, c_mid, c_out, n_layers):
    assert n_layers >= 1
    return [c_in] + [c_mid]*(n_layers-1) + [c_out]
def _build_heads(*, channels, heads, groups, in_groups):
    assert groups == in_groups == 1,'if you use head, dont touch groups'
    channels = channels[:1] + [c*heads for c in channels[1:]]
    groups = heads
    return channels,groups

In [None]:
#exports
@delegates(_linear_act_norm, but='groups')
def MLP(*, c_in=None, c_mid=None, c_out=None, n_layers=None, channels=None, groups=1, in_groups=1, heads=None, bias_last=True, **kwargs):
    test_all_eq(c_in, c_mid, c_out, n_layers, map=isinstance(NoneType))
    assert (c_in is None) != (channels is None),'either channels of in\\mid\\out\\nlayers'
    if in_groups != 1: assert in_groups == groups, 'unexpected usage'
    if c_in is not None: channels = _build_channels(c_in=c_in, c_mid=c_mid, c_out=c_out, n_layers=n_layers)
    if heads is not None: channels,groups = _build_heads(channels=channels, heads=heads, groups=groups, in_groups=in_groups)
    assert len(channels) >= 2
    blocks,c_in,g = OrderedDict(),channels[0],in_groups
    for i,c_out in enumerate(channels[1:-1]):
        m,_,c_in = _linear_act_norm(c_in,c_out,groups=g,is_final=False,**kwargs)
        g = groups
        blocks[f'block_{i}'] = m
    if bias_last is not None: kwargs['bias'] = bias_last
    head,*_ = _linear_act_norm(c_in,channels[-1],groups=in_groups,is_final=True,**kwargs)
    blocks['head'] = head
    return nn.Sequential(blocks)

In [None]:
mlp = MLP(channels=[10,20,30], bn=nn.BatchNorm1d, bn_params={'affine':False}, 
          acts=[nn.ReLU, nn.LeakyReLU], acts_params=[{},{'negative_slope':1e-3}])
test_eq(len(mlp), 2)
test_fail(lambda:mlp(torch.zeros(10,20)), contains='size mismatch')
test_eq(mlp    (torch.zeros(10,10)).shape, [10, 30])
test_eq(mlp[:1](torch.zeros(10,10)).shape, [10, 40])

In [None]:
mlp = MLP(channels=[10,20,20,30], bn=lambda x:nn.Identity(), heads=3)
inp  = torch.ones(1,10)
mres = mlp[:2](inp)
test_eq(mres.shape, [1,60])
res = mlp[2:](mres)
test_eq(res.shape, [1,90])
grad = torch.autograd.grad(res[0,0], mres)[0]
grad_t,grad_f = grad[0, :20],grad[0, 20:]
assert not (grad_t==0).any()
assert     (grad_f==0).all()

In [None]:
mlp = MLP(c_in=10, c_mid=20, c_out=30, n_layers=2, heads=10, bn=nn.BatchNorm1d, bn_params={'affine':False}, 
          acts=[nn.ReLU, nn.LeakyReLU], acts_params=[{},{'negative_slope':1e-3}])
test_eq(len(mlp), 2)
test_fail(lambda:mlp(torch.zeros(10,20)), contains='size mismatch')
test_eq(mlp    (torch.zeros(10,10)).shape, [10, 300])
test_eq(mlp[:1](torch.zeros(10,10)).shape, [10, 200*2])

In [None]:
from nbdev.sync import notebook2script;notebook2script()

Converted 00_functional.ipynb.
Converted 01_tests.ipynb.
Converted 02_tensor.ipynb.
Converted 03_images.ipynb.
Converted 04_random.ipynb.
Converted 05_domainadaptation.ipynb.
Converted 06_mlp.ipynb.
Converted 07_download.ipynb.
Converted index.ipynb.
