# Example Workflow for CM2.6

- Interpolate atmosphere onto ocean
- Recalculate flux terms from ocean res (`full_res_*`)
- Coarsen the recalculated flux + flux input fields
- Recompute flux terms once again from coarsened input fields (`coarse_res_*`)
- Look at the difference (`full_res_* - coarse_res_*`)

In [1]:
# !mamba install aerobulk-python -y

In [2]:
import fsspec
import xarray as xr
import numpy as np
import xesmf as xe
import os
from intake import open_catalog
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import json
import gcsfs
from dask.diagnostics import ProgressBar
from cm26_utils import write_split_zarr, noskin_ds_wrapper

# 👇 replace with your key 
with open('/home/jovyan/keys/pangeo-forge-ocean-transport-4967-347e2048c5a1.json') as token_file:
    token = json.load(token_file)
fs = gcsfs.GCSFileSystem(token=token)
subfolder_full = 'ocean-transport-group/scale-aware-air-sea/outputs/temp/'
subfolder_final = 'ocean-transport-group/scale-aware-air-sea/outputs/'

# for testing
appendix='_test'
# appendix = ''

# algo = 'coare3p6'
algo = 'coare3p0'
# algo='ncar'
# algo='ecmwf'
# algo='andreas'

In [3]:
# from multiprocessing.pool import ThreadPool
# import dask
# # dask.config.set(pool=ThreadPool(32))# blows out the memory?
# # dask.config.set(pool=ThreadPool(24))# coare3p6 needs more memory?
# dask.config.set(pool=ThreadPool(8))# this worked for ecmwf and ncar
# # dask.config.set(pool=ThreadPool(2))

# from dask.distributed import LocalCluster, Client
# cluster = LocalCluster(n_workers=4, threads_per_worker=2)
# client = Client(cluster)

In [4]:
kwargs = dict(consolidated=True, use_cftime=True)
cat = open_catalog("https://raw.githubusercontent.com/pangeo-data/pangeo-datastore/master/intake-catalogs/ocean/GFDL_CM2.6.yaml")
ds_ocean  = cat["GFDL_CM2_6_control_ocean_surface"].to_dask()
ds_flux  = cat["GFDL_CM2_6_control_ocean_boundary_flux"].to_dask()
# xarray says not to do this
# ds_atmos = xr.open_zarr('gs://cmip6/GFDL_CM2_6/control/atmos_daily.zarr', chunks={'time':1}, **kwargs)
ds_atmos = xr.open_zarr('gs://cmip6/GFDL_CM2_6/control/atmos_daily.zarr', **kwargs)
ds_oc_grid  = cat["GFDL_CM2_6_grid"].to_dask()
# cut to same time
all_dims = set(list(ds_ocean.dims)+list(ds_atmos.dims))
ds_ocean, ds_atmos = xr.align(
    ds_ocean,
    ds_atmos,
    join='inner',
    exclude=(di for di in all_dims if di !='time')
)
# instead do this
ds_atmos = ds_atmos.chunk({'time':1})

## Regridding the atmos variables onto the ocean grid

In [5]:
fs = gcsfs.GCSFileSystem(token=token)
path = 'ocean-transport-group/scale-aware-air-sea/regridding_weights/CM26_atmos2ocean.zarr'
mapper = fs.get_mapper(path)
ds_regridder = xr.open_zarr(mapper).load()
regridder = xe.Regridder(
    ds_atmos.olr.to_dataset(name='dummy').isel(time=0).reset_coords(drop=True),# this is the same dumb problem I keep having with 
    ds_ocean.surface_temp.to_dataset(name='dummy').isel(time=0).reset_coords(drop=True),
    'bilinear',
    weights=ds_regridder,
    periodic=True
)
regridder

xESMF Regridder 
Regridding algorithm:       bilinear 
Weight filename:            bilinear_360x576_2700x3600_peri.nc 
Reuse pre-computed weights? False 
Input grid shape:           (360, 576) 
Output grid shape:          (2700, 3600) 
Periodic in longitude?      True

