In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch

import re
import json
import os
from typing import Dict, List, Tuple
import gc

import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='transformers.generation.utils')

In [2]:
# model_path = "meta-llama/Llama-2-13b-chat-hf"
# tokenizer = AutoTokenizer.from_pretrained(model_path)
# model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", load_in_4bit=True)


model_name_or_path = "TheBloke/Llama-2-13B-chat-GPTQ"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                             device_map="auto",
                                             trust_remote_code=False,
                                             revision="gptq-8bit-128g-actorder_True")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
pipeline = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=4000,
    do_sample=True,
    temperature=0.7,
    top_p=0.95,
    top_k=40,
    repetition_penalty=1.1,
    max_time = 300,
)

In [3]:
def ehr_kg_prompting_question(term, topics, text, category):

    question = \
    f"""
    <s>[INST] <<SYS>>
    Given a crawled text about specific topic of certain {category}, please find tripples relatied to the given {category} in terms of crawled text.
    • Filling triples in updates based on given information and strictly following output style of example updates.
    • Each update should be exactly in format of [ENTITY 1, RELATIONSHIP, ENTITY 2], and the relationship is directed.
    • Both ENTITY 1 and ENTITY 2 should be noun, and one of them must be or highly similar to the current {category}.
    • Just output triples once, don't output duplicately.
    • It is possible that {category} name not exactly matched in crawled text (abbreviated or partly matched), consider it as the same thing.

    Example:
    disease name: Heart Failure
    topics: Overview
    crawled text:
    Heart failure occurs when the heart muscle doesn't pump blood as well as it should. When this happens, blood often backs up and fluid can build up in the lungs, causing shortness of breath. Certain heart conditions gradually leave the heart too weak or stiff to fill and pump blood properly. These conditions include narrowed arteries in the heart and high blood pressure. Proper treatment may improve the symptoms of heart failure and may help some people live longer. Lifestyle changes can improve quality of life. Try to lose weight, exercise, use less salt and manage stress. But heart failure can be life-threatening. People with heart failure may have severe symptoms. Some may need a heart transplant or a device to help the heart pump blood. Heart failure is sometimes called congestive heart failure.
    updates:
    [Heart Failure, IS_CAUSED_BY, Narrowed Arteries], 
    [Heart Failure, IS_CAUSED_BY, High Blood Pressure],
    [Heart Failure, HAS_SYMPTOMS, Shortness of Breath],
    [Heart Failure, HAS_SYMPTOMS, Fluid Build-up in Lungs],
    [Heart Failure, NEEDS_TREATMENT, Proper Treatment],
    [Heart Failure, NEEDS_TREATMENT, Lifestyle Changes]
    <</SYS>>
    Given a crawled text about specific topic of certain {category}, please find tripples relatied to the given {category} in terms of crawled text.
    Given information:
    {category} name: {term}
    topics: {topics}
    crawled text: {text}
    updates:
    [/INST]
    """

    return question

    
def extract_triples(text):
    start_index = text.find("Given information:")
    if start_index != -1:
        text = text[start_index:]
    else:
        return "Cannot find 'Given information'"

    triple_pattern = re.compile(r'\[(?!ENTITY 1)([^\[\],]+), (?!RELATIONSHIP)([^\[\],]+), (?!ENTITY 2)([^\[\]]+)\]', re.IGNORECASE)
    return triple_pattern.findall(text)

def clear_cuda_memory():
    torch.cuda.empty_cache()
    gc.collect()

In [4]:
term = 'Cyclothymia (cyclothymic disorder)'
topics = 'Overview'
text = \
"""
Cyclothymia (sy-kloe-THIE-me-uh), also called cyclothymic disorder, is a rare mood disorder. Cyclothymia causes emotional ups and downs, but they're not as extreme as those in bipolar I or II disorder. With cyclothymia, you experience periods when your mood noticeably shifts up and down from your baseline. You may feel on top of the world for a time, followed by a low period when you feel somewhat down. Between these cyclothymic highs and lows, you may feel stable and fine. Although the highs and lows of cyclothymia are less extreme than those of bipolar disorder, it's critical to seek help managing these symptoms because they can interfere with your ability to function and increase your risk of bipolar I or II disorder. Treatment options for cyclothymia include talk therapy (psychotherapy), medications and close, ongoing follow-up with your doctor.
"""
category = 'disease'

# Execution
qs = ehr_kg_prompting_question(term, topics, text, category)
sequences = pipeline(f'{qs}\n')
extracted_triples = extract_triples(sequences[0]['generated_text'])
for triple in extracted_triples:
    print(triple)


clear_cuda_memory()

