In [None]:
#| eval: false
! [ -e /content ] && pip install -Uqq xcube # upgrade xcube on colab

In [None]:
from fastai.data.core import *
from xcube.l2r.all import *

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

# Boot L2R 

> Bootstrapping a learning-to-rank model

In this tutorial we will find a needle in the haystack with mutual infomation gain:

#### Mutual-Information Computation

In [None]:
source = untar_xxx(XURLs.MIMIC3_L2R)
source.ls()

(#6) [Path('/home/deb/.xcube/data/mimic3_l2r/info.pkl'),Path('/home/deb/.xcube/data/mimic3_l2r/code_descriptions.csv'),Path('/home/deb/.xcube/data/mimic3_l2r/mimic3-9k_tok_lbl_info.pkl'),Path('/home/deb/.xcube/data/mimic3_l2r/code_desc.pkl'),Path('/home/deb/.xcube/data/mimic3_l2r/p_TL.pkl'),Path('/home/deb/.xcube/data/mimic3_l2r/mimic3-9k.csv')]

In [None]:
data = source/'mimic3-9k.csv'
df = pd.read_csv(data,
                 header=0,
                 names=['subject_id', 'hadm_id', 'text', 'labels', 'length', 'is_valid'],
                 dtype={'subject_id': str, 'hadm_id': str, 'text': str, 'labels': str, 'length': np.int64, 'is_valid': bool})
df[['text', 'labels']] = df[['text', 'labels']].astype(str)
len(df)

52726

In [None]:
df.head(3)

Unnamed: 0,subject_id,hadm_id,text,labels,length,is_valid
0,86006,111912,admission date discharge date date of birth sex f service surgery allergies patient recorded as having no known allergies to drugs attending first name3 lf chief complaint 60f on coumadin was found slightly drowsy tonight then fell down stairs paramedic found her unconscious and she was intubated w o any medication head ct shows multiple iph transferred to hospital1 for further eval major surgical or invasive procedure none past medical history her medical history is significant for hypertension osteoarthritis involving bilateral knee joints with a dependence on cane for ambulation chronic...,801.35;348.4;805.06;807.01;998.30;707.24;E880.9;427.31;414.01;401.9;V58.61;V43.64;707.00;E878.1;96.71,230,False
1,85950,189769,admission date discharge date service neurosurgery allergies sulfa sulfonamides attending first name3 lf chief complaint cc cc contact info major surgical or invasive procedure none history of present illness hpi 88m who lives with family had fall yesterday today had decline in mental status ems called pt was unresponsive on arrival went to osh head ct showed large r sdh pt was intubated at osh and transferred to hospital1 for further care past medical history cad s p mi in s p cabg in ventricular aneurysm at that time cath in with occluded rca unable to intervene chf reported ef 1st degre...,852.25;E888.9;403.90;585.9;250.00;414.00;V45.81;96.71,304,False
2,88025,180431,admission date discharge date date of birth sex f service surgery allergies no known allergies adverse drug reactions attending first name3 lf chief complaint s p fall major surgical or invasive procedure none history of present illness 45f etoh s p fall from window at feet found ambulating and slurring speech on scene intubated en route for declining mental status in the er the patient was found to be bradycardic to the s with bp of systolic she was given atropine dilantin and was started on saline past medical history unknown social history unknown family history unknown physical exam ex...,518.81;348.4;348.82;801.25;427.89;E882;V49.86;305.00;96.71;38.93,359,False


The file `'code_desc.pkl'` contains a short description for the labels. 

In [None]:
with open(source/'code_desc.pkl', 'rb') as f: lbs_desc = pickle.load(f)
assert isinstance(lbs_desc, dict)
test_eq(lbs_desc['427.31'], 'Atrial fibrillation')

Note that performing some computations in this notebook on the full dataset is going to take a lot of time. But don't worry `untar_xxx` has already downloaded everything you need. But you can still run the following cells if you want to generate everything from scratch. Preferably, run the following cells on a sampled dataset for quick iterations. 

**Run the cell below only if you want to sample from the full dataset to create a tiny dataset for the purpose of quick iterations.**

*Technical Point:* If we want to sample to perform quick iterations, we need to make sure the number of data points in the sample is a multiple of `bs`. So that we do not have to do a `drop_last=True` while creating the `Dataloaders`. This is because we are about to do some probability computations, and dropping data points is not a good idea as probabilities would not sum to 1.

In [None]:
bs = 8
cut = len(df) - len(df)%bs
df = df[:cut]
len(df)

52720

In [None]:
_arr = np.arange(0, len(df), bs)
# mask = (_arr > 4000) & (_arr < 5000)
mask = (_arr > 500) & (_arr < 1000)
_n = np.random.choice(_arr[mask], 1)
df = df.sample(n=_n, random_state=89, ignore_index=True)
len(df)

832

In [None]:
df.head(3)

Unnamed: 0,subject_id,hadm_id,text,labels,length,is_valid
0,2258,139169,admission date discharge date date of birth sex m service cardiothoracic surgery history of present illness the patient is a year old male with a past medical history significant for poorly controlled diabetes mellitus and hypertension as well as known coronary disease and a previous non q myocardial infarction and right coronary artery stenting in he was admitted to an outside hospital on the day prior to admission with unstable angina and found to have borderline positive troponin hypertension and st depressions in the lateral lead he was given aspirin nitrates beta blockers morphine and...,414.01;998.31;411.1;599.0;412;V45.82;250.00;401.9;530.81;36.13;37.22;36.15;36.19;39.61;39.64;88.56;88.53;33.23;96.56;33.24;78.41,1271,False
1,41217,161582,admission date discharge date date of birth sex m service medicine allergies no known allergies adverse drug reactions attending first name3 lf chief complaint new diagnosis of scc of base of tongue major surgical or invasive procedure egd w biopsy history of present illness yo man with h o cad heavy smoking and new diagnosis of scc of base of tongue with lymph node involvement pt was referred to dr last name stitle ent in for a rt neck mass at that time a cm rt cervical lymph node was palpated and fiberoptic laryngoscopy showed a cm rt base of tongue mass a ct and biopsy were recommended ...,141.0;507.0;196.0;293.0;519.09;786.30;286.9;427.89;790.29;276.52;414.01;338.3;280.0;272.0;412;V69.4;V15.82;V45.82;V66.7;E879.8;E932.0;31.42;25.01;42.23;43.11;96.6;38.93;99.25;38.93,2743,False
2,30204,172114,admission date discharge date date of birth sex f service medicine allergies etomidate norpace quinidine demerol penicillins lipitor attending doctor first name chief complaint cardiac tamponade s p pulmonary vein isolation major surgical or invasive procedure attempted pulmonary vein isolation pericardiocentesis history of present illness year old woman with a long history of paroxysmal atrial fibrillation refractory to mulitple pharmacologic interventions and multiple cardioversions who presents to the ccu with cardiac tamponade s p pulmonary vein isolation procedure past medical history...,427.31;998.2;423.3;423.9;573.0;276.6;E878.8;37.34;37.27;37.0;37.21,1764,False


**[Mutual Information](https://en.wikipedia.org/wiki/Mutual_information#)**

<img alt="Pictorial representation of simple neural network" width="400" src="info-gain.svg" caption="Pictorial representation of a simple neural network" id="img_mut_info">

The mutual information of two jointly discrete random variables X and  Y is calculated as a double sum:

$$I(T;L) = \sum_{l \in \mathcal{L}} \sum_{t in \mathcal{T}} P_{(T,L)}(t,l) \log \Bigg(\frac{P_{(T,L)}(t,l)}{P_T(t) P_L(l)} \Bigg)$$

where $P_{(T,L)}$ is the [joint probability mass function](https://en.wikipedia.org/wiki/Joint_distribution) of $T$ and $L$, and $P_T$ and $P_L$ are the [marginal probability mass fucntions](https://en.wikipedia.org/wiki/Marginal_probability) of $T$ and $L$ respectively. To compute $I$, the only quantity we need to compute is the joint pmf $P_{(T,L)}$, as the marginal pmfs can be computed from the joint pmf.

With regard to implementation, $P_{(T,L)}$ can be thought of as a 2x2 tensor as shown below:

In [None]:
p_TL = pd.DataFrame(0, columns=['t', 'not t'], index=['lbl', 'not lbl'])
p_TL

Unnamed: 0,t,not t
lbl,0,0
not lbl,0,0


...and we need to compute this $P_{(T,L)}$ for every token-label pair. In other words, we need to fill in the `joint` dataframe shown below. Note that each cell in `joint` dataframe can be thought of to be further subdivided into a 2x2 grid containing the corresponding `p_TL`.

In [None]:
bs, chnk_sz = 8, 200
info = MutualInfoGain(df, bs=bs, chnk_sz=chnk_sz, lbs_desc=source/'code_desc.pkl') # provide lbs_desc if you have it

In [None]:
%%time
dsets = info.onehotify()

CPU times: user 386 ms, sys: 151 ms, total: 538 ms
Wall time: 2.09 s


In [None]:
toks, lbs = dsets.vocab
L(toks), L(lbs), len(toks)*len(lbs)

((#11104) ['xxunk','xxpad','xxbos','xxeos','xxfld','xxrep','xxwrep','xxup','xxmaj','the'...],
 (#2244) ['008.45','008.8','009.0','009.1','031.0','031.2','038.0','038.10','038.11','038.19'...],
 24917376)

In [None]:
joint = pd.DataFrame(0, columns=range(len(lbs)), index=range(len(toks)))
joint.index.name = 'toks (T)'
joint.columns.name = 'lbs (L)'
joint

lbs (L),0,1,2,3,4,5,6,7,8,9,...,2234,2235,2236,2237,2238,2239,2240,2241,2242,2243
toks (T),Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
11099,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
11100,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
11101,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
11102,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


We can perform tensorized computation if we think of `p_TL` as a 4 dim tensor of size `(len(toks), len(lbs), 2, 2)`. Next, to be able to estimate `p_TL` we just need to iterate over the dataset and for each data point and each token-label pair record the `p_TL` information in the last two dimension of the tensor `p_TL`. And, at the end divide by size of the dataset. 

Some more implementation details (Skip this if not iterested): 

- We are going to one-hot encode the dataset (both `text` and `labels` field in the `df`). This is done by `onehot_dsets` 
- For efficieny, in reality we are not going to iterate over the dataset one by one, instead we are going to use a dataloader and perform `p_TL` computation on a mini-batch.
- Unless you are doing this in 2035 you probably do not have enogh GPU-RAM to fit the entire `p_TL` tensor of dimension `(len(toks), len(lbs), 2, 2)`. So we are going to split the lbs dimension into chunks. (Why the `lbs` dimension and not the `toks`? Because in XML datsets `toks` are approximately 60000, but the number of `lbs` could be really large of the order of millions.) With reagrd to implementation this would mean that instead of one dataloader we would roll with multiple dataloaders. And each dataloader would load the dataset in a way that mini-batches would contain the full one-hot encoding of the `text` field but only a certain `chunk` of the one-hot encoded `labels` field in `df`. Another way to think about this is that each datapoint, specifically the `labels` are splitted across multiple dataloaders. This way once we are done iterating over one such dataloader we would have filled a ceratin chunk of the `joint` dataframe shown above. And we would fill the entire `joint` only once we are done iterating over all the dataloaders. 

In [None]:
x, y = dsets[0]
test_eq(tensor(dsets.tfms[1][2].decode(y)), torch.where(y==1)[0])
test_eq(tensor(dsets.tfms[0][-1].decode(x)), torch.where(x==1)[0])

In [None]:
' '.join(L(toks)[torch.where(x==1)[0]])

'xxunk xxbos the and of to was with a on in for mg no patient is he blood at name or discharge as day one his left history last were had right by this date admission that pain hospital an pt from p normal first has have which but medications up d chest o hours also well given status dr time care after stable stitle disease please follow course started x known continued two days service artery prior per showed m it without cardiac medical past glucose heart post q namepattern1 present unit physical weeks aortic i pulmonary transferred year allergies md edema pressure t due did surgery surgical condition number procedure b lower found remained prn fluid coronary admitted hypertension soft further non all diagnosis should rate placed increased three birth bilaterally sodium abdomen aspirin bilateral illness over social than old secondary sex primary however some following examination lung positive disposition significant moderate take floor room insulin namepattern4 bleeding extremities c

In [None]:
lbs.map_ids(torch.where(y==1)[0])

(#21) ['250.00','33.23','33.24','36.13','36.15','36.19','37.22','39.61','39.64','401.9'...]

In [None]:
#| hide
#| eval: false
from fastai.data.transforms import *
from fastai.data.core import *
from fastai.text.core import *

In [None]:
#| hide
#| eval: false
splits = ColSplitter()(df)
splits

# lm_vocab = torch.load(dls_lm_vocab_path)

@Transform
def Cleanser(toks): return [o for o in toks if o in lm_vocab]

class MyNumericalize(Transform):
    "Transform to remove tokens not present in `vocab`"
    def __init__(self, vocab=None, min_freq=3, max_vocab=60000, special_toks=None):
        store_attr('vocab,min_freq,max_vocab,special_toks')
        self.o2i = None if vocab is None else defaultdict(int, {v: i for i,v in enumerate(vocab)})
    
    def setups(self, dsets):
        if dsets is None: return
        if self.vocab is None:
            count = dsets.counter if getattr(dsets, 'counter', None) is not None else Counter(p for o in dsets for p in o)
            if self.special_toks is None and hasattr(dsets, 'special_toks'):
                self.special_toks = dsets.special_toks
            self.vocab = make_vocab(count, min_freq=self.min_freq, max_vocab=self.max_vocab, special_toks=self.special_toks)
            self.o2i = defaultdict(int, {v:i for i,v in enumerate(self.vocab) if v != 'xxfake'})
    
    def encodes(self, o): return TensorText(tensor([self.o2i[o_] for o_ in o if o_ in self.vocab]))
    def decodes(self, o): return L(self.vocab[o_] for o_ in o)

# resort to this if anythiong goes wrong below
x_tfms = [Tokenizer.from_df('text', n_workers=num_cpus()), attrgetter("text"), Cleanser, MultiCategorize(vocab=lm_vocab), OneHotEncode()]
y_tfms = [ColReader('labels', label_delim=';'), MultiCategorize(), OneHotEncode()]
tfms = [x_tfms, y_tfms]

class Chunkifize(Transform):
    order = 4
    def __init__(self, num_chunks=3): store_attr('num_chunks')
    def encodes(self, o): 
        return list(torch.chunk(o, self.num_chunks))
    def decodes(self, o): 
        return torch.cat(o)

chnk_tfm = Chunkifize()
chnks = chnk_tfm(torch.arange(10))
test_eq(type(chnks), list)
test_eq(chnks, [tensor([0, 1, 2, 3]), tensor([4, 5, 6, 7]), tensor([8, 9])])
# test_fail(lambda: chnk_tfm.decode(chnks), tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
test_eq(chnk_tfm.decode(chnks), tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))

# y_tfms = [ColReader('labels', label_delim=';'), MultiCategorize(), OneHotEncode(), Chunkifize()]
# tfmd_y = TfmdLists(df, tfms=y_tfms)
# tfmd_y.decode(tfmd_y[0])

In [None]:
dls = info.lbs_chunked()

In [None]:
assert isinstance(dls[0], TfmdDL)
test_eq(len(dls),  np.ceil(len(lbs)/200))
test_eq(len(dls[0]), np.ceil(len(dsets)/bs)) # drop_last is False
# test to prove that the labels for each data point is split across multiple dataloaders
lbs_0 = torch.cat([yb[0] for dl in dls for _,yb in itertools.islice(dl, 1)])
y = y.to(default_device())
test_eq(lbs_0, y)

Now let's compute the `joint_pmf` table we had seen earlier. 

In [None]:
%%time
p_TL = info.joint_pmf()

CPU times: user 10.8 s, sys: 2.94 s, total: 13.8 s
Wall time: 16 s


In [None]:
test_eq(p_TL.shape, (info.toksize, info.lblsize, 2, 2))

Technicality: `p_TL` is not really the joint pmf (yes, I lied before!) but contains all the information needed to compute the joint pmf `p_TxL` and mutual info gain `I_TL`. This computation is going to be comnputed by `compute`:

In [None]:
%%time
p_T, p_L, p_TxL, H_T, H_L, I_TL = info.compute()

CPU times: user 708 ms, sys: 135 ms, total: 843 ms
Wall time: 931 ms


All this while if you have been working with the sampled dataset you can continue to do so for the rest of this notebook. But if you want a real feel of how things look, at this point you can load the pregenerated `p_TL` and `(p_T, p_L, p_TxL, H_T, H_L, I_TL)` for the full dataset which `untar_xxx` downloaded:

In [None]:
L(source.glob("**/*.pkl"))

(#4) [Path('/home/deb/.xcube/data/mimic3_l2r/info.pkl'),Path('/home/deb/.xcube/data/mimic3_l2r/mimic3-9k_tok_lbl_info.pkl'),Path('/home/deb/.xcube/data/mimic3_l2r/code_desc.pkl'),Path('/home/deb/.xcube/data/mimic3_l2r/p_TL.pkl')]

In [None]:
%%time 
p_TL = torch.load(source/'p_TL.pkl', map_location=torch.device('cpu'))
p_T, p_L, p_TxL, H_T, H_L, I_TL = torch.load(source/'info.pkl', map_location=torch.device('cpu'))

CPU times: user 0 ns, sys: 4.48 s, total: 4.48 s
Wall time: 6.92 s


Make sure that aren't any of those pesky nans or negs:

In [None]:
def test_nanegs(*args):
    for o in args:
        has_nans = o.isnan().all() # check for nans
        has_negs = not torch.where(o>=0, True, False).all()
        if has_nans: raise Exception(f"{namestr(o, globals())[0]} has nans")
        if has_negs: raise Exception(f"{namestr(o, globals())[0]} has negs")

In [None]:
test_fail(test_nanegs, args=(p_T, p_L, p_TxL, H_T, H_L, I_TL), contains='I_TL has negs')

Theoretically, Mutual-Info as defined [here](https://en.wikipedia.org/wiki/Mutual_information) is suposed to be non-negative (can be proved by tossing in [Jensen](https://en.wikipedia.org/wiki/Jensen%27s_inequality)). But, practically, it turns out `I_TL` has some negs because we distorted the `p_TL` and `p_TxL`  with `eps` in the `I_TL` computation.

In [None]:
torch.topk(I_TL.flatten(), 10, largest=False)

torch.return_types.topk(
values=TensorMultiCategory([-1.5736e-07, -1.5736e-07, -1.5736e-07, -1.5736e-07,
                     -1.5736e-07, -1.5736e-07, -1.5736e-07, -1.5736e-07,
                     -1.5736e-07, -1.5736e-07], device='cuda:0'),
indices=TensorMultiCategory([1867682, 1867581, 1867490, 1867556, 1867114, 1867016,
                     1867175, 1867278, 1867705, 1867872], device='cuda:0'))

In [None]:
howmany = torch.where(I_TL < 0, True, False).sum().item()
negs = torch.where(I_TL < 0, I_TL, I_TL.new_zeros(I_TL.shape))
negs.sum()/howmany

TensorMultiCategory(-5.2029e-08, device='cuda:0')

Those negs on an avg are pretty close to zero. So we need not worry. Let's roll!

In [None]:
test_eq(p_TL.shape, (info.toksize, info.lblsize, 2, 2))
test_eq(p_T.shape, (info.toksize, 2, 1))
test_eq(p_L.shape, (info.lblsize, 1, 2))
test_eq(p_TxL.shape, (info.toksize, info.lblsize, 2, 2))
test_eq(H_T.shape, [info.toksize])
test_eq(H_L.shape, [info.lblsize])
test_eq(I_TL.shape, (info.toksize, info.lblsize))

In [None]:
#| hide
# r_t, r_l = random.randrange(0, len(toks)), random.randrange(0, len(lbs))
# toks[r_t], lbs[r_l]

# test_close(p_TL[r_t,r_l].sum(), 1, eps=1e-1)
# test_eq(p_T[r_t].sum(), 1)
# test_eq(p_L[r_l].sum(), 1)

# p_TL[r_t,r_l].sum(-1), p_TL[r_t, 400].sum(-1) 

# p_T[r_t], p_L[r_l]
# I_TL[r_t,r_l]

In [None]:
eps = I_TL.new_empty(1).fill_(1e-15)
info_lbl_entropy = I_TL/(H_L + eps)
info_jaccard = I_TL/(H_T.unsqueeze(-1) + H_L.unsqueeze(0) - I_TL + eps)
assert not info_lbl_entropy.isnan().all(); assert not info_jaccard.isnan().all()
l2r_bootstrap = {'toks': toks, 'lbs': lbs, 'mut_info_lbl_entropy': info_lbl_entropy, 'mutual_info_jaccard': info_jaccard}

`l2r_bootstrap` for the full dataset was downloaded by `untar_xxx` in `boot_path`. You can load it up in the following cell. `l2r_bootstrap` will be used to bootstrap our learning-to-rank model.

In [None]:
boot_path = (source/'mimic3-9k_tok_lbl_info.pkl')
assert boot_path.exists()

#### Save those Mutual Information Gain values

Let's take a look at the *Mutual Information Gain* (`I_TL`) for each of the labels:

In [None]:
with tempfile.TemporaryDirectory() as tmpdir:
    args = (p_TL, p_T, p_L, info_jaccard, H_T, H_L)
    kwargs = {'k':10, 'save_as': Path(tmpdir)/'mut_info_jaccard.ft'}
    df_info = info.show(*args, **kwargs)
    assert (Path(tmpdir)/'mut_info_jaccard.ft').exists()

In [None]:
df_info.head()

Unnamed: 0,label,freq,prob,entropy,description,"top-k (token, prob, entropy, joint, info)"
0,8.45,19,0.022837,0.108882,Intestinal infection due to clostridium difficile,[['difficile' '0.05048077' '0.19992837' '0.015625' '0.12893958']\n ['cdiff' '0.010817308' '0.059724282' '0.0060096155' '0.10641697']\n ['colitis' '0.056490384' '0.21719941' '0.014423078' '0.09652457']\n ['loosely' '0.0024038462' '0.016897745' '0.0024038462' '0.07903936']\n ['reformatted' '0.006009616' '0.036727034' '0.0036057695' '0.07314324']\n ['flagyl' '0.16466346' '0.44732267' '0.021634616' '0.068708725']\n ['clostridium' '0.028846156' '0.13070811' '0.007211539' '0.055680964']\n ['rocephin' '0.008413462' '0.04857681' '0.0036057695' '0.055435326']\n ['enteritis' '0.0036057695' '0.023882...
1,8.8,2,0.002404,0.016898,"Intestinal infection due to other organism, not elsewhere classified",[['vasotec' '0.0024038462' '0.016897745' '0.0012019231' '0.21375288']\n ['gastroenteritis' '0.009615385' '0.054226533' '0.0024038462'\n '0.19268773']\n ['tachyarrhythmia' '0.0036057695' '0.023882464' '0.0012019231'\n '0.1501656']\n ['probenecid' '0.0036057695' '0.023882464' '0.0012019231' '0.1501656']\n ['electronics' '0.0036057695' '0.023882464' '0.0012019231' '0.1501656']\n ['poles' '0.0036057695' '0.023882464' '0.0012019231' '0.1501656']\n ['9th' '0.0036057695' '0.023882464' '0.0012019231' '0.1501656']\n ['aborted' '0.0036057695' '0.023882464' '0.0012019231' '0.1501656']\n ['tenting' ...
2,9.0,1,0.001202,0.009283,"Infectious colitis, enteritis, and gastroenteritis",[['presacral' '0.0012019231' '0.009282676' '0.0012019231' '0.9999966']\n ['vibrio' '0.0012019231' '0.009282676' '0.0012019231' '0.9999966']\n ['yersinia' '0.0024038462' '0.016897745' '0.0012019231' '0.41028064']\n ['ova' '0.0024038462' '0.016897745' '0.0012019231' '0.41028064']\n ['parasites' '0.0036057695' '0.023882464' '0.0012019231' '0.26692808']\n ['resucitation' '0.0036057695' '0.023882464' '0.0012019231' '0.26692808']\n ['exlap' '0.0036057695' '0.023882464' '0.0012019231' '0.26692808']\n ['tenting' '0.0036057695' '0.023882464' '0.0012019231' '0.26692808']\n ['adhesiolysis' '0.0036057...
3,9.1,2,0.002404,0.016898,"Colitis, enteritis, and gastroenteritis of presumed infectious origin",[['44yf' '0.0012019231' '0.009282676' '0.0012019231' '0.41028064']\n ['ischioanal' '0.0012019231' '0.009282676' '0.0012019231' '0.41028064']\n ['perianal' '0.0012019231' '0.009282676' '0.0012019231' '0.41028064']\n ['paraplegia' '0.0024038462' '0.016897745' '0.0012019231' '0.21375288']\n ['paraplegic' '0.0024038462' '0.016897745' '0.0012019231' '0.21375288']\n ['vicarious' '0.0024038462' '0.016897745' '0.0012019231' '0.21375288']\n ['spasticity' '0.0024038462' '0.016897745' '0.0012019231' '0.21375288']\n ['tuberosity' '0.0024038462' '0.016897745' '0.0012019231' '0.21375288']\n ['comp' '0.0...
4,31.0,1,0.001202,0.009283,Pulmonary diseases due to other mycobacteria,[['gist' '0.0024038462' '0.016897745' '0.0012019231' '0.41028064']\n ['disrupted' '0.0024038462' '0.016897745' '0.0012019231' '0.41028064']\n ['eroding' '0.0024038462' '0.016897745' '0.0012019231' '0.41028064']\n ['vaginitis' '0.0024038462' '0.016897745' '0.0012019231' '0.41028064']\n ['circumscribed' '0.0024038462' '0.016897745' '0.0012019231'\n '0.41028064']\n ['discern' '0.0024038462' '0.016897745' '0.0012019231' '0.41028064']\n ['inseparable' '0.0024038462' '0.016897745' '0.0012019231' '0.41028064']\n ['pearls' '0.0024038462' '0.016897745' '0.0012019231' '0.41028064']\n ['ileocecal' '...


#### Let's look at those Mutual-Information Gain values:

In [None]:
mask = (df_info.freq>50) & (df_info.freq<150)
# with pd.option_context('display.max_colwidth', 100):
# pd.reset_option('all')
df_info = df_info[mask].reset_index(drop=True)
len(df_info)

33

The dataframe below shows the top 10 tokens (based on the mutual-info-gain values) for labels are rare (freq between 50 and 150). Feel free to ChatGPT the label descriptions and the tokens to find out if we're able to find the needle in a haystack. 

In [None]:
pd.set_option('display.max_colwidth', None)
df_info.head()

Unnamed: 0,label,freq,prob,entropy,description,"top-k (token, prob, entropy, joint, info)"
0,38.9,60,0.072115,0.259077,Unspecified septicemia,[['expired' '0.10336539' '0.33241516' '0.037259616' '0.08065721']\n ['pressors' '0.13100962' '0.38830143' '0.042067308' '0.076888874']\n ['septic' '0.066105776' '0.24344721' '0.025240386' '0.06168491']\n ['spectrum' '0.06129808' '0.23052433' '0.02283654' '0.05475075']\n ['levophed' '0.07692308' '0.27118933' '0.025240386' '0.04954973']\n ['tpn' '0.04567308' '0.18557212' '0.018028846' '0.04873537']\n ['antifungal' '0.0048076925' '0.030457318' '0.0048076925' '0.04623133']\n ['broad' '0.070913464' '0.2559954' '0.02283654' '0.04429924']\n ['lactate' '0.265625' '0.5788586' '0.051682696' '0.04339258']\n ['sepsis' '0.18750001' '0.48257756' '0.042067308' '0.043172438']]
1,244.9,81,0.097356,0.319234,Unspecified hypothyroidism,[['hypothyroidism' '0.10576923' '0.33757737' '0.081730776' '0.3886398']\n ['levothyroxine' '0.11057693' '0.3477198' '0.078125' '0.3144576']\n ['synthroid' '0.046875004' '0.18920892' '0.032451924' '0.11703823']\n ['hypothyroid' '0.01923077' '0.09503007' '0.015625' '0.07425124']\n ['levoxyl' '0.018028848' '0.09026522' '0.014423078' '0.06755623']\n ['mcg' '0.3641827' '0.6557869' '0.08052885' '0.05304037']\n ['88mcg' '0.0060096155' '0.03672703' '0.0060096155' '0.041458126']\n ['cystitis' '0.0036057695' '0.023882464' '0.0036057695' '0.02528072']\n ['kyphotic' '0.0036057695' '0.023882464' '0.0036057695' '0.02528072']\n ['cvat' '0.0036057695' '0.023882464' '0.0036057695' '0.02528072']]
2,250.0,133,0.159856,0.439431,"type II diabetes mellitus [non-insulin dependent type] [NIDDM type] [adult-onset type] or unspecified type, not stated as uncontrolled, without mention of complication",[['metformin' '0.073317304' '0.2621364' '0.05048077' '0.08942482']\n ['diabetes' '0.28846157' '0.6007684' '0.11538462' '0.0821224']\n ['glyburide' '0.03846154' '0.16302359' '0.028846156' '0.061709817']\n ['dm' '0.15264423' '0.42726845' '0.067307696' '0.05150881']\n ['mellitus' '0.15264425' '0.42726848' '0.0625' '0.040773362']\n ['noninsulin' '0.010817308' '0.059724223' '0.009615385' '0.029499378']\n ['dm2' '0.044471156' '0.18190216' '0.02283654' '0.026129907']\n ['glipizide' '0.02283654' '0.10888202' '0.014423078' '0.024919808']\n ['insulin' '0.28245193' '0.5952542' '0.082932696' '0.023776595']\n ['avandia' '0.013221154' '0.07032718' '0.009615385' '0.0214933']]
3,272.0,98,0.117788,0.362495,Pure hypercholesterolemia,[['hypercholesterolemia' '0.13822116' '0.40172112' '0.07331731'\n '0.13655801']\n ['lipitor' '0.16826923' '0.45313114' '0.04567308' '0.023694605']\n ['crestor' '0.016826924' '0.08541869' '0.009615385' '0.023464732']\n ['aspirin' '0.546875' '0.6887462' '0.097355776' '0.022431085']\n ['ceased' '0.0036057695' '0.023882464' '0.0036057695' '0.020499693']\n ['carotids' '0.027644232' '0.12645534' '0.012019231' '0.019025374']\n ['nonreactive' '0.086538464' '0.29445505' '0.0' '0.017625172']\n ['palate' '0.085336536' '0.29161334' '0.0' '0.017440615']\n ['mrs' '0.03846154' '0.16302359' '0.014423078' '0.01726563']\n ['crossclamp' '0.009615385' '0.054226533' '0.0060096155' '0.017237019']]
4,276.2,80,0.096154,0.316549,Acidosis,[['acidosis' '0.07572116' '0.26819247' '0.049278848' '0.16342042']\n ['metabolic' '0.086538464' '0.29445505' '0.036057692' '0.06290171']\n ['lactic' '0.013221154' '0.07032718' '0.009615385' '0.041882366']\n ['unclear' '0.18509617' '0.4790337' '0.048076924' '0.036246117']\n ['volumes' '0.0625' '0.2337916' '0.02283654' '0.03440499']\n ['cvvh' '0.02283654' '0.10888202' '0.012019231' '0.034102257']\n ['acidemia' '0.0048076925' '0.030457318' '0.0048076925' '0.03387371']\n ['hypercarbia' '0.013221154' '0.07032718' '0.008413462' '0.031534124']\n ['ascites' '0.082932696' '0.2858742' '0.026442308' '0.030813422']\n ['hapto' '0.028846156' '0.13070805' '0.013221154' '0.030686395']]


In [None]:
pd.reset_option('all')

In [None]:
# df_info.to_excel('jaccard.xls', index=False)