In [None]:
import glob
import os
import ipywidgets

import geopandas as gpd
import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr

from functools import reduce
from mpl_toolkits.basemap import Basemap
from shapely.geometry import Point

In [None]:
QUALITY_HATCH = False
DATA_QUALITY_THRESHOLD = 0.4

In [None]:
PATH_DATA = os.path.join('..', '..', 'data')
PATH_DATA_SOURCE = os.path.join(PATH_DATA, 'raw_data')
PATH_MODELS = os.path.join(PATH_DATA, 'models', 'mev_nn')
PATH_SHAPEFILE = os.path.join(PATH_DATA_SOURCE, 'geodata', 'AUT_adm0.shp')
FOLDERS = os.listdir(PATH_MODELS)

In [None]:
FOLDERS.sort(reverse=True)

In [None]:
## Select the desired training cycle
cycle = ipywidgets.Select(
    options=FOLDERS,
    value=FOLDERS[0],
    # rows=10,
    description='Choose Training cycle:',
    disabled=False
)

cycle

In [None]:
## Select the file for the desired return level
FILES = glob.glob(os.path.join(PATH_MODELS, cycle.value, 'returns*.csv'))
FILE_NAMES = [os.path.basename(x) for x in glob.glob(os.path.join(PATH_MODELS, cycle.value, 'returns*.csv'))]
csv_files = ipywidgets.SelectMultiple(
    options=FILE_NAMES,
    value=FILE_NAMES,
    # rows=10,
    description='Choose return period:',
    disabled=False
)
csv_files

In [None]:
df_append = []
#append all files together
for file in csv_files.value:
            df_temp = pd.read_csv(os.path.join(PATH_MODELS, cycle.value, file))
            df_temp['target'] /= 10
            df_temp.rename(columns={'target': file}, inplace=True)
            df_append.append(df_temp)

In [None]:
return_data = reduce(lambda x, y: pd.merge(x, y, on = ['lon', 'lat'], how='outer'), df_append)

In [None]:
return_data

In [None]:
if QUALITY_HATCH:
    quality_reports = glob.glob(os.path.join(PATH_DATA_SOURCE, 'ZAMG', 'dataquality_map', '2023_12_04','*.h5'))
    quality_file = h5py.File(quality_reports[0])

    data_lat = quality_file["atnt_grid_lat"][:]
    data_lon = quality_file["atnt_grid_lon"][:]
    data_quality = quality_file['data'][:]

    data_xr_quality = xr.DataArray(
                            data_quality,
                            coords=dict(
                                lon=(["y", "x"], data_lon),
                                lat=(["y", "x"], data_lat),
                            ),
                            dims=["y", "x"],
                            name = 'quality'
                        )

    data_pandas_quality = data_xr_quality.to_dataframe()
    data_schraffur_quality = data_pandas_quality[(data_pandas_quality['quality'].isna()) | (data_pandas_quality['quality'] < DATA_QUALITY_THRESHOLD)]

In [None]:
if QUALITY_HATCH:
    elevation_reports = glob.glob(os.path.join(PATH_DATA_SOURCE, 'ZAMG', 'dataquality_map', '2023_11_20','*.h5'))
    elevation_file = h5py.File(elevation_reports[0])

    data_lat_elevation = elevation_file["atnt_grid_lat"][:]
    data_lon_elevation = elevation_file["atnt_grid_lon"][:]
    data_elevation = elevation_file['data'][:]

    data_xr_elevation = xr.DataArray(
                            data_elevation,
                            coords=dict(
                                lon=(["y", "x"], data_lon_elevation),
                                lat=(["y", "x"], data_lat_elevation),
                            ),
                            dims=["y", "x"],
                            name = 'elevation'
                        )
    
    data_pandas_elevation = data_xr_elevation.to_dataframe()
    data_schraffur_elevation = data_pandas_elevation[(data_pandas_elevation['elevation'].isna())]

In [None]:
geometry = [Point(xy) for xy in zip(return_data['lon'], return_data['lat'])]
gdf = gpd.GeoDataFrame(return_data, geometry=geometry, crs="EPSG:4326")
austria = gpd.read_file(PATH_SHAPEFILE)
return_data_clipped = gpd.sjoin(gdf, austria, how="inner", predicate='within')

