# LLM Powered Medical Case Sheet Ingestion
## Outline
1. Data Cleansing
2. Prompt Definition
3. Entity & Relationship Extraction
4. Neo4j Cypher Generation
5. Data Ingestion

In [1]:
from google.colab import auth as google_auth
google_auth.authenticate_user()

In [2]:
%%capture
%pip install graphdatascience
%pip install python-dotenv
%pip install retry
!gsutil cp gs://vertex_sdk_llm_private_releases/SDK/google_cloud_aiplatform-1.25.dev20230413+language.models-py2.py3-none-any.whl .
%pip install ./google_cloud_aiplatform-1.25.dev20230413+language.models-py2.py3-none-any.whl "shapely<2.0.0" --force-reinstall

In [3]:
import os
from retry import retry
import re
from string import Template
import json 
import ast
import time
import pandas as pd
from graphdatascience import GraphDataScience
import glob
from timeit import default_timer as timer
from dotenv import load_dotenv

from google.cloud import aiplatform
from google.cloud.aiplatform.private_preview.language_models import ChatModel, InputOutputTextPair

The training below shows how to instruction-tune a text-bison model. chat-bison model which we are going to use in the ingestion process is currently tunable. The below code is meant to show an example of fine-tuning

In [4]:
from typing import Union

import pandas as pd

from google.cloud.aiplatform.private_preview.language_models import TextGenerationModel
from google.cloud import aiplatform


def tune_model(
    project_id: str,
    location: str,
    training_data: Union[pd.DataFrame, str],
    train_steps: int = 10,
):
  """Tune a new model, based on a prompt-response data.

  "training_data" can be either the GCS URI of a file formatted in JSONL format
  (for example: training_data=f'gs://{bucket}/{filename}.jsonl'), or a pandas
  DataFrame. Each training example should be JSONL record with two keys, for
  example:
    {
      "input_text": <input prompt>,
      "output_text": <associated output>
    },
  or the pandas DataFame should contain two columns:
    ['input_text', 'output_text']
  with rows for each training example.

  Args:
    project_id: GCP Project ID, used to initialize aiplatform
    location: GCP Region, used to initialize aiplatform
    training_data: GCS URI of training file or pandas dataframe of training data
    train_steps: Number of training steps to use when tuning the model.
  """
  aiplatform.init(project=project_id, location=location)
  model = TextGenerationModel.from_pretrained("text-bison-001")

  model.tune_model(
      training_data=training_data,
      train_steps=train_steps,
      tuning_job_location="europe-west4",
      tuned_model_location="us-central1",
  )

  # Test the tuned model:
  print(
      model.predict("Tell me some ideas combining VR and fitness:")
  )

In [None]:
tune_model('neo4jbusinessdev', 'us-central1',
    'gs://gs_vertex_ai/eng2cypher/eng2cypher.jsonl',10
           )

In [12]:
from google.cloud.aiplatform.private_preview.language_models import TextGenerationModel
from google.cloud import aiplatform


def list_tuned_models(project_id, location):
  """List tuned models."""
  aiplatform.init(project=project_id, location=location)
  model = TextGenerationModel.from_pretrained("text-bison-001")
  tuned_model_names = model.list_tuned_model_names()
  print(tuned_model_names)

## Data Cleansing

First, let's define a function that can help clean the input data. For the sake of simplicity, lets keep it simple. In the corpus, the data refers to some Figures like scan images. We dont have them and so will remove any such references.

In [5]:
def clean_text(text):
  clean = "\n".join([row for row in text.split("\n")])
  clean = re.sub(r'\(fig[^)]*\)', '', clean, flags=re.IGNORECASE)
  return clean

Let's take this case sheet and extract entities and relations using LLM

