In [None]:
#| hide
#| eval: false
! [ -e /content ] && pip install -Uqq xcube  # upgrade fastai on colab

In [None]:
#| default_exp layers

In [None]:
#| export
from __future__ import annotations
from typing import Union
from fastai.imports import *
from fastai.torch_imports import *
from fastai.torch_core import *
from fastai.layers import *
from fastai.text.models.awdlstm import EmbeddingDropout, RNNDropout

from xcube.utils import *

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

# Layers

>Some layers which tops up the ones in [fastai](https://docs.fast.ai/layers.html)

## Basic manipulations and resizing

One can easily create a beautiful layer with minimum boilerplate using fastai utilities. We will show a few simple examples here. For details and extensive illustrations please refer to [decorated fastai layers](https://docs.fast.ai/layers.html#Basic-manipulations-and-resize).

An easy way to create a pytorch layer for a simple `func`

In [None]:
show_doc(Lambda)

---

[source](https://github.com/fastai/fastai/blob/master/fastai/layers.py#LNone){target="_blank" style="float:right; font-size:smaller"}

### Lambda

>      Lambda (func)

An easy way to create a pytorch layer for a simple `func`

In [None]:
def _add2(x): return x+2
tst = Lambda(_add2)
x = torch.randn(10,20)
test_eq(tst(x), x+2)
tst2 = pickle.loads(pickle.dumps(tst))
test_eq(tst2(x), x+2)

In [None]:
show_doc(PartialLambda)

---

[source](https://github.com/fastai/fastai/blob/master/fastai/layers.py#LNone){target="_blank" style="float:right; font-size:smaller"}

### PartialLambda

>      PartialLambda (func)

Layer that applies `partial(func, **kwargs)`

In [None]:
def test_func(a,b=2): return a+b
tst = PartialLambda(test_func, b=5)
test_eq(tst(x), x+5)

## Linear

In [None]:
#| export
def _create_bias(size, with_zeros=False):
    if with_zeros: return nn.Parameter(torch.zeros(*size))
    return nn.Parameter(torch.zeros(*size).uniform_(-0.1, 0.1))

In [None]:
#| export
class ElemWiseLin(Module):
    initrange=0.1
    def __init__(self, dim0, dim1, add_bias=False, **kwargs):
        store_attr()
        self.lin = nn.Linear(dim1, dim0, **kwargs)
        # init_default(self.lin, func=partial(torch.nn.init.uniform_, a=-self.initrange, b=self.initrange))
        init_default(self.lin)
        if self.add_bias: self.bias = _create_bias((1, ))
        
    def forward(self, x):
        res = torch.addcmul(self.bias if self.add_bias else x.new_zeros(1), x, self.lin.weight)# * self.lin.weight
        return res #+ self.bias if self.add_bias else res

In [None]:
bs, dim0, dim1 = 10, 1271, 400
tst = ElemWiseLin(dim0, dim1)
test_eq(tst.lin.weight.shape, (dim0, dim1))
x = torch.randn(bs, dim0, dim1)
test_eq(x * tst.lin.weight, tst(x))
test_eq(tst(x).shape, (bs, dim0, dim1))

## BatchNorm Layers

In [None]:
#|export
class LinBnFlatDrop(nn.Sequential):
    "Module grouping `BatchNorm1dFlat`, `Dropout` and `Linear` layers"
    def __init__(self, n_in, n_out, bn=True, p=0., act=None, lin_first=False):
        layers = [BatchNorm1dFlat(n_out if lin_first else n_in)] if bn else []
        if p != 0: layers.append(nn.Dropout(p))
        lin = [nn.Linear(n_in, n_out, bias=not bn)]
        if act is not None: lin.append(act)
        layers = lin+layers if lin_first else layers+lin
        super().__init__(*layers)

In [None]:
#| export
class LinBnDrop(nn.Sequential):
    "Module grouping `BatchNorm1d`, `Dropout` and `Linear` layers"
    def __init__(self, n_in, n_out=None, bn=True, ln=True, p=0., act=None, lin_first=False, ndim=1):
        if not ln and lin_first: raise Exception(AssertionError)
        layers = [BatchNorm(n_out if ln and lin_first else n_in, ndim=ndim)] if bn else []
        if p != 0: layers.append(nn.Dropout(p))
        lin = [nn.Linear(n_in, n_out, bias=not bn)] if ln else []
        if ln and act is not None: lin.append(act)
        layers = lin+layers if lin_first else layers+lin
        super().__init__(*layers)

`LinBnDrop` is just like [fastai's LinBnDrop](https://github.com/fastai/fastai/blob/master/fastai/layers.py#L174) with an extra modality `ln` which provides the option of skipping the linear layer. That is, `BatchNorm` or the `Linear` layer is skipped if `bn=False` or `ln=False`, as is the dropout if `p=0`. Optionally, you can add an activation for after the linear layer with act.

In [None]:
tst = LinBnDrop(10, 20)
mods = list(tst.children())
assert isinstance(mods[0], nn.BatchNorm1d)
assert isinstance(mods[1], nn.Linear)

The `LinBnDrop` layer is not going to add an activation (even if provided) if `ln` is `False` but raise an error if `not ln and ln_first`: 

In [None]:
tst = LinBnDrop(10, 20, ln=False, p=0.02, act=nn.ReLU(inplace=True))
mods = list(tst.children())
assert isinstance(mods[0], nn.BatchNorm1d)
assert isinstance(mods[1], nn.Dropout)
test_fail(lambda : LinBnDrop(10, 20, ln=False, lin_first=True), contains='AssertionError')

## Embeddings

In [None]:
#| export
class Embedding(nn.Embedding):
    "Embedding layer with truncated normal initialization"
    def __init__(self, ni, nf, std=0.01, **kwargs):
        super().__init__(ni, nf, **kwargs)
        trunc_normal_(self.weight.data, std=std)

## Attention Layers for Extreme Multi-Label Classification

In [None]:
#| export
def _linear_attention(sentc:Tensor, # Sentence typically `(bs, bptt, nh)` output of `SentenceEncoder`
                      based_on: nn.Embedding|Module # xcube's `Embedding(n_lbs, nh)` layer holding the label embeddings or a full fledged model
                  ):
    return sentc @ based_on.weight.transpose(0,1)

In [None]:
show_doc(_linear_attention)

---

[source](https://github.com/debjyotiSRoy/xcube/blob/main/xcube/layers.py#L68){target="_blank" style="float:right; font-size:smaller"}

### _linear_attention

>      _linear_attention (sentc:torch.Tensor,
>                         based_on:torch.nn.modules.sparse.Embedding|fastai.torc
>                         h_core.Module)

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| sentc | Tensor | Sentence typically `(bs, bptt, nh)` output of `SentenceEncoder` |
| based_on | nn.Embedding \| Module | xcube's `Embedding(n_lbs, nh)` layer holding the label embeddings or a full fledged model |

In [None]:
#| export
def _planted_attention(sentc: Tensor, # Sentence typically `(bs, bptt)` containing the vocab idxs that goes inside the encoder
                       brain: Tensor # label specific attn wgts for each token in vocab, typically of shape `(vocab_sz, n_lbs)`
                     ):
    return brain[sentc.long()]

In [None]:
show_doc(_planted_attention)

---

[source](https://github.com/debjyotiSRoy/xcube/blob/main/xcube/layers.py#L74){target="_blank" style="float:right; font-size:smaller"}

### _planted_attention

>      _planted_attention (sentc:torch.Tensor, brain:torch.Tensor)

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| sentc | Tensor | Sentence typically `(bs, bptt)` containing the vocab idxs that goes inside the encoder |
| brain | Tensor | label specific attn wgts for each token in vocab, typically of shape `(vocab_sz, n_lbs)` |

In [None]:
#| export
def _diffntble_planted_attention(sentc_dec: Tensor, # Sentence `(bs, bptt)` typically containing the vocab idxs obtained after decoding what comes out of the encoder
                         l2r: nn.ModuleDict # containing `nn.Embedding` for `token_factors`, `token_bias`, `label_factors` and `label_bias` from pretrained L2R model
                        ):
    
    return l2r['token_factors'](sentc_dec.long()) @ l2r['label_factors'].weight.T + l2r['token_bias'](sentc_dec.long()) + l2r['label_bias'].weight.T

In [None]:
show_doc(_diffntble_planted_attention)

---

[source](https://github.com/debjyotiSRoy/xcube/blob/main/xcube/layers.py#L80){target="_blank" style="float:right; font-size:smaller"}

### _diffntble_planted_attention

>      _diffntble_planted_attention (sentc_dec:torch.Tensor,
>                                    l2r:torch.nn.modules.container.ModuleDict)

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| sentc_dec | Tensor | Sentence `(bs, bptt)` typically containing the vocab idxs obtained after decoding what comes out of the encoder |
| l2r | nn.ModuleDict | containing `nn.Embedding` for `token_factors`, `token_bias`, `label_factors` and `label_bias` from pretrained L2R model |

In [None]:
#|export
class _Pay_Attention:
    def __init__(self, f, based_on): store_attr('f,based_on')
    def __call__(self, sentc): return self.f(sentc, self.based_on)

In [None]:
#| export
def Linear_Attention(based_on: Module): return _Pay_Attention(_linear_attention, based_on)

In [None]:
bs, bptt, nh, n_lbs = 16, 72, 100, 10
tst_lbs = Embedding(n_lbs, nh)
tst_Lin_Attn = Linear_Attention(tst_lbs)
attn_layer = Lambda(tst_Lin_Attn)
sentc = torch.randn(bs, bptt, nh)
test_eq(tst_Lin_Attn(sentc).shape , (bs, bptt, n_lbs))
test_eqs(attn_layer(sentc), tst_Lin_Attn(sentc), sentc @ tst_lbs.weight.transpose(0,1))

attn_layer2 = pickle.loads(pickle.dumps(attn_layer))
test_eqs(attn_layer2(sentc), sentc @ tst_lbs.weight.transpose(0,1))

In [None]:
#| export
def Planted_Attention(brain: Tensor): return _Pay_Attention(_planted_attention, brain)

In [None]:
bs, bptt, vocab_sz, n_lbs = 16, 72, 100, 10
inp = torch.zeros((bs, bptt)).random_(vocab_sz)
brain = torch.randn(vocab_sz, n_lbs)
tst_planted_Attn = Planted_Attention(brain)
attn_layer = Lambda(tst_planted_Attn)
attn = brain[inp.long()]
test_eq(attn.shape, (bs, bptt, n_lbs))
test_eqs(attn, tst_planted_Attn(inp), attn_layer(inp))
# test_eq(brain[sentc[8].long()][:, 4], attn[8, :, 4]) # looking at the attn wgts of the 8th sentence and 4th label 

In [None]:
#| export
class PlantedLMDecoder(Module):
    def __init__(self, 
        n_out:int, # vocab_sz 
        n_hid:int, # Number of features in encoder last layer output
        output_p:float=0.1, # Input dropout probability
        plant_wgts:dict=None, # If supplied loads `plant_wgts` into decoder
        bias:bool=True # If `False` the layer will not learn additive bias
    ):
        self.decoder = nn.Linear(n_hid, n_out, bias=bias)
        self.output_dp = RNNDropout(output_p)
        if plant_wgts: self.load_state_dict(plant_wgts)

    def forward(self, input):
        dp_inp = self.output_dp(input)
        return self.decoder(dp_inp).softmax(dim=-1).argmax(dim=-1)

In [None]:
from fastai.text.models.awdlstm import awd_lstm_lm_config
from fastai.text.models.core import AWD_LSTM

In [None]:
bs, bptt, vocab_sz, n_lbs, n_fac = 16, 72, 100, 10, 200
config = awd_lstm_lm_config.copy()
emb_sz, output_p, out_bias = map(config.get, ['emb_sz', 'output_p', 'out_bias'])
lm_decoder_pretrained_wgts = {'decoder.weight': torch.randn(vocab_sz, emb_sz), 
                'decoder.bias': torch.randn(vocab_sz, )}
lm_decoder = PlantedLMDecoder(vocab_sz, emb_sz, output_p=output_p*0.3, plant_wgts=lm_decoder_pretrained_wgts, bias=out_bias)
test_eq(lm_decoder.decoder.weight, lm_decoder_pretrained_wgts['decoder.weight'])
test_eq(lm_decoder.decoder.bias, lm_decoder_pretrained_wgts['decoder.bias'])
enc = AWD_LSTM(vocab_sz, emb_sz, 10, 3)
inp = torch.randint(0, vocab_sz, (bs,bptt))
# inp = torch.zeros((bs, bptt)).random_(vocab_sz)
sentc = enc(inp)
sentc_decoded = lm_decoder(sentc)
test_eq(sentc_decoded.shape, inp.shape)

In [None]:
#| export
def Diffntble_Planted_Attention(l2r: nn.ModuleDict): return _Pay_Attention(_diffntble_planted_attention, l2r)

In [None]:
tf = torch.randn(vocab_sz, n_fac)
tb = torch.randn(vocab_sz, 1)
lf = torch.randn(n_lbs, n_fac)
lb = torch.randn(n_lbs, 1)
wgts = dict(token_factors=tf, token_bias=tb, label_factors=lf, label_bias=lb)
l2r = nn.ModuleDict({k: nn.Embedding(*v.size()) for k,v in wgts.items()})
assert isinstance(l2r, nn.Module)
test_eq(l2r.keys(), ['token_factors', 'token_bias', 'label_factors', 'label_bias'])
tst_diffntble_planted_Attn = Diffntble_Planted_Attention(l2r)
attn_layer = Lambda(tst_diffntble_planted_Attn)
# attn = attn_layer(lm_decoder(sentc))
attn = attn_layer(sentc_decoded)
test_eq(attn.shape, (bs, bptt, n_lbs))
test_eqs(attn,
         tst_diffntble_planted_Attn(sentc_decoded),
         l2r['token_factors'](sentc_decoded.long()) @ l2r['label_factors'].weight.T + l2r['token_bias'](sentc_decoded.long()) + l2r['label_bias'].weight.T)

In [None]:
#| export
def lincomb(t, wgts=None):
    "returns the linear combination of the dim1 of a 3d tensor of `t` based on `wgts` (if `wgts` is `None` just adds the rows)"
    if wgts is None: wgts = t.new_ones(t.size(0), 1, t.size(1))
    return torch.bmm(wgts, t) # wgts@t

In [None]:
t = torch.randn(16, 72, 100)
wgts = t.new_ones(t.size(0), 1, t.size(1))
test_eq(torch.bmm(wgts, t), lincomb(t))
rand_wgts = t.new_empty(t.size(0), 15, t.size(1)).random_(10)
# test_eq(lincomb(t, wgts=rand_wgts), torch.bmm(rand_wgts, t))
tst_LinComb = PartialLambda(lincomb, wgts=rand_wgts)
test_eq(tst_LinComb(t), torch.bmm(rand_wgts, t))

In [None]:
#| export
@patch
@torch.no_grad()
def topkmax(self:Tensor, k=None, dim=1):
    """
    returns softmax of the 1th dim of 3d tensor x after zeroing out values in x smaller than `k`th largest.
    If k is `None` behaves like `x.softmax(dim=dim). Intuitively, `topkmax` hedges more compared to `F.softmax``
    """
    if dim!=1: raise NotImplementedError
    k = min(k if k is not None else np.inf, self.size(dim)-1)
    kth_largest = self.sort(dim=dim, descending=True).values[:,k,:][:,None,:].repeat(1, self.size(dim), 1)
    self[self < kth_largest] = 0.
    return self.softmax(dim=1)

In [None]:
#| export
def split_sort(t, sp_dim, sort_dim, sp_sz=500, **kwargs):
    if t.ndim==1: return t.sort(dim=sort_dim, **kwargs).values
    return torch.cat([s.sort(dim=sort_dim, **kwargs).values for s in torch.split(t, split_size_or_sections=sp_sz, dim=sp_dim)], dim=sp_dim)

In [None]:
t = torch.randn(16, 106, 819)
s_t = split_sort(t, sp_dim=1, sort_dim=-1, sp_sz=14)
test_eq(t.sort(dim=-1).values, s_t)

In [None]:
#| export
@patch
@torch.no_grad()
def inattention(self:Tensor, k=None, sort_dim=0, sp_dim=0):
    """
    returns `self` after zeroing out values smaller than `k`th largest in dimension `dim`.
    If k is `None` behaves like returns self.
    """
    k = min(k if k is not None else np.inf, self.size(sort_dim)-1)
    k_slice= [slice(None)]*self.ndim
    # rep = [1]*self.ndim
    k_slice[sort_dim] = k
    if len(k_slice) == 1: k_slice=k_slice[0]
    # rep[sort_dim] = self.size(sort_dim)
    kth_largest = split_sort(self, sp_dim=sp_dim, sort_dim=sort_dim, descending=True)[k_slice].unsqueeze(dim=sort_dim)#.repeat(*rep)
    clone = self.detach().clone()
    clone[clone < kth_largest] = 0.
    return clone

TODO: DEB 
- ~~Make it work for other dims~~
- Hyperparmam schedule the k in topkmax (start with high gradually decrease)

In [None]:
x = torch.randn((2, 7, 3))
test_eq(x.topkmax() , F.softmax(x, dim=1))
# test_fail(topkmax, args=(x, ), kwargs=dict(dim=-1)) # NotImplemented
test_fail(x.topkmax, kwargs=dict(dim=-1)) # NotImplemented
test_eq(x.inattention(k=2, sort_dim=-1), 
        torch.where(x < x.sort(dim=-1, descending=True).values[:, :, 2].unsqueeze(dim=-1), 0, x))

In [None]:
x = torch.randn((8820,) )
x_inattn = torch.where(x < x.sort(dim=0, descending=True).values[2].unsqueeze(dim=0), 0, x)
x_inattn1 = x.inattention(k=2, sort_dim=0)
test_eq(x_inattn, x_inattn1)

In [None]:
#| export
from xcube.utils import *

In [None]:
#| export
class XMLAttention(Module):
    "Compute label specific attention weights for each token in a sequence"
    def __init__(self, n_lbs, emb_sz, embed_p=0.0, plant=0.5):
        store_attr('n_lbs,emb_sz,embed_p,plant')
        self.lbs = Embedding(n_lbs, emb_sz)
        # self.lbs_weight_dp = EmbeddingDropout(self.lbs_weight, embed_p)
        self.attn = Lambda(Linear_Attention(self.lbs))
    
    @property
    def attn(self): return self._attn
    @attn.setter
    def attn(self, a): self._attn = a
    
    def forward(self, inp, sentc, mask):
        # sent is the ouput of SentenceEncoder i.e., (bs, max_len tokens, nh)
        test_eqs(inp.shape, sentc.shape[:-1], mask.shape)
        if self.attn.func.f is _linear_attention:
            top_tok_attn_wgts = F.softmax(self.attn(sentc), dim=1).masked_fill(mask[:,:,None], 0) # lbl specific wts for each token (bs, max_len, n_lbs)
            lbs_cf = None
        elif self.attn.func.f is _planted_attention:
            attn_wgts = self.attn(inp).masked_fill(mask[:,:,None], 0)
            top_tok_attn_wgts =attn_wgts.inattention(k=15, sort_dim=1)
            top_lbs_attn_wgts = attn_wgts.clone().permute(0,2,1).inattention(k=5, sort_dim=1).permute(0,2,1).contiguous() # applying `inattention` across the lbs dim
            lbs_cf = top_lbs_attn_wgts.sum(dim=1) #shape (bs, n_lbs)
        elif self.attn.func.f is _diffntble_planted_attention: #raise NotImplementedError
            # top_tok_attn_wgts = F.softmax(self.attn(self.lm_decoder(sentc)), dim=1).masked_fill(mask[:,:,None], 0).inattention(k=15, sort_dim=1) # lbl specific wts for each token (bs, max_len, n_lbs)
            # top_tok_attn_wgts0 = self.plant_attn(inp).masked_fill(mask[:,:,None], 0).inattention(k=15, sort_dim=1)
            top_tok_lin_attn_wgts = self.lin_attn(sentc).softmax(dim=1).masked_fill(mask[:,:,None], 0) # lbl specific wts for each token (bs, max_len, n_lbs)
            
            # change
            top_tok_plant_attn_wgts = self.attn(self.lm_decoder(sentc)).masked_fill(mask[:,:,None], 0).inattention(k=30, sort_dim=1).softmax(dim=1) # lbl specific wts for each token (bs, max_len, n_lbs)
            # top_tok_plant_attn_wgts = self.attn(inp).masked_fill(mask[:,:,None], 0).inattention(k=30, sort_dim=1).softmax(dim=1) # lbl specific wts for each token (bs, max_len, n_lbs)
            # change

            top_tok_attn_wgts = (1-self.plant)*top_tok_lin_attn_wgts + self.plant*top_tok_plant_attn_wgts
            # top_tok_attn_wgts = F.softmax(self.attn(inp), dim=1).masked_fill(mask[:,:,None], 0).inattention(k=15, sort_dim=1) # lbl specific wts for each token (bs, max_len, n_lbs)
            # attn_wgts = self.attn(inp).masked_fill(mask[:,:,None], 0)
            # top_tok_attn_wgts = attn_wgts.inattention(k=15, sort_dim=1)
            lbs_cf = None
        return lincomb(sentc, wgts=top_tok_attn_wgts.transpose(1,2)), top_tok_attn_wgts, lbs_cf # for each lbl do a linear combi of all the tokens based on attn_wgts (bs, num_lbs, nh)

In [None]:
# testing linear attention
inp = torch.zeros(bs, bptt).random_(100)
sentc = torch.randn(bs, bptt, nh)
mask = sentc.new_empty(sentc.size()[:-1]).random_(2).bool()
test_eq(mask.unique(), tensor([0., 1.]))
xml_attn = XMLAttention(n_lbs, nh)
attn, tok_wgts, lbs_cf = xml_attn(inp, sentc, mask)
test_eq(attn.shape, (bs, n_lbs, nh))
tst_lbs = xml_attn.lbs
tst_Lin_Attn = Linear_Attention(tst_lbs)
lin_attn_layer = Lambda(tst_Lin_Attn)
attn_wgts = F.softmax(lin_attn_layer(sentc), dim=1) # topkmax(attn_layer(sentc), dim=1)
test_eq(attn, torch.bmm(attn_wgts.masked_fill(mask[:, :, None], 0).transpose(1,2), sentc))

# testing planted attention followed by inattention
assert xml_attn.attn.func.f is _linear_attention
inp = torch.zeros((bs, bptt)).random_(vocab_sz)
brain = torch.randn(vocab_sz, n_lbs)
plant_attn_layer = Lambda(Planted_Attention(brain))
# xml_attn.attn = plant_attn_layer
setattr(xml_attn, 'attn', plant_attn_layer)
assert xml_attn.attn.func.f is _planted_attention
attn, tok_wgts, lbs_cf = xml_attn(inp, sentc, mask)
test_eqs(tok_wgts, 
         plant_attn_layer(inp).masked_fill(mask[:,:,None], 0).inattention(k=15, sort_dim=1), 
         brain[inp.long()].masked_fill(mask[:,:,None], 0).inattention(k=15, sort_dim=1)
        )
test_eq(attn, 
        lincomb(sentc, 
                wgts=brain[inp.long()].masked_fill(mask[:,:,None], 0).inattention(k=15, sort_dim=1).transpose(1,2)
               )
       )

Test masking works:

In [None]:
for attn_layer in (lin_attn_layer, plant_attn_layer):
    setattr(xml_attn, 'attn', attn_layer)
    inp = torch.zeros(bs, bptt).random_(100)
    sentc = torch.randn(bs, bptt, nh)
    sentc = sentc.masked_fill(mask[:, :, None], 0)
    assert sentc[mask].sum().item() == 0
    attn, tok_wgts, lbs_cf = xml_attn(inp, sentc, mask)
    assert sentc[mask].sum().item() == 0
    attn_wgts = F.softmax(attn_layer(sentc), dim=1) if attn_layer is lin_attn_layer else attn_layer(inp).masked_fill(mask[:,:,None], 0).inattention(k=15, sort_dim=1)# topkmax(attn_layer(sentc), dim=1)
    test_eq(attn, torch.bmm(attn_wgts.transpose(1,2), sentc))

## Export -

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()