# EDA
The objective of this notebook is to briefly visualize the distribution of the data

In [None]:
import pandas as pd
import os
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
DATA_PATH = os.path.join("data", "train.csv")
df = pd.read_csv(DATA_PATH)

In [None]:
df.head()

## Categorical Data

In [None]:
for column in df.select_dtypes(include=['object', 'category']).columns:
    if len(df[column].unique()) <= 42:
        plt.figure(figsize=(10, 6))
        df[column].value_counts().plot(kind='bar')
        plt.title(f'Distribution of {column}')
        plt.xlabel(column)
        plt.ylabel('Count')
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.show()

## Categorical Income Distribution

In [None]:
for column in df.select_dtypes(include=['object', 'category']).columns:
    if column == 'income': continue

    if len(df[column].unique()) <= 42:
        ct = pd.crosstab(df[column], df['income'], normalize='index')
        ct_sorted = ct.sort_values(by='>50K', ascending=True)
        
        ax = ct_sorted.plot(kind='barh', stacked=True, figsize=(12, max(6, len(ct_sorted) * 0.4)))
        
        plt.title(f'Distribution of {column} by Income')
        plt.xlabel('Proportion')
        plt.ylabel(column)
        plt.legend(title='Income', bbox_to_anchor=(1.05, 1), loc='upper left')
        
        for c in ax.containers:
            ax.bar_label(c, fmt='%.1f%%', label_type='center')
        
        plt.tight_layout()
        plt.show()

# Continuous

In [None]:
for column in df.columns:
    if len(df[column].unique()) > 42:
        # Distribution
        plt.figure(figsize=(10, 6))
        sns.histplot(df[column], kde=True)
        plt.title(f'Distribution of {column}')
        plt.xlabel(column)
        plt.ylabel('Count')

        # Vertical lines for mean and median
        plt.axvline(df[column].mean(), color='r', linestyle='--', label='Mean')
        plt.axvline(df[column].median(), color='g', linestyle='-.', label='Median')

        plt.legend()
        plt.tight_layout()
        plt.show()

        # # Summary statistics
        # stats = df[column].describe()
        # print(stats)
        # print(f"Median: {df[column].median()}")
        # print('\n')

In [None]:
for column in df.columns:
    if len(df[column].unique()) > 42:
        # Distribution
        plt.figure(figsize=(12, 6))
        sns.histplot(data=df[df['income'] == '<=50K'], x=column, kde=True, color='blue', alpha=0.5, label='<=50K')
        sns.histplot(data=df[df['income'] == '>50K'], x=column, kde=True, color='red', alpha=0.5, label='>50K')

        plt.title(f'Distribution of {column} by Income')
        plt.xlabel(column)
        plt.ylabel('Count')
        plt.legend()
        plt.tight_layout()
        plt.show()

        # # Summary statistics
        # stats = df[column].describe()
        # print(f"\nSummary statistics for {column}:")
        # print(stats)
        # print(f"Median: {df[column].median()}")
        # print('\n')

        # median_low = df[df['income'] == '<=50K'][column].median()
        # median_high = df[df['income'] == '>50K'][column].median()
        # print(f"Median for <=50K: {median_low}")
        # print(f"Median for >50K: {median_high}")
        # print("\n" + "="*50 + "\n")