<a href="https://colab.research.google.com/github/delfinodjaja/InteractiveDiseasePredictor/blob/main/InteractiveDiseasePrediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time
import random
from datetime import datetime
from collections import defaultdict
import math
import pickle
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch import nn
from nltk.corpus import stopwords
import nltk
from nltk.stem.wordnet import WordNetLemmatizer
import string
import time
import torch.nn.functional as F
from tensorflow.keras.preprocessing.sequence import pad_sequences

In [2]:
device='cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
df_med=pd.read_csv('cleaned_med.csv',encoding='windows-1252')
df_rf=pd.read_csv('cleaned_riskFactor.csv',encoding='windows-1252')
df_precaution=pd.read_csv('cleaned_precaution.csv',encoding='windows-1252')
df=pd.read_csv('better_dummy.csv')

In [4]:
Id_Name=df_rf[["DID","DNAME"]]
Name_id=df_rf[["DNAME","DID"]]
Id_Name_dict=Id_Name.set_index('DNAME').to_dict()
Name_Id_dict=Name_id.set_index('DID').to_dict()

In [5]:
Id_Name_dict['DID']['Typhoid']

119

In [6]:
feature=df.columns[1:]
label=df['disease_code'].unique()
y=np.array(df['disease_code'])
x=np.array(df.iloc[:,1:])

In [7]:
x

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [8]:
with open('label_encoder.pkl', 'rb') as f:
    le = pickle.load(f)

In [9]:
from tensorflow.keras.preprocessing.text import tokenizer_from_json
import json

with open("tokenizer_v2.json") as f:
    tokenizer_data =f.read()

tokenizer = tokenizer_from_json(tokenizer_data)

In [10]:
class TransformerModel(nn.Module):
    def __init__(self, hidden, num_layers, vocab_size, embedding_dim, dropout_rate, nhead=4):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)

        encoder_layers = TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=nhead,
            dim_feedforward=hidden * 4,
            dropout=dropout_rate,
            batch_first=True
        )
        self.transformer = TransformerEncoder(encoder_layers, num_layers=num_layers)

        self.dense = nn.Sequential(
            nn.Linear(embedding_dim, hidden * 2),
            nn.ReLU(),
            nn.Linear(hidden * 2, embedding_dim)
        )
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x1, x2, mask1=None, mask2=None):
        emb1 = self.dropout(self.embedding(x1))
        emb2 = self.dropout(self.embedding(x2))

        out1 = self.transformer(emb1, src_key_padding_mask=mask1)
        out2 = self.transformer(emb2, src_key_padding_mask=mask2)

        pooled1 = out1.mean(dim=1)
        pooled2 = out2.mean(dim=1)

        proj1 = self.dense(pooled1)
        proj2 = self.dense(pooled2)

        return F.normalize(proj1, p=2, dim=1), F.normalize(proj2, p=2, dim=1)

    @staticmethod
    def generate_mask(x, pad_idx=0):
        return (x == pad_idx)

In [11]:
def prior_probabilty(data):
  size=len(data)
  prob=dict()
  for i in data:
    try:
      prob[i]+=1
    except Exception:
      prob[i]=1
  prior={cls:count/size for cls,count in prob.items()}
  return prior

In [12]:
def likelihoods(x,y,num_feature):
  likelihoods = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
  count_label =dict()
  for ft,label in zip(x,y):
    try:
      count_label[label]+=1
    except Exception:
      count_label[label]=1
    for i in range(num_feature):
      val=ft[i]
      likelihoods[label][i][val]+=1
  for label in likelihoods:
    for i in range(num_feature):
      for val in [0, 1]:
          count = likelihoods[label][i][val]
          likelihoods[label][i][val] = (count + 1) / (count_label[label] + 2)
  return likelihoods

In [13]:
prior=prior_probabilty(y)
likelihoods_dict=likelihoods(x,y,x.shape[1])

In [14]:
def predictv2(symptom, prior, likelihood, num_feature):
  probabilities={}
  for i in prior:
      prob = math.log(prior[i])
      for j in range(num_feature):
          val = symptom[j]
          if val !=-1:
              prob +=math.log(likelihood[i][j][val])
      probabilities[i]= prob
  best_class = max(probabilities, key=probabilities.get)
  return best_class, probabilities

