In [4]:
import pandas as pd
import numpy as np
from pgmpy.models import DiscreteBayesianNetwork
from pgmpy.estimators import BayesianEstimator
from pgmpy.inference import VariableElimination
import pickle
import warnings
warnings.filterwarnings('ignore')

In [5]:
data = pd.read_csv('datasets/training.csv')
data = data.dropna()
for col in data.columns:
    data[col] = data[col].astype(str)

In [6]:
skin_cluster = [
    "itching", "skin_rash", "nodal_skin_eruptions", "ulcers_on_tongue", "patches_in_throat",
    "bruising", "brittle_nails", "dischromic _patches", "pus_filled_pimples", "blackheads",
    "scurring", "skin_peeling", "silver_like_dusting", "small_dents_in_nails", "inflammatory_nails",
    "blister", "red_sore_around_nose", "yellow_crust_ooze"
]
respiratory_cluster = [
    "continuous_sneezing", "cough", "breathlessness", "phlegm", "throat_irritation",
    "sinus_pressure", "runny_nose", "congestion", "chest_pain", "mucoid_sputum",
    "rusty_sputum", "blood_in_sputum"
]
gastrointestinal_cluster = [
    "stomach_pain", "acidity", "vomiting", "burning_micturition", "spotting_ urination",
    "diarrhoea", "constipation", "abdominal_pain", "mild_fever", "yellow_urine",
    "yellowing_of_eyes", "acute_liver_failure", "fluid_overload", "swelling_of_stomach",
    "pain_during_bowel_movements", "pain_in_anal_region", "bloody_stool", "irritation_in_anus",
    "bladder_discomfort", "foul_smell_of urine", "continuous_feel_of_urine", "passage_of_gases",
    "internal_itching", "toxic_look_(typhos)", "stomach_bleeding", "distention_of_abdomen",
    "belly_pain", "Heartburn"
]
neuro_muscle_cluster = [
    "joint_pain", "muscle_wasting", "back_pain", "headache", "pain_behind_the_eyes",
    "weakness_in_limbs", "cramps", "knee_pain", "hip_joint_pain", "muscle_weakness",
    "stiff_neck", "swelling_joints", "movement_stiffness", "spinning_movements",
    "loss_of_balance", "unsteadiness", "weakness_of_one_body_side", "loss_of_smell",
    "neck_pain", "dizziness", "slurred_speech", "altered_sensorium", "coma", "painful_walking"
]
general_cluster = [
    "shivering", "chills", "fatigue", "weight_gain", "anxiety", "cold_hands_and_feets",
    "mood_swings", "weight_loss", "restlessness", "lethargy", "high_fever", "sunken_eyes",
    "sweating", "dehydration", "indigestion", "nausea", "loss_of_appetite", "malaise",
    "blurred_and_distorted_vision", "redness_of_eyes", "watering_from_eyes", "visual_disturbances",
    "lack_of_concentration", "depression", "irritability", "muscle_pain", "red_spots_over_body",
    "increased_appetite", "polyuria", "family_history", "history_of_alcohol_consumption"
]
endocrine_cluster = [
    "irregular_sugar_level", "obesity", "swollen_legs", "swollen_blood_vessels",
    "puffy_face_and_eyes", "enlarged_thyroid", "swollen_extremeties", "excessive_hunger"
]
genitourinary_cluster = [
    "burning_micturition", "spotting_ urination", "abnormal_menstruation", "Urinating_a_lot"
]
cardio_cluster = [
    "fast_heart_rate", "palpitations", "prominent_veins_on_calf"
]
other_cluster = [
    "extra_marital_contacts", "receiving_blood_transfusion", "receiving_unsterile_injections",
    "fluid_overload.1"
]

In [7]:
def create_cluster_column(df, cluster, name):
    # 1 if any symptom in cluster is 1, else 0
    df[name] = df[ [col for col in cluster if col in df.columns] ].astype(int).max(axis=1).astype(str)
    return df

data = create_cluster_column(data, skin_cluster, 'skin_cluster')
data = create_cluster_column(data, respiratory_cluster, 'respiratory_cluster')
data = create_cluster_column(data, gastrointestinal_cluster, 'gastrointestinal_cluster')
data = create_cluster_column(data, neuro_muscle_cluster, 'neuro_muscle_cluster')
data = create_cluster_column(data, general_cluster, 'general_cluster')
data = create_cluster_column(data, endocrine_cluster, 'endocrine_cluster')
data = create_cluster_column(data, genitourinary_cluster, 'genitourinary_cluster')
data = create_cluster_column(data, cardio_cluster, 'cardio_cluster')
data = create_cluster_column(data, other_cluster, 'other_cluster')

In [8]:
disease_col = 'prognosis'
medicine_col = 'medicine'
clusters = [
    'skin_cluster', 'respiratory_cluster', 'gastrointestinal_cluster', 'neuro_muscle_cluster',
    'general_cluster', 'endocrine_cluster', 'genitourinary_cluster', 'cardio_cluster', 'other_cluster'
]

edges = []
# Disease → clusters
for cluster in clusters:
    edges.append((disease_col, cluster))
