In [1]:
# !pip install -e /home/jovyan/PROJECTS/scale-aware-air-sea

In [2]:
import gcsfs
import xarray as xr
import numpy as np
from scale_aware_air_sea.utils import smooth_inputs_dataset, to_zarr_split

In [3]:
fs = gcsfs.GCSFileSystem()

# version = 'v0.5'
# suffix = ''
version = 'v0.6.1'
suffix = '_test'

# set up save targets
bucket = 'gs://leap-persistent/jbusecke'
flux_path = f"{bucket}/scale-aware-air-sea/results/CM26_fluxes_{version}{suffix}.zarr"
flux_mapper = fs.get_mapper(flux_path)
output_path = f'leap-scratch/jbusecke/scale-aware-air-sea/decomposition/CM26_decomposed_{version}{suffix}.zarr'
output_mapper = fs.get_mapper(output_path)

In [4]:
# new scale separation
def decomposition(ds):
    def filt(ds):
        return smooth_inputs_dataset(ds, ['yt_ocean', 'xt_ocean'], 50)
    decomp = {}
    # Q_H (AB) - high resolution input
    decomp['Q_H'] = ds.sel(smoothing='smooth_none')
    decomp['Q_H_bar'] = filt(decomp['Q_H'])
    # Q_L low resolution input
    decomp['Q_L'] = ds.sel(smoothing='smooth_all')
    decomp['Q_L_bar'] = filt(ds.sel(smoothing='smooth_all'))
    decomp['Q_L_prime'] = decomp['Q_L'] - decomp['Q_L_bar'] # TODO: I could potentially compute this on the fly...
    
    # mixed low resolution input
    decomp['Q_L_ocean'] = ds.sel(smoothing='smooth_vel_tracer_ocean')
    decomp['Q_L_ocean_bar'] = filt(decomp['Q_L_ocean'])
    
    decomp['Q_L_atmos'] = ds.sel(smoothing='smooth_vel_tracer_atmos')
    decomp['Q_L_atmos_bar'] = filt(decomp['Q_L_atmos'])
    
    
    # Inferred Small scale
    decomp['Q_star'] = decomp['Q_H_bar'] - decomp['Q_L']
    decomp['Q_star_star'] = decomp['Q_H_bar'] - decomp['Q_L_bar']
    
    decomp['Q_star_ocean'] = decomp['Q_H_bar'] - decomp['Q_L_ocean']
    decomp['Q_star_ocean_bar'] = filt(decomp['Q_star_ocean'])
    decomp['Q_star_star_ocean'] = decomp['Q_H_bar'] - decomp['Q_L_ocean_bar']
    
    decomp['Q_star_atmos'] = decomp['Q_H_bar'] - decomp['Q_L_atmos']
    decomp['Q_star_atmos_bar'] = filt(decomp['Q_star_atmos'])
    decomp['Q_star_star_atmos'] = decomp['Q_H_bar'] - decomp['Q_L_atmos_bar']
    
    decomp['Q_star_res_wrong'] = decomp['Q_star'] - decomp['Q_star_star_ocean'] - decomp['Q_star_star_atmos']
    decomp['Q_star_res'] = decomp['Q_star'] - decomp['Q_star_ocean'] - decomp['Q_star_atmos']
    decomp['Q_star_star_res'] = decomp['Q_star_star'] - decomp['Q_star_star_ocean'] - decomp['Q_star_star_atmos'] 
    
    # for testing
    # decomp['Q_H_bar_bar'] = filt(decomp['Q_H_bar'])
    # decomp['Q_star_star_star'] = decomp['Q_H_bar_bar'] - decomp['Q_L_bar']
    # decomp['Q_star_res'] = decomp['Q_star'] - decomp['Q_star_ocean'] - decomp['Q_star_atmos']
    
    # concat into a single dataset
    datasets = [ds.drop([dvar for dvar in ['smoothing'] if dvar in ds]).assign_coords(term=k) for k,ds in decomp.items()]
    ds_out = xr.concat(datasets, dim='term', combine_attrs="override")
    return ds_out

In [5]:
ds = xr.open_dataset(flux_mapper, engine='zarr', chunks={})
ds

Unnamed: 0,Array,Chunk
Bytes,2.28 GiB,111.24 MiB
Shape,"(3, 7, 3, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,2 Graph Layers,21 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 2.28 GiB 111.24 MiB Shape (3, 7, 3, 2700, 3600) (1, 1, 3, 2700, 3600) Count 2 Graph Layers 21 Chunks Type float32 numpy.ndarray",7  3  3600  2700  3,

