### Saving data along contours

Contours defined in make_contour.ipynb

Extract volume transport, dzu and pot_rho_1 along contours

Then bin into sigma_1 bins

Alternative: save salt and temp, and turn salt and temp into pot_rho_x

Each submitted as PBS gadi scripts

### Save along contours

In [None]:
"""
Calculates transport, rho1 and dzu across contour looped over years for one Southern Ocean contour
"""

# Load modules

# Standard modules
import cosima_cookbook as cc
import matplotlib.pyplot as plt
import netCDF4 as nc
import xarray as xr
import numpy as np
from dask.distributed import Client
import cftime
import glob
import dask.array as dsa
from cosima_cookbook import distributed as ccd
# Ignore warnings
import logging
logging.captureWarnings(True)
logging.getLogger('py.warnings').setLevel(logging.ERROR)


if __name__ == '__main__':

    # Start a dask cluster with multiple cores
    client = Client(n_workers=8, local_directory='/scratch/x77/cy8964/dask_dump/dask_worker_space')
    # Load database
    session = cc.database.create_session('/g/data/ik11/databases/cosima_master.db')

    #### get run count argument that was passed to python script ####
    import sys
    year = str(sys.argv[1])
    expt = '01deg_jra55v13_ryf9091'

    start_time= year + '-01-01'
    end_time= year + '-12-31'
    contour_no = int(sys.argv[2])
    # reference density value:
    rho_0 = 1035.0
    # Note: change this range, so it matches the size of your contour arrays:
    ## FULL SO ##
#         lat_range = slice(-70,-29.99)
#         lat_range_big =  slice(-70.05,-29.9)
    
    lat_range = [slice(-60,-34.99),slice(-60,-34.99),slice(-60,-34.99),slice(-60,-34.99),slice(-60,-34.99),
                 slice(-60,-39.98),slice(-60,-39.98),slice(-62.91,-45),slice(-62.91,-45),slice(-62.91,-45),
                 slice(-64.99,-47),slice(-64.99,-47),slice(-70,-47),slice(-70,-47)][contour_no]
    lat_range_big = [slice(-60.05,-34.90),slice(-60.05,-34.90),slice(-60.05,-34.90),slice(-60.05,-34.90),slice(-60.05,-34.90),
                     slice(-60.05,-39.90),slice(-60.05,-39.90),slice(-62.96,-44.90),slice(-62.96,-44.90),slice(-62.96,-44.90),
                     slice(-65.02,-46.93),slice(-65.02,-46.93),slice(-70.05,-46.93),slice(-70.05,-46.93)][contour_no]
    # t-cells are further south and west than u-cells
    ## some grid data is required, a little complicated because these variables don't behave well with some
    dyt = cc.querying.getvar(expt, 'dyt',session, n=1, ncfile = 'ocean_grid.nc')
    dxu = cc.querying.getvar(expt, 'dxu',session, n=1, ncfile = 'ocean_grid.nc')

    # select latitude range:
    dxu = dxu.sel(yu_ocean=lat_range)
    dyt = dyt.sel(yt_ocean=lat_range)

    SSH = [-0.1,-0.2,-0.3,-0.4,-0.5,-0.6,-0.7,-0.8,-0.9,-1.0,-1.1,-1.2,-1.3,-1.4][contour_no]

    SO_SSH = 'SO_slope_contour_'+str(SSH)+'m_SSH.npz'

    outfile = '/g/data/x77/cy8964/Post_Process/'+SO_SSH

    choicename = ['SO_A','SO_B','SO_C','SO_D','SO_E','SO_F','SO_G','SO_H','SO_I','SO_J','SO_K','SO_L','SO_M','SO_N'][contour_no]


    # t-cells are further south and west than u-cells
    ## some grid data is required, a little complicated because these variables don't behave well with some
    dyt = cc.querying.getvar(expt, 'dyt',session, n=1, ncfile = 'ocean_grid.nc')
    dxu = cc.querying.getvar(expt, 'dxu',session, n=1, ncfile = 'ocean_grid.nc')

    # select latitude range:
    dxu = dxu.sel(yu_ocean=lat_range)
    dyt = dyt.sel(yt_ocean=lat_range)

    data = np.load(outfile)
    mask_y_transport = data['mask_y_transport']
    mask_x_transport = data['mask_x_transport']
    mask_y_transport_numbered = data['mask_y_transport_numbered']
    mask_x_transport_numbered = data['mask_x_transport_numbered']

#         #pad masks to help with interpolation
#         mask_x_transport = np.pad(mask_x_transport, ((0, 0), (1, 1)),'constant', constant_values=((0,0), (0,0)))
#         mask_y_transport = np.pad(mask_y_transport, ((0, 0), (1, 1)),'constant', constant_values=((0,0), (0,0)))
#         mask_x_transport_numbered = np.pad(mask_x_transport_numbered, ((0, 0), (1, 1)),'constant', constant_values=((0,0), (0,0)))
#         mask_y_transport_numbered = np.pad(mask_y_transport_numbered, ((0, 0), (1, 1)),'constant', constant_values=((0,0), (0,0)))


    yt_ocean = cc.querying.getvar(expt,'yt_ocean',session,n=1)
    yt_ocean = yt_ocean.sel(yt_ocean=lat_range)
    yu_ocean = cc.querying.getvar(expt,'yu_ocean',session,n=1)
    yu_ocean = yu_ocean.sel(yu_ocean=lat_range)
    xt_ocean = cc.querying.getvar(expt,'xt_ocean',session,n=1)
    xu_ocean = cc.querying.getvar(expt,'xu_ocean',session,n=1)
