In [None]:
#| hide
#| eval: false
! [ -e /content ] && pip install -Uqq xcube  # upgrade fastai on colab

In [None]:
#| default_exp layers

In [None]:
#| export
from fastai.imports import *
from fastai.torch_imports import *
from fastai.torch_core import *
from fastai.layers import *
from fastai.text.models.awdlstm import EmbeddingDropout, RNNDropout

from xcube.utils import *

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

# Layers

>Some layers which tops up the ones in [fastai](https://docs.fast.ai/layers.html)

## Basic manipulations and resizing

One can easily create a beautiful layer with minimum boilerplate using fastai utilities. We will show a few simple examples here. For details and extensive illustrations please refer to [decorated fastai layers](https://docs.fast.ai/layers.html#Basic-manipulations-and-resize).

An easy way to create a pytorch layer for a simple `func`

In [None]:
show_doc(Lambda)

---

[source](https://github.com/fastai/fastai/blob/master/fastai/layers.py#LNone){target="_blank" style="float:right; font-size:smaller"}

### Lambda

>      Lambda (func)

An easy way to create a pytorch layer for a simple `func`

In [None]:
def _add2(x): return x+2
tst = Lambda(_add2)
x = torch.randn(10,20)
test_eq(tst(x), x+2)
tst2 = pickle.loads(pickle.dumps(tst))
test_eq(tst2(x), x+2)

In [None]:
show_doc(PartialLambda)

---

[source](https://github.com/fastai/fastai/blob/master/fastai/layers.py#LNone){target="_blank" style="float:right; font-size:smaller"}

### PartialLambda

>      PartialLambda (func)

Layer that applies `partial(func, **kwargs)`

In [None]:
def test_func(a,b=2): return a+b
tst = PartialLambda(test_func, b=5)
test_eq(tst(x), x+5)

## Linear

In [None]:
#| export
def _create_bias(size, with_zeros=False):
    if with_zeros: return nn.Parameter(torch.zeros(*size))
    return nn.Parameter(torch.zeros(*size).uniform_(-0.1, 0.1))

In [None]:
#| export
class ElemWiseLin(Module):
    initrange=0.1
    def __init__(self, dim0, dim1, add_bias=False, **kwargs):
        store_attr()
        self.lin = nn.Linear(dim1, dim0, **kwargs)
        # init_default(self.lin, func=partial(torch.nn.init.uniform_, a=-self.initrange, b=self.initrange))
        init_default(self.lin)
        if self.add_bias: self.bias = _create_bias((1, ))
        
    def forward(self, x):
        res = torch.addcmul(self.bias if self.add_bias else x.new_zeros(1), x, self.lin.weight)# * self.lin.weight
        return res #+ self.bias if self.add_bias else res

In [None]:
bs, dim0, dim1 = 10, 1271, 400
tst = ElemWiseLin(dim0, dim1)
test_eq(tst.lin.weight.shape, (dim0, dim1))
x = torch.randn(bs, dim0, dim1)
test_eq(tst(x).shape, (bs, dim0, dim1))

## BatchNorm Layers

In [None]:
#|export
class LinBnFlatDrop(nn.Sequential):
    "Module grouping `BatchNorm1dFlat`, `Dropout` and `Linear` layers"
    def __init__(self, n_in, n_out, bn=True, p=0., act=None, lin_first=False):
        layers = [BatchNorm1dFlat(n_out if lin_first else n_in)] if bn else []
        if p != 0: layers.append(nn.Dropout(p))
        lin = [nn.Linear(n_in, n_out, bias=not bn)]
        if act is not None: lin.append(act)
        layers = lin+layers if lin_first else layers+lin
        super().__init__(*layers)

In [None]:
#| export
class LinBnDrop(nn.Sequential):
    "Module grouping `BatchNorm1d`, `Dropout` and `Linear` layers"
    def __init__(self, n_in, n_out=None, bn=True, ln=True, p=0., act=None, lin_first=False, ndim=1):
        if not ln and lin_first: raise Exception(AssertionError)
        layers = [BatchNorm(n_out if ln and lin_first else n_in, ndim=ndim)] if bn else []
        if p != 0: layers.append(nn.Dropout(p))
        lin = [nn.Linear(n_in, n_out, bias=not bn)] if ln else []
        if ln and act is not None: lin.append(act)
        layers = lin+layers if lin_first else layers+lin
        super().__init__(*layers)

`LinBnDrop` is just like [fastai's LinBnDrop](https://github.com/fastai/fastai/blob/master/fastai/layers.py#L174) with an extra modality `ln` which provides the option of skipping the linear layer. That is, `BatchNorm` or the `Linear` layer is skipped if `bn=False` or `ln=False`, as is the dropout if `p=0`. Optionally, you can add an activation for after the linear layer with act.

In [None]:
tst = LinBnDrop(10, 20)
mods = list(tst.children())
assert isinstance(mods[0], nn.BatchNorm1d)
assert isinstance(mods[1], nn.Linear)

The `LinBnDrop` layer is not going to add an activation (even if provided) if `ln` is `False` but raise an error if `not ln and ln_first`: 

In [None]:
tst = LinBnDrop(10, 20, ln=False, p=0.02, act=nn.ReLU(inplace=True))
mods = list(tst.children())
assert isinstance(mods[0], nn.BatchNorm1d)
assert isinstance(mods[1], nn.Dropout)
test_fail(lambda : LinBnDrop(10, 20, ln=False, lin_first=True), contains='AssertionError')

## Embeddings

In [None]:
#| export
class Embedding(nn.Embedding):
    "Embedding layer with truncated normal initialization"
    def __init__(self, ni, nf, std=0.01, **kwargs):
        super().__init__(ni, nf, **kwargs)
        trunc_normal_(self.weight.data, std=std)

## Attention Layers for Extreme Multi-Label Classification

In [None]:
#| export
def _linear_attention(sentc:Tensor, # Sentence typically `(bs, bptt, nh)`
                   based_on: Embedding|Module # xcube's `Embedding(n_lbs, nh)` layer holding the label embeddings or a full fledged model
                  ):
    return sentc @ based_on.weight.transpose(0,1)

In [None]:
show_doc(_linear_attention)

---

[source](https://github.com/debjyotiSRoy/xcube/blob/main/xcube/layers.py#L66){target="_blank" style="float:right; font-size:smaller"}

### _linear_attention

>      _linear_attention (sentc:torch.Tensor,
>                         based_on:__main__.Embedding|fastai.torch_core.Module)

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| sentc | Tensor | Sentence typically `(bs, bptt, nh)` |
| based_on | __main__.Embedding \| fastai.torch_core.Module | xcube's `Embedding(n_lbs, nh)` layer holding the label embeddings or a full fledged model |

In [None]:
#|export
class _Pay_Attention:
    def __init__(self, f, based_on): store_attr('f,based_on')
    def __call__(self, sentc): return self.f(sentc, self.based_on)

In [None]:
#| export
def Linear_Attention(based_on: Module): return _Pay_Attention(_linear_attention, based_on)

In [None]:
#| export
def Ranked_Attention(based_on: Module):
    # TODO: Deb Create an architecture same as the Learning2Rank Model here, so that we can preload it just like fastai preloads LM encoder during text classification.
    pass

In [None]:
bs, bptt, nh, n_lbs = 16, 72, 100, 10
tst_lbs = Embedding(n_lbs, nh)
tst_Lin_Attn = Linear_Attention(tst_lbs)
attn_layer = Lambda(tst_Lin_Attn)
sentc = torch.randn(bs, bptt, nh)
test_eq(tst_Lin_Attn(sentc).shape , (bs, bptt, n_lbs))
test_eqs(attn_layer(sentc), tst_Lin_Attn(sentc), sentc @ tst_lbs.weight.transpose(0,1))

attn_layer2 = pickle.loads(pickle.dumps(attn_layer))
test_eqs(attn_layer2(sentc), sentc @ tst_lbs.weight.transpose(0,1))

In [None]:
#| export
def lincomb(t, wgts=None):
    "returns the linear combination of the dim1 of a 3d tensor of `t` based on `wgts` (if `wgts` is `None` just adds the rows)"
    if wgts is None: wgts = t.new_ones(t.size(0), 1, t.size(1))
    return torch.bmm(wgts, t) # wgts@t

In [None]:
t = torch.randn(16, 72, 100)
wgts = t.new_ones(t.size(0), 1, t.size(1))
test_eq(torch.bmm(wgts, t), lincomb(t))
rand_wgts = t.new_empty(t.size(0), 15, t.size(1)).random_(10)
# test_eq(lincomb(t, wgts=rand_wgts), torch.bmm(rand_wgts, t))
tst_LinComb = PartialLambda(lincomb, wgts=rand_wgts)
test_eq(tst_LinComb(t), torch.bmm(rand_wgts, t))

In [None]:
#| export
@torch.no_grad()
def topkmax(x, k=None, dim=1):
    """
    returns softmax of the 1th dim of 3d tensor x after zeroing out values in x smaller than `k`th largest.
    If k is `None` behaves like `x.softmax(dim=dim). Intuitively, `topkmax` hedges more compared to `F.softmax``
    """
    if dim!=1: raise NotImplementedError
    k = min(k if k is not None else np.inf, x.size(dim)-1)
    kth_largest = x.sort(dim=dim, descending=True).values[:,k,:][:,None,:].repeat(1, x.size(dim), 1)
    x[x < kth_largest] = 0.
    return x.softmax(dim=1)

TODO: DEB 
- Make it work for other dims
- Hyperparmam schedule the k in topkmax (start with high gradually decrease)

In [None]:
x = torch.randn((2, 7, 3))
test_eq(topkmax(x, dim=1) , F.softmax(x, dim=1))
test_fail(topkmax, args=(x, ), kwargs=dict(dim=-1)) # NotImplemented

In [None]:
#| export
class XMLAttention(Module):
    "Compute label specific attention weights for each token in a sequence"
    def __init__(self, n_lbs, emb_sz, embed_p=0.0):
        store_attr('n_lbs,emb_sz,embed_p')
        self.lbs = Embedding(n_lbs, emb_sz)
        # self.lbs_weight_dp = EmbeddingDropout(self.lbs_weight, embed_p)
        self.LinAttn = Lambda(Linear_Attention(self.lbs))

    def forward(self, sentc, mask):
        # sent is the ouput of SentenceEncoder i.e., (bs, max_len tokens, nh)
        attn_wgts = F.softmax(self.LinAttn(sentc), dim=1).masked_fill(mask[:,:,None], 0) # lbl specific wts for each token (bs, max_len, n_lbs)
        return lincomb(sentc, wgts=attn_wgts.transpose(1,2)), attn_wgts # for each lbl do a linear combi of all the tokens based on attn_wgts (bs, num_lbs, nh)

In [None]:
sentc = torch.randn(bs, bptt, nh)
mask = sentc.new_empty(sentc.size()[:-1]).random_(2).bool()
test_eq(mask.unique(), tensor([0., 1.]))
xml_attn = XMLAttention(n_lbs, nh)
attn, *wgts = xml_attn(sentc, mask)
test_eq(attn.shape, (bs, n_lbs, nh))
tst_lbs = xml_attn.lbs
tst_Lin_Attn = Linear_Attention(tst_lbs)
attn_layer = Lambda(tst_Lin_Attn)
attn_wgts = F.softmax(attn_layer(sentc), dim=1) # topkmax(attn_layer(sentc), dim=1)
# test_eq(attn, torch.bmm(attn_wgts.transpose(1,2), sentc))
test_eq(attn, torch.bmm(attn_wgts.masked_fill(mask[:, :, None], 0).transpose(1,2), sentc))

Test masking works:

In [None]:
sentc = torch.randn(bs, bptt, nh)
sentc = sentc.masked_fill(mask[:, :, None], 0)
assert sentc[mask].sum().item() == 0
attn, *wgts = xml_attn(sentc, mask)
assert sentc[mask].sum().item() == 0
attn_wgts = F.softmax(attn_layer(sentc), dim=1) # topkmax(attn_layer(sentc), dim=1)
test_eq(attn, torch.bmm(attn_wgts.transpose(1,2), sentc))

## Export -

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()