# Dealing with heat flux mask

After creating the heat flux mask (refer: https://github.com/dhruvbhagtani/varying-surface-forcing/blob/main/025deg_flux_expts/SSH_or_streamfunction.ipynb), the next step is to ensure net zero heat input in the ocean. 

In [None]:
import cartopy.crs as ccrs
import cosima_cookbook as cc
import cartopy.feature as cfeature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import cmocean as cm
from dask.distributed import Client
import matplotlib.path as mpath
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

# Avoid the Runtime errors in true_divide encountered when trying to divide by zero
import warnings
warnings.filterwarnings('ignore', category = RuntimeWarning)
warnings.filterwarnings('ignore', category = ResourceWarning)
warnings.filterwarnings('ignore', category = BytesWarning)

# matplotlib stuff:
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib
from mpl_toolkits.mplot3d import axes3d
%matplotlib inline
matplotlib.rcParams['mathtext.fontset'] = 'stix'
matplotlib.rcParams['font.family'] = 'STIXGeneral'
matplotlib.rcParams['lines.linewidth'] = 2.0
matplotlib.rc('xtick', labelsize = 18) 
matplotlib.rc('ytick', labelsize = 18)

import logging
logger = logging.getLogger("distributed.utils_perf")
logger.setLevel(logging.ERROR)

from dask.distributed import Client
client = Client()
client

In [None]:
db = '/scratch/x77/db6174/access-om2/archive/databases/cc_database_param_kpp_extended2.db'
session = cc.database.create_session(db)
expt = '025deg_jra55_ryf_param_kpp_extended2'

In [None]:
area_t = cc.querying.getvar(expt = expt, session = session, variable = 'area_t', n = -1)
T = cc.querying.getvar(expt = expt, session = session, variable = 'temp', n = -1)
T = T.isel(st_ocean = 1).isel(time = 1)

In [None]:
nc_file = '/scratch/x77/db6174/025deg_inputs/flux_forced_uniform_heat/heat_mask.nc'
#nc_dataset = Dataset(nc_file, mode = 'r')
mask_ds = xr.open_dataset(nc_file)
mask = mask_ds.mask

area_t = (area_t*T)/T

In [None]:
mask_int = (mask*area_t).sum(dim = ['yt_ocean','xt_ocean'])#/area_t.sum(dim = ['yt_ocean','xt_ocean'])
mask_int.values

In [None]:
mask.plot()

In [None]:
mask_avg1 = (mask.sel(yt_ocean = slice(-90, -12))*area_t).sum(dim = ['yt_ocean','xt_ocean'])/area_t.sum(dim = ['yt_ocean','xt_ocean'])
mask_avg2 = (mask.sel(yt_ocean = slice( 12,  90))*area_t).sum(dim = ['yt_ocean','xt_ocean'])/area_t.sum(dim = ['yt_ocean','xt_ocean'])

mask_avg = 0.5*(mask_avg1 + mask_avg2)

In [None]:
xt = np.arange(0, 1440)
yt = np.arange(0, 1080)

Y, X = np.meshgrid(yt,xt)

dy = 50
dyby2 = 25

y_mid = 495
width = 60

mask_values = mask.values
mask_values[y_mid - width:y_mid + width, :] = 0
mask_values_new = mask.values

i = 200
j = 0
while j < 1440:
    i = 200
    while i < 800:
        if((mask_values_new[i, j]!=mask_values_new[i+1,j]) and (mask_values_new[i,j] == 1 or mask_values_new[i,j] == 0) and (mask_values_new[i+1,j] == 1 or mask_values_new[i+1,j] == 0)):
            if(mask_values_new[i,j] == 1 and mask_values_new[i+1,j] == 0):
                mask_values_new[i-dyby2:i+dyby2,j] = (1 + np.tanh((-yt[i-dyby2:i+dyby2] + i)/8))/2
                i = i + 25
            elif(mask_values_new[i,j] == 0 and mask_values_new[i+1,j] == 1):
                mask_values_new[i-dyby2:i+dyby2,j] = (1 + np.tanh((yt[i-dyby2:i+dyby2] - i)/8))/2
                i = i + 25
        i = i + 1
    j = j + 1

In [None]:
mask_new_da = xr.DataArray(mask_values_new, coords = [mask.yt_ocean, mask.xt_ocean], dims = ['yt_ocean', 'xt_ocean'], name = 'mask', attrs = {'units':'none'})

In [None]:
mask_new_da.plot()

In [None]:
plt.figure(figsize = (10, 6))
mask_new_da.sel(xt_ocean = 100, method = 'nearest').plot()
plt.grid()
plt.savefig('Figures/one_latitude_corrected.jpeg', bbox_inches = 'tight', dpi = 900, transparent=True)

In [None]:
rho0 = 1026
st = '2000-01-01'
et = '2009-12-31'

tx_trans = cc.querying.getvar(expt = expt, session = session, variable = 'tx_trans_int_z', frequency = '1 monthly').sel(time = slice(st, et))
psi = -tx_trans.cumsum('yt_ocean').where(abs(tx_trans<=1.e20))/(rho0*1.e6)

psi_acc = np.nanmin(psi.sel(xu_ocean = slice(-69, -67), yt_ocean = slice(-80, -55)).mean('time'))

psi_g = psi.mean('time') - psi_acc
psi_g = psi_g.rename('Barotropic Stream function')
psi_g.attrs['long_name'] = 'Barotropic Stream function'
psi_g.attrs['units'] = 'Sv'

psi_g = psi_g.where(psi_g.yt_ocean > 0, -psi_g)

# Grid (used for plotting)
geolon_c = xr.open_dataset('/g/data/ik11/grids/ocean_grid_025.nc').geolon_c
geolat_t = xr.open_dataset('/g/data/ik11/grids/ocean_grid_025.nc').geolat_t

# Define the levels for the contourf
lvls = np.arange(-80, 90, 10)

fig = plt.figure(figsize = (12, 8))
ax = fig.add_subplot(111, projection = ccrs.Robinson())

# Add land features and gridlines
ax.add_feature(cfeature.LAND, edgecolor = 'black', facecolor = 'gray', zorder = 2)
ax.gridlines(color='grey', linestyle='--')

ax.contour(geolon_c, geolat_t, mask_new_da, colors = 'black', levels = [-0.05, 0.05], transform=ccrs.PlateCarree(), add_colorbar=False)
# Plot the barotropic stream function
cf = ax.contourf(geolon_c, geolat_t, psi_g, levels = lvls, cmap = cm.cm.balance, extend = 'both',
                 transform = ccrs.PlateCarree())

# Add a colorbar
cbar = fig.colorbar(cf, ax = ax, orientation = 'vertical', shrink = 0.5)
cbar.set_label('Transport [Sv]', size = 18)
ax.set_title('Barotropic streamfunction', size = 18)
plt.savefig('Figures/heat_map_over_psi_corrected.jpeg', bbox_inches = 'tight', dpi = 900, transparent=True)

In [None]:
# Grid (used for plotting)
geolon_c = xr.open_dataset('/g/data/ik11/grids/ocean_grid_025.nc').geolon_c
geolat_t = xr.open_dataset('/g/data/ik11/grids/ocean_grid_025.nc').geolat_t

# Define the levels for the contourf
lvls = np.linspace(-1.5, 1.5, 21)

fig = plt.figure(figsize = (12, 8))
ax = fig.add_subplot(111, projection = ccrs.Robinson())

# Add land features and gridlines
ax.add_feature(cfeature.LAND, edgecolor = 'black', facecolor = 'gray', zorder = 2)
ax.gridlines(color='grey', linestyle='--')

# Plot the barotropic stream function
cf = ax.contourf(geolon_c, geolat_t, mask_new_da, levels = lvls, cmap = cm.cm.balance, extend = 'both',
                 transform = ccrs.PlateCarree())

# Adding xticks
# Add a colorbar
cbar = fig.colorbar(cf, ax = ax, orientation = 'vertical', shrink = 0.5)
cbar.set_label('Heat flux mask', size = 18)
ax.set_title('Heat flux mask', size = 18)

plt.savefig('Figures/heat_map_corrected.jpeg', bbox_inches = 'tight', dpi = 900, transparent=True)

In [None]:
mask_int_new = (mask_new_da*area_t).sum(dim = ['yt_ocean','xt_ocean'])#/area_t.sum(dim = ['yt_ocean','xt_ocean'])
mask_int_new.values

In [None]:
mask_avg_new = (mask_new_da.sum(dim = ['yt_ocean','xt_ocean'])/area_t.sum(dim = ['yt_ocean','xt_ocean']))
mask_avg_new.values