# Do brushcutting of ACESS-OM2-01 yr2 to use as boundary forcing for panan-005

In [1]:
import netCDF4
import xarray as xr
import xesmf as xe
from itertools import cycle
import os
import dask
import numpy as np
import pandas as pd
import dask.array as da
import dask.bag as db
from pykdtree.kdtree import KDTree
from dask.diagnostics import ProgressBar
import subprocess
import matplotlib.pyplot as plt
from glob import glob
import cmocean.cm as cmocean

xr.set_options(keep_attrs=True) ## This ensures that performing simple operations with xarray (eg converting temperature to Celsius) preserves attributes

from dask.distributed import Client
client = Client()
client

Perhaps you already have a cluster running?
Hosting the HTTP server on port 44651 instead


0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: /proxy/44651/status,

0,1
Dashboard: /proxy/44651/status,Workers: 7
Total threads: 28,Total memory: 125.20 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:33001,Workers: 7
Dashboard: /proxy/44651/status,Total threads: 28
Started: Just now,Total memory: 125.20 GiB

0,1
Comm: tcp://127.0.0.1:38683,Total threads: 4
Dashboard: /proxy/37279/status,Memory: 17.89 GiB
Nanny: tcp://127.0.0.1:34211,
Local directory: /jobfs/78482592.gadi-pbs/dask-worker-space/worker-yudqmmra,Local directory: /jobfs/78482592.gadi-pbs/dask-worker-space/worker-yudqmmra

0,1
Comm: tcp://127.0.0.1:36093,Total threads: 4
Dashboard: /proxy/35285/status,Memory: 17.89 GiB
Nanny: tcp://127.0.0.1:32775,
Local directory: /jobfs/78482592.gadi-pbs/dask-worker-space/worker-9upug3oj,Local directory: /jobfs/78482592.gadi-pbs/dask-worker-space/worker-9upug3oj

0,1
Comm: tcp://127.0.0.1:46367,Total threads: 4
Dashboard: /proxy/32807/status,Memory: 17.89 GiB
Nanny: tcp://127.0.0.1:40101,
Local directory: /jobfs/78482592.gadi-pbs/dask-worker-space/worker-uh688xa0,Local directory: /jobfs/78482592.gadi-pbs/dask-worker-space/worker-uh688xa0

0,1
Comm: tcp://127.0.0.1:34993,Total threads: 4
Dashboard: /proxy/40961/status,Memory: 17.89 GiB
Nanny: tcp://127.0.0.1:45975,
Local directory: /jobfs/78482592.gadi-pbs/dask-worker-space/worker-nqwj2f9n,Local directory: /jobfs/78482592.gadi-pbs/dask-worker-space/worker-nqwj2f9n

0,1
Comm: tcp://127.0.0.1:46711,Total threads: 4
Dashboard: /proxy/34119/status,Memory: 17.89 GiB
Nanny: tcp://127.0.0.1:45007,
Local directory: /jobfs/78482592.gadi-pbs/dask-worker-space/worker-uloh_iz_,Local directory: /jobfs/78482592.gadi-pbs/dask-worker-space/worker-uloh_iz_

0,1
Comm: tcp://127.0.0.1:34287,Total threads: 4
Dashboard: /proxy/37919/status,Memory: 17.89 GiB
Nanny: tcp://127.0.0.1:34689,
Local directory: /jobfs/78482592.gadi-pbs/dask-worker-space/worker-clyafbo7,Local directory: /jobfs/78482592.gadi-pbs/dask-worker-space/worker-clyafbo7

0,1
Comm: tcp://127.0.0.1:43751,Total threads: 4
Dashboard: /proxy/35605/status,Memory: 17.89 GiB
Nanny: tcp://127.0.0.1:38927,
Local directory: /jobfs/78482592.gadi-pbs/dask-worker-space/worker-z3_s2lex,Local directory: /jobfs/78482592.gadi-pbs/dask-worker-space/worker-z3_s2lex