Unnamed: 0,Array,Chunk
Bytes,2.28 GiB,111.24 MiB
Shape,"(3, 7, 3, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,2 Graph Layers,21 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,2.28 GiB,111.24 MiB
Shape,"(3, 7, 3, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,2 Graph Layers,21 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 2.28 GiB 111.24 MiB Shape (3, 7, 3, 2700, 3600) (1, 1, 3, 2700, 3600) Count 2 Graph Layers 21 Chunks Type float32 numpy.ndarray",7  3  3600  2700  3,

Unnamed: 0,Array,Chunk
Bytes,2.28 GiB,111.24 MiB
Shape,"(3, 7, 3, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,2 Graph Layers,21 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,2.28 GiB,111.24 MiB
Shape,"(3, 7, 3, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,2 Graph Layers,21 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 2.28 GiB 111.24 MiB Shape (3, 7, 3, 2700, 3600) (1, 1, 3, 2700, 3600) Count 2 Graph Layers 21 Chunks Type float32 numpy.ndarray",7  3  3600  2700  3,

Unnamed: 0,Array,Chunk
Bytes,2.28 GiB,111.24 MiB
Shape,"(3, 7, 3, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,2 Graph Layers,21 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,2.28 GiB,111.24 MiB
Shape,"(3, 7, 3, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,2 Graph Layers,21 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 2.28 GiB 111.24 MiB Shape (3, 7, 3, 2700, 3600) (1, 1, 3, 2700, 3600) Count 2 Graph Layers 21 Chunks Type float32 numpy.ndarray",7  3  3600  2700  3,

Unnamed: 0,Array,Chunk
Bytes,2.28 GiB,111.24 MiB
Shape,"(3, 7, 3, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,2 Graph Layers,21 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,2.28 GiB,111.24 MiB
Shape,"(3, 7, 3, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,2 Graph Layers,21 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 2.28 GiB 111.24 MiB Shape (3, 7, 3, 2700, 3600) (1, 1, 3, 2700, 3600) Count 2 Graph Layers 21 Chunks Type float32 numpy.ndarray",7  3  3600  2700  3,

Unnamed: 0,Array,Chunk
Bytes,2.28 GiB,111.24 MiB
Shape,"(3, 7, 3, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,2 Graph Layers,21 Chunks
Type,float32,numpy.ndarray


In [6]:
decomp = decomposition(ds)
decomp

Unnamed: 0,Array,Chunk
Bytes,6.52 GiB,111.24 MiB
Shape,"(20, 3, 3, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,258 Graph Layers,60 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 6.52 GiB 111.24 MiB Shape (20, 3, 3, 2700, 3600) (1, 1, 3, 2700, 3600) Count 258 Graph Layers 60 Chunks Type float32 numpy.ndarray",3  20  3600  2700  3,

Unnamed: 0,Array,Chunk
Bytes,6.52 GiB,111.24 MiB
Shape,"(20, 3, 3, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,258 Graph Layers,60 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,6.52 GiB,111.24 MiB
Shape,"(20, 3, 3, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,258 Graph Layers,60 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 6.52 GiB 111.24 MiB Shape (20, 3, 3, 2700, 3600) (1, 1, 3, 2700, 3600) Count 258 Graph Layers 60 Chunks Type float32 numpy.ndarray",3  20  3600  2700  3,

Unnamed: 0,Array,Chunk
Bytes,6.52 GiB,111.24 MiB
Shape,"(20, 3, 3, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,258 Graph Layers,60 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,6.52 GiB,111.24 MiB
Shape,"(20, 3, 3, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,258 Graph Layers,60 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 6.52 GiB 111.24 MiB Shape (20, 3, 3, 2700, 3600) (1, 1, 3, 2700, 3600) Count 258 Graph Layers 60 Chunks Type float32 numpy.ndarray",3  20  3600  2700  3,

Unnamed: 0,Array,Chunk
Bytes,6.52 GiB,111.24 MiB
Shape,"(20, 3, 3, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,258 Graph Layers,60 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,6.52 GiB,111.24 MiB
Shape,"(20, 3, 3, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,258 Graph Layers,60 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 6.52 GiB 111.24 MiB Shape (20, 3, 3, 2700, 3600) (1, 1, 3, 2700, 3600) Count 258 Graph Layers 60 Chunks Type float32 numpy.ndarray",3  20  3600  2700  3,

Unnamed: 0,Array,Chunk
Bytes,6.52 GiB,111.24 MiB
Shape,"(20, 3, 3, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,258 Graph Layers,60 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,6.52 GiB,111.24 MiB
Shape,"(20, 3, 3, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,258 Graph Layers,60 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 6.52 GiB 111.24 MiB Shape (20, 3, 3, 2700, 3600) (1, 1, 3, 2700, 3600) Count 258 Graph Layers 60 Chunks Type float32 numpy.ndarray",3  20  3600  2700  3,

