In [30]:
!pip install -q langchain uuid weaviate-client sentence-transformers


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [40]:
import random
import anthropic
import re
from dotenv import load_dotenv
import os
from langchain.chat_models import ChatAnthropic
from langchain.prompts import ChatPromptTemplate
import json
import uuid
import weaviate
from sentence_transformers import SentenceTransformer
from typing import List

load_dotenv()
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
WEAVIATE_KEY = os.getenv("WEAVIATE_API_KEY")
WEVIATE_URL = os.getenv("WEAVIATE_URL")

model = ChatAnthropic(
    model_name="claude-2",
    anthropic_api_key=ANTHROPIC_API_KEY,
    max_tokens_to_sample=2000,
)


def get_patient_record(seen_ids: List) -> dict:
    """Generate a synthetic medical record for a patient using Claude API"""
    if seen_ids and random.randint(1, 10) in [1, 2, 3]:
        patient_id = random.sample(seen_ids, 1)
    else:
        patient_id = str(uuid.uuid4())
        seen_ids.append(patient_id)
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", ""),
            (
                "human",
                """Generate a synthetic medical record for patient {patient_id}. Include date, age,
                  sex, symptoms (a list), diagnoses, medications, 
                  vital signs, lab results, allergies, family 
                  history, and social history. Output as a valid JSON without any additional 
                  explanations or 
                  formatting. Any JSON keys should have whitespace replaced with underscores.
                  """,
            ),
        ]
    )
    chain = prompt | model
    record = chain.invoke({"patient_id": patient_id})
    content = record.content
    pattern = r"```json([\w\W]+?)```"
    code_blocks = re.findall(pattern, content, re.DOTALL)
    my_json = json.loads(code_blocks[0].strip("\n"))
    return {"json": my_json, "seen_ids": seen_ids}


def get_patient_records(num_records: int = 10) -> List[dict]:
    """Loop through the num_records and create a list of patient records by calling Claude API"""
    num_records = num_records
    records = []
    seen_ids = []
    while len(records) < num_records:
        try:
            record = get_patient_record(seen_ids)
            print(record)
            records.append(record["json"])
            seen_ids = record["seen_ids"]
        except Exception as e:
            print(f"exception {e}")
    return records


def loop_and_insert(records: object, endpoint: str, collection_name: str) -> None:
    """Insert list of JSON into a weaviate database and embed only the symptoms"""
    client = weaviate.Client(
        auth_client_secret=weaviate.AuthApiKey(api_key=WEAVIATE_KEY),
        url=endpoint,  # e.g. "https://some-endpoint.weaviate.network/",  # Replace with your endpoint
    )
    model = SentenceTransformer("all-MiniLM-L6-v2")
    # Prepare a batch process
    client.batch.configure(batch_size=10)  # Configure batch
    with client.batch as batch:
        # Batch import all Questions
        for record in records:
            patient_id = record["patient_id"]
            symptoms = record["symptoms"]
            date = record["date"]
            try:
                batch.add_data_object(
                    {"patientID" : patient_id, "symptoms" : symptoms, "date" : date, "record": json.dumps(record)},
                    collection_name,
                    vector=model.encode(", ".join(symptoms)),
                )
            except Exception as e:
                print("Error in inserting: ", e)


def filter_individual_patientID(weaviate_url: str, patientID: str) -> dict:
    client = weaviate.Client(weaviate_url)
    filter = {
        "path": ["Things", "patientID"],  # Things appears to be a dynamic name?
        "operator": "Equal",
        "valueString": patientID,
    }
    # Construct a GraphQL query with a filter
    query = {"Get": {"Things": ["*"]}, "Where": filter}
    result = client.query(query)
    # Extract the matching Things
    filtered_data = result["data"]["Get"]["Things"]
    return filtered_data


def get_unique_patientIDs(weaviate_url: str) -> List[str]:
    client = weaviate.Client(weaviate_url)
    query = {"Get": {"Things": ["patientID"]}}
    result = client.query(query)
    # Extract the matching Things
    filtered_data = result["data"]["Get"]["Things"]
    patient_ids = [i["patientID"] for i in filtered_data]
    return patient_ids