In [2]:
# range of files to use
t = range(8, 12)

surface_tracer_vars = ["pot_temp", "salt", "dzt"]
line_tracer_vars = ["eta_t"]
surface_velocity_vars = ["u", "v", "dzu"]

chunks = {
    "T": {"time": 1, "zl": 7, "yt_ocean": 300, "xt_ocean": None},
    "U": {"time": 1, "zl": 7, "yu_ocean": 300, "xu_ocean": None},
}

# open target grid dataset
# we interpolate onto the hgrid
dg = xr.open_dataset("/g/data/e14/cs6673/panan-005/domain-tools/new_topog/ocean_hgrid_0025.nc").isel(nyp=[-1])

# interpolation grid
dg_out = xr.Dataset(
    {"lat": (["location"], dg.y.squeeze().data), "lon": (["location"], dg.x.squeeze().data)}
)

In [3]:
# open source datasets
surface_vars = surface_velocity_vars + surface_tracer_vars

in_datasets = {}
for var, staggering in list(zip(surface_tracer_vars, cycle("T"))) + list(
    zip(surface_velocity_vars, cycle("U"))
):
    if var=='pot_temp':
        var_name = 'temp'
    else:
        var_name = var
    d = xr.open_mfdataset(
        [
            '/scratch/v45/akm157/access-om2/archive/01deg_jra55v13_ryf9091/output'+f'{i}'.zfill(3)+f'/ocean/ocean_daily_3d_{var_name}.nc'
            for i in t
        ],
        chunks=chunks[staggering],
        combine="by_coords",
        parallel=False,
    )[var]
    in_datasets[var] = staggering, d

# line datasets, assume they all come from ocean_daily
d_2d = xr.open_mfdataset(
    [
        '/scratch/v45/akm157/access-om2/archive/01deg_jra55v13_ryf9091/output'+f'{i}'.zfill(3)+'/ocean/ocean_daily.nc'
        for i in t
    ],
    chunks={"time": 1, "yt_ocean": 300, "xt_ocean": None},
    combine="by_coords",
    parallel=False,
)[line_tracer_vars]

d_tracer = xr.merge([d for s, d in in_datasets.values() if s == "T"] + [d_2d])
d_velocity = xr.merge([d for s, d in in_datasets.values() if s == "U"])

In [4]:
# Make January come first:

def time_rotate(d,run_year = 1901):
    before_start_time = f"{run_year}-12-31"
    after_end_time = f"{run_year+1}-01-01"

    left = d.sel(time=slice(after_end_time, None))
    left["time"] = pd.date_range("1991-01-01 12:00:00", periods=120)

    right = d.sel(time=slice(None, before_start_time))
    right["time"] = pd.date_range("1991-05-01 12:00:00", periods=245)

    return xr.concat([left, right], "time")

# rotate time axis so January is first:
d_tracer = time_rotate(d_tracer)
d_velocity = time_rotate(d_velocity)

In [5]:
# reduce selection around target latitude
# and remove spatial chunks (required for xesmf)
d_tracer = d_tracer.sel(yt_ocean=slice(-38, -36)).chunk(
    {"yt_ocean": None, "xt_ocean": None}
)
d_velocity = d_velocity.sel(yu_ocean=slice(-38, -36)).chunk(
    {"yu_ocean": None, "xu_ocean": None}
)

In [6]:
# create the regridding weights between our grids
regridder_tracer = xe.Regridder(
    d_tracer.rename(xt_ocean="lon", yt_ocean="lat"),
    dg_out,
    "bilinear",
    periodic=True,
    locstream_out=True,
    reuse_weights=False,
    filename="bilinear_tracer_weights.nc",
)
regridder_velocity = xe.Regridder(
    d_velocity.rename(xu_ocean="lon", yu_ocean="lat"),
    dg_out,
    "bilinear",
    periodic=True,
    locstream_out=True,
    reuse_weights=False,
    filename="bilinear_velocity_weights.nc",
)