In [15]:
def calculate_information_gain(current_state, priors, likelihoods, num_features):
  symptoms_asked= [i for i, val in enumerate(current_state) if val != -1]

  best_gain=-float('inf')
  best_feature=None
  for feature in range(num_features):
      if feature in symptoms_asked:
          continue
      entropy_before=calculate_entropy(current_state, priors, likelihoods, num_features)

      entropy_after=0
      for val in [0, 1]:
          temp_state = current_state.copy()
          temp_state[feature] = val
          p_val = estimate_feature_probability(feature, val, current_state, priors, likelihoods, num_features)
          entropy_after +=p_val * calculate_entropy(temp_state, priors, likelihoods, num_features)

      gain = entropy_before - entropy_after

      if gain > best_gain:
          best_gain=gain
          best_feature=feature

  return best_feature

In [16]:
def calculate_entropy(state, priors, likelihoods, num_features):
  probabilities=get_disease_probabilities(state, priors, likelihoods, num_features)
  entropy = 0
  for p in probabilities.values():
      if p > 0:
          entropy -= p * math.log(p)
  return entropy

In [17]:
def get_disease_probabilities(state, priors, likelihoods, num_features):
  unnormalized={}
  for disease in priors:
      prob = math.log(priors[disease])
      for feature in range(num_features):
          val = state[feature]
          if val != -1:
              prob += math.log(likelihoods[disease][feature][val])
      unnormalized[disease] = math.exp(prob)

  total=sum(unnormalized.values())
  return {disease: p/total for disease, p in unnormalized.items()}

In [18]:
def estimate_feature_probability(feature, value, state, priors, likelihoods, num_features):
  probs = get_disease_probabilities(state, priors, likelihoods, num_features)
  p_feature_val = 0
  for disease, prob in probs.items():
      p_feature_val += prob * likelihoods[disease][feature][value]
  return p_feature_val

In [19]:
nltk.download('wordnet')
nltk.download('stopwords')
lemmatizer = WordNetLemmatizer()
stop_words = set(stopwords.words('english'))

def cleaning(text):
    clean_text = text.translate(str.maketrans('', '', string.punctuation)).lower()

    words = [word for word in clean_text.split() if word not in stop_words]

    lemmatized = [lemmatizer.lemmatize(word, 'v') for word in words]

    return ' '.join(lemmatized)

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


In [20]:
model = TransformerModel(
    hidden=128,
    num_layers=2,
    vocab_size=9000,
    embedding_dim=300,
    dropout_rate=0.3,
    nhead=4
)
model.load_state_dict(torch.load('fine_tuned_model_1.pt',map_location=torch.device(device)))


<All keys matched successfully>

