Description: this notebook serve as the test bench for calculating tele-connection using Cuda cores. It requires configuration to the working envs

```powershell
conda activate Cudapy
```

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import scipy.signal as signal
import rasterio
import scipy.stats as stats
from scipy.stats import hypergeom
from scipy.interpolate import RegularGridInterpolator
from matplotlib.colors import LinearSegmentedColormap
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from matplotlib.colors import ListedColormap
import matplotlib.colors as mcolors
from tqdm import tqdm
import geopandas as gpd
import pickle
import os
from numba import jit, cuda

In [2]:
def label_la_nina_sequences(df):
    """
    Labels La Nina sequences where the maximum sequence number is greater than 5.
    
    Parameters:
    df (pandas.DataFrame): DataFrame containing La_Nina_Seq column
    
    Returns:
    pandas.DataFrame: Original data with additional label column for significant sequences
    """
    # Create a copy of the dataframe
    df_copy = df.copy()
    
    # Initialize the label column
    df_copy['sig_elnino'] = 0
    df_copy['sig_lanina'] = 0
    
    # Function to find max in a sequence
    def get_sequence_max(start_idx, seq_in):
        sequence = []
        idx = start_idx
        while idx < len(df_copy) and df_copy[seq_in].iloc[idx] != 0:
            sequence.append(df_copy[seq_in].iloc[idx])
            idx += 1
        return max(sequence) if sequence else 0

    # Iterate through the dataframe
    i = 0
    seq = 'La_Nina_Seq'
    while i < len(df_copy):
        if df_copy[seq].iloc[i] == 1:  # Start of a sequence
            max_in_sequence = get_sequence_max(i, seq)
            if max_in_sequence > 5:
                # Label all numbers in this sequence
                j = i
                while j < len(df_copy) and df_copy[seq].iloc[j] != 0:
                    df_copy.loc[df_copy.index[j], 'sig_lanina'] = 1
                    j += 1
            # Skip to end of current sequence
            while i < len(df_copy) and df_copy[seq].iloc[i] != 0:
                i += 1
        else:
            i += 1

     # Iterate through the dataframe
    i = 0
    seq = 'El_Nino_Seq'
    while i < len(df_copy):
        if df_copy[seq].iloc[i] == 1:  # Start of a sequence
            max_in_sequence = get_sequence_max(i, seq)
            if max_in_sequence > 5:
                # Label all numbers in this sequence
                j = i
                while j < len(df_copy) and df_copy[seq].iloc[j] != 0:
                    df_copy.loc[df_copy.index[j], 'sig_elnino'] = 1
                    j += 1
            # Skip to end of current sequence
            while i < len(df_copy) and df_copy[seq].iloc[i] != 0:
                i += 1
        else:
            i += 1
    
    return df_copy

In [3]:
# Define column names manually
column_monthly = ["Year", "Month", "NINO1+2", "ANOM_NINO1+2", "NINO3", "ANOM_NINO3",
                "NINO4", "ANOM_NINO4", "NINO3.4", "ANOM_NINO3.4"]
column_seasonal = ['SEAS', 'YR', 'TOTAL', 'ANOM']

# Read the file
# https://www.cpc.ncep.noaa.gov/data/indices/ersst5.nino.mth.91-20.ascii
monthly = pd.read_csv("data/ersst5_nino_monthly.txt", sep='\s+', names=column_monthly, skiprows=1)[['Year', 'Month', 'NINO3.4', 'ANOM_NINO3.4']]
seasonal = pd.read_csv("data/oni_seasonal.txt", sep='\s+', names=column_seasonal, skiprows=1)

# Identify El Niño and La Niña periods
seasonal['El_Nino'] = (seasonal['ANOM'] >= 0.5).astype(int)
seasonal['La_Nina'] = (seasonal['ANOM'] <= -0.5).astype(int)

# Compute running count of consecutive months where ONI exceeds 0.5 or is below -0.5
seasonal['El_Nino_Seq'] = seasonal['El_Nino'] * seasonal['El_Nino'].groupby((seasonal['El_Nino'] != seasonal['El_Nino'].shift()).cumsum()).transform('cumsum')
seasonal['La_Nina_Seq'] = seasonal['La_Nina'] * seasonal['La_Nina'].groupby((seasonal['La_Nina'] != seasonal['La_Nina'].shift()).cumsum()).transform('cumsum')

