# 1. Data Loading & Preprocessing

This notebook handles the initial data loading and preprocessing steps for our symptom-to-disease prediction pipeline.

## Objectives
1. Load large dataset efficiently using chunking
2. Validate data quality
3. Normalize disease labels
4. Handle multi-label cases
5. Save processed dataset

## Setup and Dependencies

In [None]:
# Install required packages
!python -m pip install --upgrade pip
!pip install pandas numpy scikit-learn joblib python-dotenv tqdm

In [None]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from pathlib import Path
import json
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set display options
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 100)

## 1. Data Loading

We'll load the data in chunks to handle the large file size efficiently.

In [None]:
def load_data_in_chunks(file_path, chunksize=10000):
    """Load large dataset in chunks and combine.
    
    Args:
        file_path (str): Path to the data file
        chunksize (int): Number of rows per chunk
        
    Returns:
        pd.DataFrame: Combined dataset
    """
    chunks = []
    total_rows = sum(1 for _ in open(file_path)) - 1  # Subtract header
    
    with tqdm(total=total_rows, desc="Loading data") as pbar:
        for chunk in pd.read_csv(file_path, chunksize=chunksize):
            chunks.append(chunk)
            pbar.update(len(chunk))
    
    return pd.concat(chunks, ignore_index=True)

# Load the dataset
DATA_PATH = "../data/raw/disease_symptom_dataset.csv"  # Update with your file path
try:
    df = load_data_in_chunks(DATA_PATH)
    print(f"Dataset loaded successfully with shape: {df.shape}")
except FileNotFoundError:
    print(f"Please place the dataset at {DATA_PATH}")

## 2. Data Validation

Check for data quality issues:

In [None]:
def validate_dataset(df):
    """Perform data validation checks.
    
    Args:
        df (pd.DataFrame): Input dataset
        
    Returns:
        dict: Validation results
    """
    results = {
        'total_rows': len(df),
        'duplicates': df.duplicated().sum(),
        'missing_values': df.isnull().sum().sum(),
        'non_binary_columns': []
    }
    
    # Check for non-binary values in symptom columns
    symptom_cols = df.columns[1:]  # All columns except disease
    for col in symptom_cols:
        unique_vals = df[col].unique()
        if not all(val in [0, 1] for val in unique_vals if pd.notna(val)):
            results['non_binary_columns'].append(col)
    
    return results

validation_results = validate_dataset(df)
print("\nValidation Results:")
print(f"Total rows: {validation_results['total_rows']}")
print(f"Duplicate rows: {validation_results['duplicates']}")
print(f"Missing values: {validation_results['missing_values']}")
print(f"Columns with non-binary values: {len(validation_results['non_binary_columns'])}")

## 3. Disease Label Normalization

In [None]:
def normalize_disease_labels(df):
    """Normalize disease labels and save mapping.
    
    Args:
        df (pd.DataFrame): Input dataset
        
    Returns:
        tuple: (transformed DataFrame, label encoder)
    """
    le = LabelEncoder()
    df_processed = df.copy()
    
    # Fit and transform disease labels
    df_processed['disease'] = le.fit_transform(df['disease'])
    
    # Save label mapping
    label_mapping = dict(zip(le.classes_, le.transform(le.classes_)))
    with open('../data/processed/disease_mapping.json', 'w') as f:
        json.dump(label_mapping, f, indent=2)
    
    print(f"Normalized {len(label_mapping)} unique disease labels")
    return df_processed, le

df_normalized, label_encoder = normalize_disease_labels(df)

## 4. Handle Multi-label Cases

In [None]:
def handle_multi_label_diseases(df):
    """Process rows with multiple disease labels.
    
    Args:
        df (pd.DataFrame): Input dataset
        
    Returns:
        pd.DataFrame: Processed dataset
    """
    # Check for multi-label indicators (e.g., commas or semicolons in disease names)
    multi_label_rows = df['disease'].str.contains('[,;]', regex=True)
    
    if multi_label_rows.any():
        print(f"Found {multi_label_rows.sum()} multi-label rows")
        
        # Split multi-label rows into separate rows
        new_rows = []
        for idx, row in df[multi_label_rows].iterrows():
            diseases = [d.strip() for d in row['disease'].split(',')]
            for disease in diseases:
                new_row = row.copy()
                new_row['disease'] = disease
                new_rows.append(new_row)
        
        # Replace original multi-label rows with split rows
        df = pd.concat([
            df[~multi_label_rows],
            pd.DataFrame(new_rows)
        ], ignore_index=True)
    
    return df

df_processed = handle_multi_label_diseases(df_normalized)

## 5. Save Processed Dataset

In [None]:
# Save processed dataset
output_path = '../data/processed/processed_data.csv'
df_processed.to_csv(output_path, index=False)
print(f"Saved processed dataset to {output_path}")

# Save dataset statistics
stats = {
    'n_samples': len(df_processed),
    'n_features': len(df_processed.columns) - 1,  # Exclude disease column
    'n_classes': len(df_processed['disease'].unique()),
    'memory_usage': df_processed.memory_usage(deep=True).sum() / 1024**2  # MB
}

with open('../data/processed/dataset_stats.json', 'w') as f:
    json.dump(stats, f, indent=2)

print("\nDataset Statistics:")
for k, v in stats.items():
    print(f"{k}: {v}")