In [23]:
class v3_1_0:
    def __init__(self, symptom_names, priors, likelihoods, num_features, df_med, df_precaution, le, Id_Name_dict,
                 symptom_description, tokenizer, model, device, cleaning_func):
        self.symptom_names = symptom_names
        self.priors = priors
        self.likelihoods = likelihoods
        self.num_features = num_features
        self.df_med = df_med
        self.df_precaution = df_precaution
        self.le = le
        self.Id_Name_dict = Id_Name_dict
        self.conversation_history = []

        self.symptom_description = symptom_description
        self.tokenizer = tokenizer
        self.model = model
        self.device = device
        self.cleaning_func = cleaning_func
        self.comparison = self._prepare_comparison_data()

    def _prepare_comparison_data(self):
        comparison = self.symptom_description['description'].apply(self.cleaning_func)
        comparison = self.tokenizer.texts_to_sequences(comparison)
        comparison = pad_sequences(comparison, maxlen=50, padding='post')
        return torch.from_numpy(comparison)

    def desc_to_symptom_top3(self, description):
        cleaned_desc = self.cleaning_func(description)
        tokenized_desc = self.tokenizer.texts_to_sequences([cleaned_desc])
        padded_desc = pad_sequences(tokenized_desc, maxlen=50, padding='post')
        desc_tensor = torch.from_numpy(padded_desc).squeeze(0)

        scores = []
        self.model.eval()

        for i in self.comparison:
            with torch.no_grad():
                x1, x2 = i.unsqueeze(0).to(self.device), desc_tensor.unsqueeze(0).to(self.device)
                mask1 = self.model.generate_mask(x1)
                mask2 = self.model.generate_mask(x2)
                out1, out2 = self.model(x1, x2, mask1, mask2)
                dist = F.cosine_similarity(out1, out2)
                scores.append(dist.item())

        # Get top 3 matches
        top3_indices = np.argsort(scores)[-3:][::-1]  # Top 3 in descending order

        top3_results = []
        for idx in top3_indices:
            symptom = self.symptom_description['symptoms'].iloc[idx]
            desc = self.symptom_description['description'].iloc[idx]
            score = scores[idx]
            top3_results.append({
                'symptom': symptom,
                'description': desc,
                'score': score
            })

        return top3_results

    def display_header(self):
        print("=" * 60)
        print("🏥 AI MEDICAL ASSISTANT")
        print("=" * 60)
        print("Hello! I'm your AI medical assistant.")
        print("I'll ask you some questions about your symptoms to help identify")
        print("possible conditions. Please answer honestly and accurately.\n")
        print("💡 FEATURES:")
        print("   • Describe symptoms in your own words anytime")
        print("   • Switch between description and structured questions")
        print("   • Get multiple symptom suggestions\n")
        print("⚠️  IMPORTANT: This is for informational purposes only.")
        print("   Always consult a real doctor for proper medical advice.\n")

    def get_user_input(self, question, valid_responses=None, allow_quit=True, allow_describe=False, allow_switch=False):
        while True:
            prompt = f"{question}"
            options = []

            if allow_quit:
                options.append("'quit' to exit")
            if allow_describe:
                options.append("'describe' to describe your symptom")
            if allow_switch:
                options.append("'switch' to change question mode")

            if options:
                prompt += f" ({', '.join(options)})"
            prompt += ": "

            response = input(prompt).strip().lower()

            if allow_quit and response == 'quit':
                print("\n👋 Thank you for using the medical assistant. Stay healthy!")
                return None

            if allow_describe and response == 'describe':
                return 'describe'

            if allow_switch and response == 'switch':
                return 'switch'

            if valid_responses:
                if response in valid_responses:
                    return response
                valid_options = ', '.join(valid_responses)
                if allow_describe:
                    valid_options += ", 'describe'"
                if allow_switch:
                    valid_options += ", 'switch'"
                print(f"Please respond with one of: {valid_options}")
            else:
                return response

    def handle_symptom_description(self, show_alternatives=True):
        print("\n🗣️  Describe your symptom in your own words:")
        print("   (e.g., 'I have a splitting headache', 'My stomach really hurts', etc.)")

        description = input("Your description: ").strip()

        if not description:
            print("❌ Please provide a description.")
            return None

        print("🔍 Analyzing your description...")

        try:
            top3_matches = self.desc_to_symptom_top3(description)

            print(f"\n✨ Here are the top 3 symptoms that match your description:")
            print("-" * 50)

            for i, match in enumerate(top3_matches, 1):
                confidence = match['score'] * 100
                confidence_icon = "🔴" if confidence > 70 else "🟡" if confidence > 50 else "🟢"

                print(f"{i}. {match['symptom']}")
                print(f"   💭 Description: {match['description']}")
                print(f"   🎯 Confidence: {confidence:.1f}% {confidence_icon}")
                print()

            choice = self.get_user_input(
                "Which symptom matches best? (1-3, or 'none' if no match)",
                ['1', '2', '3', 'none'],
                allow_quit=False
            )

            if choice == 'none':
                if show_alternatives:
                    return self.handle_alternative_description()
                else:
                    return None
            else:
                selected_match = top3_matches[int(choice) - 1]
                print(f"✅ Selected: {selected_match['symptom']}")
                return selected_match['symptom']

        except Exception as e:
            print(f"❌ Error processing description: {str(e)}")
            return None

    def handle_alternative_description(self):
        """Handle cases where initial description matching fails"""
        print("\n💭 Let's try a different approach:")
        print("1. Try describing it differently")
        print("2. Choose from common symptoms")
        print("3. Skip this symptom")

        choice = self.get_user_input(
            "What would you like to do?",
            ['1', '2', '3'],
            allow_quit=False
        )

        if choice == '1':
            return self.handle_symptom_description(show_alternatives=False)
        elif choice == '2':
            return self.show_common_symptoms()
        else:
            return None

    def show_common_symptoms(self):
        common_symptoms = [
            "headache", "fever", "cough", "nausea", "fatigue",
            "chest pain", "abdominal pain", "diarrhea", "vomiting", "dizziness"
        ]

        # Filter symptoms that exist in our symptom names
        available_symptoms = [s for s in common_symptoms
                            if any(s.lower() in symptom.lower() for symptom in self.symptom_names)]

        if not available_symptoms:
            print("❌ No common symptoms available.")
            return None

        print("\n📋 Common symptoms:")
        for i, symptom in enumerate(available_symptoms[:10], 1):
            print(f"   {i}. {symptom.title()}")

        choice = self.get_user_input(
            f"Select a number (1-{min(10, len(available_symptoms))})",
            [str(i) for i in range(1, min(11, len(available_symptoms)+1))],
            allow_quit=False
        )

        if choice:
            selected = available_symptoms[int(choice)-1]
            for symptom in self.symptom_names:
                if selected.lower() in symptom.lower():
                    return symptom

        return None

    def format_symptom_question(self, symptom):
        symptom_lower = symptom.lower()

        if 'pain' in symptom_lower:
            return f"Are you experiencing {symptom.lower()}?"
        elif 'fever' in symptom_lower:
            return f"Do you currently have {symptom.lower()}?"
        elif any(word in symptom_lower for word in ['nausea', 'vomiting', 'diarrhea']):
            return f"Are you experiencing {symptom.lower()}?"
        elif any(word in symptom_lower for word in ['cough', 'sneeze']):
            return f"Do you have a {symptom.lower()}?"
        else:
            return f"Do you have {symptom.lower()}?"

    def show_progress(self, asked_questions, max_questions, confidence):
        progress = asked_questions / max_questions
        bar_length = 20
        filled = int(bar_length * progress)
        bar = "█" * filled + "░" * (bar_length - filled)

        print(f"\n📊 Progress: [{bar}] {asked_questions}/{max_questions} questions")
        print(f"🎯 Current confidence: {confidence*100:.1f}%")
        print("-" * 50)

    def show_intermediate_results(self, disease_probs, asked_questions, max_questions):
        print(f"\n🔍 Analysis after {asked_questions} questions:")

        sorted_diseases = sorted(disease_probs.items(), key=lambda x: x[1], reverse=True)[:3]

        for i, (disease_id, prob) in enumerate(sorted_diseases, 1):
            disease_name = self.le.inverse_transform([disease_id])[0]
            confidence_level = "🔴 High" if prob > 0.6 else "🟡 Medium" if prob > 0.3 else "🟢 Low"
            print(f"  {i}. {disease_name}: {prob*100:.1f}% {confidence_level}")

        if asked_questions < max_questions:
            print(f"\n💭 I need more information to be more certain...")
            time.sleep(1.5)

    def get_severity_info(self, symptom):
        severity_q = f"On a scale of 1-5, how severe is your {symptom.lower()}? (1=mild, 5=severe)"
        severity = self.get_user_input(severity_q, ['1', '2', '3', '4', '5'], allow_quit=False)

        duration_q = f"How long have you had {symptom.lower()}? (hours/days/weeks)"
        duration = self.get_user_input(duration_q, allow_quit=False)

        return severity, duration

    def display_final_results(self, final_diagnosis, disease_probs, current_state):
        disease_name = self.le.inverse_transform([final_diagnosis])[0]
        probability = disease_probs[final_diagnosis] * 100
        probability = 95 if probability>95 else probability
        print("\n" + "=" * 60)
        print("🏥 DIAGNOSIS RESULTS")
        print("=" * 60)

        if probability > 80:
            confidence_text = "🔴 Very High Confidence"
            recommendation = "Strongly recommend seeing a doctor soon."
        elif probability > 60:
            confidence_text = "🟡 High Confidence"
            recommendation = "Recommend consulting with a healthcare provider."
        elif probability > 40:
            confidence_text = "🟢 Moderate Confidence"
            recommendation = "Consider seeing a doctor if symptoms persist."
        else:
            confidence_text = "⚪ Low Confidence"
            recommendation = "Multiple conditions possible. See a doctor for proper diagnosis."

        print(f"📋 Most Likely Condition: {disease_name}")
        print(f"🎯 Confidence Level: {probability:.1f}% ({confidence_text})")
        print(f"💡 Recommendation: {recommendation}\n")

        reported_symptoms = [self.symptom_names[i] for i, val in enumerate(current_state) if val == 1]
        if reported_symptoms:
            print("✅ Symptoms you reported:")
            for symptom in reported_symptoms:
                print(f"   • {symptom}")
            print()

        self.show_treatment_info(disease_name)

    def show_treatment_info(self, disease_name):
        try:
            disease_id = self.Id_Name_dict['DID'][disease_name]
            suggested_medicine = self.df_med[self.df_med['DID'] == str(disease_id)]
            precautions = self.df_precaution[self.df_precaution['DID'] == disease_id]

            print("💊 SUGGESTED TREATMENTS:")
            if not suggested_medicine.empty:
                for _, row in suggested_medicine.iterrows():
                    print(f"   • {row['Medicine_Name']}")
            else:
                print("   No specific medications suggested.")

            print("\n🛡️  RECOMMENDED PRECAUTIONS:")
            if not precautions.empty:
                for _, row in precautions.iterrows():
                    precaution_list = [str(p) for p in row[1:-1].values if str(p).strip() and str(p) != 'nan']
                    for precaution in precaution_list:
                        print(f"   • {precaution}")
            else:
                print("   No specific precautions available.")

        except Exception as e:
            print("   Unable to retrieve treatment information.")

        print("\n⚠️  IMPORTANT REMINDER:")
        print("   This AI assistant provides general information only.")
        print("   Always consult with a qualified healthcare professional")
        print("   for proper medical diagnosis and treatment.")

    def choose_question_mode(self):
        """Let user choose between description and structured questions"""
        print("\n🔄 Choose your preferred question mode:")
        print("1. Describe symptoms in your own words")
        print("2. Answer structured yes/no questions")

        choice = self.get_user_input(
            "Choose mode (1 or 2)",
            ['1', '2'],
            allow_quit=False
        )

        return 'describe' if choice == '1' else 'structured'

    def ask_initial_symptoms(self):
        """Ask user if they want to describe symptoms initially"""
        print("🚀 Let's start your medical consultation!")

        mode = self.choose_question_mode()
        initial_symptoms = []

        if mode == 'describe':
            print("\n🗣️  Please describe your main symptoms:")
            while True:
                symptom = self.handle_symptom_description()
                if symptom:
                    initial_symptoms.append(symptom)
                    print(f"✅ Added: {symptom}")

                more = self.get_user_input(
                    "\nDo you have any other symptoms to describe?",
                    ['yes', 'y', 'no', 'n'],
                    allow_quit=False
                )

                if more in ['no', 'n']:
                    break

        return initial_symptoms, mode

    def handle_description_mode_questioning(self, current_state, asked_questions, max_questions):
        """Handle questioning in description mode"""
        print(f"\n🗣️  Question {asked_questions + 1}/{max_questions} (Description Mode):")
        print("Please describe any additional symptoms you're experiencing, or type 'done' if you have no more symptoms to describe.")

        while True:
            response = self.get_user_input(
                "Describe your symptom (or 'done' to finish, 'switch' to change mode)",
                allow_switch=True,
                allow_quit=True
            )

            if response is None:  # quit
                return None, None

            if response == 'switch':
                return 'switch', None

            if response == 'done':
                return 'done', None

            # Try to match the description to a symptom
            described_symptom = self.handle_symptom_description()
            if described_symptom:
                try:
                    symptom_idx = self.symptom_names.get_loc(described_symptom)
                    if current_state[symptom_idx] == -1:  # Not already answered
                        current_state[symptom_idx] = 1
                        severity, duration = self.get_severity_info(described_symptom)
                        self.conversation_history.append({
                            'symptom': described_symptom,
                            'present': True,
                            'severity': severity,
                            'duration': duration
                        })
                        print(f"✅ Added new symptom: {described_symptom}")
                        return 'added', described_symptom
                    else:
                        print(f"⚠️  You've already provided information about {described_symptom}")
                        continue
                except (KeyError, ValueError):
                    print(f"⚠️  Symptom '{described_symptom}' not found in database.")
                    continue
            else:
                print("❌ Could not identify a symptom from your description. Please try again or type 'done'.")
                continue

    def handle_structured_mode_questioning(self, current_state, next_feature):
        """Handle questioning in structured mode"""
        symptom = self.symptom_names[next_feature]
        question = self.format_symptom_question(symptom)

        response = self.get_user_input(
            question,
            ['yes', 'y', 'no', 'n'],
            allow_describe=True,
            allow_switch=True
        )

        return response, symptom

    def run_diagnosis(self, max_questions=15, confidence_threshold=0.85):
        self.display_header()

        # Ask for initial symptoms and mode
        initial_symptoms, current_mode = self.ask_initial_symptoms()
        if initial_symptoms is None:
            return None

        current_state = [-1] * self.num_features
        asked_questions = 0

        # Set initial symptoms
        for symptom in initial_symptoms:
            try:
                symptom_idx = self.symptom_names.get_loc(symptom)
                current_state[symptom_idx] = 1
                severity, duration = self.get_severity_info(symptom)
                self.conversation_history.append({
                    'symptom': symptom,
                    'present': True,
                    'severity': severity,
                    'duration': duration
                })
                asked_questions += 1
            except (ValueError, KeyError):
                print(f"⚠️  Symptom '{symptom}' not found in database.")

        # Main diagnosis loop with proper mode handling
        while asked_questions < max_questions:
            # Calculate disease probabilities
            disease_probs = get_disease_probabilities(current_state, self.priors, self.likelihoods, self.num_features)
            most_likely_disease = max(disease_probs, key=disease_probs.get)
            confidence = disease_probs[most_likely_disease]

            if asked_questions > 0:
                self.show_progress(asked_questions, max_questions, confidence)
                self.show_intermediate_results(disease_probs, asked_questions, max_questions)

            # Check if confident enough
            if confidence > confidence_threshold and asked_questions >= 4:
                capped=0.95 if confidence>0.95 else confidence
                print(f"\n✅ I'm confident enough ({capped*100:.1f}%) to provide a diagnosis.")
                break

            print(f"\n🔧 Current mode: {'Description' if current_mode == 'describe' else 'Structured'}")

            # Handle different modes
            if current_mode == 'describe':
                result, symptom_info = self.handle_description_mode_questioning(
                    current_state, asked_questions, max_questions
                )

                if result is None:  # quit
                    return None
                elif result == 'switch':
                    current_mode = self.choose_question_mode()
                    continue
                elif result == 'done':
                    print("✅ Moving to final diagnosis based on provided symptoms.")
                    break
                elif result == 'added':
                    asked_questions += 1
                    continue

            else:  # structured mode
                # Find next best question using information gain
                next_feature = calculate_information_gain(current_state, self.priors, self.likelihoods, self.num_features)
                if next_feature is None:
                    break

                # Skip if already answered
                if current_state[next_feature] != -1:
                    continue

                print(f"\n❓ Question {asked_questions + 1}/{max_questions}:")
                response, symptom = self.handle_structured_mode_questioning(current_state, next_feature)

                if response is None:  # quit
                    return None
                elif response == 'switch':
                    current_mode = self.choose_question_mode()
                    continue
                elif response == 'describe':
                    # Switch to description mode for this question
                    described_symptom = self.handle_symptom_description()
                    if described_symptom:
                        # Check if it matches the current question
                        if described_symptom == symptom:
                            response = 'yes'
                        else:
                            # Add the described symptom if it's different and not already answered
                            try:
                                desc_symptom_idx = self.symptom_names.get_loc(described_symptom)
                                if current_state[desc_symptom_idx] == -1:
                                    current_state[desc_symptom_idx] = 1
                                    severity, duration = self.get_severity_info(described_symptom)
                                    self.conversation_history.append({
                                        'symptom': described_symptom,
                                        'present': True,
                                        'severity': severity,
                                        'duration': duration
                                    })
                                    print(f"✅ Added new symptom: {described_symptom}")
                                else:
                                    print(f"⚠️  You've already provided information about {described_symptom}")
                            except (KeyError, ValueError):
                                print(f"⚠️  Symptom '{described_symptom}' not found in database.")

                            # Still need to answer the original question
                            original_response = self.get_user_input(
                                f"Now, back to the original question: {self.format_symptom_question(symptom)}",
                                ['yes', 'y', 'no', 'n'],
                                allow_quit=False
                            )
                            response = original_response
                    else:
                        continue

                # Process the structured response
                if response in ['yes', 'y']:
                    current_state[next_feature] = 1
                    if asked_questions < max_questions - 5:
                        print("📝 Getting more details...")
                        severity, duration = self.get_severity_info(symptom)
                        self.conversation_history.append({
                            'symptom': symptom,
                            'present': True,
                            'severity': severity,
                            'duration': duration
                        })
                elif response in ['no', 'n']:
                    current_state[next_feature] = 0
                    self.conversation_history.append({
                        'symptom': symptom,
                        'present': False
                    })

                asked_questions += 1

        # Final diagnosis
        final_diagnosis, _ = predictv2(current_state, self.priors, self.likelihoods, self.num_features)
        self.display_final_results(final_diagnosis, disease_probs, current_state)

        return final_diagnosis