enso_base = label_la_nina_sequences(seasonal)

In [4]:
climate_var = 'snowcover'

In [13]:
if climate_var == 'snowcover':
    # Load NetCDF file
    file_path = "data/snow.nc"
    ds = xr.open_dataset(file_path)
    ds_values = ds['snowc'].values

    lats = ds['latitude'].values
    lons = ds['longitude'].values
    lons = np.where(lons > 180, lons - 360, lons) # same as SPEI and crop
    ds = ds.assign_coords(longitude=lons)

    # Create meshgrid of lat/lon pairs
    lon_grid, lat_grid = np.meshgrid(lons, lats)
    coordinates = np.column_stack([lon_grid.ravel(), lat_grid.ravel()])

elif climate_var == 'snowalbedo':
    # Load NetCDF file
    file_path = "data/snow.nc"
    ds = xr.open_dataset(file_path)
    ds_values = ds['asn'].values

    lats = ds['latitude'].values
    lons = ds['longitude'].values
    lons = np.where(lons > 180, lons - 360, lons) # same as SPEI and crop
    ds = ds.assign_coords(longitude=lons)

    # Create meshgrid of lat/lon pairs
    lon_grid, lat_grid = np.meshgrid(lons, lats)
    coordinates = np.column_stack([lon_grid.ravel(), lat_grid.ravel()])

elif climate_var == 'net-solar':
    file_path = "data/solar1.nc"
    ds1 = xr.open_dataset(file_path)

    file_path = "data/solar2.nc"
    ds2 = xr.open_dataset(file_path)
    # Concatenate along the time dimension

    ds_values = np.concatenate([ds1['ssr'].values[:624], ds2['ssr'].values, ds1['ssr'].values[624:]], axis=0)

    lats = ds1['latitude'].values
    lons = ds1['longitude'].values
    lons = np.where(lons > 180, lons - 360, lons) # same as SPEI and crop
    ds = ds1.assign_coords(longitude=lons)

    # Create meshgrid of lat/lon pairs
    lon_grid, lat_grid = np.meshgrid(lons, lats)
    coordinates = np.column_stack([lon_grid.ravel(), lat_grid.ravel()])

elif climate_var == 'net-thermal':
    file_path = "data/solar1.nc"
    ds1 = xr.open_dataset(file_path)

    file_path = "data/solar2.nc"
    ds2 = xr.open_dataset(file_path)
    # Concatenate along the time dimension

    ds_values = np.concatenate([ds1['str'].values[:624], ds2['str'].values, ds1['str'].values[624:]], axis=0)

    lats = ds1['latitude'].values
    lons = ds1['longitude'].values
    lons = np.where(lons > 180, lons - 360, lons) # same as SPEI and crop
    ds = ds1.assign_coords(longitude=lons)

    # Create meshgrid of lat/lon pairs
    lon_grid, lat_grid = np.meshgrid(lons, lats)
    coordinates = np.column_stack([lon_grid.ravel(), lat_grid.ravel()])


In [None]:
# Load the world shapefile
world = gpd.read_file('data/ne_10m_admin_0_countries/ne_10m_admin_0_countries.shp')

# Create a GeoDataFrame with all mesh grid points (assuming `coordinates` holds all lat/lon)
points_gdf = gpd.GeoDataFrame(pd.DataFrame(coordinates, columns=['Longitude', 'Latitude']),
                              geometry=gpd.points_from_xy(coordinates[:, 0], coordinates[:, 1]), crs="EPSG:4326")

# Perform spatial join to select only land-based points (union of all country polygons)
land_points = gpd.sjoin(points_gdf, world, op='within')

# Extract indices for all valid land points
all_points_idx_climate = []  # This replaces `points_idx_climate`
for lon, lat in land_points[['Longitude', 'Latitude']].values:
    lat_idx = abs(ds['latitude'] - lat).argmin().item()
    lon_idx = abs(ds['longitude'] - lon).argmin().item()
    all_points_idx_climate.append([lat_idx, lon_idx])


In [12]:
# # Save data to a pickle file
# with open("outcome/lon_lat_idx_era5_w.pkl", "wb") as file:
    #pickle.dump(all_points_idx_climate, file)

with open("outcome/lon_lat_idx_era5_w.pkl", "rb") as file:
    all_points_idx_climate = pickle.load(file)

