In [173]:
import enum
from typing import List, Dict, Tuple, Optional, Union
import pandas as pd
from matplotlib import pyplot as plt

from dataclasses import dataclass

import findspark
findspark.init('/home/ubuntu/spark-3.2.1-bin-hadoop2.7')
import pyspark
spark = SparkSession.builder.appName('Iteration4').getOrCreate()

databases_path = '../datasets/'
PATH_IMAGES = '../tex/iterations/iteration_4/images/'


In [202]:
class DataFramesCSV(enum.Enum):
    ALCOHOL_CONSUMPTION_CSV = f"{databases_path}4_total-alcohol-consumption-per-capita-litres-of-pure-alcohol.csv"
    COUNTRY_MASTER_CSV = f"{databases_path}0_master_country_codes.csv"
    WHO_OBESITY_CSV = f"{databases_path}1_who_obesity.csv"
    MEAT_CONSUMPTION_CSV = f"{databases_path}2_meat_consumption.csv"
    HUNGER_CSV = f"{databases_path}5_global_hunger_index.csv"
    SMOKING_CSV = f"{databases_path}6_share-of-adults-who-smoke.csv"
    HAPPINESS_REPORT_CSV = f"{databases_path}3_happiness_report.csv"

class DataFramePreviousFieldNameOptions(enum.Enum):
    IS_NULL = 'isnull'
    D_TYPES = 'dtypes'
    COUNT = 'count'

COLUMN_RENAME_BY_DATASET = {
    DataFramesCSV.WHO_OBESITY_CSV: {
        'Numeric': 'percentage_obesity',
        'Countries, territories and areas': 'country',
        'WHO region': 'region',
        'Year': 'year',
    },
    DataFramesCSV.HAPPINESS_REPORT_CSV: {
        'year': 'year',
        'Country name': 'country',
        "Life Ladder": 'life_ladder',
        "Social support": 'social_support',
        "Freedom to make life choices": "freedom_to_make_life_choices",
        "Generosity": "generosity",
        "Perceptions of corruption": "perceptions_of_corruption",
        "Positive affect": "positive_affect",
        "Negative affect": "negative_affect",
    },
    DataFramesCSV.MEAT_CONSUMPTION_CSV: {
        'Code': 'country_code',
        'Year': 'year',
        "Meat, poultry | 00002734 || Food available for consumption | 0645pc || kilograms per year per capita": "poultry",
        "Meat, beef | 00002731 || Food available for consumption | 0645pc || kilograms per year per capita": "beef",
        "Meat, sheep and goat | 00002732 || Food available for consumption | 0645pc || kilograms per year per capita": "sheep_and_goat",
        "Meat, pig | 00002733 || Food available for consumption | 0645pc || kilograms per year per capita": "pig",
        "Fish and seafood | 00002960 || Food available for consumption | 0645pc || kilograms per year per capita": "fish_and_seafood",
    },
    DataFramesCSV.COUNTRY_MASTER_CSV: {
        'alpha-3': 'country_code',
        'name': 'country'
    },
    DataFramesCSV.HUNGER_CSV: {
        'Entity': 'country',
        'Year': 'year',
        'Global Hunger Index (2021)': 'hunger_index',
    },
    DataFramesCSV.SMOKING_CSV: {
        'Entity': 'country',
        'Year': 'year',
        'Prevalence of current tobacco use (% of adults)': 'prevalence_smoking',
    },
    DataFramesCSV.ALCOHOL_CONSUMPTION_CSV: {
        'Entity': 'country',
        'Year': 'year',
        'liters_of_pure_alcohol_per_capita': 'liters_of_pure_alcohol_per_capita',
    },
}


