In [None]:
import os
import sys
import pandas as pd
import numpy as np
import xarray as xr
import metpy.calc as mpcalc
from metpy.units import units
import dask
import dask.array as da
import datetime
from datetime import datetime
import math
from time import strptime
import seaborn as sns
# from windrose import WindroseAxes
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.units as munits
import matplotlib.patheffects as path_effects
import matplotlib.ticker as ticker
import matplotlib.colors as colors
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import FuncFormatter
from matplotlib.ticker import (MultipleLocator, AutoLocator, MaxNLocator,ScalarFormatter, LinearLocator )
plt.style.use('seaborn-whitegrid')

#cartopy
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.mpl.geoaxes import GeoAxes
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1 import AxesGrid

from dask.diagnostics import ProgressBar
get_ipython().run_line_magic('matplotlib', 'inline')

#display tick labels for even index positions with 1 decimal place,

# [0, 1.3, 2.6, 3.9, 5.2]
# even indices (0, 2, 4), the tick labels are displayed (0, 2.6, 5.2)
def custom_tick_formatter(x, pos):
    if pos % 2 == 0:      #
        return f"{x:.1f}" # returns even as a formatted string
    else:
        return "" # If the position is odd,

#custom xarray function to import pandas .describe    
def xarray_describe(da):
    desc = {}
    desc['count'] = da.count().values
    desc['mean'] = da.mean().values
    desc['std'] = da.std().values
    desc['min'] = da.min().values
    desc['25%'] = da.quantile(0.25).values
    desc['50%'] = da.quantile(0.5).values
    desc['75%'] = da.quantile(0.75).values
    desc['max'] = da.max().values
    return pd.DataFrame(desc, index=[da.name])


### TEMP plot

In [None]:
#========================================================================================================
# Process data
#========================================================================================================
start_date = '2004-01-01'
dsERA5_raw = xr.open_mfdataset(
'/lcrc/globalscratch/yfeng/era5_nudgede3sm_johannes_labled/era5/Daily_ERA5_Temp2m_2000thru2014.nc'
    ).sel(time=slice(start_date, None)).compute()
ERA5_T = dsERA5_raw['var167'].compute().rename('T')

#labels
colleague = pd.read_csv('/lcrc/globalscratch/yfeng/era5_nudgede3sm_johannes_labled/clusters-new.dat').dropna()
colleague = colleague.sort_values(by='time', ascending=True)
colleague['time'] = pd.to_datetime(colleague['time'])
colleague = colleague.loc[colleague['time'] <= '2015-01-01']
colleague['time'] = pd.to_datetime(colleague['time'])
# Rename the 'class' column to 'label'
colleague.rename(columns={'class': 'label'}, inplace=True)
# Convert the pandas DataFrame to an xarray Dataset
class_dataset = colleague.set_index('time').to_xarray()

# Merge the datasets
dsERA5 = xr.merge([ERA5_T, class_dataset], join='inner')
dsERA5 = dsERA5.sel(time=slice('2005-01-01', '2015-01-01'))

# Calculate the monthly average temperature for the two-year period
era5T_monthly_avg = dsERA5.T.groupby('time.month').mean(dim='time')

# Create a new DataArray for the all-year monthly mean temperature
# Note: No 'lev' in the ERA5 dimensions
era5T_monthly_mean = xr.DataArray(era5T_monthly_avg.values, dims=('month', 'lat', 'lon'), 
                                     coords={'month': era5T_monthly_avg.month.values, 
                                             'lat': dsERA5.lat, 'lon': dsERA5.lon})

# Repeat the monthly mean values for each month within the two-year period
# in order to, calculate the anomaly and match with the label array 
era5T_month_indices = dsERA5['time.month'] - 1  # Convert from 1-12 to 0-11
era5T_monthly_mean_broadcasted = era5T_monthly_mean.isel(month=era5T_month_indices)
dsERA5['T_MonthlyMean'] = era5T_monthly_mean_broadcasted

# Calculate the anomaly
era5T_anomaly = dsERA5.T - era5T_monthly_mean_broadcasted