('Cyclothymia', 'IS_A', 'Mood Disorder')
('Cyclothymia', 'CAUSES', 'Emotional Ups and Downs')
('Cyclothymia', 'HAS_SYMPTOMS', 'Shift in Mood')
('Cyclothymia', 'HAS_SYMPTOMS', 'Interference with Functioning')
('Cyclothymia', 'INCREASES_RISK', 'Bipolar I or II Disorder')
('Cyclothymia', 'NEEDS_TREATMENT', 'Talk Therapy')
('Cyclothymia', 'NEEDS_TREATMENT', 'Medications')
('Cyclothymia', 'NEEDS_FOLLOW_UP', 'Close Ongoing Follow-up with Doctor')


In [5]:
def prompt_triple_updates(term, topics, text):
    qs = ehr_kg_prompting_question(term, topics, text, category='disease')
    try:
        sequences = pipeline(f'{qs}\n')
        extracted_triples = extract_triples(sequences[0]['generated_text'])
        # print(sequences[0]['generated_text'])
    except:
        print(f"Prompting Error founded in {term} # {topics}")
        extracted_triples = ['Prompting Error']
        
    return extracted_triples

In [6]:
extracted_triples = prompt_triple_updates(term, topics, text)
for triple in extracted_triples:
    print(triple)

clear_cuda_memory()

('Cyclothymia', 'IS_CAUSED_BY', 'Rare Mood Disorder')
('Cyclothymia', 'HAS_SYMPTOMS', 'Emotional Ups and Downs')
('Cyclothymia', 'HAS_SYMPTOMS', 'Less Extreme Than Bipolar Disorder')
('Cyclothymia', 'NEEDS_TREATMENT', 'Talk Therapy (Psychotherapy)')
('Cyclothymia', 'NEEDS_TREATMENT', 'Medications')
('Cyclothymia', 'NEEDS_TREATMENT', 'Close, Ongoing Follow-up with Doctor')


In [7]:
# Function for OrphaNet
def save_processed_results(data: List, output_path: str, start_point: str = None):
    
    starting = False
    if start_point is None:
        starting = True

    for disease_info in data:
        oid = disease_info['disease_id']
        disease = disease_info['disease_name']
        if not starting and start_point and start_point == disease:
            starting = True
        if not starting:
            continue
        
        d_list = []
        for topic, text in disease_info.items():
            if topic in ['definition', 'prevalence', 'epidemiology', 'clinical_description', 'management_and_treatment']:
                if text and text != '' and text != '-' and text != 'Unknown':
                    processed_list = prompt_triple_updates(disease, topic, text)
                    d_list.extend(processed_list)
                    clear_cuda_memory()
        
        with open(f"{output_path}/OrphaNet_triples.txt", 'a') as file:
            file.write(f"{disease} # {oid}\n")
            for item in d_list:
                file.write(f"\t{item}\n")
            file.write("\n")  

        with open(f"{output_path}/OrphaNet_triples_progress.txt", 'w') as progress_file:
            progress_file.write(f"{disease},{oid}")

output_path = 'C:/Users/Humphrey/Desktop/Extra knowledge for disease/OrphaNet'
with open(f"{output_path}/diseases_data.json", 'r') as file:
    data = json.load(file)
start_from = 'Nail anomaly' # 'Congenital dyserythropoietic anemia type I' 
# 'Hereditary motor and sensory neuropathy type 6' # 'Dysostosis, Stanescu type'
save_processed_results(data, output_path, start_from)
# Prompting Error founded in Dysostosis, Stanescu type # clinical_description

This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (2048). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.


In [None]:
# Function for RareDisease
def save_processed_results(data: Dict, output_path: str, start_point: Tuple[str, str] = None):
    
    starting = False
    if start_point is None:
        starting = True

    for disease, topics in data.items():
        if not starting and start_point and start_point[0] == disease:
            starting = True
        if not starting:
            continue
        
        for topic, text in topics.items():
            if not starting and start_point and start_point[1] == topic:
                starting = True
            if not starting:
                continue

            if text and text != '':
                processed_list = prompt_triple_updates(disease, topic, text)
                with open(f"{output_path}/RareDisease_triples.txt", 'a') as file:
                    file.write(f"{disease} # {topic}\n")
                    for item in processed_list:
                        file.write(f"\t{item}\n")
                    file.write("\n")  
                clear_cuda_memory()

            with open(f"{output_path}/RareDisease_triples_progress.txt", 'w') as progress_file:
                progress_file.write(f"{disease},{topic}")


output_path = 'C:/Users/Humphrey/Desktop/Extra knowledge for disease/Rare Disease'
with open(f"{output_path}/RareDisease_info.json", 'r') as file:
    data = json.load(file)
start_from = None
save_processed_results(data, output_path, start_from)

