# Import libraries

In [1]:
import numpy as np
from scipy.spatial import KDTree
import os
import h5py
import pandas as pd
import xarray as xr
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from matplotlib.colors import ListedColormap
import matplotlib as mpl
from tqdm import tqdm

# Set file paths

In [9]:
# Load TAHMO station data (CSV format with columns:'precipitation')
tahmo_root_dir = r'C:\Users\c.kwa\Desktop\meteosat_retrieval\data_downloads\TAHMO\TAHMO_export_6704eaa37e81da18c0b7e245'
#model_root_dir = r'Z:\cluster_projects\ro\1149_10\earthformer-multisource-to-inca\experiments_adapted130\repotest\testOutput'
model_root_dir = r'D:\Ghana\Output_data\experiments_adapted121\repotest\testOutput'

# Set threshold list to analyse

In [3]:
thresholds_list = [0.6, 1.7, 2.7, 5, 8.6, 15]

# Import TAHMO stationsdata

In [4]:
station_name_list = []
station_data_list_TAHMO = []
for station in os.listdir(tahmo_root_dir): 
    if station.endswith('.csv'):
        station_file = os.path.join(tahmo_root_dir, station)
        station_data = pd.read_csv(station_file)
        station_name_list.append(station.split('.')[0])
        station_data_list_TAHMO.append(station_data)

sensors_meta_data_tahmo = station_data_list_TAHMO.pop(0)
stations_meta_data_tahmo = station_data_list_TAHMO.pop(0)
sensors_meta_data_tahmo_name = station_name_list.pop(0)
stations_meta_data_tahmo_name = station_name_list.pop(0)


station_data_tahmo_dict = dict(zip(station_name_list, station_data_list_TAHMO))

#Assuming df is the DataFrame containing the data
for station_name in station_data_tahmo_dict:
    if station_data_tahmo_dict[station_name].shape[1] == 3:
        station_data_tahmo_dict[station_name]['precipitation (mm)'] = station_data_tahmo_dict[station_name].iloc[:, 2].combine_first(station_data_tahmo_dict[station_name].iloc[:, 2])

        # Dropping the old columns (optional)
        station_data_tahmo_dict[station_name] = station_data_tahmo_dict[station_name].drop(station_data_tahmo_dict[station_name].columns[[1, 2]], axis = 1)

        # # Result
        # display(station_data_tahmo_dict[station_name].head())

    elif station_data_tahmo_dict[station_name].shape[1] == 4:
        station_data_tahmo_dict[station_name]['precipitation (mm)'] = station_data_tahmo_dict[station_name].iloc[:, 2].combine_first(station_data_tahmo_dict[station_name].iloc[:, 1]).combine_first(station_data_tahmo_dict[station_name].iloc[:, 3])
        # Dropping the old columns (optional)
        station_data_tahmo_dict[station_name] = station_data_tahmo_dict[station_name].drop(station_data_tahmo_dict[station_name].columns[[1, 2, 3]], axis = 1)


    if 'timestamp' in station_data_tahmo_dict[station_name].columns:
        # Set Timestamp as the index
        station_data_tahmo_dict[station_name]['timestamp'] = pd.to_datetime(station_data_tahmo_dict[station_name]['timestamp'])
        station_data_tahmo_dict[station_name] = station_data_tahmo_dict[station_name].set_index('timestamp')

    # Accumulate over 30 minutes
    station_data_tahmo_dict[station_name] = station_data_tahmo_dict[station_name].resample('30min').sum(min_count = 2)
display(stations_meta_data_tahmo)

Unnamed: 0,station code,name,country,installation height (m),latitude,longitude,elevation (m),timezone
0,TA00005,Asankragwa SHS,GH,2.0,5.807731,-2.426395,125.1,Africa/Accra
1,TA00007,Nana Yaa Kesse SHS Duayaw Nkwanta,GH,2.0,7.188273,-2.097477,341.1,Africa/Accra
2,TA00010,Chiraa SHS,GH,2.0,7.389595,-2.185991,337.1,Africa/Accra
3,TA00016,"Accra Academy School, Accra",GH,2.0,5.573104,-0.2445,32.4,Africa/Accra
4,TA00045,"Asesewaa Senior High School, Asesewaa",GH,2.0,6.400626,-0.146577,372.3,Africa/Accra
5,TA00113,Nkwanta SHS,GH,2.0,8.271124,0.515265,213.7,Africa/Accra
6,TA00116,Amedzofe Technical Institute,GH,2.0,6.845815,0.440698,731.8,Africa/Accra
7,TA00117,Keta SHS,GH,2.0,5.895083,0.989567,10.0,Africa/Accra
8,TA00118,Tema Secondary School,GH,2.0,5.641413,-0.01187,18.4,Africa/Accra
9,TA00120,Nkroful Agric SHS,GH,2.0,4.971861,-2.322676,28.0,Africa/Accra


