In [1]:
import warnings
from typing import Any, Literal

import numpy as np
import pandas as pd
import polars as pl
from rich.console import Console
from rich.theme import Theme

custom_theme = Theme(
    {
        "white": "#FFFFFF",  # Bright white
        "info": "#00FF00",  # Bright green
        "warning": "#FFD700",  # Bright gold
        "error": "#FF1493",  # Deep pink
        "success": "#00FFFF",  # Cyan
        "highlight": "#FF4500",  # Orange-red
    }
)
console = Console(theme=custom_theme)

# Visualization
# import matplotlib.pyplot as plt

# NumPy settings
np.set_printoptions(precision=4)

# Pandas settings
pd.options.display.max_rows = 1_000
pd.options.display.max_columns = 1_000
pd.options.display.max_colwidth = 600

# Polars settings
pl.Config.set_fmt_str_lengths(1_000)
pl.Config.set_tbl_cols(n=1_000)
pl.Config.set_tbl_rows(n=200)

warnings.filterwarnings("ignore")

# Black code formatter (Optional)
%load_ext lab_black

# auto reload imports
%load_ext autoreload
%autoreload 2

In [2]:
def go_up_from_current_directory(*, go_up: int = 1) -> None:
    """This is used to up a number of directories.

    Params:
    -------
    go_up: int, default=1
        This indicates the number of times to go back up from the current directory.

    Returns:
    --------
    None
    """
    import os
    import sys

    CONST: str = "../"
    NUM: str = CONST * go_up

    # Goto the previous directory
    prev_directory = os.path.join(os.path.dirname(__name__), NUM)
    # Get the 'absolute path' of the previous directory
    abs_path_prev_directory = os.path.abspath(prev_directory)

    # Add the path to the System paths
    sys.path.insert(0, abs_path_prev_directory)
    print(abs_path_prev_directory)

In [3]:
fp: str = "../../../../Documents/data_dump/bike_data/database.parquet"
data: pl.DataFrame = pl.read_parquet(fp)
console.print(f"Shape: {data.shape}", style="info")

data.head()

datetime,season,yr,mnth,hr,holiday,weekday,workingday,weathersit,temp,atemp,hum,windspeed,casual,registered,cnt
str,i64,i64,i64,i64,i64,i64,i64,i64,f64,f64,f64,f64,i64,i64,i64
"""2011-01-01 00:00:00""",1,0,1,0,0,6,0,1,0.24,0.2879,0.81,0.0,3,13,16
"""2011-01-01 01:00:00""",1,0,1,1,0,6,0,1,0.22,0.2727,0.8,0.0,8,32,40
"""2011-01-01 02:00:00""",1,0,1,2,0,6,0,1,0.22,0.2727,0.8,0.0,5,27,32
"""2011-01-01 03:00:00""",1,0,1,3,0,6,0,1,0.24,0.2879,0.75,0.0,3,10,13
"""2011-01-01 04:00:00""",1,0,1,4,0,6,0,1,0.24,0.2879,0.75,0.0,0,1,1


In [4]:
import narwhals as nw
from narwhals.typing import IntoDataFrameT, IntoFrameT


def func(df: IntoDataFrameT, s: nw.Series, col_name: str) -> int:
    return nw.from_native(df).filter(nw.col(col_name).is_in(s)).shape[0]


df = pd.DataFrame({"a": [1, 1, 2, 2, 3], "b": [4, 5, 6, 7, 8]})
s = pd.Series([1, 3])
print(func(df, s.to_numpy(), "a"))

3


In [5]:
go_up_from_current_directory(go_up=1)
from eda import EDA  # noqa: E402

/Users/mac/Desktop/Projects/Bike-Rental-Prediction


In [6]:
# Create sample data
rng = np.random.default_rng(42)
dummy_data = pd.DataFrame(
    {
        "age": rng.normal(35, 10, 1000),
        "salary": rng.exponential(50000, 1000),
        "score": rng.uniform(0, 100, 1000),
        "department": rng.choice(["Sales", "Engineering", "Marketing"], 1000),
        "experience": rng.choice(["Junior", "Mid", "Senior"], 1000, p=[0.4, 0.4, 0.2]),
        "target": rng.normal(75, 15, 1000),
    }
)

# Initialize EDA
eda = EDA(dummy_data, target_column="target")

# Print summary
eda.print_summary()

# Get numeric statistics
print("\nNumeric Statistics:")
print("--" * 10)
numeric_stats = eda.numeric_summary()
display(numeric_stats)

# Get categorical statistics
print("\nCategorical Statistics:")
print("--" * 10)
cat_stats = eda.categorical_summary()
for col, stats in cat_stats.items():
    print(f"\n{col}:")
    display(stats)

