In [1]:
!pip -q install openai langchain  tiktoken  pinecone-client openpyxl  sentence-transformers pinecone-text 

In [2]:
import time
import json
import os
import re
from collections import defaultdict 
from collections import Counter

import numpy as np
from numpy.linalg import norm
import pandas as pd

In [3]:
import openai
import pinecone
from pinecone_text.sparse import BM25Encoder

from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.output_parsers import ResponseSchema
from langchain.output_parsers import StructuredOutputParser

from langchain.vectorstores import Pinecone
from langchain.chains import VectorDBQAWithSourcesChain
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import SimpleSequentialChain
from langchain.chains import LLMRequestsChain, LLMChain

import tiktoken
from sentence_transformers import SentenceTransformer

  from tqdm.autonotebook import tqdm
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\zelalemgero\AppData\Roaming\nltk_data...
[nltk_data]   Unzipping tokenizers\punkt.zip.
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\zelalemgero\AppData\Roaming\nltk_data...
[nltk_data]   Unzipping corpora\stopwords.zip.


openai api key config

In [5]:
openai.api_key = "YOUR_API_KEY"
openai.api_base = "YOUR_API_BASE"
openai.api_type = 'azure'
openai.api_version ='2023-03-15-preview'

## Read the preprocessed mimic III file

In [None]:
filename = '/PATH/TO/YOUR/FILE/mimiciii_sampled.pkl'
sampled_mimiciii = pd.read_pickle(filename)

In [None]:
df = sampled_mimiciii[['_id','text','target']]
mimiciii_dict = dict(zip(df['_id'],df['target']))

token_text = df.iloc[1]['text']
chunks = 800
new_splits = []
for idx,row in df.iterrows():
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunks, chunk_overlap=50, length_function=tiktoken_len, separators=["\n\n","\n", " ",""]) 
  txt = row['text'].replace('\n', ' ').replace('[**', '').replace('**]','').replace('*', '').replace('--','').replace('__','')
  txt = re.sub('\s{2,}', ' ', txt) 
  splits = text_splitter.split_text(txt)
  new_splits.extend([[x,row['_id'],'NA'] for x in splits])
  if len(new_splits) > 10000:
    break

df2 = pd.DataFrame(new_splits, columns =['text', 'index','code']) 
test_keys = list(set(df2['index']))

bm25 = BM25Encoder()
bm25.fit(token_text)
embd_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

## Initialize a pinecone vector database and insert the chunked mimic notes 

In [None]:
pinecone.init(api_key=pinecone_key, environment=pinecone_env)

index_name = "mimic-search"

# check if the ner-search index exists
if index_name not in pinecone.list_indexes():
    # create the index if it does not exist
    pinecone.create_index(
        index_name,
        dimension=384,
        metric="dotproduct",
        pod_type="s1"
    )
# connect to doc-search index we created
index = pinecone.Index(index_name)

In [None]:
from tqdm.auto import tqdm
def insert_text(df):
  # we will use batches of 64
  batch_size = 8
  #EMBEDDING_MODEL = "text-embedding-ada-002"
  

  for i in tqdm(range(0, len(df), batch_size)):
      ids =  [x for x in range(batch_size)]
      # find end of batch
      i_end = min(i+batch_size, len(df))
      # extract batch
      batch = df.iloc[i:i_end]
      # generate embeddings for batch
      #dense_vecs = embed_model.encode().tolist()
      #dense_vecs = [get_embedding(txt,  EMBEDDING_MODEL) for txt in batch["text"].tolist()]
      dense_vecs = embd_model.encode(batch["text"].tolist())
      sparse_vecs = bm25.encode_documents(batch["text"].tolist())
      # create sparse vecs
      contexts = batch['text'].tolist()
      metadata = batch.to_dict(orient="records")
      upserts = []
      for _id, sparse, dense,  context in zip(ids, sparse_vecs,dense_vecs, contexts):
          # build metadata struct
          #metadata = {'context': context}
          meta = metadata[_id]
          #batch = batch.drop('text', axis=1)
          
          # append all to upserts list as pinecone.Vector (or GRPCVector)
          upserts.append({
              'id': 'note' + str(_id) + str(i),
              'sparse_values': sparse,
              'values': dense,
              'metadata': meta})

      _ = index.upsert(vectors=upserts)

      if i % 100 == 0:
        print(i)
      

  index.describe_index_stats() 