In [6]:
sample_que = """The patient was a 34-yr-old man who presented with complaints of fever and a chronic cough.
He was a smoker and had a history of pulmonary tuberculosis that had been treated and cured.
A computed tomographic (CT) scan revealed multiple tiny nodules in both lungs.
A thoracoscopic lung biopsy was taken from the right upper lobe.
The microscopic examination revealed a typical LCH.
The tumor cells had vesicular and grooved nuclei, and they formed small aggregations around the bronchioles (Fig.1).
The tumor cells were strongly positive for S-100 protein, vimentin, CD68 and CD1a.
There were infiltrations of lymphocytes and eosinophils around the tumor cells.
With performing additional radiologic examinations, no other organs were thought to be involved.
He quit smoking, but he received no other specific treatment.
He was well for the following one year.
After this, a follow-up CT scan was performed and it showed a 4 cm-sized mass in the left lower lobe, in addition to the multiple tiny nodules in both lungs (Fig.2).
A needle biopsy specimen revealed the possibility of a sarcoma; therefore, a lobectomy was performed.
Grossly, a 4 cm-sized poorly-circumscribed lobulated gray-white mass was found (Fig.3), and there were a few small satellite nodules around the main mass.
Microscopically, the tumor cells were aggregated in large sheets and they showed an infiltrative growth.
The cytologic features of some of the tumor cells were similar to those seen in a typical LCH.
However, many tumor cells showed overtly malignant cytologic features such as pleomorphic/hyperchromatic nuclei and prominent nucleoli (Fig.4), and multinucleated tumor giant cells were also found.
There were numerous mitotic figures ranging from 30 to 60 per 10 high power fields, and some of them were abnormal.
A few foci of typical LCH remained around the main tumor mass.
Immunohistochemically, the tumor cells were strongly positive for S-100 protein (Fig.5) and vimentin; they were also positive for CD68 (Dako N1577, Clone KPI), and focally positive for CD1a (Fig.6), and they were negative for cytokeratin, epithelial membrane antigen, CD3, CD20 and HMB45.
The ultrastructural analysis failed to demonstrate any Birbeck granules in the cytoplasm of the tumor cells.
Now, at five months after lobectomy, the patient is doing well with no significant change in the radiologic findings.
"""

sample_ans = """
{'entities': [{'label': 'Case',
    'id': 'case1',
    'summary': '34-yr-old man with fever, chronic cough, history of pulmonary tuberculosis, LCH diagnosis, and sarcoma. Underwent lobectomy and is doing well.'},
   {'label': 'Person',
    'id': 'person1',
    'age': '34',
    'location': '',
    'gender': 'male'},
   {'label': 'Symptom', 'id': 'fever', 'description': 'Fever'},
   {'label': 'Symptom', 'id': 'chronicCough', 'description': 'Chronic cough'},
   {'label': 'Disease',
    'id': 'pulmonaryTuberculosis',
    'name': 'Pulmonary Tuberculosis'},
   {'label': 'Disease',
    'id': 'langerhansCellHistiocytosis',
    'name': 'Langerhans Cell Histiocytosis'},
   {'label': 'Disease', 'id': 'sarcoma', 'name': 'Sarcoma'},
   {'label': 'BodySystem', 'id': 'lungs', 'name': 'Lungs'},
   {'label': 'BodySystem', 'id': 'heart', 'name': 'Heart'},
   {'label': 'Diagnosis',
    'id': 'ctScan',
    'name': 'CT Scan',
    'description': 'Computed Tomographic (CT) scan',
    'when': 'initial'},
   {'label': 'Diagnosis',
    'id': 'thoracoscopicLungBiopsy',
    'name': 'Thoracoscopic Lung Biopsy',
    'description': 'Thoracoscopic lung biopsy from the right upper lobe',
    'when': 'initial'},
   {'label': 'Diagnosis',
    'id': 'followUpCtScan',
    'name': 'Follow-up CT Scan',
    'description': 'Follow-up CT scan showing a 4 cm-sized mass in the left lower lobe',
    'when': 'one year later'},
   {'label': 'Diagnosis',
    'id': 'needleBiopsy',
    'name': 'Needle Biopsy',
    'description': 'Needle biopsy specimen revealing the possibility of a sarcoma',
    'when': 'one year later'},
   {'label': 'Diagnosis',
    'id': 'lobectomy',
    'name': 'Lobectomy',
    'description': 'Lobectomy performed to remove the mass',
    'when': 'one year later'},
   {'label': 'Biological',
    'id': 'multipleTinyNodules',
    'name': 'Multiple Tiny Nodules',
    'description': 'Multiple tiny nodules in both lungs'},
   {'label': 'Biological',
    'id': 'lchCells',
    'name': 'LCH Cells',
    'description': 'Typical LCH cells with vesicular and grooved nuclei'},
   {'label': 'Biological',
    'id': 'tumorCells',
    'name': 'Tumor Cells',
    'description': 'Tumor cells with malignant cytologic features'}],
  'relationships': ['case1|FOR|person1',
   "person1|HAS_SYMPTOM{when:'initial',frequency:'',span:''}|fever",
   "person1|HAS_SYMPTOM{when:'initial',frequency:'',span:''}|chronicCough",
   "person1|HAS_DISEASE{when:'past'}|pulmonaryTuberculosis",
   "person1|HAS_DISEASE{when:'initial'}|langerhansCellHistiocytosis",
   "person1|HAS_DISEASE{when:'one year later'}|sarcoma",
   'chronicCough|SEEN_ON|lungs',
   'langerhansCellHistiocytosis|AFFECTS|lungs',
   'sarcoma|AFFECTS|lungs',
   'person1|HAS_DIAGNOSIS|ctScan',
   'person1|HAS_DIAGNOSIS|thoracoscopicLungBiopsy',
   'person1|HAS_DIAGNOSIS|followUpCtScan',
   'person1|HAS_DIAGNOSIS|needleBiopsy',
   'person1|HAS_DIAGNOSIS|lobectomy',
   'ctScan|SHOWED|multipleTinyNodules',
   'thoracoscopicLungBiopsy|SHOWED|lchCells',
   'lobectomy|SHOWED|tumorCells']}
"""