EXPLORATORY DATA ANALYSIS SUMMARY
* Dataset Shape: (1000, 6)
* Numeric Columns: 4
* Categorical Columns: 2
* Total Missing Values: 0
* Memory Usage: 0.14 MB
* Target Column: target

* Numeric Columns:
  - age
  - salary
  - score
  - target

* Categorical Columns:
  - department
  - experience


Numeric Statistics:
--------------------


Unnamed: 0,Column,Count,Missing,Missing_Pct,Unique,Mean,Median,Mode,Std,Variance,Min,Max,Range,IQR,Q25,Q50,Q75,Skewness,Kurtosis,Outliers_IQR,Outliers_ZScore
0,age,1000,0,0.0,1000,34.711084,35.061779,-1.484128,9.892171,97.85504,-1.484128,66.788537,68.272665,12.861997,28.036871,35.061779,40.898867,-0.043754,0.091884,9,2
1,salary,1000,0,0.0,1000,50779.113101,35946.365196,3.561308,51331.619471,2634935000.0,3.561308,380733.295997,380729.734689,52320.270647,15872.055529,35946.365196,68192.326176,2.195049,6.813502,59,20
2,score,1000,0,0.0,1000,49.574605,49.894015,0.098999,28.779592,828.2649,0.098999,99.976534,99.877534,50.159562,24.523089,49.894015,74.68265,0.022874,-1.192196,0,0
3,target,1000,0,0.0,1000,74.829598,75.007343,27.845005,15.03318,225.9965,27.845005,123.618886,95.773881,20.539317,64.520023,75.007343,85.05934,-0.050504,0.041808,9,5



Categorical Statistics:
--------------------

department:


Unnamed: 0,Category,Count,Percentage
0,Sales,367,36.7
1,Marketing,319,31.9
2,Engineering,314,31.4



experience:


Unnamed: 0,Category,Count,Percentage
0,Junior,415,41.5
1,Mid,398,39.8
2,Senior,187,18.7


In [7]:
eda._calculate_outliers_iqr(series=dummy_data["salary"])
pl.int_range(0, 5)

In [8]:
# Create visualizations
fig1 = eda.plot_numeric_distribution()
fig1.show()

fig2 = eda.plot_categorical_distribution()
fig2.show()

fig3 = eda.plot_correlation_heatmap()
fig3.show()

# Group analysis
group_stats = eda.group_analysis("department", ["age", "salary"])
print(group_stats)

# Outlier detection
fig4 = eda.plot_outliers(method="iqr")
fig4.show()

              age                                        salary             \
            count    mean  median     std    min     max  count       mean   
department                                                                   
Engineering   314  34.775  34.914   9.638  8.272  64.051    314  52495.812   
Marketing     319  34.871  34.720  10.331  5.355  66.789    319  48969.550   
Sales         367  34.517  35.309   9.740 -1.484  60.493    367  50883.220   

                                                        
                median        std      min         max  
department                                              
Engineering  36020.482  54388.395  450.713  330203.937  
Marketing    35823.306  47278.564   24.211  311165.033  
Sales        35973.482  52104.043    3.561  380733.296  


In [152]:
from typing import cast

import narwhals as nw
import narwhals.selectors as n_cs
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
import polars as pl
import pyarrow as pa
from narwhals.typing import IntoDataFrameT
from plotly.subplots import make_subplots
from scipy.stats import entropy, spearmanr


