In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import itertools
import upsetplot
from statsmodels.graphics.mosaicplot import mosaic

In [None]:


@pd.api.extensions.register_dataframe_accessor("missing")
class DontMissMe:
    """
    Pandas DataFrame accessor to handle missing values analysis and visualization.

    Attributes:
        _obj (pd.DataFrame): The DataFrame to which the accessor is attached.
    """

    def __init__(self, pandas_obj):
        """
        Initialize the DontMissMe accessor with a DataFrame.

        Args:
            pandas_obj (pd.DataFrame): The DataFrame to analyze for missing values.
        """
        self._obj = pandas_obj

    def number_missing(self) -> int:
        """
        Calculate the total number of missing values in the DataFrame.

        Returns:
            int: Total number of missing values.
        """
        return self._obj.isna().sum().sum()

    def number_complete(self) -> int:
        """
        Calculate the total number of complete (non-missing) values in the DataFrame.

        Returns:
            int: Total number of complete values.
        """
        return self._obj.size - self.number_missing()

    def missing_variable_summary(self) -> pd.DataFrame:
        """
        Summarize missing values by variable.

        Returns:
            pd.DataFrame: Summary of missing values by variable, including the percentage of missing values.
        """
        return self._obj.isnull().pipe(
            lambda df_1: (
                df_1.sum()
                .reset_index(name="n_missing")
                .rename(columns={"index": "variable"})
                .assign(
                    n_cases=len(df_1),
                    pct_missing=lambda df_2: df_2.n_missing / df_2.n_cases * 100,
                )
            )
        )

    def missing_case_summary(self) -> pd.DataFrame:
        """
        Summarize missing values by case (row).

        Returns:
            pd.DataFrame: Summary of missing values by case, including the percentage of missing values.
        """
        return self._obj.assign(
            case=lambda df: df.index,
            n_missing=lambda df: df.apply(
                axis="columns", func=lambda row: row.isna().sum()
            ),
            pct_missing=lambda df: df["n_missing"] / df.shape[1] * 100,
        )[["case", "n_missing", "pct_missing"]]

    def missing_variable_table(self) -> pd.DataFrame:
        """
        Create a table summarizing the number of missing values per variable.

        Returns:
            pd.DataFrame: Table summarizing the number of missing values per variable.
        """
        summary_df = self.missing_variable_summary()
        return (
            summary_df
            .groupby('n_missing')
            .size()
            .reset_index(name='n_variables')
            .rename(columns={'n_missing': 'n_missing_in_variable'})
            .assign(
                pct_variables=lambda df: df.n_variables / df.n_variables.sum() * 100
            )
            .sort_values("pct_variables", ascending=False)
        )

    def missing_case_table(self) -> pd.DataFrame:
        """
        Create a table summarizing the number of missing values per case (row).

        Returns:
            pd.DataFrame: Table summarizing the number of missing values per case.
        """
        summary_df = self.missing_case_summary()
        return (
            summary_df
            .groupby('n_missing')
            .size()
            .reset_index(name='n_cases')
            .rename(columns={'n_missing': 'n_missing_in_case'})
            .assign(pct_case=lambda df: df.n_cases / df.n_cases.sum() * 100)
            .sort_values("pct_case", ascending=False)
        )

    def missing_variable_span(self, variable: str, span_every: int) -> pd.DataFrame:
        """
        Analyze missing values over spans of a specified size for a given variable.

        Args:
            variable (str): The variable to analyze.
            span_every (int): The span size.

        Returns:
            pd.DataFrame: Summary of missing values over spans.
        """
        return (
            self._obj.assign(
                span_counter=lambda df: (
                    np.repeat(a=range(df.shape[0]), repeats=span_every)[: df.shape[0]]
                )
            )
            .groupby("span_counter")
            .aggregate(
                n_in_span=(variable, "size"),
                n_missing=(variable, lambda s: s.isnull().sum()),
            )
            .assign(
                n_complete=lambda df: df.n_in_span - df.n_missing,
                pct_missing=lambda df: df.n_missing / df.n_in_span * 100,
                pct_complete=lambda df: 100 - df.pct_missing,
            )
            .drop(columns=["n_in_span"])
            .reset_index()
        )

    def missing_variable_run(self, variable) -> pd.DataFrame:
        """
        Analyze runs of missing and complete values for a given variable.

        Args:
            variable (str): The variable to analyze.

        Returns:
            pd.DataFrame: Summary of runs of missing and complete values.
        """
        rle_list = self._obj[variable].pipe(
            lambda s: [[len(list(g)), k] for k, g in itertools.groupby(s.isnull())]
        )
        return pd.DataFrame(data=rle_list, columns=["run_length", "is_na"]).replace(
            {False: "complete", True: "missing"}
        )

    def sort_variables_by_missingness(self, ascending=False):
        """
        Sort variables by the number of missing values.

        Args:
            ascending (bool, optional): Sort order. Defaults to False.

        Returns:
            pd.DataFrame: DataFrame with variables sorted by missingness.
        """
        return self._obj.pipe(
            lambda df: df[df.isna().sum().sort_values(ascending=ascending).index]
        )

    def create_shadow_matrix(
        self,
        true_string: str = "Missing",
        false_string: str = "Not Missing",
        only_missing: bool = False,
        suffix: str = "_NA",
    ) -> pd.DataFrame:
        """
        Create a shadow matrix indicating missing values.

        Args:
            true_string (str, optional): Label for missing values. Defaults to "Missing".
            false_string (str, optional): Label for non-missing values. Defaults to "Not Missing".
            only_missing (bool, optional): Include only columns with missing values. Defaults to False.
            suffix (str, optional): Suffix to add to column names. Defaults to "_NA".

        Returns:
            pd.DataFrame: Shadow matrix indicating missing values.
        """
        return (
            self._obj.isna()
            .pipe(lambda df: df[df.columns[df.any()]] if only_missing else df)
            .replace({False: false_string, True: true_string})
            .add_suffix(suffix)
        )

    def bind_shadow_matrix(
        self,
        true_string: str = "Missing",
        false_string: str = "Not Missing",
        only_missing: bool = False,
        suffix: str = "_NA",
    ) -> pd.DataFrame:
        """
        Bind the shadow matrix to the original DataFrame.

        Args:
            true_string (str, optional): Label for missing values. Defaults to "Missing".
            false_string (str, optional): Label for non-missing values. Defaults to "Not Missing".
            only_missing (bool, optional): Include only columns with missing values. Defaults to False.
            suffix (str, optional): Suffix to add to column names. Defaults to "_NA".

        Returns:
            pd.DataFrame: DataFrame with shadow matrix bound to the original DataFrame.
        """
        return pd.concat(
            objs=[
                self._obj,
                self._obj.missing.create_shadow_matrix(
                    true_string=true_string,
                    false_string=false_string,
                    only_missing=only_missing,
                    suffix=suffix,
                ),
            ],
            axis="columns",
        )

    def missing_scan_count(self, search) -> pd.DataFrame:
        """
        Scan for specific values indicating missingness and count their occurrences.

        Args:
            search (list or set): Values to search for indicating missingness.

        Returns:
            pd.DataFrame: Count of specified values indicating missingness by variable.
        """
        return (
            self._obj.apply(axis="rows", func=lambda column: column.isin(search))
            .sum()
            .reset_index()
            .rename(columns={"index": "variable", 0: "n"})
            .assign(original_type=self._obj.dtypes.reset_index()[0])
        )

    # Plotting functions ---
    def missing_variable_plot(self):
        """
        Plot the number of missing values per variable.
        """
        df = self._obj.missing.missing_variable_summary().sort_values("n_missing")
        plot_range = range(1, len(df.index) + 1)
        plt.hlines(y=plot_range, xmin=0, xmax=df.n_missing, color="black")
        plt.plot(df.n_missing, plot_range, "o", color="black")
        plt.yticks(plot_range, df.variable)
        plt.grid(axis="y")
        plt.xlabel("Number missing")
        plt.ylabel("Variable")

    def missing_case_plot(self):
        """
        Plot the number of missing values per case (row).
        """
        df = self._obj.missing.missing_case_summary()
        sns.displot(data=df, x="n_missing", binwidth=1, color="black")
        plt.grid(axis="x")
        plt.xlabel("Number of missings in case")
        plt.ylabel("Number of cases")

    def missing_variable_span_plot(
        self, variable: str, span_every: int, rot: int = 0, figsize=None
    ):
        """
        Plot the percentage of missing values over spans for a given variable.

        Args:
            variable (str): The variable to plot.
            span_every (int): The span size.
            rot (int, optional): Rotation of x-axis labels. Defaults to 0.
            figsize (tuple, optional): Size of the figure. Defaults to None.
        """
        self._obj.missing.missing_variable_span(
            variable=variable, span_every=span_every
        ).plot.bar(
            x="span_counter",
            y=["pct_missing", "pct_complete"],
            stacked=True,
            width=1,
            color=["black", "lightgray"],
            rot=rot,
            figsize=figsize,
        )
        plt.xlabel("Span number")
        plt.ylabel("Percentage missing")
        plt.legend(["Missing", "Present"])
        plt.title(
            f"Percentage of missing values\nOver a repeating span of {span_every}",
            loc="left",
        )
        plt.grid(False)
        plt.margins(0)
        plt.tight_layout(pad=0)

    def missing_upsetplot(self, variables: list[str] = None, **kwargs):
        """
        Create an upset plot to visualize the intersections of missing values across variables.

        Args:
            variables (list[str], optional): List of variables to include in the plot. Defaults to None.
            **kwargs: Additional keyword arguments passed to `upsetplot.plot`.

        Returns:
            matplotlib.axes.Axes: The plot axis.
        """
        if variables is None:
            variables = self._obj.columns.tolist()
        return (
            self._obj.isna()
            .value_counts(variables)
            .pipe(lambda df: upsetplot.plot(df, **kwargs))
        )

    def scatter_imputation_plot(
        self, x, y, imputation_suffix="_imp", show_marginal=False, **kwargs
    ):
        """
        Create a scatter plot showing imputed values.

        Args:
            x (str): The x-axis variable.
            y (str): The y-axis variable.
            imputation_suffix (str, optional): Suffix indicating imputed values. Defaults to "_imp".
            show_marginal (bool, optional): Whether to show marginal plots. Defaults to False.
            **kwargs: Additional keyword arguments passed to the plotting function.

        Returns:
            seaborn.axisgrid.JointGrid or seaborn.axisgrid.FacetGrid: The plot object.
        """
        x_imputed = f"{x}{imputation_suffix}"
        y_imputed = f"{y}{imputation_suffix}"
        plot_func = sns.scatterplot if not show_marginal else sns.jointplot
        return (
            self._obj[[x, y, x_imputed, y_imputed]]
            .assign(is_imputed=lambda df: df[x_imputed] | df[y_imputed])
            .pipe(lambda df: plot_func(data=df, x=x, y=y, hue="is_imputed", **kwargs))
        )

    def missing_mosaic_plot(
        self,
        target_var: str,
        x_categorical_var: str,
        y_categorical_var: str,
        ax=None
    ):
        """
        Create a mosaic plot to visualize the relationship between missing values and categorical variables.

        Args:
            target_var (str): The target variable to check for missing values.
            x_categorical_var (str): The x-axis categorical variable.
            y_categorical_var (str): The y-axis categorical variable.
            ax (matplotlib.axes.Axes, optional): The plot axis. Defaults to None.

        Returns:
            matplotlib.axes.Axes: The plot axis.
        """
        return (
            self._obj
            .assign(
                **{target_var: lambda df: df[target_var].isna().replace([True, False], ["NA", "!NA"])}
            )
            .groupby(
                [x_categorical_var, y_categorical_var, target_var],
                dropna=False,
                as_index=True,
            )
            .size()
            .pipe(
                lambda df: mosaic(
                    data=df,
                    properties=lambda key: {"color": "r" if "NA" in key else "gray"},
                    ax=ax,
                    horizontal=True,
                    axes_label=True,
                    title="",
                    labelizer=lambda key: "",
                )
            )
        )