In [10]:
# lats = range(1801)
# lons = range(3600)
# # Create meshgrid of lat/lon pairs
# lon_grid, lat_grid = np.meshgrid(lons, lats)
# # Flatten the grids and create coordinate pairs
# coordinates_idx = list(zip(lat_grid.ravel(), lon_grid.ravel()))

### Seasonal

In [11]:
season_order = ["DJF", "JFM", "FMA", "MAM", "AMJ", "MJJ", "JJA", "JAS", "ASO", "SON", "OND", "NDJ"]
states = ['El Nino', 'La Nina', 'Normal']
tercile_order = ['BN', 'NN', 'AN']
states_lower = ['elnino', 'lanina', 'normal']

seasons_all = enso_base['SEAS'].values
sig_elnino = enso_base['sig_elnino'].values
sig_lanina = enso_base['sig_lanina'].values
enso_masks = [
        sig_elnino == 1,  # El Nino
        sig_lanina == 1,  # La Nina
        (sig_elnino == 0) & (sig_lanina == 0)  # Neutral
    ]

In [14]:
# Initialize global dictionaries for storing results
freq_all_world = {}  
significance_all_world = {}

# Initialize lists to store global data
freq_all = []
significance_all = []

# Process all grid points worldwide
for lat, lon in tqdm(all_points_idx_climate):  # Assuming `all_points_idx_climate` contains global points
    values = cp.array(ds_values[:, lat, lon])  # Extract climate data for this grid point
    
    # If all values are NaN, append NaN-filled arrays and continue
    if np.all(np.isnan(values)):  
        freq_all.append(np.full((3, len(season_order), 3), np.nan))
        significance_all.append(np.full((3, 2), np.nan))
        continue  

    # Reshape data to 75 years × 12 months
    values_mon = cp.reshape(values, (75, 12)).T  

    # Compute terciles (33rd and 67th percentiles) for each month
    tercile = np.nanquantile(values_mon, [0.33, 0.67], axis=1)  

    # Convert values into tercile categories (-1: below, 0: normal, 1: above)
    tercile_binary = cp.select(
        [values_mon.T <= tercile[0], values_mon.T > tercile[1]],
        [-1, 1],
        default=0
    ).flatten()[:-1]  # Flatten and remove last value to match length

    # Initialize frequency array (3 ENSO states × seasons × 3 terciles)
    frequencies = np.zeros((3, len(season_order), 3))  

    # Compute tercile frequency per ENSO state and season
    for state_idx, state_mask in enumerate(enso_masks):  
        for seas_idx, season in enumerate(season_order):  
            season_mask = seasons_all == season  
            combined_mask = state_mask & season_mask  # Mask for ENSO & season
            
            if np.any(combined_mask):  
                for tercile in [-1, 0, 1]:  
                    tercile_count = np.sum((tercile_binary == tercile) & combined_mask)  
                    frequencies[state_idx, seas_idx, tercile + 1] = tercile_count / np.sum(combined_mask)

    # Compute total number of months
    total_months = len(tercile_binary)  

    # Create masks for months above and below normal
    above_mask = tercile_binary == 1  
    below_mask = tercile_binary == -1  

    # Count total months in each category
    total_above = np.sum(above_mask)  
    total_below = np.sum(below_mask)  

    # Initialize arrays to store observed counts and p-values
    enso_months = np.zeros(3)  
    observed_above = np.zeros(3)  
    observed_below = np.zeros(3)  
    p_values = np.zeros((3, 2))  # [above, below] p-values for each ENSO state

    # Compute statistical significance using hypergeometric test
    for i, mask in enumerate(enso_masks):  
        enso_months[i] = np.sum(mask)  # Total months in this ENSO state
        observed_above[i] = np.sum(mask & above_mask)  # Above normal counts
        observed_below[i] = np.sum(mask & below_mask)  # Below normal counts

        # Compute p-value for above-normal months
        p_values[i, 0] = hypergeom.sf(
            observed_above[i] - 1,  # k-1 (observed successes minus 1)
            total_months,  # N (total population)
            total_above,  # K (total successes in population)
            enso_months[i]  # n (sample size)
        )

        # Compute p-value for below-normal months
        p_values[i, 1] = hypergeom.sf(
            observed_below[i] - 1,  # k-1 (observed successes minus 1)
            total_months,  # N (total population)
            total_below,  # K (total successes in population)
            enso_months[i]  # n (sample size)
        )

    # Store frequency and significance results for this grid point
    freq_all.append(frequencies)  
    significance_all.append(p_values)  

