# Download Wikipedia and unzip the bz2 file

# Cut the file down just to the infoboxes we care about
```
grep -A 50 -B 50 "Infobox medical condition (new)" enwiki-20230820-pages-articles-multistream.xml < medical-condition-box.txt
```

# Extract all Symptom <=> Name pairs

In [None]:
import time

with open("./medical-condition-box.txt", 'r') as wikipedia_file:
    wiki_data = wikipedia_file.read()

In [None]:
import re
infobox_start_indexes = [m.start() for m in re.finditer(r'Infobox medical condition \(new\)', wiki_data)]
len(infobox_start_indexes)

In [None]:
import re

curly_bracket_regex = r"((\{\{).*?(\}\}))"
curly_bracket_regex = re.compile(curly_bracket_regex)
refrence_regex = r"((&lt;).*?(gt;))"
refrence_regex = re.compile(refrence_regex)
info_box_section_regex = r"\| symptoms\s+="
info_box_section_regex = re.compile(info_box_section_regex)

j = 0

def clean_wikipedia_string(string):
    string = re.sub(refrence_regex, "", string)
    string = re.sub(curly_bracket_regex, "", string)
    string = re.sub(info_box_section_regex, "", string)
    while string.find("[[") != -1:
        start_bracket = string.find("[[")
        end_bracket = string.find("]]")

        bracket_string = string[start_bracket:end_bracket+2]
        bracket_text = bracket_string.replace("[[", "").replace("]]", "")
        if "|" in bracket_text:
            bracket_text = bracket_text.split("|")[-1]

        string = string.replace(bracket_string, bracket_text)

    return string

all_medical_conditions = []
for infobox_start_index in infobox_start_indexes:
    MAX_CHARACTERS = 5000
    info_box_text = wiki_data[infobox_start_index:infobox_start_index+MAX_CHARACTERS]
    start_index=0
    end_index=0
    while True:
        next_close = info_box_text[start_index:].find("}}")
        next_open = info_box_text[start_index:].find("{{")
        if next_open < next_close and next_open != -1:
            start_index += next_close + 2
        else:
            end_index = start_index + next_close
            break
        
    info_box_focused = info_box_text[:end_index]
    
    provided_info = info_box_focused.split("\n")
    symptoms = [line for line in provided_info if "| symptoms" in line]
    if len(symptoms) == 1:
        symptoms_string = clean_wikipedia_string(symptoms[0]).strip()

        if symptoms_string != "":
            names = [line for line in provided_info if "| name" in line]
            if len(names) == 1:
                name_string = clean_wikipedia_string(names[0])
                condition_name = name_string.split("=")[-1]
                if condition_name.strip() != "":
                    all_medical_conditions.append({"name" : condition_name, "symptoms": symptoms_string})
        
#     print("\n\n ========================== \n\n")


# Prompt Chat GPT to cleanup Symptoms

In [None]:
from openai import OpenAI
import pandas as pd
import time
import json
from tqdm import tqdm

client = OpenAI(
  organization='org-B...',
  api_key= "sk-E...",
)

def get_prompt(symptoms):
    return f"""Given the brief symptom text, extract a list of basically worded symptoms. Use as few words as possible. Make a list with each symptom on a new line and a dash to start each symptom. Only list the symptoms in the text.

SYMPTOMS:
{symptoms}"""

def load_dictionary_from_file(filename):
    try:
        with open(filename, 'r') as file:
            return json.load(file)
    except FileNotFoundError:
        return {}

def save_dictionary_to_file(dictionary, filename):
    with open(filename, 'w') as file:
        json.dump(dictionary, file)

CACHE_FILE = "openai_cache.json"
openai_cache = load_dictionary_from_file(CACHE_FILE)

for medical_condition in tqdm(all_medical_conditions):
    name_string  = medical_condition["name"]
    symptom_string = medical_condition["symptoms"]
    
    if symptom_string in openai_cache:
        continue
    else:
        prompt = get_prompt(openai_cache)
        completion = client.chat.completions.create(
          model="gpt-3.5-turbo",
          messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": get_prompt(symptom_string)}
          ],
          temperature=0.0
        )
        content = completion.choices[0].message.content
        openai_cache[symptom_string] = content
        
        print("")
        print(name_string)
        print(content)
        save_dictionary_to_file(openai_cache, CACHE_FILE)
        
        time.sleep(0.1)
    