def run_enhanced(symptom_names, priors, likelihoods, num_features, df_med, df_precaution, le, Id_Name_dict,
                symptom_description, tokenizer, model, device, cleaning_func):

    chatbot = v3_1_0(symptom_names, priors, likelihoods, num_features, df_med, df_precaution, le, Id_Name_dict,
                     symptom_description, tokenizer, model, device, cleaning_func)
    return chatbot.run_diagnosis()

In [25]:
symptom_description=pd.read_csv('symptom_descriptions.csv')
run_enhanced(feature, prior ,likelihoods_dict, 131, df_med, df_precaution, le, Id_Name_dict,
             symptom_description, tokenizer, model, device, cleaning)

🏥 AI MEDICAL ASSISTANT
Hello! I'm your AI medical assistant.
I'll ask you some questions about your symptoms to help identify
possible conditions. Please answer honestly and accurately.

💡 FEATURES:
   • Describe symptoms in your own words anytime
   • Switch between description and structured questions
   • Get multiple symptom suggestions

⚠️  IMPORTANT: This is for informational purposes only.
   Always consult a real doctor for proper medical advice.

🚀 Let's start your medical consultation!

🔄 Choose your preferred question mode:
1. Describe symptoms in your own words
2. Answer structured yes/no questions
Choose mode (1 or 2): 1