#         xt_ocean = xt_ocean.sel(xt_ocean = lon_range_big)
#         xu_ocean = xu_ocean.sel(xu_ocean = lon_range_big)

    # Convert contour masks to data arrays, so we can multiply them later.
    # We need to ensure the lat lon coordinates correspond to the actual data location:
    #       The y masks are used for vhrho, so like vhrho this should have dimensions (yu_ocean, xt_ocean).
    #       The x masks are used for uhrho, so like uhrho this should have dimensions (yt_ocean, xu_ocean).
    #       However the actual name will always be simply y_ocean/x_ocean irrespective of the variable
    #       to make concatenation of transports in both direction and sorting possible.

    mask_x_transport = xr.DataArray(mask_x_transport, coords = [('y_ocean', yt_ocean), ('x_ocean', xu_ocean)])
    mask_y_transport = xr.DataArray(mask_y_transport, coords = [('y_ocean', yu_ocean), ('x_ocean', xt_ocean)])
    mask_x_transport_numbered = xr.DataArray(mask_x_transport_numbered, coords = [('y_ocean', yt_ocean), ('x_ocean', xu_ocean)])
    mask_y_transport_numbered = xr.DataArray(mask_y_transport_numbered, coords = [('y_ocean', yu_ocean), ('x_ocean', xt_ocean)])
    # Create the contour order data-array. Note that in this procedure the x-grid counts have x-grid
    #   dimensions and the y-grid counts have y-grid dimensions, but these are implicit, the dimension
    #   *names* are kept general across the counts, the generic y_ocean, x_ocean, so that concatening works
    #   but we dont double up with numerous counts for one lat/lon point.

    # stack contour data into 1d:
    mask_x_numbered_1d = mask_x_transport_numbered.stack(contour_index = ['y_ocean', 'x_ocean'])
    mask_x_numbered_1d = mask_x_numbered_1d.where(mask_x_numbered_1d > 0, drop = True)
    mask_y_numbered_1d = mask_y_transport_numbered.stack(contour_index = ['y_ocean', 'x_ocean'])
    mask_y_numbered_1d = mask_y_numbered_1d.where(mask_y_numbered_1d > 0, drop = True)
    contour_ordering = xr.concat((mask_x_numbered_1d,mask_y_numbered_1d), dim = 'contour_index')
    contour_ordering = contour_ordering.sortby(contour_ordering)
    contour_index_array = np.arange(1,len(contour_ordering)+1)

    # Note vhrho_nt is v*dz*1035 and is positioned on north centre edge of t-cell.
    vhrho = cc.querying.getvar(expt,'vhrho_nt',session,start_time=start_time, end_time=end_time)
    uhrho = cc.querying.getvar(expt,'uhrho_et',session,start_time=start_time, end_time=end_time)

    # select latitude range and this month:
    vhrho = vhrho.sel(yt_ocean=lat_range).sel(time=slice(start_time,end_time))
    uhrho = uhrho.sel(yt_ocean=lat_range).sel(time=slice(start_time,end_time))

    # Note that vhrho is defined as the transport across the northern edge of a tracer cell so its coordinates
    #       should be (yu_ocean, xt_ocean).
    #  uhrho is defined as the transport across the eastern edge of a tracer cell so its coordinates should
    #       be (yt_ocean, xu_ocean).
    #  However we will keep the actual name as simply y_ocean/x_ocean irrespective of the variable
    #       to make concatenation and sorting possible.
    yt_ocean = dyt.yt_ocean.values
    yu_ocean = dxu.yu_ocean.values
    xu_ocean = dxu.xu_ocean.values
    xt_ocean = dyt.xt_ocean.values
    vhrho.coords['yt_ocean'] = yu_ocean
    uhrho.coords['xt_ocean'] = xu_ocean
    vhrho = vhrho.rename({'yt_ocean':'y_ocean', 'xt_ocean':'x_ocean'})
    uhrho = uhrho.rename({'yt_ocean':'y_ocean', 'xt_ocean':'x_ocean'})

    # First we also need to change coords on dxu, dyt, so we can multiply the transports:
    dyt = dyt.reset_coords().dyt # remove geolon_t/geolat_t coordinates
    dxu = dxu.reset_coords().dxu # remove geolon_t/geolat_t coordinates
    dxu.coords['xu_ocean'] = xt_ocean
    dxu = dxu.rename({'yu_ocean':'y_ocean', 'xu_ocean':'x_ocean'})
    dyt.coords['xt_ocean'] = xu_ocean
    dyt = dyt.rename({'yt_ocean':'y_ocean','xt_ocean':'x_ocean'})

    # convert to transports and multiply by contour masks:
    vhrho = vhrho*dxu*mask_y_transport/rho_0
    uhrho = uhrho*dyt*mask_x_transport/rho_0

    ## initiate a empty dataarray
    vol_trans_across_contour = xr.DataArray(np.zeros((len(uhrho.time),len(uhrho.st_ocean),len(contour_index_array))),
                                        coords = [uhrho.time,uhrho.st_ocean, contour_index_array],
                                        dims = ['time','st_ocean', 'contour_index'],
                                        name = 'vol_trans_across_contour')

    for time_step in range(len(uhrho.time)):
        print(time_step)
        # load one timestep of transport data:
        # loading here speeds it up a lot:
        uhrho_i = uhrho[time_step,...]
        uhrho_i = uhrho_i.fillna(0)
        uhrho_i = uhrho_i.load()
        vhrho_i = vhrho[time_step,...]
        vhrho_i = vhrho_i.fillna(0)
        vhrho_i = vhrho_i.load()

        # stack transports into 1d and drop any points not on contour:
        x_transport_1d_i = uhrho_i.stack(contour_index = ['y_ocean', 'x_ocean'])
        x_transport_1d_i = x_transport_1d_i.where(mask_x_numbered_1d>0, drop = True)
        y_transport_1d_i = vhrho_i.stack(contour_index = ['y_ocean', 'x_ocean'])
        y_transport_1d_i = y_transport_1d_i.where(mask_y_numbered_1d>0, drop = True)

        # combine all points on contour:
        vol_trans_across_contour_i = xr.concat((x_transport_1d_i, y_transport_1d_i), dim = 'contour_index')
        vol_trans_across_contour_i = vol_trans_across_contour_i.sortby(contour_ordering)
        vol_trans_across_contour_i.coords['contour_index'] = contour_index_array
        vol_trans_across_contour_i = vol_trans_across_contour_i.load()

        # write into larger array:
        vol_trans_across_contour[time_step,:,:] = vol_trans_across_contour_i

        del uhrho_i, vhrho_i, x_transport_1d_i, y_transport_1d_i, vol_trans_across_contour_i

    save_dir = '/g/data/x77/cy8964/Post_Process/New_SO/'+choicename+'_'
    ds_vol_trans_across_contour = xr.Dataset({'vol_trans_across_contour': vol_trans_across_contour})
    ds_vol_trans_across_contour.to_netcdf(save_dir+'vol_trans_across_contour_'+year+'.nc')

    ########## now save dzu for this year ###############
    yt_ocean = cc.querying.getvar(expt,'yt_ocean',session,n=1)
    yt_ocean = yt_ocean.sel(yt_ocean=lat_range)
    yu_ocean = cc.querying.getvar(expt,'yu_ocean',session,n=1)
    yu_ocean = yu_ocean.sel(yu_ocean=lat_range)
    xt_ocean = cc.querying.getvar(expt,'xt_ocean',session,n=1)
    xu_ocean = cc.querying.getvar(expt,'xu_ocean',session,n=1)