# Assign the anomaly to the dataset
dsERA5['T_Anomaly'] = era5T_anomaly

# subset the Dataset using label
bool_class1 = (dsERA5['label'] == 1)
bool_class2 = (dsERA5['label'] == 2)
bool_class3 = (dsERA5['label'] == 3)
bool_class4 = (dsERA5['label'] == 4)

dsERA5_class1 = dsERA5.where(bool_class1, drop=True)
dsERA5_class2 = dsERA5.where(bool_class2, drop=True)
dsERA5_class3 = dsERA5.where(bool_class3, drop=True)
dsERA5_class4 = dsERA5.where(bool_class4, drop=True)

# Compute the mean for each class
dsERA5_mean_class1 = dsERA5_class1.mean('time').compute()
dsERA5_mean_class2 = dsERA5_class2.mean('time').compute()
dsERA5_mean_class3 = dsERA5_class3.mean('time').compute()
dsERA5_mean_class4 = dsERA5_class4.mean('time').compute()

### ERA5 temp mean

In [None]:
#========================================================================================================
# Plotting Monthly Mean
#========================================================================================================
vmin0 = 240
vmax0 = 300
vint0 = 5

fsize=23
lsize=21


bounds = np.arange(vmin0, vmax0+vint0, vint0)
icmap = 'jet'

cmap = colors.ListedColormap(['#EF293D', '#EF2354', '#EE206C', '#EE2086', 
                                     '#CB2891','#942A90', '#662F92', 
                                     '#283A97','#0058A9','#0073BC','#0092D5',
                                     '#00B4F0','#00B0C3','#00AB9C','#00A878',
                                     '#00A654','#00AC4D','#5EBC46','#FFF303',
                                     '#FDB417','#F6821F','#F25B22','#EF3029',])


# norm = colors.Normalize(vmin=vmin0, vmax=vmax0)

    
# Set up the projection with a specific central longitude and true scale latitude
proj = ccrs.NorthPolarStereo(central_longitude=180, true_scale_latitude=70, globe=None)

# Create a 2x2 grid of subplots with shared colorbar
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15, 15), subplot_kw={'projection': proj},
                        sharex=True, sharey=True)

# Loop through each subplot and plot the corresponding mean xarray dataset
scatter_lon = -156.8
scatter_lat = 71.3
for i, ax in enumerate(axs.flat):

    # Get the mean xarray dataset for the corresponding class
    if i == 0:
        ds_mean = dsERA5_mean_class1
        title = "ERA5 Cluster 1"
    elif i == 1:
        ds_mean = dsERA5_mean_class2
        title = "ERA5 Cluster 2"
    elif i == 2:
        ds_mean = dsERA5_mean_class3
        title = "ERA5 Cluster 3"
    else:
        ds_mean = dsERA5_mean_class4
        title = "ERA5 Cluster 4"
   
    gl = ax.gridlines(draw_labels=False, linewidth=0.5, color='gray', alpha=0.8, linestyle='--')
    gl.xlocator = plt.FixedLocator(np.arange(-180, 181, 30))
    # Manually add longitude labels around the polar projection
    for lon in np.arange(-180, 151, 30):
        x, y = ax.projection.transform_point(lon, 45, ccrs.Geodetic())
        ax.text(x, y, f'{lon}°', ha='center', va='center', transform=ax.projection,fontsize=12)
        
        # Plot the mean xarray dataset on the axis object
    im = ax.pcolormesh(ds_mean['lon'], ds_mean['lat'], ds_mean['T'], transform=ccrs.PlateCarree(), vmin=vmin0, vmax=vmax0,
                   cmap=icmap)
    
    # Add black contours to the plot
    cs = ax.contour(ds_mean['lon'], ds_mean['lat'], ds_mean['T'], levels=bounds, colors='black',
                    linewidths=1, transform=ccrs.PlateCarree())
    
    # Add coastlines to the plot
    ax.coastlines(resolution='10m', linewidth=0.3)

    # Set the title and font size
    ax.set_title(title, fontsize=fsize)
    
    # Add a point to the plot using the scatter function
    ax.scatter(scatter_lon, scatter_lat, marker='*', s=200, color='black', 
            edgecolors='black', linewidths=0.5, zorder =5, transform=ccrs.PlateCarree())
    
    for spine in ax.spines.values():
            spine.set_color('k')
            spine.set_linewidth(2)

