Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Diagnosis Levels with correct API response #49 #57

Merged
merged 18 commits into from
Aug 15, 2024
Merged
25 changes: 12 additions & 13 deletions app/categorization/llm_categorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
AssessmentToolPrompt,
DiagnosisPrompt,
)
from categorization.llm_helper import SexLevel, AgeFormat, get_assessment_label
from categorization.llm_helper import SexLevel, AgeFormat, get_assessment_label,Diagnosis_Level,list_terms


def Diagnosis(
Expand All @@ -18,24 +18,23 @@ def Diagnosis(
{"column": key, "content": value}
)
reply = str(llm_response_Diagnosis)
print(reply)

# print(reply)
if "yes" in reply.lower():
values = value.split()
unique_values = list(set(values[1:]))

# Create dictionary for Levels
levels_dict = {val: "" for val in unique_values}

# Create the output dictionary
output = {"TermURL": "nb:Diagnosis", "Levels": levels_dict}

output = {"TermURL": "nb:Diagnosis", "Levels": {}}
unique_entries=list_terms(key,value)
levels={} #the empty dictionary passed to the diagnosis_level function to be filled
level={} # the dictionary which will become the output
level = Diagnosis_Level(unique_entries, code_system,levels)
output["Levels"] = level
print(json.dumps(output))
return output

else:
return AssessmentTool(key, value, code_system)



def AssessmentTool(
key: str, value: str, code_system: str
) -> Optional[Dict[str, Any]]:
Expand Down Expand Up @@ -88,7 +87,7 @@ def llm_invocation(
chainGeneral = GeneralPrompt | llm
key, value = list(result_dict.items())[0]
llm_response = chainGeneral.invoke({"column": key, "content": value})
print(llm_response)
# print(llm_response)
r = str(llm_response)
if "Participant_IDs" in r:
output = {"TermURL": "nb:ParticipantID"}
Expand Down
82 changes: 82 additions & 0 deletions app/categorization/llm_helper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

from datetime import datetime
import json
import re
Expand Down Expand Up @@ -113,6 +114,87 @@ def AgeFormat(result_dict: Dict[str, str], key: str) -> Dict[str, Any]:
return output




def list_terms(key, value):
words = value.split()
unique_entries = list(set(words))
if key in unique_entries:
unique_entries.remove(key)
print("check8.0")
print(unique_entries)
return unique_entries



def is_score(input_string):
# Remove all whitespace
cleaned_string = re.sub(r'\s+', '', input_string)

# Check if the string contains only digits
if cleaned_string.isdigit():
return True

# Check if the string contains only one or two alphabetic characters with digits
alpha_count = sum(c.isalpha() for c in cleaned_string)
if alpha_count <= 2 and all(c.isdigit() or c.isalpha() for c in cleaned_string):
return True

return False

def are_all_digits(input_list):
# Check if all elements in the list are digit strings
return all(element.isdigit() for element in input_list)

def Diagnosis_Level(unique_entries:dict,code_system: str,levels):
# print(unique_entries)

def load_dictionary(file_path):
with open(file_path, 'r') as file:
return json.load(file)


#get the list of related labels
def get_label_for_abbreviation(abbreviation:str, abbreviation_to_label):
if abbreviation in abbreviation_to_label:
return abbreviation_to_label[abbreviation]
elif abbreviation.isdigit():
return ["some score"]
else:
return ["left for user"]

# Path to your JSON file
file_path = 'app/parsing/abbreviation_to_labels.json'

# Load the JSON data
data = load_dictionary(file_path)



def Get_Level(unique_entries:list):
if are_all_digits(unique_entries):
print("scores")
else:
for i in range (0,len(unique_entries)):
levelfield=get_label_for_abbreviation(unique_entries[i],data)
levels[unique_entries[i]] = levelfield




Get_Level(unique_entries)
print('''

helper return


''')
print(levels)
return levels




def get_assessment_label(key: str, code_system: str) -> Union[str, List[str]]:
def load_dictionary(file_path: str) -> Any:
with open(file_path, "r") as file:
Expand Down
1 change: 1 addition & 0 deletions app/parsing/abbreviation_to_labels.json

Large diffs are not rendered by default.

79 changes: 63 additions & 16 deletions app/parsing/json_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,26 @@ class Annotations(BaseModel): # type:ignore
IsAboutAssessmentTool,
]
Identifies: Optional[str] = None
Levels: Optional[Dict[str, Dict[str, str]]] = None

Levels: Optional[
Union[
Dict[str, List[Dict[str, str]]],
Dict[str, Dict[str, str]],
Dict[str, str],
Dict[str, List[str]],
# Add this to allow for lists of strings
]
] = None
Transformation: Optional[Dict[str, str]] = None
IsPartOf: Optional[Union[List[Dict[str, str]], Dict[str, str], str]] = None


class TSVAnnotations(BaseModel): # type:ignore
Description: str
Levels: Optional[Dict[str, str]] = None
# Levels: Optional[Union[Dict[str, str],Dict[str, List[str]]]] = None

Levels: Optional[Union[Dict[str, str], Dict[str, List[str]]]] = None
# Levels: Optional[Union[Dict[str,List[str]],Dict[str,str],str]]
Annotations: Annotations


Expand Down Expand Up @@ -114,7 +126,8 @@ def handle_age(parsed_output: Dict[str, Any]) -> TSVAnnotations:


def handle_categorical(
parsed_output: Dict[str, Any], levels_mapping: Mapping[str, Dict[str, str]]
parsed_output: Dict[str, Any],
levels_mapping: Mapping[str, List[Dict[str, str]]],
) -> TSVAnnotations:
termurl = parsed_output.get("TermURL")

Expand All @@ -127,10 +140,25 @@ def handle_categorical(
else:
raise ValueError(f"Unhandled TermURL: {termurl}")

levels = {
key: levels_mapping.get(value.strip().lower(), {})
for key, value in parsed_output.get("Levels", {}).items()
}
if termurl == "nb:Diagnosis":
levels = {
key: [
levels_mapping.get(item.strip().lower(), {})
for item in (value if isinstance(value, list) else [value])
]
for key, value in parsed_output.get("Levels", {}).items()
}
if termurl == "nb:Sex":
levels = {
key: (
levels_mapping.get(value[0].strip().lower(), {})
if isinstance(value, list)
else levels_mapping.get(value.strip().lower(), {})
)
for key, value in parsed_output.get("Levels", {}).items()
}

print(levels)

annotations = Annotations(IsAbout=annotation_instance, Levels=levels)
return TSVAnnotations(
Expand Down Expand Up @@ -202,14 +230,29 @@ def handle_assessmentTool(
def load_levels_mapping(mapping_file: str) -> Dict[str, Dict[str, str]]:
with open(mapping_file, "r") as file:
mappings = json.load(file)
return {
entry["label"]
.strip()
.lower(): {"TermURL": entry["identifier"], "Label": entry["label"]}
for entry in mappings
}

levels_mapping = {}
for entry in mappings:
label_key = entry.get("label", "").strip().lower()
identifier_key = entry.get("identifier")

if not label_key:
print(f"Warning: Missing or empty 'label' in entry: {entry}")
continue

if not identifier_key:
# print(f"Warning: Missing 'identifier' for label '{label_key}' in entry: {entry}")
# Optionally, you can skip this entry or assign a default value
identifier_key = "default_identifier"

levels_mapping[label_key] = {
"TermURL": identifier_key,
"Label": entry["label"],
}

return levels_mapping

# noqa: E501
def load_assessmenttool_mapping(
mapping_file: str,
) -> Mapping[str, Dict[str, str]]:
Expand Down Expand Up @@ -242,7 +285,9 @@ def process_parsed_output(
)
elif code_system == "snomed":
print("Using SNOMED CT terms for assessment tool annotation.")
assessmenttool_mapping_file = "app/parsing/measurementTerms.json"
assessmenttool_mapping_file = (
"app/parsing/abbreviations_measurementTerms.json"
)
assessmenttool_mapping = load_levels_mapping(
assessmenttool_mapping_file
)
Expand All @@ -255,7 +300,7 @@ def process_parsed_output(

termurl_to_function_with_levels: Dict[
str,
Callable[[Dict[str, Any], Mapping[str, Dict[str, str]]], Any],
Callable[[Dict[str, Any], Mapping[str, Dict[str, Any]]], Any],
] = {
"nb:Sex": handle_categorical,
"nb:Diagnosis": handle_categorical,
Expand Down Expand Up @@ -294,6 +339,8 @@ def process_parsed_output(
)
else:
return "Error: TermURL is missing from the parsed output"
else:
return "Error: parsed_output is not a dictionary"


def update_json_file(
Expand All @@ -303,7 +350,7 @@ def update_json_file(
data_dict = data.model_dump(exclude_none=True)
else:
data_dict = {"error": data}

# noqa: E501
try:
with open(filename, "r") as file:
file_data: Dict[str, Any] = json.load(file)
Expand Down
3 changes: 2 additions & 1 deletion app/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ def process_file(
try:
input_dict = {key: value}
llm_response = llm_invocation(input_dict, code_system)
print(llm_response)
result = process_parsed_output(llm_response, code_system) # type: ignore # noqa: E501
results[key] = result
update_json_file(result, json_file, key)
except Exception as e:
results[key] = {"error": str(e)}

return results
1 change: 1 addition & 0 deletions rag_documents/abbreviation_to_labels.json

Large diffs are not rendered by default.

Loading
Loading