#         xt_ocean = xt_ocean.sel(xt_ocean = lon_range_big)
#         xu_ocean = xu_ocean.sel(xu_ocean = lon_range_big)

    dzt = cc.querying.getvar(expt,'dzt',session,start_time=start_time, end_time=end_time,ncfile='%daily%')
    dzt = dzt.sel(yt_ocean=lat_range).sel(time=slice(start_time,end_time))

    # inititalise empty array
    dzu_along_contour = xr.DataArray(np.zeros((len(dzt.time),len(dzt.st_ocean),len(contour_index_array))),
                                      coords = [dzt.time,dzt.st_ocean, contour_index_array],
                                      dims = ['time','st_ocean', 'contour_index'],
                                      name = 'dzu_along_contour')

    for time_step in range(len(dzt.time)):
        print(time_step)
    #     if time_step == 3:
    #         break

        # This is faster if we load first here:
        dzt_i = dzt[time_step,...]
        dzt_i = dzt_i.fillna(0)
        dzt_i = dzt_i.load()
        dzt_i = dzt_i.rename({'yt_ocean':'y_ocean', 'xt_ocean':'x_ocean'})

        # Note that this interpolation does not work as generically as e.g. salt.interp(),
        #    but it is much faster and doesn't require removing chunking (which also slow things down).
        # Be careful that your latitude range extends at least one point either direction beyond your contour.
        # If your domain is not the full longitude range, you will need to adapt this, so you have the correct interpolation
        #    only the edges of your domain (it assumes it is reentrant).
        # Need to overwrite coords, so these two variables can be added together.

        #First create dzu

        dzt_i_right = dzt_i.roll(x_ocean = -1, roll_coords = False)
        dzt_i_up = dzt_i.roll(y_ocean = -1,roll_coords = False)
        dzt_i_up_right=dzt_i.roll(y_ocean = -1,roll_coords = False).roll(x_ocean = -1, roll_coords = False)
        dzu = np.fmin(np.fmin(np.fmin(dzt_i,dzt_i_right),dzt_i_up),dzt_i_up_right)

        #now the xgrid needs BAY(dzu) while ygrid needs BAX(dzu) so that they are on uhrho and vhrho grids

        dzu_n = dzu.copy()
        dzu_s = dzu.roll(y_ocean = 1, roll_coords=False)
        BAY_dzu = (dzu_n+dzu_s)/2
        BAY_dzu['x_ocean'] = xu_ocean.values

        dzu_e = dzu.copy()
        dzu_w = dzu.roll(x_ocean = 1, roll_coords=False)
        BAX_dzu = (dzu_w+dzu_e)/2
        BAX_dzu['y_ocean'] = yu_ocean.values

        # stack transports into 1d and drop any points not on contour:
        BAY_dzu = BAY_dzu.where(mask_x_transport_numbered>0)
        BAX_dzu = BAX_dzu.where(mask_y_transport_numbered>0)
        x_dzu_1d = BAY_dzu.stack(contour_index = ['y_ocean', 'x_ocean'])
        y_dzu_1d = BAX_dzu.stack(contour_index = ['y_ocean', 'x_ocean'])
        x_dzu_1d = x_dzu_1d.where(mask_x_numbered_1d>0,drop=True)
        y_dzu_1d = y_dzu_1d.where(mask_y_numbered_1d>0,drop=True)

        # combine all points on contour:
        dzu_along_contour_i = xr.concat((x_dzu_1d, y_dzu_1d), dim = 'contour_index')
        dzu_along_contour_i = dzu_along_contour_i.sortby(contour_ordering)
        dzu_along_contour_i.coords['contour_index'] = contour_index_array
        dzu_along_contour_i = dzu_along_contour_i.load()

        # write into larger array:
        dzu_along_contour[time_step,:,:] = dzu_along_contour_i

        del dzt_i,dzu,dzu_w,dzu_e, dzu_s, dzu_n, BAY_dzu, BAX_dzu, x_dzu_1d, y_dzu_1d, dzu_along_contour_i

    ### Save:
    save_dir = '/g/data/x77/cy8964/Post_Process/New_SO/'+choicename+'_'

    ds_dzu_along_contour = xr.Dataset({'dzu_along_contour': dzu_along_contour})
    ds_dzu_along_contour.to_netcdf(save_dir+'dzu_along_contour_'+year+'.nc')

    ########## now save pot_rho_1 for this year ##########
    yt_ocean = cc.querying.getvar(expt,'yt_ocean',session,n=1)
    yt_ocean = yt_ocean.sel(yt_ocean=lat_range)
    yu_ocean = cc.querying.getvar(expt,'yu_ocean',session,n=1)
    yu_ocean = yu_ocean.sel(yu_ocean=lat_range)
    yu_ocean = yu_ocean.rename({'yu_ocean':'y_ocean'})
    xt_ocean = cc.querying.getvar(expt,'xt_ocean',session,n=1)
    xu_ocean = cc.querying.getvar(expt,'xu_ocean',session,n=1)
