In [None]:
import os
import sys
import pandas as pd
import numpy as np
import xarray as xr
import metpy.calc as mpcalc
from metpy.units import units
import dask
import dask.array as da
import datetime
from datetime import datetime
import math
from time import strptime
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.units as munits
import matplotlib.patheffects as path_effects
import matplotlib.ticker as ticker
import matplotlib.colors as colors
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import FuncFormatter
from matplotlib.ticker import (MultipleLocator, AutoLocator, MaxNLocator,ScalarFormatter, LinearLocator )
plt.style.use('seaborn-whitegrid')

#cartopy
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.mpl.geoaxes import GeoAxes
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1 import AxesGrid

from dask.diagnostics import ProgressBar
get_ipython().run_line_magic('matplotlib', 'inline')

# Suppress scientific notation
np.set_printoptions(suppress=True)

#display tick labels for even index positions with 1 decimal place,

# [0, 1.3, 2.6, 3.9, 5.2]
# even indices (0, 2, 4), the tick labels are displayed (0, 2.6, 5.2)
def custom_tick_formatter(x, pos):
    if pos % 2 == 0:      #
        return f"{x:.1f}" # returns even as a formatted string
    else:
        return "" # If the position is odd,

#custom xarray function to import pandas .describe    
def xarray_describe(da):
    desc = {}
    desc['count'] = da.count().values
    desc['mean'] = da.mean().values
    desc['std'] = da.std().values
    desc['min'] = da.min().values
    desc['25%'] = da.quantile(0.25).values
    desc['50%'] = da.quantile(0.5).values
    desc['75%'] = da.quantile(0.75).values
    desc['max'] = da.max().values
    return pd.DataFrame(desc, index=[da.name])


### TEMP plots

In [None]:
#========================================================================================================
# Process data
#========================================================================================================
Nug_E3SM = xr.open_mfdataset(
'/lcrc/project/land_atmos_modeling/caghili/NSA/Nudged_E3SM/TBOT_ndg/TBOT*.nc'
    ).sel(lat=slice(45, 90))
Nug_E3SM_T =  Nug_E3SM['T'].compute()

#labels
colleague = pd.read_csv('/lcrc/globalscratch/yfeng/era5_nudgede3sm_johannes_labled/clusters-new.dat').dropna()
colleague = colleague.sort_values(by='time', ascending=True)
colleague['time'] = pd.to_datetime(colleague['time'])
colleague = colleague.loc[colleague['time'] <= '2015-01-01']
colleague['time'] = pd.to_datetime(colleague['time'])
# Rename the 'class' column to 'label'
colleague.rename(columns={'class': 'label'}, inplace=True)
# Convert the pandas DataFrame to an xarray Dataset
class_dataset = colleague.set_index('time').to_xarray()
# Assuming Nug_E3SM_T is an xarray DataArray

# Convert the 'time' coordinate in Nug_E3SM_T to pandas datetime
Nug_E3SM_T['time'] = Nug_E3SM_T.indexes['time'].to_datetimeindex()
# Merge the datasets
dsE3SM = xr.merge([Nug_E3SM_T, class_dataset], join='inner')
dsE3SM = dsE3SM.sel(time=slice('2005-01-01', '2007-01-01'))

# Calculate the monthly average temperature for the two-year period
monthly_avg = dsE3SM.T.groupby('time.month').mean(dim='time')

# Create a new DataArray for the two-year monthly mean temperature
# Note: Include 'lev' in the dimensions as well
tbot_2yr_monthly_mean = xr.DataArray(monthly_avg.values, dims=('month', 'lev', 'lat', 'lon'), 
                                     coords={'month': monthly_avg.month.values, 'lev': dsE3SM.lev, 
                                             'lat': dsE3SM.lat, 'lon': dsE3SM.lon})

# Repeat the monthly mean values for each month within the two-year period
month_indices = dsE3SM['time.month'] - 1  # Convert from 1-12 to 0-11
tbot_2yr_monthly_mean_broadcasted = tbot_2yr_monthly_mean.isel(month=month_indices)
dsE3SM['T_MonthlyMean'] = tbot_2yr_monthly_mean_broadcasted

# Calculate the anomaly
anomaly = dsE3SM.T - tbot_2yr_monthly_mean_broadcasted

# Assign the anomaly to the dataset
dsE3SM['T_Anomaly'] = anomaly


bool_class1 = (dsE3SM['label'] == 1)
bool_class2 = (dsE3SM['label'] == 2)
bool_class3 = (dsE3SM['label'] == 3)
bool_class4 = (dsE3SM['label'] == 4)

dsE3SM_class1 = dsE3SM.where(bool_class1, drop=True)
dsE3SM_class2 = dsE3SM.where(bool_class2, drop=True)
dsE3SM_class3 = dsE3SM.where(bool_class3, drop=True)
dsE3SM_class4 = dsE3SM.where(bool_class4, drop=True)

# Compute the mean for each class
dsE3SM_mean_class1 = dsE3SM_class1.mean('time').compute()
dsE3SM_mean_class2 = dsE3SM_class2.mean('time').compute()
dsE3SM_mean_class3 = dsE3SM_class3.mean('time').compute()
dsE3SM_mean_class4 = dsE3SM_class4.mean('time').compute()

In [None]:
#========================================================================================================
# Plotting Monthly Mean
#========================================================================================================
vmin0 = 240
vmax0 = 300
vint0 = 5

fsize=23
lsize=21


bounds = np.arange(vmin0, vmax0+vint0, vint0)
icmap = 'jet'

cmap = colors.ListedColormap(['#EF293D', '#EF2354', '#EE206C', '#EE2086', 
                                     '#CB2891','#942A90', '#662F92', 
                                     '#283A97','#0058A9','#0073BC','#0092D5',
                                     '#00B4F0','#00B0C3','#00AB9C','#00A878',
                                     '#00A654','#00AC4D','#5EBC46','#FFF303',
                                     '#FDB417','#F6821F','#F25B22','#EF3029',])


# norm = colors.Normalize(vmin=vmin0, vmax=vmax0)

    
# Set up the projection with a specific central longitude and true scale latitude
proj = ccrs.NorthPolarStereo(central_longitude=180, true_scale_latitude=70, globe=None)