# Store global results
freq_all_world["global"] = freq_all  
significance_all_world["global"] = significance_all  

  0%|          | 22/2141055 [00:03<101:03:08,  5.89it/s]


TypeError: no implementation found for 'numpy.nanquantile' on types that implement __array_function__: [<class 'cupy.ndarray'>]

## Mapping

In [None]:
# Save data to a pickle file
with open(f"outcome/frequency_w_{climate_var}.pkl", "wb") as file:
    pickle.dump(freq_all_world, file)

with open(f"outcome/significance_w_{climate_var}.pkl", "wb") as file:
    pickle.dump(significance_all_world, file)

In [14]:
import pickle

# Load the frequency data
with open(f"outcome/frequency_w_{climate_var}.pkl", "rb") as file:
    freq_all_world = pickle.load(file)

# Load the significance data
with open(f"outcome/significance_w_{climate_var}.pkl", "rb") as file:
    significance_all_world = pickle.load(file)


In [15]:
def create_custom_colormaps():
    """
    Create two custom colormaps for Below Normal and Above Normal based on the image
    """
    # Below Normal colors (yellow to brown)
    below_colors = ['#f9fa04', '#e7b834', '#ce8033', '#a9451d', '#783100']
    
    # Above Normal colors (light green to blue)
    above_colors = ['#d1f8cb', '#adf79f', '#75ba6f', '#4394cc', '#0c3af3']
    
    below_cmap = ListedColormap(below_colors)
    above_cmap = ListedColormap(above_colors)
    
    return below_cmap, above_cmap

