In [None]:
import numpy as np
import xarray as xr
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import sys
import tripyview as tpv
import eddytools as et
import math
from cmocean import cm as cmo
import matplotlib.path as mpath
import cartopy.crs as ccrs

In [None]:
figpath='/PATH/TO/OUTPUT/'
datapath='/PATH/TO/DATA/'

In [None]:
#plotting functions

In [None]:
def simple_plot(dpi=300,
                ygridlocs=[-80,-75,-70,-65,-60],
                figsize=(6,6),
                box=[-180,180,-80,-60],
                cols=1,
                rows=1,
               ):
    
    fig, axes = plt.subplots(
                rows,cols,
                subplot_kw=dict(projection=ccrs.SouthPolarStereo()),
                constrained_layout=True,
                figsize=figsize,
                dpi=dpi,
                # facecolor='lightgrey',
            )
    if isinstance(axes, np.ndarray):
        axesflat = axes.flatten()
    else:
        axesflat = [axesflat]
    for ax in axesflat:
        ax.set_extent(box, crs=ccrs.PlateCarree())
        theta = np.linspace(0, 2*np.pi, 100)
        center, radius = [0.5, 0.505], 0.5
        verts = np.vstack([np.sin(theta), np.cos(theta)]).T
        circle = mpath.Path(verts * radius + center)

        ax.set_boundary(circle, transform=ax.transAxes)
        ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False, linewidth=0.5, \
                xlocs=range(-180,171,30), ylocs=[], \
                color='gray', alpha=0.5, linestyle='--', zorder=10)
        # Draw concentric circles (but hide labels) for the parallels of the latitude
        ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=False, linewidth=0.5, \
                        xlocs=[], ylocs=ygridlocs, \
                        color='grey', alpha=0.5, linestyle='--', zorder=10)

        # ax.coastlines(lw=0.5, resolution="110m", facecolor='k',zorder=101)

    return fig, axes

In [None]:
# colormaps. remove dark bits
rm = cmo.tools.crop_by_percent(cmo.amp, 20, which='max', N=None)
rbm = cmo.tools.crop_by_percent(cmo.balance, 20, which='both', N=None)

In [None]:
#for land mask
mesh=tpv.load_mesh_fesom2('/PATH/TO/MESH/', do_rot='None', focus=0, do_info=False, do_pickle=False,
                          do_earea=False, do_narea=False, do_eresol=[False,'mean'], do_nresol=[False,'eresol'])


In [None]:
#data

In [None]:
#streamfunction from tripyview: https://github.com/FESOM/tripyview
streamf50=xr.open_dataset(datapath+'hbstreamf_1951-1956.nc')
streamf90=xr.open_dataset(datapath+'hbstreamf_2091-2096.nc')

In [None]:
#surface wind

u50files=[]
v50files=[]
u90files=[]
v90files=[]

for i in np.arange(6):
    u50files.append(datapath+'uas_Amon_AWI-CM-1-1-MR_historical_r1i1p1f1_gn_'+str(1951+i)+'01-'+str(1951+i)+'12.nc')
    v50files.append(datapath+'vas_Amon_AWI-CM-1-1-MR_historical_r1i1p1f1_gn_'+str(1951+i)+'01-'+str(1951+i)+'12.nc')

    u90files.append(datapath+'uas_Amon_AWI-CM-1-1-MR_ssp370_r1i1p1f1_gn_'+str(2091+i)+'01-'+str(2091+i)+'12.nc')
    v90files.append(datapath+'vas_Amon_AWI-CM-1-1-MR_ssp370_r1i1p1f1_gn_'+str(2091+i)+'01-'+str(2091+i)+'12.nc')
uw50=xr.open_mfdataset(u50files).mean(dim='time').compute()
uw90=xr.open_mfdataset(u90files).mean(dim='time').compute()
vw50=xr.open_mfdataset(v50files).mean(dim='time').compute()
vw90=xr.open_mfdataset(v90files).mean(dim='time').compute()

