In [1]:
import os, sys
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../src")))

from tqdm import tqdm
from rag_prompt_template import *
from rag_util import *
from rag_moduler import *
from rag_extraction import *
from grammar_llm_utils import *
import json

Initialise RAG pipeline
------

In [None]:
using_llm = "mistralsmall"
using_embed = "hitsnomed"
task = "tripleextraction"
eval_dataset = "processed_data_pubmed_aggregated_annotations_5"

PARAMETERS = {
    "llm_model_name": LLM[using_llm],
    "tokenizer_name": LLM[using_llm],
    "embed_model_name": EMBED_MODEL[using_embed],
    "storage_dir": f"../index/snomed_dataset_nodoc_commandr_hitsnomed", # this is a partial KG indices for testing
    # "storage_dir": f"index/snomed_all_dataset_nodoc_hitsnomed",  # this is a full KG indices for testing
    "input_text_dir": f"../data/humandx_data/humandx_findings.json",
    "context_window": 32768,
    "max_new_tokens": 1024,
    "case_num":50,
    "verbose": True,
    "similarity_top_k": 30,
    "graph_store_query_depth": 5,
    "retriever_mode": "hybrid",
    "test_id": f"_test_{task}_{eval_dataset}_{using_llm}_{using_embed}_grammar_predicate_constrained"
}

In [None]:
# initialise llm service context
llm = init_llm_service_context(llm_model_name=PARAMETERS["llm_model_name"], 
                                    tokenizer_name=PARAMETERS["tokenizer_name"], 
                                    embed_model_name=PARAMETERS["embed_model_name"],
                                    context_window=PARAMETERS["context_window"],
                                    max_new_tokens=PARAMETERS["max_new_tokens"],
                                    # quantization_config=None,
                                )



Loading checkpoint shards:   0%|          | 0/9 [00:00<?, ?it/s]

In [None]:
# Initialize and load knowledge graph index
kg_index = init_kg_storage_context(llm, storage_dir=PARAMETERS["storage_dir"])

In [None]:
# Initialize retriever
retriever = init_retriever(kg_index=kg_index,
                            similarity_top_k=30,
                            graph_store_query_depth=2,
                            verbose=False
                            )

Retriever created, retriever: <class 'llama_index.core.indices.knowledge_graph.retrievers.KGTableRetriever'>, retriever_mode: hybrid


Triple Extraction example
------

In [None]:
# Example text for triple extraction

text = "Results: Clinical data from our institution reveals that Leak of cranial cerebrospinal fluid due to and following procedure on central nervous system shows pathological morphology typical of Cerebrospinal fluid leakage, establishing a critical pathophysiological relationship. Moreover, Cerebrospinal fluid leakage belongs to the category of Morphologically Abnormal Structure. Clinical manifestations develop through well-defined pathological processes affecting specific organ systems and cellular functions. Accurate diagnosis depends on recognition of characteristic clinical patterns, appropriate use of diagnostic testing, and careful interpretation of results within the clinical context. Clinical management requires systematic approach including accurate diagnosis, appropriate treatment selection, and ongoing monitoring of therapeutic response. Treatment protocols emphasize patient safety, efficacy optimization, and quality of life considerations. Comprehensive care includes patient education, support services, and coordination with healthcare team members. Long-term prognosis is generally positive with early intervention and appropriate ongoing management. Success depends on multifactorial considerations including patient characteristics, disease severity, treatment response, and adherence to recommended care protocols. Systematic follow-up ensures continued treatment effectiveness."

system_prompt= """You are a syntax generator that produces sentences conforming to the following format:
    [ (SUBJECT; PREDICATE; OBJECT), (SUBJECT; PREDICATE; OBJECT), ... ]\n\n
    Rules to follow:
    Each triple must be enclosed in parentheses ()
    Each triple must be separated by a comma
    The entire list of triples must be enclosed in square brackets [ ... ]
    Generate at least two triples per output, optionally more
"""

retrieved_results = retriever.retrieve(text)
retrieved_triples = retrieved_results[1].node.metadata['kg_rel_texts']

prompt = f"""{system_prompt}

Here are the triples to use as a reference: {retrieved_triples}

Now extract the triples from the following text: {text}
"""

prompt

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


'You are a syntax generator that produces sentences conforming to the following format:\n    [ (SUBJECT; PREDICATE; OBJECT), (SUBJECT; PREDICATE; OBJECT), ... ]\n\n\n    Rules to follow:\n    Each triple must be enclosed in parentheses ()\n    Each triple must be separated by a comma\n    The entire list of triples must be enclosed in square brackets [ ... ]\n    Generate at least two triples per output, optionally more\n\n\nHere are the triples to use as a reference: ["(\'International neuroblastoma pathology classification: Favorable histology group, patient of any age with ganglioneuroma (Schwannian stroma-dominant) maturing, or mature (finding)\', \'associated morphology\', \'Lesion (morphologic abnormality)\')", "(\'Intracranial hemorrhage following injury with prolonged loss of consciousness AND return to pre-existing conscious level (disorder)\', \'type\', \'Disorder\')", "(\'Cerebrovascular accident due to thrombus of left middle cerebral artery (disorder)\', \'due to\', \'Thro