class ExploratoryDataAnalysis:
    EMPTY_DATAFRAME: str = "🚫 Empty dataframe"
    NO_NUMERIC_COLUMNS: str = "🚫 No numeric columns available"
    NO_CATEGORICAL_COLUMNS: str = "🚫 No categorical columns available"

    def __init__(self, data: IntoDataFrameT, target_column: str | None = None) -> None:
        # Store the original data reference to check its type later
        self._original_data = data
        self.data = nw.from_native(data)
        self.target_column = target_column
        self.numeric_columns = self._get_numeric_columns()
        self.categorical_columns = self._get_categorical_columns()
        self.boolean_columns = self._get_boolean_columns()

    def _convert_to_native(self, df: pl.DataFrame | nw.DataFrame) -> IntoDataFrameT:
        """Convert Polars DataFrame or Narwhals DataFrame to the original dataframe type.

        Parameters
        ----------
        df : pl.DataFrame | nw.DataFrame
            The DataFrame to convert (either Polars or Narwhals).

        Returns
        -------
        IntoDataFrameT
            The converted DataFrame in the same format as the original input data.

        """
        # Handle Narwhals DataFrame
        if isinstance(df, nw.DataFrame):  # Narwhals DataFrame
            df = df.to_native()  # type: ignore

        # Now convert to the original format
        if isinstance(self._original_data, pd.DataFrame):
            if isinstance(df, pl.DataFrame):
                return cast(IntoDataFrameT, df.to_pandas())
            elif isinstance(df, pd.DataFrame):
                return cast(IntoDataFrameT, df)

        if isinstance(self._original_data, pa.Table):
            if isinstance(df, pl.DataFrame):
                return cast(IntoDataFrameT, df.to_arrow())
            elif isinstance(df, pa.Table):
                return cast(IntoDataFrameT, df)

        # Default: return as Polars DataFrame
        return cast(IntoDataFrameT, df)

    def _get_numeric_columns(self) -> list[str]:
        """Get numeric columns from the DataFrame."""
        return self.data.select(n_cs.numeric()).columns

    def _get_categorical_columns(self) -> list[str]:
        """Get categorical columns from the DataFrame."""
        return self.data.select(n_cs.string()).columns

    def _get_boolean_columns(self) -> list[str]:
        """Get boolean columns from the DataFrame."""
        return self.data.select(n_cs.boolean()).columns

    @staticmethod
    def _select_valid_columns(
        actual_cols: list[str], selected_cols: list[str]
    ) -> list[str]:
        return list(set(actual_cols) & set(selected_cols))

    def _calculate_outliers_iqr(self, series: nw.Series) -> tuple[nw.Series, nw.Series]:
        """Calculate outliers using the Interquartile Range (IQR) method.

        The IQR method identifies outliers as data points that fall below Q1 - 1.5*IQR
        or above Q3 + 1.5*IQR, where Q1 is the 25th percentile, Q3 is the 75th percentile,
        and IQR = Q3 - Q1.

        Parameters
        ----------
        series : nw.Series
            The numeric series to analyze for outliers.

        Returns
        -------
        tuple[nw.Series, nw.Series]
            A tuple containing:
            - normal_points: Series with values within the normal range
            - outliers: Series with values identified as outliers

        Notes
        -----
        This method uses the "nearest" interpolation method for quantile calculations
        to ensure consistent results across different dataframe backends.
        """
        Q1: float = series.quantile(0.25, interpolation="nearest")
        Q3: float = series.quantile(0.75, interpolation="nearest")
        IQR: float = Q3 - Q1
        lower_bound: float = Q1 - 1.5 * IQR
        upper_bound: float = Q3 + 1.5 * IQR
        normal_points: nw.Series = series.filter(
            (series >= lower_bound) & (series <= upper_bound)
        )
        outliers: nw.Series = series.filter(
            (series < lower_bound) | (series > upper_bound)
        )

        return normal_points, outliers

    def _calculate_outliers_zscore(
        self, series: nw.Series, threshold: float = 3.0
    ) -> tuple[nw.Series, nw.Series]:
        """Calculate outliers using the Z-score method.

        The Z-score method identifies outliers as data points whose absolute Z-score
        exceeds the specified threshold. Z-score is calculated as (x - mean) / std.

        Parameters
        ----------
        series : nw.Series
            The numeric series to analyze for outliers.
        threshold : float, default 3.0
            The Z-score threshold above which points are considered outliers.
            Common values are 2.0 (95% confidence) or 3.0 (99.7% confidence).

        Returns
        -------
        tuple[nw.Series, nw.Series]
            A tuple containing:
            - normal_points: Series with Z-scores within the threshold
            - outliers: Series with Z-scores exceeding the threshold

        Notes
        -----
        This method assumes the data follows a normal distribution. For non-normal
        distributions, the IQR method may be more appropriate.
        """
        mean: float = series.mean()
        std: float = series.std()
        z_scores: nw.Series = (series - mean) / std
        normal_points: nw.Series = series.filter(z_scores.abs() <= threshold)
        outliers: nw.Series = series.filter(z_scores.abs() > threshold)

        return normal_points, outliers

    def correlation_analysis(
        self,
        method: Literal["pearson", "spearman"] = "pearson",
        **kwargs: Any,
    ) -> IntoFrameT:
        """Calculate correlation matrix for numeric columns."""

        if len(self.numeric_columns) < 2:
            raise ValueError(
                "🚫 At least two numeric columns are required for correlation analysis."
            )

        X: np.ndarray[Any, Any] = self.data.select(self.numeric_columns).to_numpy()

        if method == "pearson":
            matrix: np.ndarray = np.corrcoef(X, rowvar=False, **kwargs)

        else:  # spearman
            matrix = spearmanr(X, axis=0, **kwargs).correlation
            if np.isscalar(matrix):
                # Convert to a matrix
                matrix = np.array([[1.0, matrix], [matrix, 1.0]])

        result: pl.DataFrame = pl.DataFrame(matrix, schema=self.numeric_columns)
        return self._convert_to_native(result)

    def numeric_summary(self, columns: list[str] | None = None) -> IntoFrameT:
        """Get summary statistics for numeric columns."""
        columns = (
            self._select_valid_columns(self.numeric_columns, columns)
            if columns
            else self.numeric_columns
        )

        summary_stats: list[Any] = []

        for col in columns:
            series = self.data[col]

            if len(series) == 0:
                print(self.EMPTY_DATAFRAME)
                continue

            # Central tendency: mean, median and mode
            mean: float = series.mean().__round__(2)
            median: float = series.median().__round__(2)
            mode: list[float] = series.mode().to_list()

            # Spread: std, variance, range, iqr_value, min, max
            std: float = series.std().__round__(2)
            variance: float = series.var().__round__(2)
            data_range: float = series.max() - series.min()
            iqr_value: float = series.quantile(
                0.75, interpolation="nearest"
            ) - series.quantile(0.25, interpolation="nearest")
            min_value: float = series.min()
            max_value: float = series.max()

            # Distribution shape: skewness and kurtosis
            skewness: float = series.skew().__round__(2)
            kurtosis: float = series.kurtosis().__round__(2)

            # Others: count, missing_values, unique_values
            count: int = series.count()
            missing_values: int = series.is_null().sum()
            missing_pct: float = (missing_values / series.shape[0]).__round__(2)
            unique_values: int = series.n_unique()

            # Outliers
            _, outlier_series_iqr = self._calculate_outliers_iqr(series)
            outlier_count_iqr = outlier_series_iqr.count()

            _, outlier_series_zscore = self._calculate_outliers_zscore(series)
            outlier_count_zscore = outlier_series_zscore.count()

            summary_stats.append(
                {
                    "column": col,
                    "mean": mean,
                    "median": median,
                    "mode": mode,
                    "std": std,
                    "variance": variance,
                    "range": data_range,
                    "iqr_value": iqr_value,
                    "min": min_value,
                    "max": max_value,
                    "skewness": skewness,
                    "kurtosis": kurtosis,
                    "outlier_series_iqr": outlier_series_iqr.to_list(),
                    "outlier_count_iqr": outlier_count_iqr,
                    "outlier_series_zscore": outlier_series_zscore.to_list(),
                    "outlier_count_zscore": outlier_count_zscore,
                    "total_count": count,
                    "missing_values": missing_values,
                    "missing_pct": missing_pct,
                    "unique_values": unique_values,
                }
            )

        # Create summary as Polars DataFrame first
        summary_df: pl.DataFrame = pl.from_records(summary_stats)

        return self._convert_to_native(summary_df)

    def categorical_summary(
        self, columns: list[str] | None = None
    ) -> dict[str, IntoFrameT]:
        """Get summary statistics for categorical columns."""
        columns = (
            self._select_valid_columns(self.categorical_columns, columns)
            if columns
            else self.categorical_columns + self.boolean_columns
        )

        summary_stats: list[dict[str, Any]] = []

        for col in columns:
            series = self.data[col]

            if len(series) == 0:
                continue

            # Frequency counts and percentages
            value_counts = series.value_counts().to_numpy()

            # Basic stats: count, missing_values, missing_pct, unique_values
            count: int = series.count()
            missing_values: int = series.is_null().sum()
            missing_pct: float = (missing_values / series.shape[0] * 100).__round__(2)
            unique_values: int = series.n_unique()

            # Entropy (measure of uncertainty or randomness)
            non_null_series = series.drop_nulls()
            if len(non_null_series) > 0:
                vc_non_null = non_null_series.value_counts(sort=True, normalize=False)
                entropy_value: float = entropy(vc_non_null["count"], base=10).__round__(
                    2
                )
            else:
                entropy_value = 0.0

            summary_stats.append(
                {
                    "column": col,
                    "total_count": count,
                    "unique_values": unique_values,
                    "entropy": entropy_value,
                    "value_counts": value_counts,
                    "missing_values": missing_values,
                    "missing_pct": missing_pct,
                }
            )

        summary_df: pl.DataFrame = pl.from_records(summary_stats)

        return self._convert_to_native(summary_df)

    def group_analysis(
        self, groupby: str, numeric_cols: list[str] | None = None
    ) -> IntoFrameT:
        # Ensure the column is a valid cat column
        if groupby not in self.categorical_columns:
            raise ValueError(f"🚫 {groupby!r} must be a categorical variable")
        numeric_cols = (
            self._select_valid_columns(self.numeric_columns, numeric_cols)
            if numeric_cols
            else self.numeric_columns
        )
        if len(numeric_cols) == 0:
            raise ValueError("🚫 No valid numeric columns found.")

        return (
            self.data.select(numeric_cols + [groupby])
            .group_by(groupby)
            .agg(
                n_cs.numeric().count().name.suffix("_count"),
                n_cs.numeric().mean().round(2).name.suffix("_mean"),
                n_cs.numeric().median().round(2).name.suffix("_median"),
                n_cs.numeric().std().round(2).name.suffix("_std"),
                n_cs.numeric().min().name.suffix("_min"),
                n_cs.numeric().max().name.suffix("_max"),
            )
            .to_native()
        )

    def plot_numeric_distribution(
        self,
        columns: list[str] | None = None,
        plot_type: Literal["all", "histogram", "box", "violin"] = "all",
    ) -> go.Figure:
        columns = (
            self._select_valid_columns(self.numeric_columns, columns)
            if columns
            else self.numeric_columns
        )
        if not columns:
            print(self.NO_NUMERIC_COLUMNS)
            return go.Figure()

        n_cols: int = min(3, len(columns))
        n_rows: int = (len(columns) + n_cols - 1) // n_cols

        if plot_type == "all":
            fig = make_subplots(
                rows=3,
                cols=len(columns),
                subplot_titles=[f"{col} - Histogram" for col in columns]
                + [f"{col} - Box Plot" for col in columns]
                + [f"{col} - Violin Plot" for col in columns],
                vertical_spacing=0.1,
            )

            for i, col in enumerate(columns):
                # Histogram
                fig.add_trace(
                    go.Histogram(
                        x=self.data[col], name=f"{col}_hist", showlegend=False
                    ),
                    row=1,
                    col=i + 1,
                )

                # Box plot
                fig.add_trace(
                    go.Box(y=self.data[col], name=f"{col}_box", showlegend=False),
                    row=2,
                    col=i + 1,
                )

                # Violin plot
                fig.add_trace(
                    go.Violin(y=self.data[col], name=f"{col}_violin", showlegend=False),
                    row=3,
                    col=i + 1,
                )

            fig.update_layout(
                height=800,
                width=1000,
                title_text="Numeric Distributions - All Plot Types",
            )

        else:
            fig = make_subplots(
                rows=n_rows, cols=n_cols, subplot_titles=[f"{col}" for col in columns]
            )

            for i, col in enumerate(columns):
                row = i // n_cols + 1
                col_idx = i % n_cols + 1

                if plot_type == "histogram":
                    fig.add_trace(
                        go.Histogram(x=self.data[col], name=col, showlegend=False),
                        row=row,
                        col=col_idx,
                    )
                elif plot_type == "box":
                    fig.add_trace(
                        go.Box(y=self.data[col], name=col, showlegend=False),
                        row=row,
                        col=col_idx,
                    )
                elif plot_type == "violin":
                    fig.add_trace(
                        go.Violin(y=self.data[col], name=col, showlegend=False),
                        row=row,
                        col=col_idx,
                    )

            fig.update_layout(
                height=300 * n_rows,
                width=400 * n_rows,
                title_text=f"Numeric Distributions - {plot_type.title()}",
            )

        return fig

    def plot_categorical_distribution(
        self,
        columns: list[str] | None = None,
        plot_type: Literal["all", "bar", "pie"] = "all",
    ) -> go.Figure:
        columns = (
            self._select_valid_columns(self.categorical_columns, columns)
            if columns
            else self.categorical_columns
        )
        if not columns:
            print(self.NO_CATEGORICAL_COLUMNS)
            return go.Figure()

        if plot_type == "all":
            n_rows = len(columns)
            # Create proper subplot titles for alternating bar and pie charts
            subplot_titles = []
            for col in columns:
                subplot_titles.extend([f"{col} - Bar Chart", f"{col} - Pie Chart"])

            fig = make_subplots(
                rows=n_rows,
                cols=2,
                specs=[[{"type": "xy"}, {"type": "domain"}] for _ in range(n_rows)],
                subplot_titles=subplot_titles,
            )

            for i, col in enumerate(columns):
                value_counts = (
                    self.data[col]
                    .value_counts(sort=True, normalize=True)
                    .with_columns((nw.col("proportion") * 100).round(1))
                )

                # Bar chart
                fig.add_trace(
                    go.Bar(
                        x=value_counts[col].to_numpy(),
                        y=value_counts["proportion"].to_numpy(),
                        name=f"{col}_bar",
                        showlegend=False,
                    ),
                    row=i + 1,
                    col=1,
                )

                # Pie chart
                fig.add_trace(
                    go.Pie(
                        labels=value_counts[col].to_numpy(),
                        values=value_counts["proportion"].to_numpy(),
                        name=f"{col}_pie",
                        showlegend=False,
                    ),
                    row=i + 1,
                    col=2,
                )

            fig.update_layout(
                height=400 * len(columns),
                width=420 * len(columns),
                title_text="Categorical Distributions",
            )

        else:
            n_cols = min(3, len(columns))
            n_rows = (len(columns) + n_cols - 1) // n_cols

            if plot_type == "pie":
                specs = [
                    [{"type": "domain"} for _ in range(n_cols)] for _ in range(n_rows)
                ]
            else:
                specs = None

            fig = make_subplots(
                rows=n_rows,
                cols=n_cols,
                specs=specs,
                subplot_titles=[f"{col}" for col in columns],
            )

            for i, col in enumerate(columns):
                row = i // n_cols + 1
                col_idx = i % n_cols + 1
                value_counts = (
                    self.data[col]
                    .value_counts(sort=True, normalize=True)
                    .with_columns((nw.col("proportion") * 100).round(1))
                )

                if plot_type == "bar":
                    fig.add_trace(
                        go.Bar(
                            x=value_counts[col].to_numpy(),
                            y=value_counts["proportion"].to_numpy(),
                            name=col,
                            showlegend=False,
                        ),
                        row=row,
                        col=col_idx,
                    )
                elif plot_type == "pie":
                    fig.add_trace(
                        go.Pie(
                            labels=value_counts[col].to_numpy(),
                            values=value_counts["proportion"].to_numpy(),
                            name=col,
                            showlegend=False,
                        ),
                        row=row,
                        col=col_idx,
                    )

            fig.update_layout(
                height=300 * n_rows,
                width=400 * n_rows,
                title_text=f"Categorical Distributions - {plot_type.title()}",
            )

        return fig

    def plot_correlation_heatmap(
        self, method: Literal["pearson", "spearman"] = "pearson"
    ) -> go.Figure:
        corr_matrix: IntoFrameT = self.correlation_analysis(method=method)

        if len(corr_matrix) == 0:
            print(self.EMPTY_DATAFRAME)
            return go.Figure()

        fig = go.Figure(
            data=go.Heatmap(
                z=corr_matrix.to_numpy(),
                x=corr_matrix.columns,
                y=corr_matrix.columns,
                colorscale="RdBu",
                zmid=0,
                text=np.round(corr_matrix.to_numpy(), 3),
                texttemplate="%{text}",
                textfont={"size": 10},
                hovertemplate="%{x} vs %{y}<br>Correlation: %{z:.3f}<extra></extra>",
            )
        )

        fig.update_layout(
            height=600,
            width=800,
            title=f"{method.title()} Correlation Matrix",
            xaxis_title="Variables",
            yaxis_title="Variables",
        )

        return fig

    def plot_outliers(
        self, columns: list[str] | None = None, method: Literal["iqr", "zscore"] = "iqr"
    ) -> go.Figure:
        columns = (
            self._select_valid_columns(self.numeric_columns, columns)
            if columns
            else self.numeric_columns
        )

        if not columns:
            print(self.NO_CATEGORICAL_COLUMNS)
            return go.Figure()

        n_cols = min(3, len(columns))
        n_rows = (len(columns) + n_cols - 1) // n_cols

        fig = make_subplots(
            rows=n_rows,
            cols=n_cols,
            subplot_titles=[f"{col} - Outliers ({method})" for col in columns],
            vertical_spacing=0.1,
            horizontal_spacing=0.05,
        )

        for i, col in enumerate(columns):
            row: int = i // n_cols + 1
            col_idx: int = i % n_cols + 1

            series = self.data[col].drop_nulls()

            if method == "iqr":
                normal_points, outliers = self._calculate_outliers_iqr(series)
            else:  # zscore
                normal_points, outliers = self._calculate_outliers_zscore(series)

            # Plot normal points
            fig.add_trace(
                go.Scatter(
                    x=np.arange(0, len(normal_points)),
                    y=normal_points.to_numpy(),
                    mode="markers",
                    name=f"{col}_normal",
                    marker={
                        "color": "lightblue",
                        "size": 2,
                        "opacity": 0.7,
                        "line": {"width": 1, "color": "blue"},
                    },
                    showlegend=False,
                ),
                row=row,
                col=col_idx,
            )

            # Plot outliers
            if len(outliers) > 0:
                fig.add_trace(
                    go.Scatter(
                        x=np.arange(0, len(outliers)),
                        y=outliers.to_numpy(),
                        mode="markers",
                        name=f"{col}_outliers",
                        marker={
                            "color": "red",
                            "size": 4,
                            "symbol": "diamond",
                            "line": {"width": 2, "color": "darkred"},
                        },
                        showlegend=False,
                    ),
                    row=row,
                    col=col_idx,
                )

        fig.update_layout(
            height=300 * n_rows,
            width=450 * n_rows,
            title_text="Outliers Detection",
        )
        return fig

    def plot_group_analysis(
        self,
        groupby: str,
        numeric_col: str,
        plot_type: Literal["bar", "box", "scatter", "violin"] = "bar",
    ) -> go.Figure:
        data = self._convert_to_native(self.data)
        if groupby not in self.categorical_columns:
            raise ValueError(f"🚫 {groupby!r} is not a categorical column")

        if numeric_col not in self.numeric_columns:
            raise ValueError(f"🚫 {numeric_col!r} is not a numeric column")

        if plot_type == "box":
            fig = px.box(
                data,
                x=groupby,
                y=numeric_col,
                title=f"{numeric_col} by {groupby} - Box Plot",
            )
        elif plot_type == "violin":
            fig = px.violin(
                data,
                x=groupby,
                y=numeric_col,
                title=f"{numeric_col} by {groupby} - Violin Plot",
            )
        elif plot_type == "bar":
            grouped_data = self.data.group_by(groupby).agg(n_cs.numeric().mean())
            grouped_data = self._convert_to_native(grouped_data)
            fig = px.bar(
                grouped_data,
                x=groupby,
                y=numeric_col,
                title=f"Average {numeric_col} by {groupby}",
            )
        elif plot_type == "scatter":
            # For scatter plot, we'll use the index as x-axis
            fig = px.scatter(
                data,
                x=groupby,
                y=numeric_col,
                title=f"{numeric_col} by {groupby} - Scatter Plot",
            )
        else:
            raise ValueError(
                '🚫 plot_type must be "box", "violin", "bar", or "scatter"'
            )

        return fig

    def generate_full_report(self) -> dict[str, Any]:
        return {
            "dataset_info": {
                "shape": self.data.shape,
                "numeric_columns": len(self.numeric_columns),
                "categorical_columns": len(self.categorical_columns),
                "boolean_columns": len(self.boolean_columns),
                "total_columns": self.data.shape[1],
                "total_rows": self.data.shape[0],
                "missing_values_total": self.data.null_count().to_numpy().sum().item(),
                "memory_usage": f"{round(self.data.estimated_size(unit='mb'), 2)} MB",
            },
            "numeric_summary": self.numeric_summary(),
            "categorical_summary": self.categorical_summary(),
            "correlation_matrix": self.correlation_analysis(),
        }

    def display_all_plots(
        self,
        outlier_method: Literal["iqr", "zscore"] = "iqr",
        numeric_cols: list[str] | None = None,
    ) -> None:
        numeric_cols = (
            self._select_valid_columns(self.numeric_columns, numeric_cols)
            if numeric_cols
            else self.numeric_columns
        )

        # Create visualizations
        fig1 = self.plot_numeric_distribution()
        fig1.show()

        fig2 = self.plot_categorical_distribution()
        fig2.show()

        fig3 = self.plot_correlation_heatmap()
        fig3.show()

        fig4 = self.plot_outliers(method=outlier_method)
        fig4.show()

    def print_summary(self) -> None:
        """Print a quick summary of the dataset."""
        print("=" * 60)
        print("🚀 EXPLORATORY DATA ANALYSIS SUMMARY")
        print("=" * 60)
        print(f"* Dataset Shape: {self.data.shape}")
        print(f"* Total Rows: {self.data.shape[0]}")
        print(f"* Total Columns: {self.data.shape[1]}")
        print(f"* Numeric Columns: {len(self.numeric_columns)}")
        print(f"* Categorical Columns: {len(self.categorical_columns)}")
        print(f"* Boolean Columns: {len(self.boolean_columns)}")
        print(
            f"* Total Missing Values: {self.data.null_count().to_numpy().sum().item()}"
        )
        print(f"* Memory Usage: {round(self.data.estimated_size(unit='mb'), 2)}  MB")

        if self.target_column:
            print(f"* Target Column: {self.target_column}")

        print("\n* Numeric Columns:")
        for col in self.numeric_columns:
            print(f"  - {col}")

        print("\n* Categorical Columns:")
        for col in self.categorical_columns:
            print(f"  - {col}")

        print("=" * 60)
        print()

