# Step 2: depth-average TF
This notebook is to test an efficient application of *Step 2* of Verjans's workflow: depth-average the ocean thermal forcing we produced in Step 1.

The full workflow is outlined in Vincent's Readme1.txt in [this Zenodo archive](https://zenodo.org/records/7931326).  We are modifying the workflow to deploy it efficiently for ISMIP7.

14 Nov 2024 | EHU

Edits:
- Applied xarray `sel` and `where` to streamline this computation. Removed unused read-in commands.
- TODO 14 Nov: Correct metadata in NC file being written out

### Imports and run settings

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Depth-averaging TF in given DepthRange at all grid points of given Model
Choose scenario of interest

@author: vincent
"""

import os
import sys
import copy
import csv
import numpy as np
import netCDF4 as nc
import xarray as xr
import dask
from datetime import datetime

from verjansFunctions import freezingPoint
from verjansFunctions import calcDpAveraged

In [None]:
## Settings for this run
savingTF         = True
cwd              = os.getcwd()+'/'

SelModel         = 'CESM2'
DepthRange       = [0,500] #depth range of interest, [200:500] is from Slater et al. (2019, 2020)
ShallowThreshold = 100 #bathymetry threshold: if bathymetry is shallower, gridpoint is discarded

DirModNC   = f'/home/theghub/ehultee/data/'
DirSave    = f'/home/theghub/ehultee/data'


### EHU: I suspect we don't need any of the below from VV, but note we'll need to think through 
###     settings for looping over multiple GCMs in the production run.
To2015hist                 = False
To2100histssp585           = False
To2100histssp126           = True

# if(To2015hist):
#     partname = 'hist'
# elif(To2100histssp585):
#     partname = 'hist2100ssp585'
# elif(To2100histssp126):
#     partname = 'hist2100ssp126'
    
# if(SelModel=='MIROCES2L'):
#     dim2d              = True
#     if(To2015hist):
#         ls_members     = [f'r{id}' for id in range(1,30+1)]
#     elif(To2100histssp585 or To2100histssp126):
#         ls_members     = [f'r{id}' for id in range(1,10+1)]
# elif(SelModel=='IPSLCM6A'):
#     dim2d              = True
#     if(To2015hist):
#         ls_members     = [f'r{id}' for id in range(1,32+1)]
#         ls_members.remove('r2') #no r2 member for IPSLCM6A
#     elif(To2100histssp585 or To2100histssp126):
#         ls_members     = ['r1','r3','r4','r6','r14']

### Read in dataset

In [None]:
path_o = DirModNC + 'tf-CESM2-200001-201412-v4_no_intermed_compute.nc'

ds = xr.open_dataset(path_o)
ds

### Average over a depth slice

What we want to do is simply compute the average TF over a depth range defined in `DepthRange` above, with the condition that bathymetry must be deeper than `ShallowThreshold`.  Vincent does this by reading various NC variables into empty arrays, then applying if-else tests to find the range over which to average.  We should be able to do this on dask arrays using xarray's `sel` command. 

NOTE: I can easily select the relevant depth range, but I think Vincent's `ShallowThreshold` approach assumes that the depth variable is only defined up to the maximum depth of the grid cell (the bathymetry).  I believe the xarray way to do this is to find cells where the TF is NaN for levels >100 m.  Should double-check.

In [None]:
depth_slice = ds.sel(lev=slice(DepthRange[0], DepthRange[1]))
depth_slice.mean(dim='lev', skipna=True)

### Apply depth condition

In [None]:
no_shallow = ds.TF.sel(lev=125.0) ## first set up what to test - the depth level just below ShallowThreshold
## TODO: automate this better? Currently it is hard-coded and a user change to ShallowThreshold would have to be manually applied here as well

deep_only = ds.TF.where(~xr.ufuncs.isnan(no_shallow))   ## now select TF in the whole dataset wherever it is *not NaN* below ShallowThreshold
deep_only.max() ## reality check: is this a float and not a nan? is it a reasonable value?

For me the value output here is about 19. Seems very high, but this is the max.  At least it's a float!

In [None]:
deep_sliced = ds.sel(lev=slice(DepthRange[0], DepthRange[1]))
dsm = deep_sliced.mean(dim='lev', skipna=True)
dsm

According to me, the above is the data Vincent wants: depth-averaged over `DepthRange`, trimmed to include only cells with data deeper than `ShallowThreshold`.  Write out to a NetCDF.

### Write NetCDF out

In [None]:
from datetime import date
out_fn = DirSave + '/tfdpavg-{}-{}.nc'.format(SelModel, date.today())

from dask.diagnostics import ProgressBar

with ProgressBar():
    dsm.to_netcdf(path=out_fn)

In [None]:
## test read-in

ds2 = xr.open_dataset(out_fn)
ds2

This shows that we have successfully computed and written out depth-averaged TF -- note that the `lev` coordinate present in the original TF dataset has now disappeared, because that is the dimension over which we averaged.  

TODO: Correct metadata for the TF variable here to indicate that it is depth-averaged.

---
Ref Vincent's raw code below.

In [None]:
## The below is Vincent's raw code, for reference.

### Depth indices ###
dpmin = DepthRange[0]
### Compute 1 member at a time ###
for mm,member in enumerate(ls_members):
    print(f'Member: {member}')
    tfdpavg    = np.zeros((len(timefull),nny,nnx))
    memberfile = f'thetao_tf_{SelModel}{partname}_{member}.nc'
    ds         = nc.Dataset(DirModNC+memberfile)
    tfprofilefull = np.array(ds.variables['thermalforcing'])
    ds.close()
    for indy in range(nny):
        print(f'indy: {indy}')
        for indx in range(nnx):
            # Extract entire tf profile #
            tf0   = tfprofilefull[:,:,indy,indx]
            # Check if bathy is at least deeper than ShallowThreshold #
            if(tf0[0,izth]<1e10):
                # Find depth index of bathymetry #
                izmax = np.where(tf0[0,:]<1e10)[0][-1]
                # Constrain max depth by DepthRange or bathymetry #
                dpmax = min(DepthRange[1],depthfull[izmax])
                # Calculate average TF over depth range #
                for tt in range(len(timefull)):
                    tfdpavg[tt,indy,indx] = calcDpAveraged(tf0[tt,:],depthfull,dmin=dpmin,dmax=dpmax)
            else:
                # Bathymetry does not go deep enough #
                tfdpavg[:,indy,indx] = 1.1e20
            
    if(savingTF):
        nameout = f'ensemble{SelModel}_{partname}_M{member}_TFdpavg_Dp{DepthRange[0]}to{DepthRange[1]}_bathymin{ShallowThreshold}.nc'
        ### Open netcdf ###
        outnc        = nc.Dataset(DirSave+nameout,'w',format='NETCDF4')
        timedim      = outnc.createDimension('timeDim',size=len(timefull)) 
        zdim         = outnc.createDimension('depthDim',size=len(depthfull)) 
        latdim       = outnc.createDimension('latDim',nny) 
        londim       = outnc.createDimension('lonDim',nnx) 
        
        time_nc      = outnc.createVariable('time','f4',('timeDim',))
        depth_nc     = outnc.createVariable('depth','f4',('depthDim',))
        if(dim2d==True):
            lat_nc   = outnc.createVariable('lat','f4',('latDim','lonDim',))
            lon_nc   = outnc.createVariable('lon','f4',('latDim','lonDim',))
        elif(dim2d==False):
            lat_nc   = outnc.createVariable('lat','f4',('latDim',))
            lon_nc   = outnc.createVariable('lon','f4',('lonDim',))
        tfdpavg_nc   = outnc.createVariable(f'tfdpavg{DepthRange[0]}to{DepthRange[1]}_bathymin{ShallowThreshold}','f4',('timeDim','latDim','lonDim',))
            
        time_nc[:]          = timefull
        depth_nc[:]         = depthfull
        if(dim2d==True):
            lat_nc[:,:]     = latsfull
            lon_nc[:,:]     = lonsfull
        elif(dim2d==False):
            lat_nc[:]       = latsfull
            lon_nc[:]       = lonsfull
        tfdpavg_nc[:,:,:]   = tfdpavg
        
        depth_nc.units     = 'meter'
        time_nc.units      = 'yr'
        tfdpavg_nc.units   = 'degC'
        outnc.close()

            
     
# print('End of python job')
# os._exit(0) 