#         xt_ocean = xt_ocean.sel(xt_ocean = lon_range_big)
#         xu_ocean = xu_ocean.sel(xu_ocean = lon_range_big)
    xu_ocean = xu_ocean.rename({'xu_ocean':'x_ocean'})

    rho = cc.querying.getvar(expt,'pot_rho_1',session,start_time=start_time, end_time=end_time,ncfile='%daily%')
    rho = rho.sel(yt_ocean=lat_range).sel(time=slice(start_time,end_time))
    # inititalise empty array
    rho_along_contour = xr.DataArray(np.zeros((len(rho.time),len(rho.st_ocean),len(contour_index_array))),
                                      coords = [rho.time,rho.st_ocean, contour_index_array],
                                      dims = ['time','st_ocean', 'contour_index'],
                                      name = 'rho_along_contour')

    for time_step in range(len(rho.time)):
        print(time_step)
    #     if time_step == 5:
    #         break

        # This is faster if we load first here:
        rho_i = rho[time_step,...]
        rho_i = rho_i.fillna(0)
        rho_i = rho_i.load()
        rho_i = rho_i.rename({'yt_ocean':'y_ocean', 'xt_ocean':'x_ocean'})

        # Note that this interpolation does not work as generically as e.g. rho.interp(),
        #    but it is much faster and doesn't require removing chunking (which also slow things down).
        # Be careful that your latitude range extends at least one point either direction beyond your contour.
        # If your domain is not the full longitude range, you will need to adapt this, so you have the correct interpolation
        #    only the edges of your domain (it assumes it is reentrant).
        # Need to overwrite coords, so these two variables can be added together:
        rho_w = rho_i.copy()
        rho_w.coords['x_ocean'] = xu_ocean.values
        rho_e = rho_i.roll(x_ocean=-1)
        rho_e.coords['x_ocean'] = xu_ocean.values
        # rho_xgrid will be on the uhrho grid:
        rho_xgrid = (rho_e + rho_w)/2

        rho_s = rho_i.copy()
        rho_s.coords['y_ocean'] = yu_ocean.values
        rho_n = rho_i.roll(y_ocean=-1)
        rho_n.coords['y_ocean'] = yu_ocean.values
        # rho_ygrid will be on the vhrho grid:
        rho_ygrid = (rho_s + rho_n)/2

        # stack transports into 1d and drop any points not on contour:
        rho_xgrid = rho_xgrid.where(mask_x_transport_numbered>0)
        rho_ygrid = rho_ygrid.where(mask_y_transport_numbered>0)
        x_rho_1d = rho_xgrid.stack(contour_index = ['y_ocean', 'x_ocean'])
        y_rho_1d = rho_ygrid.stack(contour_index = ['y_ocean', 'x_ocean'])
        x_rho_1d = x_rho_1d.where(mask_x_numbered_1d>0,drop=True)
        y_rho_1d = y_rho_1d.where(mask_y_numbered_1d>0,drop=True)

        # combine all points on contour:
        rho_along_contour_i = xr.concat((x_rho_1d, y_rho_1d), dim = 'contour_index')
        rho_along_contour_i = rho_along_contour_i.sortby(contour_ordering)
        rho_along_contour_i.coords['contour_index'] = contour_index_array
        rho_along_contour_i = rho_along_contour_i.load()

        # write into larger array:
        rho_along_contour[time_step,:,:] = rho_along_contour_i

        del rho_i,rho_w,rho_e, rho_s, rho_n, rho_xgrid, rho_ygrid, x_rho_1d, y_rho_1d, rho_along_contour_i

    ### Save:
    save_dir = '/g/data/x77/cy8964/Post_Process/New_SO/'+choicename+'_'

    ds_rho_along_contour = xr.Dataset({'pot_rho_1_along_contour': rho_along_contour})
    ds_rho_along_contour.to_netcdf(save_dir+'pot_rho_1_along_contour_'+year+'.nc')