In [None]:
#insert text notes into vetor db
insert_text(df2)

In [6]:
def encode_query(text: str):
    # create dense vec
    dense_vec = embd_model.encode(text).tolist()
    #dense_vec = get_embedding(text, engine=EMBEDDING_MODEL)

    # create sparse vec
    sparse_vec = bm25.encode_queries(text)
    return dense_vec, sparse_vec



def hybrid_scale(dense, sparse, alpha: float):
    # check alpha value is in range
    if alpha < 0 or alpha > 1:
        raise ValueError("Alpha must be between 0 and 1")
    # scale sparse and dense vectors to create hybrid search vecs
    hsparse = {
        'indices': sparse['indices'],
        'values':  [v * (1 - alpha) for v in sparse['values']]
    }
    hdense = [v * alpha for v in dense]
    return hdense, hsparse

def search_pinecone(query,alpha=1):
    if len(query)==2:
      q = query[0] 
      filter_criteria = {"index": {"$eq": query[1]}}
    elif len(query) == 1:
       q = query
       filter_criteria = None
    else:
       raise Exception('Please input a query') 
    dense,sparse = encode_query(query[0])  
    hdense, hsparse = hybrid_scale(dense,sparse, alpha= alpha)
   
    # create embeddings for the query
    xc = index.query(top_k=150, vector=hdense,sparse_vector=hsparse,include_metadata=True, filter=filter_criteria)
    #print(xc)
    indx_score = {}
    for lst in xc['matches']:
      idx = lst['metadata']['index']
      if idx == 'icd':
        idx = 'icd_' + lst['metadata']['code']
      score = lst['score']
      indx_score[idx] = score
    
    r = [x["metadata"] for x in xc["matches"]]
    return {"index score": indx_score, "metadata": r}

#### tiktoken tokenizer for adjusting number of tokens passed to various openai models

In [7]:
tik_tokenizer = tiktoken.get_encoding('p50k_base')
def tiktoken_len(text):
  tokens = tik_tokenizer.encode(text, disallowed_special = ())
  return len(tokens)

### create schemas for prompting using langchain's ResponseSchema module

In [8]:
disease_schema = ResponseSchema(name = 'diseases', description="this is focuses on medical history of the patient.It involves extracting information about diseases, disorders, or medical conditions that have affected the patient")
response_schemas = [disease_schema]
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
format_instructions = output_parser.get_format_instructions()
template_string = """You are an expert in information extraction from clinical notes. Given a clinical note, you extract diseases/disorders/procedures and return all as a single python list of strings like ["disease","disease","procedure"]. Here is the note: {text_note}. {format_instructions}
"""
disease_prompt = ChatPromptTemplate(messages=[HumanMessagePromptTemplate.from_template(template_string)],
                            input_variables=['text_note'],
                            partial_variables={"format_instructions":format_instructions},
                            output_parser=output_parser
)

evidence_schema = ResponseSchema(name = 'evidence', description="""You are an expert fact checker. Your job is to fact check the {diseases} provided based on the full input text given and return a python list of lists. Individual lists have
                                            evidence text span next to each of the disease. Your job is to determine whether each of the values in the list {diseases} is correct or not. Find the span of text (a one sentence) from the 
                                            {text_input} as an evidence.  Add the text span evidence that makes the answer True or False. Dont add any extra text.
                                             Example output is here ###
                                              [
                                             ["ESRD",  "ESRD secondary to hypertensive nephrosclerosis ","True"],
                                            ["DM", " DM, on glipizide at home","True"],
                                            ["Hypertension", "high blood pressure ruled out", "False"]
                                                        ]                               """)