In [None]:
#mke
mkeann50=xr.open_dataset(datapath+'mke_ann_reg_95m_1951-1955.nc')
mkeann90=xr.open_dataset(datapath+'mke_ann_reg_95m_2091-2095.nc')

In [None]:
#eke
ekeann50=xr.open_dataset(datapath+'eke_mean_reg_95m_1951-1955.nc')
ekeann90=xr.open_dataset(datapath+'eke_mean_reg_95m_2091-2095.nc')

In [None]:
#ice multi-year monthly mean climatology
icemons90=xr.open_dataset(datapath+'a_ice_iso_reg_ymonmean_2091-2096.nc')
icemons50=xr.open_dataset(datapath+'a_ice_iso_reg_ymonmean_1951-1956.nc')

In [None]:
#ice monthly
icemons_ann90=xr.open_dataset(datapath+'a_ice_iso_reg_monmean_2091-2096.nc')
icemons_ann50=xr.open_dataset(datapath+'a_ice_iso_reg_monmean_1951-1956.nc')

In [None]:
def define_grid(gridtype='reg',bounds=[-180,180,-80,90],dx=1,dy=1,periodic=True):
    left,right,bottom,top=bounds
    #variables we want to keep
    global lons_c, lats_c, lons_gl, lats_gl, lons_gr, lats_gr, lats_g, \
    vpointslonl, vpointslatl, vpointslonr, vpointslatr, upointslonl, upointslatl, upointslonr, \
    upointslatr, cpointslon, cpointslat, fpointslonl, fpointslatl, fpointslonr, fpointslatr, \
    dxF, dyF, dxC, dyC, dxG, dyG, dxV, dyU 
    
    if gridtype=='reg':
        # the x center and f points on either side
        lons_c = np.arange(left+(0.5*dx), right, dx)
        lons_gl = np.arange(left, right, dx)
        lons_gr = lons_gl+(dx)
        
        lats_c = np.arange(bottom+(0.5*dy), top, dy)
        lats_gl = np.arange(bottom, top, dy)
        lats_gr = lats_gl+(dy)
        
    elif gridtype=='iso':
        
        # the x center and f points on either side
        lons_c = np.arange(left+(0.5*dx), right, dx)
        lons_gl = np.arange(left, right, dx)
        lons_gr = lons_gl+(dx)
        
        #the y 
        lats_g=[bottom]

        #we need one extra center lat point in order to differentiate the correct number of v points
        while lats_g[-1] < top+(dy*(np.cos(np.radians(top)))):
            lats_g.append(lats_g[-1]+dy*(np.cos(np.radians(lats_g[-1]))))

        lats_gl=lats_g[:-1]
        lats_gr=lats_g[1:]
        lats_c=(np.asarray(lats_gl)+np.asarray(lats_gr))/2

        lats_g,lats_gl,lats_gr=lats_g[:-1],lats_gl[:-1],lats_gr[:-1]
        
    else:
        print('grid type not prepared with this function')
        pass
        
    #latlon points in 2d
    vpointslonl, vpointslatl = np.meshgrid(lons_c,lats_gl)
    vpointslonr, vpointslatr = np.meshgrid(lons_c,lats_gr)

    upointslonl, upointslatl = np.meshgrid(lons_gl,lats_c)
    upointslonr, upointslatr = np.meshgrid(lons_gr,lats_c)

    cpointslon, cpointslat = np.meshgrid(lons_c,lats_c)

    fpointslonl, fpointslatl = np.meshgrid(lons_gl,lats_gl)
    fpointslonr, fpointslatr = np.meshgrid(lons_gr,lats_gr)
    
    #distances
    dxC = et.tracking.get_distance_latlon2m(cpointslon[:-1,1:],cpointslat[:-1,1:], 
                                            cpointslon[:-1,:-1],cpointslat[:-1,:-1])
    
    #we use the left lat point because we lose the top row 
    dxV = et.tracking.get_distance_latlon2m(vpointslonl[:,1:],vpointslatl[:,1:], 
                                            vpointslonl[:,:-1],vpointslatl[:,:-1])
    
    #for dxC and dxV we need to manually add the periodic distance
    if periodic:
        endcolumn=et.tracking.get_distance_latlon2m(cpointslon[:-1,0],cpointslat[:-1,0],
                                                    cpointslon[:-1,-1],cpointslat[:-1,-1])
        dxC = np.append(dxC, endcolumn.reshape((endcolumn.shape[0],1)),axis=1)
        endcolumn=et.tracking.get_distance_latlon2m(vpointslonl[:,0],vpointslatl[:,0],
                                                    vpointslonl[:,-1],vpointslatl[:,-1])
        dxV = np.append(dxV, endcolumn.reshape((endcolumn.shape[0],1)),axis=1)

    #for dyC and dyU, we lose the top row that we added above
    #*** to do: change to lose the bottom row, so that this works better in the northern hemisphere
    #the bottom row can be on land over antarctica
    dyC = et.tracking.get_distance_latlon2m(cpointslon[1:,:],cpointslat[1:,:],
                                            cpointslon[:-1,:],cpointslat[:-1,:])
    dyU = et.tracking.get_distance_latlon2m(upointslonl[1:,:],upointslatl[1:,:],
                                            upointslonl[:-1,:],upointslatl[:-1,:])

    #dxG and dyG
    dxG = et.tracking.get_distance_latlon2m(fpointslonr,fpointslatl,fpointslonl,fpointslatl)
    dyG = et.tracking.get_distance_latlon2m(fpointslonl,fpointslatr,fpointslonl,fpointslatl)

    #dxF and dyF. for dxF we have to manually remove the top layer
    dxF = et.tracking.get_distance_latlon2m(upointslonr,upointslatl,upointslonl,upointslatl)
    dxF = dxF[:-1,:]
    dyF = et.tracking.get_distance_latlon2m(vpointslonl,vpointslatr,vpointslonl,vpointslatl)
    
    lats_c=lats_c[:-1]
    cpointslat=cpointslat[:-1,:]
    upointslonl=upointslonl[:-1,:]
    upointslatl=upointslatl[:-1,:]
    upointslonr=upointslonr[:-1,:]
    upointslatr=upointslatr[:-1,:]
    