### Binning

In [None]:
"""
Bins volume transport and dzu for 10 years for one Southern Ocean contour
"""

# Load modules

# Standard modules
import cosima_cookbook as cc
import matplotlib.pyplot as plt
import netCDF4 as nc
import xarray as xr
import numpy as np
from dask.distributed import Client
import cftime
import glob
import dask.array as dsa
from cosima_cookbook import distributed as ccd
# Ignore warnings
import logging
logging.captureWarnings(True)
logging.getLogger('py.warnings').setLevel(logging.ERROR)


if __name__ == '__main__':

    # Start a dask cluster with multiple cores
    client = Client(n_workers=8, local_directory='/scratch/x77/cy8964/dask_dump/dask_worker_space')
    # Load database
    session = cc.database.create_session('/g/data/ik11/databases/cosima_master.db')

    #### get run count argument that was passed to python script ####
    import sys
    contour_no = int(sys.argv[1]) ## this is range 0 to 13, defining which contour (SSH=-0.1 to -1.4)

    expt = '01deg_jra55v13_ryf9091'
    year = '2170'
    start_time= year + '-01-01'
    end_time= year + '-12-31'

    # reference density value:
    rho_0 = 1035.0
    # Note: change this range, so it matches the size of your contour arrays:
    # Note different contours have different ranges for efficiency. Ensure the same range as how contours were made
    lat_range = [slice(-60,-34.99),slice(-60,-34.99),slice(-60,-34.99),slice(-60,-34.99),slice(-60,-34.99),
                 slice(-60,-39.98),slice(-60,-39.98),slice(-62.91,-45),slice(-62.91,-45),slice(-62.91,-45),
                 slice(-64.99,-47),slice(-64.99,-47),slice(-70,-47),slice(-70,-47)][contour_no]
    lat_range_big = [slice(-60.05,-34.90),slice(-60.05,-34.90),slice(-60.05,-34.90),slice(-60.05,-34.90),slice(-60.05,-34.90),
                     slice(-60.05,-39.90),slice(-60.05,-39.90),slice(-62.96,-44.90),slice(-62.96,-44.90),slice(-62.96,-44.90),
                     slice(-65.02,-46.93),slice(-65.02,-46.93),slice(-70.05,-46.93),slice(-70.05,-46.93)][contour_no]
    # t-cells are further south and west than u-cells
    ## some grid data is required, a little complicated because these variables don't behave well with some
    dyt = cc.querying.getvar(expt, 'dyt',session, n=1, ncfile = 'ocean_grid.nc')
    dxu = cc.querying.getvar(expt, 'dxu',session, n=1, ncfile = 'ocean_grid.nc')

    # select latitude range:
    dxu = dxu.sel(yu_ocean=lat_range)
    dyt = dyt.sel(yt_ocean=lat_range)

    SSH = [-0.1,-0.2,-0.3,-0.4,-0.5,-0.6,-0.7,-0.8,-0.9,-1.0,-1.1,-1.2,-1.3,-1.4][contour_no]

    SO_SSH = 'SO_slope_contour_'+str(SSH)+'m_SSH.npz'

    outfile = '/g/data/x77/cy8964/Post_Process/'+SO_SSH

    choicename = ['SO_A','SO_B','SO_C','SO_D','SO_E','SO_F','SO_G','SO_H','SO_I','SO_J','SO_K','SO_L','SO_M','SO_N'][contour_no]

    data = np.load(outfile)
    mask_y_transport = data['mask_y_transport']
    mask_x_transport = data['mask_x_transport']
    mask_y_transport_numbered = data['mask_y_transport_numbered']
    mask_x_transport_numbered = data['mask_x_transport_numbered']

#         #pad masks to help with interpolation
#         mask_x_transport = np.pad(mask_x_transport, ((0, 0), (1, 1)),'constant', constant_values=((0,0), (0,0)))
#         mask_y_transport = np.pad(mask_y_transport, ((0, 0), (1, 1)),'constant', constant_values=((0,0), (0,0)))
#         mask_x_transport_numbered = np.pad(mask_x_transport_numbered, ((0, 0), (1, 1)),'constant', constant_values=((0,0), (0,0)))
#         mask_y_transport_numbered = np.pad(mask_y_transport_numbered, ((0, 0), (1, 1)),'constant', constant_values=((0,0), (0,0)))


    yt_ocean = cc.querying.getvar(expt,'yt_ocean',session,n=1)
    yt_ocean = yt_ocean.sel(yt_ocean=lat_range)
    yu_ocean = cc.querying.getvar(expt,'yu_ocean',session,n=1)
    yu_ocean = yu_ocean.sel(yu_ocean=lat_range)
    xt_ocean = cc.querying.getvar(expt,'xt_ocean',session,n=1)
    xu_ocean = cc.querying.getvar(expt,'xu_ocean',session,n=1)