# completion = client.chat.completions.create(
#   model="gpt-3.5-turbo",
#   messages=[
#     {"role": "system", "content": "You are a helpful assistant."},
#     {"role": "user", "content": get_prompt("'''Acute''': vomiting, abdominal pain, watery diarrhea'''Chronic''': thickened skin, darker skin, cancer")}
#   ]
# )

# print(completion.choices[0].message.content)

In [None]:
for i, medical_condition in enumerate(all_medical_conditions):
    symptoms = openai_cache[medical_condition["symptoms"]]
    symptoms = list(filter(lambda x: len(x) > 0, map(lambda x: x.replace("-","").strip(), symptoms.split("\n"))))
    symptoms = list(map(lambda x: x.lower(), symptoms))
    all_medical_conditions[i]["symptom_list"] = symptoms

In [None]:
# Find most common symptoms
symptom_frequency = {}

for medical_condition in all_medical_conditions:
    for symptom in medical_condition["symptom_list"]:
        if symptom not in symptom_frequency:
            symptom_frequency[symptom] = 0
        symptom_frequency[symptom] += 1

In [None]:
print("\n".join(list(symptom_frequency.keys())))

# Cleanup further by combining similar symptoms with an embedding model

In [None]:
import torch.nn.functional as F

from torch import Tensor
from transformers import AutoTokenizer, AutoModel


def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


symptom_names = list(symptom_frequency.keys())
symptom_names_prompt_engineered = [f"Symptoms {symptom_name}" for symptom_name in symptom_names]
tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-large-v2')
model = AutoModel.from_pretrained('intfloat/e5-large-v2')

def get_embeddings(texts):
    # Tokenize the input texts
    batch_dict = tokenizer(texts, max_length=512, padding=True, truncation=True, return_tensors='pt')

    outputs = model(**batch_dict)
    embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])

    # normalize embeddings
    embeddings = F.normalize(embeddings, p=2, dim=1)
    return embeddings
    
embeddings = get_embeddings(symptom_names_prompt_engineered)
# scores = (embeddings[:2] @ embeddings[2:].T) * 100
# print(scores.tolist())

In [None]:
print(embeddings.shape)

In [None]:
import numpy as np
sim_matrix = (embeddings @ embeddings.T).detach().numpy()
sim_matrix = np.triu(sim_matrix)
sim_matrix.shape

In [None]:
embeddings = embeddings.detach().numpy()
embeddings_map = {symptom_names[i]:embeddings[i] for i in range(len(symptom_names))} # Text => Embeddings

In [None]:
equivalency_matrix = np.zeros_like(sim_matrix)
for i in range(sim_matrix.shape[0]):
    sim_matrix[i][sim_matrix[i] > 0.999] = 0
    equivalency_matrix[i][(sim_matrix[i] > 0.97)] = 1
#     match_count = sim_matrix[i][close_enough_mask].shape[0]
#     equivalency_matrix[i] = sim_matrix[i][(np.argsort(-sim_matrix[i]) < match_count).astype(np.int32)]


In [None]:
import networkx as nx
SymptomG = nx.Graph()

for symptom_name in symptom_names:
    SymptomG.add_node(symptom_name)

xs,ys = np.where(equivalency_matrix == 1)
for i in range(len(xs)):
    SymptomG.add_edge(symptom_names[xs[i]], symptom_names[ys[i]])
for a in range(len(symptom_names)):
    for b in range(len(symptom_names)):
        if a > b:
            continue
            
        WORDS_TO_INDICATE_EQUIVALENCY_REGARDLESS_OF_EMBEDDING = ["fever", "headache"]
        for word in WORDS_TO_INDICATE_EQUIVALENCY_REGARDLESS_OF_EMBEDDING:
            if word in symptom_names[a] and word in symptom_names[b]:
                SymptomG.add_edge(symptom_names[a], symptom_names[b])