define_grid('iso',[-180,180,-80,-40],dx=0.05,dy=0.05, periodic=True)

areas=dyU*dxV

In [None]:
#weight ice concentration by cell area
iceawm50=(icemons_ann50.a_ice[:,:,:1284]*areas.T.reshape(1,*areas.T.shape)[:,:,:1284]).sum(dim=['lat','lon'])/np.sum(areas.T.reshape(1,*areas.T.shape)[:,:,:1284])
iceawm90=(icemons_ann90.a_ice[:,:,:1284]*areas.T.reshape(1,*areas.T.shape)[:,:,:1284]).sum(dim=['lat','lon'])/np.sum(areas.T.reshape(1,*areas.T.shape)[:,:,:1284])

In [None]:
#calculate range and mean for each month
icemins50=iceawm50.groupby('time.month').min(dim='time')
icemins90=iceawm90.groupby('time.month').min(dim='time')
icemaxs50=iceawm50.groupby('time.month').max(dim='time')
icemaxs90=iceawm90.groupby('time.month').max(dim='time')
icemean50=iceawm50.groupby('time.month').mean(dim='time')
icemean90=iceawm90.groupby('time.month').mean(dim='time')

In [None]:
#plotting

In [None]:
#contourlabel format
def fmtx(x):
    s = f"{x:.1f}"
    if s.endswith("0"):
        s = f"{x:.0f}"
    return rf"{s} \Sv" if plt.rcParams["text.usetex"] else f"{s} Sv"

In [None]:
labels=['a','b','c','d','e','f','g']
monstrings=['Jan','Feb','Mar','Apr','May','Jun','Jul','Aug','Sep','Oct','Nov','Dec']

In [None]:
#figure 3
fig,axes=simple_plot(dpi=500,
                ygridlocs=[-80,-75,-70,-65,-60],
                figsize=(12,20),
                box=[-180,180,-80,-59],
                rows=4,
                cols=3)