#         xt_ocean = xt_ocean.sel(xt_ocean = lon_range_big)
#         xu_ocean = xu_ocean.sel(xu_ocean = lon_range_big)

    # Convert contour masks to data arrays, so we can multiply them later.
    # We need to ensure the lat lon coordinates correspond to the actual data location:
    #       The y masks are used for vhrho, so like vhrho this should have dimensions (yu_ocean, xt_ocean).
    #       The x masks are used for uhrho, so like uhrho this should have dimensions (yt_ocean, xu_ocean).
    #       However the actual name will always be simply y_ocean/x_ocean irrespective of the variable
    #       to make concatenation of transports in both direction and sorting possible.

    mask_x_transport = xr.DataArray(mask_x_transport, coords = [('y_ocean', yt_ocean), ('x_ocean', xu_ocean)])
    mask_y_transport = xr.DataArray(mask_y_transport, coords = [('y_ocean', yu_ocean), ('x_ocean', xt_ocean)])
    mask_x_transport_numbered = xr.DataArray(mask_x_transport_numbered, coords = [('y_ocean', yt_ocean), ('x_ocean', xu_ocean)])
    mask_y_transport_numbered = xr.DataArray(mask_y_transport_numbered, coords = [('y_ocean', yu_ocean), ('x_ocean', xt_ocean)])
    # Create the contour order data-array. Note that in this procedure the x-grid counts have x-grid
    #   dimensions and the y-grid counts have y-grid dimensions, but these are implicit, the dimension
    #   *names* are kept general across the counts, the generic y_ocean, x_ocean, so that concatening works
    #   but we dont double up with numerous counts for one lat/lon point.

    # stack contour data into 1d:
    mask_x_numbered_1d = mask_x_transport_numbered.stack(contour_index = ['y_ocean', 'x_ocean'])
    mask_x_numbered_1d = mask_x_numbered_1d.where(mask_x_numbered_1d > 0, drop = True)
    mask_y_numbered_1d = mask_y_transport_numbered.stack(contour_index = ['y_ocean', 'x_ocean'])
    mask_y_numbered_1d = mask_y_numbered_1d.where(mask_y_numbered_1d > 0, drop = True)
    contour_ordering = xr.concat((mask_x_numbered_1d,mask_y_numbered_1d), dim = 'contour_index')
    contour_ordering = contour_ordering.sortby(contour_ordering)
    contour_index_array = np.arange(1,len(contour_ordering)+1)


    ########## now bin dzu for all 10 years ##########
    yt_ocean = cc.querying.getvar(expt,'yt_ocean',session,n=1)
    yt_ocean = yt_ocean.sel(yt_ocean=lat_range)
    yu_ocean = cc.querying.getvar(expt,'yu_ocean',session,n=1)
    yu_ocean = yu_ocean.sel(yu_ocean=lat_range)
    yu_ocean = yu_ocean.rename({'yu_ocean':'y_ocean'})
    xt_ocean = cc.querying.getvar(expt,'xt_ocean',session,n=1)
    xu_ocean = cc.querying.getvar(expt,'xu_ocean',session,n=1)
    xu_ocean = xu_ocean.rename({'xu_ocean':'x_ocean'})
    
    # get lat and lon along contour, useful for plotting later:
    lat_along_contour = contour_ordering.y_ocean
    lon_along_contour = contour_ordering.x_ocean
    contour_index_array = np.arange(1,len(contour_ordering)+1)
    # don't need the multi-index anymore, replace with contour count and save
    lat_along_contour.coords['contour_index'] = contour_index_array
    lon_along_contour.coords['contour_index'] = contour_index_array

    # extract saved dzu and sigma1 along contours, which are saved for each year
    year = '2170'
    save_dir = '/g/data/x77/cy8964/Post_Process/New_SO/'+choicename+'_'

    vol_trans_across_contour = xr.open_dataset(save_dir+'dzu_along_contour_'+year+'.nc')
    vol_trans_across_contour = vol_trans_across_contour.dzu_along_contour
    vol_trans_across_contour = vol_trans_across_contour.load()

    sigma1_along_contour = xr.open_dataset(save_dir+'pot_rho_1_along_contour_'+year+'.nc')
    sigma1_along_contour = sigma1_along_contour.pot_rho_1_along_contour
    sigma1_along_contour = sigma1_along_contour.load()
    for year in np.arange(2171,2180):

        vol_trans_across_contour_i = xr.open_dataset(save_dir+'dzu_along_contour_'+str(year)+'.nc')
        vol_trans_across_contour_i = vol_trans_across_contour_i.dzu_along_contour
        vol_trans_across_contour_i = vol_trans_across_contour_i.load()

        sigma1_along_contour_i = xr.open_dataset(save_dir+'pot_rho_1_along_contour_'+str(year)+'.nc')
        sigma1_along_contour_i = sigma1_along_contour_i.pot_rho_1_along_contour
        sigma1_along_contour_i = sigma1_along_contour_i.load()

        vol_trans_across_contour = xr.concat([vol_trans_across_contour,vol_trans_across_contour_i], dim = 'time')
        sigma1_along_contour = xr.concat([sigma1_along_contour,sigma1_along_contour_i], dim = 'time')


    time = sigma1_along_contour.time

    ## define isopycnal bins
    isopycnal_bins_sigma1 = 1000+ np.array([1,28,29,30,31,31.5,31.9,32,32.1,32.2,32.25,
                                                32.3,32.35,32.4,32.42,32.44,32.46,32.48,32.50,32.51,
                                                32.52,32.53,32.54,32.55,32.56,32.58,32.6,32.8,33,34,45])


    ## intialise empty transport along contour in density bins array
    vol_trans_across_contour_binned = xr.DataArray(np.zeros((len(time),len(isopycnal_bins_sigma1),len(contour_ordering))), 
                                                   coords = [time,isopycnal_bins_sigma1, contour_index_array], 
                                                   dims = ['time','isopycnal_bins', 'contour_index'], 
                                                   name = 'vol_trans_across_contour_binned')

    # loop through density bins:
    for i in range(len(isopycnal_bins_sigma1)-1):
        print(i)
        bin_mask = sigma1_along_contour.where(sigma1_along_contour<=isopycnal_bins_sigma1[i+1]).where(sigma1_along_contour>isopycnal_bins_sigma1[i])*0+1
        bin_fractions = (isopycnal_bins_sigma1[i+1]-sigma1_along_contour * bin_mask)/(isopycnal_bins_sigma1[i+1]-isopycnal_bins_sigma1[i])
        ## transport
        transport_across_contour_in_sigmalower_bin = (vol_trans_across_contour * bin_mask * bin_fractions).sum(dim = 'st_ocean')
        vol_trans_across_contour_binned[:,i,:] += transport_across_contour_in_sigmalower_bin.fillna(0)
        del transport_across_contour_in_sigmalower_bin
        transport_across_contour_in_sigmaupper_bin = (vol_trans_across_contour * bin_mask * (1-bin_fractions)).sum(dim = 'st_ocean')
        vol_trans_across_contour_binned[:,i+1,:] += transport_across_contour_in_sigmaupper_bin.fillna(0)
        del bin_mask, bin_fractions, transport_across_contour_in_sigmaupper_bin

    #save
    year = '2170'
    save_dir = '/g/data/x77/cy8964/Post_Process/New_SO/'+choicename+'_'

    ds_vol_trans_across_contour_binned = xr.Dataset({'dzu_across_contour_binned': vol_trans_across_contour_binned})
    ds_vol_trans_across_contour_binned.to_netcdf(save_dir+'dzu_across_contour_binned.nc')

    year = '2170'
    save_dir = '/g/data/x77/cy8964/Post_Process/New_SO/'+choicename+'_'

    #bin volume
    # extract volume and sigma1 along contours for each year
    vol_trans_across_contour = xr.open_dataset(save_dir+'vol_trans_across_contour_'+year+'.nc')
    vol_trans_across_contour = vol_trans_across_contour.vol_trans_across_contour
    vol_trans_across_contour = vol_trans_across_contour.chunk(chunks = {"time":1,"st_ocean":10,"contour_index":200})
    vol_trans_across_contour = vol_trans_across_contour.load()


    sigma1_along_contour = xr.open_dataset(save_dir+'pot_rho_1_along_contour_'+year+'.nc')
    sigma1_along_contour = sigma1_along_contour.pot_rho_1_along_contour
    sigma1_along_contour = sigma1_along_contour.chunk(chunks = {"time":1,"st_ocean":10,"contour_index":200})
    sigma1_along_contour = sigma1_along_contour.load()
    for year in np.arange(2171,2180):
        vol_trans_across_contour_i = xr.open_dataset(save_dir+'vol_trans_across_contour_'+str(year)+'.nc')
        vol_trans_across_contour_i = vol_trans_across_contour_i.vol_trans_across_contour
        vol_trans_across_contour_i = vol_trans_across_contour_i.chunk(chunks = {"time":1,"st_ocean":10,"contour_index":200 })
        vol_trans_across_contour_i = vol_trans_across_contour_i.load()

        sigma1_along_contour_i = xr.open_dataset(save_dir+'pot_rho_1_along_contour_'+str(year)+'.nc')
        sigma1_along_contour_i = sigma1_along_contour_i.pot_rho_1_along_contour
        sigma1_along_contour_i = sigma1_along_contour_i.chunk(chunks = {"time":1,"st_ocean":10,"contour_index":200 })
        sigma1_along_contour_i = sigma1_along_contour_i.load()

        vol_trans_across_contour = xr.concat([vol_trans_across_contour,vol_trans_across_contour_i], dim = 'time')
        sigma1_along_contour = xr.concat([sigma1_along_contour,sigma1_along_contour_i], dim = 'time')


    time = sigma1_along_contour.time

    ## define isopycnal bins
    isopycnal_bins_sigma1 = 1000+ np.array([1,28,29,30,31,31.5,31.9,32,32.1,32.2,32.25,
                                                32.3,32.35,32.4,32.42,32.44,32.46,32.48,32.50,32.51,
                                                32.52,32.53,32.54,32.55,32.56,32.58,32.6,32.8,33,34,45])


    ## intialise empty transport along contour in density bins array
    vol_trans_across_contour_binned = xr.DataArray(np.zeros((len(time),len(isopycnal_bins_sigma1),len(contour_ordering))), 
                                                   coords = [time,isopycnal_bins_sigma1, contour_index_array], 
                                                   dims = ['time','isopycnal_bins', 'contour_index'], 
                                                   name = 'vol_trans_across_contour_binned')

    # loop through density bins:
    for i in range(len(isopycnal_bins_sigma1)-1):
        print(i)
        bin_mask = sigma1_along_contour.where(sigma1_along_contour<=isopycnal_bins_sigma1[i+1]).where(sigma1_along_contour>isopycnal_bins_sigma1[i])*0+1
        bin_fractions = (isopycnal_bins_sigma1[i+1]-sigma1_along_contour * bin_mask)/(isopycnal_bins_sigma1[i+1]-isopycnal_bins_sigma1[i])
        ## transport
        transport_across_contour_in_sigmalower_bin = (vol_trans_across_contour * bin_mask * bin_fractions).sum(dim = 'st_ocean')
        vol_trans_across_contour_binned[:,i,:] += transport_across_contour_in_sigmalower_bin.fillna(0)
        del transport_across_contour_in_sigmalower_bin
        transport_across_contour_in_sigmaupper_bin = (vol_trans_across_contour * bin_mask * (1-bin_fractions)).sum(dim = 'st_ocean')
        vol_trans_across_contour_binned[:,i+1,:] += transport_across_contour_in_sigmaupper_bin.fillna(0)
        del bin_mask, bin_fractions, transport_across_contour_in_sigmaupper_bin

    year = '2170'
    save_dir = '/g/data/x77/cy8964/Post_Process/New_SO/'+choicename+'_' 

    ds_vol_trans_across_contour_binned = xr.Dataset({'vol_trans_across_contour_binned': vol_trans_across_contour_binned})
    ds_vol_trans_across_contour_binned.to_netcdf(save_dir+'vol_trans_across_contour_binned.nc')


