# Create Iron Sediment and Vent Forcing files

This notebook is based off an IDL script from J. Keith Moore (UCI),
received Sept 11, 2024.
It reads in `POC_FLUX_IN`, `UVEL`, `VVEL`, `KVMIX`, and `TEMP` from previous model run,
as well as `PERCENTSED` from a file of unknown provenance.

## Step 0: Python Imports

In [None]:
import glob
import os

from dask.distributed import Client, wait
from dask_jobqueue import PBSCluster
import numpy as np
import xarray as xr

## Step 1: Read in Data

In [None]:
# Set parameters
xfactor = 0.005639
minval = 0.1
minoxic = 0.0005343
process_LENS = False
#process_LENS = True

# CESM Output
var_list = ['POC_FLUX_IN', 'UVEL', 'VVEL', 'KVMIX', 'TEMP']

# Output File Names
grid = 'gx1v7'
percsedfile_out = f'percentsed_{grid}_cesm2_ecos1.1_2024.nc'
fesedflux_out = f'fesedflux_{grid}_CESM2_ecos1.1_2024.nc'
fesedredflux_out = f'fesedfluxRed_{grid}_CESM2_ecos1.1_2024.nc'

for file in [percsedfile_out, fesedflux_out, fesedredflux_out]:
    if os.path.exists(file):
        os.remove(file)

In [None]:
# Create empty datasets to collect data arrays
ds_means = xr.Dataset()
local_vars = xr.Dataset()
outputs = xr.Dataset()

In [None]:
# Other forcing file
inputdir = os.path.join(os.path.sep,
                        'glade',
                        'work',
                        'mlevy',
                        'cesm_inputdata',
                        'inputs_for_fe_forcing',
                       )

if process_LENS:
    # Get ready to read lots of LENS data
    CASENAME = 'b.e21.BHISTcmip6.f09_g17.LE2-1001.001' # Testing with LENS2, which is in time series -- reading history output will be different
    outroot = os.path.join(os.path.sep,
                           "glade",
                           "campaign",
                           "cgd",
                           "cesm",
                           "CESM2-LE",
                           "ocn",
                           "proc",
                           "tseries",
                           "month_1"
                          )

    # Read one variable at a time into a dictionary
    for var in var_list:
        chunk_size = {'time': 120, 'nlat': 384, 'nlon': 320}
        if var in ['UVEL', 'VVEL', 'TEMP']:
            chunk_size['time'] = 4
            chunk_size['z_t'] = 60
        elif var == "KVMIX":
            chunk_size['time'] = 4
            chunk_size['z_w_bot'] = 60
        if 'ds_in' not in locals():
            ds_in = xr.open_mfdataset(os.path.join(outroot, var, f"{CASENAME}*.nc"), chunks=chunk_size)[var].to_dataset()
        else:
            ds_in[var] = xr.open_mfdataset(os.path.join(outroot, var, f"{CASENAME}*.nc"), chunks=chunk_size)[var]

    # Set up PBS Cluster for analysis
    cluster = PBSCluster(cores=1,
                         processes=1,
                         memory='20GB',
                         account='P93300606',
                         queue='casper',
                         walltime='00:20:00',
                         interface='ext',
                        )
    
    client = Client(cluster)
    cluster.scale(16)

else:
    ds_means_file = os.path.join(inputdir, 'JAMES_8p4z_last20yr_annual_mean.nc')
    ds_means = xr.open_dataset(ds_means_file)[var_list].squeeze()
    client=None
    # read in ds_means?

percent_sed_file = os.path.join(inputdir, 'percentSed_stdgrid_gx1v6.nc')
ds_percent_sed = xr.open_dataset(percent_sed_file)

client

In [None]:
%%time

if process_LENS:
    for var in ds_in:
        # ds_means[var] = ds_in[var].mean('time').persist()
        # wait(ds_means[var])
        # ds_means[var] = ds_means[var].compute()
        ds_means[var] = ds_in[var].mean('time').compute()
        ds_means[var].attrs = ds_in[var].attrs
    client.shutdown()

ds_means

In [None]:
ds_means['mask'] = xr.zeros_like(ds_means['TEMP'], dtype=bool)
ds_means['mask'].name = 'Ocean Mask'
ds_means['mask'].data = np.where(np.isfinite(ds_means['TEMP'].data), True, False)
# Temp had some unexpected values of -1 in deep ocean for blocks where TLAT and TLON are missing (LBE to blame)
ds_means['mask'].data = np.where(np.logical_not(ds_means['mask'].isel(z_t=0).data), False, ds_means['mask'].data)
ds_means['land_mask'] = xr.zeros_like(ds_means['mask'])
ds_means['land_mask'].name = 'Land Mask'
ds_means['land_mask'].data = np.logical_not(ds_means['mask'].data)
ds_means

## Step 2: Compute Mean Horizontal Speed

Looks like Keith used $\ell_1$ norm here