In [6]:
# Initialize tokenizer
# print(f"Using GPU is CUDA:{os.environ['CUDA_VISIBLE_DEVICES']}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
gpu_name = torch.cuda.get_device_name(0)
print(f"Using GPU: {gpu_name}")
if "A100" in gpu_name or "L4" in gpu_name or "H100" in gpu_name:
    dtype = torch.bfloat16
else:
    dtype = torch.float16

model = AutoModelForCausalLM.from_pretrained(LLM[using_llm], 
                                                torch_dtype=dtype,
                                            device_map="balanced",
                                            quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(LLM[using_llm])

Using device: cuda
Using GPU: NVIDIA A100-PCIE-40GB


Loading checkpoint shards:   0%|          | 0/9 [00:00<?, ?it/s]

In [None]:
# Initialize grammar
productions, regex_dict = init_grammar(MODE="entpred") # MODE="pred" or "entpred"

pars_table, map_terminal_tokens = get_parsing_table_and_map_tt(
        tokenizer, 
        productions=productions, 
        regex_dict=regex_dict,
    )

Initializing grammar in entpred mode...


In [None]:
# Evaluate on the eval dataset
from datetime import datetime
from tqdm import tqdm
now = datetime.now().strftime("%Y%m%d_%H%M%S")

eval_dataset = "../data/dataset_example/processed_pubmed_annotations_5/csv/triple_extraction_dataset_eval.csv"
eval_df = pd.read_csv(eval_dataset)
eval_df['output'] = eval_df['output'].apply(lambda x: x.replace("; ",";").replace("), ", "),"))

predictions = []

for index, row in tqdm(eval_df.iterrows()):
    input_text = row['input']

    print(f"Processing case {index+1}/{len(eval_df)}: {input_text[:50]}...")
    retrieved_triples = retriever.retrieve(input_text)[1].node.metadata['kg_rel_texts']
    print(f"length of retrieved triples: {len(retrieved_triples)}")

    if len(retrieved_triples) == 0:
        print("No triples retrieved, skipping this case.")
        predictions.append("No triples retrieved")
        continue
    
    prompt = f"""
{system_prompt}

Here are the triples to use as a reference: {retrieved_triples}

Now extract the triples from the following text: {input_text}

"""
    # print(f"Prompt: {prompt}")

    LogitProcessor, Streamer = generate_grammar_parameters(tokenizer, pars_table, map_terminal_tokens)

    with torch.no_grad():
        # output = generate_text(model, tokenizer, prompt, LogitProcessor, Streamer, chat_template, max_new_tokens=500, do_sample=True, temperature=0.7, top_p=0.9)
        output = generate_text(model, tokenizer, prompt, LogitProcessor, Streamer, chat_template, max_new_tokens=500)
        predictions.append(output)
    
    print(f"Generated output: {output}")

eval_df['pred_output'] = predictions
eval_df[['output','pred_output']].to_csv(f"../results/{PARAMETERS['test_id']}_{now}.csv", index=False)


0it [00:00, ?it/s]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Processing case 1/5: Traditional, trans-cervical thyroidectomy results ...
length of retrieved triples: 30


1it [00:07,  7.22s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Generated output: [('Trandolapril';'type';'Drug or medicament')]
Processing case 2/5: Cardiac sarcoidosis is an inflammatory myocardial ...
length of retrieved triples: 30


2it [00:20, 10.60s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Generated output: [('Cardiac sarcoidosis';'is modification of';'Inflammatory myopathy with abundant macrophages')]
Processing case 3/5: Tracheal resection and anastomosis surgery is a sa...
length of retrieved triples: 30


3it [00:28,  9.69s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Generated output: [('Tracheal stenosis following tracheostomy';'associated morphology';'Stenosis of trachea')]
Processing case 4/5: Airborne aerosol transmission, an established mech...
length of retrieved triples: 30


4it [00:37,  9.31s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Generated output: [('Airway trauma';'due to';'Upper respiratory inflammation caused by chemical fumes')]
Processing case 5/5: Adenoid cystic carcinoma (ACC) is the most common ...
length of retrieved triples: 30


5it [00:47,  9.57s/it]

Generated output: [('Adenoid cystic carcinoma of lacrimal gland';'associated morphology';'Adenoid cystic carcinoma of lacrimal gland')]