In [6]:
ds_atmos_regridded = regridder(ds_atmos[['slp', 'v_ref', 'u_ref', 't_ref', 'q_ref', 'wind']])# We are only doing noskin for now , 'swdn_sfc', 'lwdn_sfc'
ds_atmos_regridded

## combine into merged dataset
ds_merged = xr.merge(
    [
        ds_atmos_regridded,
        ds_ocean[['surface_temp']],
    ]
)
ds_merged = ds_merged.transpose(
    'xt_ocean', 'yt_ocean', 'time'
)
ds_merged

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,8 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 264.51 GiB 37.08 MiB Shape (3600, 2700, 7305) (3600, 2700, 1) Count 8 Graph Layers 7305 Chunks Type float32 numpy.ndarray",7305  2700  3600,

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,8 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,8 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 264.51 GiB 37.08 MiB Shape (3600, 2700, 7305) (3600, 2700, 1) Count 8 Graph Layers 7305 Chunks Type float32 numpy.ndarray",7305  2700  3600,

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,8 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,8 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 264.51 GiB 37.08 MiB Shape (3600, 2700, 7305) (3600, 2700, 1) Count 8 Graph Layers 7305 Chunks Type float32 numpy.ndarray",7305  2700  3600,

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,8 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,8 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 264.51 GiB 37.08 MiB Shape (3600, 2700, 7305) (3600, 2700, 1) Count 8 Graph Layers 7305 Chunks Type float32 numpy.ndarray",7305  2700  3600,

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,8 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,8 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 264.51 GiB 37.08 MiB Shape (3600, 2700, 7305) (3600, 2700, 1) Count 8 Graph Layers 7305 Chunks Type float32 numpy.ndarray",7305  2700  3600,

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,8 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,8 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 264.51 GiB 37.08 MiB Shape (3600, 2700, 7305) (3600, 2700, 1) Count 8 Graph Layers 7305 Chunks Type float32 numpy.ndarray",7305  2700  3600,

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,8 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,3 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 264.51 GiB 37.08 MiB Shape (3600, 2700, 7305) (3600, 2700, 1) Count 3 Graph Layers 7305 Chunks Type float32 numpy.ndarray",7305  2700  3600,

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,3 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray


## Recalculate Fluxes on the native ocean grid (most resource intense step)
For now I am writing out the aerobulk-python results to a cloud bucket, reloading them and then continue to the next step. 

TODO: Eventually it would be nice if we could skip this step alltoghether and stream the full computation (this + coarsening), so we do not write this huge amount of data to the bucket.

TODO: add the algo to the datasets in this step. 

In [7]:
ds_out = noskin_ds_wrapper(ds_merged, algo=algo, input_range_check=False)
# ds_out = noskin_ds_wrapper(ds_merged, algo=algo, input_range_check=True)
ds_out

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,40 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 264.51 GiB 37.08 MiB Shape (3600, 2700, 7305) (3600, 2700, 1) Count 40 Graph Layers 7305 Chunks Type float32 numpy.ndarray",7305  2700  3600,

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,40 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,40 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 264.51 GiB 37.08 MiB Shape (3600, 2700, 7305) (3600, 2700, 1) Count 40 Graph Layers 7305 Chunks Type float32 numpy.ndarray",7305  2700  3600,

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,40 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,40 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 264.51 GiB 37.08 MiB Shape (3600, 2700, 7305) (3600, 2700, 1) Count 40 Graph Layers 7305 Chunks Type float32 numpy.ndarray",7305  2700  3600,

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,40 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,40 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 264.51 GiB 37.08 MiB Shape (3600, 2700, 7305) (3600, 2700, 1) Count 40 Graph Layers 7305 Chunks Type float32 numpy.ndarray",7305  2700  3600,

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,40 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,40 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 264.51 GiB 37.08 MiB Shape (3600, 2700, 7305) (3600, 2700, 1) Count 40 Graph Layers 7305 Chunks Type float32 numpy.ndarray",7305  2700  3600,