response_schemas = [evidence_schema]
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
format_instructions = output_parser.get_format_instructions()
template_string = """You are an expert fact checker. Given a clinical note and list of diseases and procedures, you verify by extracting evidence and return all as a single python list of strings. Here is the note: {text_note} and list of diseases and procedures. {format_instructions}
"""
evidence_prompt = ChatPromptTemplate(messages=[HumanMessagePromptTemplate.from_template(template_string)],
                            input_variables=['text_note','diseases'],
                            partial_variables={"format_instructions":format_instructions},
                            output_parser=output_parser
)




entail_schema = ResponseSchema(name = 'entail', description="this focuses on checking wether the disease/procedure can be entailed from the text fragment in the provided {text_note} in the format [disease, text fragment] ")
response_schemas = [entail_schema]
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
format_instructions = output_parser.get_format_instructions()
template_string = """ You are an expert textual entailment agent. Textual Entailment(aka Natural Language inference) is directional relation between text fragments. The relation holds whenever the truth of one text fragment follows from another text. Your job is to check wether the disease can be entailed from the text in the provided {text_note} in the format (disease, text fragment).Dont extrapolate, entail only based on the provided text fragment. If the disease can be entailed from the text fragment based on the {text_note}, add the diseases to the output list. Return the disease list of all diseases that can be entailed. All values must be in a string format inside double quotations . Here is the note: {text_note}. {format_instructions}
"""
entail_prompt = ChatPromptTemplate(messages=[HumanMessagePromptTemplate.from_template(template_string)],
                            input_variables=['text_note'],
                            partial_variables={"format_instructions":format_instructions},
                            output_parser=output_parser
)

omissions_schema = ResponseSchema(name = 'omissions', description="this focuses on finding all the missing disease/procedure from the provided {text_note} that are not in the list {diseases} ")
response_schemas = [omissions_schema]
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
format_instructions = output_parser.get_format_instructions()
template_string = """ You are an expert disease inspector. Your job is to find all possible diseases/procedures in the given {text_note} exhaustively and return in a python list of strings. Your response should be in the form of 
                                            python list with all the diseases/procedures that you can verify do exist in the {text_note}. Make sure to return the disease/procedures exhaustively.Dont include a disease/procedure if it is in the {diseases} list. Return only unique diseases/procedures. All diseases/procedures in the list must be in a string format.  Here is the note: {text_note} and the list of diseases/procedures {diseases}. {format_instructions}
"""
omissions_prompt = ChatPromptTemplate(messages=[HumanMessagePromptTemplate.from_template(template_string)],
                            input_variables=['text_note','diseases'],
                            partial_variables={"format_instructions":format_instructions},
                            output_parser=output_parser
)


icd_schema = ResponseSchema(name = 'icd', description="this focuses on assining ICD 9 codes for all the diseases/procedures listed in the {diseases} based on the context in {text_note} ")
response_schemas = [icd_schema]
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
format_instructions = output_parser.get_format_instructions()
template_string = """You are an expert clinical data encoder.  your job is to extract ICD 9 codes for all the diseases/procedures listed in {diseases} based ont the context in {text_note}. Then return a python list of strings containing all the ICD 9 codes you assigned.{format_instructions}
"""
icd_prompt = ChatPromptTemplate(messages=[HumanMessagePromptTemplate.from_template(template_string)],
                            input_variables=['text_note','diseases'],
                            partial_variables={"format_instructions":format_instructions},
                            output_parser=output_parser
)

icd_schema_10 = ResponseSchema(name = 'icd', description="this focuses on assining ICD 10 codes for all the diseases/procedures listed in the {diseases} based on the context in {text_note} ")
response_schemas = [icd_schema_10]
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
format_instructions = output_parser.get_format_instructions()
template_string = """You are an expert clinical data encoder.  your job is to extract ICD 10 codes for all the diseases/procedures listed in {diseases} based ont the context in {text_note}. Then return a python list of strings containing all the ICD 10 codes you assigned.{format_instructions}
"""
icd_prompt_10 = ChatPromptTemplate(messages=[HumanMessagePromptTemplate.from_template(template_string)],
                            input_variables=['text_note','diseases'],
                            partial_variables={"format_instructions":format_instructions},
                            output_parser=output_parser
)