In [None]:
symptom_simplifier_map = {}
for comp in nx.connected_components(SymptomG):
    shortest = ""
    shortest_len = 999
    for symptom in comp:
        if len(symptom) < shortest_len:
            shortest_len = len(symptom)
            shortest = symptom
            
    symptom_simplifier_map[shortest] = comp

In [None]:
for medical_condition in all_medical_conditions:
    new_symptom_list = set()
    for symptom in medical_condition["symptom_list"]:
        for shortest, symptom_set in symptom_simplifier_map.items():
            if symptom in symptom_set:
                new_symptom_list.add(shortest)
                
    medical_condition["condensed_symptom_list"] = list(new_symptom_list)

In [None]:
# Find most common symptoms
condensed_symptom_frequency = {}

for medical_condition in all_medical_conditions:
    for symptom in medical_condition["condensed_symptom_list"]:
        if symptom not in condensed_symptom_frequency:
            condensed_symptom_frequency[symptom] = 0
        condensed_symptom_frequency[symptom] += 1

In [None]:
# Get all common symptoms embeddings (This is inefficient because we already computed them earlier)
common_symptom_names = [s for s,o in condensed_symptom_frequency.items() if o > 2]
uncommon_symptom_names = [s for s,o in condensed_symptom_frequency.items() if o <= 2]
common_symptom_names_prompt_engineered = [f"Symptom {symp}" for symp in common_symptom_names]
common_embeddings = get_embeddings(common_symptom_names_prompt_engineered).detach().numpy()

In [None]:
# If a symptom is rare, combine it with something bigger, else cut it out
SIM_CUTOFF = 0.92
uncommon_to_common_map = {}

for symptom, relevance in condensed_symptom_frequency.items():
    if relevance <= 2:
        symp_embedding = embeddings_map[symptom]
        sim_matrix = (symp_embedding @ common_embeddings.T)
        sim_matrix[sim_matrix < SIM_CUTOFF] = 0
        sim_matrix[sim_matrix > 0.999] = 0
        highest_index = np.argsort(-sim_matrix)[0]
        if sim_matrix[highest_index] != 0:
            uncommon_to_common_map[symptom] = common_symptom_names[highest_index]


In [None]:
for medical_condition in all_medical_conditions:
    new_symptom_list = set()
    for symptom in medical_condition["condensed_symptom_list"]:
        if symptom in common_symptom_names:
            new_symptom_list.add(symptom)
        else:
            if symptom in uncommon_to_common_map:
                new_symptom_list.add(uncommon_to_common_map[symptom])
    
    medical_condition["condensed_symptom_list2"] = list(new_symptom_list)

In [None]:
# Find most common symptoms
condensed_symptom_frequency2 = {}

for medical_condition in all_medical_conditions:
    for symptom in medical_condition["condensed_symptom_list2"]:
        if symptom not in condensed_symptom_frequency2:
            condensed_symptom_frequency2[symptom] = 0
        condensed_symptom_frequency2[symptom] += 1
        
condensed_symptom_frequency2

In [None]:
hq_medical_conditions = [mc for mc in all_medical_conditions if len(mc["condensed_symptom_list2"]) >= 2]

In [None]:
import networkx as nx

medical_condition_names = [mc["name"] for mc in hq_medical_conditions]

G = nx.Graph()
for i, medical_condition in enumerate(hq_medical_conditions):
    G.add_nodes_from([(medical_condition["name"], {"name": medical_condition["name"], "type": "MedicalCondition"})])

for symptom in condensed_symptom_frequency2.keys():
    G.add_nodes_from([(symptom, {"name": symptom.strip(), "type": "Symptom"})])
    
for i, medical_condition in enumerate(hq_medical_conditions):
    for symptom in medical_condition["condensed_symptom_list2"]:
        G.add_edge(medical_condition["name"], symptom)
        
with open("md-symptom.gexf", 'w') as f:
    f.write("\n".join(list(nx.generate_gexf(G))))

In [None]:
save_dictionary_to_file(hq_medical_conditions, "high_quality_symptoms.pkl")

# Prompt Chat GPT for Medical Condition prevelence and Symtom Chance

