# AI Diagnosis Model Training for MediChain

## Objective
Train an AI diagnosis model using symptoms dataset to predict medical conditions based on patient symptoms.

## Dataset
- **symptoms_dataset.csv**: Contains binary symptom indicators and corresponding diagnoses
- **Features**: fever, cough, fatigue, shortness_of_breath, headache, sore_throat
- **Target**: diagnosis (medical condition)

## Model
- **Algorithm**: Random Forest Classifier
- **Evaluation**: Accuracy score and classification report
- **Output**: diagnosis_model.pkl for Flask backend integration

## 1. Import Required Libraries

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.preprocessing import LabelEncoder
import joblib
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
np.random.seed(42)

## 2. Load and Explore the Dataset

In [None]:
# Load the dataset
df = pd.read_csv('symptoms_dataset.csv')

# Display basic information
print("Dataset Shape:", df.shape)
print("\nColumn Names:", df.columns.tolist())
print("\nFirst 5 rows:")
df.head()

In [None]:
# Check for missing values
print("Missing values per column:")
print(df.isnull().sum())

# Display data types
print("\nData Types:")
print(df.dtypes)

## 3. Exploratory Data Analysis

In [None]:
# Count of each diagnosis
diagnosis_counts = df['diagnosis'].value_counts()
print("Diagnosis Distribution:")
print(diagnosis_counts)

# Visualize diagnosis distribution
plt.figure(figsize=(12, 6))
diagnosis_counts.plot(kind='bar')
plt.title('Distribution of Diagnoses')
plt.xlabel('Diagnosis')
plt.ylabel('Count')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

In [None]:
# Correlation heatmap
plt.figure(figsize=(10, 8))
correlation_matrix = df.corr()
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0)
plt.title('Correlation Matrix of Symptoms')
plt.tight_layout()
plt.show()

## 4. Prepare Data for Training

In [None]:
# Separate features and target
X = df.drop('diagnosis', axis=1)
y = df['diagnosis']

# Encode the target variable
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)

print("Features shape:", X.shape)
print("Target shape:", y_encoded.shape)
print("\nUnique diagnoses:", label_encoder.classes_)
print("\nEncoded labels:", np.unique(y_encoded))

In [None]:
# Split the data
X_train, X_test, y_train, y_test = train_test_split(
    X, y_encoded, test_size=0.2, random_state=42, stratify=y_encoded
)

print("Training set size:", X_train.shape)
print("Test set size:", X_test.shape)
print("\nTraining set class distribution:")
print(pd.Series(y_train).value_counts())

## 5. Train the Random Forest Model

In [None]:
# Initialize and train the model
rf_model = RandomForestClassifier(
    n_estimators=100,
    max_depth=10,
    min_samples_split=5,
    random_state=42,
    n_jobs=-1
)

# Train the model
rf_model.fit(X_train, y_train)

print("Model training completed!")
print("Number of features:", rf_model.n_features_in_)
print("Feature importances:", rf_model.feature_importances_)

In [None]:
# Feature importance visualization
feature_importance = pd.DataFrame({
    'feature': X.columns,
    'importance': rf_model.feature_importances_
}).sort_values('importance', ascending=False)

plt.figure(figsize=(10, 6))
sns.barplot(data=feature_importance, x='importance', y='feature')
plt.title('Feature Importance in Diagnosis Prediction')
plt.xlabel('Importance Score')
plt.tight_layout()
plt.show()

## 6. Model Evaluation

In [None]:
# Make predictions
y_pred = rf_model.predict(X_test)

# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Model Accuracy: {accuracy:.4f}")

# Detailed classification report
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=label_encoder.classes_))

In [None]:
# Confusion matrix
cm = confusion_matrix(y_test, y_pred)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=label_encoder.classes_, 
            yticklabels=label_encoder.classes_)
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()

## 7. Test Model with Sample Cases

In [None]:
# Test cases
test_cases = [
    [1, 1, 1, 0, 1, 1],  # High fever, cough, fatigue, headache, sore throat
    [1, 0, 1, 1, 0, 0],  # High fever, fatigue, shortness of breath
    [0, 1, 0, 0, 1, 1],  # Cough, headache, sore throat
    [1, 1, 1, 1, 0, 0],  # All symptoms except headache and sore throat
    [0, 0, 1, 0, 1, 0]   # Fatigue and headache only
]

test_df = pd.DataFrame(test_cases, columns=X.columns)

# Make predictions
predictions = rf_model.predict(test_df)
predicted_diagnoses = label_encoder.inverse_transform(predictions)

# Display results
for i, (case, diagnosis) in enumerate(zip(test_cases, predicted_diagnoses)):
    print(f"Case {i+1}: {dict(zip(X.columns, case))}")
    print(f"Predicted Diagnosis: {diagnosis}")
    print("-" * 50)

## 8. Save the Model and Encoders

In [None]:
# Save the trained model
joblib.dump(rf_model, 'diagnosis_model.pkl')
print("Model saved as 'diagnosis_model.pkl'")

# Save the label encoder
joblib.dump(label_encoder, 'label_encoder.pkl')
print("Label encoder saved as 'label_encoder.pkl'")

# Save feature names
joblib.dump(X.columns.tolist(), 'feature_names.pkl')
print("Feature names saved as 'feature_names.pkl'")

## 9. Create Model Summary Report

In [None]:
# Create a summary report
summary = {
    'model_type': 'Random Forest Classifier',
    'accuracy': accuracy,
    'num_features': len(X.columns),
    'feature_names': X.columns.tolist(),
    'num_classes': len(label_encoder.classes_),
    'class_names': label_encoder.classes_.tolist(),
    'training_samples': len(X_train),
    'test_samples': len(X_test),
    'feature_importance': dict(zip(X.columns, rf_model.feature_importances_))
}

print("=== MODEL SUMMARY ===")
for key, value in summary.items():
    print(f"{key}: {value}")

# Save summary to JSON
import json
with open('model_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)
print("\nModel summary saved to 'model_summary.json'")

## 10. Integration with Flask Backend

In [None]:
# Create a simple prediction function for Flask integration
def predict_diagnosis(symptoms_dict):
    """
    Predict diagnosis from symptoms dictionary
    
    Args:
        symptoms_dict: Dictionary with symptom names as keys and binary values
    
    Returns:
        Dictionary with predicted diagnosis and confidence
    """
    # Load model and encoder
    model = joblib.load('diagnosis_model.pkl')
    encoder = joblib.load('label_encoder.pkl')
    features = joblib.load('feature_names.pkl')
    
    # Create input array
    input_data = [symptoms_dict.get(feature, 0) for feature in features]
    
    # Make prediction
    prediction = model.predict([input_data])[0]
    probabilities = model.predict_proba([input_data])[0]
    
    # Get diagnosis and confidence
    diagnosis = encoder.inverse_transform([prediction])[0]
    confidence = probabilities[prediction]
    
    return {
        'diagnosis': diagnosis,
        'confidence': float(confidence),
        'all_probabilities': {
            encoder.inverse_transform([i])[0]: float(prob) 
            for i, prob in enumerate(probabilities)
        }
    }

# Test the function
test_symptoms = {
    'fever': 1,
    'cough': 1,
    'fatigue': 1,
    'shortness_of_breath': 0,
    'headache': 1,
    'sore_throat': 1
}

result = predict_diagnosis(test_symptoms)
print("Test Result:", result)