# Import packages

In [None]:
import os
import cv2
import cmaps
import cmocean
import seaborn as sns
import numpy as np
import xarray as xr
import scipy.io as sio
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.gridspec as gridspec
from netCDF4 import Dataset
from datetime import date
from pandas import Series, DataFrame
import pandas as pd
from mpl_toolkits.basemap import Basemap
from matplotlib.pyplot import Polygon
from matplotlib import rcParams
from matplotlib.backends.backend_pdf import PdfPages
version = mpl.__version__
rcParams['font.family'] = 'sans-serif'
directory   = '/srv/scratch/z3533156'

# Import time functions

In [None]:
def datestring_to_serial_day(datestring,epochY=1990,epochm=1,epochd=1,epochH=0,epochM=0):
    import pandas as pd
    import datetime
    serial_day_timedelta = pd.to_datetime(datestring) - datetime.datetime(epochY,epochm,epochd,epochH,epochM)
    corrected_serial_day_number = serial_day_timedelta.days + serial_day_timedelta.seconds/86400
    return corrected_serial_day_number
def serial_day_to_datestring(day,epochY=1990,epochm=1,epochd=1,epochH=0,epochM=0):
    import datetime
    corrected_date = datetime.datetime(epochY,epochm,epochd,epochH,epochM) + datetime.timedelta(day)
    return corrected_date.strftime("%Y-%m") 

# Read heat budget terms in the mixed layerheat budget terms in the mixed layer

In [None]:
daily_t           = np.arange(date(2021,12,1).toordinal(),date(2022,2,28).toordinal()+1,1)
daily_dates       = [date.fromordinal(tt.astype(int)) for tt in daily_t]
dataset           = sio.loadmat(directory+'/MHW/Figure_data/Figure4_MLD_ERA5.mat')
lon0              = dataset['lon_LD'][:,0]
lat0              = dataset['lat_LD'][:,0]
lon               = np.tile(lon0,[np.size(lat0,0),1]).transpose()
lat               = np.tile(lat0,[np.size(lon0,0),1])
HeatBudget_DJF    = dataset['HeatBudget_DJF'][:,:]
HeatBudget_Dec_LD = dataset['HeatBudget_Dec_LD'][:,:,:]
HeatBudget_Jan_LD = dataset['HeatBudget_Jan_LD'][:,:,:]
HeatBudget_Feb_LD = dataset['HeatBudget_Feb_LD'][:,:,:]
HeatBudget_LD     = np.concatenate((HeatBudget_Dec_LD,HeatBudget_Jan_LD,HeatBudget_Feb_LD),axis=2)
################################################################################################################################
len1 = np.size(HeatBudget_DJF[0,0:31])
len2 = np.size(HeatBudget_DJF[1,0:31])
len3 = np.size(HeatBudget_DJF[2,0:31])
len4 = np.size(HeatBudget_DJF[4,0:31])
df0  = DataFrame({'T_rate': np.append(np.append(np.append(HeatBudget_DJF[0,0:31],HeatBudget_DJF[1,0:31]),HeatBudget_DJF[2,0:31]),HeatBudget_DJF[4,0:31]),
                'Terms': np.append(np.append(np.append(np.zeros(len1),np.ones(len2)),2+np.zeros(len3)),3+np.zeros(len4))})
df0['Terms'][df0['Terms']==0]='T(Dec)'
df0['Terms'][df0['Terms']==1]='U(Dec)'
df0['Terms'][df0['Terms']==2]='V(Dec)'
df0['Terms'][df0['Terms']==3]='Q(Dec)'

len1=np.size(HeatBudget_DJF[0,31:62])
len2=np.size(HeatBudget_DJF[1,31:62])
len3=np.size(HeatBudget_DJF[2,31:62])
len4=np.size(HeatBudget_DJF[4,31:62])
df1=DataFrame({'T_rate': np.append(np.append(np.append(HeatBudget_DJF[0,31:62],HeatBudget_DJF[1,31:62]),HeatBudget_DJF[2,31:62]),HeatBudget_DJF[4,31:62]),
               'Terms':  np.append(np.append(np.append(np.zeros(len1),np.ones(len2)),2+np.zeros(len3)),3+np.zeros(len4))})
df1['Terms'][df1['Terms']==0]='T(Jan)'
df1['Terms'][df1['Terms']==1]='U(Jan)'
df1['Terms'][df1['Terms']==2]='V(Jan)'
df1['Terms'][df1['Terms']==3]='Q(Jan)'