### If saved salt and temp instead (swap pot_rho_1 variable for salt and temp):
you can get potential density from

from gsw import SA_from_SP, p_from_z, sigma2


In [None]:
    temp_along_contour = xr.open_dataset(save_dir+'temp_along_contour_'+year+'.nc')
    temp_along_contour = temp_along_contour.temp_along_contour
    temp_along_contour = temp_along_contour.load()
    
    salt_along_contour = xr.open_dataset(save_dir+'salt_along_contour_'+year+'.nc')
    salt_along_contour = salt_along_contour.salt_along_contour
    salt_along_contour = salt_along_contour.load()    
    # load the other years too to have a 10 year daily dataset
    for year in np.arange(2171,2180):

#         dzu_along_contour_i = xr.open_dataset(save_dir+'dzu_along_contour_'+str(year)+'.nc')
#         dzu_along_contour_i = dzu_along_contour_i.dzu_along_contour
#         dzu_along_contour_i = dzu_along_contour_i.load()

        temp_along_contour_i = xr.open_dataset(save_dir+'temp_along_contour_'+str(year)+'.nc')
        temp_along_contour_i = temp_along_contour_i.temp_along_contour
        temp_along_contour_i = temp_along_contour_i.load()

        salt_along_contour_i = xr.open_dataset(save_dir+'salt_along_contour_'+str(year)+'.nc')
        salt_along_contour_i = salt_along_contour_i.salt_along_contour
        salt_along_contour_i = salt_along_contour_i.load()    


