In [None]:
import os
import gsw

import numpy as np

from netCDF4 import Dataset
import matplotlib.pyplot as plt

In [None]:
# Diretories where data is stored:

datadir = '/gws/nopw/j04/nemo_vol1/jmecking/en4/v4.2.2_c14/'
savedir = '/gws/nopw/j04/nemo_vol1/jmecking/en4/v4.2.2_c14/density/'

In [None]:
# Compute mixed layer depth:
yr = np.arange(1900,2025)
ny = np.size(yr,axis=0)

for yy in range(0,ny):
    for mm in range(0,12):
        infile  = (datadir + 'EN.4.2.2.f.analysis.c14.' + str(yr[yy]) + str(mm+1).zfill(2) + '.nc')
        outfile = (savedir + 'EN.4.2.2.f.analysis.c14.' + str(yr[yy]) + str(mm+1).zfill(2) + '_mld_d003.nc')
        if os.path.isfile(infile):
            print(infile)
            if not os.path.isfile(outfile):
                # Read in potential temperature and salinity:
                ncid      = Dataset(infile,'r')
                temp      = np.squeeze(ncid.variables['temperature'][:,:,:,:]) - 273.15
                sal       = np.squeeze(ncid.variables['salinity'][:,:,:,:])
                lat       = ncid.variables['lat'][:]
                lon       = ncid.variables['lon'][:]
                depth     = ncid.variables['depth'][:]
                dep_bnds  = ncid.variables['depth_bnds'][:,:]
                time      = ncid.variables['time'][:]
                time_bnds = ncid.variables['time_bnds'][:,:]
                ncid.close()

                # Set required variables:
                LON,LAT = np.meshgrid(lon,lat)
                LON   = np.tile(LON,(np.size(depth),1,1))
                LAT   = np.tile(LAT,(np.size(depth),1,1))
                DEP   = np.swapaxes(np.tile(depth,(np.size(lon),np.size(lat),1)),0,2)
                p     = gsw.p_from_z(-DEP,LAT)
                tmask = 1-temp.mask

                # Compute densities:
                SA     = gsw.SA_from_SP(sal, p, LON, LAT)
                CT     = gsw.conversions.CT_from_pt(SA, temp)
                sigma0 = gsw.density.sigma0(SA, CT)
                                
                # Compute differences:
                sigma0_diff = np.tile(sigma0[1,:,:],(len(depth),1,1)) - sigma0

                mld_top  = np.nan*np.ones((len(lat),len(lon)),'float')
                mld_bot  = np.nan*np.ones((len(lat),len(lon)),'float')
                mld_diff = np.nan*np.ones((len(lat),len(lon)),'float')
                mld      = np.nan*np.ones((len(lat),len(lon)),'float')
                # Compute where thresholds are hit:
                for ii in range(0,len(lon)):
                    for jj in range(0,len(lat)):
                        if np.sum(tmask[:3,jj,ii]) == 3:
                            nd = np.sum(tmask[:,jj,ii])
                            #plt.figure()
                            #plt.plot(sigma0_diff[2:,jj,ii],-depth[2:],marker='.')
                            #plt.axvline(x=-0.03,color='k')
                            #plt.ylim([-200,0])
                            inds_b = np.where(sigma0_diff[2:nd,jj,ii] < -0.03)[0] + 2
                            inds_t = np.where(sigma0_diff[2:nd,jj,ii] > -0.03)[0] + 2
                            # Check if the mixed layer depth goes to bottom:
                            if ((len(inds_b) == 0) & (len(inds_t) == nd-2)): # Mixed layer goes to bottom
                                mld[jj,ii]      = dep_bnds[nd-1,1]
                                mld_top[jj,ii]  = 0
                                mld_bot[jj,ii]  = 1
                                mld_diff[jj,ii] = 0
                                d_b = mld[jj,ii]
                                d_t = mld[jj,ii]
                                #plt.axhline(y=-mld[jj,ii],color='C4')
                            elif ((len(inds_t) == 0) & (len(inds_b) == nd-2)):  # Mixed layer stays at surface:
                                mld[jj,ii]      = 15
                                mld_top[jj,ii]  = 1
                                mld_bot[jj,ii]  = 0
                                mld_diff[jj,ii] = 0
                                d_b = 15
                                d_t = 15
                                #plt.axhline(y=-mld[jj,ii],color='C5')
                            else:
                                mld_top[jj,ii]  = 0
                                mld_bot[jj,ii]  = 0
                                mld_diff[jj,ii] = inds_b[0] - inds_t[-1]   
                                #plt.axhline(y=-depth[inds_b[0]] ,color='C1')
                                #plt.axhline(y=-depth[inds_t[-1]],color='C2',linestyle='--')
                                if inds_b[0] > inds_t[-1]:
                                    # interpolate between top and bottom:
                                    m          = (sigma0_diff[inds_b[0],jj,ii] - sigma0_diff[inds_t[-1],jj,ii])/(depth[inds_b[0]] - depth[inds_t[-1]])
                                    b          = sigma0_diff[inds_b[0],jj,ii] - m*depth[inds_b[0]]
                                    mld[jj,ii] = (-0.03 - b)/m
                                    #plt.axhline(y=-mld[jj,ii],color='C3',linestyle=':')
                                else:
                                    # Curvy line, interpolate bottom top index and the depth below:
                                    if inds_t[-1]+1 == nd:
                                        mld[jj,ii] = dep_bnds[nd-1,1]
                                    else:
                                        m          = (sigma0_diff[inds_t[-1]+1,jj,ii] - sigma0_diff[inds_t[-1],jj,ii])/(depth[inds_t[-1]+1] - depth[inds_t[-1]])
                                        b          = sigma0_diff[inds_t[-1]+1,jj,ii] - m*depth[inds_t[-1]+1]
                                        mld[jj,ii] = (-0.03 - b)/m
                                    #plt.axhline(y=-mld[jj,ii],color='C6',linestyle=':')

                                
                                
                                
                # Write to file:
                ncid = Dataset(outfile,'w')
                
                # coordinates:
                ncid.createDimension('lat'  ,len(lat))
                ncid.createDimension('lon'  ,len(lon))
                ncid.createDimension('bnds' ,2)
                ncid.createDimension('time' ,None)

                # variables:
                ncid.createVariable('lat'      ,'f8' ,('lat',))
                ncid.createVariable('lon'      ,'f8' ,('lon',))
                ncid.createVariable('time'     ,'f8' ,('time',))
                ncid.createVariable('time_bnds','f8' ,('time','bnds',))
                ncid.createVariable('mld'      ,'f8' ,('time','lat','lon',))
                ncid.createVariable('mld_diff' ,'f8' ,('time','lat','lon',))
                ncid.createVariable('mld_top'  ,'f8' ,('time','lat','lon',))
                ncid.createVariable('mld_bot'  ,'f8' ,('time','lat','lon',))

                # fill variables:
                ncid.variables['lat'][:]            = lat
                ncid.variables['lon'][:]            = lon
                ncid.variables['time'][:]           = time
                ncid.variables['time_bnds'][:]      = time_bnds
                ncid.variables['mld'][0,:,:]      = mld
                ncid.variables['mld_diff'][0,:,:] = mld_diff
                ncid.variables['mld_top'][0,:,:]  = mld_top
                ncid.variables['mld_bot'][0,:,:]  = mld_bot

                # Close netcdf file:
                ncid.close()