len1=np.size(HeatBudget_DJF[0,62:])
len2=np.size(HeatBudget_DJF[1,62:])
len3=np.size(HeatBudget_DJF[2,62:])
len4=np.size(HeatBudget_DJF[4,62:])
df2=DataFrame({'T_rate': np.append(np.append(np.append(HeatBudget_DJF[0,62:],HeatBudget_DJF[1,62:]),HeatBudget_DJF[2,62:]),HeatBudget_DJF[4,62:]),
               'Terms':  np.append(np.append(np.append(np.zeros(len1),np.ones(len2)),2+np.zeros(len3)),3+np.zeros(len4))})
df2['Terms'][df2['Terms']==0]='T(Feb)'
df2['Terms'][df2['Terms']==1]='U(Feb)'
df2['Terms'][df2['Terms']==2]='V(Feb)'
df2['Terms'][df2['Terms']==3]='Q(Feb)'
data = [df0, df1,df2]
df   = pd.concat(data)

# Plot the spatial distribution and time series of heat budget terms in the mixed layer

In [None]:
fig_ratio     = 0.9
labelfont     = 20
padspacescale = 25
labelpadscale = 4
linefont      = 2
scale1        = 1.85
scale2        = 0.97
fig           = plt.figure(figsize=(24, 30))
gs            = gridspec.GridSpec(4,4)
labels        = ['a','b','c','d','e','f','g','h','i','j','k','l','m','n']
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
ax1 = fig.add_subplot(gs[0:1])
l, b, w, h = ax1.get_position().bounds
ax1.set_position([l, b, scale1*w, 0.45*scale1*h])
legends = ['TTEND','UADV','VADV','ENTR','Q','Residual']
colors  = ['xkcd:orange','xkcd:cherry red','xkcd:blue purple','xkcd:gray','xkcd:bluish','xkcd:forest green']
for i in range(6):
    plt.plot(daily_dates, HeatBudget_DJF[i,:], color=colors[i], linewidth=1.5*linefont, linestyle='solid')
leg=plt.legend(legends, loc = 2, ncol = 3, fontsize=15, frameon=True, facecolor='xkcd:pale blue', framealpha=0.3)
for line in leg.get_lines():
    line.set_linewidth(6.0)
plt.title(labels[0], fontsize=1.0*labelfont,loc='left', pad=0.4*padspacescale, weight='bold',family='sans-serif')
plt.xlim(daily_dates[0], daily_dates[89])
xmajorLocator   = mpl.dates.WeekdayLocator(byweekday=2, interval=2)
xmajorFormatter = mpl.dates.DateFormatter('%Y-%m-%d')
ax1.xaxis.set_major_locator(xmajorLocator)
ax1.xaxis.set_major_formatter(xmajorFormatter)
plt.xticks(fontsize=0.7*labelfont)
ax1.axes.xaxis.set_tick_params(pad=10)
ax1.set_xlabel('Date',fontsize = labelfont,labelpad=5)
plt.ylim(-2,2)
yminorLocator   = plt.MultipleLocator(0.5)
ymajorLocator   = plt.MultipleLocator(1)
ymajorFormatter = plt.FormatStrFormatter('%2.0f')
ax1.yaxis.set_minor_locator(yminorLocator)
ax1.yaxis.set_major_locator(ymajorLocator)
ax1.yaxis.set_major_formatter(ymajorFormatter)  
plt.yticks(fontsize=labelfont)
ax1.axes.yaxis.set_tick_params(pad=10)
ax1.set_ylabel('($^\circ$C / day)',fontsize = labelfont,labelpad=-5)
ax1.spines['bottom'].set_linewidth(labelpadscale)
ax1.spines['left'].set_linewidth(labelpadscale)
ax1.spines['top'].set_linewidth(labelpadscale)
ax1.spines['right'].set_linewidth(labelpadscale)
plt.tick_params(axis='x',which='major',bottom='on',left='on',top='on',right='on',length=8.0,width=4,colors='black',direction='in')
plt.tick_params(axis='y',which='minor',bottom='on',left='on',top='on',right='on',length=4.0,width=4,colors='black',direction='in')
plt.tick_params(axis='y',which='major',bottom='on',left='on',top='on',right='on',length=8.0,width=4,colors='black',direction='in')
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
ax2 = fig.add_subplot(gs[2:3])
l, b, w, h = ax2.get_position().bounds
ax2.set_position([l-0.05, b, scale1*w, 0.45*scale1*h])
plt.title(labels[1], fontsize=1.0*labelfont,loc='left', pad=0.4*padspacescale, weight='bold',family='sans-serif')
legends1 = ['TTEND','UADV','VADV','Q']
colors1  = ['xkcd:orange','xkcd:cherry red','xkcd:blue purple','xkcd:bluish']
for i in range(4):
    plt.plot([2*i,2*i+1.5], [1.0,1.0], color=colors1[i], linewidth=10*linefont, linestyle='solid') 
