In [None]:
import os
import glob
import xarray as xr
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature

In [None]:
home = Path.home()

correct_data_dict = {
    "rsds": home / "correct_data/OBS6_ERA5_reanaly_1_day_rsds_2000-2000.nc",
    "tas": home / "correct_data/OBS6_ERA5_reanaly_1_day_tas_2000-2000.nc",
    "pr": home / "correct_data/OBS6_ERA5_reanaly_1_day_pr_2000-2000.nc",
    "evspsblpot": home / "correct_data/OBS6_ERA5_reanaly_1_Eday_evspsblpot_2000-2000.nc",
    "evspsbl" : home / "correct_data/OBS6_ERA5_reanaly_1_Eday_evspsbl_1994-1994.nc",
}

correct_data_dict = {
    "rsds": "/data/shared/climate-data/obs6/Tier3/ERA5/OBS6_ERA5_reanaly_1_day_rsds_1997-1997.nc",
    "tas": "/data/shared/climate-data/obs6/Tier3/ERA5/OBS6_ERA5_reanaly_1_day_tas_1997-1997.nc",
    "pr": "/data/shared/climate-data/obs6/Tier3/ERA5/OBS6_ERA5_reanaly_1_day_pr_1997-1997.nc",
    "evspsblpot": "/data/shared/climate-data/obs6/Tier3/ERA5/OBS6_ERA5_reanaly_1_Eday_evspsblpot_1997-1997.nc",
    "evspsbl" : "/data/shared/climate-data/obs6/Tier3/ERA5/OBS6_ERA5_reanaly_1_Eday_evspsbl_1994-1994.nc"
}

test_variables = ["rsds", "tas", "pr", "evspsblpot"]
# test_variables = ["evspsbl"]

input_folder = home / "cmorized_output" 
# input_folder = home / "cmorized_output/not_yet_converted/temp" 

In [None]:
# # Cell for quick tests
# input_folder_test = home / "data_to_be_cmorized"
# test_file = input_folder_test / "era5_evaporation_1994_hourly.nc"

# ds = xr.open_dataset(test_file, engine='netcdf4')

# ds

In [None]:
files = sorted(glob.glob(os.path.join(input_folder, "*.nc")))

fig, axs = plt.subplots(2, 2, figsize=(24,16))
axs = axs.ravel()  # flatten into 1D array for easier looping

# Plot each variable in its corresponding subplot
for ax, variable in zip(axs, test_variables):
    # Pattern for the variable (matches all years)
    pattern = os.path.join(input_folder, f"*_{variable}_*.nc")
    # Get all matching files
    files = sorted(glob.glob(pattern))
    
    if not files:
        print(f"‚ö†Ô∏è No files found for {variable}")
        continue

    print(f"üìÇ Opening {len(files)} files for {variable}")
    # print(files)
    # Open all files as a single dataset (concatenates along time automatically)
    ds = xr.open_mfdataset(files, combine='by_coords', engine='netcdf4')
    ds_mean = ds[variable].mean(dim=["lat", "lon"])
    
    correct_ds = xr.open_dataset(correct_data_dict[variable])
    correct_mean = correct_ds[variable].mean(dim=["lat", "lon"])

    # Check for units
    if correct_ds[variable].attrs['units'] == ds[variable].attrs['units']:
        print(f"Units are checked and correct")
    else:
        raise ValueError(f"Units are wrong\nUnit of variable {variable} {ds[variable].attrs['units']}")

    if not np.any(np.isnat(ds["time_bnds"].values)):
        print(f"Time bounds are checked and correct")
    else:
        raise ValueError(f"Time bounds contain 1 or more 'NaT'\nTime bounds of {variable} ")
    
     # ---- Compute day-of-year and group by year ----
    for year, data in ds_mean.groupby('time.year'):
        doy = data['time'].dt.dayofyear 
        ax.plot(doy, data, label=f"test {year}", alpha=0.7)
        
    for year, data in correct_mean.groupby('time.year'):
        doy = data['time'].dt.dayofyear
        ax.plot(doy, data, label=f"correct {year}", linestyle='--', linewidth=2)
    
    ax.set_title(variable)
    ax.legend()

plt.tight_layout()
plt.show()

## Final changes that were done to the data (do not run per se)

In [None]:
test_file = home / "cmorized_output" / "not_yet_converted" / "OBS6_ERA5_reanaly_1_day_pr_2020-2020.nc"
test_data = xr.open_dataset(test_file)

In [None]:
# test_data = test_data.set_coords('height')
    