# Create a 2x2 grid of subplots with shared colorbar
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15, 15), subplot_kw={'projection': proj},
                        sharex=True, sharey=True)

# Loop through each subplot and plot the corresponding mean xarray dataset
scatter_lon = -156.8
scatter_lat = 71.3
for i, ax in enumerate(axs.flat):

    # Get the mean xarray dataset for the corresponding class
    if i == 0:
        ds_mean = dsE3SM_mean_class1
        title = "Nudged E3SM Cluster 1"
    elif i == 1:
        ds_mean = dsE3SM_mean_class2
        title = "Nudged E3SM Cluster 2"
    elif i == 2:
        ds_mean = dsE3SM_mean_class3
        title = "Nudged E3SM Cluster 3"
    else:
        ds_mean = dsE3SM_mean_class4
        title = "Nudged E3SM Cluster 4"
   
    gl = ax.gridlines(draw_labels=False, linewidth=0.5, color='gray', alpha=0.8, linestyle='--')
    gl.xlocator = plt.FixedLocator(np.arange(-180, 181, 30))
    # Manually add longitude labels around the polar projection
    for lon in np.arange(-180, 151, 30):
        x, y = ax.projection.transform_point(lon, 45, ccrs.Geodetic())
        ax.text(x, y, f'{lon}°', ha='center', va='center', transform=ax.projection,fontsize=12)
        
        # Plot the mean xarray dataset on the axis object
    im = ax.pcolormesh(ds_mean['lon'], ds_mean['lat'], ds_mean['T'].isel(lev=0), transform=ccrs.PlateCarree(), vmin=vmin0, vmax=vmax0,
                   cmap=icmap)
    
    # Add black contours to the plot
    cs = ax.contour(ds_mean['lon'], ds_mean['lat'], ds_mean['T'].isel(lev=0), levels=bounds, colors='black',
                    linewidths=1, transform=ccrs.PlateCarree())
    
    # Add coastlines to the plot
    ax.coastlines(resolution='10m', linewidth=0.3)

    # Set the title and font size
    ax.set_title(title, fontsize=fsize)
    
    # Add a point to the plot using the scatter function
    ax.scatter(scatter_lon, scatter_lat, marker='*', s=700, color='white', 
            edgecolors='black', linewidths=0.5, zorder =5, transform=ccrs.PlateCarree())
    
    for spine in ax.spines.values():
            spine.set_color('k')
            spine.set_linewidth(2)

fig.subplots_adjust(left = 0.03, right = 0.97, bottom = 0.1, top = 0.95,hspace=0.1,wspace=0.1)
cbar_ax = fig.add_axes([0.10, 0.03, 0.8, 0.03])

cb = fig.colorbar(im, cax=cbar_ax, shrink=0.7, pad=0.02, orientation='horizontal', extend='both')

cb.ax.set_title('Temperature (K)', fontsize=lsize)
cb.ax.tick_params(labelsize=lsize-5, length=0)
cb.ax.xaxis.set_ticks_position('bottom')
cb.ax.xaxis.set_label_position('bottom')
cb.ax.tick_params(axis='x', direction='out', pad=5, labelrotation=0)
cb.ax.xaxis.set_tick_params(color='black', width=1.5, which='both', pad=10)
cb.ax.xaxis.set_tick_params(size=2, which='both')
cb.set_ticks(bounds)
#---------------------------------------------------------
cb.ax.tick_params(width=1.5, length=5)
# cb.ax.xaxis.set_major_formatter(FuncFormatter(custom_tick_formatter))

figname = '/home/zhengx/Research/NSA/IMG/NdgE3SM_Jlabel_Temp_V2.png'
# plt.savefig(figname,facecolor='white', edgecolor='none')
#=-=-=-=-=-=-=-=-=-=-=-=#=-=-=-=-=-=-=-=-=-=-=-=#=-=-=-=-=-=-=-=-=-=-=-=
#========================================================================================================

In [None]:
#========================================================================================================
# MULMENSTADT 2012 color scheme
index_c2012 = 0 # set to 1 if use M2012 scheme
#========================================================================================================
# Plotting Anomaly
fsize=23
lsize=21
#========================================================================================================
vmin0 = -4.0
vmax0 = 4.0
vint0 = 0.5

bounds = np.arange(vmin0, vmax0+vint0, vint0)
icmap = 'bwr'

cmap = colors.ListedColormap(['#EF293D', '#EF2354', '#EE206C', '#EE2086', 
                                     '#CB2891','#942A90', '#662F92', 
                                     '#283A97','#0058A9','#0073BC','#0092D5',
                                     '#00B4F0','#00B0C3','#00AB9C','#00A878',
                                     '#00A654','#00AC4D','#5EBC46','#FFF303',
                                     '#FDB417','#F6821F','#F25B22','#EF3029',])

# norm = colors.Normalize(vmin=-4, vmax=4)


if index_c2012 == 1: 
    bounds = np.linspace(-4.5, 7, 24)
    norm = colors.BoundaryNorm(bounds, cmap.N)
    # #norm = colors.Normalize(vmin=-1, vmax=1, clip=False)
    norm.vmin = vmin
    norm.vmax = vmax
    icmap = cmap
    
# Set up the projection with a specific central longitude and true scale latitude
proj = ccrs.NorthPolarStereo(central_longitude=180, true_scale_latitude=70, globe=None)

# Create a 2x2 grid of subplots with shared colorbar
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15, 15), subplot_kw={'projection': proj},
                        sharex=True, sharey=True)

