# Imports 

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import numpy as np

from datasets import load_dataset
from datasets import Value, ClassLabel, Features, DatasetDict


import transformers
from transformers import AutoTokenizer
import torch

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [4]:
from preprocessing.cleaning_utils import *

# Read in Data

In [5]:
mimic_dir = "/home/vs428/project/MIMIC/files/mimiciii/1.4/"
n2c2_dir = "/home/vs428/project/n2c2/2022/N2C2-AP-Reasoning/"
n2c2_data_dir =  "/home/vs428/project/n2c2/2022/Data/"


In [6]:
classes = ['Not Relevant', 'Neither', 'Indirect', 'Direct']
features = Features({
    'ROW ID':Value("int64"),
    'HADM ID':Value("int64"),
    'Assessment':Value("string"),
    'Plan Subsection':Value("string"),
    "Relation":Value("string")
}) 

dataset = load_dataset("csv", data_files={
                            "train":n2c2_data_dir + "train.csv",
                            "valid":n2c2_data_dir + "dev.csv",
                        },

                       features=features)
# dataset = dataset.class_encode_column("Relation")
dataset = dataset.rename_column("Relation", "label")

Using custom data configuration default-b1948d86214b7517
Reusing dataset csv (/home/vs428/.cache/huggingface/datasets/csv/default-b1948d86214b7517/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519)


  0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
test = pd.Series(dataset['train']['Assessment'])

In [8]:
# test.sample(10).tolist()
notes = pd.read_csv(mimic_dir + "NOTEEVENTS.csv")

  notes = pd.read_csv(mimic_dir + "NOTEEVENTS.csv")


In [10]:
dataset['train'][123]

{'ROW ID': 560236,
 'HADM ID': 199097,
 'Assessment': 'Patient is a 58 year old male with history of atrial fibrillation on\n   coumadin, hyperlipidemia, hypertension, peripheral vascular disease,\n   s/p CVA who presented to [**Hospital3 847**] on [**2185-3-2**] with heart\n   failure exacerbation and atrial fibrillation with RVR, transferred due\n   to difficulty with diuresis [**2-7**] hypotension.',
 'Plan Subsection': 'GERD\n   - Continue [**Hospital1 **] PPI',
 'label': 'Neither'}

# Electra-MEDAL

In [9]:
sample_texts = ["35 yo F with SLE, restrictive lung disease, CM (EF 15-20%) presents with severe lupus induced cardiomyopathy.", 
                                                 "54 year old man with pmh significant for bipolar disorder presenting with lithium toxicity",
                                                 "69 F w/ MMP including copd, diastolic chf, dm2, recent prolonged hospitalization where she became trach- dependent secondary to prolonged virulent pseudomonal PNA, presented in septic shock, ARF and AMS.",
                                                 "57 y/o man with long smokign history admitted to [**Hospital Unit Name 10**] with makred dypnea and lower extremity swelling:",
                                                 "73 yo male with history of bilateral renal cell carcinoma metastatic to right adrenalectomy presents for post-operative monitoring after right adrenalectomy and prostate biopsy."

                              ]

# Read in Meta-Inventory Abbreviations File

In [11]:
import scispacy
import spacy
from scispacy.abbreviation import AbbreviationDetector


In [12]:
nlp = spacy.load("en_core_sci_scibert")
nlp.add_pipe("abbreviation_detector")

<scispacy.abbreviation.AbbreviationDetector at 0x2b4a279336a0>

In [13]:
from functools import partiali had 

In [14]:
abbreviations = pd.read_csv("/home/vs428/project/Abbreviations/Metainventory_Version1.0.0.csv", sep="|", na_filter=False
                           )

In [15]:
med_abbvs = abbreviations[abbreviations['Source'].isin(["Vanderbilt Clinic Notes", "Vanderbilt Discharge Sums", "Berman", "Stetson", "Columbia"])]#.groupby("SF").size()

In [16]:
med_abbvs = med_abbvs[~med_abbvs['SF'].isin(nlp.Defaults.stop_words)]
med_abbvs = med_abbvs[~med_abbvs['SF'].isin(["man", "woman", "old", "Mr.", "Ms.", "Mrs", "M", "F"])]