def capture_get_dataframe_info_image(
        table_name: str,
        name_file: str,
        df_spark: pyspark.sql.dataframe.DataFrame,
        previous_data: pd.Series = None,
        previous_data_name: DataFramePreviousFieldNameOptions = None,
        figure_size_height=10.0,
        figure_size_width=5.0,
):
    dataframe = df_spark.toPandas()
    info_object = {
        'columns': dataframe.columns.str[0:30].tolist(),
        'dtypes': dataframe.dtypes.tolist(),
        'count': dataframe.count().tolist(),
        'isnull': dataframe.isnull().sum().tolist(),
    }
    dataframe_info = pd.DataFrame(info_object)

    if previous_data is not None:
        current_data = dataframe_info[previous_data_name.value]
        dataframe_info[f'old {previous_data_name.value}'] = previous_data
        dataframe_info['change'] = previous_data - current_data

    fig, ax = plt.subplots(figsize=(figure_size_width, figure_size_height))
    ax.axis('off')
    ax.axis('tight')
    dataframe_info.reset_index(inplace=True)
    col_widths = [0.08, 0.35, 0.15, 0.15, 0.15, 0.15, 0.15]
    data_table = ax.table(
        cellText=dataframe_info.values,
        colLabels=[' '.join(col.split('_')) for col in dataframe_info.columns],
        colWidths=col_widths,
        loc='center'
    )

    if previous_data is not None:
        for (i, j), val in np.ndenumerate(dataframe_info.values):
            if j == 6 and val != 0:  # We look into the second column (j==1), and search for zero values
                data_table[(i + 1, j)].set_facecolor("red")
                data_table[(i + 1, j)].set_text_props(color='white', weight='bold')

    num_rows = dataframe.shape[0] - 1
    data_table.auto_set_font_size(False)
    data_table.set_fontsize(6)
    plt.title(f'{table_name} (Records: {num_rows})')
    plt.tight_layout()
    plt.savefig(f"{PATH_IMAGES}{name_file}.png", dpi=200, bbox_inches='tight')
    plt.close()

def capture_get_dataframe_info_image_new(
        table_name: str,
        name_file: str,
        df_spark: pyspark.sql.dataframe.DataFrame,
        previous_data: list = None,
        previous_data_name: str = None,
        figure_size_height=10.0,
        figure_size_width=5.0,
        PATH_IMAGES="./"
):

    # Collect required statistics from PySpark DataFrame
    column_names = [col[:30] for col in df_spark.columns]
    column_types = [dtype for _, dtype in df_spark.dtypes]
    row_count = df_spark.count()
    column_counts = [row_count for _ in df_spark.columns]
    column_null_counts = df_spark.agg(*[F.sum(F.when(F.col(c).isNull(), 1).otherwise(0)).alias(c) for c in df_spark.columns]).collect()[0]
    column_non_null_counts = [row_count - null_count for null_count in column_null_counts]

    info_object = {
        'columns': column_names,
        'dtypes': column_types,
        'count': column_non_null_counts,
        'isnull': column_null_counts
    }

    if previous_data is not None:
        info_object[f'old {previous_data_name}'] = previous_data
        current_data = info_object[previous_data_name]
        info_object['change'] = [prev - curr for prev, curr in zip(previous_data, current_data)]

    dataframe_info = list(zip(*info_object.values()))

    # Plotting the table image
    fig, ax = plt.subplots(figsize=(figure_size_width, figure_size_height))
    ax.axis('off')
    ax.axis('tight')
    col_widths = [0.08, 0.35, 0.15, 0.15, 0.15, 0.15, 0.15]
    data_table = ax.table(
        cellText=dataframe_info,
        colLabels=[' '.join(col.split('_')) for col in info_object.keys()],
        colWidths=col_widths,
        loc='center'
    )

    if previous_data is not None:
        for (i, j), val in np.ndenumerate(np.array(dataframe_info)):
            if j == 6 and val != 0:  # We look into the last column (j==6), and search for non-zero values
                data_table[(i + 1, j)].set_facecolor("red")
                data_table[(i + 1, j)].set_text_props(color='white', weight='bold')

    num_rows = row_count - 1
    data_table.auto_set_font_size(False)
    data_table.set_fontsize(6)
    plt.title(f'{table_name} (Records: {num_rows})')
    plt.tight_layout()
    plt.savefig(f"{PATH_IMAGES}{name_file}.png", dpi=200, bbox_inches='tight')
    plt.close()