que = """A 28-year-old previously healthy man presented with a 6-week history of palpitations.
The symptoms occurred during rest, 2–3 times per week, lasted up to 30 minutes at a time and were associated with dyspnea.
Except for a grade 2/6 holosystolic tricuspid regurgitation murmur (best heard at the left sternal border with inspiratory accentuation), physical examination yielded unremarkable findings.
An electrocardiogram (ECG) revealed normal sinus rhythm and a Wolff– Parkinson– White pre-excitation pattern (Fig.1: Top), produced by a right-sided accessory pathway.
Transthoracic echocardiography demonstrated the presence of Ebstein's anomaly of the tricuspid valve, with apical displacement of the valve and formation of an “atrialized” right ventricle (a functional unit between the right atrium and the inlet [inflow] portion of the right ventricle) (Fig.2).
The anterior tricuspid valve leaflet was elongated (Fig.2C, arrow), whereas the septal leaflet was rudimentary (Fig.2C, arrowhead).
Contrast echocardiography using saline revealed a patent foramen ovale with right-to-left shunting and bubbles in the left atrium (Fig.2D).
The patient underwent an electrophysiologic study with mapping of the accessory pathway, followed by radiofrequency ablation (interruption of the pathway using the heat generated by electromagnetic waves at the tip of an ablation catheter).
His post-ablation ECG showed a prolonged PR interval and an odd “second” QRS complex in leads III, aVF and V2–V4 (Fig.1Bottom), a consequence of abnormal impulse conduction in the “atrialized” right ventricle.
The patient reported no recurrence of palpitations at follow-up 6 months after the ablation.

"""

## Prompt Definition

**⚠️** You need to duplicate `config.env.example` file in the left and rename as `config.env`. Edit the values in this file and provide the values for API keys and Neo4j credentials

In [7]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [8]:
load_dotenv('/content/drive/MyDrive/Colab Notebooks/GenAI-Playground/config-gcp.env', override=True)

shell_output = ! gcloud config list --format 'value(core.project)' 2>/dev/null
PROJECT_ID = os.getenv('PROJECT_ID')
os.environ["GCLOUD_PROJECT"] = PROJECT_ID
os.environ['GCLOUD_REGION'] = 'us-central1'