fig.subplots_adjust(left = 0.03, right = 0.97, bottom = 0.1, top = 0.95,hspace=0.1,wspace=0.1)
cbar_ax = fig.add_axes([0.10, 0.03, 0.8, 0.03])

cb = fig.colorbar(im, cax=cbar_ax, shrink=0.7, pad=0.02, orientation='horizontal', extend='both')

cb.ax.set_title('Temperature (K)', fontsize=lsize)
cb.ax.tick_params(labelsize=lsize-5, length=0)
cb.ax.xaxis.set_ticks_position('bottom')
cb.ax.xaxis.set_label_position('bottom')
cb.ax.tick_params(axis='x', direction='out', pad=5, labelrotation=0)
cb.ax.xaxis.set_tick_params(color='black', width=1.5, which='both', pad=10)
cb.ax.xaxis.set_tick_params(size=2, which='both')
cb.set_ticks(bounds)
#---------------------------------------------------------
cb.ax.tick_params(width=1.5, length=5)
# cb.ax.xaxis.set_major_formatter(FuncFormatter(custom_tick_formatter))

figname = '/home/zhengx/Research/NSA/IMG/ERA5_05-15_Jlabel_Temp_V2.png'
plt.savefig(figname,facecolor='white', edgecolor='none')
#=-=-=-=-=-=-=-=-=-=-=-=#=-=-=-=-=-=-=-=-=-=-=-=#=-=-=-=-=-=-=-=-=-=-=-=
#========================================================================================================

### ERA5 temp anomaly

In [None]:
#========================================================================================================
# MULMENSTADT 2012 color scheme
index_c2012 = 0
#========================================================================================================
# Plotting Anomaly
fsize=23
lsize=21
#========================================================================================================
vmin0 = -4.0
vmax0 = 4.0
vint0 = 0.5

bounds = np.arange(vmin0, vmax0+vint0, vint0)
icmap = 'bwr'

cmap = colors.ListedColormap(['#EF293D', '#EF2354', '#EE206C', '#EE2086', 
                                     '#CB2891','#942A90', '#662F92', 
                                     '#283A97','#0058A9','#0073BC','#0092D5',
                                     '#00B4F0','#00B0C3','#00AB9C','#00A878',
                                     '#00A654','#00AC4D','#5EBC46','#FFF303',
                                     '#FDB417','#F6821F','#F25B22','#EF3029',])

# norm = colors.Normalize(vmin=-4, vmax=4)


if index_c2012 == 1: 
    bounds = np.linspace(-4.5, 7, 24)
    norm = colors.BoundaryNorm(bounds, cmap.N)
    # #norm = colors.Normalize(vmin=-1, vmax=1, clip=False)
    norm.vmin = vmin
    norm.vmax = vmax
    icmap = cmap
    
# Set up the projection with a specific central longitude and true scale latitude
proj = ccrs.NorthPolarStereo(central_longitude=180, true_scale_latitude=70, globe=None)

# Create a 2x2 grid of subplots with shared colorbar
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15, 15), subplot_kw={'projection': proj},
                        sharex=True, sharey=True)

