In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, Dataset
from collections import namedtuple
import torch
from functools import partial
from tqdm import tqdm
from torch.optim import Adam
from torch import nn
from sklearn.metrics import accuracy_score, confusion_matrix

In [3]:
from ner_ehr.training.datasets import EHRDataModule, EHRBatchCollator
from ner_ehr.data.vocab import TokenEntityVocab
from ner_ehr.training.models import LSTMNERTagger
from ner_ehr.training.losses import cross_entropy
from ner_ehr.training.metrics import accuracy_per_class

In [4]:
from typing import Union, List
from pathlib import Path
from ner_ehr.data.variables import AnnotationTuple, LongAnnotationTuple
from ner_ehr.data.utils import df_to_namedtuples
from ner_ehr.data.vocab import TokenEntityVocab
from ner_ehr.data.ehr import EHR
import os
from glob import glob


def read_annotatedtuples(dir: Union[str, Path]) -> List[AnnotationTuple]:
    """Read annotated tuples from CSVs present inside given directory.

    Args:
        dir: directory containing CSVs with annotated tokens

    Returns:
        annotatedtuples: list of AnnotatedToken tuples
                [
                    Annotation(
                        doc_id='100035',
                        token='Admission',
                        start_idx=0,
                        end_idx=9,
                        entity='O'),
                    Annotation(
                        doc_id='100035',
                        token='Date',
                        start_idx=10,
                        end_idx=14,
                        entity='O'),
                ]
    """
    annotatedtuples = []
    for fp in glob(os.path.join(dir, r"*.csv")):
        annotatedtuples += df_to_namedtuples(
            name=AnnotationTuple.__name__,
            df=EHR.read_csv_tokens_with_annotations(fp),
        )

    return annotatedtuples

In [35]:
dir = "../processed/train_subset"
annotatedtuples = read_annotatedtuples(dir)
vocab = TokenEntityVocab(to_lower=True)
vocab.fit(annotatedtuples=annotatedtuples)

In [36]:
vocab.token_to_idx("Admission")

[2]

In [38]:
vocab.token_entity_freq["admission"]

Counter({'O': 6})

In [33]:
vocab.token_entity_freq["Admission"].items()

dict_items([])

In [51]:
pd.DataFrame.from_dict({"a": 6, "b": 6, "c":3}, orient="index",).reset_index().sort_values(by=[0,"index"], ascending=[False, True]).values[0][1]

6

In [41]:
sorted(vocab.token_entity_freq["admission"].items(), key=lambda x: (x[1], x[0]), reverse=True)

[('O', 6)]

In [20]:
vocab.token_to_idx("nitin")

[1]

In [6]:
collate_fn = EHRBatchCollator(return_meta=True)
dm = EHRDataModule(
    vocab=vocab,
    collate_fn = collate_fn,
    dir_train="../processed/train_subset/",
    # dir_val="../processed/val/",
    # dir_test="../processed/test/",
    seq_length = 128, 
    batch_size_train = 32,
    annotated=True)
dm.setup()

In [7]:
dl = dm.train_dataloader()
for i, (X, Y, data) in enumerate(dl):
    if i == 0:
        break
X.shape, Y.shape, (len(data), len(data[0]))

(torch.Size([32, 128]), torch.Size([32, 128]), (32, 128))

In [59]:
# import spacy
# from spacy import displacy

# for i, (start, end) in enumerate(para_start_indexes3):
#     print(f"{'='*50}{i}{'='*50}")
#     window = 10
#     idx = (end - start + 2*window)//2
    
#     string = text[start-window:end+window]
    
#     ex = [{"text": string, 
#        "ents": [{"start": idx-1, "end": idx+1, "label": "O"}],
#        "title": None}]
#     html = displacy.render(ex, style="ent", manual=True)

In [None]:
from ner_ehr.data.embeddings import (
    GloveEmbeddings, 
    PubMedicalEmbeddings)

In [17]:
embed = GloveEmbeddings(
    unknown_token_embedding=np.zeros(50), 
    glove_fp="/scratch/mittal.nit/embeddings/glove.6B.50d.txt")
embed.load_word2vec()
embed.embeddings.most_similar("seizure")

