In [30]:
!pip install -q langchain uuid weaviate-client sentence-transformers


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [33]:
import random
import anthropic
import re
from dotenv import load_dotenv
import os
from langchain.chat_models import ChatAnthropic
from langchain.prompts import ChatPromptTemplate
import json
import uuid
import weaviate
from sentence_transformers import SentenceTransformer
from typing import List

load_dotenv()

ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
model = ChatAnthropic(
    model_name="claude-v2",
    anthropic_api_key=ANTHROPIC_API_KEY,
    max_tokens_to_sample=2000,
)

def get_patient_record(seen_ids: List) -> dict:
    """Generate a synthetic medical record for a patient using Claude API"""
    if seen_ids and random.randint(1, 10) in [1, 2, 3]:
        patient_id = random.sample(seen_ids, 1)
    else:
        patient_id = str(uuid.uuid4())
        seen_ids = seen_ids.append(patient_id)
    
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", ""),
            (
                "human",
                "Generate a synthetic medical record for patient {patient_id}. Include date, age, sex, symptoms, diagnoses, medications, vital signs, lab results, allergies, family history, and social history. Output as a valid JSON without any additional explanations or formatting",
            ),
        ]
    )

    chain = prompt | model
    record = chain.invoke(patient_id)
    content = record.content
    pattern = r"```json([\w\W]+?)```"
    code_blocks = re.findall(pattern, content, re.DOTALL)
    my_json = json.loads(code_blocks[0].strip("\n"))
    return {"json": my_json, "seen_ids": seen_ids}


def get_patient_records(num_records: int = 10) -> List[dict]:
    """Loop through the num_records and create a list of patient records by calling Claude API"""

    num_records = num_records
    records = []
    seen_ids = []

    for i in range(num_records):
        try:
            record = get_patient_record(seen_ids)
            records.append(record["json"])
            seen_ids = record["seen_ids"]
        except:
            record = get_patient_record(seen_ids)
            records.append(record["json"])
            seen_ids = record["seen_ids"]

    return records


def loop_and_insert(records: object, endpoint: str, collection_name: str) -> None:
    """Insert list of JSON into a weaviate database and embed only the symptoms"""
    client = weaviate.Client(
        url=endpoint,  # e.g. "https://some-endpoint.weaviate.network/",  # Replace with your endpoint
    )
    model = SentenceTransformer("all-MiniLM-L6-v2")

    # Prepare a batch process
    client.batch.configure(batch_size=10)  # Configure batch
    with client.batch as batch:
        # Batch import all Questions
        for record in records:
            batch.add_data_object(
                record,
                collection_name,
                vector=model.encode(
                    ", ".join(record["symptoms"])
                ),  # CHANGE COLLECTION NAME
            )


def filter_individual_patientID(weaviate_url: str, patientID: str) -> dict:
    client = weaviate.Client(weaviate_url)
    filter = {
        "path": ["Things", "patientID"],  # Things appears to be a dynamic name?
        "operator": "Equal",
        "valueString": patientID,
    }
    # Construct a GraphQL query with a filter
    query = {"Get": {"Things": ["*"]}, "Where": filter}
    result = client.query(query)
    # Extract the matching Things
    filtered_data = result["data"]["Get"]["Things"]
    return filtered_data


def get_unique_patientIDs(weaviate_url: str) -> List[str]:
    client = weaviate.Client(weaviate_url)
    query = {"Get": {"Things": ["patientID"]}}
    result = client.query(query)
    # Extract the matching Things
    filtered_data = result["data"]["Get"]["Things"]
    patient_ids = [i["patientID"] for i in filtered_data]
    return patient_ids


def create_summary_database(endpoint, collection_name, patient_ids):
    for id in set(patient_ids):
        data = filter_individual_patientID(endpoint, id)
        # Summarize records into a single string
        summary = "\n".join([d["text"] for d in data])
        # Vectorize summary
        model = SentenceTransformer("all-MiniLM-L6-v2")
        vector = model.encode(summary)
        # Create a new object with id, summary, and vector
        new_obj = {"patientID": id, "summary": summary, "vector": vector}
        # Add to new database
        client = weaviate.Client(url=endpoint)
        client.batch.add_data_object(new_obj, collection_name)

    client.batch.execute()

In [34]:
records = get_patient_records(10)

TypeError: list indices must be integers or slices, not str

In [None]:
# # Call functions and write to database
# for i in range(100):
#     records = get_patient_records(num_records=10)
#     loop_and_insert(records, endpoint=, collection_name=)
#     create_summary_database(endpoint=, collection_name=, patient_ids=get_unique_patientIDs(endpoint=)