# Loop through each subplot and plot the corresponding mean xarray dataset
scatter_lon = -156.8
scatter_lat = 71.3
for i, ax in enumerate(axs.flat):

    # Get the mean xarray dataset for the corresponding class
    if i == 0:
        ds_mean = dsE3SM_mean_class1
        title = "Nudged E3SM Cluster 1"
    elif i == 1:
        ds_mean = dsE3SM_mean_class2
        title = "Nudged E3SM Cluster 2"
    elif i == 2:
        ds_mean = dsE3SM_mean_class3
        title = "Nudged E3SM Cluster 3"
    else:
        ds_mean = dsE3SM_mean_class4
        title = "Nudged E3SM Cluster 4"
   
    gl = ax.gridlines(draw_labels=False, linewidth=0.5, color='gray', alpha=0.8, linestyle='--')
    gl.xlocator = plt.FixedLocator(np.arange(-180, 181, 30))
    # Manually add longitude labels around the polar projection
    for lon in np.arange(-180, 151, 30):
        x, y = ax.projection.transform_point(lon, 45, ccrs.Geodetic())
        ax.text(x, y, f'{lon}°', ha='center', va='center', transform=ax.projection,fontsize=12)
    
    # Plot the mean xarray dataset on the axis object
    im = ax.pcolormesh(ds_mean['lon'], ds_mean['lat'], ds_mean['T_Anomaly'].isel(lev=0), transform=ccrs.PlateCarree(),
                   vmin=vmin0, vmax=vmax0, cmap=icmap)
    
    # Add black contours to the plot
    cs = ax.contour(ds_mean['lon'], ds_mean['lat'], ds_mean['T_Anomaly'].isel(lev=0), levels=bounds, colors='black',
                    linewidths=1, transform=ccrs.PlateCarree())
    
    # Add coastlines to the plot
    ax.coastlines(resolution='10m', linewidth=0.3)

    # Set the title and font size
    ax.set_title(title, fontsize=fsize)
    
    # Add a point to the plot using the scatter function
    ax.scatter(scatter_lon, scatter_lat, marker='*', s=700, color='white', 
            edgecolors='k', linewidths=0.5, zorder =5, transform=ccrs.PlateCarree())
    
    for spine in ax.spines.values():
            spine.set_color('k')
            spine.set_linewidth(2)

fig.subplots_adjust(left = 0.03, right = 0.97, bottom = 0.1, top = 0.95,hspace=0.1,wspace=0.1)
cbar_ax = fig.add_axes([0.10, 0.03, 0.8, 0.03])

cb = fig.colorbar(im, cax=cbar_ax, shrink=0.7, pad=0.02, orientation='horizontal', extend='both')

cb.ax.set_title('Temperature anomaly (K)', fontsize=lsize)
cb.ax.tick_params(labelsize=lsize-5, length=0)
cb.ax.xaxis.set_ticks_position('bottom')
cb.ax.xaxis.set_label_position('bottom')
cb.ax.tick_params(axis='x', direction='out', pad=5, labelrotation=0)
cb.ax.xaxis.set_tick_params(color='black', width=1.5, which='both', pad=10)
cb.ax.xaxis.set_tick_params(size=2, which='both')
cb.set_ticks(bounds)
cb.set_ticklabels(['' if i % 1.0 != 0 else str(i) for i in bounds])
#---------------------------------------------------------
# specify ticks matching 2012 paper
string_values = np.array(['', '-4', '', '-3', '', '-2', '', '-1', '', '0', '', '1', '', '2',
       '', '3', '', '4', '', '5', '', '6', '', '7'], dtype='<U2')
if index_c2012 == 1: cb.set_ticks(bounds,labels=string_values)
#---------------------------------------------------------
cb.ax.tick_params(width=1.5, length=5)
# cb.ax.xaxis.set_major_formatter(FuncFormatter(custom_tick_formatter))

figname = '/home/zhengx/Research/NSA/IMG/NdgE3SM_Jlabel_TempAno_V2.png'
if index_c2012 == 1: figname = '/home/zhengx/Research/NSA/IMG/NdgE3SM_Jlabel_TempAno_C2012.png'
plt.savefig(figname,facecolor='white', edgecolor='none')
#=-=-=-=-=-=-=-=-=-=-=-=#=-=-=-=-=-=-=-=-=-=-=-=#=-=-=-=-=-=-=-=-=-=-=-=
#========================================================================================================

### PWV

In [None]:
#========================================================================================================
Nug_E3SM = xr.open_mfdataset(
'/lcrc/project/land_atmos_modeling/caghili/NSA/Nudged_E3SM/TMQ_ndg/TMQ*.nc'
    ).sel(lat=slice(45, 90))
Nug_E3SM_T =  Nug_E3SM['TMQ'].compute()

#labels
colleague = pd.read_csv('/lcrc/project/land_atmos_modeling/caghili/NSA/Obs/clusters-new.dat').dropna()
colleague = colleague.sort_values(by='time', ascending=True)
colleague['time'] = pd.to_datetime(colleague['time'])
colleague = colleague.loc[colleague['time'] <= '2015-01-01']
colleague['time'] = pd.to_datetime(colleague['time'])
# Rename the 'class' column to 'label'
colleague.rename(columns={'class': 'label'}, inplace=True)
# Convert the pandas DataFrame to an xarray Dataset
class_dataset = colleague.set_index('time').to_xarray()
# Assuming Nug_E3SM_T is an xarray DataArray

# Convert the 'time' coordinate in Nug_E3SM_T to pandas datetime
Nug_E3SM_T['time'] = Nug_E3SM_T.indexes['time'].to_datetimeindex()
# Merge the datasets
dsE3SM = xr.merge([Nug_E3SM_T, class_dataset], join='inner')
dsE3SM = dsE3SM.sel(time=slice('2005-01-01', '2007-01-01'))

## Calculate the monthly average temperature for the two-year period
monthly_avg = dsE3SM.TMQ.groupby('time.month').mean(dim='time')
#
## Create a new DataArray for the two-year monthly mean temperature
## Note: Include 'lev' in the dimensions as well
tbot_2yr_monthly_mean = xr.DataArray(monthly_avg.values, dims=('month',  'lat', 'lon'), 
                                    coords={'month': monthly_avg.month.values,  
                                            'lat': dsE3SM.lat, 'lon': dsE3SM.lon})

# Repeat the monthly mean values for each month within the two-year period
month_indices = dsE3SM['time.month'] - 1  # Convert from 1-12 to 0-11
tbot_2yr_monthly_mean_broadcasted = tbot_2yr_monthly_mean.isel(month=month_indices)
dsE3SM['TMQ_MonthlyMean'] = tbot_2yr_monthly_mean_broadcasted

# Calculate the anomaly
anomaly = dsE3SM.TMQ - tbot_2yr_monthly_mean_broadcasted

# Assign the anomaly to the dataset
dsE3SM['TMQ_Anomaly'] = anomaly

bool_class1 = (dsE3SM['label'] == 1)
bool_class2 = (dsE3SM['label'] == 2)
bool_class3 = (dsE3SM['label'] == 3)
bool_class4 = (dsE3SM['label'] == 4)

