In [None]:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

class PSICalculator:
    """
    A class to calculate the Population Stability Index (PSI) and analyze population shifts.

    Attributes:
        data (pd.DataFrame): The input dataset containing the feature of interest.
        quantiles (int): Number of quantiles (buckets) to divide the data into.
        training_period (str): The label for the training period in the dataset.
        new_period (str): The label for the new/test period in the dataset.
        channel (str, optional): A column to filter the data by specific channel, if applicable.
        column (str): The column on which the PSI calculation will be performed.
    """

    def __init__(self, data, quantiles=10, training_period='train', new_period='test', channel=None, column=None):
        """
        Initializes the PSICalculator class with the dataset and parameters.

        Args:
            data (pd.DataFrame): The input dataset.
            quantiles (int): Number of quantiles (default is 10).
            training_period (str): Label for the training period (default is 'train').
            new_period (str): Label for the new period (default is 'test').
            channel (str, optional): Filter data for a specific channel (default is None).
            column (str): Column to perform PSI calculation on.
        """
        self.data = data
        self.quantiles = quantiles
        self.training_period = training_period
        self.new_period = new_period
        self.channel = channel
        self.column = column

    def filter_data(self):
        """
        Filters the data for a specific channel if specified.

        Returns:
            pd.DataFrame: Filtered dataset.
        """
        if self.channel:
            return self.data[self.data['channel'] == self.channel]
        return self.data

    def calculate_quantiles(self, data, column):
        """
        Divides the data into myquantiles and returns the breakpoints.

        Args:
            data (pd.DataFrame): Dataset for the training period.
            column (str): Column to calculate quantiles for.

        Returns:
            pd.Series, np.ndarray: Quantile bins for the data and the breakpoints.
        """
        return pd.qcut(data[column], q=self.quantiles, duplicates='drop', retbins=True)

    def calculate_psi(self):
        """
        Calculates the Population Stability Index (PSI) between the training and new periods.

        Returns:
            tuple:
                - psi (float): The overall PSI value.
                - df (pd.DataFrame): A summary DataFrame with bucket details.
        """
        # Step 1: Filter data by channel if applicable
        filtered_data = self.filter_data()

        # Step 2: Separate the training and new data
        train_data = filtered_data[filtered_data['period'] == self.training_period]
        new_data = filtered_data[filtered_data['period'] == self.new_period]

        # Step 3: Calculate quantiles and bin the data
        train_binned, breakpoints = self.calculate_quantiles(train_data, self.column)
        new_binned = pd.cut(new_data[self.column], bins=breakpoints)

        # Step 4: Count observations in each bucket
        initial_counts = train_binned.value_counts(sort=False)
        new_counts = new_binned.value_counts(sort=False)

        # Step 5: Create a summary DataFrame
        df = pd.DataFrame({
            'Bucket': np.arange(1, len(breakpoints)),  # Bucket labels
            'Breakpoint Value': breakpoints[1:],       # Breakpoint values
            'Initial Count': initial_counts.values,    # Training period counts
            'New Count': new_counts.values             # New period counts
        })
        df['Initial Percent'] = df['Initial Count'] / len(train_data)
        df['New Percent'] = df['New Count'] / len(new_data)

        # Step 6: Calculate PSI for each bucket
        psi_values = (df['Initial Percent'] - df['New Percent']) * np.log(df['Initial Percent'] / df['New Percent'])
        df['PSI Value'] = psi_values
        psi = psi_values.sum()

        # Step 7: Raise alert if PSI exceeds the threshold
        if psi > 0.25:
            print("ALERT: Significant population shift detected (PSI > 0.25). Model recalibration is recommended.")

        return psi, df

    def plot_buckets(self, df):
        """
        Visualizes the initial and new buckets' counts using a bar plot.

        Args:
            df (pd.DataFrame): The summary DataFrame from calculate_psi().

        Returns:
            None: Displays the bar plot.
        """
        # Reshape data for easier plotting
        plot_df = df.melt(id_vars="Bucket",
                          value_vars=["Initial Count", "New Count"],
                          var_name="Period",
                          value_name="Count")

        # Create bar plot
        plt.figure(figsize=(12, 6))
        sns.barplot(data=plot_df, x="Bucket", y="Count", hue="Period", palette="viridis")
        plt.title("Bucket Counts: Initial vs New Period", fontsize=16)
        plt.xlabel("Bucket", fontsize=14)
        plt.ylabel("Count", fontsize=14)
        plt.legend(title="Period", fontsize=12)
        plt.show()