In [None]:
velocity = xr.zeros_like(ds_means['UVEL'])
velocity.name = 'velocity'
velocity.data = np.abs(ds_means['UVEL'].data) + np.abs(ds_means['VVEL'].data)
velocity.isel(z_t=0).plot()

In [None]:
ds_means['velocity'] = xr.zeros_like(ds_means['UVEL'])
ds_means['velocity'].name = 'velocity'
# Note: IDL script loops through popz-2, not popz-1 => velocity at bottom is 0!
ds_means['velocity'].data[:-1,:,:] = np.abs(ds_means['UVEL'].data[:-1,:,:]) + np.abs(ds_means['VVEL'].data[:-1,:,:])
ds_means['velocity'].isel(z_t=0).plot()

## Step 3: Minimum percent sed when land-adjacent

Also, rescale in vertical

In [None]:
# Set up arrays of indices for cell to left / right
left_ind = np.arange(ds_means['mask'].sizes['nlon']) - 1
left_ind[0] = ds_means['mask'].sizes['nlon'] - 1
right_ind = np.arange(ds_means['mask'].sizes['nlon']) + 1
right_ind[-1] = 0

# set up arrays of indices for cells two below, directly below, directly above, and two above
down2_ind = np.arange(ds_means['mask'].sizes['nlat']) - 2
down2_ind[:2] = 0
down_ind = np.arange(ds_means['mask'].sizes['nlat']) - 1
down_ind[0] = 0
up_ind = np.arange(ds_means['mask'].sizes['nlat']) + 1
up_ind[-1] = ds_means['mask'].sizes['nlat'] - 1
up2_ind = np.arange(ds_means['mask'].sizes['nlat']) + 2
up2_ind[-2:] = ds_means['mask'].sizes['nlat'] - 1

In [None]:
ds_means['land_adj'] = xr.zeros_like(ds_means['mask'])
ds_means['land_adj'].name = "Land Adjacent"

# look for land due east/west of cell
for lon_ind in [left_ind, right_ind]:
    ds_means['land_adj'].data = np.where(ds_means['land_mask'].data[:,:,lon_ind],True, ds_means['land_adj'].data)

# look for land due north/south of cell
for lat_ind in [down2_ind, down_ind, up_ind, up2_ind]:
    ds_means['land_adj'].data = np.where(ds_means['land_mask'].data[:,lat_ind,:], True, ds_means['land_adj'].data)

# # look for land in the corners of the halo
for lat_ind in [down2_ind, down_ind, up_ind, up2_ind]:
    for lon_ind in [left_ind, right_ind]:
        ds_means['land_adj'].data = np.where((ds_means['land_mask'].data[:,lat_ind,:])[:,:,lon_ind],True, ds_means['land_adj'].data)

# Actual land points are not considered land-adjacent
ds_means['land_adj'].data = np.where(ds_means['land_mask'].data, False, ds_means['land_adj'].data)

# Plot land-adjacent cells in surface layer
print(f"There are {np.sum(ds_means['land_adj'].isel(z_t=0).data)} land-adjacent cells in the top level")
ds_means['land_adj'].isel(z_t=0).plot()
# ds_means['land_adj'].isel(z_t=-1).plot()

In [None]:
# Setting dimensions to match temperature, but data comes from ds_percent_sed
outputs['percsed'] = xr.zeros_like(ds_means['TEMP'])
outputs['percsed'].name = "Percent Sed"
# Initial data is TEMP, but mask out erroneous data values
outputs['percsed'].data = ds_percent_sed['PERCENTSED'].data

# Two steps in land-adjacent points:
# 1. Set percsed to max(minval, percsed)
outputs['percsed'].data = np.where(ds_means['land_adj'].data,
                                   np.maximum(minval, outputs['percsed'].data),
                                   outputs['percsed'].data)
# 2. Normalize percsed so sum of every column is 1
percsed_sum = np.where(ds_means['land_adj'].data, outputs['percsed'].sum('z_t'), 1)

# Note that perced_sum = 1 when not land-adjacent, so no np.where() statement needed
outputs['percsed'].data = outputs['percsed'].data / percsed_sum.data

outputs['percsed'].sum('z_t').plot()

In [None]:
# Write to File?
outputs['percsed'].to_dataset(name="PERCENTSED").to_netcdf(percsedfile_out)

## Step 4: Compute Sediment Input from Reducing and Oxic Sediments

In [None]:
ds_means.z_t[39:42].data

In [None]:
local_vars['speed'] = xr.zeros_like(ds_means['velocity'])
local_vars['speed'].name = 'Local copy of speed'
local_vars['speed'].data = np.minimum(20., np.maximum(0.2, ds_means['velocity'].data))

# Keith's script wants z index <= 40; level 40 is 1106.2 m deep while level 41 is 1244.6 m, hence using 1200m as threshold
local_vars['speed'].data = np.where(np.logical_and(local_vars['speed'] < 2., ds_means['velocity'].z_t < 1200. * 100.),
                                    2.,
                                    local_vars['speed'].data
                                   )