In [16]:
def create_probability_maps(data_all, index_all, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    
    base_map = np.full((1801, 3600), np.nan)
    below_cmap, above_cmap = create_custom_colormaps()
    
    projection = ccrs.PlateCarree(central_longitude=180)
    data_transform = ccrs.PlateCarree()
    
    for season in tqdm(range(12)):
        for state in range(3):
            category_map = base_map.copy()
            probability_map = base_map.copy()
            data = np.array(data_all['global'])

            for coord_idx, (lat_idx, lon_idx) in enumerate(index_all):
                probs = data[coord_idx, state, season, :]
                max_prob = np.max(probs)
                max_state = np.argmax(probs)
                
                category_map[lat_idx, lon_idx] = max_state
                probability_map[lat_idx, lon_idx] = max_prob * 100
            
            fig = plt.figure(figsize=(15, 8))
            ax = plt.axes(projection=projection)
            ax.set_title(f'State {states[state]}, Season {season_order[season]}', fontsize=14, pad=15, fontweight='bold')
            ax.coastlines(linewidth=0.5)
            ax.add_feature(cfeature.BORDERS, linewidth=0.3)
            
            below_mask = category_map == 0
            above_mask = category_map == 2
            
            below_probs = np.ma.masked_where(~below_mask, probability_map)
            ax.imshow(below_probs, transform=data_transform, extent=[0, 360, -90, 90], cmap=below_cmap, vmin=40, vmax=70)
            
            above_probs = np.ma.masked_where(~above_mask, probability_map)
            ax.imshow(above_probs, transform=data_transform, extent=[0, 360, -90, 90], cmap=above_cmap, vmin=40, vmax=70)
            
            norm = plt.Normalize(40, 70)
            sm_below = plt.cm.ScalarMappable(cmap=below_cmap, norm=norm)
            sm_above = plt.cm.ScalarMappable(cmap=above_cmap, norm=norm)
            
            cbar_below_ax = fig.add_axes([0.125, 0.05, 0.35, 0.02])
            cbar_below = fig.colorbar(sm_below, cax=cbar_below_ax, orientation='horizontal')
            cbar_below.set_label('Below Normal')
            
            cbar_above_ax = fig.add_axes([0.525, 0.05, 0.35, 0.02])
            cbar_above = fig.colorbar(sm_above, cax=cbar_above_ax, orientation='horizontal')
            cbar_above.set_label('Above Normal')
            
            fig.text(0.5, 0.15, 'Probability (%) of Most Likely Category', ha='center', va='center', fontsize=10)
            plt.subplots_adjust(bottom=0.2)
            
            filename = f'prob_{states_lower[state]}_season_{season_order[season]}.png'
            plt.savefig(os.path.join(output_dir, filename), dpi=300, bbox_inches='tight')
            plt.close()


In [20]:
def create_dual_p_value_maps(data_all, index_all, output_dir=f'outcome/map/{climate_var}'):
    """
    Create side-by-side maps showing p-values for above (blue) and below (red)
    """
    # Create colormaps
    # Blue scheme for above
    colors_above = ['#f7fbff', '#deebf7', '#c6dbef', '#9ecae1', '#6baed6', '#4292c6', '#2171b5', '#084594'][::-1]
    cmap_above = LinearSegmentedColormap.from_list('custom_blues', colors_above)
    
    # Red scheme for below
    colors_below = ['#fff5f0', '#fee0d2', '#fcbba1', '#fc9272', '#fb6a4a', '#ef3b2c', '#cb181d', '#99000d'][::-1]
    cmap_below = LinearSegmentedColormap.from_list('custom_reds', colors_below)
    
    
    for state in tqdm(range(3)):
        # Create a new map for this season and state
        base_map_above = np.full((1801, 3600), np.nan)
        base_map_below = np.full((1801, 3600), np.nan)
        
        data = np.array(data_all['global'])
        for coord_idx, (lat_idx, lon_idx) in enumerate(index_all):
            # Get probabilities for all states in this season
            probs_ab = data[coord_idx, state , 0]
            probs_bl = data[coord_idx, state , 1]
            
            base_map_above[lat_idx, lon_idx] = probs_ab
            base_map_below[lat_idx, lon_idx] = probs_bl

        # Create the plot
        fig = plt.figure(figsize=(20, 8))
        
        # Create projection
        projection = ccrs.PlateCarree(central_longitude=180)
        data_transform = ccrs.PlateCarree()
        
        # Above Normal p-values (left plot)
        ax1 = plt.subplot(121, projection=projection)
        ax1.coastlines(linewidth=0.5)
        ax1.add_feature(cfeature.BORDERS, linewidth=0.3)
        
        img1 = ax1.imshow(base_map_above,
                        transform=data_transform,
                        extent=[0, 360, -90, 90],
                        cmap=cmap_above,
                        vmin=0,
                        vmax=1)
        
        # Add colorbar for above
        cbar1 = plt.colorbar(img1, orientation='horizontal', pad=0.1)
        cbar1.set_label('P-value Above Normal')
        
        # Add gridlines
        gl1 = ax1.gridlines(draw_labels=True, linewidth=0.2, color='gray', alpha=0.5)
        gl1.top_labels = False
        gl1.right_labels = False
        
        # Title for left plot
        ax1.set_title('Above Normal P-values', fontsize=12, pad=10)
        
        # Below Normal p-values (right plot)
        ax2 = plt.subplot(122, projection=projection)
        ax2.coastlines(linewidth=0.5)
        ax2.add_feature(cfeature.BORDERS, linewidth=0.3)
        
        img2 = ax2.imshow(base_map_below,
                        transform=data_transform,
                        extent=[0, 360, -90, 90],
                        cmap=cmap_below,
                        vmin=0,
                        vmax=1)
        
        # Add colorbar for below
        cbar2 = plt.colorbar(img2, orientation='horizontal', pad=0.1)
        cbar2.set_label('P-value Below Normal')
        
        # Add gridlines
        gl2 = ax2.gridlines(draw_labels=True, linewidth=0.2, color='gray', alpha=0.5)
        gl2.top_labels = False
        gl2.right_labels = False
        
        # Title for right plot
        ax2.set_title('Below Normal P-values', fontsize=12, pad=10)
        
        filename = f'pvalue_maps_{states_lower[state]}.png'
        filepath = os.path.join(output_dir, filename)
        
        # Save the figure
        plt.savefig(filepath, 
                   dpi=300,              # Resolution
                   bbox_inches='tight',   # Trim extra white space
                   facecolor='white',     # White background
                   format='png')          # File format
        
        # Show the plot (optional)
        # plt.show()
        
        # Close the figure to free memory
        plt.close()

    

In [17]:
create_probability_maps(freq_all_world, all_points_idx_climate, output_dir=f'outcome/map/{climate_var}')

100%|██████████| 12/12 [05:49<00:00, 29.11s/it]


In [21]:
create_dual_p_value_maps(significance_all_world, all_points_idx_climate, output_dir=f'outcome/map/{climate_var}')

100%|██████████| 3/3 [00:12<00:00,  4.18s/it]