plt.text(0.0, 0.78, 'TTEND', color='xkcd:orange',      fontsize=1.0*labelfont)
plt.text(2.1, 0.78, 'UADV',  color='xkcd:cherry red',  fontsize=1.0*labelfont)
plt.text(4.2, 0.78, 'VADV',  color='xkcd:blue purple', fontsize=1.0*labelfont)
plt.text(6.5, 0.78, 'Q',     color='xkcd:bluish',      fontsize=1.0*labelfont)
plt.text(-0.25, -0.9, '2021-12', color='xkcd:black',   fontsize=1.8*labelfont)
plt.text(3.75,  -0.9, '2022-01', color='xkcd:black',   fontsize=1.8*labelfont)
plt.text(7.75,  -0.9, '2022-02', color='xkcd:black',   fontsize=1.8*labelfont)
plt.plot([0,12], [0,0], color='black', linewidth=1.5*linefont, linestyle='dashed')
plt.fill_between([-0.25,3.25], -1.2, 1.2, facecolor='xkcd:pale orange', alpha=0.3)
plt.fill_between([3.75,7.25],  -1.2, 1.2, facecolor='xkcd:pale blue',   alpha=0.3)
plt.fill_between([7.75,11.25], -1.2, 1.2, facecolor='xkcd:pale orange', alpha=0.3)
bar = sns.barplot(x="Terms", y="T_rate", data=df, ci=95, errcolor='black', ec='black',
                  palette=['xkcd:orange','xkcd:cherry red','xkcd:blue purple','xkcd:bluish',
                           'xkcd:orange','xkcd:cherry red','xkcd:blue purple','xkcd:bluish',
                           'xkcd:orange','xkcd:cherry red','xkcd:blue purple','xkcd:bluish'],
                  n_boot=10000,errwidth=2,capsize=0.2)
hatches = [' ',' ',' ',' ', 'x','x','x','x', '\\','\\','\\','\\']
# Loop over the bars
for i,thisbar in enumerate(bar.patches):
    # Set a different hatch for each bar
    thisbar.set_hatch(hatches[i])
widthbars=[0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5]
for bar,newwidth in zip(ax2.patches,widthbars):
    x=bar.get_x()
    width=bar.get_width()
    centre=x+width/2.0
    bar.set_x(centre-newwidth/2.0)
    bar.set_width(newwidth)