local_vars['speed'].data = np.where(outputs['percsed'].data > 0., local_vars['speed'].data, 0.)
local_vars['speed'].isel(z_t=0).plot()

In [None]:
local_vars['scale'] = xr.zeros_like(local_vars['speed'])
local_vars['scale'].name = 'Scale Factor'
local_vars['scale'].data = local_vars['speed'].data

local_vars['Kd'] = xr.zeros_like(ds_means['KVMIX'])
local_vars['Kd'].name = 'Kd'
local_vars['Kd'].data[2:,:,:] = np.where(outputs['percsed'].data[2:,:,:] > 0., ds_means['KVMIX'].data[1:-1,:,:], 0.)
local_vars['Kd'].data = np.minimum(local_vars['Kd'].data, 10.)

local_vars['scale'].data = np.maximum(local_vars['Kd'].data, local_vars['scale'].data)

In [None]:
outputs['fesed'] = xr.zeros_like(outputs['percsed'])
outputs['fesed'].name = "Fe Sediment"
outputs['fesed'].data = minoxic * outputs['percsed'].data * local_vars['scale'].data
# Convert to model units
outputs['fesed'].data = outputs['fesed'].data / ((365. * 864.) * 1.1574e-6)

In [None]:
outputs['fesed'].to_dataset(name='FESEDFLUXIN').to_netcdf(fesedflux_out)

In [None]:
# convert POC_FLUX_IN from mmol / m^3  cm / s -> g / m^2 / yr
# 365 * 86400 s / yr, 12.011 gC / mol C, 0.01 m / cm, 0.001 mmol / mol
local_vars['POC'] = xr.zeros_like(ds_means['POC_FLUX_IN'])
local_vars['POC'].data = (365. * 86400.) * 12.011 * 0.01 * 0.001 * ds_means['POC_FLUX_IN'].data

# Update Western Pacific
# local_vars['POC'].data[:32,179:210,138:249] = local_vars['POC'].data[:32,179:210,138:249] * 5.
local_vars['POC'].data[:32,:,:] = np.where(np.logical_and(np.logical_and(ds_means['TLONG'].data > 115.8,
                                                                         ds_means['TLONG'].data < 239.6),
                                                          np.logical_and(ds_means['TLAT'].data > -2.01,
                                                                         ds_means['TLAT'].data < 6.1),
                                                         ),
                                           local_vars['POC'].data[:32,:,:] * 5.,
                                           local_vars['POC'].data[:32,:,:])
# local_vars['POC'].data[:32,114:179,160:249] = local_vars['POC'].data[:32,114:179,160:249] * 2.5
local_vars['POC'].data[:32,:,:] = np.where(np.logical_and(np.logical_and(ds_means['TLONG'].data > 140.5,
                                                                         ds_means['TLONG'].data < 239.6),
                                                          np.logical_and(ds_means['TLAT'].data > -20.2,
                                                                         ds_means['TLAT'].data < -2.25),
                                                         ),
                                           local_vars['POC'].data[:32,:,:] * 2.5,
                                           local_vars['POC'].data[:32,:,:])
# Update Southern Ocean
# local_vars['POC'].data[:32,:44,:] = local_vars['POC'].data[:32,:44,:] * 2.5
local_vars['POC'].data[:32,:,:] = np.where(ds_means['TLAT'].data < -56.2,
                                           local_vars['POC'].data[:32,:,:] * 2.5,
                                           local_vars['POC'].data[:32,:,:])

In [None]:
local_vars['sinkpoc'] = xr.zeros_like(local_vars['POC'])
local_vars['sinkpoc'].name = 'Sinking POC'

local_vars['sinkpoc'].data = np.where(np.logical_and(local_vars['POC'] < 10.,
                                       ds_means['POC_FLUX_IN'].z_t < 1000. * 100.),
                        10.,
                        local_vars['POC'].data)

In [None]:
local_vars['Tfunc'] = xr.zeros_like(ds_means['TEMP'])
local_vars['Tfunc'].name = 'Tfunc'
local_vars['Tfunc'].data = 1.5**(((ds_means['TEMP'].data + 273.15) - (32.0 + 273.15)) / 10.)

In [None]:
outputs['fesedRed'] = xr.zeros_like(outputs['fesed'])
outputs['fesedRed'].name = 'Iron Sediment Reduced'

outputs['fesedRed'].data = local_vars['sinkpoc'].data * xfactor * outputs['percsed'].data * local_vars['scale'].data * local_vars['Tfunc'].data
# Convert to model units
outputs['fesedRed'].data = outputs['fesedRed'].data / ((365. * 864.) * 1.1574e-6)

In [None]:
outputs['fesedRed'].to_dataset(name='FESEDFLUXIN').to_netcdf(fesedredflux_out)