dsE3SM_class1 = dsE3SM.where(bool_class1, drop=True)
dsE3SM_class2 = dsE3SM.where(bool_class2, drop=True)
dsE3SM_class3 = dsE3SM.where(bool_class3, drop=True)
dsE3SM_class4 = dsE3SM.where(bool_class4, drop=True)

# Compute the mean for each class
dsE3SM_mean_class1 = dsE3SM_class1.mean('time').compute()
dsE3SM_mean_class2 = dsE3SM_class2.mean('time').compute()
dsE3SM_mean_class3 = dsE3SM_class3.mean('time').compute()
dsE3SM_mean_class4 = dsE3SM_class4.mean('time').compute()


In [None]:
#========================================================================================================
# Plotting Monthly Mean
#========================================================================================================
vmin0 = 0
vmax0 = 20
vint0 = 2

fsize=23
lsize=21


bounds = np.arange(vmin0, vmax0+vint0, vint0)
icmap = 'jet'

cmap = colors.ListedColormap(['#EF293D', '#EF2354', '#EE206C', '#EE2086', 
                                     '#CB2891','#942A90', '#662F92', 
                                     '#283A97','#0058A9','#0073BC','#0092D5',
                                     '#00B4F0','#00B0C3','#00AB9C','#00A878',
                                     '#00A654','#00AC4D','#5EBC46','#FFF303',
                                     '#FDB417','#F6821F','#F25B22','#EF3029',])


# norm = colors.Normalize(vmin=vmin0, vmax=vmax0)

    
# Set up the projection with a specific central longitude and true scale latitude
proj = ccrs.NorthPolarStereo(central_longitude=180, true_scale_latitude=70, globe=None)

# Create a 2x2 grid of subplots with shared colorbar
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15, 15), subplot_kw={'projection': proj},
                        sharex=True, sharey=True)

# Loop through each subplot and plot the corresponding mean xarray dataset
scatter_lon = -156.8
scatter_lat = 71.3
for i, ax in enumerate(axs.flat):

    # Get the mean xarray dataset for the corresponding class
    if i == 0:
        ds_mean = dsE3SM_mean_class1
        title = "Nudged E3SM Cluster 1"
    elif i == 1:
        ds_mean = dsE3SM_mean_class2
        title = "Nudged E3SM Cluster 2"
    elif i == 2:
        ds_mean = dsE3SM_mean_class3
        title = "Nudged E3SM Cluster 3"
    else:
        ds_mean = dsE3SM_mean_class4
        title = "Nudged E3SM Cluster 4"
   
    gl = ax.gridlines(draw_labels=False, linewidth=0.5, color='gray', alpha=0.8, linestyle='--')
    gl.xlocator = plt.FixedLocator(np.arange(-180, 181, 30))
    # Manually add longitude labels around the polar projection
    for lon in np.arange(-180, 151, 30):
        x, y = ax.projection.transform_point(lon, 45, ccrs.Geodetic())
        ax.text(x, y, f'{lon}°', ha='center', va='center', transform=ax.projection,fontsize=12)
        
        # Plot the mean xarray dataset on the axis object
    im = ax.pcolormesh(ds_mean['lon'], ds_mean['lat'], ds_mean['TMQ'], transform=ccrs.PlateCarree(), vmin=vmin0, vmax=vmax0,
                   cmap=icmap)
    
    # Add black contours to the plot
    cs = ax.contour(ds_mean['lon'], ds_mean['lat'], ds_mean['TMQ'], levels=bounds, colors='black',
                    linewidths=1, transform=ccrs.PlateCarree())
    
    # Add coastlines to the plot
    ax.coastlines(resolution='10m', linewidth=0.3)

    # Set the title and font size
    ax.set_title(title, fontsize=fsize)
    
    # Add a point to the plot using the scatter function
    ax.scatter(scatter_lon, scatter_lat, marker='*', s=200, color='black', 
            edgecolors='black', linewidths=0.5, zorder =5, transform=ccrs.PlateCarree())
    
    for spine in ax.spines.values():
            spine.set_color('k')
            spine.set_linewidth(2)

fig.subplots_adjust(left = 0.03, right = 0.97, bottom = 0.1, top = 0.95,hspace=0.1,wspace=0.1)
cbar_ax = fig.add_axes([0.10, 0.03, 0.8, 0.03])

cb = fig.colorbar(im, cax=cbar_ax, shrink=0.7, pad=0.02, orientation='horizontal', extend='both')

cb.ax.set_title('Precipitable Water Vapor (kg $m^{-2}$)', fontsize=lsize)
cb.ax.tick_params(labelsize=lsize-5, length=0)
cb.ax.xaxis.set_ticks_position('bottom')
cb.ax.xaxis.set_label_position('bottom')
cb.ax.tick_params(axis='x', direction='out', pad=5, labelrotation=0)
cb.ax.xaxis.set_tick_params(color='black', width=1.5, which='both', pad=10)
cb.ax.xaxis.set_tick_params(size=2, which='both')
cb_bounds=np.arange(vmin0, vmax0+1, 1)
cb.set_ticks(cb_bounds)
cb.set_ticklabels(['' if i % 2.0 != 0 else str(i) for i in cb_bounds])
#---------------------------------------------------------
cb.ax.tick_params(width=1.5, length=5)
# cb.ax.xaxis.set_major_formatter(FuncFormatter(custom_tick_formatter))

figname = '/home/zhengx/Research/NSA/IMG/NdgE3SM_Jlabel_TMQ_V2.png'
plt.savefig(figname,facecolor='white', edgecolor='none')
#=-=-=-=-=-=-=-=-=-=-=-=#=-=-=-=-=-=-=-=-=-=-=-=#=-=-=-=-=-=-=-=-=-=-=-=
#========================================================================================================

In [None]:

#========================================================================================================
# Plotting Anomaly
#========================================================================================================
vmin0 = -3
vmax0 = 3
vint0 = 0.2
# Plotting Anomaly
fsize=23
lsize=21

bounds = np.arange(vmin0, vmax0+vint0, vint0)
icmap = 'BrBG'