# Loop through each subplot and plot the corresponding mean xarray dataset
scatter_lon = -156.8
scatter_lat = 71.3
for i, ax in enumerate(axs.flat):

    # Get the mean xarray dataset for the corresponding class
    if i == 0:
        ds_mean = dsERA5_mean_class1
        title = "ERA5 Cluster 1"
    elif i == 1:
        ds_mean = dsERA5_mean_class2
        title = "ERA5 Cluster 2"
    elif i == 2:
        ds_mean = dsERA5_mean_class3
        title = "ERA5 Cluster 3"
    else:
        ds_mean = dsERA5_mean_class4
        title = "ERA5 Cluster 4"
   
    gl = ax.gridlines(draw_labels=False, linewidth=0.5, color='gray', alpha=0.8, linestyle='--')
    gl.xlocator = plt.FixedLocator(np.arange(-180, 181, 30))
    # Manually add longitude labels around the polar projection
    for lon in np.arange(-180, 151, 30):
        x, y = ax.projection.transform_point(lon, 45, ccrs.Geodetic())
        ax.text(x, y, f'{lon}°', ha='center', va='center', transform=ax.projection,fontsize=12)
    
    # Plot the mean xarray dataset on the axis object
    im = ax.pcolormesh(ds_mean['lon'], ds_mean['lat'], ds_mean['T_Anomaly'], transform=ccrs.PlateCarree(),
                   vmin=vmin0, vmax=vmax0, cmap=icmap)
    
    # Add black contours to the plot
    cs = ax.contour(ds_mean['lon'], ds_mean['lat'], ds_mean['T_Anomaly'], levels=bounds, colors='black',
                    linewidths=1, transform=ccrs.PlateCarree())
    
    # Add coastlines to the plot
    ax.coastlines(resolution='10m', linewidth=0.3)

    # Set the title and font size
    ax.set_title(title, fontsize=fsize)
    
    # Add a point to the plot using the scatter function
    ax.scatter(scatter_lon, scatter_lat, marker='*', s=700, color='w', 
            edgecolors='k', linewidths=0.5, zorder =5, transform=ccrs.PlateCarree())
    
    for spine in ax.spines.values():
            spine.set_color('k')
            spine.set_linewidth(2)

fig.subplots_adjust(left = 0.03, right = 0.97, bottom = 0.1, top = 0.95,hspace=0.1,wspace=0.1)
cbar_ax = fig.add_axes([0.10, 0.03, 0.8, 0.03])

cb = fig.colorbar(im, cax=cbar_ax, shrink=0.7, pad=0.02, orientation='horizontal', extend='both')

cb.ax.set_title('Temperature anomaly (K)', fontsize=lsize)
cb.ax.tick_params(labelsize=lsize-5, length=0)
cb.ax.xaxis.set_ticks_position('bottom')
cb.ax.xaxis.set_label_position('bottom')
cb.ax.tick_params(axis='x', direction='out', pad=5, labelrotation=0)
cb.ax.xaxis.set_tick_params(color='black', width=1.5, which='both', pad=10)
cb.ax.xaxis.set_tick_params(size=2, which='both')
cb.set_ticks(bounds)
cb.set_ticklabels(['' if i % 1.0 != 0 else str(i) for i in bounds])
#---------------------------------------------------------
# specify ticks matching 2012 paper
string_values = np.array(['', '-4', '', '-3', '', '-2', '', '-1', '', '0', '', '1', '', '2',
       '', '3', '', '4', '', '5', '', '6', '', '7'], dtype='<U2')
if index_c2012 == 1: cb.set_ticks(bounds,labels=string_values)
#---------------------------------------------------------
cb.ax.tick_params(width=1.5, length=5)
# cb.ax.xaxis.set_major_formatter(FuncFormatter(custom_tick_formatter))

figname = '/home/zhengx/Research/NSA/IMG/ERA5_05-15_Jlabel_TempAno_V2.png'
if index_c2012 == 1: figname = '/home/zhengx/Research/NSA/IMG/ERA5_Jlabel_TempAno_C2012.png'
plt.savefig(figname,facecolor='white', edgecolor='none')
#=-=-=-=-=-=-=-=-=-=-=-=#=-=-=-=-=-=-=-=-=-=-=-=#=-=-=-=-=-=-=-=-=-=-=-=
#========================================================================================================

## Difference with E3SM

In [None]:
#========================================================================================================
# Process data
#========================================================================================================
Nug_E3SM = xr.open_mfdataset(
'/lcrc/project/land_atmos_modeling/caghili/NSA/Nudged_E3SM/TBOT_ndg/TBOT*.nc'
    ).sel(lat=slice(45, 90))
Nug_E3SM_T =  Nug_E3SM['T'].compute()