Unnamed: 0,Array,Chunk
Bytes,6.52 GiB,111.24 MiB
Shape,"(20, 3, 3, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,258 Graph Layers,60 Chunks
Type,float32,numpy.ndarray


In [7]:
#TODO: I want attrs to propagate through...

## Local Version

In [8]:
# try a local cluster for testing
from distributed import Client, LocalCluster
cluster = LocalCluster(n_workers=4, threads_per_worker=4)
client = Client(cluster)
client

2022-12-22 19:37:31,926 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-yek49qss', purging
2022-12-22 19:37:31,927 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-ckl702e5', purging
2022-12-22 19:37:31,927 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-m1m0pikq', purging


0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: /user/jbusecke/air_sea_project/proxy/8787/status,

0,1
Dashboard: /user/jbusecke/air_sea_project/proxy/8787/status,Workers: 4
Total threads: 16,Total memory: 58.87 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:34335,Workers: 4
Dashboard: /user/jbusecke/air_sea_project/proxy/8787/status,Total threads: 16
Started: Just now,Total memory: 58.87 GiB

0,1
Comm: tcp://127.0.0.1:39011,Total threads: 4
Dashboard: /user/jbusecke/air_sea_project/proxy/46559/status,Memory: 14.72 GiB
Nanny: tcp://127.0.0.1:42647,
Local directory: /tmp/dask-worker-space/worker-4nq60z_j,Local directory: /tmp/dask-worker-space/worker-4nq60z_j

0,1
Comm: tcp://127.0.0.1:37499,Total threads: 4
Dashboard: /user/jbusecke/air_sea_project/proxy/43127/status,Memory: 14.72 GiB
Nanny: tcp://127.0.0.1:35579,
Local directory: /tmp/dask-worker-space/worker-kfrchl_s,Local directory: /tmp/dask-worker-space/worker-kfrchl_s

0,1
Comm: tcp://127.0.0.1:42443,Total threads: 4
Dashboard: /user/jbusecke/air_sea_project/proxy/36743/status,Memory: 14.72 GiB
Nanny: tcp://127.0.0.1:40541,
Local directory: /tmp/dask-worker-space/worker-qcbe24ud,Local directory: /tmp/dask-worker-space/worker-qcbe24ud

0,1
Comm: tcp://127.0.0.1:42813,Total threads: 4
Dashboard: /user/jbusecke/air_sea_project/proxy/35689/status,Memory: 14.72 GiB
Nanny: tcp://127.0.0.1:44699,
Local directory: /tmp/dask-worker-space/worker-76luueua,Local directory: /tmp/dask-worker-space/worker-76luueua


## Gateway-Version (not working with the taper filter atm)

In [9]:
# import dask
# dask.config.set(
#     {
#         "distributed.comm.timeouts.tcp": "720s",
#         "distributed.comm.timeouts.connect": "720s",
#         "distributed.scheduler.allowed-failures":10,
#     }
# )

# print(dask.config.get("distributed.scheduler.allowed-failures"))

# from dask_gateway import Gateway
# gateway = Gateway()



# # close existing clusters
# open_clusters = gateway.list_clusters()
# print(list(open_clusters))
# if len(open_clusters)>0:
#     for c in open_clusters:
#         cluster = gateway.connect(c.name)
#         cluster.shutdown()  

# options = gateway.cluster_options()
# options.worker_memory = 52
# options.worker_cores = 12

# options.environment = dict(
#     DASK_DISTRIBUTED__SCHEDULER__WORKER_SATURATION="1.0"
# )

# # Create a cluster with those options
# cluster = gateway.new_cluster(options)
# client = cluster.get_client()

# # cluster.adapt(10, 200)
# cluster.scale(200)
# client

In [10]:
ds_save = decomp.isel(algo=0, time=slice(0,360*5)) # TODO: run this for all time, but then it would be reallly big
# ds_write = ds_write.sel(term=['Q_star', 'Q_star_ocean', 'Q_star_atmos','Q_star_res_real','Q_star_res_real', 'Q_star_star', 'Q_star_star_ocean', 'Q_star_star_atmos','Q_star_star_res'])

print(f"{ds_save.nbytes/1e12}TB")

if suffix == '_test':
    # only works on small subsets, otherwise see below.
    ds_save.to_zarr(output_mapper, mode='w')
else:
    # to_zarr_split(ds_save, output_mapper, split_interval=200)
    to_zarr_split(ds_save, output_mapper, split_interval=300)

0.011664051816TB