In [None]:
list_tuned_models(PROJECT_ID, 'europe-west4')

This is a helper function to talk to the LLM with our prompt and text input

In [9]:
# Bison Prompt to complete
@retry(tries=2, delay=5)
def process_gpt(
    project_id: str,
    model_name: str,
    temperature: float,
    max_output_tokens: int,
    top_p: float,
    top_k: int,
    prompt: str,
    que: str,
    location: str = "us-central1",
    ) :
    """Predict using a Large Language Model."""
    aiplatform.init(project=project_id, location=location)

    chat_model = ChatModel.from_pretrained(model_name)
    parameters = {
      "temperature": temperature,
      "max_output_tokens": max_output_tokens,
      "top_p": top_p,
      "top_k": top_k,
    }

    chat = chat_model.start_chat(
      context='''You are a helpful Medical Case Sheet expert who extracts relevant information which will be eventually used to store them on a Neo4j Knowledge Graph after processing''',
      examples=[
        InputOutputTextPair(
          input_text=prompt+sample_que,
          output_text=sample_ans
        )
      ]
    )
    return chat.send_message(prompt+que,**parameters)


This is a simple prompt to start with. If the processing is very complex, you can also chain the prompts as and when required. I am going to use a single prompt here that helps me to extract the text strictly as per the Entities and Relationships defined. This is a simplification. In the real scenario, especially with medical records, you have to leverage on Domain experts to define the Ontology systematically and capture the important information. You might also be fine-tuning the model as and when required.

Also, instead of one single large model, you can also consider chaining a number of smaller ones as per your needs.

We are going with this Graph Schema for our Case Sheet:
![schema.png](schema.png)

In [10]:
prompt="""From the Case sheet for a patient below, extract the following Entities & relationships described in the mentioned format 
0. ALWAYS FINISH THE OUTPUT. Never send partial responses
1. First, look for these Entity types in the text and generate as comma-separated format similar to entity type.
   `id` property of each entity must be alphanumeric and must be unique among the entities. You will be referring this property to define the relationship between entities. Do not create new entity types that aren't mentioned below. Document must be summarized and stored inside Case entity under `summary` property. You will have to generate as many entities as needed as per the types below:
    Entity Types:
    label:'Case',id:string,summary:string //Case
    label:'Person',id:string,age:string,location:string,gender:string //Patient mentioned in the case
    label:'Symptom',id:string,description:string //Symptom Entity; `id` property is the name of the symptom, in lowercase & camel-case & should always start with an alphabet
    label:'Disease',id:string,name:string //Disease diagnosed now or previously as per the Case sheet; `id` property is the name of the disease, in lowercase & camel-case & should always start with an alphabet
    label:'BodySystem',id:string,name:string //Body Part affected. Eg: Chest, Lungs; id property is the name of the part, in lowercase & camel-case & should always start with an alphabet
    label:'Diagnosis',id:string,name:string,description:string,when:string //Diagnostic procedure conducted; `id` property is the summary of the Diagnosis, in lowercase & camel-case & should always start with an alphabet
    label:'Biological',id:string,name:string,description:string //Results identified from Diagnosis; `id` property is the summary of the Biological, in lowercase & camel-case & should always start with an alphabet
    
3. Next generate each relationships as triples of head, relationship and tail. To refer the head and tail entity, use their respective `id` property. Relationship property should be mentioned within brackets as comma-separated. They should follow these relationship types below. You will have to generate as many relationships as needed as defined below:
    Relationship types:
    case|FOR|person
    person|HAS_SYMPTOM{when:string,frequency:string,span:string}|symptom //the properties inside HAS_SYMPTOM gets populated from the Case sheet
    person|HAS_DISEASE{when:string}|disease //the properties inside HAS_DISEASE gets populated from the Case sheet
    symptom|SEEN_ON|chest
    disease|AFFECTS|heart
    person|HAS_DIAGNOSIS|diagnosis
    diagnosis|SHOWED|biological
4. Do not send any response other than code block in the response

The output should look like :
{
    "entities": [{"label":"Case","id":string,"summary":string}],
    "relationships": ["disease|AFFECTS|heart"]
}

Case Sheet:
$ctext
"""