In [None]:
from openai import OpenAI
import pandas as pd
import time
import json
from tqdm import tqdm

client = OpenAI(
  organization='org-B...',
  api_key= "sk-E...",
)

def get_symp_prob_prevelance_prompt(symptom_list, name):
    symptom_list_str = "\n".join([f"    - {symp}: ??.??" for symp in symptom_list])
    return f"""Given the following medical condition, provide a quick one sentence summary. Given the list of symptoms, provide the percentage chance of experiencing the symptom given that you have the illness. Make sure to use the format below. Ensure the symptom values only contain digits (no "%" symbol) and make it out of 100. Please also provide prevalence by stating the percentage chance of an American having this illness in a year (Based out of 100). THE VALUES OF prevalence and symptoms SHOULD ONLY BE NUMBERS!!!
name: "{name}"
summary:
symptoms:
{symptom_list_str}
prevalence: ??.?"""

def load_dictionary_from_file(filename):
    try:
        with open(filename, 'r') as file:
            return json.load(file)
    except FileNotFoundError:
        return {}

def save_dictionary_to_file(dictionary, filename):
    with open(filename, 'w') as file:
        json.dump(dictionary, file)

hq_medical_conditions = load_dictionary_from_file("high_quality_symptoms.pkl")
        
CACHE_FILE = "openai_cache_2.json"
openai_cache = load_dictionary_from_file(CACHE_FILE)

for medical_condition in tqdm(hq_medical_conditions):
    name_string  = medical_condition["name"]
    symptom_list = medical_condition["condensed_symptom_list2"] + ["death"]
    prompt =  get_symp_prob_prevelance_prompt(symptom_list, name_string)
    
    if prompt in openai_cache:
        print("Skipping")
        continue
    else:
        completion = client.chat.completions.create(
          model="gpt-3.5-turbo",
          messages=[
            {"role": "system", "content": "You are a helpful assistant which response in YAML."},
            {"role": "user", "content": prompt}
          ],
          temperature=0.0
        )
        content = completion.choices[0].message.content
        openai_cache[prompt] = content
        
        print("")
        print(name_string)
        print(content)
        save_dictionary_to_file(openai_cache, CACHE_FILE)
        
        time.sleep(0.1)

In [None]:
import yaml

def clean_str(string):
    return (''.join(e for e in string if e.isalnum())).lower().strip()

for mc in hq_medical_conditions:
    name_string  = mc["name"]
    symptom_list = mc["condensed_symptom_list2"] + ["death"]
    prompt = get_symp_prob_prevelance_prompt(symptom_list, name_string)
    response_yaml = yaml.safe_load(openai_cache[prompt])
    if "?" in str(response_yaml["prevalence"]):
        mc["prevalence"] = 0.001
    else:
        mc["prevalence"] = float(response_yaml["prevalence"])
    mc["summary"] = response_yaml["summary"]
    mc["name"] = mc["name"].strip()
    
    # Parse out symptom list
    symptom_dict = {}
    death_prev = 0
    for symptom in mc["condensed_symptom_list2"]:
        if type(response_yaml["symptoms"]) == list:
            found = False
            for res_symptom in response_yaml["symptoms"]:
                try:
                    symp = list(res_symptom.keys())[0]
                    prev = list(res_symptom.values())[0]
                    
                    if clean_str(symp) == "death":
                        death_prev = prev
                    
                    if clean_str(symp) == clean_str(symptom):
                        symptom_dict[symptom] = float(prev)
                        found = True
                        break
                    
                except:
                    print(response_yaml)
            for res_symptom in response_yaml["symptoms"]:
                symp = list(res_symptom.keys())[0]
                prev = list(res_symptom.values())[0]
                if clean_str(symp) == "death":
                    death_prev = prev
                    
            if (found == False):
                print(symptom)
                print(response_yaml)
        elif type(response_yaml["symptoms"]) == dict:
            found = False
            for symp, prev in response_yaml["symptoms"].items():
                try:
                    if clean_str(symp) == clean_str(symptom):
                        symptom_dict[symptom] = float(prev)
                        death_prev = float(prev)
                        found = True
                        break
                    
                except:
                    print(response_yaml)
                    
            for symp, prev in response_yaml["symptoms"].items():
                if clean_str(symp) == "death":
                    death_prev = prev
            if (found == False):
                print(symptom)
                print(response_yaml)
                
    mc["symptom_prevelances"] = symptom_dict
    mc["mortality_chance"] = float(death_prev)