Unnamed: 0,Array,Chunk
Bytes,264.51 GiB,37.08 MiB
Shape,"(3600, 2700, 7305)","(3600, 2700, 1)"
Count,40 Graph Layers,7305 Chunks
Type,float32,numpy.ndarray


# Setting up a dask beast

## Problems with distributed

Using the threaded scheduler is currently the only way I can finish writing out zarr files. This is however quite slow (a lot more cores would help tremendously in speeding this up). 
I am trying to debug this in `cm26_pipeline-debug-distributed.ipynb`. For now I will use this as a brute force 'leave running all night' way to get anything to analyze.

In [8]:
# hacky stuff (but it works!)
# n_workers = 50 # did not get 100 for a long time, 
n_workers = 45
# mabye 100 is thought of for adaptive and 85 for longer term?
n_threads = 12 # I think 12 is the max?
# n_threads = 8 # just seeing if we get these quicker (not really)
# But for this we need a bit more memory per core I think (might not be true anymore).
# Would be nice to figure out this ratio properly



import subprocess
import logging
from distributed import WorkerPlugin

class MambaPlugin(WorkerPlugin):
    """
    Install packages on a worker as it starts up.

    Parameters
    ----------
    packages : List[str]
        A list of packages to install with pip on startup.
    """
    def __init__(self, packages):
        self.packages = packages

    def setup(self, worker):
        logger = logging.getLogger("distributed.worker")
        subprocess.call(['mamba', 'install'] + self.packages)
        logger.info("Installed %s", self.packages)


# distributed does not like these long tasks (>40s) we produce. 
# Lets see if we can tune that (!YES!):
import dask
dask.config.set({"distributed.comm.timeouts.tcp": "60s", "distributed.comm.timeouts.connect": "60s"})

from dask_gateway import Gateway
gateway = Gateway()

# close existing clusters
open_clusters = gateway.list_clusters()
print(list(open_clusters))
if len(open_clusters)>0:
    for c in open_clusters:
        cluster = gateway.connect(c.name)
        cluster.shutdown()  

options = gateway.cluster_options()

# set the options programatically, or through their HTML repr
options.worker_memory = 52  # I think somewhere around this (52 previously) is the max per pod?
options.worker_cores = n_threads 
# Create a cluster with those options
cluster = gateway.new_cluster(options)
client = cluster.get_client()
# 
plugin = MambaPlugin(['aerobulk-python'])
client.register_worker_plugin(plugin)
# def check():
#     import aerobulk
#     return aerobulk.__version__
# cluster.scale(2)
# client.wait_for_workers(2)
# client.run(check)
cluster.scale(n_workers)
client

[ClusterReport<name=prod.4d69634f9e134b5ba284efd4259d87b9, status=RUNNING>]


  self.scheduler_comm.close_rpc()


0,1
Connection method: Cluster object,Cluster type: dask_gateway.GatewayCluster
Dashboard: /services/dask-gateway/clusters/prod.6a286b4758174abca629b6a611c1044f/status,


In [9]:
# reduce the amount of set up for testing
# ds_out = ds_out.isel(time=slice(0,1000))

In [10]:
fs.exists(path)

True

In [11]:
# Ad-hoc hack (splitting the dataset into batches and append to zarr store
path = f'{subfolder_full}CM26_high_res_output_{algo}{appendix}.zarr'
mapper = fs.get_mapper(path)
print(f"Writing to {path}")

overwrite = True
if fs.exists(path) and overwrite:
# # # delete the mapper (only uncomment if you want to start from scratch!)
    fs.rm(path, recursive=True)    

