In [39]:
import numpy as np
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
import netCDF4
import scipy
from scipy import stats
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import seaborn as sns

import os
from matplotlib.colors import TwoSlopeNorm

In [None]:
def determine_global_ranges(data_dict):
    """
    Determine appropriate global min/max ranges for each variable across all years.
    
    Parameters:
    -----------
    data_dict : dict
        Dictionary with variable names as keys and xarray DataArrays as values
        
    Returns:
    --------
    dict
        Dictionary with variable names as keys and (min, max) tuples as values
    """
    global_ranges = {}
    
    for var_name, da in data_dict.items():
        try:
            if var_name in ['vimfc', 'vidq_dt']:
                # For diverging variables, use symmetric range
                max_abs = float(np.nanmax(np.abs(da.values)))
                # Add a small buffer (5%) to prevent edge cases
                max_abs = max_abs * 1.05
                global_ranges[var_name] = (-max_abs, max_abs)
            else:
                # For sequential variables
                vmin = float(np.nanmin(da.values))
                vmax = float(np.nanmax(da.values))
                # Add a small buffer on the upper end
                # vmax = vmax + (vmax - vmin) * 0.005
                global_ranges[var_name] = (vmin, vmax)
                
            print(f"Global range for {var_name}: {global_ranges[var_name]}")
        except Exception as e:
            print(f"Error determining range for {var_name}: {e}")
            # Fallback to default ranges
            if var_name in ['vimfc', 'vidq_dt']:
                global_ranges[var_name] = (-1, 1)
            elif var_name == 'sst':
                global_ranges[var_name] = (270, 310)  # Typical SST range in Kelvin
            elif var_name == 'tp':
                global_ranges[var_name] = (0, 0.05)  # Typical precipitation in m/day
            else:
                global_ranges[var_name] = (0, 1)
            
        global_ranges['vimfc'] = (-0.0015, 0.0015)
        global_ranges['tp'] = (0, 0.15)

    return global_ranges