cmap = colors.ListedColormap(['#EE2B38','#EE294D','#EE2562','#EE2776','#E9278C','#C12B90','#912A8E','#692B8E','#3C3492',
                              '#124DA1','#0064B0','#007BC1','#0097D9','#00B4F0','#00B0C9','#00ADA8','#00A887','#00A666',
                              '#00A650','#00AE4D','#60BB47','#A2CD3A','#F1E913','#FFD209','#F99C1C','#F47620','#F05526','#EE3229',])


norm = colors.Normalize(vmin=-2, vmax=3.6)

if index_c2012 == 1: 
    bounds = np.arange(-2, 3.8, 0.2)
    norm = colors.BoundaryNorm(bounds, cmap.N)
    norm.vmin = vmin
    norm.vmax = vmax
    icmap = cmap

# Set up the projection with a specific central longitude and true scale latitude
proj = ccrs.NorthPolarStereo(central_longitude=180, true_scale_latitude=70, globe=None)

# Create a 2x2 grid of subplots with shared colorbar
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15, 15), subplot_kw={'projection': proj},
                        sharex=True, sharey=True)

# Loop through each subplot and plot the corresponding mean xarray dataset
scatter_lon = -156.8
scatter_lat = 71.3
for i, ax in enumerate(axs.flat):

    # Get the mean xarray dataset for the corresponding class
    if i == 0:
        ds_mean = dsE3SM_mean_class1
        title = "Nudged E3SM Cluster 1"
    elif i == 1:
        ds_mean = dsE3SM_mean_class2
        title = "Nudged E3SM Cluster 2"
    elif i == 2:
        ds_mean = dsE3SM_mean_class3
        title = "Nudged E3SM Cluster 3"
    else:
        ds_mean = dsE3SM_mean_class4
        title = "Nudged E3SM Cluster 4"

   
    gl = ax.gridlines(draw_labels=False, linewidth=0.5, color='gray', alpha=0.8, linestyle='--')
    gl.xlocator = plt.FixedLocator(np.arange(-180, 181, 30))
    # Manually add longitude labels around the polar projection
    for lon in np.arange(-180, 151, 30):
        x, y = ax.projection.transform_point(lon, 45, ccrs.Geodetic())
        ax.text(x, y, f'{lon}°', ha='center', va='center', transform=ax.projection,fontsize=12)
    
    
    # Plot the mean xarray dataset on the axis object
    im = ax.pcolormesh(ds_mean['lon'], ds_mean['lat'], ds_mean['TMQ_Anomaly'], transform=ccrs.PlateCarree(),
                   cmap=icmap, vmin=vmin0, vmax=vmax0)
    
    # Add black contours to the plot
    cs = ax.contour(ds_mean['lon'], ds_mean['lat'], ds_mean['TMQ_Anomaly'], levels=bounds, colors='black',
                    linewidths=0.5, transform=ccrs.PlateCarree())
    
    # Add coastlines to the plot
    ax.coastlines(resolution='10m', linewidth=0.3)

    # Set the title and font size
    ax.set_title(title, fontsize=fsize)
    
    # Add a point to the plot using the scatter function
    ax.scatter(scatter_lon, scatter_lat, marker='*', s=700, color='w', 
            edgecolors='k', linewidths=0.5, zorder =5, transform=ccrs.PlateCarree())

    for spine in ax.spines.values():
            spine.set_color('k')
            spine.set_linewidth(2)

fig.subplots_adjust(left = 0.03, right = 0.97, bottom = 0.1, top = 0.95,hspace=0.1,wspace=0.1)
cbar_ax = fig.add_axes([0.10, 0.03, 0.8, 0.03])

cb = fig.colorbar(im, cax=cbar_ax, shrink=0.7, pad=0.02, orientation='horizontal', extend='both')

cb.ax.set_title('Precipitable Water Vapor anomaly (kg $m^{-2}$)', fontsize=lsize)
cb.ax.tick_params(labelsize=lsize-5, length=0)
cb.ax.xaxis.set_ticks_position('bottom')
cb.ax.xaxis.set_label_position('bottom')
cb.ax.tick_params(axis='x', direction='out', pad=5, labelrotation=0)
cb.ax.xaxis.set_tick_params(color='black', width=1.5, which='both', pad=10)
cb.ax.xaxis.set_tick_params(size=2, which='both')
cb.set_ticks(bounds)
cb.set_ticklabels(['' if i % 1 > 0.001 else '%4.1f' % i for i in bounds])
#---------------------------------------------------------
if index_c2012 == 1: cb.set_ticks(bounds,labels=string_values)
#---------------------------------------------------------
cb.ax.tick_params(width=1.5, length=5)

figname = '/home/zhengx/Research/NSA/IMG/NdgE3SM_Jlabel_TMQAno_V2.png'
if index_c2012 == 1: figname = '/home/zhengx/Research/NSA/IMG/NdgE3SM_Jlabel_TMQano_C2012.png'
plt.savefig(figname,facecolor='white', edgecolor='none')
#=-=-=-=-=-=-=-=-=-=-=-=#=-=-=-=-=-=-=-=-=-=-=-=#=-=-=-=-=-=-=-=-=-=-=-=

### SLP

In [None]:
Nug_E3SM = xr.open_mfdataset(
'/lcrc/project/land_atmos_modeling/caghili/NSA/Nudged_E3SM/PSL_ndg/*.nc'
    ).sel(lat=slice(45, 90))
Nug_E3SM_T =  Nug_E3SM['PSL'].compute()
Nug_E3SM_T = Nug_E3SM_T / 100 #pa to hpa


#labels
colleague = pd.read_csv('/lcrc/project/land_atmos_modeling/caghili/NSA/Obs/clusters-new.dat').dropna()
colleague = colleague.sort_values(by='time', ascending=True)
colleague['time'] = pd.to_datetime(colleague['time'])
colleague = colleague.loc[colleague['time'] <= '2015-01-01']
colleague['time'] = pd.to_datetime(colleague['time'])
# Rename the 'class' column to 'label'
colleague.rename(columns={'class': 'label'}, inplace=True)
# Convert the pandas DataFrame to an xarray Dataset
class_dataset = colleague.set_index('time').to_xarray()
# Assuming Nug_E3SM_T is an xarray DataArray

