# Import libraries

In [1]:
import numpy as np
from scipy.spatial import KDTree
import os
import h5py
import pandas as pd
import xarray as xr
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from matplotlib.colors import ListedColormap
import matplotlib as mpl
from tqdm import tqdm

# Set file paths

In [10]:
# Load TAHMO station data (CSV format with columns:'precipitation')
gmet_root_dir = r'C:\Users\c.kwa\Desktop\meteosat_retrieval\data_downloads\GMET\GMet_AWSdata\GMet_AWSdata'
model_root_dir = r'D:\Ghana\Output_data\experiments_adapted131\repotest\testOutput' #r'Z:\cluster_projects\ro\1149_10\earthformer-multisource-to-inca\experiments_adapted130\repotest\testOutput'
#model_root_dir = r'D:\Ghana\Output_data\experiments_adapted130\repotest\testOutput'
#model_root_dir = r"D:/Ghana/Output_data/Earthformer_corrected/2022"

# Set threshold list to analyse

In [3]:
thresholds_list = [0.6, 1.7, 2.7, 5, 8.6, 15]

# Import GMET stationsdata

In [5]:
station_data_list_GMET = []
station_name_list_files = []
for station in os.listdir(gmet_root_dir):
    if station.endswith('.csv'):
        station_file = os.path.join(gmet_root_dir, station)
        station_data = pd.read_csv(station_file)
        station_data_list_GMET.append(station_data)
        station_name_list_files.append(station.split('.')[0])
        #display(station_data)

stations_meta_data_GMET = station_data_list_GMET.pop(7)
station_name_list = stations_meta_data_GMET['STN'].tolist()

station_data_gmet_dict = dict(zip(station_name_list, station_data_list_GMET))

for station_name in station_data_gmet_dict:
    # Convert 'timestamp' column to datetime format
    station_data_gmet_dict[station_name]['Timestamp'] = pd.to_datetime(
        station_data_gmet_dict[station_name]['Date'].astype(str) + ' ' + 
        station_data_gmet_dict[station_name]['Time'].astype(str),
        format='%m/%d/%Y %I:%M:%S %p'
    )
    # Set Timestamp as the index
    station_data_gmet_dict[station_name] = station_data_gmet_dict[station_name].set_index('Timestamp')
    # Drop the original Date and Time columns
    station_data_gmet_dict[station_name] = station_data_gmet_dict[station_name].drop(['Date', 'Time'], axis=1)

    # Ensure the column is numeric
    for col in station_data_gmet_dict[station_name].columns:
        station_data_gmet_dict[station_name][col] = pd.to_numeric(station_data_gmet_dict[station_name][col], errors='coerce')

        # Accumulate over 30 minutes
        station_data_gmet_dict[station_name] = station_data_gmet_dict[station_name].resample('30min', label = 'right', closed = 'right').sum(min_count = 2)

# Get the longitude latitude grid

In [6]:
# Load the SEVIRI file to get longitude latitude grid
file_name = 'MSG4-SEVI-MSG15-0100-NA-20200501001242.772000000Z-NA.hdf5'
file_path = fr'C:\Users\c.kwa\Desktop\meteosat_retrieval\SEVIRI_retrieval\Test_batch\Native_to_h5\hdf5\2020\05\{file_name}'
seviri_ds = xr.open_dataset(file_path, engine = 'netcdf4')

# Print dataset information (variables and dimensions)
grid_lat = seviri_ds['y'][:]  [::-1]
grid_lon = seviri_ds['x'][:]

grid_lat = grid_lat.values
grid_lon = grid_lon.values

# Create 2D latitude and longitude grids
grid_lon_2d, grid_lat_2d = np.meshgrid(grid_lon, grid_lat)

grid_lat_2d = np.flipud(grid_lat_2d)

seviri_ds

# Define function to find nearest grid to stations

In [7]:
def nearest_grid_to_stations(grid_lat, grid_lon, station_lat, station_lon):
    """
    Map stations to nearest grid cells using KDTree.
    
    Parameters
    ----------
    grid_lat : numpy.ndarray
        Latitudes of grid cells.
    grid_lon : numpy.ndarray
        Longitudes of grid cells.
    station_lat : numpy.ndarray
        Latitudes of stations.
    station_lon : numpy.ndarray
        Longitudes of stations.
    
    Returns
    -------
    indices : numpy.ndarray
        Indices of grid cells nearest to each station.
    """
    grid_points = np.array(list(zip(grid_lat.ravel(), grid_lon.ravel())))
    station_points = np.array(list(zip(station_lat, station_lon)))
    tree = KDTree(grid_points)
    _, indices = tree.query(station_points)
    return indices

In [11]:
indices = nearest_grid_to_stations(grid_lat_2d, grid_lon_2d, stations_meta_data_GMET['LAT'], stations_meta_data_GMET['LON'])

dt_list = np.empty(len(os.listdir(model_root_dir)), dtype=object)
target_station_cells = np.zeros((len(os.listdir(model_root_dir)), len(station_name_list)))
pred_station_cells = np.zeros((len(os.listdir(model_root_dir)), len(station_name_list)))

# Wrap the loop with tqdm for progress bar
for i, file in enumerate(tqdm(os.listdir(model_root_dir), desc="Processing files")):
    try:
        time_str = file.split('_')[1].split('.')[0]
        # Rearrange timestamp to match nearest timestamp of GMET stations (adjust as needed)
        dt_list[i] = datetime.strptime(time_str, "%Y%m%d%H%M%S") + timedelta(seconds=1)
        with h5py.File(fr'{model_root_dir}//{file}', 'r') as f:
            # Convert from mm/h to mm/30 min
            target = f['y'][:] / 2
            pred = f['y_hat'][:] / 2

            # Reduce the time dimension as it is 1
            target = np.squeeze(target)
            pred = np.squeeze(pred)

            # Flip the data along the y-axis
            target = np.flipud(target)
            pred = np.flipud(pred)

            for j, grid_idx in enumerate(indices):
                target_station_cells[i, j] = target.flat[grid_idx]
                pred_station_cells[i, j] = pred.flat[grid_idx]

    except Exception as e:
        print(f"Error occurred while processing file {file}: {e}")
        for j, grid_idx in enumerate(indices):
            target_station_cells[i, j] = np.nan
            pred_station_cells[i, j] = np.nan

Processing files: 100%|██████████████████████████████████████████████████████████████| 728/728 [00:11<00:00, 63.83it/s]


In [12]:
datetime_index = pd.DatetimeIndex(dt_list)
df_target = pd.DataFrame(target_station_cells, index=datetime_index, columns = station_name_list)
df_target = df_target.sort_index()
#df_target.to_csv('df_target_gmet_balanced_2022.csv')



df_pred = pd.DataFrame(pred_station_cells, index=datetime_index, columns = station_name_list)
df_pred = df_pred.sort_index()
df_pred.to_csv('df_output_gmet_balanced_test.csv')