In [None]:
import pandas as pd
import numpy as np

In [None]:
class DataPreprocessor:

    def __init__(self, df):

        self.__df = df
        self.__feature_dtypes=self.__df.dtypes

    def _get_missing_values(self):

        missing_values = self.__df.isnull().sum()
        missing_values = missing_values[missing_values > 0]
        missing_values = missing_values.sort_values(ascending=False)
        
        return missing_values
    
    def _get_zero_values(self):

        zero_values = (self.__df == 0).sum()
        zero_values = zero_values[zero_values > 0]
        zero_values = zero_values.sort_values(ascending=False)

        return zero_values
    
    def _info_table(self, feature_names, feature_dtypes, values, value_type='zero', hide_header=False):

        print("=" * 50)
        if len(feature_names) == 0:

            if not hide_header: print(f'====> This data does not contain any {value_type} values')
            print()    

            return

        if not hide_header: print(f'====> This data contains {value_type} values')
        print()
    

        print("{:13} {:13} {:30} {:15}".format('Feature Name'.upper(),
                                            'Data Format'.upper(),
                                            f'{value_type} (Num-Perc)'.upper(),
                                            'Seven Samples'.upper()))
        for feature_name, dtype, value in zip(feature_names,feature_dtypes[feature_names],values):
            feature_total_rows=self.__df[feature_name].shape[0]
            print("{:15} {:14} {:20}".format(feature_name,
                                            str(dtype), 
                                            str(value) + ' - ' + 
                                            str(round(100*value/feature_total_rows,3))+' %'), end="")

            for i in range(0, 7):
                print(self.__df[feature_name].iloc[i], end=",")
            print()
        
        print()

    def _info(self):  

        missing_values=self._get_missing_values()
        feature_names_missing=missing_values.index.values
        missing_values=missing_values.values

        zero_values=self._get_zero_values()
        feature_names_zero=zero_values.index.values
        zero_values=zero_values.values

        rows, columns=self.__df.shape

        print("=" * 50)
        print('====> This data contains {} rows and {} columns'.format(rows,columns))
        print()
        
        self._info_table(feature_names_zero,
                            self.__feature_dtypes,
                            zero_values)

        self._info_table(feature_names_missing,
                            self.__feature_dtypes,
                            missing_values,
                            'null')
        
        print("="*50)

    def remove_columns(self, columns):

        self.__df=self.__df.drop(columns, axis=1)

    def fill_with_group_median(self, columns):

        for column in columns:

            group_medians=self.__df.groupby('group')[column].transform(lambda x: round(x[x != 0].median(), 2))
            self.__df[column] = self.__df[column].mask(self.__df[column] == 0, group_medians)

        self._info_table(columns,
                            self.__feature_dtypes,
                            np.zeros(len(columns)),
                            'zero',
                            hide_header=True)
        
    def get_data(self):

        return self.__df

    def save_data(self, path):
            
        self.__df.to_csv(path, index=False)

    def information(self):

        return self._info()
        

In [None]:
stats_t1 = pd.read_csv('../stats_t1/merged_data.csv')
stats_flair = pd.read_csv('../stats/merged_data.csv')

In [None]:
data_preprocessor_t1 = DataPreprocessor(stats_t1)
data_preprocessor_t1.information()

In [None]:
data_preprocessor_t1.remove_columns(['left_wm_hypointensities',
                                     'right_wm_hypointensities',
                                     'left_non_wm_hypointensities',
                                     'right_non_wm_hypointensities',
                                     'non_wm_hypointensities',
                                     '5th_ventricle'])
data_preprocessor_t1.information()

In [None]:
print("="*50)
print(f'====> T1w data after inputation')

data_preprocessor_t1.fill_with_group_median(['left_vessel', 'right_vessel', 'optic_chiasm'])

In [None]:
data_preprocessor_t1.save_data('../stats_t1/preprocessed_data.csv')

In [None]:
data_preprocessor_flair = DataPreprocessor(stats_flair)
data_preprocessor_flair.information()

In [None]:
data_preprocessor_flair.remove_columns(['left_wm_hypointensities',
                                     'right_wm_hypointensities',
                                     'left_non_wm_hypointensities',
                                     'right_non_wm_hypointensities',
                                     'non_wm_hypointensities',
                                     '5th_ventricle'])
data_preprocessor_flair.information()

In [None]:
print("="*50)
print(f'====> FLAIR data after inputation')

data_preprocessor_flair.fill_with_group_median(['left_vessel', 'right_vessel', 'optic_chiasm'])

In [None]:
data_preprocessor_flair.save_data('../stats/preprocessed_data.csv')