#labels
colleague = pd.read_csv('/lcrc/globalscratch/yfeng/era5_nudgede3sm_johannes_labled/clusters-new.dat').dropna()
colleague = colleague.sort_values(by='time', ascending=True)
colleague['time'] = pd.to_datetime(colleague['time'])
colleague = colleague.loc[colleague['time'] <= '2015-01-01']
colleague['time'] = pd.to_datetime(colleague['time'])
# Rename the 'class' column to 'label'
colleague.rename(columns={'class': 'label'}, inplace=True)
# Convert the pandas DataFrame to an xarray Dataset
class_dataset = colleague.set_index('time').to_xarray()
# Assuming Nug_E3SM_T is an xarray DataArray

# Convert the 'time' coordinate in Nug_E3SM_T to pandas datetime
Nug_E3SM_T['time'] = Nug_E3SM_T.indexes['time'].to_datetimeindex()
# Merge the datasets
dsE3SM = xr.merge([Nug_E3SM_T, class_dataset], join='inner')
dsE3SM = dsE3SM.sel(time=slice('2005-01-01', '2007-01-01'))

# Calculate the monthly average temperature for the two-year period
monthly_avg = dsE3SM.T.groupby('time.month').mean(dim='time')

# Create a new DataArray for the two-year monthly mean temperature
# Note: Include 'lev' in the dimensions as well
tbot_2yr_monthly_mean = xr.DataArray(monthly_avg.values, dims=('month', 'lev', 'lat', 'lon'), 
                                     coords={'month': monthly_avg.month.values, 'lev': dsE3SM.lev, 
                                             'lat': dsE3SM.lat, 'lon': dsE3SM.lon})

# Repeat the monthly mean values for each month within the two-year period
month_indices = dsE3SM['time.month'] - 1  # Convert from 1-12 to 0-11
tbot_2yr_monthly_mean_broadcasted = tbot_2yr_monthly_mean.isel(month=month_indices)
dsE3SM['T_MonthlyMean'] = tbot_2yr_monthly_mean_broadcasted

# Calculate the anomaly
anomaly = dsE3SM.T - tbot_2yr_monthly_mean_broadcasted

# Assign the anomaly to the dataset
dsE3SM['T_Anomaly'] = anomaly


bool_class1 = (dsE3SM['label'] == 1)
bool_class2 = (dsE3SM['label'] == 2)
bool_class3 = (dsE3SM['label'] == 3)
bool_class4 = (dsE3SM['label'] == 4)

dsE3SM_class1 = dsE3SM.where(bool_class1, drop=True)
dsE3SM_class2 = dsE3SM.where(bool_class2, drop=True)
dsE3SM_class3 = dsE3SM.where(bool_class3, drop=True)
dsE3SM_class4 = dsE3SM.where(bool_class4, drop=True)

# Compute the mean for each class
dsE3SM_mean_class1 = dsE3SM_class1.mean('time').compute()
dsE3SM_mean_class2 = dsE3SM_class2.mean('time').compute()
dsE3SM_mean_class3 = dsE3SM_class3.mean('time').compute()
dsE3SM_mean_class4 = dsE3SM_class4.mean('time').compute()

### E3SM-ERA5 temp mean

In [None]:
#=================================================================
# Simple Regridding Function using area averaging
# Function to perform averaging of 'B' onto the grid of 'A'
# A is coarser grid, B is finer grid
def average_to_coarser_grid(A, B):
    # Calculate the latitude and longitude resolutions from A
    lat_resolution = np.diff(A.lat.values).mean()
    lon_resolution = np.diff(A.lon.values).mean()

    # Create an empty DataArray to store the averaged values
    averaged_B = xr.full_like(A, fill_value=np.nan)

    # Iterate over the cells in A's grid
    for lat in A.lat.values:
        for lon in A.lon.values:
            # Define the boundaries of the current cell in A
            lat_min = lat - lat_resolution / 2
            lat_max = lat + lat_resolution / 2
            lon_min = lon - lon_resolution / 2
            lon_max = lon + lon_resolution / 2

            # Find the corresponding cells in B
            cells_in_B = B.where((B.lat >= lat_min) & (B.lat < lat_max) &
                                 (B.lon >= lon_min) & (B.lon < lon_max), drop=True)

            # Calculate the average of these cells
            averaged_value = cells_in_B.mean(dim=['lat', 'lon'])
            
            # Assign the averaged value to the corresponding cell in the new DataArray
            averaged_B.loc[dict(lat=lat, lon=lon)] = averaged_value

    return averaged_B