Let's run our completion task with our LLM

In [19]:
%%time
def run_completion(prompt, results, ctext):
    try:
      pr = Template(prompt).substitute(ctext=ctext)
      res = process_gpt(PROJECT_ID,
                        'chat-bison-001'
                        , 0, 1024, 0.8, 40, prompt, que, location="us-central1")
      results.append(res)
      return results
    except Exception as e:
        print(e)

prompts = [prompt]
results = []
for p in prompts:
  results = run_completion(p, results, clean_text(sample_que))
    



Unknown model name 'projects/803648085855/locations/us-central1/models/7947351409425383424'. Available model names are: ['text-bison-001', 'text-bison-alpha', 'embedding-gecko-001', 'chat-bison-001']
CPU times: user 33.2 ms, sys: 974 µs, total: 34.2 ms
Wall time: 5.01 s


In [13]:
results[0].text

"Sure, here are the entities and relationships extracted from the case sheet:\n\nEntities:\n\n* Case:\n    * label: 'Case'\n    * id: 'case1'\n    * summary: '28-year-old man with palpitations, Ebstein's anomaly, and Wolff– Parkinson– White pre-excitation pattern. Underwent radiofrequency ablation and is doing well.'\n* Person:\n    * label: 'Person'\n    * id: 'person1'\n    * age: '28'\n    * location: ''\n    * gender: 'male'\n* Symptom:\n    * label: 'Palpitations'\n    * id: 'palpitations'\n    * description: 'Palpitations occurred during rest, 2–3 times per week, lasted up to 30 minutes at a time and were associated with dyspnea.'\n* Disease:\n    * label: 'Ebstein's anomaly'\n    * id: 'ebsteinsAnomaly'\n    * name: 'Ebstein's anomaly'\n* BodySystem:\n    * label: 'Heart'\n    * id: 'heart'\n    * name: 'Heart'\n* Diagnosis:\n    * label: 'Electrophysiologic study'\n    * id: 'electrophysiologicStudy'\n    * name: 'Electrophysiologic study'\n    * description: 'Mapping of the ac

## Neo4j Cypher Generation

The entities & relationships we got from the LLM have to be transformed to Cypher so we can ingest into Neo4j

In [None]:
#pre-processing results for uploading into Neo4j - helper function:
def get_prop_str(prop_dict, _id):
    s = []
    for key, val in prop_dict.items():
      if key != 'label' and key != 'id':
         s.append(_id+"."+key+' = "'+str(val).replace('\"', '"').replace('"', '\"')+'"') 
    return ' ON CREATE SET ' + ','.join(s)

def get_cypher_compliant_var(_id):
    return "_"+ re.sub(r'[\W_]', '', _id)

def generate_cypher(in_json):
    e_map = {}
    e_stmt = []
    r_stmt = []
    e_stmt_tpl = Template("($id:$label{id:'$key'})")
    r_stmt_tpl = Template("""
      MATCH $src
      MATCH $tgt
      MERGE ($src_id)-[:$rel]->($tgt_id)
    """)
    for obj in in_json:
      for j in obj['entities']:
          props = ''
          label = j['label']
          id = j['id']
          if label == 'Case':
                id = 'c'+str(time.time_ns())
          elif label == 'Person':
                id = 'p'+str(time.time_ns())
          varname = get_cypher_compliant_var(j['id'])
          stmt = e_stmt_tpl.substitute(id=varname, label=label, key=id)
          e_map[varname] = stmt
          e_stmt.append('MERGE '+ stmt + get_prop_str(j, varname))

      for st in obj['relationships']:
          rels = st.split("|")
          src_id = get_cypher_compliant_var(rels[0].strip())
          rel = rels[1].strip()
          tgt_id = get_cypher_compliant_var(rels[2].strip())
          stmt = r_stmt_tpl.substitute(
              src_id=src_id, tgt_id=tgt_id, src=e_map[src_id], tgt=e_map[tgt_id], rel=rel)
          
          r_stmt.append(stmt)

    return e_stmt, r_stmt