#titles on first row
titlesize=20
axes[0,0].set_title('1951-1956',fontsize=titlesize)
axes[0,1].set_title('2091-2096',fontsize=titlesize)
axes[0,2].set_title('Change',fontsize=titlesize)

plt.rc('xtick',labelsize=14)
plt.rc('ytick',labelsize=14)
plt.rc('axes',labelsize=14)

#land mask
for ax in axes.flatten():
    ax.add_geometries(mesh.lsmask_p, crs=ccrs.PlateCarree(), 
                              facecolor=[0.6,0.6,0.6], edgecolor='none' ,linewidth=1,zorder=100)
for num,ax in enumerate(axes[:,0]):
    ax.annotate(labels[num], xy=(0.05, 0.9),xycoords='axes fraction',horizontalalignment='left', 
                     verticalalignment='bottom',fontsize=22,weight='bold')

#streamfunction
cax_0=axes[0,0].contourf(streamf50.nlon,streamf50.nlat,streamf50.hbstreamf,levels=np.linspace(120,210,10),cmap=rm,extend='both',transform=ccrs.PlateCarree())
axes[0,1].contourf(streamf90.nlon,streamf90.nlat,streamf90.hbstreamf,levels=np.linspace(120,210,10),cmap=rm,extend='both',transform=ccrs.PlateCarree())
dax_0=axes[0,2].contourf(streamf50.nlon,streamf50.nlat,streamf90.hbstreamf-streamf50.hbstreamf,levels=np.linspace(-10,10,9),cmap=rbm,extend='both',transform=ccrs.PlateCarree())

#streamfunc contours
conax50_00=axes[0,0].contour(streamf50.nlon,streamf50.nlat,streamf50.hbstreamf,levels=np.linspace(120,210,10),colors='k',alpha=1,linewidths=0.7,transform=ccrs.PlateCarree())
conax90_01=axes[0,1].contour(streamf90.nlon,streamf90.nlat,streamf90.hbstreamf,levels=np.linspace(120,210,10),colors='k',alpha=1,linewidths=0.7,linestyles='--',transform=ccrs.PlateCarree())
conax50_02=axes[0,2].contour(streamf50.nlon,streamf50.nlat,streamf50.hbstreamf,levels=np.linspace(120,210,10),colors='k',alpha=1,linewidths=0.7,transform=ccrs.PlateCarree())
conax90_02=axes[0,2].contour(streamf90.nlon,streamf90.nlat,streamf90.hbstreamf,levels=np.linspace(120,210,10),colors='k',alpha=1,linewidths=0.7,linestyles='--',transform=ccrs.PlateCarree())

#row 0 colorbars
plt.colorbar(cax_0,ax=axes[0,:2],orientation='horizontal',label='Horizontal barotropic streamfunction (Sv)')
plt.colorbar(dax_0,ax=axes[0,2],orientation='horizontal',label='Δ Horizontal barotropic streamfunction (Sv)',ticks=[-10,-5,0,5,10])

#eke 
cax_1=axes[1,0].pcolormesh(ekeann50.lon,ekeann50.lat,np.log10(ekeann50.eke[0].T*10000),vmin=-1,vmax=1.5,cmap=rm,transform=ccrs.PlateCarree())
axes[1,1].pcolormesh(ekeann90.lon,ekeann90.lat,np.log10(ekeann90.eke[0].T*10000),vmin=-1,vmax=1.5,cmap=rm,transform=ccrs.PlateCarree())
dax_1=axes[1,2].pcolormesh(ekeann50.lon,ekeann50.lat,ekeann90.eke[0].T*10000-ekeann50.eke[0].T*10000,cmap=rbm,transform=ccrs.PlateCarree(),norm=colors.SymLogNorm(linthresh=0.1, linscale=0.01,vmin=-30, vmax=30, base=10),)