# Convert the 'time' coordinate in Nug_E3SM_T to pandas datetime
Nug_E3SM_T['time'] = Nug_E3SM_T.indexes['time'].to_datetimeindex()
# Merge the datasets
dsE3SM = xr.merge([Nug_E3SM_T, class_dataset], join='inner')
dsE3SM = dsE3SM.sel(time=slice('2005-01-01', '2007-01-01'))

bool_class1 = (dsE3SM['label'] == 1)
bool_class2 = (dsE3SM['label'] == 2)
bool_class3 = (dsE3SM['label'] == 3)
bool_class4 = (dsE3SM['label'] == 4)

dsE3SM_class1 = dsE3SM.where(bool_class1, drop=True)
dsE3SM_class2 = dsE3SM.where(bool_class2, drop=True)
dsE3SM_class3 = dsE3SM.where(bool_class3, drop=True)
dsE3SM_class4 = dsE3SM.where(bool_class4, drop=True)

# Compute the mean for each class
dsE3SM_mean_class1 = dsE3SM_class1.mean('time').compute()
dsE3SM_mean_class2 = dsE3SM_class2.mean('time').compute()
dsE3SM_mean_class3 = dsE3SM_class3.mean('time').compute()
dsE3SM_mean_class4 = dsE3SM_class4.mean('time').compute()

In [None]:
#========================================================================================================
# Plotting
index_c2012 = 1
#========================================================================================================
vmin0 = 999
vmax0 = 1026
vint0 = 1
fsize=23
lsize=21

bounds = np.arange(vmin0, vmax0+vint0, vint0)
icmap = 'turbo'

if index_c2012 == 1:
    bounds = np.arange(vmin0, vmax0+vint0, vint0)

    cmap = colors.ListedColormap(['#EE2846', '#EE245A','#ED1C70','#EC1D8B', '#CA2690','#9F2990', '#7A2B90', '#543393','#2B3D99','#2757A7',
                                  '#286DB6','#2583C5','#229BD7','#0EB0E0','#0FB0BC','#13B09B','#19B381',
                                   '#13B15F','#3CB54D' ,'#6DBE46','#9DCB3D','#DCE232', '#F9DA2B', 
                                    '#FEAC2B','#F57F29','#F1582A', '#EE2B32', ])

    norm = colors.BoundaryNorm(bounds, cmap.N)
    icmap = cmap
    #norm = colors.Normalize(vmin=-1, vmax=1, clip=False)
    norm.vmin = vmin0
    norm.vmax = vmax0

# Set up the projection with a specific central longitude and true scale latitude
proj = ccrs.NorthPolarStereo(central_longitude=180, true_scale_latitude=70, globe=None)

# Create a 2x2 grid of subplots with shared colorbar
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15, 15), subplot_kw={'projection': proj},
                        sharex=True, sharey=True)

# Loop through each subplot and plot the corresponding mean xarray dataset
scatter_lon = -156.8
scatter_lat = 71.3
for i, ax in enumerate(axs.flat):

    # Get the mean xarray dataset for the corresponding class
    if i == 0:
        ds_mean = dsE3SM_mean_class1
        title = "Nudged E3SM Cluster 1"
    elif i == 1:
        ds_mean = dsE3SM_mean_class2
        title = "Nudged E3SM Cluster 2"
    elif i == 2:
        ds_mean = dsE3SM_mean_class3
        title = "Nudged E3SM Cluster 3"
    else:
        ds_mean = dsE3SM_mean_class4
        title = "Nudged E3SM Cluster 4"

   
    gl = ax.gridlines(draw_labels=False, linewidth=0.5, color='gray', alpha=0.8, linestyle='--')
    gl.xlocator = plt.FixedLocator(np.arange(-180, 181, 30))
    # Manually add longitude labels around the polar projection
    for lon in np.arange(-180, 151, 30):
        x, y = ax.projection.transform_point(lon, 45, ccrs.Geodetic())
        ax.text(x, y, f'{lon}°', ha='center', va='center', transform=ax.projection,fontsize=12)

    # Plot the mean xarray dataset on the axis object
    im = ax.pcolormesh(ds_mean['lon'], ds_mean['lat'], ds_mean['PSL'], transform=ccrs.PlateCarree(),
                   cmap=icmap, vmin=vmin0, vmax=vmax0)
    
    # Add black contours to the plot
    cs = ax.contour(ds_mean['lon'], ds_mean['lat'], ds_mean['PSL'], levels=bounds, colors='black',
                    linewidths=0.5, transform=ccrs.PlateCarree())
    
    # Add coastlines to the plot
    ax.coastlines(resolution='10m', linewidth=0.3)

    # Set the title and font size
    ax.set_title(title, fontsize=fsize)
    
    # Add a point to the plot using the scatter function
    ax.scatter(scatter_lon, scatter_lat, marker='*', s=700, color='w', 
            edgecolors='k', linewidths=0.5, zorder =5, transform=ccrs.PlateCarree())

    for spine in ax.spines.values():
            spine.set_color('k')
            spine.set_linewidth(2)

# colorbar
fig.subplots_adjust(left = 0.03, right = 0.97, bottom = 0.1, top = 0.95,hspace=0.1,wspace=0.1)
cbar_ax = fig.add_axes([0.10, 0.03, 0.8, 0.03])

cb = fig.colorbar(im, cax=cbar_ax, shrink=0.7, pad=0.02, orientation='horizontal', extend='both')

cb.ax.set_title('Sea level pressure (hPa)', fontsize=lsize)
cb.ax.tick_params(labelsize=lsize-5, length=0)
cb.ax.xaxis.set_ticks_position('bottom')
cb.ax.xaxis.set_label_position('bottom')
cb.ax.tick_params(axis='x', direction='out', pad=5, labelrotation=0)
cb.ax.xaxis.set_tick_params(color='black', width=1.5, which='both', pad=10)
cb.ax.xaxis.set_tick_params(size=2, which='both')
cb.set_ticks(bounds)
cb.set_ticklabels(['' if i % 5 != 0 else str(i) for i in bounds])
#---------------------------------------------------------
# specify ticks matching 2012 paper
string_values = np.array(['','1000', '', '', '', '', '1005', '', '', '', '',
                       '1010', '', '', '', '', '1015', '', '', '', '',
                       '1020', '', '', '', '', '1025',''], dtype='<U4')