ax2.set_xlabel(' ',fontsize =  0.65*labelfont,labelpad=5)
ax2.set_ylabel('($^\circ$C / day)',fontsize =  labelfont,labelpad=-5)
plt.ylim(-1.2, 1.2)
yminorLocator   = plt.MultipleLocator(0.2)
ymajorLocator   = plt.MultipleLocator(0.4)
ymajorFormatter = plt.FormatStrFormatter('%2.1f')
ax2.yaxis.set_minor_locator(yminorLocator)
ax2.yaxis.set_major_locator(ymajorLocator)
ax2.yaxis.set_major_formatter(ymajorFormatter)  
ax2.axes.xaxis.set_tick_params(pad=5)
ax2.axes.yaxis.set_tick_params(pad=5)
ax2.spines['bottom'].set_linewidth(labelpadscale)
ax2.spines['left'].set_linewidth(labelpadscale)
ax2.spines['top'].set_linewidth(labelpadscale)
ax2.spines['right'].set_linewidth(labelpadscale)
plt.xticks(fontsize=0.65*labelfont)
plt.yticks(fontsize=labelfont)
ax2.axes.xaxis.set_ticklabels([])
plt.tick_params(bottom='on',left='on',length=8.0,width=4,colors='black',direction='in')
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
levels1     = np.linspace(-1,1,50)
tick_marks1 = np.linspace(-1,1,11)
cmaps1      = cmaps.cmocean_balance
for i in range(12):
    ax3 = fig.add_subplot(gs[i+4])
    l, b, w, h = ax3.get_position().bounds
    m   = Basemap(projection='merc',llcrnrlat=-38.0-0.001,urcrnrlat=-30+0.001,llcrnrlon=150.0-0.001,urcrnrlon=160-0.001,resolution='i')
    m.drawcoastlines(color='0.1',  linewidth=0.5*linefont)
    m.drawmapboundary(color='0.1', linewidth=0.5*linefont)
    m.fillcontinents(color='0.95', lake_color='white')
    m.drawmeridians(np.arange(150,165,4),labels=[0,0,0,0],linewidth=0.4*linefont,dashes=[2,2],color='.7',fontsize=labelfont)  
    x, y = m(lon, lat)
    CB1  = m.contourf(x, y,  HeatBudget_LD[:,:,i],cmap=cmaps1,levels=levels1,origin='lower',extend='both')  
    cx00, cy00 = m(150.2, -31.8)
    if np.mod(i,4)==0:
        plt.text(cx00, cy00, 'TTEND', color='black', fontsize=1.0*labelfont)
        plt.ylabel('Latitude',fontsize=labelfont,labelpad=0.5*padspacescale,family='sans-serif')
        m.drawparallels(np.arange(-40,-10, 2),labels=[1,0,0,0],linewidth=0.4*linefont,dashes=[2,2],color='.7',fontsize=labelfont)
    elif np.mod(i,4)==1:
        l = l - 0.034
        plt.text(cx00, cy00, 'UADV', color='black', fontsize=1.0*labelfont)
        m.drawparallels(np.arange(-40,-10, 2),labels=[0,0,0,0],linewidth=0.4*linefont,dashes=[2,2],color='.7',fontsize=labelfont)
    elif np.mod(i,4)==2:
        l = l - 0.068
        plt.text(cx00, cy00, 'VADV', color='black', fontsize=1.0*labelfont)
        m.drawparallels(np.arange(-40,-10, 2),labels=[0,0,0,0],linewidth=0.4*linefont,dashes=[2,2],color='.7',fontsize=labelfont)
    else:
        l = l - 0.102
        plt.text(cx00, cy00, 'Q', color='black', fontsize=1.0*labelfont)
        m.drawparallels(np.arange(-40,-10, 2),labels=[0,0,0,0],linewidth=0.4*linefont,dashes=[2,2],color='.7',fontsize=labelfont)
    ax3.spines['bottom'].set_linewidth(labelpadscale)
    ax3.spines['left'].set_linewidth(labelpadscale)
    ax3.spines['top'].set_linewidth(labelpadscale)
    ax3.spines['right'].set_linewidth(labelpadscale)
    plt.title(labels[i+2], fontsize=labelfont,loc='left', pad=0.4*padspacescale, weight='bold',family='sans-serif')
    cx0,  cy0  = m(150.1305, -30.7)
    if i<4:
        b = b + 0.02
        plt.text(cx0, cy0, '2021-12', color='dodgerblue', fontsize=0.9*labelfont)
    elif i<8:
        b = b + 0.078
        plt.text(cx0, cy0, '2022-01', color='dodgerblue', fontsize=0.9*labelfont)
    else:
        b = b + 0.136
        plt.text(cx0, cy0, '2022-02', color='dodgerblue', fontsize=0.9*labelfont)
        plt.xlabel('Longitude',fontsize=labelfont,labelpad=-10.0,family='sans-serif')
        m.drawmeridians(np.arange(150,165, 4),labels=[0,0,0,1],linewidth=0.4*linefont,dashes=[2,2],color='.7',fontsize=labelfont)        
    ax3.set_position([l, b, scale2*w, scale2*h])
    plt.tick_params(axis='both',which='major',bottom='on',left='on',length=50.0,width=20,colors='black',direction='out')
    if i==0:
        lon2     = np.array([lon[56,51],lon[63,51],lon[81,81],lon[74,81],lon[56,51]])
        lat2     = np.array([lat[56,51],lat[63,51],lat[81,81],lat[74,81],lat[56,51]])
        x1, y1   = m(lon2,lat2)
        x2, y2   = m(151.2093, -33.8688)
        cx1, cy1 = m(150.2205, -33.4144)
        m.plot(x1, y1, linewidth=labelpadscale, linestyle='solid', color='green')
        m.plot(x2, y2, marker='^',  color='xkcd:red',   markersize=10)
        plt.text(cx1, cy1,'Sydney', color='xkcd:red', fontsize=0.8*labelfont)
cbaxes1 = fig.add_axes([0.135, 0.22, 0.65, 0.01])
cb1     = plt.colorbar(CB1,orientation='horizontal',cax = cbaxes1)
cb1.set_ticks(tick_marks1)
cb1.set_label(r'($^\circ$C day$^{-1}$)', fontsize=labelfont,labelpad=0)
cb1.ax.tick_params(labelsize=labelfont)
fig.savefig(directory+'/MHW/Figure_plots/Figure4_MLD_Budget.png',dpi=300,bbox_inches = 'tight') 