# Aligning the time dimensions if needed
# Example: Resampling B to match the time points of A
# B = B.resample(time=A.time, method='nearest')
#=================================================================

In [None]:
#========================================================================================================
# regrid ERA5 on E3SM and calc. the diff for all four class
# make sure they have some longitude format!!
# cluster 1
tmpE3SM1 = dsE3SM_mean_class1['T'].isel(lev=0); tmpE3SM1['lon']=tmpE3SM1['lon']-180 #change to -180~180
tmpERA51 = dsERA5_mean_class1['T'].isel(lat=slice(None, None, -1))
ERA5T_C1_Remap=average_to_coarser_grid(tmpE3SM1,tmpERA51)
diffT_C1=tmpE3SM1-ERA5T_C1_Remap
print('Class 1 done')
# cluster 2
tmpE3SM2 = dsE3SM_mean_class2['T'].isel(lev=0); tmpE3SM2['lon']=tmpE3SM2['lon']-180 #change to -180~180
tmpERA52 = dsERA5_mean_class2['T'].isel(lat=slice(None, None, -1))
ERA5T_C2_Remap=average_to_coarser_grid(tmpE3SM2,tmpERA52)
diffT_C2=tmpE3SM2-ERA5T_C2_Remap
print('Class 2 done')
# cluster 3
tmpE3SM3 = dsE3SM_mean_class3['T'].isel(lev=0); tmpE3SM3['lon']=tmpE3SM3['lon']-180 #change to -180~180
tmpERA53 = dsERA5_mean_class3['T'].isel(lat=slice(None, None, -1))
ERA5T_C3_Remap=average_to_coarser_grid(tmpE3SM3,tmpERA53)
diffT_C3=tmpE3SM3-ERA5T_C3_Remap
print('Class 3 done')
# cluster 4
tmpE3SM4 = dsE3SM_mean_class4['T'].isel(lev=0); tmpE3SM4['lon']=tmpE3SM4['lon']-180 #change to -180~180
tmpERA54 = dsERA5_mean_class4['T'].isel(lat=slice(None, None, -1))
ERA5T_C4_Remap=average_to_coarser_grid(tmpE3SM4,tmpERA54)
diffT_C4=tmpE3SM4-ERA5T_C4_Remap
print('Class 4 done')

#========================================================================================================


In [None]:
#========================================================================================================
# Plotting Monthly Mean
#========================================================================================================
vmin0 = 240
vmax0 = 300
vint0 = 5

fsize=23
lsize=21


bounds = np.arange(vmin0, vmax0+vint0, vint0)
icmap = 'jet'

cmap = colors.ListedColormap(['#EF293D', '#EF2354', '#EE206C', '#EE2086', 
                                     '#CB2891','#942A90', '#662F92', 
                                     '#283A97','#0058A9','#0073BC','#0092D5',
                                     '#00B4F0','#00B0C3','#00AB9C','#00A878',
                                     '#00A654','#00AC4D','#5EBC46','#FFF303',
                                     '#FDB417','#F6821F','#F25B22','#EF3029',])


# norm = colors.Normalize(vmin=vmin0, vmax=vmax0)

    
# Set up the projection with a specific central longitude and true scale latitude
proj = ccrs.NorthPolarStereo(central_longitude=180, true_scale_latitude=70, globe=None)

# Create a 2x2 grid of subplots with shared colorbar
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15, 15), subplot_kw={'projection': proj},
                        sharex=True, sharey=True)