def create_hovmoller_diagram(data_dict, year, global_ranges, save_path='./', cmap_dict=None):
    """
    Create a Hovmoller diagram with all variables in a row using contourf.
    Use consistent color scales across years for comparison.
    
    Parameters:
    -----------
    data_dict : dict
        Dictionary with variable names as keys and xarray DataArrays as values
    year : int
        Year to plot
    global_ranges : dict
        Dictionary with variable names as keys and (min, max) tuples as values
    save_path : str
        Directory to save the plot
    cmap_dict : dict, optional
        Dictionary with variable names as keys and colormaps as values
    """
    # Define default colormaps if not provided
    if cmap_dict is None:
        cmap_dict = {
            'mer': 'Reds',
            'sst': 'RdBu_r',
            'vimfc': 'RdBu_r',
            'vidq_dt': 'RdBu_r',
            'tp': 'Blues'
        }
    
    # Create a figure with subplots arranged horizontally
    fig, axs = plt.subplots(1, len(data_dict), figsize=(20, 6), sharey=True, squeeze=False)
    axs = axs.flatten()  # Ensure axs is always indexable regardless of subplot count
    
    # Define nice titles for variables
    var_titles = {
        'mer': 'Mean Evaporation Rate',
        'sst': 'Sea Surface Temperature',
        'vimfc': 'Moisture Flux Convergence',
        'vidq_dt': 'Moisture Tendency',
        'tp': 'Total Precipitation'
    }
    
    # Units dictionary
    units_dict = {
        'mer': '(kg m$^{-2}$ s$^{-1}$)',
        'sst': '(K)',
        'vimfc': '(kg m$^{-2}$ s$^{-1}$)',
        'vidq_dt': '(kg s$^{-1}$)',
        'tp': '(m/day)'
    }
    
    # Loop through variables
    for i, (var_name, da) in enumerate(data_dict.items()):
        # Select data for the specified year
        try:
            # Get the specific year data
            da_year = da.sel(time=slice(f"{year}-01-01", f"{year}-12-31"))
            
            # Extract day of year for y-axis
            doy = da_year.time.dt.dayofyear.values
            
            # Get longitudes for x-axis
            lons = da_year.lon.values if 'lon' in da_year.dims else da_year.longitude.values
            
            # Get global range
            vmin, vmax = global_ranges[var_name]
            
            # Create levels based on variable type
            # if var_name in ['vimfc']:
            #     # For diverging variables, use symmetric levels
            #     levels = np.linspace(vmin + 0.0004, vmax - 0.0004, 21)  # 21 levels for smooth visualization
            #     extend = 'both'

            if var_name in ['vidq_dt', 'vimfc']:
                levels = np.linspace(vmin , vmax, 21)  # 21 levels for smooth visualization
                extend = 'both'

            else:
                # For sequential variables
                levels = np.linspace(vmin, vmax, 21)
                extend = 'max'
            
            # Create the contourf plot
            im = axs[i].contourf(lons, doy, da_year.values, 
                                 levels=levels,
                                 cmap=cmap_dict[var_name],
                                 extend=extend)
            
            # Add colorbar with the fixed range
            cbar = fig.colorbar(im, ax=axs[i], orientation='vertical', pad=0.02)
            cbar.set_label(units_dict.get(var_name, ''))
            
            # Set labels and title
            if i == 0:  # Only add y-label to the leftmost plot
                axs[i].set_ylabel('Day of Year')
            
            axs[i].set_xlabel('Longitude')
            axs[i].set_title(f"{var_titles.get(var_name, var_name.upper())}\n{units_dict.get(var_name, '')}")
            
            # Add gridlines for better readability
            axs[i].grid(True, linestyle='--', alpha=0.3)
            
            # Set y-axis limits to ensure all days of year are visible
            axs[i].set_ylim(0, 366)
            
        except Exception as e:
            print(f"Error plotting {var_name} for year {year}: {e}")
            axs[i].text(0.5, 0.5, f"Error plotting {var_name}", 
                       ha='center', va='center', transform=axs[i].transAxes)
    
    # Add a main title for the entire figure
    plt.suptitle(f'Hovmoller Diagrams for {year}', fontsize=16, y=1.02)
    
    # Adjust layout and save plot
    plt.tight_layout()
    plt.savefig(os.path.join(save_path, f'hovmoller_{year}.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Figure saved as 'hovmoller_{year}.png' in {save_path}")


In [57]:
mer_tropical_box = xr.open_dataarray('/Users/richard_zhang/Library/CloudStorage/OneDrive-Personal/A_Melbourne-Uni/A_Weather_for_21st_Century_RA_Internship/Local_Remote_Influences_on_Coastal_Rainfall/Data_preparation/ERA5_Tropical_Box/mer_tropical_box.nc')
sst_tropical_box = xr.open_dataarray('/Users/richard_zhang/Library/CloudStorage/OneDrive-Personal/A_Melbourne-Uni/A_Weather_for_21st_Century_RA_Internship/Local_Remote_Influences_on_Coastal_Rainfall/Data_preparation/ERA5_Tropical_Box/sst_daily_tropical_band_1998_2022.nc')
vimfc_tropical_box = xr.open_dataarray('/Users/richard_zhang/Library/CloudStorage/OneDrive-Personal/A_Melbourne-Uni/A_Weather_for_21st_Century_RA_Internship/Local_Remote_Influences_on_Coastal_Rainfall/Data_preparation/ERA5_Tropical_Box/vimfc_tropical_box.nc')
viqtend_tropical_box = xr.open_dataarray('/Users/richard_zhang/Library/CloudStorage/OneDrive-Personal/A_Melbourne-Uni/A_Weather_for_21st_Century_RA_Internship/Local_Remote_Influences_on_Coastal_Rainfall/Data_preparation/ERA5_Tropical_Box/viqtend_tropical_box.nc')
T_P_tropical_box = xr.open_dataarray('/Users/richard_zhang/Library/CloudStorage/OneDrive-Personal/A_Melbourne-Uni/A_Weather_for_21st_Century_RA_Internship/Local_Remote_Influences_on_Coastal_Rainfall/Data_preparation/ERA5_Tropical_Box/T_P_tropical_box.nc')

mer_tropical_box = -mer_tropical_box
viqtend_tropical_box = -viqtend_tropical_box

In [58]:
da = mer_tropical_box
region = da
lats = region.latitude.values
weights = np.cos(np.deg2rad(lats))
weights_da = xr.DataArray(weights, coords=[region.latitude], dims=['latitude'])
mer_tropical_box = region.weighted(weights_da).mean(dim='latitude')
da = vimfc_tropical_box
region = da
lats = region.latitude.values
weights = np.cos(np.deg2rad(lats))
weights_da = xr.DataArray(weights, coords=[region.latitude], dims=['latitude'])
vimfc_tropical_box = region.weighted(weights_da).mean(dim='latitude')
da = viqtend_tropical_box
region = da
lats = region.latitude.values
weights = np.cos(np.deg2rad(lats))
weights_da = xr.DataArray(weights, coords=[region.latitude], dims=['latitude'])
viqtend_tropical_box = region.weighted(weights_da).mean(dim='latitude')
da = sst_tropical_box
region = da
lats = region.latitude.values
weights = np.cos(np.deg2rad(lats))
weights_da = xr.DataArray(weights, coords=[region.latitude], dims=['latitude'])
sst_tropical_box = region.weighted(weights_da).mean(dim='latitude')
da = T_P_tropical_box
region = da
lats = region.latitude.values
weights = np.cos(np.deg2rad(lats))
weights_da = xr.DataArray(weights, coords=[region.latitude], dims=['latitude'])
T_P_tropical_box = region.weighted(weights_da).mean(dim='latitude')

In [59]:
def main():
    """Main function to load data and create hovmoller diagrams for each year."""
    try:
        print("Loading datasets...")
        
        mer = mer_tropical_box
        sst = sst_tropical_box
        vimfc = vimfc_tropical_box
        vidq_dt = viqtend_tropical_box
        tp = T_P_tropical_box
        
        # Convert precipitation from m/s to m/day if needed
        if hasattr(tp, 'attrs') and tp.attrs.get('units', '') == 'm s-1':
            tp = tp * 86400  # Convert to m/day
        
        # Create a dictionary with all variables
        data_dict = {
            'mer': mer,
            'sst': sst,
            'vimfc': vimfc,
            'vidq_dt': vidq_dt,
            'tp': tp
        }
        
        # Get the unique years from the time dimension of the first dataset
        time_vars = [var for var in data_dict.values() if 'time' in var.dims]
        if time_vars:
            sample_time = time_vars[0].time
            years = np.unique(sample_time.dt.year.values)
        else:
            # Fallback to default years if no time dimension found
            years = [2020]
        
        print(f"Found years: {years}")
        
        # First determine global ranges across all years
        print("Determining global data ranges for consistent colorbars...")
        global_ranges = determine_global_ranges(data_dict)
        
        # Then create hovmoller diagram for each year with consistent scales
        for year in years:
            print(f"Creating Hovmoller diagram for {year}...")
            create_hovmoller_diagram(data_dict, year, global_ranges)
            
    except Exception as e:
        print(f"An error occurred: {e}")


if __name__ == "__main__":
    main()

Loading datasets...
Found years: [1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011
 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022]
Determining global data ranges for consistent colorbars...
Global range for mer: (6.479545459829056e-06, 0.00022600520972681414)
Global range for sst: (296.6013001692325, 304.15125401986575)
Global range for vimfc: (-0.002679827260323241, 0.002679827260323241)
Global range for vidq_dt: (-0.0005860565289743762, 0.0005860565289743762)
Global range for tp: (-8.940698100409273e-08, 0.18992269983570922)
Creating Hovmoller diagram for 1998...
Figure saved as 'hovmoller_1998.png' in ./
Creating Hovmoller diagram for 1999...
Figure saved as 'hovmoller_1999.png' in ./
Creating Hovmoller diagram for 2000...
Figure saved as 'hovmoller_2000.png' in ./
Creating Hovmoller diagram for 2001...
Figure saved as 'hovmoller_2001.png' in ./
Creating Hovmoller diagram for 2002...
Figure saved as 'hovmoller_2002.png' in ./
Creating Hovmoller diagram 