#row 1 colorbars
cbar=plt.colorbar(cax_1,ax=axes[1,:2],orientation='horizontal',label='Eddy kinetic energy ($cm^2$/$s^2$)',extend='max')
cbar.set_ticks(ticks=[-1,0,1,np.log10(30)],labels=['0.1','1','10','30'])
dbar=plt.colorbar(dax_1,ax=axes[1,2],orientation='horizontal',label='Δ Eddy kinetic energy ($cm^2$/$s^2$)',extend='both')
dbar.set_ticks(ticks=[-30,-10,-1,0,1,10,30],labels=['-30','-10','1','0','1','10','30'])

#MKE
cax_2=axes[2,0].pcolormesh(mkeann50.lon,mkeann50.lat,np.log10(mkeann50.mke[0].T*10000),vmin=-1,vmax=np.log10(50),cmap=rm,transform=ccrs.PlateCarree())
axes[2,1].pcolormesh(mkeann90.lon,mkeann90.lat,np.log10(mkeann90.mke[0].T*10000),vmin=-1,vmax=np.log10(50),cmap=rm,transform=ccrs.PlateCarree())
dax_2=axes[2,2].pcolormesh(mkeann50.lon,mkeann50.lat,(mkeann90.mke[0].T*10000)-(mkeann50.mke[0].T*10000),cmap=rbm,transform=ccrs.PlateCarree(),norm=colors.SymLogNorm(linthresh=0.1, linscale=0.01,vmin=-50, vmax=50, base=10),)

#row 2 colorbars
cbar2=plt.colorbar(cax_2,ax=axes[2,:2],orientation='horizontal',label='Mean kinetic energy ($cm^2$/$s^2$)',extend='max')
dbar2=plt.colorbar(dax_2,ax=axes[2,2],orientation='horizontal',label='Δ Mean kinetic energy ($cm^2$/$s^2$)',extend='both')
cbar2.set_ticks(ticks=[-1,0,1,np.log10(50)],labels=['0.1','1','10','50'])
dbar2.set_ticks(ticks=[-50,-10,-1,0,1,10,50],labels=['-50','-10','1','0','1','10','50'])

#Zonal winds
cax_3=axes[3,0].contourf(uw50.lon,uw50.lat,uw50.uas,levels=np.linspace(-10,10,41),cmap=rbm,extend='both',transform=ccrs.PlateCarree())
axes[3,1].contourf(uw90.lon,uw90.lat,uw90.uas,levels=np.linspace(-10,10,41),cmap=rbm,extend='both',transform=ccrs.PlateCarree())
dax_3=axes[3,2].contourf(uw50.lon,uw50.lat,uw90.uas-uw50.uas,levels=np.linspace(-3,3,41),cmap=rbm,extend='both',transform=ccrs.PlateCarree())

#row 3 colorbars
plt.colorbar(cax_3,ax=axes[3,:2],orientation='horizontal',label='Zonal surface wind speed (m/s Eastward)')
plt.colorbar(dax_3,ax=axes[3,2],orientation='horizontal',label='Δ Zonal surface wind speed (m/s Eastward)',ticks=[-3,-1.5,0,1.5,3])

plt.savefig(figpath+'figure_3.png',bbox_inches='tight')

In [None]:
#figure 4
fig,axes=simple_plot(dpi=300,
                ygridlocs=[-80,-75,-70,-65,-60],
                figsize=(12,15),
                box=[-180,180,-80,-59],
                rows=3,
                cols=3)

#titles on first row
titlesize=20
axes[0,0].set_title('1951-1956',fontsize=titlesize)
axes[0,1].set_title('2091-2096',fontsize=titlesize)
axes[0,2].set_title('Change',fontsize=titlesize)

plt.rc('xtick',labelsize=14)
plt.rc('ytick',labelsize=14)
plt.rc('axes',labelsize=14)

#land mask
for ax in axes.flatten():
    ax.add_geometries(mesh.lsmask_p, crs=ccrs.PlateCarree(), 
                              facecolor=[0.6,0.6,0.6], edgecolor='none' ,linewidth=1,zorder=100)
for num,ax in enumerate(axes[:,0]):
    ax.annotate(labels[num], xy=(0.05, 0.9),xycoords='axes fraction',horizontalalignment='left', 
                     verticalalignment='bottom',fontsize=22,weight='bold')