In [17]:
med_abbvs = med_abbvs.astype({"Source":"category"})

In [18]:
sorter = ["Vanderbilt Discharge Sums", "Vanderbilt Clinic Notes",  "Stetson", "Columbia", "Berman"]

In [19]:
med_abbvs.Source.cat.set_categories(sorter, inplace=True)

  res = method(*args, **kwargs)


In [20]:
med_abbvs = med_abbvs.sort_values(['Source'])

In [21]:
med_abbvs[med_abbvs["SF"] == "CVA"]

Unnamed: 0,GroupID,RecordID,SF,SFUI,NormSF,LF,LFUI,NormLF,Source,Modified
400041,G155037,R404085,CVA,S015761,cva,costovertebral angle,L062247,costovertebral angle,Vanderbilt Discharge Sums,
400039,G155036,R404083,CVA,S015761,cva,cerebrovascular accident,L055946,cerebrovascular accident,Vanderbilt Discharge Sums,
398834,G155036,R402858,CVA,S015761,cva,cerebral vascular accident,L055869,cerebral vascular accident,Vanderbilt Clinic Notes,
398833,G018226,R402857,CVA,S015761,cva,costovertebral angle tenderness,L062248,costovertebral angle tenderness,Vanderbilt Clinic Notes,


In [22]:
med_abbvs[med_abbvs['SF'] == 'yr']

Unnamed: 0,GroupID,RecordID,SF,SFUI,NormSF,LF,LFUI,NormLF,Source,Modified
391050,G181844,R394879,yr,S103595,yr,year,L169257,year,Berman,


In [23]:
unq_sfs = med_abbvs['SF'].unique()

In [32]:
med_abbvs

Unnamed: 0,GroupID,RecordID,SF,SFUI,NormSF,LF,LFUI,NormLF,Source,Modified
400018,G154987,R404062,CRI,S015189,cri,chronic renal insufficiency,L057690,chronic renal insufficiency,Vanderbilt Discharge Sums,
400604,G156042,R404663,MAC,S042996,mac,mitral annular calcification,L112066,mitral annular calcification,Vanderbilt Discharge Sums,
400603,G156197,R404662,Mac,S045973,mac,macular,L106324,macular,Vanderbilt Discharge Sums,
400602,G045611,R404661,MAC,S042996,mac,multi-lumen access catheters,L113563,,Vanderbilt Discharge Sums,
400601,G156043,R404660,MAC,S042996,mac,mycobacterium avium complex,L114672,Mycobacterium avium complex,Vanderbilt Discharge Sums,
...,...,...,...,...,...,...,...,...,...,...
391057,G107533,R394886,zbp,S103626,zbp,zinc binding protein,L169489,zinc binding protein,Berman,
391058,G107534,R394887,zcl,S103627,zcl,zoonotic cutaneous leishmaniasis,L169666,zoonotic cutaneous leishmaniasis,Berman,
391059,G107535,R394888,zd,S103628,zd,zone drilling,L169635,,Berman,
391045,G107494,R394874,yohf,S103587,yohf,year old hispanic female,L169265,,Berman,


In [24]:
out_test = dataset.map(partial(expand_abbreviations, spacy_pip=nlp, abbv_map=med_abbvs, unq_sfs=unq_sfs))

  0%|          | 0/4633 [00:00<?, ?ex/s]

  0%|          | 0/597 [00:00<?, ?ex/s]

In [25]:
rand_idxs = np.random.randint(len(out_test['train']), size=10)
print(dataset['train'][rand_idxs]['Assessment'])
print(out_test['train'][rand_idxs]['Assessment'])