# now we can apply it to input DataArrays:
ds_out = xr.merge([regridder_tracer(d_tracer), regridder_velocity(d_velocity)])



In [7]:
# first fill nans in thickness with zeros, so the flood filling is not applied to them:
ds_out['dzt'] = ds_out.dzt.fillna(0)
ds_out['dzu'] = ds_out.dzu.fillna(0)

# Then fill in NaNs in the rest:
ds_out = (
    ds_out
    .ffill("st_ocean")
    .interpolate_na("location")
    .ffill("location")
    .bfill("location")
)

In [8]:
# fix up all the coordinate metadata
ds_out = ds_out.rename(location="nx_segment_001")
for var in surface_vars:
    ds_out[var] = ds_out[var].rename(st_ocean=f"nz_segment_001_{var}")
    ds_out = ds_out.rename({var: f"{var}_segment_001"})
    ds_out[f"nz_segment_001_{var}"] = np.arange(ds_out[f"nz_segment_001_{var}"].size)

for var in line_tracer_vars:
    ds_out = ds_out.rename({var: f"{var}_segment_001"})

# segment coordinates (x, y, z)
ds_out["nx_segment_001"] = np.arange(ds_out["nx_segment_001"].size)
ds_out["ny_segment_001"] = [0]

# lat/lon/depth/dz
ds_out["lon_segment_001"] = (["ny_segment_001", "nx_segment_001"], dg.x.data)
ds_out["lat_segment_001"] = (["ny_segment_001", "nx_segment_001"], dg.y.data)

# reset st_ocean so it's not an index coordinate
ds_out = ds_out.reset_index("st_ocean").reset_coords("st_ocean")
depth = ds_out["st_ocean"]
depth.name = "depth"
depth["st_ocean"] = np.arange(depth["st_ocean"].size)
del ds_out["st_ocean"]

In [9]:
# I don't think this trying to make them doubles actually worked.
encoding_dict = {
    "time": {
        "dtype": "double",},
        #"units": "days since 1900-01-01 12:00:00",
        #"calendar": "noleap",},
    "nx_segment_001": {"dtype": "int32",},
    "ny_segment_001": {"dtype": "int32",},
    "pot_temp_segment_001": {"dtype": "double",},
    "salt_segment_001": {"dtype": "double",},
    "eta_t_segment_001": {"dtype": "double",},
    "u_segment_001": {"dtype": "double",},
    "v_segment_001": {"dtype": "double",},
    "dz_pot_temp_segment_001": {"dtype": "double",},
    "dz_salt_segment_001": {"dtype": "double",},
    "dz_u_segment_001": {"dtype": "double",},
    "dz_v_segment_001": {"dtype": "double",},
}

In [10]:
# add the y dimension to dz:
ds_out["dzt_segment_001"] = ds_out["dzt_segment_001"].expand_dims("ny_segment_001", axis=2)
ds_out["dzu_segment_001"] = ds_out["dzu_segment_001"].expand_dims("ny_segment_001", axis=2)

for var in line_tracer_vars:
    ds_out[f"{var}_segment_001"] = ds_out[f"{var}_segment_001"].expand_dims(
        "ny_segment_001", axis=1
    )
    
    encoding_dict[f"{var}_segment_001"] = {"_FillValue": netCDF4.default_fillvals["f8"],}