In [None]:
for file in csv_files.value:
    print(f"Processing {file}", flush=True)

    for clip_to_austria in [False, True]:
        infix_clipped = "_clipped_AUT" if clip_to_austria else ""

        return_data_tmp = return_data_clipped if clip_to_austria else return_data
        
        fig = plt.figure(figsize=(15, 10))
    
        #initialize the Basemap
        m = Basemap(projection = 'lcc', resolution='f', lat_0=47.5, lon_0=13.3, width=0.6E6, height=3.7E5)
        m.drawmapboundary()
        m.drawcountries(linewidth=2)
    
        m.scatter(return_data_tmp['lon'], return_data_tmp['lat'], c=return_data_tmp[file], cmap="jet", marker=',', s=0.7, latlon=True, vmin=1, vmax=6)
    
        plt.colorbar(label='MEHS', extend="max")
    
        plt.savefig(os.path.join(PATH_MODELS, cycle.value, f"hailriskat_{file}{infix_clipped}.pdf"), bbox_inches="tight")
        plt.savefig(os.path.join(PATH_MODELS, cycle.value, f"hailriskat_{file}{infix_clipped}.png"), bbox_inches="tight")
    
        if QUALITY_HATCH:
            print(f"Processing {file} (hatched)", flush=True)
            m.scatter(data_schraffur_quality['lon'], data_schraffur_quality['lat'], s=0.01, edgecolor='black', linewidth=3, latlon=True, facecolor='black', hatch='x')
            m.scatter(data_schraffur_elevation['lon'], data_schraffur_elevation['lat'], s=0.01, edgecolor='black', linewidth=3, latlon=True, facecolor='black', hatch='x')
            plt.savefig(os.path.join(PATH_MODELS, cycle.value, f"hailriskat_{file}{infix_clipped}_hatched.pdf"), bbox_inches="tight")
            plt.savefig(os.path.join(PATH_MODELS, cycle.value, f"hailriskat_{file}{infix_clipped}_hatched.png"), bbox_inches="tight")
    
        plt.close()

In [None]:
colors = {
        '(0, 2] cm' : '#C875C4',
        '(2, 3] cm' : '#579DF7', 
        '(3, 4] cm' : '#A4F885',
        '(4, 5] cm' : '#F3AE33',
        '> 5cm' : '#8F383F',
}

for file in csv_files.value:
    print(f"Processing {file}", flush=True)

    for clip_to_austria in [False, True]:            
        infix_clipped = "_clipped_AUT" if clip_to_austria else ""

        return_data_tmp = return_data_clipped.copy() if clip_to_austria else return_data.copy()
    
        return_data_tmp.loc[:, 'cat_string'] = np.select(
            [
                (return_data_tmp[file] > 0) & (return_data_tmp[file] < 2), 
                (return_data_tmp[file] >= 2) & (return_data_tmp[file] < 3), 
                (return_data_tmp[file] >= 3) & (return_data_tmp[file] < 4), 
                (return_data_tmp[file] >= 4) & (return_data_tmp[file] < 5), 
                (return_data_tmp[file] >= 5),
            ], 
            [
                list(colors.keys())[0],
                list(colors.keys())[1],
                list(colors.keys())[2],
                list(colors.keys())[3],
                list(colors.keys())[4],
            ], 
            default='ERROR'
        )

        return_data_tmp.query("cat_string != 'ERROR'", inplace=True)
    
        ax = plt.figure(figsize=(15, 10))
    
        #initialize the Basemap
        m = Basemap(projection='lcc', resolution='f', lat_0=47.5, lon_0=13.3, width=0.6E6, height=3.7E5)
        m.drawmapboundary()
        m.drawcountries(linewidth=2)
    
        for cat in np.unique(return_data_tmp['cat_string']):
            return_data_cat = return_data_tmp.query(f'cat_string == "{cat}"') 
            
            m.scatter(return_data_cat['lon'], return_data_cat['lat'], c=colors[cat], latlon=True, label=cat, marker=',', s=0.7, alpha=1)
    
        # Plot legend.
        ax.legend(loc='center left', bbox_to_anchor=(0.9, 0.5), markerscale=10)
    
        plt.savefig(os.path.join(PATH_MODELS, cycle.value, f"hailriskat_{file}{infix_clipped}_classification.pdf"), bbox_inches="tight")
        plt.savefig(os.path.join(PATH_MODELS, cycle.value, f"hailriskat_{file}{infix_clipped}_classification.png"), bbox_inches="tight")
        
        plt.close()