In [2]:
import pandas as pd
import numpy as np
import gym
from gym import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import torch
import nltk
from nltk.stem import PorterStemmer
import logging
from typing import List, Tuple, Dict, Any
import json
from datetime import datetime


logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


nltk.download('punkt', quiet=True)


ps = PorterStemmer()

class AdvancedDiagnosisSystem:
    def __init__(self, disease_data_path: str, medication_data_path: str,
                 precaution_data_path: str, medicine_data_path: str):
        self.disease_data_path = disease_data_path
        self.medication_data_path = medication_data_path
        self.precaution_data_path = precaution_data_path
        self.medicine_data_path = medicine_data_path

        self.disease_model = None
        self.drug_recommender = None
        self.le = LabelEncoder()

        self.load_and_preprocess_data()
        self.train_models()

    def load_and_preprocess_data(self):
        logger.info("Loading and preprocessing data...")
        try:
            self.disease_data = pd.read_csv(self.disease_data_path)
            self.medication_data = pd.read_csv(self.medication_data_path)
            self.precaution_data = pd.read_csv(self.precaution_data_path)
            self.medicine_data = pd.read_csv(self.medicine_data_path)
        except FileNotFoundError as e:
            logger.error(f"Error loading data: {e}")
            raise


        self.X = self.disease_data.iloc[:, :-1]
        self.y = self.le.fit_transform(self.disease_data.iloc[:, -1])
        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(self.X, self.y, test_size=0.2, random_state=42)


        self.medicine_data['combined_text'] = self.medicine_data['Reason'] + ' ' + self.medicine_data['Description']
        self.medicine_data['combined_text'] = self.medicine_data['combined_text'].apply(self.stem)

        self.vectorizer = TfidfVectorizer(stop_words='english', max_features=5000)
        self.tfidf_matrix = self.vectorizer.fit_transform(self.medicine_data['combined_text'])
        self.cosine_sim = cosine_similarity(self.tfidf_matrix)

    def train_models(self):
        logger.info("Training models...")

        env = DummyVecEnv([lambda: self.DiseaseDiagnosisEnv(self.X_train, self.y_train)])
        self.disease_model = PPO("MlpPolicy", env, verbose=1)
        self.disease_model.learn(total_timesteps=2000000)



    def stem(self, text: str) -> str:
        words = nltk.word_tokenize(text)
        return ' '.join([ps.stem(word) for word in words])

    class DiseaseDiagnosisEnv(gym.Env):
        def __init__(self, X, y):
            super().__init__()
            self.X = X
            self.y = y
            self.current_step = 0
            self.action_space = spaces.Discrete(len(np.unique(y)))
            self.observation_space = spaces.Box(low=0, high=1, shape=(X.shape[1],), dtype=np.float32)

        def reset(self):
            self.current_step = 0
            return self.X.iloc[self.current_step].values.astype(np.float32)

        def step(self, action):
            reward = 1 if action == self.y[self.current_step] else -1
            self.current_step += 1
            done = self.current_step >= len(self.X)
            obs = self.X.iloc[self.current_step].values.astype(np.float32) if not done else np.zeros(self.X.shape[1])
            return obs, reward, done, {}

    def get_predicted_proba(self, symptoms: List[str]) -> np.ndarray:
        obs = np.zeros(len(self.X.columns))
        for symptom in symptoms:
            if symptom in self.X.columns:
                obs[self.X.columns.get_loc(symptom)] = 1
        obs_tensor = torch.FloatTensor(obs).unsqueeze(0)
        with torch.no_grad():
            action, _ = self.disease_model.predict(obs_tensor)
            probs = self.disease_model.policy.get_distribution(obs_tensor).distribution.probs
        return probs.squeeze().numpy()

    def predict_disease(self, symptoms: List[str]) -> List[Tuple[str, float]]:
        probs = self.get_predicted_proba(symptoms)
        top_indices = probs.argsort()[::-1]
        return [(self.le.inverse_transform([i])[0], probs[i]) for i in top_indices]

    def get_medications(self, disease: str) -> List[str]:
        return self.medication_data[self.medication_data['Disease'] == disease]['Medication'].tolist()

    def get_precautions(self, disease: str) -> List[str]:
        disease_precautions = self.precaution_data[self.precaution_data['Disease'] == disease].iloc[0]
        return [p for p in disease_precautions[['Precaution_1', 'Precaution_2', 'Precaution_3', 'Precaution_4']] if pd.notna(p)]

    def get_drug_recommendations(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
        processed_query = self.stem(query)
        query_vec = self.vectorizer.transform([processed_query])
        sim_scores = cosine_similarity(query_vec, self.tfidf_matrix)[0]
        top_indices = sim_scores.argsort()[-top_k:][::-1]
        return [(self.medicine_data.iloc[i]['Drug_Name'], sim_scores[i]) for i in top_indices]

    def analyze_symptoms(self, input_symptoms: str) -> Dict[str, Any]:
        symptoms = [symptom.strip().replace(' ', '_').lower() for symptom in input_symptoms.split(',')]
        predicted_diseases = self.predict_disease(symptoms)

        top_disease, probability = predicted_diseases[0]

        medications = self.get_medications(top_disease)
        precautions = self.get_precautions(top_disease)

        query = f"Treatment for {top_disease}"
        drug_recommendations = self.get_drug_recommendations(query)

        return {
            "input_symptoms": symptoms,
            "predicted_disease": top_disease,
            "disease_probability": float(probability),
            "recommended_medications": medications,
            "precautions": precautions,
            "additional_drug_recommendations": drug_recommendations,
            "all_predicted_diseases": [(disease, float(prob)) for disease, prob in predicted_diseases[:5]]
        }

    def generate_report(self, analysis_result: Dict[str, Any]) -> str:
        report = f"""
        ======= Disease Diagnosis and Treatment Report =======
        Date: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}

        Input Symptoms: {', '.join(analysis_result['input_symptoms'])}

        Top Predicted Disease: {analysis_result['predicted_disease']} (Probability: {analysis_result['disease_probability']:.2%})

        Other Possible Diseases:
        {self._format_disease_list(analysis_result['all_predicted_diseases'][1:])}

        Recommended Medications:
        {self._format_list(analysis_result['recommended_medications'])}

        Precautions:
        {self._format_list(analysis_result['precautions'])}

        Additional Drug Recommendations:
        {self._format_drug_list(analysis_result['additional_drug_recommendations'])}

        ====================================================
        Note: This report is generated by an QSTAR system and should be reviewed by a healthcare professional before making any medical decisions.
        """
        return report

    def _format_disease_list(self, diseases: List[Tuple[str, float]]) -> str:
        return '\n'.join([f"  - {disease}: {probability:.2%}" for disease, probability in diseases])

    def _format_list(self, items: List[str]) -> str:
        return '\n'.join([f"  - {item}" for item in items])

    def _format_drug_list(self, drugs: List[Tuple[str, float]]) -> str:
        return '\n'.join([f"  - {drug}: {score:.4f}" for drug, score in drugs])

    def save_report(self, report: str, filename: str):
        try:
            with open(filename, 'w') as f:
                f.write(report)
            logger.info(f"Report saved to {filename}")
        except IOError as e:
            logger.error(f"Error saving report: {e}")

    def process_input(self, input_symptoms: str) -> None:
        try:

            clean_input = input_symptoms.replace('"', '').replace("'", "")
            if "symptoms" in clean_input:
                clean_input = clean_input.split("symptoms")[1].split(",")[1:]
                clean_input = ",".join(clean_input).split("expected_disease")[0].strip()

            analysis_result = self.analyze_symptoms(clean_input)
            report = self.generate_report(analysis_result)
            print(report)

            filename = f"medical_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
            self.save_report(report, filename)


            json_filename = f"analysis_result_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
            with open(json_filename, 'w') as f:
                json.dump(analysis_result, f, indent=2, default=float)
            logger.info(f"Analysis result saved to {json_filename}")

        except Exception as e:
            logger.error(f"Error processing input: {e}")

if __name__ == "__main__":
    system = AdvancedDiagnosisSystem(
        disease_data_path='Training.csv',
        medication_data_path='medications.csv',
        precaution_data_path='precautions_df.csv',
        medicine_data_path='medicine.csv'
    )

    while True:
        input_symptoms = input("Enter symptoms (comma-separated) or 'quit' to exit: ")
        if input_symptoms.lower() == 'quit':
            break
        system.process_input(input_symptoms)

  from jax import xla_computation as _xla_computation


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
|    iterations           | 702         |
|    time_elapsed         | 2201        |
|    total_timesteps      | 1437696     |
| train/                  |             |
|    approx_kl            | 0.008124447 |
|    clip_fraction        | 0.0221      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.0718     |
|    explained_variance   | -0.118      |
|    learning_rate        | 0.0003      |
|    loss                 | 0.198       |
|    n_updates            | 7010        |
|    policy_gradient_loss | -0.00842    |
|    value_loss           | 0.639       |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 653         |
|    iterations           | 703         |
|    time_elapsed         | 2204        |
|    total_timesteps      | 1439744     |
| train/                  |             |
|    approx

In [1]:
!pip install stable-baselines3[extra]

Collecting stable-baselines3[extra]
  Downloading stable_baselines3-2.3.2-py3-none-any.whl.metadata (5.1 kB)
Collecting gymnasium<0.30,>=0.28.1 (from stable-baselines3[extra])
  Downloading gymnasium-0.29.1-py3-none-any.whl.metadata (10 kB)
Collecting shimmy~=1.3.0 (from shimmy[atari]~=1.3.0; extra == "extra"->stable-baselines3[extra])
  Downloading Shimmy-1.3.0-py3-none-any.whl.metadata (3.7 kB)
Collecting autorom~=0.6.1 (from autorom[accept-rom-license]~=0.6.1; extra == "extra"->stable-baselines3[extra])
  Downloading AutoROM-0.6.1-py3-none-any.whl.metadata (2.4 kB)
Collecting AutoROM.accept-rom-license (from autorom[accept-rom-license]~=0.6.1; extra == "extra"->stable-baselines3[extra])
  Downloading AutoROM.accept-rom-license-0.6.1.tar.gz (434 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m434.7/434.7 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdon