In [None]:
#| hide
#| eval: false
! [ -e /content ] && pip install -Uqq xcube  # upgrade fastai on colab

In [None]:
#| default_exp layers

In [None]:
#| export
from fastai.imports import *
from fastai.torch_imports import *
from fastai.torch_core import *
from fastai.layers import *
from fastai.text.models.awdlstm import EmbeddingDropout, RNNDropout

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

# Layers

>Some layers which tops up the ones in [fastai](https://docs.fast.ai/layers.html)

## Basic manipulations and resizing

One can easily create a beautiful layer with minimum boilerplate using fastai utilities. We will show a simple example here. For details and extensive illustrations please refer to [decorated fastai layers](https://docs.fast.ai/layers.html#Basic-manipulations-and-resize).

An easy way to create a pytorch layer for a simple `func`

In [None]:
def _add2(x): return x+2
tst = Lambda(_add2)
x = torch.randn(10,20)
test_eq(tst(x), x+2)
tst2 = pickle.loads(pickle.dumps(tst))
test_eq(tst2(x), x+2)

## BatchNorm Layers

In [None]:
#|export
class LinBnFlatDrop(nn.Sequential):
    "Module grouping `BatchNorm1dFlat`, `Dropout` and `Linear` layers"
    def __init__(self, n_in, n_out, bn=True, p=0., act=None, lin_first=False):
        layers = [BatchNorm1dFlat(n_out if lin_first else n_in)] if bn else []
        if p != 0: layers.append(nn.Dropout(p))
        lin = [nn.Linear(n_in, n_out, bias=not bn)]
        if act is not None: lin.append(act)
        layers = lin+layers if lin_first else layers+lin
        super().__init__(*layers)

In [None]:
tst = LinBnFlatDrop(400, 1)

In [None]:
tst

LinBnFlatDrop(
  (0): BatchNorm1dFlat(400, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): Linear(in_features=400, out_features=1, bias=False)
)

In [None]:
#| export
class LinBnDrop(nn.Sequential):
    "Module grouping `BatchNorm1d`, `Dropout` and `Linear` layers"
    def __init__(self, n_in, n_out=None, bn=True, ln=True, p=0., act=None, lin_first=False, ndim=1):
        if not ln and lin_first: raise Exception(AssertionError)
        layers = [BatchNorm(n_out if ln and lin_first else n_in, ndim=ndim)] if bn else []
        if p != 0: layers.append(nn.Dropout(p))
        lin = [nn.Linear(n_in, n_out, bias=not bn)] if ln else []
        if ln and act is not None: lin.append(act)
        layers = lin+layers if lin_first else layers+lin
        super().__init__(*layers)

`LinBnDrop` is just like [fastai's LinBnDrop](https://github.com/fastai/fastai/blob/master/fastai/layers.py#L174) with an extra modality `ln` which provides the option of skipping the linear layer. That is, `BatchNorm` or the `Linear` layer is skipped if `bn=False` or `ln=False`, as is the dropout if `p=0`. Optionally, you can add an activation for after the linear layer with act.

In [None]:
tst = LinBnDrop(10, 20)
mods = list(tst.children())
assert isinstance(mods[0], nn.BatchNorm1d)
assert isinstance(mods[1], nn.Linear)

The `LinBnDrop` layer is not going to add an activation (even if provided) if `ln` is `False` but raise an error if `not ln and ln_first`: 

In [None]:
tst = LinBnDrop(10, 20, ln=False, p=0.02, act=nn.ReLU(inplace=True))
mods = list(tst.children())
assert isinstance(mods[0], nn.BatchNorm1d)
assert isinstance(mods[1], nn.Dropout)
test_fail(lambda : LinBnDrop(10, 20, ln=False, lin_first=True), contains='AssertionError')

## Embeddings

In [None]:
#| export
class Embedding(nn.Embedding):
    "Embedding layer with truncated normal initialization"
    def __init__(self, ni, nf, std=0.01, **kwargs):
        super().__init__(ni, nf, **kwargs)
        trunc_normal_(self.weight.data, std=std)

## Attention Layers for Extreme Multi-Label Classification

In [None]:
#| export
class XMLAttention(Module):
    "Compute label specific attention weights for each token in a sequence"
    def __init__(self, n_lbs, emb_sz, embed_p=0.0):
         store_attr('n_lbs,emb_sz,embed_p')
         self.lbs_weight = Embedding(n_lbs, emb_sz)
         # self.lbs_weight_dp = EmbeddingDropout(self.lbs_weight, embed_p)
         # self.lbs_weight.weight.data.normal_(0, 0.01)   
         # self.input_dp = RNNDropout(0.02)

    def forward(self, x):
        # x is the ouput of SentenceEncoder i.e., (bs, max_len tokens, nh)
        # lbs_emb = self.lbs_weight(torch.arange(self.n_lbs, device=x.device)) # pulling out the lbs embeddings
        lbs_emb = self.lbs_weight.weight
        # x_dp = self.input_dp(x)
        attn_wgts = F.softmax(x @ lbs_emb.transpose(0,1), dim=1) # lbl specific wts for each token (bs, max_len, n_lbs)
        return attn_wgts.transpose(1,2) @ x # for each lbl do a linear combi of all the tokens based on attn_wgts (bs, num_lbs, nh)
    

## Export -

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()