def create_summary_database(
    endpoint: str, collection_name: str, patient_ids: List, doctor_type: str = "General Practicioner (GP)"
) -> None:
    for id in set(patient_ids):
        data = filter_individual_patientID(endpoint, id)
        # Join all the text of the records together into a single string
        id_string = ""
        for record in data:
            joined_text = json.dumps(record)
            id_string += joined_text

        # Use Claude to summarise the record string
        prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    f"""You are a medical summariser with experience in summarising complex medical histories for patients
                 into a concise summary that be understood by a medical profressional, in this instance a {doctor_type}""",
                ),
                (
                    "human",
                    """The following is a concatentation of one or more medical records for an individual.
                    Summarise the following concatenated medical record into a single string that captures
                    the variance of the individual concatenated records: {record_string}""",
                ),
            ]
        )
        chain = prompt | model
        summary = chain.invoke({"record_string": id_string}).content

        # Vectorize summary
        model = SentenceTransformer("all-MiniLM-L6-v2")
        vector = model.encode(summary)
        # Create a new object with id, summary, and vector
        new_obj = {"patientID": id, "summary": summary, "vector": vector}
        # Add to new database
        client = weaviate.Client(url=endpoint)
        client.batch.add_data_object(new_obj, collection_name)
    client.batch.execute()


def search_for_similar_summaries(endpoint: str, patientID: str, num_results: int = 10) -> tuple:
    """Given a patients summary, return a list of similar summaries including their diagnoses and medications. Give a confidence
    score for the probability that the patient has the same diagnosis and medications as the similar patient"""
    # Get the patient summary
    client = weaviate.Client(url=endpoint)
    filter = {
        "path": ["Things", "patientID"],  # Things appears to be a dynamic name?
        "operator": "Equal",
        "valueString": patientID,
    }
    # Construct a GraphQL query with a filter
    query = {"Get": {"Things": ["*"]}, "Where": filter}
    result = client.query(query)
    # Extract the matching Things
    filtered_data = result["data"]["Get"]["Things"]
    patient_summary = filtered_data[0]["summary"]
    patient_vector = filtered_data[0]["vector"]  # Potentially correct up to this point
    # Search for similar summaries
    query = {
        "Aggregate": {
            "Things": [
                {
                    "similarSummaries": {  # Is this mean to be the collection name?
                        "certainty": 0.8,
                        "limit": num_results,
                        "vector": patient_vector,
                    }
                }
            ]
        }
    }
    result = client.query(query)
    # Extract the matching Things
    similar_summaries = result["data"]["Aggregate"]["Things"][0]["similarSummaries"]  # Also need to change this if so
    print("check the format of similar_summaries to see how to call it")
    print(similar_summaries)
    pids = [s["patientID"] for s in similar_summaries["result"]]
    return (pids, patient_summary)


def filter_for_symptoms(
    patientID: str, symptoms: str, endpoint: str, num_results: int, t1_collection_name: str, t2_collection_name: str
) -> tuple:
    """Function that takes patientID, and a list of symptoms and looks up the patient summary based on ID,
    Then calls the search_for_similar_summaries function to return a list of similar summaries and suggested diagnoses/medications
    """
    model = SentenceTransformer("all-MiniLM-L6-v2")
    symptoms_vector = model.encode(symptoms)
    relevant_summary_pids, summary = search_for_similar_summaries(endpoint, patientID, num_results, t2_collection_name)
    print(relevant_summary_pids)
    client = weaviate.Client(url=endpoint)
    response = (
        client.query.get(t1_collection_name, ["symptoms"])
        .with_where({"path": [patientID], "operator": "In", "valueString": relevant_summary_pids})
        .with_near_vector(
            {
                "vector": symptoms_vector,
            }
        )
        .with_limit(10)
        .with_additional(["distance"])
        .do()
    )
    return (response, summary)


def get_final_output(response: tuple) -> dict:
    """Function that takes the response from the filter_for_symptoms function and returns a final output
    with the patientID, symptoms, diagnoses, medications, and distance"""
    final_output = []
    for i in response["data"]["Get"]["Things"]:
        patientID = i["patientID"]
        symptoms = i["symptoms"]
        diagnoses = i["diagnoses"]
        medications = i["medications"]
        distance = i["distance"]
        final_output.append(
            {
                "patientID": patientID,
                "symptoms": symptoms,
                "diagnoses": diagnoses,
                "medications": medications,
                "distance": distance,
            }
        )
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", """"""),
            (
                "human",
                """You will be provided with a list of diagnoses and medications of patients
                 who experienced similar symptoms. Based on how often these occur, provide the most likely 
                 diagnoses and medications. The output should be JSON format, with medications grouped with 
                 the diagnoses they treat. Do not include any explanation or other information in the output 
                 other than the JSON. 
                 <example_output>{"diagnoses": \[{"diagnosis": "treatment"}\]}</example_output>
                <records>"""
                + f"""{final_output}"""
                + """</records>
                Assistant: {"diagnoses":[]}""",
            ),
        ]
    )
    chain = prompt | model
    final_output = chain.invoke().content
    return final_output


def create_natural_language_summary(summary: str, final_output: dict, target_comprehension: str = "non-specialist audiences") -> str:
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", f"""You are an expert at communicating complex medical topics to {target_comprehension}"""),
            (
                "human",
                f"""Create a summary of an individual's medical history and likely diagnoses and treatments.
                medical history summary: "{summary}"
                likely diagnoses and treatments: {json.dumps(final_output)}.
                If a particular language is specified, translate the summary into that language.""",
            ),
        ]
    )

    chain = prompt | model
    translated_output = chain.invoke().content
    return translated_output

  """You will be provided with a list of diagnoses and medications of patients