#         dzu_along_contour = xr.concat([dzu_along_contour,dzu_along_contour_i], dim = 'time')
        temp_along_contour = xr.concat([temp_along_contour,temp_along_contour_i], dim = 'time')
        salt_along_contour = xr.concat([salt_along_contour,salt_along_contour_i], dim = 'time')

    # turn salt and temp into sigma2
    time = salt_along_contour.time
    st_ocean = cc.querying.getvar(expt,'st_ocean',session,n=1)
    depth = -st_ocean.values
    depth = xr.DataArray(depth, coords = [st_ocean], dims = ['st_ocean'])
    depth_along_contour = (salt_along_contour[0,...]*0+1)*depth

    pressure_along_contour = xr.DataArray(p_from_z(depth_along_contour,lat_along_contour), 
                                          coords = [st_ocean, contour_index_array], 
                                          dims = ['st_ocean','contour_index'], 
                                          name = 'pressure', attrs = {'units':'dbar'})

    # absolute salinity:
    abs_salt_along_contour = xr.DataArray(SA_from_SP(salt_along_contour,pressure_along_contour,
                                                 lon_along_contour,lat_along_contour), 
                                      coords = [time,st_ocean,contour_index_array], 
                                      dims = ['time','st_ocean','contour_index'], 
                                      name = 'Absolute salinity', 
                                      attrs = {'units':'Absolute Salinity (g/kg)'})
    # sigma2:
    sigma2_along_contour = xr.DataArray(sigma2(abs_salt_along_contour, temp_along_contour-273.15),
                                         coords = [time,st_ocean, contour_index_array], 
                                        dims = ['time','st_ocean', 'contour_index'], 
                                         name = 'potential density ref 2000dbar', 
                                        attrs = {'units':'kg/m^3 (-1000 kg/m^3)'})