In [None]:
ent_cyp, rel_cyp = generate_cypher(results)

## Data Ingestion

In [None]:
connectionUrl = os.getenv('NEO4J_CONN_URL')
username = os.getenv('NEO4J_USER')
password = os.getenv('NEO4J_PASSWORD')

In [None]:
gds = GraphDataScience(connectionUrl, auth=(username, password))
gds.version()

'2.3.4+17'

Before loading the data, create constraints as below

In [None]:
gds.run_cypher('CREATE CONSTRAINT unique_case_id IF NOT EXISTS FOR (n:Case) REQUIRE n.id IS UNIQUE')
gds.run_cypher('CREATE CONSTRAINT unique_person_id IF NOT EXISTS FOR (n:Person) REQUIRE (n.id) IS UNIQUE')
gds.run_cypher('CREATE CONSTRAINT unique_symptom_id IF NOT EXISTS FOR (n:Symptom) REQUIRE (n.id) IS UNIQUE')
gds.run_cypher('CREATE CONSTRAINT unique_disease_id IF NOT EXISTS FOR (n:Disease) REQUIRE n.id IS UNIQUE')
gds.run_cypher('CREATE CONSTRAINT unique_bodysys_id IF NOT EXISTS FOR (n:BodySystem) REQUIRE n.id IS UNIQUE')
gds.run_cypher('CREATE CONSTRAINT unique_diag_id IF NOT EXISTS FOR (n:Diagnosis) REQUIRE n.id IS UNIQUE')
gds.run_cypher('CREATE CONSTRAINT unique_biological_id IF NOT EXISTS FOR (n:Biological) REQUIRE n.id IS UNIQUE')

Ingest the entities

In [None]:
%%time
for e in ent_cyp:
    gds.run_cypher(e)


CPU times: user 35.4 ms, sys: 0 ns, total: 35.4 ms
Wall time: 1.49 s


Ingest relationships now

In [None]:
%%time
for r in rel_cyp:
    gds.run_cypher(r)

CPU times: user 51.5 ms, sys: 0 ns, total: 51.5 ms
Wall time: 2.72 s


This is a helper function to ingest all case sheets inside the `data/` directory

In [None]:
def run_pipeline(count=191):
    txt_files = glob.glob("data/case_sheets/*.txt")[0:count]
    print(f"Running pipeline for {len(txt_files)} files")
    failed_files = process_pipeline(txt_files)
    print(failed_files)
    return failed_files

def process_pipeline(files):
    failed_files = []
    for f in files:
        try:
            with open(f, 'r') as file:
                print(f"  {f}: Reading File...")
                data = file.read().rstrip()
                text = clean_text(data)
                print(f"    {f}: Extracting E & R")
                results = extract_entities_relationships(f, text)
                print(f"    {f}: Generating Cypher")
                ent_cyp, rel_cyp = generate_cypher(results)
                print(f"    {f}: Ingesting Entities")
                for e in ent_cyp:
                    gds.run_cypher(e)
                print(f"    {f}: Ingesting Relationships")
                for r in rel_cyp:
                    gds.run_cypher(r)
                print(f"    {f}: Processing DONE")
        except Exception as e:
            print(f"    {f}: Processing Failed with exception {e}")
            failed_files.append(f)
    return failed_files
            
def extract_entities_relationships(f, text):
    start = timer()
    system = "You are a helpful Medical Case Sheet expert who extracts relevant information and store them on a Neo4j Knowledge Graph"
    prompts = [prompt1]
    all_cypher = ""
    results = []
    for p in prompts:
      p = Template(p).substitute(ctext=text)
      res = process_gpt(system, p)
      results.append(json.loads(res))
    end = timer()
    elapsed = (end-start)
    print(f"    {f}: E & R took {elapsed}secs")
    return results

In [None]:
%%time
failed_files = run_pipeline(200)

If processing failed for some files due to API Rate limit or some other error, you can retry as below

In [None]:
%%time
failed_files = process_pipeline(failed_files)
failed_files

In [None]:
results