In [11]:
def get_diseases(note):
  for i in range(3):
    try:
      chain = LLMChain(llm=llm,prompt=disease_prompt)
      diseases = chain.predict_and_parse(text_note = note)
      return diseases['diseases']
    except openai.error.RateLimitError:
      time.sleep(5)  
    except Exception as e:
      print(e) 

def get_evidence(note,diseases):
  for i in range(3):
    try:
      chain = LLMChain(llm=llm,prompt=evidence_prompt)
      evidence = chain.predict_and_parse(text_note = note, diseases = diseases)
      return evidence  
    except openai.error.RateLimitError:
      time.sleep(5)  
    except Exception as e:
      print(e) 

def does_ential(evidence):
  for i in range(3):
    try:
      chain = LLMChain(llm=llm,prompt=entail_prompt)
      evidence_list = [(x[0],x[1]) for x in evidence['evidence'] if x[2].lower()=='true']
      entail = chain.predict_and_parse( text_note = evidence_list)
      verified = entail['entail']
      return verified  
    except openai.error.RateLimitError:
      time.sleep(5)  
    except Exception as e:
      print(e) 

def find_omissions(note, verified):
  for i in range(3):
    try:
      chain = LLMChain(llm=llm,prompt=omissions_prompt)
      omissions = chain.predict_and_parse( text_note = note, diseases = verified)
      return omissions['omissions']  
    except openai.error.RateLimitError:
      time.sleep(5)  
    except Exception as e:
      print(e)   

def get_icds(note, diseases):
  for _ in range(3):
    try:
      chain = LLMChain(llm=llm,prompt=icd_prompt)
      icds = chain.predict_and_parse( text_note = note, diseases = diseases)
      return icds
    except openai.error.RateLimitError:
      time.sleep(5)  
    except Exception as e:
      print(e)   

def get_icds_10(note, diseases):
  for _ in range(3):
    try:
      chain = LLMChain(llm=llm,prompt=icd_prompt_10)
      icds = chain.predict_and_parse( text_note = note, diseases = diseases)
      return icds
    except openai.error.RateLimitError:
      time.sleep(5)  
    except Exception as e:
      print(e)  

def extract_notes(k):
  query = [ "disease, patient conditions, Diagnosis, treatment, examination, laboratory tests, imaging, medical problem, medical condition past medical history, what are the diseases/conditions the patient had, procedures ?",indx]
  result = search_pinecone(query,0.7)
  retrieved_text = "" .join([x['text'] for x in result['metadata']])[:400000] 
  
  return retrieved_text   

def calculate_metrics_9(output, icd):
    tn, fn, fp, acc, total_pred, total_y = 0, 0, 0, 0, 0, 0

    for k, vv in output.items():
        try:
          new_v = vv['icd']
        except:
          continue  
        y_true = {str(x) for x in mimiciii_dict[k] if str(x) in icd}
        if not y_true:
          continue
        v = {xx.strip() for xx in new_v if xx.strip() in icd}
        
        neg = set(icd).difference(y_true)
        potential_neg = set(icd).difference(v)

        tn += len(neg.intersection(potential_neg))
        fn += len(potential_neg.intersection(y_true))
        fp += len(v.intersection(neg))
        total_y += len(y_true)
        acc += len(v.intersection(y_true))
        total_pred += len(v)

    P = acc / total_pred
    R = acc / total_y
    A = (acc + tn) / (acc + tn + fn + fp)

    #return {'P': P, 'R': R, 'A': A}

    print("total correct:", acc, "total pred:",total_pred, "total true:", total_y)  
    print('precision is ',P)
    print('recall is ',R)  
    print('F1 is ',2*((P*R)/(P+R)))
    print('Accuracy is ',A)


### top 10 and top50 most common ICD9 codes for evaluation

In [None]:
icd_top10 = ['401.9','272.4','530.81','250.00','428.0','427.31','414.01','518.81','599.0','584.9']
icd_top50 = ['401.9','38.93','428.0','427.31','414.01','96.04','96.6','584.9','250.00','96.71','272.4','518.81','99.04','39.61','599.0','530.81','96.72','272.0', '285.9','88.56','244.9','486','38.91', '285.1','36.15','276.2','496','99.15','995.92','V58.61','507.0','038.9','88.72','585.9','403.90','311','305.1','37.22','412','33.24','39.95','287.5','410.71','276.1','V45.81','424.0', '45.13','V15.82','511.9','37.23']