['42 y/o lady with CVID, HepC, Type 1 DM, distant IBD > 20 yrs ago, last\n   flare, recent cryptospordial infection presented to OSH with worsening\n   abdominal pain, nausea and vomitting.', '67F with stage 4 pancreatic cancer now p/w 4 days of nausea and\n   vomiting, inabaility to tolerate po, and fatigue, found to have\n   hyponatremia, hyerkalemia, and acute renal insufficiency\n   .', '87 yo male w/ HTN, dementia, past prostate CA presents w/ shock and\n   hypoxemic respiratory failure', '87 y/o F h/o CAD, COPD, HTN a/w sepsis in the setting of GNR bacteremia\n   and c/f obstructive pyelonephritis.', '43-year-old woman with pulmonary hypertension s/p L pneumonectomy due\n   to TB, also with history of OSA, presented with recurrent dyspnea.', '67yo woman transferred to the medical ICU with hypotension and\n   somnolence in the setting of IL-2 treatment for metastatic renal cell\n   carcinoma.', '[**Age over 90 **] yo female with severe dementia with poor functional status, multipl

In [26]:
# abbreviations = abbreviations[~abbreviations['SF'].isin(nlp.Defaults.stop_words)]
# abbreviations = abbreviations[~abbreviations['SF'].isin(["man", "woman", "old", "Mr.", "Ms.", "Mrs"])]

In [29]:
set(dataset['train']['Assessment'])

{'# Chronic atrial fibrillation: status post pulmonary vein isolation.\n   Telemetry shows normal sinus rythym with PACs, versus multifocal atrial\n   tachycardia followed by sinus pauses.  Pt failed prior atrial ablation\n   and trial without nodal blocking agents.\n   --continue propafenone and diltiazem as above\n   -- hold atenolol\n   -- restart Coumadin\n   -- omeprazole for esophageal protection after ablation',
 '# Respiratory alkalosis.  pH 7.68 in ED.  Compensating for metabolic\n   acidosis but hyperventilating in addition.  Improved on repeat; likely\n   anxiety related as patient hyperventilating during testing.  ASA\n   negative.\n   -          Ativan prn anxiety as per home regimen.',
 '#) LLE DVT - felt secondary to venous stasis from lymphatic obstruction\n   and malignancy. IVC filter placed due to extensive clot burden and\n   limitations in anticoagulation given recent hematemesis and\n   thrombocytopenia.  PLTs 69, HCT 27.9.\n   -  Original plan was to cont. hepari

In [31]:
out_test['train'].to_parquet(n2c2_data_dir + "train-abbv.parquet")
out_test['valid'].to_parquet(n2c2_data_dir + "valid-abbv.parquet")
# dataset['test'].to_parquet()

419974

## Read in Electra Model

In [None]:
from transformers import AutoTokenizer, AutoModel

electra_tokenizer = AutoTokenizer.from_pretrained("xhlu/electra-medal")
electra = AutoModel.from_pretrained("xhlu/electra-medal")

In [None]:
# electra = torch.hub.load("BruceWen120/medal", "electra")

## Add abbreviations to tokenizer vocab

In [None]:
electra_tokenizer.add_tokens(list(set(filtered_abbvs['abbreviation'].str.lower().tolist())))

4317

In [None]:
electra.resize_token_embeddings(len(electra_tokenizer))

Embedding(34839, 128)

## Identify abbreviations

In [None]:
tokenized = electra_tokenizer(sample_texts)

In [None]:
abbv_locs = []

for x in tokenized['input_ids']:
    tokens = electra_tokenizer.convert_ids_to_tokens(x)
    abbv_locs[i for i, token in enumerate(tokens) if token in filtered_abbvs['abbreviation'].str.lower().tolist()]

# test = [electra_tokenizer.convert_ids_to_tokens(tokenized['input_ids'][x]) for x in range(len(tokenized['input_ids']))]

['[CLS]', '35', 'yo', 'f', 'with', 'sle', ',', 're', 'str', 'ict', '##ive', 'lung', 'd', 'ise', 'as', '##e', ',', 'cm', '(', 'ef', '15', '-', '20', '%', ')', 'pre', '##se', 'nts', 'with', 'sev', 'ere', 'lu', '##pus', 'i', 'nd', 'uc', '##ed', 'c', 'ard', 'iom', 'y', 'opa', 'thy', '.', '[SEP]']
[5, 7, 8, 9, 13, 14, 17, 19, 27, 29, 30, 34, 35, 38, 39, 41]


## Create MeDAL-esque Dataset

In [None]:
# for ELECTRA
class HuggingfaceInferenceDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, max_length=512, device='cpu'):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.device = device
        self.df = df 

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idxs):
        batch_df = self.df.iloc[idxs]
        idxs = list(compress(idxs, batch_df['TEXT'].apply(lambda string: len(string.split()) < self.max_length).to_list()))
        batch_df = self.df.iloc[idxs]
        locs = batch_df['LOCATION'].values
        labels = torch.tensor(batch_df['LABEL_NUM'].values)
        labels = labels.to(self.device)
        tokenized = self.tokenizer.batch_encode_plus(batch_df['TEXT'].tolist(), max_length=self.max_length, \
                    pad_to_max_length=True)['input_ids']
        return tokenized, torch.tensor(locs), labels