In [None]:
# Function for Wiki Extra Knowledge
def save_processed_results(data: Dict, output_path: str, start_point: Tuple[str, str] = None):
    starting = False
    if start_point is None:
        starting = True

    for disease, content in data.items():
        # Check if 'Extra Knowledge' exists for the disease
        extra_knowledge = content.get('Extra Knowledge', {})
        if not extra_knowledge:
            continue

        if not starting and start_point and start_point[0] == disease:
            starting = True
        if not starting:
            continue

        for topic, text in extra_knowledge.items():
            if not starting and start_point and start_point[1] == topic:
                starting = True
            if not starting:
                continue

            if text and text != '':
                disease_name = disease.split('#')[1] if '#' in disease else disease
                processed_list = prompt_triple_updates(disease_name, topic, text)
                with open(f"{output_path}/wiki_triples.txt", 'a') as file:
                    file.write(f"{disease} # {topic}\n")
                    for item in processed_list:
                        file.write(f"\t{item}\n")
                    file.write("\n")
                clear_cuda_memory()

            with open(f"{output_path}/wiki_triples_progress.txt", 'w') as progress_file:
                progress_file.write(f"{disease},{topic}")

# Example usage
output_path = 'C:/Users/Humphrey/Desktop/Extra knowledge for disease/Wiki'
with open(f"{output_path}/D4_extra_knowledge.json", 'r') as file:
    data = json.load(file)
start_from = None
# ('437.3#Cerebral aneurysm nonruptured', 'Overview')
# ('345.6#Infantile spasms', 'Overview')
save_processed_results(data, output_path, start_from)

# D1, D2 completed
# Prompting Error founded in 493#Asthma # Notes
# Prompting Error founded in 200.3#Marginal zone lymphoma # Extranodal marginal zone lymphoma
# Prompting Error founded in Extrinsic asthma # Notes
# Prompting Error founded in Intrinsic asthma # Notes
# Prompting Error founded in Threatened abortion # Citations
# Prompting Error founded in Extrinsic asthma # Notes
# Prompting Error founded in Intrinsic asthma # Notes
# Prompting Error founded in Threatened abortion # Citations

This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (2048). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.


In [None]:
# Function for MayoClinic
def save_processed_results(data: Dict, output_path: str, start_point: Tuple[str, str] = None):
    """
    Saves the processed results to a text file.
    
    :param data: Nested dictionary with disease topics and texts.
    :param output_file: Path to the output text file.
    :param start_point: A tuple of the disease and topic where processing should start or resume.
    :return: None
    """
    # Check if we need to start from a specific point
    starting = False
    if start_point is None:
        starting = True

    # Open the output file in append mode so as to not overwrite existing content.
    for disease, topics in data.items():
        if not starting and start_point and start_point[0] == disease:
            # Start processing from the specified topic in the specified disease.
            starting = True
        if not starting:
            # Skip diseases until we reach the starting point.
            continue
        
        for topic, text in topics.items():
            if not starting and start_point and start_point[1] == topic:
                # Start processing from the specified topic.
                starting = True
            if not starting:
                # Skip topics until we reach the starting point.
                continue

            # Process the text for the current topic.
            if text and text != '':
                processed_list = prompt_triple_updates(disease, topic, text)
                # Write the processed list to the file, each element on a new line with indentation.
                with open(f"{output_path}/mayoclinic_triples.txt", 'a') as file:
                    file.write(f"{disease} # {topic}\n")
                    for item in processed_list:
                        file.write(f"\t{item}\n")
                    file.write("\n")  # Add an extra newline for separation between topics.
                # Dropping previous prompts considering the memory's limitation
                clear_cuda_memory()

            # Optionally, save the current position to a separate file for recovery.
            with open(f"{output_path}/mayoclinic_triples_progress.txt", 'w') as progress_file:
                progress_file.write(f"{disease},{topic}")

# Example usage:
# Assuming 'data' is the JSON data loaded from the file.
# The output file will be in the mounted directory.
# Starting from the last processed disease and topic if provided.
output_path = 'C:/Users/Humphrey/Desktop/Extra knowledge for disease/Mayoclinic'
with open(f"{output_path}/Mayoclinic_info.json", 'r') as file:
    data = json.load(file)
start_from = ('Chronic hives', 'Overview') # ('Dry mouth', 'Overview')
# ('Myalgic encephalomyelitis/chronic fatigue syndrome (ME/CFS)', 'Overview')
# ('Type 1 diabetes', 'Overview') - Treatment 
# ('Genital warts', 'Risk factors') # ('Broken collarbone', 'Risk factors')
# ('Pulmonary hypertension', 'Diagnosis') # ('Cyclothymia (cyclothymic disorder)', 'Symptoms')

# Call the function with the example data and starting point.
save_processed_results(data, output_path, start_from)

This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (2048). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.
