In [None]:
#default_exp mlp

In [None]:
#exports
from kesscore.imports import *
from kesscore.functional import *
from kesscore.tensor import *
import copy

In [None]:
#exports
class MultiActs(nn.Sequential):
    '''Given acts=[a0,a1] and v=[v0,v1] returns [a0(v0),a1(v0),a0(v1),a1(v1)]'''
    def forward(self, x): return interleaved([f(x) for f in self], dim=1)

In [None]:
_t,_r = [[-1,-2]],[[0,-1,0,-2]]
test_eq(MultiActs(nn.ReLU(), nn.Identity())(torch.tensor(_t)),torch.tensor(_r))

In [None]:
#exports
def _init_acts(acts, acts_params):
    '''Init act using act_params. in case one of them is a list with multiple items, somewhat of a broadcast is used, and than combined using MultiAct'''
    if not isinstance(acts,        list): acts        = [acts]
    if not isinstance(acts_params, list): acts_params = [acts_params]
    na,nap = len(acts),len(acts_params)
    assert na==nap or 1 in [na,nap], f'either equal or one is eqaul to 1 {[na,nap]}'
    
    acts = [act(**params) for act,params in zip_cycle_longest(acts,acts_params)]
    if   len(acts)==0: return nn.Identity(), 1
    elif len(acts)==1: return acts[0], 1
    else:              return MultiActs(*acts), len(acts)

In [None]:
#exports
class Linear(Module):
    __repr__=basic_repr('in_channels,out_channels,bias,groups')
    def __init__(self, in_channels, out_channels, bias=True, groups=1):
        store_attr()
        if groups == 1: self.m = nn.Linear(in_channels, out_channels, bias)
        else:           self.m = nn.Conv2d(in_channels, out_channels, 1, bias=bias, groups=groups)
    def forward(self, x):
        test_eq(x.ndim, 2)
        if self.groups == 1: return self.m(x)
        x = x.view(*x.shape, 1, 1)
        x = self.m(x)
        return x.view(*x.shape[:2])

In [None]:
test_eq(repr(Linear(10,20,True)), 'Linear(in_channels=10, out_channels=20, bias=True, groups=1)')

In [None]:
#exports
def _linear_act_norm(c_in, c_out, *, is_final, groups, bias=True, bn=nn.BatchNorm1d, bn_params={}, acts=[nn.LeakyReLU], acts_params={}):
    model = Linear(c_in, c_out, bias, groups)
    if is_final: return model, c_in, c_out
    acts, expansion = _init_acts(acts=acts, acts_params=acts_params)
    bn = bn(c_out * expansion, **bn_params)
    return nn.Sequential(nn.Linear(c_in, c_out, bias), acts, bn), c_in, c_out * expansion

In [None]:
#exports
@delegates(_linear_act_norm, but='groups')
def MLP(*, c_in=None, c_mid=None, c_out=None, n_layers=None, channels=None, groups=1, in_groups=1, heads=None, **kwargs):
    L(c_in, c_mid, c_out, n_layers).map(isinstance(NoneType)).assert_all_eq()
    assert (c_in is None) != (channels is None),'either channels of in\\mid\\out\\nlayers'
    if c_in is not None: 
        assert n_layers >= 1
        channels = [c_in] + [c_mid]*(n_layers-1) + [c_out]
    assert len(channels) >= 2
    if in_groups != 1: assert in_groups == groups, 'unexpected usage'
    if heads is not None:
        assert groups == in_groups == 1, 'if you use head, dont touch groups'
        groups = heads
        channels = channels[:1] + [c*heads for c in channels[1:]]
    blocks,c_in = OrderedDict(),channels[0]
    for i,c_out in enumerate(channels[1:-1]):
        m,_,c_in = _linear_act_norm(c_in,c_out,groups=in_groups,is_final=False,**kwargs)
        in_groups = groups
        blocks[f'block_{i}'] = m
    head,*_ = _linear_act_norm(c_in,channels[-1],groups=in_groups,is_final=True,**kwargs)
    blocks['head'] = head
    return nn.Sequential(blocks)

In [None]:
mlp = MLP(channels=[10,20,30], bn=nn.BatchNorm1d, bn_params={'affine':False}, 
          acts=[nn.ReLU, nn.LeakyReLU], acts_params=[{},{'negative_slope':1e-3}])
test_eq(len(mlp), 2)
test_fail(lambda:mlp(torch.zeros(10,20)), contains='size mismatch, m1: [10 x 20], m2: [10 x 20]')
test_eq(mlp    (torch.zeros(10,10)).shape, [10, 30])
test_eq(mlp[:1](torch.zeros(10,10)).shape, [10, 40])

In [None]:
mlp = MLP(c_in=10, c_mid=20, c_out=30, n_layers=2, heads=10, bn=nn.BatchNorm1d, bn_params={'affine':False}, 
          acts=[nn.ReLU, nn.LeakyReLU], acts_params=[{},{'negative_slope':1e-3}])
test_eq(len(mlp), 2)
test_fail(lambda:mlp(torch.zeros(10,20)), contains='size mismatch, m1: [10 x 20], m2: [10 x 200]')
test_eq(mlp    (torch.zeros(10,10)).shape, [10, 300])
test_eq(mlp[:1](torch.zeros(10,10)).shape, [10, 200*2])

In [None]:
from nbdev.sync import notebook2script

In [None]:
notebook2script()

Converted 00_functional.ipynb.
Converted 01_images.ipynb.
Converted 02_download.ipynb.
Converted 03_tensor.ipynb.
Converted 04_random.ipynb.
Converted 05_domainadaptation.ipynb.
Converted 06_mlp.ipynb.
Converted 07_tests.ipynb.
Converted index.ipynb.