def capture_summary_dataset_to_image(
        name_file: str,
        df_spark: pyspark.sql.dataframe.DataFrame,
        dataset_name: str,
        figure_size_height=3,
        font_size=7,
):
    dataset = df_spark.toPandas()
    desc = dataset.describe().round(2).T

    fig, ax = plt.subplots(figsize=(7, figure_size_height))
    new_order = ['min', '25%', '50%', '75%', 'max', 'mean', 'count', 'std']
    desc = desc[new_order]
    # Hide axes
    ax.axis('off')
    ax.axis('tight')
    col_widths = [0.08, 0.08, 0.08, 0.08, 0.08, 0.08, 0.08, 0.08]

    data_table = ax.table(
        cellText=desc.values,
        colLabels=desc.columns,
        rowLabels=[name[:30] for name in desc.index],
        cellLoc='center',
        loc='center',
        colWidths=col_widths,
    )
    data_table.auto_set_font_size(False)
    data_table.set_fontsize(font_size)
    fig.tight_layout()

    plt.title(f"Descriptive Statistics {dataset_name}")
    plt.savefig(f"{PATH_IMAGES}{name_file}.png", dpi=200, bbox_inches='tight')
    plt.close()

    
def rename_columns(df, map_columns: Dict[str,str]) -> pyspark.sql.dataframe.DataFrame:
    for old_name, new_name in map_columns.items():
        df = df.withColumnRenamed(old_name, new_name)
    return df

def read_csv(file_path_enum: DataFramesCSV, sep=",") -> pyspark.sql.dataframe.DataFrame:
    return spark.read.csv(file_path_enum.value, header=True, inferSchema=True, sep=sep)
    

