In [None]:
import glob
import ipywidgets
import os

import cmasher as cmr
import matplotlib.pyplot as plt
import pandas as pd

from functools import reduce
from mpl_toolkits.basemap import Basemap

In [None]:
PATH_DATA = os.path.join('..', 'data')
PATH_RETURNLVLS = os.path.join(PATH_DATA, 'models', 'mev_nn', 'final_ensemble', 'results')

In [None]:
## Select the file for the desired return level
FILES = glob.glob(os.path.join(PATH_RETURNLVLS, '*.csv'))
FILE_NAMES = sorted([os.path.basename(x) for x in glob.glob(os.path.join(PATH_RETURNLVLS, '*.csv'))])

csv_files = ipywidgets.SelectMultiple(
    options=FILE_NAMES,
    value=FILE_NAMES,
    description='Choose return period:',
    disabled=False
)
csv_files

In [None]:
df_temp = pd.read_csv(os.path.join(PATH_RETURNLVLS, csv_files.value[0]))

avail_cols = sorted(list(set(df_temp.columns) - set(['lon', 'lat'])))

sel_cols = ipywidgets.SelectMultiple(
    options=avail_cols,
    value=avail_cols,
    description='Choose columns:',
    disabled=False
)

sel_cols

In [None]:
df_append = []

# append all files together
for file in csv_files.value:
    df_temp = pd.read_csv(os.path.join(PATH_RETURNLVLS, file))
    df_temp = df_temp[['lon', 'lat'] + list(sel_cols.value)]
    
    for col in sel_cols.value:
        df_temp[col] /= 10
        df_temp.rename(columns={col: f"{os.path.splitext(file)[0]}_{col}"}, inplace=True)
    
    df_append.append(df_temp)

In [None]:
return_data = reduce(lambda x, y: pd.merge(x, y, on = ['lon', 'lat'], how='outer'), df_append)

In [None]:
return_data

In [None]:
cmap = cmr.get_sub_cmap('plasma', 0.05, 0.9)

for column in return_data.columns:
    if column in ['lon', 'lat']:
        continue

    print(f"Processing {column}", flush=True)
    
    if "std" in column:
        vmin = 0
        vmax = max(return_data[column].max(), 2)
    else:
        vmin = 1
        vmax = 6

    fig = plt.figure(figsize=(15, 10))

    # initialize the Basemap
    m = Basemap(projection = 'lcc', resolution='f', lat_0=47.5, lon_0=13.3, width=0.6E6, height=3.7E5)
    m.drawmapboundary()
    m.drawcountries(linewidth=2)

    m.scatter(return_data['lon'], return_data['lat'], c=return_data[column], cmap=cmap, marker=',', s=0.7, latlon=True, vmin=vmin, vmax=vmax)

    plt.colorbar(label='MEHS', extend="max")

    plt.savefig(os.path.join(PATH_RETURNLVLS, f"hailriskat_{column}.pdf"), bbox_inches="tight")
    plt.savefig(os.path.join(PATH_RETURNLVLS, f"hailriskat_{column}.png"), bbox_inches="tight")

    plt.close()