# don't do for dzt or dzu:
for var in (surface_vars[:2] + surface_vars[3:5]):
    if var == 'pot_temp':
        ds_out[f"{var}_segment_001"] = ds_out[f"{var}_segment_001"] - 273.15 # Celsius
    # add the y dimension
    ds_out[f"{var}_segment_001"] = ds_out[f"{var}_segment_001"].expand_dims(
        "ny_segment_001", axis=2
    )
    
    if var in surface_tracer_vars:
        ds_out[f"dz_{var}_segment_001"] = (
            ["time", f"nz_segment_001_{var}", "ny_segment_001", "nx_segment_001"],
            ds_out["dzt_segment_001"].data)
    else:
        ds_out[f"dz_{var}_segment_001"] = (
            ["time", f"nz_segment_001_{var}", "ny_segment_001", "nx_segment_001"],
            ds_out["dzu_segment_001"].data)
    
    encoding_dict[f"{var}_segment_001"] = {
        "_FillValue": netCDF4.default_fillvals["f8"],
        "zlib": True, 
    }
    encoding_dict[f"dz_{var}_segment_001"] = {
        "_FillValue": netCDF4.default_fillvals["f8"],
        "zlib": True, 
    }
    encoding_dict[f"nz_segment_001_{var}"] = {
        "dtype": "int32"
    }

# we don't need dz_segment_001:
ds_out = ds_out.drop_vars('dzt_segment_001')
ds_out = ds_out.drop('nz_segment_001_dzt')
ds_out = ds_out.drop_vars('dzu_segment_001')
ds_out = ds_out.drop('nz_segment_001_dzu')

In [11]:
ds_old = xr.open_dataset("/g/data/x77/ahg157/inputs/mom6/panan/forcing_obc_shifted.nc", decode_times=False)

# time on ds_om4 was wacky:
ds_out["time"] = ds_old.time

# replace lat/lon:
ds_out = ds_out.drop_vars('lon_segment_001')
ds_out = ds_out.drop_vars('lat_segment_001')

ds_out["ilist_segment_001"] = (ds_old.ilist_segment_001)
ds_out["jlist_segment_001"] = (ds_old.ilist_segment_001)

In [None]:
with ProgressBar():
    ## Add modulo attribute for MOM6 to treat as repeat forcing:
    ds_out["time"] = ds_out["time"].assign_attrs({"modulo":" "})
    ds_out.to_netcdf("forcing_access_yr2_005.nc", encoding=encoding_dict, unlimited_dims="time")


# double check fill values. I don't think they are correct.


In [23]:
ds_out = xr.open_dataset("/g/data/v45/akm157/inputs/mom6/panan/panan_005/forcing_access_yr2_005.nc", decode_times=False)

# make doubles:
ds_out['pot_temp_segment_001'] = ds_out['pot_temp_segment_001'].astype('double') - 273.15 # Celsius
ds_out['salt_segment_001'] = ds_out['salt_segment_001'].astype('double')
ds_out['eta_t_segment_001'] = ds_out['eta_t_segment_001'].astype('double')
ds_out['u_segment_001'] = ds_out['u_segment_001'].astype('double')
ds_out['v_segment_001'] = ds_out['v_segment_001'].astype('double')
ds_out['dz_u_segment_001'] = ds_out['dz_u_segment_001'].astype('double')
ds_out['dz_v_segment_001'] = ds_out['dz_v_segment_001'].astype('double')
ds_out['dz_pot_temp_segment_001'] = ds_out['dz_pot_temp_segment_001'].astype('double')
ds_out['dz_salt_segment_001'] = ds_out['dz_salt_segment_001'].astype('double')

encoding_dict = {"pot_temp_segment_001": {"_FillValue": netCDF4.default_fillvals["f8"],},
                "salt_segment_001": {"_FillValue": netCDF4.default_fillvals["f8"],},
                "u_segment_001": {"_FillValue": netCDF4.default_fillvals["f8"],},
                "v_segment_001": {"_FillValue": netCDF4.default_fillvals["f8"],},
                 "eta_t_segment_001": {"_FillValue": netCDF4.default_fillvals["f8"],},
                }

ds_out.to_netcdf("forcing_access_yr2_005.nc", encoding=encoding_dict, unlimited_dims="time")


In [None]:
# Make sure there are no NaN Fill_Values. Do e.g. 'ncatted -a _FillValue,dz_u_segment_001,d,,, forcing_access_yr2_005.nc' if needed.