In [22]:
actual_num_cols: list[str] = ["name", "age", "salary", "score", "target"]
num_col: list[str] = ["age", "salary", "score", "not_included"]

valid_cols = list(set(actual_num_cols) & set(num_col))
valid_cols


def _select_valid_columns(
    actual_cols: list[str], selected_cols: list[str]
) -> list[str]:
    return list(set(actual_cols) & set(selected_cols))

In [153]:
# Test with both Pandas and Polars data
df_pandas = pd.DataFrame(
    {
        "name": ["Alice", "Bob", "Charlie", "Anya", "Eve", "Olivia"],
        "age": [25, 30, 35, 28, 25, 38],
        "role": ["Engineer", "Manager", "Engineer", "Engineer", "HR", "Engineer"],
        "salary": [85000, 55000, 65000, 55000, 48000, 72000],
        "self_employed": [True, False, True, False, True, False],
    }
)

df_polars = pl.DataFrame(
    {
        "name": ["Alice", "Bob", "Charlie", "Anya", "Eve", "Olivia"],
        "age": [25, 30, 35, 28, 25, 38],
        "role": ["Engineer", "Manager", "Engineer", "Engineer", "HR", "Engineer"],
        "salary": [85000, 55000, 65000, 55000, 48000, 72000],
        "self_employed": [True, False, True, False, True, False],
    }
)