# Loop through each subplot and plot the corresponding mean xarray dataset
scatter_lon = -156.8
scatter_lat = 71.3
for i, ax in enumerate(axs.flat):

    # Get the mean xarray dataset for the corresponding class
    if i == 0:
        ds_mean = ERA5T_C1_Remap
        title = "ERA5 Cluster 1"
    elif i == 1:
        ds_mean = ERA5T_C2_Remap
        title = "ERA5 Cluster 2"
    elif i == 2:
        ds_mean = ERA5T_C3_Remap
        title = "ERA5 Cluster 3"
    else:
        ds_mean = ERA5T_C4_Remap
        title = "ERA5 Cluster 4"
   
    gl = ax.gridlines(draw_labels=False, linewidth=0.5, color='gray', alpha=0.8, linestyle='--')
    gl.xlocator = plt.FixedLocator(np.arange(-180, 181, 30))
    # Manually add longitude labels around the polar projection
    for lon in np.arange(-180, 151, 30):
        x, y = ax.projection.transform_point(lon, 45, ccrs.Geodetic())
        ax.text(x, y, f'{lon}°', ha='center', va='center', transform=ax.projection,fontsize=12)
        
        # Plot the mean xarray dataset on the axis object
    im = ax.pcolormesh(ds_mean['lon'], ds_mean['lat'], ds_mean, transform=ccrs.PlateCarree(), vmin=vmin0, vmax=vmax0,
                   cmap=icmap)
    
    # Add black contours to the plot
    cs = ax.contour(ds_mean['lon'], ds_mean['lat'], ds_mean, levels=bounds, colors='black',
                    linewidths=1, transform=ccrs.PlateCarree())
    
    # Add coastlines to the plot
    ax.coastlines(resolution='10m', linewidth=0.3)

    # Set the title and font size
    ax.set_title(title, fontsize=fsize)
    
    # Add a point to the plot using the scatter function
    ax.scatter(scatter_lon, scatter_lat, marker='*', s=200, color='black', 
            edgecolors='black', linewidths=0.5, zorder =5, transform=ccrs.PlateCarree())
    
    for spine in ax.spines.values():
            spine.set_color('k')
            spine.set_linewidth(2)

fig.subplots_adjust(left = 0.03, right = 0.97, bottom = 0.1, top = 0.95,hspace=0.1,wspace=0.1)
cbar_ax = fig.add_axes([0.10, 0.03, 0.8, 0.03])

cb = fig.colorbar(im, cax=cbar_ax, shrink=0.7, pad=0.02, orientation='horizontal', extend='both')

cb.ax.set_title('Temperature (K)', fontsize=lsize)
cb.ax.tick_params(labelsize=lsize-5, length=0)
cb.ax.xaxis.set_ticks_position('bottom')
cb.ax.xaxis.set_label_position('bottom')
cb.ax.tick_params(axis='x', direction='out', pad=5, labelrotation=0)
cb.ax.xaxis.set_tick_params(color='black', width=1.5, which='both', pad=10)
cb.ax.xaxis.set_tick_params(size=2, which='both')
cb.set_ticks(bounds)
#---------------------------------------------------------
cb.ax.tick_params(width=1.5, length=5)
# cb.ax.xaxis.set_major_formatter(FuncFormatter(custom_tick_formatter))

figname = '/home/zhengx/Research/NSA/IMG/ERA5_05-15_Jlabel_Temp_Regrid_V2.png'
plt.savefig(figname,facecolor='white', edgecolor='none')
#=-=-=-=-=-=-=-=-=-=-=-=#=-=-=-=-=-=-=-=-=-=-=-=#=-=-=-=-=-=-=-=-=-=-=-=
#========================================================================================================

In [None]:
#========================================================================================================
# Plotting Monthly Mean
#========================================================================================================
vmin0 = -16
vmax0 = 16

fsize=23
lsize=21

bounds = np.arange(vmin0, vmax0+4, 4)
icmap = 'coolwarm'

# norm = colors.Normalize()


# Set up the projection with a specific central longitude and true scale latitude
proj = ccrs.NorthPolarStereo(central_longitude=180, true_scale_latitude=70, globe=None)