[('unspecified', 0.766258955001831),
 ('seizures', 0.746199905872345),
 ('accidental', 0.7298815846443176),
 ('retaliation', 0.7024218440055847),
 ('torture', 0.6841019988059998),
 ('executions', 0.679871678352356),
 ('ordering', 0.6795619130134583),
 ('imprisonment', 0.6747194528579712),
 ('confiscation', 0.6736593842506409),
 ('deportation', 0.6688999533653259)]

In [18]:
EMBEDDING_DIM = 50
NUM_CLASSES = vocab.num_uniq_entities
VOCAB_SIZE = vocab.num_uniq_tokens
HIDDEN_SIZE = 32
EPOCHS = 1

embedding_weights = np.zeros((vocab.num_uniq_tokens, EMBEDDING_DIM), dtype=np.float)
for token, idx in tqdm(vocab._token_to_idx.items(), leave=False, position=0):
    embedding_weights[idx] = embed(tokens=token)[0]
embedding_weights = torch.tensor(embedding_weights, dtype=torch.float32) 

lstm = LSTMNERTagger(
    embedding_dim=EMBEDDING_DIM, 
    vocab_size=VOCAB_SIZE,
    hidden_size=HIDDEN_SIZE, 
    num_classes=NUM_CLASSES, 
    # embedding_weights=embedding_weights, 
    bidirectional=True)

dl = dm.train_dataloader()
adam = Adam(lstm.parameters(), lr=.001)
t = tqdm(range(EPOCHS))
losses = [] 
for i in t:
    for j, (X, Y, data) in enumerate(dl):
        Y_hat = lstm(X)
        loss = cross_entropy(Y_hat, Y)
        losses.append(loss.item())
        adam.zero_grad()
        loss.backward()
        adam.step()
        t.set_description(f"epoch: {i+1}/{EPOCHS}, batch={j}, loss={losses[-1]:.5f}")

epoch: 1/1, batch=0, loss=2.89801: 100%|██████████| 1/1 [00:00<00:00,  1.01it/s]


In [21]:
lstm.eval()
with torch.no_grad():
    Y_hat = lstm(X)
# Y_hat = torch.argmax(Y_hat, axis=-1)
Y_hat.shape, Y.shape

(torch.Size([32, 128, 17]), torch.Size([32, 128]))

In [16]:
Y_hat = torch.randn((5,5), requires_grad=True)

with torch.no_grad():
    X = nn.functional.softmax(Y_hat, dim=-1)
X

tensor([[0.2697, 0.3675, 0.1870, 0.0703, 0.1055],
        [0.0492, 0.4234, 0.0826, 0.4139, 0.0310],
        [0.1586, 0.2004, 0.2256, 0.1752, 0.2403],
        [0.1779, 0.3017, 0.2987, 0.1736, 0.0480],
        [0.1276, 0.2928, 0.1792, 0.3584, 0.0420]])

In [57]:
accuracy(Y_hat, Y)

array([0.01037577, 0.11494253, 0.03125   , 0.        , 0.16666667,
       0.        , 0.        , 0.07142857, 0.        , 0.        ,
       0.07692308, 0.        , 0.        , 0.35      , 0.10416667,
       0.28571429, 0.        ])

In [17]:
# embed = PubMedicalEmbeddings(
#     unknown_token_embedding=np.zeros(200), 
#     pubmed_fp="/scratch/mittal.nit/embeddings/pubmed2018_w2v_200D/pubmed2018_w2v_200D.bin")
# embed.load_word2vec()
# embed.embeddings.most_similar(positive="seizure")

In [43]:
num_tokens_not_in_embed = len(
    [token for token in vocab._token_to_idx.keys() if token.lower() not in embed.embeddings.key_to_index.keys()])
num_unknown_embed = 0
for embedding in embeddings:
    if np.allclose(embedding, embed.unknown_token_embedding):
        num_unknown_embed+=1
assert num_tokens_not_in_embed == num_unknown_embed
num_tokens_not_in_embed

214

In [None]:
from ner_ehr.utils import save_np, load_np

In [None]:
save_np(embeddings, fp="test")

In [None]:
new_embeddings = load_np("test")

In [None]:
np.allclose(embeddings, new_embeddings)

True