In [7]:
from prettytable import PrettyTable
import pandas as pd

class interrogate_single_dataset:

    describe_params = {'percentiles': None, 'include': None, 'exclude': None}
    cardinality_limit = 10

    def __init__(self, df, display=False):

        # Get shape of the dataframe (rows, columns)
        self.shape = df.shape

        # Get statistical info on dataframe
        self.stats = pd.DataFrame(df.describe(**self.describe_params))
        
        # Get info on dataframe
        self.table = self.pretty_table_missing_counts(df)

        # Get lists of feature names from the dataset
        self.columns = {
            'Numerical': [col for col in df.columns if df[col].dtype in ['int64', 'float64']],
            'Object': [col for col in df.columns if df[col].dtype == 'object'],
            'Categorical': [col for col in df.columns if df[col].nunique() < self.cardinality_limit and df[col].dtype == "object"],
            'High Cardinal': [col for col in df.columns if df[col].nunique() > self.cardinality_limit and df[col].dtype == "object"],
        }

        if display:
            self.print_output()

    @staticmethod
    def print_divider():
        print("\n","-"*100, "\n")

    def print_output(self):
        self.print_divider()
        print(f"Shape of the DataFrame is: {self.shape[0]} rows | {self.shape[1]} columns", "\n")
        print(f"DataFrame statistical information: \n{self.stats}", "\n")
        print("DataFrame dtypes and Counts:\n", self.table,"\n")
        self.print_divider()

    def pretty_table_missing_counts(self, df, cardinality_limit=10):

        '''
        Method takes 2 parameters (1 required and 2 optional) and creates a visually pleasing table using the PrettyTable package
        that displays all columns, datatypes, non-null/null counts, and whether cardinality is high
        Params:
            ► df (DataFrame) | Pandas DataFrame
            ► cardinality_limit (int) | Integer representing the limit for considering a column to have high cardinality. Default = 10 items
        Return:
            ► PrettyTable table
        '''

        table = PrettyTable()

        table.field_names = ['Column Name', 'Data Type', 'Non-Null Count', 'Missing Count', 'Unique', 'High Cardinality']

        for column in df.columns:
            data_type = str(df[column].dtype)
            non_null_count = df[column].count()
            missing_count = df.shape[0] - non_null_count
            uniques = df[column].nunique()
            cardinality = uniques > cardinality_limit
            table.add_row([column, data_type, non_null_count, missing_count, uniques, cardinality])

        return table

    

In [6]:
class extract_column_lists:
    
    '''
    Class contains a series of methods for extracting lists of columns from 2 datasets.
    No initialisation method. Each method below must be called manually to return the relevant
    list item

    Methods:
        ► get_dtype_cols(df, dtype)                         | Get list of specific dtype cols
        ► get_cat_cols(df, cardinality_limit)               | Get categorical cols
        ► get_numerical_cols(df)                            | Get numerical cols
        ► get_good_cols(train_df, test_df)                  | Get good cols (In both datasets)
        ► get_bad_cols(train_df, test_df)                   | Get bad cols (Not in both sets)
        ► get_low_cardinality_cols(df, cardinality_limit)   | Get low cardinality cols
        ► get_high_cardinality_cols(df, cardinality_limit)  | Get high cardinality cols

    '''

    def get_dtype_cols(self, df, dtype):
        '''
        Method returns a list of columns from within a dataframe that meet the dtype criteria
        Params:
            ► df (DataFrame) | Dataframe object
            ► dtype (str) | String for the dtype. Options = "object", "float", "integer"
        '''
        dtype_options = {"object", "float", "integer"}

        if dtype not in dtype_options:
            raise ValueError(f"Invalid dtype provided. Allowed options are: {dtype_options}")
        else:

            return [col for col in df.columns if df[col].dtype == dtype]
        
    def get_cat_cols(self, df, cardinality_limit=10):
        '''
        Method gets a list of columns that can be considered categorical.
        Specifies object dtypes and columns with low cardinality
        Params:
            ► df (DataFrame) | Dataframe for method to be applied over
            ► cardinality_limit (int) | Number of unique items within the columns
        '''
        return [col for col in df.columns if df[col].nunique() < cardinality_limit and df[col].dtype == "object"]
    
    def get_numerical_cols(self, df):
        '''
        Method gets a list of columns from a dataframe that are int64 and float64 dtype
        Params:
            ► df (DataFrame) | Dataframe for method to be applied over
        '''
        return [col for col in df.columns if df[col].dtype in ['int64', 'float64']]
        
    def get_good_cols(self, train_df, test_df):
        '''
        Method returns list of columns that can be safely ordinal encoded (If columns exist 
        within both trest and train datasets)
        Params:
            ► train (DataFrame) | Dataframe of training data
            ► test (DataFrame) | Dataframe of testing data
        '''

        object_cols = self.get_dtype_cols(train_df, 'object')
        
        return [col for col in object_cols if set(test_df[col]).issubset(set(train_df[col]))]
    

    def get_bad_cols(self, train_df, test_df):
        '''
        Method returns a list of problematic columns that can be dropped from the dataset (Those
        not present in both datasets)
        Params:
            ► train (DataFrame) | Dataframe of training data
            ► test (DataFrame) | Dataframe of testing data
        '''

        object_cols = self.get_dtype_cols(train_df, 'object')
        good_label_cols = self.get_good_cols(train_df, test_df)

        return list(set(object_cols)-set(good_label_cols))
    
    def get_low_cardinality_cols(self, df, cardinality_limit=10):
        '''
        Method gets a list of columns that can be considered low cardinality
        Params:
            ► df (DataFrame) | Dataframe for method to be applied over
            ► cardinality_limit (int) | Number of unique items within the columns
        '''

        object_cols = self.get_dtype_cols(df, 'object')

        return [col for col in object_cols if df[col].nunique() < cardinality_limit]
    
    def get_high_cardinality_cols(self, df, cardinality_limit=10):
        '''
        Method gets a list of columns that can be considered low cardinality
        Params:
            ► df (DataFrame) | Dataframe for method to be applied over
            ► cardinality_limit (int) | Number of unique items within the columns
        '''

        object_cols = self.get_dtype_cols(df, 'object')
        card_cols = self.get_low_cardinality_cols(df, 'object')

        return list(set(self.object_cols)-set(card_cols))