# print("Testing with Pandas DataFrame:")
# print("=" * 40)
# eda_pandas = ExploratoryDataAnalysis(df_pandas)
# result_pandas = eda_pandas.numeric_summary()
# print(f"Result type: {type(result_pandas)}")
# print()

# print("Testing with Polars DataFrame:")
# print("=" * 40)
# eda_polars = ExploratoryDataAnalysis(df_polars)
# result_polars = eda_polars.numeric_summary()
# print(f"Result type: {type(result_polars)}")

# print("\nPandas result:")
# display(result_pandas)

# print("\nPolars result:")
# display(result_polars)


eda_pandas = ExploratoryDataAnalysis(df_polars)

# eda_pandas.correlation_analysis()
eda_pandas.display_all_plots()

In [24]:
df_polars

name,age,role,salary,self_employed
str,i64,str,i64,bool
"""Alice""",25,"""Engineer""",85000,True
"""Bob""",30,"""Manager""",55000,False
"""Charlie""",35,"""Engineer""",65000,True
"""Anya""",28,"""Engineer""",55000,False
"""Eve""",25,"""HR""",48000,True
"""Olivia""",38,"""Engineer""",72000,False


In [25]:
nw_df = nw.from_native(df_polars)
nw_df

┌─────────────────────────────────────────────────────┐
|                 Narwhals DataFrame                  |
|-----------------------------------------------------|
|shape: (6, 5)                                        |
|┌─────────┬─────┬──────────┬────────┬───────────────┐|
|│ name    ┆ age ┆ role     ┆ salary ┆ self_employed │|
|│ ---     ┆ --- ┆ ---      ┆ ---    ┆ ---           │|
|│ str     ┆ i64 ┆ str      ┆ i64    ┆ bool          │|
|╞═════════╪═════╪══════════╪════════╪═══════════════╡|
|│ Alice   ┆ 25  ┆ Engineer ┆ 85000  ┆ true          │|
|│ Bob     ┆ 30  ┆ Manager  ┆ 55000  ┆ false         │|
|│ Charlie ┆ 35  ┆ Engineer ┆ 65000  ┆ true          │|
|│ Anya    ┆ 28  ┆ Engineer ┆ 55000  ┆ false         │|
|│ Eve     ┆ 25  ┆ HR       ┆ 48000  ┆ true          │|
|│ Olivia  ┆ 38  ┆ Engineer ┆ 72000  ┆ false         │|
|└─────────┴─────┴──────────┴────────┴───────────────┘|
└─────────────────────────────────────────────────────┘

In [114]:
df_polars.estimated_size(unit="mb")
nw_df.null_count().to_numpy().sum().item()

0

In [103]:
df_polars.null_count().sum().sum_horizontal()

sum
u32
0


In [82]:
eda = ExploratoryDataAnalysis(data=df_polars)
# eda.group_analysis(groupby="role", numeric_cols=["age"])

# eda.plot_categorical_distribution(plot_type="all")
# eda.plot_correlation_heatmap(method="pearson")

# eda._calculate_outliers_iqr(series=nw_df["age"])[0]

# pl.int_range(0, 5)
# np.arange(0, 5)


eda.plot_group_analysis(groupby="role", numeric_col="age", plot_type="scatter")

In [None]:
fp: str = "../data/bikes_2024.parquet"
data_table: pl.DataFrame = pl.read_parquet(fp)
console.print(f"Shape: {data_table.shape}", style="info")

data_table.head(200)