In [41]:
records = get_patient_records(50)
print(records)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


exception list index out of range
exception list index out of range


In [37]:
loop_and_insert(records, WEVIATE_URL, "Patient")

{'error': [{'message': "'heart disease' is not a valid nested property name of 'family_history'. NestedProperty names in Weaviate are restricted to valid GraphQL names, which must be “/[_A-Za-z][_0-9A-Za-z]*/”., invalid object property 'family_history' on class 'Patient': property 'family_history.diabetes': invalid boolean property 'family_history.diabetes' on class 'Patient': not a bool, but string"}]}


In [None]:
unique_patient_ids = get_unique_patientIDs(WEVIATE_URL)
print(unique_patient_ids)

In [None]:
create_summary_database(WEVIATE_URL, "SummaryTable", unique_patient_ids)

In [None]:
filtered_pids, summary = filter_for_symptoms("patientID", "symptoms", WEVIATE_URL, 10, "Patient", "SummaryTable")
print(filtered_pids)
print(summary)

In [None]:
# This is presented to the Dr / In app
final_diagnoses_and_medication = get_final_output(filtered_pids)
print(final_diagnoses_and_medication)

In [None]:
# This is presented to the patient
pesonalised_summary = create_natural_language_summary(summary, final_diagnoses_and_medication, "layman")

In [35]:
properties = ["symptoms, diagnoses"]

results = client.query.get("Patient", properties).do()

print(json.dumps(results, indent=2))

{
  "data": {
    "Get": {
      "Patient": [
        {
          "diagnoses": [
            "hypertension",
            "diabetes_mellitus_type_2"
          ],
          "symptoms": [
            "chest_pain",
            "shortness_of_breath",
            "nausea"
          ]
        },
        {
          "diagnoses": [
            "hypertension",
            "coronary_artery_disease"
          ],
          "symptoms": [
            "chest_pain",
            "shortness_of_breath",
            "fatigue"
          ]
        },
        {
          "diagnoses": [
            "acute_myocardial_infarction",
            "hypertension"
          ],
          "symptoms": [
            "chest_pain",
            "shortness_of_breath",
            "nausea"
          ]
        }
      ]
    }
  }
}


In [16]:
import weaviate
import os

auth_config = weaviate.AuthApiKey(api_key="kbypW858drjOi8i6fSTB4cNJb2FboJHVdMp3")
client = weaviate.Client(url="https://diagnosis-ai-generator-kk6m9y1x.weaviate.network", auth_client_secret=auth_config)  

client.schema.get()



{'classes': [{'class': 'SummaryTable',
   'invertedIndexConfig': {'bm25': {'b': 0.75, 'k1': 1.2},
    'cleanupIntervalSeconds': 60,
    'stopwords': {'additions': None, 'preset': 'en', 'removals': None}},
   'multiTenancyConfig': {'enabled': False},
   'properties': [{'dataType': ['uuid'],
     'indexFilterable': True,
     'indexSearchable': False,
     'name': 'patientID'}],
   'replicationConfig': {'factor': 1},
   'shardingConfig': {'virtualPerPhysical': 128,
    'desiredCount': 1,
    'actualCount': 1,
    'desiredVirtualCount': 128,
    'actualVirtualCount': 128,
    'key': '_id',
    'strategy': 'hash',
    'function': 'murmur3'},
   'vectorIndexConfig': {'skip': False,
    'cleanupIntervalSeconds': 300,
    'maxConnections': 64,
    'efConstruction': 128,
    'ef': -1,
    'dynamicEfMin': 100,
    'dynamicEfMax': 500,
    'dynamicEfFactor': 8,
    'vectorCacheMaxObjects': 1000000000000,
    'flatSearchCutoff': 40000,
    'distance': 'cosine',
    'pq': {'enabled': False,
     '

In [39]:
cls = {
            "class": "Patient",
            "description": "Patient medical record",
            "properties": [],
        }

client.schema.create_class(cls)

In [38]:
client.schema.delete_class("Patient")