In [1]:
import itertools
import os

import pandas as pd
from scipy.stats import mannwhitneyu
import numpy as np

import utils

In [2]:
# Input parameters
country_codes = ['ESP', 'EST', 'ETH', 'USA']
dem_names = ['AW3D30', 'HydroSHEDS', 'MERIT', 'NASADEM', 'TanDEM']
feature_types = ['basin', 'stream']

In [3]:
stat_name = 'forest_pct'

In [4]:
for feature_type in feature_types:
    out_fp = f'D:/dem_comparison/data/mannwhitneyu_{stat_name}_{feature_type}.xlsx'
    if os.path.exists(out_fp):
        os.remove(out_fp)
    for country_code in country_codes:
        merged = utils.merge_stats_for_plot(country_code, dem_names, feature_type, stat_name)
        merged[f'{stat_name}_class'] = merged.apply(utils.get_class_func(stat_name), axis=1)
        stat_classes = merged[f'{stat_name}_class'].unique()
        row_list = []
        for dem_name in dem_names:
            for class_pair in itertools.combinations(stat_classes, 2):
                class_1 = class_pair[0]
                class_2 = class_pair[1]
                subset_1 = merged.loc[merged[f'{stat_name}_class'] == class_1]
                subset_2 = merged.loc[merged[f'{stat_name}_class'] == class_2]
                x = subset_1.loc[subset_1['dem_name'] == dem_name]['dist_to_ref'].to_list()
                y = subset_2.loc[subset_2['dem_name'] == dem_name]['dist_to_ref'].to_list()
                try:
                    U1, p = mannwhitneyu(x, y)
                    U1 = round(U1, 1)
                    p = round(p, 3)
                except ValueError:
                    U1 = np.nan
                    p = np.nan
                catchment_name = utils.get_catchment_name(country_code)
                if p < 0.05:
                    significance = 1
                else:
                    significance = 0
                row = (
                    catchment_name,
                    country_code,
                    dem_name,
                    feature_type,
                    stat_name,
                    class_1,
                    len(x),
                    class_2,
                    len(y),
                    U1,
                    p,
                    significance
                )
                row_list.append(row)
        out_df_columns = [
            'catchment_name',
            'country_code',
            'dem_name',
            'feature_type',
            'stat_name',
            'class_1',
            'count_1',
            'class_2',
            'count_2',
            'U',
            'p',
            'significant'
        ]
        out_df = pd.DataFrame(row_list, columns=out_df_columns)
        display(out_df)
        if not os.path.exists(out_fp):
            out_df.to_excel(out_fp, sheet_name=catchment_name, index=False)
        else:
            with pd.ExcelWriter(out_fp, mode='a') as writer:
                out_df.to_excel(writer, sheet_name=catchment_name, index=False)