# September ice concentration
cax_0=axes[0,0].contourf(icemons50.lon,icemons50.lat,icemons50.a_ice[8].T*100,levels=np.linspace(0,100,21),cmap=rm,transform=ccrs.PlateCarree())
axes[0,1].contourf(icemons90.lon,icemons90.lat,icemons90.a_ice[8].T*100,levels=np.linspace(0,100,21),cmap=rm,transform=ccrs.PlateCarree())
dax_0=axes[0,2].contourf(icemons50.lon,icemons50.lat,icemons90.a_ice[8].T*100-icemons50.a_ice[8].T*100,levels=np.linspace(-100,100,21),cmap=rbm,transform=ccrs.PlateCarree())

# row 0 colorbars
plt.colorbar(cax_0,ax=axes[0,:2],orientation='horizontal',label='September sea ice concentration (%)',ticks=[0,25,50,75,100])
plt.colorbar(dax_0,ax=axes[0,2],orientation='horizontal',label='Δ September sea ice concentration (%)',ticks=[-100,-50,0,50,100])

# March ice concentration
cax_1=axes[1,0].contourf(icemons50.lon,icemons50.lat,icemons50.a_ice[2].T*100,levels=np.linspace(0,100,21),cmap=rm,transform=ccrs.PlateCarree())
axes[1,1].contourf(icemons90.lon,icemons90.lat,icemons90.a_ice[2].T*100,levels=np.linspace(0,100,21),cmap=rm,transform=ccrs.PlateCarree())
dax_1=axes[1,2].contourf(icemons50.lon,icemons50.lat,icemons90.a_ice[2].T*100-icemons50.a_ice[2].T*100,levels=np.linspace(-100,100,21),cmap=rbm,transform=ccrs.PlateCarree())

# row 1 colorbars
plt.colorbar(cax_1,ax=axes[1,:2],orientation='horizontal',label='March sea ice concentration (%)',ticks=[0,25,50,75,100])
plt.colorbar(dax_1,ax=axes[1,2],orientation='horizontal',label='Δ March sea ice concentration (%)',ticks=[-100,-50,0,50,100])

#Sea ice monthly
newax = plt.subplot2grid(shape=(9,3), loc=(7, 0), rowspan=2, colspan=3,)
dice0=newax.plot(monstrings,icemean50*100,color=mpl.cm.get_cmap(cmo.balance)(0.2))
dice0f=newax.fill_between(monstrings,icemins50*100,icemaxs50*100,color=mpl.cm.get_cmap(cmo.balance)(0.2),alpha=0.5)#,label='1951-1956')

dice1=newax.plot(monstrings,icemean90*100,color=mpl.cm.get_cmap(cmo.balance)(0.8))
dice1f=newax.fill_between(monstrings,icemins90*100,icemaxs90*100,color=mpl.cm.get_cmap(cmo.balance)(0.8),alpha=0.5)#,label='2091-2096')

newax.set_ylim(0,60)
newax.set_xlim(0,11)
newax.grid(axis='y')
newax.set_ylabel('Sea ice concentration\n south of 60$^\circ$S (%)')
newax.set_xlabel('Month')

legendfill0 = newax.fill([-3,-2], [0,1],linewidth=0, color=mpl.cm.get_cmap(cmo.balance)(0.2),alpha=0.5,label='1951-1956')
legendfill1 = newax.fill([-3,-2], [0,1],linewidth=0, color=mpl.cm.get_cmap(cmo.balance)(0.8),alpha=0.5,label='2091-2096')

newax.legend(loc='upper left',fontsize=16)
newax.legend(handles=[(legendfill0[0],dice0[0]),(legendfill1[0],dice1[0])],labels=['1951-1956','2091-2096'],loc='upper left',fontsize=16)

newax.annotate('c', xy=(-0.06, 1),xycoords='axes fraction',horizontalalignment='left', 
                     verticalalignment='bottom',fontsize=22,weight='bold')

plt.savefig(figpath+'figure_4.png',bbox_inches='tight')