In [None]:
hq_medical_conditions[-100]

In [None]:
all_symptoms = set()
for symptom_list in [list(mc["symptom_prevelances"].keys()) for mc in hq_medical_conditions]:
    all_symptoms.update(symptom_list)
    
all_symptoms = sorted(list(all_symptoms))
all_symptoms[:10]

In [None]:
import networkx as nx

medical_condition_names = [mc["name"] for mc in hq_medical_conditions]

G = nx.Graph()
for i, medical_condition in enumerate(hq_medical_conditions):
    G.add_nodes_from([(medical_condition["name"], {"name": medical_condition["name"].strip(), "type": "MedicalCondition", "mortality_rate": float(medical_condition["mortality_chance"]) + 0.0000001, "prevalence": medical_condition["prevalence"]})])

for symptom in all_symptoms:
    G.add_nodes_from([(symptom, {"name": symptom.strip(), "type": "Symptom"})])
    
for i, medical_condition in enumerate(hq_medical_conditions):
    for symptom, prevalence in medical_condition["symptom_prevelances"].items():
        G.add_edge(medical_condition["name"], symptom, prevalence=prevalence)
        
with open("md-symptom.gexf", 'w') as f:
    f.write("\n".join(list(nx.generate_gexf(G))))

# Test Bayesian Predictions

In [None]:
import torch.nn.functional as F

from torch import Tensor
from transformers import AutoTokenizer, AutoModel


def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


all_symptoms = list(set(all_symptoms))
symptom_names_prompt_engineered = [f"Symptoms {symptom_name}" for symptom_name in all_symptoms]
tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-large-v2')
model = AutoModel.from_pretrained('intfloat/e5-large-v2')

def get_embeddings(texts):
    # Tokenize the input texts
    batch_dict = tokenizer(texts, max_length=512, padding=True, truncation=True, return_tensors='pt')

    outputs = model(**batch_dict)
    embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])

    # normalize embeddings
    embeddings = F.normalize(embeddings, p=2, dim=1)
    return embeddings
    
embeddings = get_embeddings(symptom_names_prompt_engineered)
similarity_matrix = embeddings @ embeddings.T
similarity_matrix = similarity_matrix.detach().numpy()

In [None]:
np.save('md-symptom-sim-mat.npy', similarity_matrix)
similarity_matrix = np.load('md-symptom-sim-mat.npy')

In [None]:
def get_medical_condition_from_symptoms(g, symptoms, sim_matrix=None):
    symptom_nodes = []
    possible_medical_conditions = []
    for symptom in symptoms:
        if g.has_node(symptom):
            # TODO: Assert the node that is found is a symptom node
            symptom_nodes.append(g.nodes[symptom]["name"])
            neighbor_ids = g.neighbors(symptom)
            for neighbor_id in neighbor_ids:
                possible_medical_conditions.append(g.nodes[neighbor_id])

    # Initialize each chance to 1
    mc_chance = {mc["name"]:mc["prevalence"]/100 for mc in possible_medical_conditions}

    # Check each edge
    for mc in possible_medical_conditions:
        for symp in symptom_nodes:
            edge_data = g.get_edge_data(mc["name"], symp)
            if edge_data != None:
                mc_chance[mc["name"]] *= edge_data["prevalence"] / 100
            else:
                mc_chance[mc["name"]] *= 0.01

    # Normalize
    chance_sum = 0
    for symp, chance in mc_chance.items():
        chance_sum += chance

    for symp, chance in mc_chance.items():
        mc_chance[symp] = chance / chance_sum
        
    return sorted(mc_chance.items(), key=lambda x:-x[1])

In [None]:
# Implement similarity embeddings matrix
get_medical_condition_from_symptoms(G, ["fever", "itch", "headache"], sim_matrix=similarity_matrix)