display(test_data.dims)

In [None]:
correct_file_leap_dict = {
    "tas": home / "correct_data/OBS6_ERA5_reanaly_1_day_tas_2016-2016.nc",            # main variables
    "pr": home / "correct_data/OBS6_ERA5_reanaly_1_day_pr_2016-2016.nc",
    "evspsblpot": home / "correct_data/OBS6_ERA5_reanaly_1_Eday_evspsblpot_2016-2016.nc",
    "rsds": home / "correct_data/OBS6_ERA5_reanaly_1_day_rsds_2016-2016.nc"
}
correct = xr.open_dataset(correct_file_leap_dict["pr"])
display(correct.dims)

In [None]:
# print(correct.data_vars["time_bnds"].values)
correct_bounds = correct.data_vars["time_bnds"].values
print(correct_bounds.dtype)

# Suppose you want to change the year to 2025
new_year = 2020
year_change = 2020-2016

# Extract components
years = correct_bounds.astype('datetime64[Y]').astype(int) + 1970
months = correct_bounds.astype('datetime64[M]').astype(int) % 12 + 1
days = (correct_bounds - correct_bounds.astype('datetime64[M]')).astype('timedelta64[D]').astype(int) + 1
hours = (correct_bounds - correct_bounds.astype('datetime64[D]')).astype('timedelta64[h]').astype(int)
minutes = (correct_bounds - correct_bounds.astype('datetime64[h]')).astype('timedelta64[m]').astype(int) % 60
seconds = (correct_bounds - correct_bounds.astype('datetime64[m]')).astype('timedelta64[s]').astype(int) % 60
nanoseconds = (correct_bounds - correct_bounds.astype('datetime64[s]')).astype('timedelta64[ns]').astype(int)
# print(years + year_change)
# Rebuild datetime64[ns] with new year
new_bounds = np.array([
    np.datetime64(f'{y:04d}-{m:02d}-{d:02d}T{h:02d}:{mi:02d}:{s:02d}.{ns:09d}')
    for y, m, d, h, mi, s, ns in zip(
       years.flatten(), months.flatten(), days.flatten(), hours.flatten(), minutes.flatten(), seconds.flatten(), nanoseconds.flatten()
    )
]).reshape(correct_bounds.shape)
print(new_bounds)

# print(test_data.data_vars["time_bnds"].values)

# test_data.data_vars["time_bnds"].values = new_bounds

# print(test_data.data_vars["time_bnds"].values)

In [None]:
test_data = test_data.transpose("time", "lat", "lon", ...)

In [None]:
print(test_data.dims)
print(test_data.coords)

In [None]:
test_file = home / "cmorized_output"/ "OBS6_ERA5_reanaly_1_Eday_evspsblpot_2020-2020.nc"
test_data = xr.open_dataset(correct_data_dict["pr"])
# display(test_data)

data_var = test_data["pr"]
# --- Select a single time step ---
data_single = data_var.isel(time=0)
# print(data_single)
# --- Plotting ---
plt.figure(figsize=(12,6))
ax = plt.axes(projection=ccrs.PlateCarree())
data_single.plot.pcolormesh(
    ax=ax,
    transform=ccrs.PlateCarree(),
    cmap='coolwarm',
    cbar_kwargs={'label': f'evspsblpot units'}
)

# Add features
ax.coastlines()
ax.add_feature(cfeature.BORDERS, linewidth=0.5)
ax.set_global()
ax.set_title(f'ERA5 {var_name} at time {str(data_single.time.values)}')

plt.show()

## A real plot

In [None]:
files = sorted(glob.glob(os.path.join(input_folder, "*.nc")))
for file in files:
    ds = xr.open_dataset(file)

    for var in test_variables:
        if var in file.split('/')[-1]:
            var_name = var
    data_var = ds[var_name]
    
    # --- Select a single time step ---
    data_single = data_var.isel(time=0)
    
    # --- Plotting ---
    plt.figure(figsize=(12,6))
    ax = plt.axes(projection=ccrs.PlateCarree())
    data_single.plot.pcolormesh(
        ax=ax,
        transform=ccrs.PlateCarree(),
        cmap='coolwarm',
        cbar_kwargs={'label': f'{var_name} units'}
    )
    
    # Add features
    ax.coastlines()
    ax.add_feature(cfeature.BORDERS, linewidth=0.5)
    ax.set_global()
    ax.set_title(f'ERA5 {var_name} at time {str(data_single.time.values)}')
    
    plt.show()