# Get the longitude latitude grid

In [5]:
# Load the SEVIRI file to get longitude latitude grid
file_name = 'MSG4-SEVI-MSG15-0100-NA-20200501001242.772000000Z-NA.hdf5'
file_path = fr'C:\Users\c.kwa\Desktop\meteosat_retrieval\SEVIRI_retrieval\Test_batch\Native_to_h5\hdf5\2020\05\{file_name}'
seviri_ds = xr.open_dataset(file_path, engine = 'netcdf4')

# Print dataset information (variables and dimensions)
grid_lat = seviri_ds['y'][:]  [::-1]
grid_lon = seviri_ds['x'][:]

grid_lat = grid_lat.values
grid_lon = grid_lon.values

# Create 2D latitude and longitude grids
grid_lon_2d, grid_lat_2d = np.meshgrid(grid_lon, grid_lat)

grid_lat_2d = np.flipud(grid_lat_2d)

seviri_ds

# Define function to find nearest grid to stations

In [6]:
def nearest_grid_to_stations(grid_lat, grid_lon, station_lat, station_lon):
    """
    Map stations to nearest grid cells using KDTree.
    
    Parameters
    ----------
    grid_lat : numpy.ndarray
        Latitudes of grid cells.
    grid_lon : numpy.ndarray
        Longitudes of grid cells.
    station_lat : numpy.ndarray
        Latitudes of stations.
    station_lon : numpy.ndarray
        Longitudes of stations.
    
    Returns
    -------
    indices : numpy.ndarray
        Indices of grid cells nearest to each station.
    """
    grid_points = np.array(list(zip(grid_lat.ravel(), grid_lon.ravel())))
    station_points = np.array(list(zip(station_lat, station_lon)))
    tree = KDTree(grid_points)
    _, indices = tree.query(station_points)
    return indices

In [10]:
indices = nearest_grid_to_stations(grid_lat_2d, grid_lon_2d, stations_meta_data_tahmo['latitude'], stations_meta_data_tahmo['longitude'])

dt_list = np.empty(len(os.listdir(model_root_dir)), dtype=object)
target_station_cells = np.zeros((len(os.listdir(model_root_dir)), len(station_name_list)))
pred_station_cells = np.zeros((len(os.listdir(model_root_dir)), len(station_name_list)))

# Wrap the loop with tqdm for progress bar
for i, file in enumerate(tqdm(os.listdir(model_root_dir), desc="Processing files")):
    try:
        time_str = file.split('_')[1].split('.')[0]
        # Rearrange timestamp to match nearest timestamp of GMET stations (adjust as needed)
        dt_list[i] = datetime.strptime(time_str, "%Y%m%d%H%M%S") + timedelta(seconds=1)
        with h5py.File(fr'{model_root_dir}//{file}', 'r') as f:
            # Convert from mm/h to mm/30 min
            target = f['y'][:] / 2
            pred = f['y_hat'][:] / 2

            # Reduce the time dimension as it is 1
            target = np.squeeze(target)
            pred = np.squeeze(pred)

            # Flip the data along the y-axis
            target = np.flipud(target)
            pred = np.flipud(pred)

            for j, grid_idx in enumerate(indices):
                target_station_cells[i, j] = target.flat[grid_idx]
                pred_station_cells[i, j] = pred.flat[grid_idx]

    except Exception as e:
        print(f"Error occurred while processing file {file}: {e}")
        for j, grid_idx in enumerate(indices):
            target_station_cells[i, j] = np.nan
            pred_station_cells[i, j] = np.nan

Processing files: 100%|█████████████████████████████████████████████████████████████| 728/728 [00:01<00:00, 531.69it/s]


In [11]:
datetime_index = pd.DatetimeIndex(dt_list)
df_target = pd.DataFrame(target_station_cells, index=datetime_index, columns = station_name_list)
df_target = df_target.sort_index()
#df_target.to_csv('df_target_balanced_tahmo_2022.csv')



df_pred = pd.DataFrame(pred_station_cells, index=datetime_index, columns = station_name_list)
df_pred = df_pred.sort_index()
df_pred.to_csv('df_output_MSE_tahmo_test.csv')