if index_c2012 == 1: cb.set_ticks(bounds,labels=string_values)
#---------------------------------------------------------
cb.ax.tick_params(width=1.5, length=5)

figname = '/home/zhengx/Research/NSA/IMG/NdgE3SM_Jlabel_SLP_V2.png'
if index_c2012 == 1: figname = '/home/zhengx/Research/NSA/IMG/NdgE3SM_Jlabel_SLP_C2012.png'
plt.savefig(figname,facecolor='white', edgecolor='none')
#=-=-=-=-=-=-=-=-=-=-=-=#=-=-=-=-=-=-=-=-=-=-=-=#=-=-=-=-=-=-=-=-=-=-=-=


### 850mb Wind

In [None]:
#========================================================================================================
# Process data
#========================================================================================================
Nug_E3SM_U0 = xr.open_mfdataset(
'/lcrc/globalscratch/yfeng/NSA/Data/U850/U850*.nc'
    ).sel(lat=slice(45, 90))
Nug_E3SM_V0 = xr.open_mfdataset(
'/lcrc/globalscratch/yfeng/NSA/Data/V850/V850*.nc'
    ).sel(lat=slice(45, 90))
Nug_E3SM_U =  Nug_E3SM_U0['U850'].compute()
Nug_E3SM_V =  Nug_E3SM_V0['V850'].compute()
Nug_E3SM_Wspd=np.sqrt(Nug_E3SM_U**2+Nug_E3SM_V**2).rename('WSPD850')

#labels
colleague = pd.read_csv('/lcrc/globalscratch/yfeng/era5_nudgede3sm_johannes_labled/clusters-new.dat').dropna()
colleague = colleague.sort_values(by='time', ascending=True)
colleague['time'] = pd.to_datetime(colleague['time'])
colleague = colleague.loc[colleague['time'] <= '2015-01-01']
colleague['time'] = pd.to_datetime(colleague['time'])
# Rename the 'class' column to 'label'
colleague.rename(columns={'class': 'label'}, inplace=True)
# Convert the pandas DataFrame to an xarray Dataset
class_dataset = colleague.set_index('time').to_xarray()
# Assuming Nug_E3SM_T is an xarray DataArray

# Convert the 'time' coordinate in Nug_E3SM_T to pandas datetime
Nug_E3SM_U['time'] = Nug_E3SM_U.indexes['time'].to_datetimeindex()
Nug_E3SM_V['time'] = Nug_E3SM_V.indexes['time'].to_datetimeindex()
Nug_E3SM_Wspd['time'] = Nug_E3SM_Wspd.indexes['time'].to_datetimeindex()

# Merge the datasets
dsE3SM = xr.merge([Nug_E3SM_U, Nug_E3SM_V, Nug_E3SM_Wspd, class_dataset], join='inner')
dsE3SM = dsE3SM.sel(time=slice('2005-01-01', '2007-01-01'))

# Calculate the monthly average temperature for the two-year period
U_monthly_avg = dsE3SM.U850.groupby('time.month').mean(dim='time')
V_monthly_avg = dsE3SM.V850.groupby('time.month').mean(dim='time')
WSPD_monthly_avg = dsE3SM.WSPD850.groupby('time.month').mean(dim='time')

# Create a new DataArray for the two-year monthly mean temperature
U_monthly_mean = xr.DataArray(U_monthly_avg.values, dims=('month', 'lat', 'lon'), 
                                     coords={'month': U_monthly_avg.month.values, 
                                             'lat': dsE3SM.lat, 'lon': dsE3SM.lon})
V_monthly_mean = xr.DataArray(V_monthly_avg.values, dims=('month', 'lat', 'lon'), 
                                     coords={'month': V_monthly_avg.month.values, 
                                             'lat': dsE3SM.lat, 'lon': dsE3SM.lon})
WSPD_monthly_mean = xr.DataArray(WSPD_monthly_avg.values, dims=('month', 'lat', 'lon'), 
                                     coords={'month': WSPD_monthly_avg.month.values, 
                                             'lat': dsE3SM.lat, 'lon': dsE3SM.lon})

# Repeat the monthly mean values for each month within the two-year period
month_indices = dsE3SM['time.month'] - 1  # Convert from 1-12 to 0-11
U_monthly_mean_broadcasted = U_monthly_mean.isel(month=month_indices)
dsE3SM['U_MonthlyMean'] = U_monthly_mean_broadcasted

V_monthly_mean_broadcasted = V_monthly_mean.isel(month=month_indices)
dsE3SM['V_MonthlyMean'] = V_monthly_mean_broadcasted

WSPD_monthly_mean_broadcasted = WSPD_monthly_mean.isel(month=month_indices)
dsE3SM['WSPD_MonthlyMean'] = WSPD_monthly_mean_broadcasted

# Calculate the anomaly
U_Anomaly = dsE3SM.U850 - U_monthly_mean_broadcasted
dsE3SM['U_Anomaly'] = U_Anomaly
V_Anomaly = dsE3SM.V850 - V_monthly_mean_broadcasted
dsE3SM['V_Anomaly'] = V_Anomaly
WSPD_Anomaly = dsE3SM.WSPD850 - WSPD_monthly_mean_broadcasted
dsE3SM['WSPD_Anomaly'] = WSPD_Anomaly

bool_class1 = (dsE3SM['label'] == 1)
bool_class2 = (dsE3SM['label'] == 2)
bool_class3 = (dsE3SM['label'] == 3)
bool_class4 = (dsE3SM['label'] == 4)

dsE3SM_class1 = dsE3SM.where(bool_class1, drop=True)
dsE3SM_class2 = dsE3SM.where(bool_class2, drop=True)
dsE3SM_class3 = dsE3SM.where(bool_class3, drop=True)
dsE3SM_class4 = dsE3SM.where(bool_class4, drop=True)

# Compute the mean for each class
dsE3SM_mean_class1 = dsE3SM_class1.mean('time').compute()
dsE3SM_mean_class2 = dsE3SM_class2.mean('time').compute()
dsE3SM_mean_class3 = dsE3SM_class3.mean('time').compute()
dsE3SM_mean_class4 = dsE3SM_class4.mean('time').compute()