# SciSpacy

In [15]:
import scispacy
import spacy
from scispacy.abbreviation import AbbreviationDetector


In [16]:
nlp = spacy.load("en_core_sci_scibert")
nlp.add_pipe("abbreviation_detector")

<scispacy.abbreviation.AbbreviationDetector at 0x2b710ea43a00>

In [None]:
doc = nlp("Spinal and bulbar muscular atrophy (SBMA) is an \
           inherited motor neuron disease caused by the expansion \
           of a polyglutamine tract within the androgen receptor AR. \
           SBMA can be caused by this easily.")

In [None]:
print("Abbreviation", "\t", "Definition")
for abrv in doc._.abbreviations:
	print(f"{abrv} \t ({abrv.start}, {abrv.end}) {abrv._.long_form}")

Abbreviation 	 Definition
SBMA 	 (31, 32) Spinal and bulbar muscular atrophy
SBMA 	 (6, 7) Spinal and bulbar muscular atrophy


In [None]:
idx = np.random.randint(len(dataset['train']))

In [None]:
doc = dataset['train'][idx]
doc['Assessment'] = 'This is a 58 yom with history of (ESRD) on (HD) since [**3-8**], then Peritoneal\n   Dialysis since [**9-10**], (DM2), (HTN), Diastolic (CHF), history of (MSSA)\n   peritonitis [**6-11**] who presents to ED with new onset abdominal pain since\n   this morning, associated with nausea/vomiting and fevers, admittted\n   with bacterial peritonitis.'
print(f"Assessment: {doc['Assessment']}\nPlan: {doc['Plan Subsection']}")
parsed_assess = nlp(doc['Assessment'])
parsed_plan = nlp(doc['Plan Subsection'])

print("Abbreviation", "\t", "Definition")
for abrv in parsed_assess._.abbreviations:
    print(f"{abrv} \t ({abrv.start}, {abrv.end}) {abrv._.long_form}")    
for abrv in parsed_plan._.abbreviations:
    print(f"{abrv} \t ({abrv.start}, {abrv.end}) {abrv._.long_form}")    


Assessment: This is a 58 yom with history of (ESRD) on (HD) since [**3-8**], then Peritoneal
   Dialysis since [**9-10**], (DM2), (HTN), Diastolic (CHF), history of (MSSA)
   peritonitis [**6-11**] who presents to ED with new onset abdominal pain since
   this morning, associated with nausea/vomiting and fevers, admittted
   with bacterial peritonitis.
Plan: Bacterial Peritonitis:  Patient presents with symptoms on abdominal
   pain, fevers, N/V and paracentesis consistent with bacterial
   peritonitis.  He has been started on Vanco/Ceftaz for treatment of
   infection.  Renal aware and is following, will continue with PD per
   renal recs.  Patient meets sepsis protocol, however CVP elevated.
   - cont Vanco 1gm daily and Ceftaz 1gm daily
   - Vanco level in AM, dose accordingly
   - Bolus with 250cc IVF for hypotension, mental status changes
   - Add Levophed if needed
   - f/u renal recs
   - Morphine 2-4mg IV q6h PRN pain
Abbreviation 	 Definition


**Notes:** So this obviously didn't work because the abbreviation detector requires the definition form of an abbreviation, like "coronary artery disease (CAD)"