In [209]:
@dataclass
class ProjectManager:
    missing_countries = None
    country_master = None
    integrated_dataset = None
    generate_images_du_02: bool = False

    def __post_init__(self):
        country_master = read_csv(DataFramesCSV.COUNTRY_MASTER_CSV)
        country_master = rename_columns(
            df=country_master,
            map_columns={
                "alpha-3": "country_code",
                "name": "country"
            }
        )
        self.country_master = country_master

    def _capture_get_dataframe_info_image(
            self,
            table_name: str,
            name_file: str,
            df_spark: pyspark.sql.dataframe.DataFrame,
            figure_size_height=10.0,
            figure_size_width=5.0,
            force_save_image=False,
            previous_data: pd.Series = None,
            previous_data_name: DataFramePreviousFieldNameOptions = None,
    ):
        if self.generate_images_du_02 or force_save_image:
            capture_get_dataframe_info_image(
                table_name=table_name,
                name_file=name_file,
                df_spark=df_spark,
                figure_size_height=figure_size_height,
                figure_size_width=figure_size_width,
                previous_data_name=previous_data_name,
                previous_data=previous_data,
            )
    
    def _capture_summary_dataset_to_image(
            self,
            name_file: str,
            df_spark: pyspark.sql.dataframe.DataFrame,
            dataset_name: str,
            figure_size_height=3,
            font_size=7,
    ):
        if self.generate_images_du_02:
            capture_summary_dataset_to_image(
                name_file=name_file,
                df_spark=df_spark,
                dataset_name=dataset_name,
                figure_size_height=figure_size_height,
                font_size=font_size,
            )

    
    def _du_02_country_master(self):
        country_master = read_csv(DataFramesCSV.COUNTRY_MASTER_CSV)
        self._capture_get_dataframe_info_image(
            table_name='Countries Dataset',
            name_file='du_country_dataset',
            df_spark=country_master,
            figure_size_height=2.5
        )

    def _du_02_meat_consumption(self):
        meat_consumption = read_csv(DataFramesCSV.MEAT_CONSUMPTION_CSV)
        self._capture_get_dataframe_info_image(
            table_name='Meat Consumption Dataset',
            name_file='du_meat_consumption_dataset',
            df_spark=meat_consumption,
            figure_size_height=2
        )
        self._capture_summary_dataset_to_image(
            df_spark=meat_consumption,
            dataset_name='Meat Consumption',
            name_file='du_meat_consumption_summary',
            figure_size_height=2
        )

        meat_consumption = rename_columns(
            df=meat_consumption,
            map_columns=COLUMN_RENAME_BY_DATASET.get(DataFramesCSV.MEAT_CONSUMPTION_CSV)
        )
        """
        self._merge_by_country_code(dataset_name='meat_consumption', target_dataset=meat_consumption)
        self._du_data_exploration_basics(
            dataset=meat_consumption,
            metric_name_plot='Kg./Year per Capita - Beef consumption',
            metric_label='beef',
            dataset_name='meat_beef',
            country_label='country_code',
        )
        self._du_data_exploration_basics(
            dataset=meat_consumption,
            metric_name_plot='Kg./Year per Capita - Poultry consumption',
            metric_label='poultry',
            dataset_name='meat_poultry',
            country_label='country_code',
        )
        self._du_data_exploration_basics(
            dataset=meat_consumption,
            metric_name_plot='Kg./Year per Capita - Sheep and Goat consumption',
            metric_label='sheep_and_goat',
            dataset_name='meat_sheep',
            country_label='country_code',
        )
        self._du_data_exploration_basics(
            dataset=meat_consumption,
            metric_name_plot='Kg./Year per Capita - Pig consumption',
            metric_label='pig',
            dataset_name='meat_pig',
            country_label='country_code',
        )
        self._du_data_exploration_basics(
            dataset=meat_consumption,
            metric_name_plot='Kg./Year per Capita - Fish and Seafood consumption',
            metric_label='fish_and_seafood',
            dataset_name='meat_fish_seafood',
            country_label='country_code',
        )
        """

    def _du_02_hunger(self):
        hunger = read_csv(DataFramesCSV.HUNGER_CSV)
        self._capture_get_dataframe_info_image(
            table_name='Hunger Dataset',
            name_file='du_hunger_dataset',
            df_spark=hunger,
            figure_size_height=1.5
        )
        
        self._capture_summary_dataset_to_image(
            df_spark=hunger,
            dataset_name='Hunger',
            name_file='du_hunger_summary',
            figure_size_height=1
        )
        """
        hunger = hunger[["Entity", "Year", "Global Hunger Index (2021)"]]
        hunger.rename(
            inplace=True,
            columns=COLUMN_RENAME_BY_DATASET.get(DataFramesCSV.HUNGER_CSV)
        )

        self._merge_by_country_name(dataset_name='hunger', target_dataset=hunger)
        self._du_data_exploration_basics(
            dataset=hunger,
            metric_name_plot='Global Hunger Index',
            metric_label='hunger_index',
            dataset_name='hunger',
            country_label='country',
        )
        """

    def _du_02_smoking(self):
        smoking = read_csv(DataFramesCSV.SMOKING_CSV)
        self._capture_get_dataframe_info_image(
            table_name='Smoking Dataset',
            name_file='du_smoking_dataset',
            df_spark=smoking,
            figure_size_height=1
        )
        
        self._capture_summary_dataset_to_image(
            df_spark=smoking,
            dataset_name='Smoking',
            name_file='du_smoking_summary',
            figure_size_height=1
        )
        """
        smoking.drop(inplace=True, columns=["Code"])
        smoking.rename(
            inplace=True,
            columns=COLUMN_RENAME_BY_DATASET.get(DataFramesCSV.SMOKING_CSV)
        )
        self._du_data_exploration_basics(
            dataset=smoking,
            metric_name_plot='Percentage Prevalence Tobacco use Adults',
            metric_label='prevalence_smoking',
            dataset_name='smoking',
        )
        self._merge_by_country_name(dataset_name='smoking', target_dataset=smoking)
        """

    def _du_02_alcohol_consumption(self):
        alcohol_consumption = read_csv(DataFramesCSV.ALCOHOL_CONSUMPTION_CSV)
        self._capture_get_dataframe_info_image(
            table_name='Alcohol Consumption Dataset',
            name_file='du_alcohol_consumption_dataset',
            df_spark=alcohol_consumption,
            figure_size_height=1.2
        )
        
        self._capture_summary_dataset_to_image(
            df_spark=alcohol_consumption,
            dataset_name='Alcohol Consumption',
            name_file='du_alcohol_summary',
            figure_size_height=1
        )
        """
        alcohol_consumption.drop(inplace=True, columns=["Code"])

        alcohol_consumption.rename(
            inplace=True,
            columns=COLUMN_RENAME_BY_DATASET.get(DataFramesCSV.ALCOHOL_CONSUMPTION_CSV)
        )
        self._du_data_exploration_basics(
            dataset=alcohol_consumption,
            metric_name_plot='Liters of Pure Alcohol per Capita',
            metric_label='liters_of_pure_alcohol_per_capita',
            dataset_name='alcohol',
        )
        self._merge_by_country_name(dataset_name='alcohol_consumption', target_dataset=alcohol_consumption)
        """
    
    def _du_02_obesity(self):
        obesity_dataset = read_csv(DataFramesCSV.WHO_OBESITY_CSV)

        self._capture_summary_dataset_to_image(
            df_spark=obesity_dataset,
            dataset_name='obesity',
            name_file='du_obesity_summary'
        )

        self._capture_get_dataframe_info_image(
            table_name='Obesity Dataset',
            name_file='du_obesity_dataset',
            df_spark=obesity_dataset,
            figure_size_height=3.3
        )
        """
        # Sex if filtered to match both
        obesity_dataset = obesity_dataset[obesity_dataset.Sex == "Both sexes"]
        obesity_dataset = obesity_dataset[["Numeric", "Countries, territories and areas", "WHO region", 'Year']]
        obesity_dataset.rename(
            inplace=True,
            columns=pj.COLUMN_RENAME_BY_DATASET.get(pj.DataFramesCSV.WHO_OBESITY_CSV)
        )

        self._du_data_exploration_basics(
            dataset=obesity_dataset,
            metric_name_plot='Percentage Obesity',
            metric_label='percentage_obesity',
            dataset_name='obesity',
        )
        """

    def _du_02_happiness(self):
        happiness_record = read_csv(DataFramesCSV.HAPPINESS_REPORT_CSV, sep=";")
        
        self._capture_get_dataframe_info_image(
            table_name='Happiness Report Dataset',
            name_file='du_happiness_dataset',
            df_spark=happiness_record,
            figure_size_height=2.5
        )

        self._capture_summary_dataset_to_image(
            df_spark=happiness_record,
            dataset_name='Happiness',
            name_file='du_happiness_summary',
            figure_size_height=2.3
        )
        """
        new_columns = [col.split(',', 2)[-1].strip() for col in happiness_record.columns]
        happiness_record.columns = new_columns

        happiness_record.rename(
            inplace=True,
            columns=COLUMN_RENAME_BY_DATASET.get(DataFramesCSV.HAPPINESS_REPORT_CSV)
        )
        self._merge_by_country_name(dataset_name='happiness', target_dataset=happiness_record)
        self._du_data_exploration_basics(
            dataset=happiness_record,
            metric_name_plot='Life Ladder',
            metric_label='life_ladder',
            dataset_name='happiness',
            country_label='country',
        )
        self._du_data_exploration_basics(
            dataset=happiness_record,
            metric_name_plot='Social Support',
            metric_label='social_support',
            dataset_name='happiness',
            country_label='country',
        )
        self._du_data_exploration_basics(
            dataset=happiness_record,
            metric_name_plot='Freedom to Make Life Choices',
            metric_label='freedom_to_make_life_choices',
            dataset_name='happiness',
            country_label='country',
        )
        self._du_data_exploration_basics(
            dataset=happiness_record,
            metric_name_plot='Generosity',
            metric_label='generosity',
            dataset_name='happiness',
            country_label='country',
        )
        self._du_data_exploration_basics(
            dataset=happiness_record,
            metric_name_plot='Perceptions of Corruption',
            metric_label='perceptions_of_corruption',
            dataset_name='happiness',
            country_label='country',
        )
        self._du_data_exploration_basics(
            dataset=happiness_record,
            metric_name_plot='Positive Affect',
            metric_label='positive_affect',
            dataset_name='happiness',
            country_label='country',
        )
        self._du_data_exploration_basics(
            dataset=happiness_record,
            metric_name_plot='Negative affect',
            metric_label='negative_affect',
            dataset_name='happiness',
            country_label='country',
        )
        """
    def _function_mapper_du_02(self, dataset: DataFramesCSV):
        mapper = {
            DataFramesCSV.MEAT_CONSUMPTION_CSV: self._du_02_meat_consumption,
            DataFramesCSV.WHO_OBESITY_CSV: self._du_02_obesity,
            DataFramesCSV.HAPPINESS_REPORT_CSV: self._du_02_happiness,
            DataFramesCSV.HUNGER_CSV: self._du_02_hunger,
            DataFramesCSV.SMOKING_CSV: self._du_02_smoking,
            DataFramesCSV.ALCOHOL_CONSUMPTION_CSV: self._du_02_alcohol_consumption,
        }

        return mapper.get(dataset)
    
    def _du_02_run_processes(self):
        self._function_mapper_du_02(dataset=DataFramesCSV.MEAT_CONSUMPTION_CSV)()
        self._function_mapper_du_02(dataset=DataFramesCSV.WHO_OBESITY_CSV)()
        self._function_mapper_du_02(dataset=DataFramesCSV.HAPPINESS_REPORT_CSV)()
        self._function_mapper_du_02(dataset=DataFramesCSV.HUNGER_CSV)()
        self._function_mapper_du_02(dataset=DataFramesCSV.SMOKING_CSV)()
        self._function_mapper_du_02(dataset=DataFramesCSV.ALCOHOL_CONSUMPTION_CSV)()

    def du_02(self):
        # pd.set_option('display.max_columns', None)
        # pd.set_option('display.max_rows', None)
        # pd.set_option('display.expand_frame_repr', False)

        self._du_02_country_master()
        self._du_02_run_processes()

        # self.get_crosstab_missing_countries(save_table=True).info()