In [None]:
from matplotlib.colors import LinearSegmentedColormap
#========================================================================================================
# MULMENSTADT 2012 color scheme
index_c2012 = 0
#========================================================================================================
# Plotting Mean
#========================================================================================================
vmin0 = 2
vmax0 = 15
vint0 = 1
fsize=23
lsize=21

bounds = np.arange(vmin0, vmax0+vint0, vint0)
icmap = 'gist_earth_r'
# Load the original colormap
original_cmap = plt.cm.magma_r

# Create a new colormap using the first 80% of the original colormap
icmap = LinearSegmentedColormap.from_list(
    "truncated_gnuplot", original_cmap(np.linspace(0, 0.9, 256))
)



if index_c2012 == 1: 
    cmap = colors.ListedColormap(['#EF293D','#EF293D', '#EF2354', '#EE206C', '#EE2086', 
                                     '#CB2891','#942A90', '#662F92', 
                                     '#283A97','#0058A9','#0073BC','#0092D5',
                                     '#00B4F0','#00B0C3','#00AB9C','#00A878',
                                     '#00A654','#00AC4D','#5EBC46','#FFF303',
                                     '#FDB417','#F6821F','#F25B22','#EF3029','#EF3029'])
    bounds = bounds.copy()
    norm = colors.BoundaryNorm(bounds, cmap.N)
    # #norm = colors.Normalize(vmin=-1, vmax=1, clip=False)
    norm.vmin = vmin0
    norm.vmax = vmax0
    icmap = cmap
    
# Set up the projection with a specific central longitude and true scale latitude
proj = ccrs.NorthPolarStereo(central_longitude=180, true_scale_latitude=70, globe=None)

# Create a 2x2 grid of subplots with shared colorbar
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15, 15), subplot_kw={'projection': proj},
                        sharex=True, sharey=True)

# Loop through each subplot and plot the corresponding mean xarray dataset
scatter_lon = -156.8
scatter_lat = 71.3
for i, ax in enumerate(axs.flat):

    # Get the mean xarray dataset for the corresponding class
    if i == 0:
        ds_mean = dsE3SM_mean_class1
        title = "E3SM Cluster 1"
    elif i == 1:
        ds_mean = dsE3SM_mean_class2
        title = "E3SM Cluster 2"
    elif i == 2:
        ds_mean = dsE3SM_mean_class3
        title = "E3SM Cluster 3"
    else:
        ds_mean = dsE3SM_mean_class4
        title = "E3SM Cluster 4"
    
    ax.set_extent([-180, 180, 45, 90], crs=ccrs.PlateCarree())
    
    gl = ax.gridlines(draw_labels=False, linewidth=0.5, color='gray', alpha=0.8, linestyle='--')
    gl.xlocator = plt.FixedLocator(np.arange(-180, 181, 30))
    # Manually add longitude labels around the polar projection
    for lon in np.arange(-180, 151, 30):
        x, y = ax.projection.transform_point(lon, 45, ccrs.Geodetic())
        ax.text(x, y, f'{lon}°', ha='center', va='center', transform=ax.projection,fontsize=12)
    
    # Plot the mean xarray dataset on the axis object
    im = ax.pcolormesh(ds_mean['lon'], ds_mean['lat'], ds_mean['WSPD850'], transform=ccrs.PlateCarree(),
                   cmap=icmap, vmin=vmin0, vmax=vmax0)
    
    skip = 2
    ax.quiver(ds_mean['lon'].values[::skip], ds_mean['lat'].values[::skip], 
              ds_mean['U850'].values[::skip, ::skip], ds_mean['V850'].values[::skip, ::skip],
              transform=ccrs.PlateCarree(), width=0.0025, headlength=7, headwidth=7, scale=125, regrid_shape=25, color='black')

    
    # Add coastlines to the plot
    ax.coastlines(resolution='10m', linewidth=0.3)

    # Set the title and font size
    ax.set_title(title, fontsize=fsize)
    
    # Add a point to the plot using the scatter function
    ax.scatter(scatter_lon, scatter_lat, marker='*', s=200, color='k', 
            edgecolors='k', linewidths=0.5, zorder =5, transform=ccrs.PlateCarree())

    
    for spine in ax.spines.values():
            spine.set_color('k')
            spine.set_linewidth(2)

fig.subplots_adjust(left = 0.03, right = 0.97, bottom = 0.1, top = 0.95,hspace=0.1,wspace=0.1)
cbar_ax = fig.add_axes([0.10, 0.03, 0.8, 0.03])

cb = fig.colorbar(im, cax=cbar_ax, shrink=0.7, pad=0.02, orientation='horizontal', extend='both')

cb.ax.set_title('850hPa Wind speed (m/s)', fontsize=lsize)
cb.ax.tick_params(labelsize=lsize-5, length=0)
cb.ax.xaxis.set_ticks_position('bottom')
cb.ax.xaxis.set_label_position('bottom')
cb.ax.tick_params(axis='x', direction='out', pad=5, labelrotation=0)
cb.ax.xaxis.set_tick_params(color='black', width=1.5, which='both', pad=10)
cb.ax.xaxis.set_tick_params(size=2, which='both')
cb.set_ticks(bounds)
# cb.set_ticklabels(['' if i % 5 != 0 else str(i) for i in bounds])

#---------------------------------------------------------
# specify ticks matching 2012 paper
string_values = np.array(['0', '', '', '', '2', '', '', '', '4', '', '', '', '6', '',
       '', '',  '8', '', '', '', '10', '', '', '','12'], dtype='<U2')
if index_c2012 == 1: cb.set_ticks(bounds,labels=string_values)
#---------------------------------------------------------
cb.ax.tick_params(width=1.5, length=5)
# cb.ax.xaxis.set_major_formatter(FuncFormatter(custom_tick_formatter))

figname = '/home/zhengx/Research/NSA/IMG/NdgE3SM_Jlabel_WSPD850_V2.png'
if index_c2012 == 1: figname = '/home/zhengx/Research/NSA/IMG/10yrE3SM_Jlabel_WSPD850_C2012.png'
plt.savefig(figname,facecolor='white', edgecolor='none')
#=-=-=-=-=-=-=-=-=-=-=-=#=-=-=-=-=-=-=-=-=-=-=-=#=-=-=-=-=-=-=-=-=-=-=-=
#========================================================================================================