In [27]:
import os
import json
import openai
import pandas as pd

from llama_index.embeddings import OpenAIEmbedding
from llama_index.schema import Document, MetadataMode
from llama_index.node_parser.simple import SimpleNodeParser

from config import MAIN_DIR
from utils import (
    convert_doc_to_dict,
    generate_vectorindex,
    load_vectorindex,
    query_wrapper,
    remove_final_sentence
    )

In [2]:
DATA_DIR = os.path.join(MAIN_DIR, "data")
EMB_DIR = os.path.join(DATA_DIR, "emb_store")
FINETUNE_DIR = os.path.join(EMB_DIR, "finetune")

with open(os.path.join(MAIN_DIR, "auth", "api_keys.json"), "r") as f:
    api_keys = json.load(f)

os.environ["OPENAI_API_KEY"] = api_keys["OPENAI_API_KEY"]
openai.api_key = api_keys["OPENAI_API_KEY"]

In [3]:
with open(os.path.join(EMB_DIR, "texts.json"), "r") as f:
    text_list_by_page = json.load(f)

In [4]:
embed_model = OpenAIEmbedding(model="text-embedding-ada-002")

In [5]:
page_nodes = []

for text in text_list_by_page:
    text["metadata"]["mode"] = "text"
    doc = Document(
        text=text["text"],
        metadata=text["metadata"],
        excluded_embed_metadata_keys = ['file_name', 'page_label', 'variant', 'mode'],
        excluded_llm_metadata_keys = ['file_name', 'page_label', 'variant']
        )
    page_nodes.append(doc)

In [28]:
testcase_df = pd.read_csv(
        os.path.join(DATA_DIR, "queries", "MSK LLM Fictitious Case Files Full.csv"),
        usecols = ['ACR scenario', 'Guideline', 'Variant', 'Appropriateness Category',
                   'Scan Order', 'Clinical File']
        )

patient_profiles = testcase_df["Clinical File"]
scan_orders = testcase_df["Scan Order"]

question_template = "Patient Profile: {profile}\nScan ordered: {scan_order}"

testcase_df["queries"] = [
    query_wrapper(question_template, {"profile": remove_final_sentence(patient_profile, True)[0],
                                      "scan_order": remove_final_sentence(patient_profile, True)[1]})
    for patient_profile in patient_profiles
    ]

In [56]:
search_space = {
    "similarity_top_k": [3, 5, 7],
    "chunk_size": [256, 512, 1024]
}

In [57]:
for chunk_size in search_space["chunk_size"]:

    description = "FreeText-Chunk_size={}".format(chunk_size)
    save_folder = os.path.join(FINETUNE_DIR, description)
    
    if not os.path.exists(save_folder):
        os.makedirs(save_folder, exist_ok=True)

    node_parser = SimpleNodeParser.from_defaults(chunk_size = chunk_size)

    text_nodes = node_parser.get_nodes_from_documents(page_nodes)
    text_docs = []

    for node in text_nodes:
        doc = Document(
            text=node.text,
            metadata=node.metadata,
            excluded_embed_metadata_keys = ['file_name', 'page_label', 'variant', 'mode'],
            excluded_llm_metadata_keys = ['file_name', 'page_label', 'variant']
            )
        text_docs.append(doc)
        
    text_contents_for_embs = [
        text_node.get_content(metadata_mode = MetadataMode.EMBED) 
        for text_node in text_nodes
        ]

    text_dicts = [convert_doc_to_dict(doc) for doc in text_docs]
    text_embs = embed_model.get_text_embedding_batch(text_contents_for_embs, show_progress=True)
    
    text_info_for_save = []

    for text_dict, text_emb, text_doc in zip(text_dicts, text_embs, text_docs):
        text_info_for_save.append({"text_doc": text_dict, "text_emb": text_emb})
        text_doc.embedding = text_emb

    with open(os.path.join(save_folder, f"text-embs-chunk_size={chunk_size}.json"), "w") as f:
        json.dump(text_info_for_save, f)
        
    generate_vectorindex(
        embeddings=embed_model,
        emb_size=1536,
        documents=text_docs,
        output_directory=os.path.join(save_folder, "db"),
        emb_store_type="chroma",
        chunk_size=chunk_size,
        index_name="texts"
    )
    
    texts_index = load_vectorindex(
        db_directory = os.path.join(save_folder, "db"),
        emb_store_type = "chroma", index_name = "texts",
    )
    
    for top_k in search_space["similarity_top_k"]:
        text_retriever = texts_index.as_retriever(similarity_top_k = top_k)
        
        retrieval_dataset = {
            "question": [], "contexts": [], "ground_truths": []
        }

        for query, variant, guideline in zip(testcase_df["queries"], testcase_df["ACR scenario"], testcase_df["Guideline"]):
            correct_variant = "Condition: {}\nPatient Category: {}".format(guideline, variant)
            retrieved_nodes = text_retriever.retrieve(query)
            retrieval_dataset["question"].append(query)
            retrieval_dataset["ground_truths"].append(correct_variant)
            retrieval_dataset["contexts"].append(
                [node_with_score.node.text for node_with_score in retrieved_nodes]
            )
                
        with open(os.path.join(save_folder, f"retrieval_dataset_FreeText-Chunk_size={chunk_size}-K={top_k}.json"), "w") as f:
            json.dump(retrieval_dataset, f)

Generating embeddings:   2%|▏         | 30/1827 [00:02<02:18, 13.02it/s]