In [210]:
manager = ProjectManager(
    generate_images_du_02=True
)
# manager.country_master.head(10)

In [211]:
manager.du_02()

In [197]:
df_spark = read_csv(DataFramesCSV.COUNTRY_MASTER_CSV)

In [110]:
from pyspark.sql import SparkSession, functions as F
# Get column names truncated to 30 characters
column_names = [col[:30] for col in df_spark.columns]

# Get data types of columns
column_types = [dtype for _, dtype in df_spark.dtypes]

# Get count for each column
row_count = df_spark.count()
column_counts = [row_count for _ in df_spark.columns]

# Get null count for each column
column_null_counts = df_spark.agg(*[F.sum(F.when(F.col(c).isNull(), 1).otherwise(0)).alias(c) for c in df_spark.columns]).collect()[0]

info_object = {
    'columns': column_names,
    'dtypes': column_types,
    'count': column_counts,
    'isnull': column_null_counts
}

print(info_object)

{'columns': ['name', 'alpha-2', 'alpha-3', 'country-code', 'iso_3166-2', 'region', 'sub-region', 'intermediate-region', 'region-code', 'sub-region-code', 'intermediate-region-code'], 'dtypes': ['string', 'string', 'string', 'int', 'string', 'string', 'string', 'string', 'int', 'int', 'int'], 'count': [249, 249, 249, 249, 249, 249, 249, 249, 249, 249, 249], 'isnull': Row(name=0, alpha-2=0, alpha-3=0, country-code=0, iso_3166-2=0, region=1, sub-region=1, intermediate-region=142, region-code=1, sub-region-code=1, intermediate-region-code=142)}
