In [None]:
#| eval: false
! [ -e /content ] && pip install -Uqq xcube # upgrade xcube on colab

In [None]:
#| default_exp l2r.data.info_gain

In [None]:
#| export
from fastcore.basics import *
from fastai.torch_core import *
from fastai.data.core import *
from fastai.data.transforms import *
from fastai.text.core import *
from fastai.text.data import *
from xcube.imports import *
from xcube.torch_imports import *

In [None]:
#| hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Information Gain

> Computation of mutual information gain

This module contains the all classes and functions needed to compute mutual information gain for the tokens and labels. This mutual information is then used to bootstrap a L2R model from xml text data. Please follow the [booting tutorial](14_tutorial.boot_l2r.ipynb) to understand how this module is used.

In [None]:
#| export
class BatchLbsChunkify(ItemTransform):
    order = 100
    def __init__(self, chnk_st, chnk_end): store_attr('chnk_st,chnk_end')
    def encodes(self, x): 
        return (x[0], x[1][:, self.chnk_st:self.chnk_end])

In [None]:
#| export
class MutualInfoGain:
    def __init__(self, df, bs=8, chnk_sz=200, device=None): store_attr('df,bs,chnk_sz,device')
    
    def onehotify(self):
        x_tfms = [Tokenizer.from_df('text', n_workers=num_cpus()), attrgetter("text"), Numericalize(), OneHotEncode()]
        y_tfms = [ColReader('labels', label_delim=';'), MultiCategorize(), OneHotEncode()]
        tfms = [x_tfms, y_tfms]
        self.dsets = Datasets(self.df, tfms=[x_tfms, y_tfms], )
        self.toksize, self.lblsize = self.dsets.vocab.map(len)
        return self.dsets
        
    def lbs_chunked(self):
        lbs = self.dsets.vocab[1]
        self.dls = []
        for chnk_st in range(0, len(lbs), self.chnk_sz):
            self.dls.append(TfmdDL(self.dsets, bs=self.bs, 
                              after_batch=[BatchLbsChunkify(chnk_st, min(chnk_st+self.chnk_sz, len(lbs)))], 
                              device=default_device() if self.device is None else self.device))
        return self.dls
    
    def _mutual_info_gain(self, dl):
        """
        Computes [mutual information gain](https://en.wikipedia.org/wiki/Mutual_information) for each token label pair
        `dl` is (bag-of-words text, one-hot encoded targets)
        """
        xb, yb = dl.one_batch() 
        toksize, lblsize = xb.size(1), yb.size(1)
        p_TL = torch.zeros(toksize, lblsize, 4, dtype=torch.float, device=default_device())
        eps = p_TL.new_empty(1).fill_(1e-8)
        for x,y in dl:
            test_eq(x.shape, (dl.bs, toksize)); test_eq(y.shape, (dl.bs, lblsize))
            t = x.unsqueeze(-1).expand(-1, -1, lblsize) ; test_eq(t.shape, (dl.bs, toksize, lblsize))
            l = y.unsqueeze(1).expand(-1, toksize, -1) ; test_eq(l.shape, (dl.bs, toksize, lblsize))
            tl = torch.stack((t,l), dim=-1) ; test_eq(tl.shape, (dl.bs, toksize, lblsize, 2)) 
            p_TL_tt = tl[...,0].logical_and(tl[...,1]) ; test_eq(p_TL_tt.shape, (dl.bs, toksize, lblsize)) 
            p_TL_tf = tl[...,0].logical_and(tl[...,1].logical_not()) ; test_eq(p_TL_tf.shape, (dl.bs, toksize, lblsize)) 
            p_TL_ft = tl[...,0].logical_not().logical_and(tl[...,1]) ; test_eq(p_TL_ft.shape, (dl.bs, toksize, lblsize))
            p_TL_ff = tl[...,0].logical_not().logical_and(tl[...,1].logical_not()) ; test_eq(p_TL_ff.shape, (dl.bs, toksize, lblsize)) 
            p_TL = p_TL + torch.stack((p_TL_tt, p_TL_tf, p_TL_ft, p_TL_ff), dim=-1).float().sum(dim=0)
        p_TL = p_TL / tensor(len(self.dsets)).float()
        p_TL = p_TL.view(toksize, lblsize, 2, 2) ; test_eq(p_TL.shape, (toksize, lblsize, 2, 2))# last axis: lbl axis, 2nd last axis: token axis
        return p_TL
    
    def joint_pmf(self):
        self.p_TL_full = [] 
        for dl in progress_bar(self.dls):
            p_TL = self._mutual_info_gain(dl)
            self.p_TL_full.append(p_TL)
            del p_TL; #del p_T; del p_L; del p_TxL; del I_TL; torch.cuda.empty_cache()
        self.p_TL_full = torch.cat(self.p_TL_full, dim=1); test_eq(self.p_TL_full.shape, (self.toksize, self.lblsize, 2, 2))
        return self.p_TL_full
    
    def compute(self):
        eps = self.p_TL_full.new_empty(1).fill_(1e-15)
        toksize, lblsize = self.p_TL_full.size(0), self.p_TL_full.size(1)
        p_T = self.p_TL_full[:,0].sum(-1, keepdim=True); test_eq(p_T.shape, (toksize, 2, 1))# 0 because we can pick any label and apply total prob law
        p_L = self.p_TL_full[0,:].sum(-2, keepdim=True); test_eq(p_L.shape, (lblsize, 1, 2)) # 0 becuase we can pick any token and apply total prob law
        p_TxL = self.p_TL_full.sum(-1, keepdim=True) @ self.p_TL_full.sum(-2, keepdim=True); test_eq(p_TxL.shape, (toksize, lblsize, 2, 2))
        H_T = -(p_T * torch.log(p_T+eps)).sum(-2).squeeze(); test_eq(H_T.shape, [toksize])
        H_L = -(p_L * torch.log(p_L+eps)).sum(-1).squeeze(); test_eq(H_L.shape, [lblsize])
        I_TL = (self.p_TL_full * torch.log((self.p_TL_full + eps)/(p_TxL + eps))).flatten(start_dim=-2).sum(-1); test_eq(I_TL.shape, (toksize, lblsize))
        return p_T, p_L, p_TxL, H_T, H_L, I_TL

In [None]:
#| export
@property
@patch
def lbs_frqs(self:MutualInfoGain):
    f = ColReader('labels', label_delim=';')
    self._frqs = Counter()
    for o in self.df.itertuples(): self._frqs.update(f(o))
    return self._frqs

## Export -

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()