🗣️  Please describe your main symptoms:

🗣️  Describe your symptom in your own words:
   (e.g., 'I have a splitting headache', 'My stomach really hurts', etc.)
Your description: i am feeling pain in chest and cramping pelvis
🔍 Analyzing your description...


  output = torch._nested_tensor_from_mask(



✨ Here are the top 3 symptoms that match your description:
--------------------------------------------------
1. chest_pain
   💭 Description: Pain or discomfort in the chest area can be serious if sudden.
   🎯 Confidence: 99.9% 🔴

2. belly_pain
   💭 Description: General pain in the abdominal area, similar to abdominal pain.
   🎯 Confidence: 99.9% 🔴

3. abdominal_pain
   💭 Description: Discomfort or pain in the area between the chest and pelvis may be cramping, dull, or sharp.
   🎯 Confidence: 99.8% 🔴

Which symptom matches best? (1-3, or 'none' if no match): 1
✅ Selected: chest_pain
✅ Added: chest_pain

Do you have any other symptoms to describe?: n
On a scale of 1-5, how severe is your chest_pain? (1=mild, 5=severe): 1
How long have you had chest_pain? (hours/days/weeks): 4

📊 Progress: [█░░░░░░░░░░░░░░░░░░░] 1/15 questions
🎯 Current confidence: 16.4%
--------------------------------------------------

🔍 Analysis after 1 questions:
  1. Tuberculosis: 16.4% 🟢 Low
  2. Common Cold: 16.

np.int64(36)