In [None]:
#| hide
#| eval: false
! [ -e /content ] && pip install -Uqq xcube  # upgrade xcube on colab

In [None]:
#| export
from fastai.basics import *
from fastai.text.learner import *
from fastai.callback.rnn import *
from fastai.text.models.awdlstm import *
from fastai.text.models.core import get_text_classifier
from xcube.text.models.core import *

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| default_exp text.learner

# Learner for the XML Text application:

> All the functions necessary to build `Learner` suitable for transfer learning in XML text classification.

The most important function of this module is `xmltext_classifier_learner`. This will help you define a `Learner` using a pretrained Language Model for the encoder and a pretrained Learning-to-Rank-Model for the decoder. (Tutorial: Coming Soon!). This module is inspired from [fastai's](https://github.com/fastai/fastai) [TextLearner](https://docs.fast.ai/text.learner.html) based on the paper [ULMFit](https://arxiv.org/pdf/1801.06146.pdf).

## Loading label embeddings from a pretrained colab model

In [None]:
#| export
def _get_text_vocab(dls:DataLoaders) -> list:
    "Get text vocabulary from `DataLoaders`"
    vocab = dls.vocab
    if isinstance(vocab, L): vocab = vocab[0]
    return vocab

In [None]:
#| export
def _get_label_vocab(dls:DataLoaders) -> list:
    "Get label vocabulary from `DataLoaders`"
    vocab = dls.vocab
    if isinstance(vocab, L): vocab = vocab[1]
    return vocab

In [None]:
#| export
def match_collab(
    old_wgts:dict, # Embedding weights of the colab model
    collab_vocab:dict, # Vocabulary of `token` and `label` used for colab pre-training
    lbs_vocab:list # Current labels vocabulary
) -> dict:
    "Convert the label embedding in `old_wgts` to go from `old_vocab` in colab to `lbs_vocab`"
    bias, wgts = old_wgts.get('i_bias.weight', None), old_wgts.get('i_weight.weight')
    wgts_m = wgts.mean(0)
    new_wgts = wgts.new_zeros((len(lbs_vocab), wgts.size(1)))
    if bias is not None:
        bias_m = bias.mean(0)
        new_bias = bias.new_zeros((len(lbs_vocab), 1))
    collab_lbs_vocab = collab_vocab['label']
    collab_o2i = collab_lbs_vocab.o2i if hasattr(collab_lbs_vocab, 'o2i') else {w:i for i,w in enumerate(collab_lbs_vocab)}
    missing = 0
    for i,w in enumerate(lbs_vocab):
        idx = collab_o2i.get(w, -1)
        new_wgts[i] = wgts[idx] if idx>=0 else wgts_m
        if bias is not None: new_bias[i] = bias[idx] if idx>=0 else bias_m
        if idx == -1: missing = missing + 1
    old_wgts['i_weight.weight'] = new_wgts
    if bias is not None: old_wgts['i_bias.weight'] = new_bias
    return old_wgts, missing

In [None]:
wgts = {'u_weight.weight': torch.randn(3,5), 
        'i_weight.weight': torch.randn(4,5),
        'u_bias.weight'  : torch.randn(3,1),
        'i_bias.weight'  : torch.randn(4,1)}
collab_vocab = {'token': ['#na#', 'sun', 'moon', 'earth', 'mars'],
                'label': ['#na#', 'a', 'c', 'b']}
lbs_vocab = ['a', 'b', 'c']
new_wgts, missing = match_collab(wgts.copy(), collab_vocab, lbs_vocab)
test_eq(missing, 0)
test_close(wgts['u_weight.weight'], new_wgts['u_weight.weight'])
test_close(wgts['u_bias.weight'], new_wgts['u_bias.weight'])
with ExceptionExpected(ex=AssertionError, regex="close"):
    test_close(wgts['i_weight.weight'][1:], new_wgts['i_weight.weight'])
    test_close(wgts['i_bias.weight'][1:], new_wgts['i_bias.weight'])
old_w, new_w = wgts['i_weight.weight'], new_wgts['i_weight.weight']
old_b, new_b = wgts['i_bias.weight'], new_wgts['i_bias.weight']
for (old_k,old_v), (new_k, new_v) in zip(wgts.items(), new_wgts.items()): 
    if old_k.startswith('u'): test_eq(old_v.size(), new_v.size())
    else: test_ne(old_v.size(), new_v.size());
    # print(f"old: {old_k} = {old_v.size()}, new: {new_k} = {new_v.size()}")
test_eq(new_w[0], old_w[1]); test_eq(new_b[0], old_b[1])
test_eq(new_w[1], old_w[3]); test_eq(new_b[1], old_b[3])
test_eq(new_w[2], old_w[2]); test_eq(new_b[2], old_b[2])
test_shuffled(list(old_b[1:].squeeze().numpy()), list(new_b.squeeze().numpy()))
test_eq(torch.sort(old_b[1:], dim=0)[0], torch.sort(new_b, dim=0)[0])
test_eq(torch.sort(old_w[1:], dim=0)[0], torch.sort(new_w, dim=0)[0])

## Loading Pretrained Information Gain as Attention 

In [None]:
from xcube.l2r.all import *

In [None]:
source_mimic = untar_xxx(XURLs.MIMIC3)
xml_vocab = load_pickle(source_mimic/'mimic3-9k_clas_full_vocab.pkl')
xml_vocab = L(xml_vocab).map(listify)

In [None]:
source_l2r = untar_xxx(XURLs.MIMIC3_L2R)
boot_path = join_path_file('mimic3-9k_tok_lbl_info', source_l2r, ext='.pkl')
l2r_bootstrap = torch.load(boot_path, map_location=default_device())

In [None]:
*brain_vocab, brain = mapt(l2r_bootstrap.get, ['toks', 'lbs', 'mutual_info_jaccard'])
brain_vocab = L(brain_vocab).map(listify)
toks, lbs = brain_vocab
print(f"last two places in brain vocab has {toks[-2:]}")
# toks = CategoryMap(toks, sort=False)
lbs_des = load_pickle(source_mimic/'code_desc.pkl')
assert isinstance(lbs_des, dict)
test_eq(brain.shape, (len(toks), len(lbs))) # last two places has 'xxfake'

last two places in brain vocab has ['xxfake', 'xxfake']


The tokens which are there in the xml vocab but not in the brain:

In [None]:
not_found_in_brain = L(set(xml_vocab[0]).difference(set(brain_vocab[0])))
not_found_in_brain

(#20) ['remiained','promiscuity','q2day','cella','dobhoof','dissension','theses','1193p','unrmarkable','calcijex'...]

In [None]:
test_fail(lambda : toks.index('cella'), contains='is not in list')

The tokens which are in the brain but were not present in the xml vocab:

In [None]:
set(brain_vocab[0]).difference(xml_vocab[0])

set()

Thankfully, we have `info` for all the labels in the xml vocab:

In [None]:
assert set(brain_vocab[1]).symmetric_difference(brain_vocab[1]) == set()
# test_shuffled(xml_vocab[1], mimic_vocab[1])

In [None]:
#| export
def _xml2brain(xml_vocab, brain_vocab):
    "Creates a mapping between the indices of the xml vocab and the brainrmation-gain vocab"
    xml2brain = {i: brain_vocab.index(o) if o in brain_vocab else np.inf  for i,o in enumerate(xml_vocab)}
    xml2brain_notfnd = [o for o in xml2brain if xml2brain[o] is np.inf]
    return xml2brain, xml2brain_notfnd

In [None]:
toks_xml2brain, toks_notfnd = _xml2brain(xml_vocab[0], brain_vocab[0])

toks_found = set(toks_xml2brain).difference(set(toks_notfnd))
test_shuffled(array(xml_vocab[0])[toks_notfnd], not_found_in_brain)
some_xml_idxs = np.random.choice(array(L(toks_found)), size=10)
some_xml_toks = array(xml_vocab[0])[some_xml_idxs]
corres_brain_idxs = L(map(toks_xml2brain.get, some_xml_idxs))
corres_brain_toks = array(toks)[corres_brain_idxs]
assert all_equal(some_xml_toks, corres_brain_toks)

In [None]:
lbs_xml2brain, lbs_notfnd = _xml2brain(xml_vocab[1], brain_vocab[1])

lbs_found = set(lbs_xml2brain).difference(set(lbs_notfnd))
some_xml_idxs = np.random.choice(array(L(lbs_found)), size=10)
some_xml_lbs = array(xml_vocab[1])[some_xml_idxs]
corres_brain_idxs = L(map(lbs_xml2brain.get, some_xml_idxs))
corres_brain_lbs = array(lbs)[corres_brain_idxs]
assert all_equal(some_xml_lbs, corres_brain_lbs)

In [None]:
#| export
def brainsplant(xml_vocab, brain_vocab, brain, device=None):
    toks_xml2brain, toks_notfnd = _xml2brain(xml_vocab[0], brain_vocab[0])
    lbs_xml2brain, lbs_notfnd = _xml2brain(xml_vocab[1], brain_vocab[1])
    xml_brain = torch.zeros(*xml_vocab.map(len)).to(default_device() if device is None else device) # initialize empty brain
    toks_map = L((xml_idx, brn_idx) for xml_idx, brn_idx in toks_xml2brain.items() if brn_idx is not np.inf) 
    lbs_map = L((xml_idx, brn_idx) for xml_idx, brn_idx in lbs_xml2brain.items() if brn_idx is not np.inf) 
    # xml_brain[toks_map.itemgot(0)] = brain[toks_map.itemgot(1)] # permute toks dim to match xml and brain
    # xml_brain[:, lbs_map.itemgot(0)] = xml_brain[:, lbs_map.itemgot(1)] # permute lbs dim to match xml and brain
    xml_brain[toks_map.itemgot(0)] = brain[toks_map.itemgot(1)][:, lbs_map.itemgot(1)][:, lbs_map.itemgot(0)] # permute toks dim to match xml and brain
    return xml_brain, toks_map, lbs_map, toks_xml2brain, lbs_xml2brain

In [None]:
xml_brain, toks_map, lbs_map, toks_xml2brain, lbs_xml2brain = brainsplant(xml_vocab, brain_vocab, brain)
test_eq(xml_brain.shape, xml_vocab.map(len))
test_eq(xml_brain[toks_notfnd], xml_brain.new_zeros(len(toks_notfnd), len(xml_vocab[1])))
assert all_equal(array(xml_vocab[0])[toks_map.itemgot(0)], array(brain_vocab[0])[toks_map.itemgot(1)])
assert all_equal(array(xml_vocab[1])[lbs_map.itemgot(0)], array(brain_vocab[1])[lbs_map.itemgot(1)])

In [None]:
# tests to ensure `brainsplant` was successful 
lbl_idx_from_brn = brain_vocab[1].index('642.41')
top_toks_from_brn = brain[:, lbl_idx_from_brn].topk(k=20).indices.cpu()
array(brain_vocab[0])[top_toks_from_brn]
lbl_idx_from_xml = xml_vocab[1].index('642.41')
top_toks_from_xml = xml_brain[:, lbl_idx_from_xml].topk(k=20).indices.cpu()
test_eq(lbs_xml2brain[lbl_idx_from_xml], lbl_idx_from_brn)
test_eq(array(brain_vocab[0])[top_toks_from_brn], array(xml_vocab[0])[top_toks_from_xml])
# lbs_des['642.41'], array(xml_vocab[0])[top_toks_from_xml]

In [None]:
t = torch.zeros(4, 3).long()
s = torch.arange(20).view(2, 10).long()
t,s
row_perm = L((0, 1), (3, 0))
col_perm = L((2, 1), (1, 3), (0, -1))
t[row_perm.itemgot(0)] = s[row_perm.itemgot(1)][:, col_perm.itemgot(1)][:, col_perm.itemgot(0)]
t

tensor([[19, 13, 11],
        [ 0,  0,  0],
        [ 0,  0,  0],
        [ 9,  3,  1]])

## Base `Learner` for NLP

In [None]:
#| export
def load_collab_keys(
    model, # Model architecture
    wgts:dict # Model weights
) -> tuple:
    "Load only collab `wgts` (`i_weight` and `i_bias`) in `model`, keeping the rest as is"
    sd = model.state_dict()
    lbs_weight, i_weight = sd.get('1.attn.lbs_weight.weight', None), wgts.get('i_weight.weight', None)
    lbs_bias, i_bias = sd.get('1.attn.lbs_weight.bias', None), wgts.get('i_bias.weight', None) 
    if lbs_weight is not None and i_weight is not None: lbs_weight.data = i_weight.data
    if lbs_bias is not None and i_bias is not None: lbs_bias.data = i_bias.data
    if '1.attn.lbs_weight_dp.emb.weight' in sd:
        sd['1.attn.lbs_weight_dp.emb.weight'] = i_weight.data.clone()
    return model.load_state_dict(sd)

In [None]:
config = awd_lstm_clas_config.copy()
config.update({'n_hid': 10, 'emb_sz': 5})
# tst = get_text_classifier(AWD_LSTM, 100, 3, config=config)
tst = get_xmltext_classifier(AWD_LSTM, 100, 3, config=config)
old_sd = tst.state_dict().copy()
r = re.compile(".*attn.*")
test_eq([key for key in old_sd if 'attn' in key], list(filter(r.match, old_sd)))
print("\n".join(list(filter(r.match, old_sd))))

1.pay_attn.lbs.weight
1.boost_attn.lin.weight
1.boost_attn.lin.bias


In [None]:
import copy

In [None]:
old_sd = copy.deepcopy(tst.state_dict())
load_collab_keys(tst, new_wgts)
# <TODO: Deb> fix the following tests later
# test_ne(old_sd['1.attn.lbs_weight.weight'], tst.state_dict()['1.attn.lbs_weight.weight'])
# test_eq(tst.state_dict()['1.pay_attn.lbs_weight.weight'], new_wgts['i_weight.weight'])
# test_ne(old_sd['1.attn.lbs_weight_dp.emb.weight'], tst.state_dict()['1.attn.lbs_weight_dp.emb.weight'])
# test_eq(tst.state_dict()['1.attn.lbs_weight_dp.emb.weight'], new_wgts['i_weight.weight'])

<All keys matched successfully>

In [None]:
#| export
@delegates(Learner.__init__)
class TextLearner(Learner):
    "Basic class for a `Learner` in NLP."
    def __init__(self, 
        dls:DataLoaders, # Text `DataLoaders`
        model, # A standard PyTorch model
        alpha:float=2., # Param for `RNNRegularizer`
        beta:float=1., # Param for `RNNRegularizer`
        moms:tuple=(0.8,0.7,0.8), # Momentum for `Cosine Annealing Scheduler`
        **kwargs
    ):
        super().__init__(dls, model, moms=moms, **kwargs)
        self.add_cbs(rnn_cbs())

    def save_encoder(self, 
        file:str # Filename for `Encoder` 
    ):
        "Save the encoder to `file` in the model directory"
        if rank_distrib(): return # don't save if child proc
        encoder = get_model(self.model)[0]
        if hasattr(encoder, 'module'): encoder = encoder.module
        torch.save(encoder.state_dict(), join_path_file(file, self.path/self.model_dir, ext='.pth'))
    
    @delegates(save_model)
    def save(self,
        file:str, # Filename for the state_directory of the model
        **kwargs
    ):
        """
        Save model and optimizer state (if `with_opt`) to `self.path/self.model_dir/file`
        Save `self.dls.vocab` to `self.path/self.model_dir/clas_vocab.pkl`
        """
        model_file = join_path_file(file, self.path/self.model_dir, ext='.pth')
        vocab_file = join_path_file(file+'_vocab', self.path/self.model_dir, ext='.pkl')
        save_model(model_file, self.model, getattr(self, 'opt', None), **kwargs)
        save_pickle(vocab_file, self.dls.vocab)
        return model_file

    def load_encoder(self, 
        file:str, # Filename of the saved encoder 
        device:(int,str,torch.device)=None # Device used to load, defaults to `dls` device
    ):
        "Load the encoder `file` from the model directory, optionally ensuring it's on `device`"
        encoder = get_model(self.model)[0]
        if device is None: device = self.dls.device
        if hasattr(encoder, 'module'): encoder = encoder.module
        distrib_barrier()
        wgts = torch.load(join_path_file(file,self.path/self.model_dir, ext='.pth'), map_location=device)
        encoder.load_state_dict(clean_raw_keys(wgts))
        self.freeze()
        return self
    
    def load_brain(self,
        file: str, # Filename of the saved attention wgts
        device:(int,str,torch.device)=None # Device used to load, defaults to `dls` device
    ):
        """Load the pre-learnt label specific attention weights for each token from `file` located in the 
        model directory, optionally ensuring it's one `device`
        """
        brain_path = join_path_file(file, self.path/self.model_dir, ext='.pkl')
        brain_bootstrap = torch.load(brain_path, map_location=device)
        *brain_vocab, brain = mapt(brain_bootstrap.get, ['toks', 'lbs', 'mutual_info_jaccard'])
        brain_vocab = L(brain_vocab).map(listify)
        vocab = L(_get_text_vocab(self.dls), _get_label_vocab(self.dls)).map(listify)
        xml_brain, toks_map, lbs_map, toks_xml2brain, lbs_xml2brain = brainsplant(vocab, brain_vocab, brain)
        # import pdb; pdb.set_trace()
        return self

    def load_pretrained(self, 
        wgts_fname:str, # Filename of saved weights 
        vocab_fname:str, # Saved vocabulary filename in pickle format
        model=None # Model to load parameters from, defaults to `Learner.model`
    ):
        "Load a pretrained model and adapt it to the data vocabulary."
        old_vocab = load_pickle(vocab_fname)
        new_vocab = _get_text_vocab(self.dls)
        distrib_barrier()
        wgts = torch.load(wgts_fname, map_location = lambda storage,loc: storage)
        if 'model' in wgts: wgts = wgts['model'] #Just in case the pretrained model was saved with an optimizer
        wgts = match_embeds(wgts, old_vocab, new_vocab)
        load_ignore_keys(self.model if model is None else model, clean_raw_keys(wgts))
        self.freeze()
        return self

    #For previous versions compatibility. Remove at release
    @delegates(load_model_text)
    def load(self, 
        file:str, # Filename of saved model 
        with_opt:bool=None, # Enable to load `Optimizer` state
        device:(int,str,torch.device)=None, # Device used to load, defaults to `dls` device
        **kwargs
    ):
        if device is None: device = self.dls.device
        if self.opt is None: self.create_opt()
        file = join_path_file(file, self.path/self.model_dir, ext='.pth')
        load_model_text(file, self.model, self.opt, device=device, **kwargs)
        return self
    
    def load_collab(self,
        wgts_fname:str, # Filename of the saved collab model
        collab_vocab_fname:str, # Saved Vocabulary of collab labels in pickle format 
        model=None # Model to load parameters from, defaults to `Learner.model`
    ):
        "Load the label embeddings learned by collab model`, and adapt it to the label vocabulary."
        collab_vocab = load_pickle(collab_vocab_fname)
        lbs_vocab = _get_label_vocab(self.dls)
        distrib_barrier()
        wgts = torch.load(wgts_fname, map_location=lambda storage,loc: storage)
        if 'model' in wgts: wgts = wgts['model'] #Just in case the pretrained model was saved with an optimizer
        wgts, _ = match_collab(wgts, collab_vocab, lbs_vocab)
        load_collab_keys(self.model if model is None else model, wgts)
        self.freeze()
        return self

Adds a `ModelResetter` and an `RNNRegularizer` with `alpha` and `beta` to the callbacks, the rest is the same as `Learner` init. 

This `Learner` adds functionality to the base class:

## `Learner` convenience functions

In [None]:
#| export
from xcube.text.models.core import _model_meta 

In [None]:
#| export 
@delegates(Learner.__init__)
def xmltext_classifier_learner(dls, arch, seq_len=72, config=None, backwards=False, pretrained=True, collab=False, drop_mult=0.5, n_out=None,
                           lin_ftrs=None, ps=None, max_len=72*20, y_range=None, splitter=None, **kwargs):
    "Create a `Learner` with a text classifier from `dls` and `arch`."
    vocab = _get_text_vocab(dls)
    if n_out is None: n_out = get_c(dls)
    assert n_out, "`n_out` is not defined, and could not be inferred from the data, set `dls.c` or pass `n_out`"
    model = get_xmltext_classifier2(arch, len(vocab), n_out, seq_len=seq_len, config=config, y_range=y_range,
                                drop_mult=drop_mult, max_len=max_len)
    meta = _model_meta[arch]
    learn = TextLearner(dls, model, splitter=splitter if splitter is not None else meta['split_clas'], **kwargs)
    url = 'url_bwd' if backwards else 'url'
    if pretrained:
        if url not in meta:
            warn("There are no pretrained weights for that architecture yet!")
            return learn
        model_path = untar_data(meta[url], c_key='model')
        try: fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']]
        except IndexError: print(f'The model in {model_path} is incomplete, download again'); raise
        learn = learn.load_pretrained(*fnames, model=learn.model[0])
    if collab:
        try: fnames = [list(learn.path.glob(f'**/collab/*collab*.{ext}'))[0] for ext in ['pth', 'pkl']]
        except IndexError: print(f'The collab model in {learn.path} is incomplete, re-train it!'); raise
        learn = learn.load_colab(*fnames, model=learn.model[1])
    learn.freeze()
    return learn   

## Export -

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()