In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pointbiserialr

# ----------- STEP 1: LOAD AND INSPECT THE DATA -----------
df = pd.read_parquet("../data/processed/all_user_combined_data_processed.parquet")

print("=== Dataset Overview ===")
print(f"Shape: {df.shape}")
print("\n--- Data Types ---")
print(df.dtypes)
print("\n--- First 5 rows ---")
print(df.head())

print("\n--- Statistical Summary ---")
print(df.describe(include='all'))


target_col = 'Depression_label'  
if target_col in df.columns:
    print(f"\n--- Target Variable ('{target_col}') Distribution ---")
    print(df[target_col].value_counts(normalize=True))
else:
    print(f"Target column '{target_col}' not found!")

# ----------- STEP 2: CHECK FOR MISSING DATA -----------
print("\n=== Missing Data ===")
missing = df.isnull().sum()
missing = missing[missing > 0].sort_values(ascending=False)
print(missing)

# Visualize missing data
plt.figure(figsize=(10, 6))
sns.heatmap(df.isnull(), cbar=False, yticklabels=False, cmap='viridis')
plt.title('Missing Values Heatmap')
plt.show()

# ----------- STEP 3: STATISTICAL SUMMARY & TARGET BALANCE -----------
# Numeric columns summary
numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
print("\n=== Numeric Features Summary ===")
print(df[numeric_cols].describe())

# Categorical columns summary
cat_cols = df.select_dtypes(include='object').columns.tolist()
print("\n=== Categorical Features Summary ===")
for col in cat_cols:
    print(f"\nValue counts for {col}:")
    print(df[col].value_counts())

# Plot target variable balance
if target_col in df.columns:
    plt.figure(figsize=(6,4))
    sns.countplot(data=df, x=target_col)
    plt.title('Target Variable Distribution')
    plt.show()

# ----------- STEP 4: FEATURE RELATIONSHIPS -----------

# Correlation matrix for numeric features
plt.figure(figsize=(12,10))
corr = df[numeric_cols].corr()
sns.heatmap(corr, annot=True, fmt=".2f", cmap='coolwarm')
plt.title('Correlation Matrix for Numeric Features')
plt.show()

# Point-biserial correlations of numeric features with binary target
def compute_pointbiserial_correlations(df, feature_cols, label_col):
    correlations = {}
    for col in feature_cols:
        if col == label_col:
            continue
        try:
            if df[col].nunique() > 1:
                corr, pval = pointbiserialr(df[label_col], df[col])
                correlations[col] = (corr, pval)
        except Exception as e:
            print(f"Skipping {col} due to error: {e}")
    return pd.DataFrame.from_dict(correlations, orient='index', columns=['Correlation', 'p-value']).sort_values('Correlation', ascending=False)

if target_col in df.columns:
    pb_corr_df = compute_pointbiserial_correlations(df, numeric_cols, target_col)
    print("\n=== Point-Biserial Correlations with Target ===")
    print(pb_corr_df.head(10))

# ----------- STEP 5: VISUALIZATIONS -----------

# Histograms / distributions of features grouped by target
for col in numeric_cols:
    if col == target_col:
        continue
    plt.figure(figsize=(8,4))
    sns.histplot(data=df, x=col, hue=target_col, kde=True, element='step', stat='density')
    plt.title(f'Distribution of {col} by {target_col}')
    plt.show()

# Boxplots of numeric features by target
for col in numeric_cols:
    if col == target_col:
        continue
    plt.figure(figsize=(6,4))
    sns.boxplot(data=df, x=target_col, y=col)
    plt.title(f'{col} by {target_col}')
    plt.show()

# Pairplot (limited to few features to avoid overload)
subset_features = numeric_cols[:5] + [target_col] if target_col in df.columns else numeric_cols[:5]
sns.pairplot(df[subset_features], hue=target_col, diag_kind='kde')
plt.show()

# ----------- STEP 6: DETECT OUTLIERS AND ANOMALIES -----------

from scipy.stats import zscore

# Calculate Z-scores and flag outliers for numeric features
outliers = {}
for col in numeric_cols:
    if col == target_col:
        continue
    z_scores = np.abs(zscore(df[col].dropna()))
    outlier_indices = df[col].dropna().index[z_scores > 3].tolist()
    outliers[col] = len(outlier_indices)

print("\n=== Number of Outliers per Numeric Feature (Z-score > 3) ===")
for col, count in outliers.items():
    print(f"{col}: {count}")

# Visualize outliers using boxplots already above, or:

# ----------- STEP 7: FEATURE ENGINEERING INSIGHTS -----------

print("\n--- Feature Engineering Insights ---")

# Check for constant features (zero variance)
constant_features = [col for col in df.columns if df[col].nunique() == 1]
print(f"Constant features (zero variance): {constant_features}")

# Check for duplicated columns (optional)
duplicated_cols = []
for i in range(len(df.columns)):
    for j in range(i + 1, len(df.columns)):
        if df.iloc[:, i].equals(df.iloc[:, j]):
            duplicated_cols.append((df.columns[i], df.columns[j]))
print(f"Duplicated columns: {duplicated_cols}")

# Suggest scaling for numeric features (can use StandardScaler or MinMaxScaler later)
print("Consider scaling numeric features before modeling.")

# Encoding for categorical features (e.g., one-hot or label encoding)
print(f"Categorical features to encode: {cat_cols}")

# ----------- STEP 8: DATA QUALITY AND BIAS CHECK -----------

print("\n--- Data Quality & Bias Checks ---")

# Check data distribution for key demographic features like age, gender if present
for col in ['age', 'gender']:
    if col in df.columns:
        print(f"\nDistribution of {col}:")
        print(df[col].value_counts(normalize=True))
        plt.figure(figsize=(6,4))
        sns.countplot(data=df, x=col)
        plt.title(f'Distribution of {col}')
        plt.show()

# Check for class imbalance in target
if target_col in df.columns:
    balance = df[target_col].value_counts(normalize=True)
    print(f"\nClass distribution in target '{target_col}':")
    print(balance)

# Check for duplicate rows
dupes = df.duplicated().sum()
print(f"\nNumber of duplicate rows: {dupes}")

# Check for impossible or suspicious values (negative ages, etc.) if domain relevant
if 'age' in df.columns:
    invalid_ages = df[df['age'] < 0].shape[0]
    print(f"Number of invalid (negative) age entries: {invalid_ages}")

print("\n--- End of EDA ---")