#
write_split_zarr(mapper, ds_out, split_interval=n_threads*n_workers*2) # the amount of threads in the cluster, basically forcing a write after each calculation 
# the memory still slowly overflows...
# I wonder if Gabes dask fix would help here? It seems like a similar problem
# Store tasks are not executed before grabbing more compute tasks.

Writing to ocean-transport-group/scale-aware-air-sea/outputs/temp/CM26_high_res_output_coare3p0_test.zarr
initializing store


  0%|          | 0/7 [00:00<?, ?it/s]

Writing split 0
Start: 0181-01-01 12:00:00
Stop: 0183-12-16 12:00:00
Writing split 1
Start: 0183-12-17 12:00:00
Stop: 0186-11-30 12:00:00
Writing split 2
Start: 0186-12-01 12:00:00
Stop: 0189-11-14 12:00:00
Writing split 3
Start: 0189-11-15 12:00:00
Stop: 0192-10-29 12:00:00
Writing split 4
Start: 0192-10-30 12:00:00
Stop: 0195-10-14 12:00:00
Writing split 5
Start: 0195-10-15 12:00:00
Stop: 0198-09-28 12:00:00
Writing split 6
Start: 0198-09-29 12:00:00
Stop: 0200-12-31 12:00:00


Little tidbit about timeouts:
Apparently you can do this:
```
import dask
import distributed
dask.config.set({"distributed.comm.timeouts.tcp": "50s"})
```
[source](https://stackoverflow.com/questions/60088134/dask-distributed-client-error-failed-to-reconnect-to-scheduler-after-10-00-s)


Yayyyy, this seems to work!

In [None]:
# close the distributed client (TODO: still need to fix the issue below)

In [None]:
cluster.shutdown()

In [None]:
client.close()

# Spin up a new cluster with less threads

In [None]:
# hacky stuff (but it works!)
# n_workers = 50 # did not get 100 for a long time, 
n_workers = 20
# mabye 100 is thought of for adaptive and 85 for longer term?
n_threads = 2 # I think 12 is the max?
# n_threads = 8 # just seeing if we get these quicker (not really)
# But for this we need a bit more memory per core I think (might not be true anymore).
# Would be nice to figure out this ratio properly

# close existing clusters
open_clusters = gateway.list_clusters()
print(list(open_clusters))
if len(open_clusters)>0:
    for c in open_clusters:
        cluster = gateway.connect(c.name)
        cluster.shutdown()  

options = gateway.cluster_options()

# set the options programatically, or through their HTML repr
options.worker_memory = 52  # I think somewhere around this (52 previously) is the max per pod?
options.worker_cores = n_threads 
# Create a cluster with those options
cluster = gateway.new_cluster(options)
client = cluster.get_client()
# 
plugin = MambaPlugin(['aerobulk-python'])
client.register_worker_plugin(plugin)
# def check():
#     import aerobulk
#     return aerobulk.__version__
# cluster.scale(2)
# client.wait_for_workers(2)
# client.run(check)
cluster.scale(n_workers)
client

# so apparently you can get messages from within the fortran code? 

I got this for coare3p0:
![image.png](attachment:0a343576-aappendix4944-9af5-4cfce5efcdba.png)

Super strange

**Note**: It seems that error was just a glitch...I have now set up the split writing in a way that should avoid having duplicate times. 

If the error happens again, I need to see how I can build in some sort of retry logic.

### Notes:

I could actually write out the sst, u,v etc here too. That way the next processing step might be quicker? Then again, this here is wayy slower, and we might not want to slow it down even further.

# The large scaled computation came until here flawlessly, but then the cluster got killed...weird. Now trying to skip everything above and pick up only with a threaded cluster

## Coarsening the input/output and recomputing the 'large scale' output

In [None]:
# reload the flux output (this takes a loooong ass time, probably because the store is not consolidated)
path = f'{subfolder_full}CM26_high_res_output_{algo}{appendix}.zarr'
mapper = fs.get_mapper(path)
ds_recomputed_full = xr.open_dataset(mapper, engine='zarr', consolidated=False, use_cftime=True, chunks={'time':1})
ds_recomputed_full

In [None]:
# combine all variables (input + flux output + area) in one dataset. 
ds = xr.merge([ds_merged, ds_recomputed_full], join='inner')
ds = ds.assign_coords(area=ds_oc_grid.area_t)#.load()
ds

### Manually do a weighted coarsen
> Would be nice if this could work with the xr.weighted logic!

In [None]:
coarsen_win = dict(xt_ocean=20, yt_ocean=20)
mean_dims = ['xt_ocean', 'yt_ocean']

masked_area = ds.area.where(~np.isnan(ds.surface_temp.isel(time=0)))
ds_weighted = ds * masked_area
coarsened_area = masked_area.coarsen(**coarsen_win).sum(mean_dims).drop_vars(['time'])#.squeeze(drop=True)# This does not work. Annoying!
ds_coarsened = ds_weighted.coarsen(**coarsen_win).sum(mean_dims) / coarsened_area
# add area t
ds_coarsened = ds_coarsened.assign_coords(area=coarsened_area)


# Apply a new (strict) landmask to fields that had nans in them before
coarsened_landmask = coarsened_landmask = np.isnan(ds['surface_temp'].isel(time=0)).coarsen(**coarsen_win).sum(mean_dims)>0
for var in ['qh', 'ql', 'surface_temp', 't_ref', 'slp']:
    ds_coarsened[var] = ds_coarsened[var].where(~coarsened_landmask)
ds_coarsened

In [None]:
ds_recompute_coarse = noskin_ds_wrapper(ds_coarsened, input_range_check=True)
# add all the recomputed variables from coarsened ouput back to the input dataset
for var in ds_recompute_coarse.data_vars:
    ds_coarsened[var+'_large_scale'] = ds_recompute_coarse[var]

In [None]:
ds_coarsened.qh_large_scale.isel(time=1200).plot()

# TODO: this still takes very long (~4 h?)
I am sure I can get a dask cluster to push this through quicker.

But I am always getting some weird cancelled errors...is this a timeout problem?

In [None]:
# write this out as final store (this is the store that we can use for analysis)
path = f'{subfolder_final}CM26_final_output_full_time_{algo}{appendix}.zarr'
mapper = fs.get_mapper(path)
ds_coarsened.attrs['algo'] = algo

with ProgressBar():
    print("hello?")
    ds_coarsened.to_zarr(mapper, mode='w', consolidated=True)

In [None]:
# reload the final store (needs the first cell to be executed, but should otherwise work independently)
path = f'{subfolder_final}CM26_final_output_full_time_{algo}{appendix}.zarr'
mapper = fs.get_mapper(path)
ds_plot = xr.open_dataset(mapper, engine='zarr', consolidated=True, use_cftime=True, chunks='auto')
ds_plot

In [None]:
for var in ['ql', 'qh', 'evap', 'taux', 'tauy']:
    with ProgressBar():
        full = ds_plot[var].mean('time').load()
        large_scale = ds_plot[var+'_large_scale'].mean('time').load()
        small_scale = full-large_scale

        # not quite sure if we should look at the mean of the difference or the difference of the mean over time...
        plot_kwargs = {'y':'yt_ocean', 'robust':True, 'center':0}
        plt.figure(figsize=[25,4])
        plt.subplot(1,4,1)
        full.plot(**plot_kwargs)
        plt.title('Full output coarsened')
        plt.subplot(1,4,2)
        large_scale.plot(**plot_kwargs)
        plt.title('Output from coarsened input')
        plt.subplot(1,4,3)
        small_scale.plot(**plot_kwargs)
        plt.title('Small Scale Absolute')
        plt.subplot(1,4,4)
        (small_scale/full*100).plot(vmax=10, **plot_kwargs)
        plt.title('Small Scale Relative')
        plt.show()