# Clusters → symptoms (including overlapping symptoms)
cluster_map = {
    'skin_cluster': skin_cluster,
    'respiratory_cluster': respiratory_cluster,
    'gastrointestinal_cluster': gastrointestinal_cluster,
    'neuro_muscle_cluster': neuro_muscle_cluster,
    'general_cluster': general_cluster,
    'endocrine_cluster': endocrine_cluster,
    'genitourinary_cluster': genitourinary_cluster,
    'cardio_cluster': cardio_cluster,
    'other_cluster': other_cluster
}
for cluster, symptoms in cluster_map.items():
    for symptom in symptoms:
        if symptom in data.columns:
            edges.append((cluster, symptom))
# Disease → medicine
edges.append((disease_col, medicine_col))

In [9]:
model = DiscreteBayesianNetwork(edges)
model.fit(data, estimator=BayesianEstimator, prior_type="BDeu", equivalent_sample_size=10)
infer = VariableElimination(model)

INFO:pgmpy: Datatype (N=numerical, C=Categorical Unordered, O=Categorical Ordered) inferred from data: 
 {'itching': 'C', 'skin_rash': 'C', 'nodal_skin_eruptions': 'C', 'continuous_sneezing': 'C', 'shivering': 'C', 'chills': 'C', 'joint_pain': 'C', 'stomach_pain': 'C', 'acidity': 'C', 'ulcers_on_tongue': 'C', 'muscle_wasting': 'C', 'vomiting': 'C', 'burning_micturition': 'C', 'spotting_ urination': 'C', 'fatigue': 'C', 'weight_gain': 'C', 'anxiety': 'C', 'cold_hands_and_feets': 'C', 'mood_swings': 'C', 'weight_loss': 'C', 'restlessness': 'C', 'lethargy': 'C', 'patches_in_throat': 'C', 'irregular_sugar_level': 'C', 'cough': 'C', 'high_fever': 'C', 'sunken_eyes': 'C', 'breathlessness': 'C', 'sweating': 'C', 'dehydration': 'C', 'indigestion': 'C', 'headache': 'C', 'yellowish_skin': 'C', 'dark_urine': 'C', 'nausea': 'C', 'loss_of_appetite': 'C', 'pain_behind_the_eyes': 'C', 'back_pain': 'C', 'constipation': 'C', 'abdominal_pain': 'C', 'diarrhoea': 'C', 'mild_fever': 'C', 'yellow_urine': 'C

In [10]:
def get_most_informative_symptom(current_evidence, remaining_symptoms, infer, disease_node='prognosis'):
    max_info_gain = -np.inf
    best_symptom = None
    for symptom in remaining_symptoms:
        try:
            cpd = model.get_cpds(symptom)
            states = cpd.state_names[symptom]
            info_gain = 0
            for state in states:
                evidence = current_evidence.copy()
                evidence[symptom] = state
                q = infer.query(variables=[disease_node], evidence=evidence, show_progress=False)
                p = q.values
                entropy = -np.sum(p * np.log2(p + 1e-9))
                info_gain += entropy / len(states)
            if info_gain > max_info_gain:
                max_info_gain = info_gain
                best_symptom = symptom
        except Exception:
            continue
    return best_symptom

In [11]:
def dynamic_predict():
    print("Disease Prediction (Dynamic Symptom Questioning)\n")
    evidence = {}
    asked = set()
    # Only ask about symptoms that are in the data and not clusters/outputs
    all_symptoms = [col for col in data.columns if col not in clusters + [disease_col, medicine_col]]
    remaining = set(all_symptoms)
    max_questions = 10
    confidence_threshold = 0.9

    for i in range(max_questions):
        next_symptom = get_most_informative_symptom(evidence, list(remaining), infer)
        if not next_symptom:
            break
        states = model.get_cpds(next_symptom).state_names[next_symptom]
        ans = None
        while ans not in states:
            ans = input(f"Q{i+1}: Do you have '{next_symptom}'? ({'/'.join(states)}): ").strip()
            if ans in states:
                break
            else:
                print(f"   [!] Please enter one of: {', '.join(states)}")
        evidence[next_symptom] = ans
        asked.add(next_symptom)
        remaining.remove(next_symptom)

        # Inference
        result = infer.query(variables=[disease_col], evidence=evidence, show_progress=False)
        probs = result.values
        top_prob = np.max(probs)
        top_disease = result.state_names[disease_col][np.argmax(probs)]
        print(f"Top prediction so far: {top_disease} (Confidence: {top_prob:.1%})")
        if top_prob >= confidence_threshold:
            break

    print("\nFinal probabilities:")
    for disease, prob in zip(result.state_names[disease_col], result.values):
        print(f"{disease:<20} | Probability: {prob:.2%}")
    print(f"\n✅ Most Likely Diagnosis: {top_disease} (Confidence: {top_prob:.1%})")

    # Suggest medicine for the predicted disease
    med = data.loc[data[disease_col] == top_disease, medicine_col].mode()
    if not med.empty:
        print(f"💊 Suggested Medicine: {med.iloc[0]}")
    else:
        print("No medicine suggestion available for this diagnosis.")

In [14]:
if __name__ == "__main__":
    dynamic_predict()

Disease Prediction (Dynamic Symptom Questioning)

   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
   [!] Please enter one of: 0
Top prediction so far: Dengue (Confidence: 2.5%)
T