# Create a 2x2 grid of subplots with shared colorbar
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15, 15), subplot_kw={'projection': proj},
                        sharex=True, sharey=True)

# Loop through each subplot and plot the corresponding mean xarray dataset
scatter_lon = -156.8
scatter_lat = 71.3
for i, ax in enumerate(axs.flat):

    # Get the mean xarray dataset for the corresponding class
    if i == 0:
        pdiff=diffT_C1
        title = "E3SM-ERA5 Cluster 1"
    elif i == 1:
        pdiff=diffT_C2
        title = "E3SM-ERA5 Cluster 2"
    elif i == 2:
        pdiff=diffT_C3
        title = "E3SM-ERA5 Cluster 3"
    else:
        pdiff=diffT_C4
        title = "E3SM-ERA5 Cluster 4"
   
    ax.set_extent([-180, 180, 45, 90], crs=ccrs.PlateCarree())
    
    gl = ax.gridlines(draw_labels=False, linewidth=0.5, color='gray', alpha=0.8, linestyle='--')
    gl.xlocator = plt.FixedLocator(np.arange(-180, 181, 30))
    # Manually add longitude labels around the polar projection
    for lon in np.arange(-180, 151, 30):
        x, y = ax.projection.transform_point(lon, 45, ccrs.Geodetic())
        ax.text(x, y, f'{lon}°', ha='center', va='center', transform=ax.projection,fontsize=12)
    
    
    # Plot the mean xarray dataset on the axis object
    im = ax.pcolormesh(pdiff['lon'], pdiff['lat'], pdiff, transform=ccrs.PlateCarree(), vmin=vmin0, vmax=vmax0,
                   cmap=icmap)
    
    # Add black contours to the plot
    cs = ax.contour(pdiff['lon'], pdiff['lat'], pdiff, levels=bounds, colors='black',
                    linewidths=1, transform=ccrs.PlateCarree())
    
    # Add coastlines to the plot
    ax.coastlines(resolution='10m', linewidth=0.3)

    # Set the title and font size
    ax.set_title(title, fontsize=fsize)
    
    # Add a point to the plot using the scatter function
    ax.scatter(scatter_lon, scatter_lat, marker='*', s=200, color='black', 
            edgecolors='black', linewidths=0.5, zorder =5, transform=ccrs.PlateCarree())
    
    for spine in ax.spines.values():
            spine.set_color('k')
            spine.set_linewidth(2)

fig.subplots_adjust(left = 0.03, right = 0.97, bottom = 0.1, top = 0.95,hspace=0.1,wspace=0.1)
cbar_ax = fig.add_axes([0.10, 0.03, 0.8, 0.03])

cb = fig.colorbar(im, cax=cbar_ax, shrink=0.7, pad=0.02, orientation='horizontal', extend='both')

cb.ax.set_title('Temperature Difference (K)', fontsize=lsize)
cb.ax.tick_params(labelsize=lsize-5, length=0)
cb.ax.xaxis.set_ticks_position('bottom')
cb.ax.xaxis.set_label_position('bottom')
cb.ax.tick_params(axis='x', direction='out', pad=5, labelrotation=0)
cb.ax.xaxis.set_tick_params(color='black', width=1.5, which='both', pad=10)
cb.ax.xaxis.set_tick_params(size=2, which='both')
cb.set_ticks(bounds)
#---------------------------------------------------------
cb.ax.tick_params(width=1.5, length=5)
# cb.ax.xaxis.set_major_formatter(FuncFormatter(custom_tick_formatter))

figname = '/home/zhengx/Research/NSA/IMG/Diff_E3SM-ERA5_05-15_Jlabel_Temp_Regrid_V2.png'
plt.savefig(figname,facecolor='white', edgecolor='none')
#=-=-=-=-=-=-=-=-=-=-=-=#=-=-=-=-=-=-=-=-=-=-=-=#=-=-=-=-=-=-=-=-=-=-=-=
#========================================================================================================

In [None]:
B=diffT_C1
lloc = B.where((B.lat >= lat_min) & (B.lat < lat_max) &
                     (B.lon >= lon_min) & (B.lon < lon_max), drop=True)