# 2. Exploratory Data Analysis

This notebook performs exploratory data analysis on the processed symptom-disease dataset.

## Objectives
1. Analyze dataset statistics
2. Visualize class distribution
3. Analyze symptom frequency
4. Create correlation heatmaps
5. Save visualization reports

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
from pathlib import Path

# Set style for visualizations
plt.style.use('seaborn')
sns.set_palette('husl')

# Create reports directory if it doesn't exist
Path('../reports/eda').mkdir(parents=True, exist_ok=True)

## 1. Load Processed Data

In [None]:
# Load processed dataset
df = pd.read_csv('../data/processed/processed_data.csv')

# Load disease mapping
with open('../data/processed/disease_mapping.json', 'r') as f:
    disease_mapping = json.load(f)

# Reverse mapping for readability
id_to_disease = {v: k for k, v in disease_mapping.items()}

print(f"Dataset shape: {df.shape}")

## 2. Dataset Statistics

In [None]:
def analyze_dataset_statistics(df):
    """Calculate and display basic dataset statistics."""
    stats = {
        'Total samples': len(df),
        'Number of features': len(df.columns) - 1,
        'Number of diseases': len(df['disease'].unique()),
        'Average symptoms per case': df.iloc[:, 1:].sum(axis=1).mean(),
        'Memory usage (MB)': df.memory_usage(deep=True).sum() / 1024**2
    }
    
    for k, v in stats.items():
        print(f"{k}: {v:.2f}")
    
    return stats

dataset_stats = analyze_dataset_statistics(df)

## 3. Class Distribution Analysis

In [None]:
def plot_class_distribution(df, id_to_disease, top_n=20):
    """Plot distribution of diseases."""
    plt.figure(figsize=(15, 8))
    
    # Get disease counts and map IDs to names
    disease_counts = df['disease'].value_counts()
    disease_names = [id_to_disease[i] for i in disease_counts.index[:top_n]]
    
    # Create bar plot
    sns.barplot(x=disease_counts[:top_n], y=disease_names)
    
    plt.title(f'Top {top_n} Most Common Diseases')
    plt.xlabel('Number of Cases')
    plt.ylabel('Disease')
    
    # Save plot
    plt.tight_layout()
    plt.savefig('../reports/eda/class_distribution.png')
    plt.show()

plot_class_distribution(df, id_to_disease)

## 4. Symptom Frequency Analysis

In [None]:
def analyze_symptom_frequency(df):
    """Analyze and plot symptom frequencies."""
    # Calculate symptom frequencies
    symptom_cols = df.columns[1:]
    symptom_freq = df[symptom_cols].sum().sort_values(ascending=False)
    
    plt.figure(figsize=(15, 8))
    
    # Plot top 20 most common symptoms
    sns.barplot(x=symptom_freq[:20], y=symptom_freq.index[:20])
    
    plt.title('Top 20 Most Common Symptoms')
    plt.xlabel('Frequency')
    plt.ylabel('Symptom')
    
    plt.tight_layout()
    plt.savefig('../reports/eda/symptom_frequency.png')
    plt.show()
    
    return symptom_freq

symptom_frequencies = analyze_symptom_frequency(df)

## 5. Correlation Analysis

In [None]:
def plot_correlation_heatmap(df, n_symptoms=30):
    """Create and plot correlation heatmap for top symptoms."""
    # Get top N most frequent symptoms
    top_symptoms = symptom_frequencies.head(n_symptoms).index
    
    # Calculate correlation matrix
    corr_matrix = df[top_symptoms].corr()
    
    plt.figure(figsize=(15, 12))
    
    # Create heatmap
    sns.heatmap(
        corr_matrix,
        cmap='RdBu_r',
        center=0,
        annot=True,
        fmt='.2f',
        square=True,
        cbar_kws={'label': 'Correlation'}
    )
    
    plt.title('Symptom Correlation Heatmap')
    plt.tight_layout()
    plt.savefig('../reports/eda/correlation_heatmap.png')
    plt.show()

plot_correlation_heatmap(df)

## 6. Disease-Symptom Association Analysis

In [None]:
def analyze_disease_symptom_associations(df, id_to_disease, top_n=10):
    """Analyze which symptoms are most associated with each disease."""
    associations = {}
    
    for disease_id in df['disease'].unique():
        # Get cases for this disease
        disease_cases = df[df['disease'] == disease_id]
        
        # Calculate symptom frequencies for this disease
        symptom_freq = disease_cases.iloc[:, 1:].sum()
        
        # Store top N symptoms
        top_symptoms = symptom_freq.nlargest(top_n)
        associations[id_to_disease[disease_id]] = dict(top_symptoms)
    
    # Save associations to file
    with open('../reports/eda/disease_symptom_associations.json', 'w') as f:
        json.dump(associations, f, indent=2)
    
    return associations

disease_symptom_assoc = analyze_disease_symptom_associations(df, id_to_disease)

# Display example associations for one disease
example_disease = list(disease_symptom_assoc.keys())[0]
print(f"\nTop symptoms for {example_disease}:")
for symptom, freq in disease_symptom_assoc[example_disease].items():
    print(f"{symptom}: {freq}")

## 7. Save EDA Report

In [None]:
def generate_eda_report(stats, disease_mapping, symptom_freq):
    """Generate a markdown report summarizing EDA findings."""
    report = f"""\
# Exploratory Data Analysis Report

## Dataset Overview
- Total samples: {stats['Total samples']:.0f}
- Number of features: {stats['Number of features']:.0f}
- Number of unique diseases: {stats['Number of diseases']:.0f}
- Average symptoms per case: {stats['Average symptoms per case']:.2f}

## Disease Distribution
- Total unique diseases: {len(disease_mapping)}
- See visualization: `class_distribution.png`

## Symptom Analysis
- Total unique symptoms: {len(symptom_freq)}
- Most common symptoms:
  {chr(10).join(['  - ' + s + f': {v:.0f} cases' for s, v in symptom_freq.head().items()])}

## Visualizations
1. `class_distribution.png`: Distribution of diseases
2. `symptom_frequency.png`: Frequency of symptoms
3. `correlation_heatmap.png`: Symptom correlations

## Disease-Symptom Associations
Detailed associations can be found in `disease_symptom_associations.json`
"""
    
    with open('../reports/eda/eda_report.md', 'w') as f:
        f.write(report)
    
    print("EDA report generated successfully!")

generate_eda_report(dataset_stats, disease_mapping, symptom_frequencies)