In [None]:
# output variables for storing the abalation runs 
base_icd = defaultdict(list)
prune_icd = defaultdict(list)
omissions_icd = defaultdict(list)
op_icd = defaultdict(list)

### Initialize the openai models

In [None]:
llm_4 = ChatOpenAI(model_name = 'gpt-4-32k-0314', openai_api_key=openai.api_key,temperature=0.1,max_tokens=500, deployment_id='gpt-4-32k-0314')
llm_35 = ChatOpenAI(model_name = 'gpt-35-turbo-0301', openai_api_key=openai.api_key,temperature=0.1,max_tokens=500,deployment_id='gpt-35-turbo-0301')
llm_3 = OpenAI(model_name = 'text-davinci-003', openai_api_key=openai.api_key,temperature=0.1,max_tokens=500, deployment_id='text-davinci-003')

llm_list = [llm_4,llm_35,llm_3]

In [None]:
# MIMIC III ICD-9 coding
for model in llm_list:
  llm = model
  
  for k in test_keys:
    diseases,entail,all_omissions = [],[],[]
    note = extract_notes(k)

    # adjust maximum context length for each model
    num_tokens = tiktoken_len(note)
    if llm.model_name == 'text-davinci-003':
      max_tokens = 3000
    elif llm.model_name ==  'gpt-35-turbo-0301':
      max_tokens = 7000
    else:
      max_tokens = 30000  
    while num_tokens > max_tokens:
      delta =  num_tokens - max_tokens
      note = note[:-delta*3]
      num_tokens = tiktoken_len(note)
    

    #Extract diseases
    diseases = get_diseases(note)
    if diseases:
      icds = get_icds(note, diseases)
      base_icd[k] = icds
    else:
      base_icd[k] = []  
    
    #Generate Evidence 
    evidence = get_evidence(note,diseases)
    entail = does_ential(evidence)
    if entail:
      icds = get_icds(note,entail)
      prune_icd[k] = icds
    else:
      prune_icd[k] = []  
    
    
    #Find omissions
    if diseases:
      all_omissions = diseases.copy()
    else:
      all_omissions = []  
    for _ in range(3):
      
      omissions = find_omissions(note,all_omissions)
      if not omissions:
        break
      all_omissions.extend(omissions)
    if all_omissions:  
      icds = get_icds(note, all_omissions) 
      omissions_icd[k] = icds
    else:
      omissions_icd[k] = []  
    

    # Verify found omissions
    evidence = get_evidence(note,all_omissions)
    entail = does_ential(evidence)
    if entail:
      icds = get_icds(note,  entail)
      op_icd[k] = icds
    else:
      op_icd[k] = []  
    


  print(f'Top 10 Extractions Using {llm.model_name} model:')
  icd = icd_top10
  
  print("base ",calculate_metrics_9(omissions_icd, icd))
  print("prune ",calculate_metrics_9(prune_icd, icd))
  print("omissions ",calculate_metrics_9(omissions_icd, icd))
  print(" omissions + prune",calculate_metrics_9(op_icd, icd))  

  print(f'Top 50 Extractions Using {llm.model_name} model:')
  icd = icd_top50
  print("base ",calculate_metrics_9(base_icd, icd))
  print("prune ",calculate_metrics_9(prune_icd, icd))
  print("omissions ",calculate_metrics_9(omissions_icd, icd))
  print(" omissions + prune",calculate_metrics_9(op_icd, icd)) 

In [None]:
mimiciii_icd_pred = {}
for kk,vv in op_icd.items():
  mimiciii_icd_pred[kk] = {"base":base_icd[kk], "prune":prune_icd[kk],"omissions":omissions_icd[kk],"op":vv}

## save the predictions to file

In [None]:
filename = "YOUR/OUTUT/PATH"
os.makedirs(os.path.dirname(filename), exist_ok = True)
with open(filename,'w') as outfile:
  json.dump(mimiciii_icd_pred, outfile, indent=4, sort_keys = True, default = str)