In [None]:
#| hide
#| eval: false
! [ -e /content ] && pip install -Uqq xcube  # upgrade xcube on colab

In [None]:
#| export
from fastai.basics import *
from fastai.text.learner import *
from fastai.callback.rnn import *
from fastai.text.models.awdlstm import *
from fastai.text.models.core import get_text_classifier
from fastprogress.fastprogress import master_bar, progress_bar
from xcube.text.models.core import *

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| default_exp text.learner

# Learner for the XML Text application:

> All the functions necessary to build `Learner` suitable for transfer learning in XML text classification.

The most important function of this module is `xmltext_classifier_learner`. This will help you define a `Learner` using a pretrained Language Model for the encoder and a pretrained Learning-to-Rank-Model for the decoder. (Tutorial: Coming Soon!). This module is inspired from [fastai's](https://github.com/fastai/fastai) [TextLearner](https://docs.fast.ai/text.learner.html) based on the paper [ULMFit](https://arxiv.org/pdf/1801.06146.pdf).

## Loading label embeddings from a pretrained colab model

In [None]:
#| export
def _get_text_vocab(dls:DataLoaders) -> list:
    "Get text vocabulary from `DataLoaders`"
    vocab = dls.vocab
    if isinstance(vocab, L): vocab = vocab[0]
    return vocab

In [None]:
#| export
def _get_label_vocab(dls:DataLoaders) -> list:
    "Get label vocabulary from `DataLoaders`"
    vocab = dls.vocab
    if isinstance(vocab, L): vocab = vocab[1]
    return vocab

In [None]:
#| export
def match_collab(
    old_wgts:dict, # Embedding weights of the colab model
    collab_vocab:dict, # Vocabulary of `token` and `label` used for colab pre-training
    lbs_vocab:list # Current labels vocabulary
) -> dict:
    "Convert the label embedding in `old_wgts` to go from `old_vocab` in colab to `lbs_vocab`"
    bias, wgts = old_wgts.get('i_bias.weight', None), old_wgts.get('i_weight.weight')
    wgts_m = wgts.mean(0)
    new_wgts = wgts.new_zeros((len(lbs_vocab), wgts.size(1)))
    if bias is not None:
        bias_m = bias.mean(0)
        new_bias = bias.new_zeros((len(lbs_vocab), 1))
    collab_lbs_vocab = collab_vocab['label']
    collab_o2i = collab_lbs_vocab.o2i if hasattr(collab_lbs_vocab, 'o2i') else {w:i for i,w in enumerate(collab_lbs_vocab)}
    missing = 0
    for i,w in enumerate(lbs_vocab):
        idx = collab_o2i.get(w, -1)
        new_wgts[i] = wgts[idx] if idx>=0 else wgts_m
        if bias is not None: new_bias[i] = bias[idx] if idx>=0 else bias_m
        if idx == -1: missing = missing + 1
    old_wgts['i_weight.weight'] = new_wgts
    if bias is not None: old_wgts['i_bias.weight'] = new_bias
    return old_wgts, missing

In [None]:
wgts = {'u_weight.weight': torch.randn(3,5), 
        'i_weight.weight': torch.randn(4,5),
        'u_bias.weight'  : torch.randn(3,1),
        'i_bias.weight'  : torch.randn(4,1)}
collab_vocab = {'token': ['#na#', 'sun', 'moon', 'earth', 'mars'],
                'label': ['#na#', 'a', 'c', 'b']}
lbs_vocab = ['a', 'b', 'c']
new_wgts, missing = match_collab(wgts.copy(), collab_vocab, lbs_vocab)
test_eq(missing, 0)
test_close(wgts['u_weight.weight'], new_wgts['u_weight.weight'])
test_close(wgts['u_bias.weight'], new_wgts['u_bias.weight'])
with ExceptionExpected(ex=AssertionError, regex="close"):
    test_close(wgts['i_weight.weight'][1:], new_wgts['i_weight.weight'])
    test_close(wgts['i_bias.weight'][1:], new_wgts['i_bias.weight'])
old_w, new_w = wgts['i_weight.weight'], new_wgts['i_weight.weight']
old_b, new_b = wgts['i_bias.weight'], new_wgts['i_bias.weight']
for (old_k,old_v), (new_k, new_v) in zip(wgts.items(), new_wgts.items()): 
    if old_k.startswith('u'): test_eq(old_v.size(), new_v.size())
    else: test_ne(old_v.size(), new_v.size());
    # print(f"old: {old_k} = {old_v.size()}, new: {new_k} = {new_v.size()}")
test_eq(new_w[0], old_w[1]); test_eq(new_b[0], old_b[1])
test_eq(new_w[1], old_w[3]); test_eq(new_b[1], old_b[3])
test_eq(new_w[2], old_w[2]); test_eq(new_b[2], old_b[2])
test_shuffled(list(old_b[1:].squeeze().numpy()), list(new_b.squeeze().numpy()))
test_eq(torch.sort(old_b[1:], dim=0)[0], torch.sort(new_b, dim=0)[0])
test_eq(torch.sort(old_w[1:], dim=0)[0], torch.sort(new_w, dim=0)[0])

## Loading Pretrained Information Gain as Attention 

In [None]:
from xcube.l2r.all import *

In [None]:
source_mimic = untar_xxx(XURLs.MIMIC3)
xml_vocab = load_pickle(source_mimic/'mimic3-9k_clas_full_vocab.pkl')
xml_vocab = L(xml_vocab).map(listify)

In [None]:
source_l2r = untar_xxx(XURLs.MIMIC3_L2R)
boot_path = join_path_file('mimic3-9k_tok_lbl_info', source_l2r, ext='.pkl')
bias_path = join_path_file('p_L', source_l2r, ext='.pkl')
l2r_bootstrap = torch.load(boot_path, map_location=default_device())
brain_bias = torch.load(bias_path, map_location=default_device())

In [None]:
*brain_vocab, brain = mapt(l2r_bootstrap.get, ['toks', 'lbs', 'mutual_info_jaccard'])
brain_vocab = L(brain_vocab).map(listify)
toks, lbs = brain_vocab
print(f"last two places in brain vocab has {toks[-2:]}")
# toks = CategoryMap(toks, sort=False)
brain_bias = brain_bias[:, :, 0].squeeze(-1)
lbs_des = load_pickle(source_mimic/'code_desc.pkl')
assert isinstance(lbs_des, dict)
test_eq(brain.shape, (len(toks), len(lbs))) # last two places has 'xxfake'
test_eq(brain_bias.shape, [len(lbs)])

last two places in brain vocab has ['xxfake', 'xxfake']


The tokens which are there in the xml vocab but not in the brain:

In [None]:
not_found_in_brain = L(set(xml_vocab[0]).difference(set(brain_vocab[0])))
not_found_in_brain

(#20) ['cella','q2day','remiained','luteinizing','promiscuity','sharpio','calcijex','dissension','mhc','theses'...]

In [None]:
test_fail(lambda : toks.index('cella'), contains='is not in list')

The tokens which are in the brain but were not present in the xml vocab:

In [None]:
set(brain_vocab[0]).difference(xml_vocab[0])

set()

Thankfully, we have `info` for all the labels in the xml vocab:

In [None]:
assert set(brain_vocab[1]).symmetric_difference(brain_vocab[1]) == set()
# test_shuffled(xml_vocab[1], mimic_vocab[1])

In [None]:
#| export
def _xml2brain(xml_vocab, brain_vocab, parent_bar=None):
    "Creates a mapping between the indices of the xml vocab and the brain vocab"
    pbar = progress_bar(xml_vocab, parent=parent_bar, leave=True)
    xml2brain = {i: brain_vocab.index(o) if o in brain_vocab else np.inf  for i,o in enumerate(pbar)}
    xml2brain_notfnd = [o for o in xml2brain if xml2brain[o] is np.inf]
    return xml2brain, xml2brain_notfnd

In [None]:
toks_xml2brain, toks_notfnd = _xml2brain(xml_vocab[0], brain_vocab[0])

toks_found = set(toks_xml2brain).difference(set(toks_notfnd))
test_shuffled(array(xml_vocab[0])[toks_notfnd], not_found_in_brain)
some_xml_idxs = np.random.choice(array(L(toks_found)), size=10)
some_xml_toks = array(xml_vocab[0])[some_xml_idxs]
corres_brain_idxs = L(map(toks_xml2brain.get, some_xml_idxs))
corres_brain_toks = array(toks)[corres_brain_idxs]
assert all_equal(some_xml_toks, corres_brain_toks)

In [None]:
lbs_xml2brain, lbs_notfnd = _xml2brain(xml_vocab[1], brain_vocab[1])

lbs_found = set(lbs_xml2brain).difference(set(lbs_notfnd))
some_xml_idxs = np.random.choice(array(L(lbs_found)), size=10)
some_xml_lbs = array(xml_vocab[1])[some_xml_idxs]
corres_brain_idxs = L(map(lbs_xml2brain.get, some_xml_idxs))
corres_brain_lbs = array(lbs)[corres_brain_idxs]
assert all_equal(some_xml_lbs, corres_brain_lbs)

In [None]:
#| export
def brainsplant(xml_vocab, brain_vocab, brain, brain_bias, device=None):
    toks_lbs = 'toks lbs'.split()
    mb = master_bar(range(2))
    for i in mb:
        globals().update(dict(zip((toks_lbs[i]+'_xml2brain', toks_lbs[i]+'_notfnd'), (_xml2brain(xml_vocab[i], brain_vocab[i], parent_bar=mb)))))
        mb.write = f"Finished Loop {i}"
    xml_brain = torch.zeros(*xml_vocab.map(len)).to(default_device() if device is None else device) # initialize empty brain
    xml_lbsbias = torch.zeros(len(xml_vocab[1])).to(default_device() if device is None else device)
    toks_map = L((xml_idx, brn_idx) for xml_idx, brn_idx in toks_xml2brain.items() if brn_idx is not np.inf) 
    lbs_map = L((xml_idx, brn_idx) for xml_idx, brn_idx in lbs_xml2brain.items() if brn_idx is not np.inf) 
    xml_brain[toks_map.itemgot(0)] = brain[toks_map.itemgot(1)][:, lbs_map.itemgot(1)] # permute toks dim to match xml and brain
    xml_brain[:, lbs_map.itemgot(0)] = xml_brain.clone() # permute lbs dim to match xml and brain
    xml_lbsbias[lbs_map.itemgot(0)] = brain_bias[lbs_map.itemgot(1)].clone() # permute toks dim to match xml and brain
    return xml_brain, xml_lbsbias, toks_map, lbs_map, toks_xml2brain, lbs_xml2brain

In [None]:
xml_brain, xml_lbsbias, toks_map, lbs_map, toks_xml2brain, lbs_xml2brain = brainsplant(xml_vocab, brain_vocab, brain, brain_bias)
test_eq(xml_brain.shape, xml_vocab.map(len))
test_eq(xml_brain[toks_notfnd], xml_brain.new_zeros(len(toks_notfnd), len(xml_vocab[1])))
assert all_equal(array(xml_vocab[0])[toks_map.itemgot(0)], array(brain_vocab[0])[toks_map.itemgot(1)])
assert all_equal(array(xml_vocab[1])[lbs_map.itemgot(0)], array(brain_vocab[1])[lbs_map.itemgot(1)])

In [None]:
# tests to ensure `brainsplant` was successful 
lbl = '642.41'
lbl = '38.93'
lbl = '51.10'
lbl = '996.87'
lbl_idx_from_brn = brain_vocab[1].index(lbl)
tok_vals_from_brn, top_toks_from_brn= L(brain[:, lbl_idx_from_brn].topk(k=20)).map(Self.cpu())
lbl_idx_from_xml = xml_vocab[1].index(lbl)
tok_vals_from_xml, top_toks_from_xml = L(xml_brain[:, lbl_idx_from_xml].topk(k=20)).map(Self.cpu())
test_eq(lbs_xml2brain[lbl_idx_from_xml], lbl_idx_from_brn)
test_eq(tok_vals_from_brn, tok_vals_from_xml)
test_eq(array(brain_vocab[0])[top_toks_from_brn], array(xml_vocab[0])[top_toks_from_xml])
test_eq(brain_bias[lbl_idx_from_brn], xml_lbsbias[lbl_idx_from_xml])
print(f"For the lbl {lbl} ({lbs_des.get(lbl)}), the top tokens that needs attention are:")
print('\n'.join(L(array(xml_vocab[0])[top_toks_from_xml], use_list=True).zipwith(L(tok_vals_from_xml.numpy(), use_list=True)).map(str).map(lambda o: "+ "+o)))

For the lbl 996.87 (Complications of transplanted intestine), the top tokens that needs attention are:
+ ('consultued', 0.25548762)
+ ('cip', 0.25548762)
+ ('parlor', 0.24661502)
+ ('transplantations', 0.18601614)
+ ('scaffoid', 0.18601614)
+ ('epineprine', 0.18601614)
+ ('culinary', 0.17232327)
+ ('coordinates', 0.1469037)
+ ('aminotransferases', 0.12153866)
+ ('hydronephroureter', 0.12153866)
+ ('27yom', 0.12153866)
+ ('27y', 0.103684604)
+ ('hardward', 0.090407245)
+ ('leukoreduction', 0.08014185)
+ ('venting', 0.07831942)
+ ('secrete', 0.07196123)
+ ('orthogonal', 0.07196123)
+ ('naac', 0.06891022)
+ ('mgso4', 0.0662555)
+ ('septecemia', 0.065286644)


In [None]:
tok = 'fibrillation'
tok = 'colpo'
tok = 'amiodarone'
tok = 'flagyl'
tok = 'nasalilid'
tok = 'hemetemesis'
tok = 'restitched'
tok_idx_from_brn = brain_vocab[0].index(tok)
lbs_vals_from_brn, top_lbs_from_brn = L(brain[tok_idx_from_brn].topk(k=20)).map(Self.cpu())
tok_idx_from_xml = xml_vocab[0].index(tok)
test_eq(tok_idx_from_brn, toks_xml2brain[tok_idx_from_xml])
lbs_vals_from_xml, top_lbs_from_xml = L(xml_brain[tok_idx_from_xml].topk(k=20)).map(Self.cpu())
test_eq(lbs_vals_from_brn, lbs_vals_from_xml)
try: 
    test_eq(array(brain_vocab[1])[top_lbs_from_brn], array(xml_vocab[1])[top_lbs_from_xml])
except AssertionError as e: 
    print(type(e).__name__, "due to instability in sorting (nothing to worry!)");
    test_shuffled(array(brain_vocab[1])[top_lbs_from_brn], array(xml_vocab[1])[top_lbs_from_xml])
print('')
print(f"For the token {tok}, the top labels that needs attention are:")
print('\n'.join(L(mapt(lbs_des.get, array(xml_vocab[1])[top_lbs_from_xml])).zipwith(L(lbs_vals_from_xml.numpy(), use_list=True)).map(str).map(lambda o: "+ "+o)))


For the token restitched, the top labels that needs attention are:
+ ('Other operations on supporting structures of uterus', 0.29102018)
+ ('Other proctopexy', 0.29102018)
+ ('Other operations on cul-de-sac', 0.18601614)
+ (None, 0.07494824)
+ ('Intervertebral disc disorder with myelopathy, thoracic region', 0.055331517)
+ ('Excision of scapula, clavicle, and thorax [ribs and sternum] for graft', 0.04382947)
+ ('Other repair of omentum', 0.028067086)
+ ('Chronic lymphocytic thyroiditis', 0.01986737)
+ (None, 0.019236181)
+ ('Reclosure of postoperative disruption of abdominal wall', 0.016585195)
+ ('Other disorders of calcium metabolism', 0.009393147)
+ ('Pain in joint involving pelvic region and thigh', 0.008421187)
+ ('Exteriorization of small intestine', 0.00817792)
+ ('Fusion or refusion of 9 or more vertebrae', 0.00762466)
+ ('Kyphosis (acquired) (postural)', 0.0074228523)
+ ('Unspecified procedure as the cause of abnormal reaction of patient, or of later complication, without men

In [None]:
some_toks = random.sample(toks_map.itemgot(0), 10)
counts = [c*6 for c in random.sample(range(10), 10)]
some_toks = random.sample(some_toks, 20, counts=counts)
# Counter(some_toks)
cors_toks_brn = L(mapt(toks_xml2brain.get, some_toks))
test_eq(array(brain_vocab[0])[cors_toks_brn], array(xml_vocab[0])[some_toks])
print("some tokens (with repetitions):\n",'\n'.join(['-'+xml_vocab[0][t]for t in some_toks]))

some tokens (with repetitions):
 -disorientated
-disorientated
-dmh
-ibp
-literacy
-abruptly
-faxed
-delsym
-literacy
-delsym
-literacy
-ibp
-literacy
-delsym
-abruptly
-caox3
-caox3
-caox3
-caox3
-literacy


In [None]:
attn = xml_brain[some_toks]
test_eq(attn.shape, (len(some_toks), xml_brain.shape[1]))
# semantics of attn
# for each token we can compute the attention each label deserves by pulling out all the columns for a label
for t, a in zip(some_toks,attn):
    test_eq(xml_brain[t], a)
# for each label we can compute the attention those tokens deserve by pulling out all rows for a label
for lbl in range(xml_brain.shape[1]):
    test_eq(xml_brain[:, lbl][some_toks], attn[:, lbl])

In [None]:
pd.DataFrame([(xml_vocab[0][t], l:=xml_vocab[1][lbl_idx], val.item(), lbs_des.get(l, 'NF')) for t,lbl_idx,val in zip(some_toks,attn.max(dim=1).indices.cpu(), attn.max(dim=1).values.cpu())],
            columns=['token', 'most_relevant_lbl', 'lbl_attn', 'description']).sort_values(by='lbl_attn', ascending=False)

Unnamed: 0,token,most_relevant_lbl,lbl_attn,description
7,delsym,344.2,0.096369,Diplegia of upper limbs
13,delsym,344.2,0.096369,Diplegia of upper limbs
9,delsym,344.2,0.096369,Diplegia of upper limbs
2,dmh,983.1,0.041843,Toxic effect of acids
0,disorientated,171.0,0.036627,"Malignant neoplasm of connective and other soft tissue of head, face, and neck"
1,disorientated,171.0,0.036627,"Malignant neoplasm of connective and other soft tissue of head, face, and neck"
18,caox3,375.01,0.028963,Acute dacryoadenitis
17,caox3,375.01,0.028963,Acute dacryoadenitis
16,caox3,375.01,0.028963,Acute dacryoadenitis
15,caox3,375.01,0.028963,Acute dacryoadenitis


In [None]:
from xcube.layers import inattention

In [None]:
# define label inattention cutoff
k = 5

In [None]:
top_lbs_attn = attn.clone().unsqueeze(0).permute(0,2,1).inattention(k=k).permute(0,2,1).squeeze(0).contiguous() # applying `inattention` across the lbs dim
test_eq(top_lbs_attn.shape, (len(some_toks), xml_brain.shape[1]))
test_ne(attn, top_lbs_attn)
test_eq(top_lbs_attn.argmax(dim=1), attn.argmax(dim=1))
lbs_cf = top_lbs_attn.sum(dim=0)
test_eq(lbs_cf.shape, [top_lbs_attn.shape[1]])
idxs = lbs_cf.nonzero().flatten().cpu()
print(f"After looking at the tokens {[xml_vocab[0][t]for t in some_toks]}, I am confident about the following labels:")
pd.DataFrame([(l:=xml_vocab[1][idx], val.item(), lbs_des.get(l, 'NF')) for idx,val in zip(idxs,lbs_cf[idxs])],
            columns=['lbl', 'lbl_cf', 'description']).sort_values(by='lbl_cf', ascending=False)

After looking at the tokens ['disorientated', 'disorientated', 'dmh', 'ibp', 'literacy', 'abruptly', 'faxed', 'delsym', 'literacy', 'delsym', 'literacy', 'ibp', 'literacy', 'delsym', 'abruptly', 'caox3', 'caox3', 'caox3', 'caox3', 'literacy'], I am confident about the following labels:


Unnamed: 0,lbl,lbl_cf,description
10,344.2,0.289106,Diplegia of upper limbs
22,367.1,0.289106,Myopia
36,706.1,0.158741,Other acne
35,691.8,0.14564,Other atopic dermatitis and related conditions
21,442.89,0.134519,Aneurysm of other specified site
50,374.89,0.115853,Other disorders of eyelid
49,375.01,0.115853,Acute dacryoadenitis
20,449,0.090417,Septic arterial embolism
11,438.13,0.087779,NF
23,423.1,0.080689,Adhesive pericarditis


In [None]:
#| hide
#| eval: false
# semantics: `s` pulling out its 1st and 0th row is equivalent to `t` pulling out its 0th and 3rd row respectively (i.e., the data residing in the 1st and 0th row of the s matrix is same as the data residing at the 0th and the 3rd row of t's matrix)
t = torch.zeros(4, 3).long()
s = torch.arange(20).view(2, 10).long()
# s = torch.arange(6).view(2,3).long()
row_perm = L((0, 1), (3, 0)) # 
col_perm = L((2, 1), (0, 3), (1, -1))
# col_perm = L((0,2), (1,0), (2,1))
ic(t,s);
ic(s[row_perm.itemgot(1)]); # pull out relevant rows from s
ic(s[row_perm.itemgot(1)][:, col_perm.itemgot(1)]); # pull out relevant cols from s
t[row_perm.itemgot(0)] = s[row_perm.itemgot(1)][:, col_perm.itemgot(1)] # permute rows
t[:, col_perm.itemgot(0)] = t.clone() # permute cols
# t[row_perm.itemgot(0)] = s[row_perm.itemgot(1)][:, col_perm.itemgot(1)][:, col_perm.itemgot(0)]
ic(t);

In [None]:
l2r_wgts = torch.load(join_path_file('lin_lambdarank_full', source_l2r, ext='.pth'), map_location=default_device())
if 'model' in l2r_wgts: l2r_wgts = l2r_wgts['model']

Need to match the wgts in xml and brain:

In [None]:
def brainsplant_diffntble(xml_vocab, brain_vocab, l2r_wgts, device=None):
    toks_lbs = 'toks lbs'.split()
    mb = master_bar(range(2))
    for i in mb:
        globals().update(dict(zip((toks_lbs[i]+'_xml2brain', toks_lbs[i]+'_notfnd'), (_xml2brain(xml_vocab[i], brain_vocab[i], parent_bar=mb)))))
        mb.write = f"Finished Loop {i}" 
    toks_map = L((xml_idx, brn_idx) for xml_idx, brn_idx in toks_xml2brain.items() if brn_idx is not np.inf) 
    lbs_map = L((xml_idx, brn_idx) for xml_idx, brn_idx in lbs_xml2brain.items() if brn_idx is not np.inf) 
    tf_xml = torch.zeros(len(xml_vocab[0]), 200).to(default_device() if device is None else device) 
    tb_xml = torch.zeros(len(xml_vocab[0]), 1).to(default_device() if device is None else device) 
    lf_xml = torch.zeros(len(xml_vocab[1]), 200).to(default_device() if device is None else device) 
    lb_xml = torch.zeros(len(xml_vocab[1]), 1).to(default_device() if device is None else device) 
    tf_l2r, tb_l2r, lf_l2r, lb_l2r = list(l2r_wgts.values())
    tf_xml[toks_map.itemgot(0)] = tf_l2r[toks_map.itemgot(1)].clone()
    tb_xml[toks_map.itemgot(0)] = tb_l2r[toks_map.itemgot(1)].clone()
    lf_xml[lbs_map.itemgot(0)] = lf_l2r[lbs_map.itemgot(1)].clone()
    lb_xml[lbs_map.itemgot(0)] = lb_l2r[lbs_map.itemgot(1)].clone()
    # import pdb; pdb.set_trace()
    xml_wgts = {k: xml_val for k, xml_val in zip(l2r_wgts.keys(), (tf_xml, tb_xml, lf_xml, lb_xml))}
    mod_dict = nn.ModuleDict({k.split('.')[0]: nn.Embedding(*v.size()) for k,v in xml_wgts.items()}).to(default_device() if device is None else device) 
    mod_dict.load_state_dict(xml_wgts)
    return mod_dict, toks_map, lbs_map

In [None]:
mod_dict, toks_map, lbs_map = brainsplant_diffntble(xml_vocab, brain_vocab, l2r_wgts)
assert isinstance(mod_dict, nn.Module)
assert nn.Module in mod_dict.__class__.__mro__ 

test_eq(mod_dict['token_factors'].weight.data[toks_map.itemgot(0)], l2r_wgts['token_factors.weight'][toks_map.itemgot(1)])
test_eq(mod_dict['token_bias'].weight.data[toks_map.itemgot(0)], l2r_wgts['token_bias.weight'][toks_map.itemgot(1)])
test_eq(mod_dict['label_factors'].weight.data[lbs_map.itemgot(0)], l2r_wgts['label_factors.weight'][lbs_map.itemgot(1)])
test_eq(mod_dict['label_bias'].weight.data[lbs_map.itemgot(0)], l2r_wgts['label_bias.weight'][lbs_map.itemgot(1)])

In [None]:
mod_dict

ModuleDict(
  (token_factors): Embedding(57376, 200)
  (token_bias): Embedding(57376, 1)
  (label_factors): Embedding(8922, 200)
  (label_bias): Embedding(8922, 1)
)

In [None]:
some_lbs = ['996.87', '51.10', '38.93']

for lbl in some_lbs:
    print(f"{lbl}: {lbs_des.get(lbl, 'NF')}")

996.87: Complications of transplanted intestine
51.10: Endoscopic retrograde cholangiopancreatography [ERCP]
38.93: Venous catheterization, not elsewhere classified


In [None]:
lbs_idx = tensor(mapt(xml_vocab[1].index, some_lbs)).to(default_device())

In [None]:
toks_idx = torch.randint(0, len(xml_vocab[0]), (72,)).to(default_device())
print("-"+'\n-'.join(array(xml_vocab[0])[toks_idx.cpu()].tolist()))

-influx
-latissimus
-equinovarus
-deteriorates
-aap
-mvh
-135
-incipient
-rhubarb
-nizhoni
-trancutaneous
-indicaton
-subset
-largyngeal
-lemonade
-debulk
-aerations
-l34
-perserverates
-trendelenberg
-kettr
-meningitic
-bored
-hashimoto
-mountains
-wit
-asts
-ellicits
-pax
-adb
-alcholism
-violinist
-301b
-subpopulation
-intraorally
-98o2
-agreesive
-monilla
-jig
-paroxysmalatrial
-10pts
-knees
-conventionally
-soonest
-recap
-rediscuss
-spontanous
-pulmary
-repletement
-450x12
-symetrically
-fdi
-pshx
-svco2
-topimax
-2100cc
-conceal
-nauasea
-decontamination
-administrator
-fraction
-tachyarrythmia
-oversee
-dabigutran
-reiterated
-aftetr
-bues
-symettric
-powerful
-depocyte
-hyperextension
-hepsc


In [None]:
apprx_brain = mod_dict['token_factors'](toks_idx) @ mod_dict['label_factors'](lbs_idx).T + mod_dict['token_bias'](toks_idx) + mod_dict['label_bias'](lbs_idx).T
apprx_brain.shape

torch.Size([72, 3])

These are the tokens as ranked by the pretrained L2R model (which is essentially an approximation of the actual brain):

In [None]:
pd.DataFrame(array(xml_vocab[0])[toks_idx[apprx_brain.argsort(dim=0, descending=True)].cpu()], columns=L(zip(some_lbs, mapt(lbs_des.get, some_lbs))).map(': '.join))

Unnamed: 0,996.87: Complications of transplanted intestine,51.10: Endoscopic retrograde cholangiopancreatography [ERCP],"38.93: Venous catheterization, not elsewhere classified"
0,fraction,wit,fraction
1,knees,fraction,subpopulation
2,subpopulation,administrator,knees
3,wit,subset,subset
4,administrator,knees,pshx
...,...,...,...
67,paroxysmalatrial,mvh,powerful
68,indicaton,ellicits,indicaton
69,rhubarb,indicaton,perserverates
70,depocyte,aftetr,rhubarb


Just to compare: This is how an actual brain would rank those tokens:

In [None]:
# array(xml_vocab[0])[xml_brain[:, lbl_idx].topk(k=20, dim=0).indices.cpu()]
pd.DataFrame(array(xml_vocab[0])[toks_idx[xml_brain[:, lbs_idx][toks_idx].argsort(descending=True, dim=0)].cpu()], columns=L(zip(some_lbs, mapt(lbs_des.get, some_lbs))).map(': '.join))

Unnamed: 0,996.87: Complications of transplanted intestine,51.10: Endoscopic retrograde cholangiopancreatography [ERCP],"38.93: Venous catheterization, not elsewhere classified"
0,fraction,wit,knees
1,knees,administrator,svco2
2,hyperextension,pshx,meningitic
3,meningitic,hashimoto,fraction
4,301b,reiterated,subset
...,...,...,...
67,latissimus,topimax,pshx
68,monilla,conceal,equinovarus
69,dabigutran,aftetr,debulk
70,trendelenberg,symettric,oversee


## Base `Learner` for NLP

In [None]:
#| export
def load_collab_keys(
    model, # Model architecture
    wgts:dict # Model weights
) -> tuple:
    "Load only collab `wgts` (`i_weight` and `i_bias`) in `model`, keeping the rest as is"
    sd = model.state_dict()
    lbs_weight, i_weight = sd.get('1.attn.lbs_weight.weight', None), wgts.get('i_weight.weight', None)
    lbs_bias, i_bias = sd.get('1.attn.lbs_weight.bias', None), wgts.get('i_bias.weight', None) 
    if lbs_weight is not None and i_weight is not None: lbs_weight.data = i_weight.data
    if lbs_bias is not None and i_bias is not None: lbs_bias.data = i_bias.data
    if '1.attn.lbs_weight_dp.emb.weight' in sd:
        sd['1.attn.lbs_weight_dp.emb.weight'] = i_weight.data.clone()
    return model.load_state_dict(sd)

In [None]:
config = awd_lstm_clas_config.copy()
config.update({'n_hid': 10, 'emb_sz': 5})
# tst = get_text_classifier(AWD_LSTM, 100, 3, config=config)
tst = get_xmltext_classifier(AWD_LSTM, 100, 3, config=config)
old_sd = tst.state_dict().copy()
r = re.compile(".*attn.*")
test_eq([key for key in old_sd if 'attn' in key], list(filter(r.match, old_sd)))
print("\n".join(list(filter(r.match, old_sd))))

1.pay_attn.lbs.weight
1.boost_attn.lin.weight
1.boost_attn.lin.bias


In [None]:
import copy

In [None]:
old_sd = copy.deepcopy(tst.state_dict())
load_collab_keys(tst, new_wgts)
# <TODO: Deb> fix the following tests later
# test_ne(old_sd['1.attn.lbs_weight.weight'], tst.state_dict()['1.attn.lbs_weight.weight'])
# test_eq(tst.state_dict()['1.pay_attn.lbs_weight.weight'], new_wgts['i_weight.weight'])
# test_ne(old_sd['1.attn.lbs_weight_dp.emb.weight'], tst.state_dict()['1.attn.lbs_weight_dp.emb.weight'])
# test_eq(tst.state_dict()['1.attn.lbs_weight_dp.emb.weight'], new_wgts['i_weight.weight'])

<All keys matched successfully>

In [None]:
#| export
from xcube.layers import *
from xcube.layers import _planted_attention

In [None]:
#| export
@delegates(Learner.__init__)
class TextLearner(Learner):
    "Basic class for a `Learner` in NLP."
    def __init__(self, 
        dls:DataLoaders, # Text `DataLoaders`
        model, # A standard PyTorch model
        alpha:float=2., # Param for `RNNRegularizer`
        beta:float=1., # Param for `RNNRegularizer`
        moms:tuple=(0.8,0.7,0.8), # Momentum for `Cosine Annealing Scheduler`
        **kwargs
    ):
        super().__init__(dls, model, moms=moms, **kwargs)
        self.add_cbs(rnn_cbs())

    def save_encoder(self, 
        file:str # Filename for `Encoder` 
    ):
        "Save the encoder to `file` in the model directory"
        if rank_distrib(): return # don't save if child proc
        encoder = get_model(self.model)[0]
        if hasattr(encoder, 'module'): encoder = encoder.module
        torch.save(encoder.state_dict(), join_path_file(file, self.path/self.model_dir, ext='.pth'))
    
    @delegates(save_model)
    def save(self,
        file:str, # Filename for the state_directory of the model
        **kwargs
    ):
        """
        Save model and optimizer state (if `with_opt`) to `self.path/self.model_dir/file`
        Save `self.dls.vocab` to `self.path/self.model_dir/clas_vocab.pkl`
        """
        model_file = join_path_file(file, self.path/self.model_dir, ext='.pth')
        vocab_file = join_path_file(file+'_vocab', self.path/self.model_dir, ext='.pkl')
        save_model(model_file, self.model, getattr(self, 'opt', None), **kwargs)
        save_pickle(vocab_file, self.dls.vocab)
        return model_file

    def load_encoder(self, 
        file:str, # Filename of the saved encoder 
        device:(int,str,torch.device)=None # Device used to load, defaults to `dls` device
    ):
        "Load the encoder `file` from the model directory, optionally ensuring it's on `device`"
        encoder = get_model(self.model)[0]
        if device is None: device = self.dls.device
        if hasattr(encoder, 'module'): encoder = encoder.module
        distrib_barrier()
        wgts = torch.load(join_path_file(file,self.path/self.model_dir, ext='.pth'), map_location=device)
        encoder.load_state_dict(clean_raw_keys(wgts))
        self.freeze()
        return self
    
    def load_brain(self,
        file_wgts: str, # Filename of the saved attention wgts
        file_bias: str, # Filename of the saved label bias
        device:(int,str,torch.device)=None # Device used to load, defaults to `dls` device
    ):
        """Load the pre-learnt label specific attention weights for each token from `file` located in the 
        model directory, optionally ensuring it's one `device`
        """
        brain_path = join_path_file(file_wgts, self.path/self.model_dir, ext='.pkl')
        bias_path = join_path_file(file_bias, self.path/self.model_dir, ext='.pkl')
        brain_bootstrap = torch.load(brain_path, map_location=default_device() if device is None else device)
        *brain_vocab, brain = mapt(brain_bootstrap.get, ['toks', 'lbs', 'mutual_info_jaccard'])
        brain_vocab = L(brain_vocab).map(listify)
        vocab = L(_get_text_vocab(self.dls), _get_label_vocab(self.dls)).map(listify)
        brain_bias = torch.load(bias_path, map_location=default_device() if device is None else device)
        brain_bias = brain_bias[:, :, 0].squeeze(-1)
        print("Performing brainsplant...")
        self.brain, self.lbsbias, *_ = brainsplant(vocab, brain_vocab, brain, brain_bias)
        print("Successfull!")
        # import pdb; pdb.set_trace()
        plant_attn_layer = Lambda(Planted_Attention(self.brain))
        setattr(self.model[1].pay_attn, 'attn', plant_attn_layer)
        assert self.model[1].pay_attn.attn.func.f is _planted_attention
        return self

    def load_pretrained(self, 
        wgts_fname:str, # Filename of saved weights 
        vocab_fname:str, # Saved vocabulary filename in pickle format
        model=None # Model to load parameters from, defaults to `Learner.model`
    ):
        "Load a pretrained model and adapt it to the data vocabulary."
        old_vocab = load_pickle(vocab_fname)
        new_vocab = _get_text_vocab(self.dls)
        distrib_barrier()
        wgts = torch.load(wgts_fname, map_location = lambda storage,loc: storage)
        if 'model' in wgts: wgts = wgts['model'] #Just in case the pretrained model was saved with an optimizer
        wgts = match_embeds(wgts, old_vocab, new_vocab)
        load_ignore_keys(self.model if model is None else model, clean_raw_keys(wgts))
        self.freeze()
        return self

    #For previous versions compatibility. Remove at release
    @delegates(load_model_text)
    def load(self, 
        file:str, # Filename of saved model 
        with_opt:bool=None, # Enable to load `Optimizer` state
        device:(int,str,torch.device)=None, # Device used to load, defaults to `dls` device
        **kwargs
    ):
        if device is None: device = self.dls.device
        if self.opt is None: self.create_opt()
        file = join_path_file(file, self.path/self.model_dir, ext='.pth')
        load_model_text(file, self.model, self.opt, device=device, **kwargs)
        return self
    
    def load_collab(self,
        wgts_fname:str, # Filename of the saved collab model
        collab_vocab_fname:str, # Saved Vocabulary of collab labels in pickle format 
        model=None # Model to load parameters from, defaults to `Learner.model`
    ):
        "Load the label embeddings learned by collab model`, and adapt it to the label vocabulary."
        collab_vocab = load_pickle(collab_vocab_fname)
        lbs_vocab = _get_label_vocab(self.dls)
        distrib_barrier()
        wgts = torch.load(wgts_fname, map_location=lambda storage,loc: storage)
        if 'model' in wgts: wgts = wgts['model'] #Just in case the pretrained model was saved with an optimizer
        wgts, _ = match_collab(wgts, collab_vocab, lbs_vocab)
        load_collab_keys(self.model if model is None else model, wgts)
        self.freeze()
        return self

Adds a `ModelResetter` and an `RNNRegularizer` with `alpha` and `beta` to the callbacks, the rest is the same as `Learner` init. 

This `Learner` adds functionality to the base class:

## `Learner` convenience functions

In [None]:
#| export
from xcube.text.models.core import _model_meta 

In [None]:
#| export 
@delegates(Learner.__init__)
def xmltext_classifier_learner(dls, arch, seq_len=72, config=None, backwards=False, pretrained=True, collab=False, drop_mult=0.5, n_out=None,
                           lin_ftrs=None, ps=None, max_len=72*20, y_range=None, splitter=None, running_decoder=True, **kwargs):
    "Create a `Learner` with a text classifier from `dls` and `arch`."
    vocab = _get_text_vocab(dls)
    if n_out is None: n_out = get_c(dls)
    assert n_out, "`n_out` is not defined, and could not be inferred from the data, set `dls.c` or pass `n_out`"
    model = get_xmltext_classifier2(arch, len(vocab), n_out, seq_len=seq_len, config=config, y_range=y_range,
                                drop_mult=drop_mult, max_len=max_len, running_decoder=running_decoder)
    # model = get_xmltext_classifier(arch, len(vocab), n_out, seq_len=seq_len, config=config, y_range=y_range,
                                # drop_mult=drop_mult, max_len=max_len)
    meta = _model_meta[arch]
    learn = TextLearner(dls, model, splitter=splitter if splitter is not None else meta['split_clas'], **kwargs)
    url = 'url_bwd' if backwards else 'url'
    if pretrained:
        if url not in meta:
            warn("There are no pretrained weights for that architecture yet!")
            return learn
        model_path = untar_data(meta[url], c_key='model')
        try: fnames = [list(model_path.glob(f'*.{ext}'))[0] for ext in ['pth', 'pkl']]
        except IndexError: print(f'The model in {model_path} is incomplete, download again'); raise
        learn = learn.load_pretrained(*fnames, model=learn.model[0])
    if collab:
        try: fnames = [list(learn.path.glob(f'**/collab/*collab*.{ext}'))[0] for ext in ['pth', 'pkl']]
        except IndexError: print(f'The collab model in {learn.path} is incomplete, re-train it!'); raise
        learn = learn.load_colab(*fnames, model=learn.model[1])
    learn.freeze()
    return learn   

## Export -

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()