In [1]:
# !pip install -e /home/jovyan/PROJECTS/scale-aware-air-sea

In [2]:
import gcsfs
import xarray as xr
import numpy as np
from scale_aware_air_sea.utils import smooth_inputs_dataset, to_zarr_split

In [7]:
# new scale separation
def decomposition(ds):
    def filt(ds):
        return smooth_inputs_dataset(ds, ['yt_ocean', 'xt_ocean'], 50)
    decomp = {}
    # Q_H (AB) - high resolution input
    decomp['Q_H'] = ds.sel(smoothing='smooth_full')
    decomp['Q_H_bar'] = filt(decomp['Q_H'])
    # Q_L low resolution input
    decomp['Q_L'] = ds.sel(smoothing='smooth_all')
    decomp['Q_L_bar'] = filt(ds.sel(smoothing='smooth_all'))
    decomp['Q_L_prime'] = decomp['Q_L'] - decomp['Q_L_bar'] # TODO: I could potentially compute this on the fly...
    
    # mixed low resolution input
    decomp['Q_L_ocean'] = ds.sel(smoothing='smooth_vel_tracer_ocean')
    decomp['Q_L_ocean_bar'] = filt(decomp['Q_L_ocean'])
    
    decomp['Q_L_atmos'] = ds.sel(smoothing='smooth_vel_tracer_atmos')
    decomp['Q_L_atmos_bar'] = filt(decomp['Q_L_atmos'])
    
    
    # Inferred Small scale
    decomp['Q_star'] = decomp['Q_H_bar'] - decomp['Q_L']
    decomp['Q_star_star'] = decomp['Q_H_bar'] - decomp['Q_L_bar']
    
    decomp['Q_star_ocean'] = decomp['Q_H_bar'] - decomp['Q_L_ocean']
    decomp['Q_star_ocean_bar'] = filt(decomp['Q_star_ocean'])
    decomp['Q_star_star_ocean'] = decomp['Q_H_bar'] - decomp['Q_L_ocean_bar']
    
    decomp['Q_star_atmos'] = decomp['Q_H_bar'] - decomp['Q_L_atmos']
    decomp['Q_star_atmos_bar'] = filt(decomp['Q_star_atmos'])
    decomp['Q_star_star_atmos'] = decomp['Q_H_bar'] - decomp['Q_L_atmos_bar']
    
    decomp['Q_star_res_wrong'] = decomp['Q_star'] - decomp['Q_star_star_ocean'] - decomp['Q_star_star_atmos']
    decomp['Q_star_res'] = decomp['Q_star'] - decomp['Q_star_ocean'] - decomp['Q_star_atmos']
    decomp['Q_star_star_res'] = decomp['Q_star_star'] - decomp['Q_star_star_ocean'] - decomp['Q_star_star_atmos'] 
    
    # for testing
    # decomp['Q_H_bar_bar'] = filt(decomp['Q_H_bar'])
    # decomp['Q_star_star_star'] = decomp['Q_H_bar_bar'] - decomp['Q_L_bar']
    # decomp['Q_star_res'] = decomp['Q_star'] - decomp['Q_star_ocean'] - decomp['Q_star_atmos']
    
    # concat into a single dataset
    datasets = [ds.drop([dvar for dvar in ['smoothing'] if dvar in ds]).assign_coords(term=k) for k,ds in decomp.items()]
    ds_out = xr.concat(datasets, dim='term')
    return ds_out

In [8]:
fs = gcsfs.GCSFileSystem()
fs.ls('leap-persistent/jbusecke/scale-aware-air-sea/results')
mapper = fs.get_mapper('leap-persistent/jbusecke/scale-aware-air-sea/results/CM26_fluxes_v0.5.zarr')
ds = xr.open_dataset(mapper, engine='zarr', chunks={})
ds

Unnamed: 0,Array,Chunk
Bytes,5.42 TiB,111.24 MiB
Shape,"(3, 7, 7305, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,2 Graph Layers,51135 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 5.42 TiB 111.24 MiB Shape (3, 7, 7305, 2700, 3600) (1, 1, 3, 2700, 3600) Count 2 Graph Layers 51135 Chunks Type float32 numpy.ndarray",7  3  3600  2700  7305,

Unnamed: 0,Array,Chunk
Bytes,5.42 TiB,111.24 MiB
Shape,"(3, 7, 7305, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,2 Graph Layers,51135 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,5.42 TiB,111.24 MiB
Shape,"(3, 7, 7305, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,2 Graph Layers,51135 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 5.42 TiB 111.24 MiB Shape (3, 7, 7305, 2700, 3600) (1, 1, 3, 2700, 3600) Count 2 Graph Layers 51135 Chunks Type float32 numpy.ndarray",7  3  3600  2700  7305,

Unnamed: 0,Array,Chunk
Bytes,5.42 TiB,111.24 MiB
Shape,"(3, 7, 7305, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,2 Graph Layers,51135 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,5.42 TiB,111.24 MiB
Shape,"(3, 7, 7305, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,2 Graph Layers,51135 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 5.42 TiB 111.24 MiB Shape (3, 7, 7305, 2700, 3600) (1, 1, 3, 2700, 3600) Count 2 Graph Layers 51135 Chunks Type float32 numpy.ndarray",7  3  3600  2700  7305,

Unnamed: 0,Array,Chunk
Bytes,5.42 TiB,111.24 MiB
Shape,"(3, 7, 7305, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,2 Graph Layers,51135 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,5.42 TiB,111.24 MiB
Shape,"(3, 7, 7305, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,2 Graph Layers,51135 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 5.42 TiB 111.24 MiB Shape (3, 7, 7305, 2700, 3600) (1, 1, 3, 2700, 3600) Count 2 Graph Layers 51135 Chunks Type float32 numpy.ndarray",7  3  3600  2700  7305,

Unnamed: 0,Array,Chunk
Bytes,5.42 TiB,111.24 MiB
Shape,"(3, 7, 7305, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,2 Graph Layers,51135 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,5.42 TiB,111.24 MiB
Shape,"(3, 7, 7305, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,2 Graph Layers,51135 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 5.42 TiB 111.24 MiB Shape (3, 7, 7305, 2700, 3600) (1, 1, 3, 2700, 3600) Count 2 Graph Layers 51135 Chunks Type float32 numpy.ndarray",7  3  3600  2700  7305,

Unnamed: 0,Array,Chunk
Bytes,5.42 TiB,111.24 MiB
Shape,"(3, 7, 7305, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,2 Graph Layers,51135 Chunks
Type,float32,numpy.ndarray


In [9]:
decomp = decomposition(ds)

In [10]:
decomp

Unnamed: 0,Array,Chunk
Bytes,15.50 TiB,111.24 MiB
Shape,"(20, 3, 7305, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,258 Graph Layers,146100 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 15.50 TiB 111.24 MiB Shape (20, 3, 7305, 2700, 3600) (1, 1, 3, 2700, 3600) Count 258 Graph Layers 146100 Chunks Type float32 numpy.ndarray",3  20  3600  2700  7305,

Unnamed: 0,Array,Chunk
Bytes,15.50 TiB,111.24 MiB
Shape,"(20, 3, 7305, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,258 Graph Layers,146100 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,15.50 TiB,111.24 MiB
Shape,"(20, 3, 7305, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,258 Graph Layers,146100 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 15.50 TiB 111.24 MiB Shape (20, 3, 7305, 2700, 3600) (1, 1, 3, 2700, 3600) Count 258 Graph Layers 146100 Chunks Type float32 numpy.ndarray",3  20  3600  2700  7305,

Unnamed: 0,Array,Chunk
Bytes,15.50 TiB,111.24 MiB
Shape,"(20, 3, 7305, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,258 Graph Layers,146100 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,15.50 TiB,111.24 MiB
Shape,"(20, 3, 7305, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,258 Graph Layers,146100 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 15.50 TiB 111.24 MiB Shape (20, 3, 7305, 2700, 3600) (1, 1, 3, 2700, 3600) Count 258 Graph Layers 146100 Chunks Type float32 numpy.ndarray",3  20  3600  2700  7305,

Unnamed: 0,Array,Chunk
Bytes,15.50 TiB,111.24 MiB
Shape,"(20, 3, 7305, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,258 Graph Layers,146100 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,15.50 TiB,111.24 MiB
Shape,"(20, 3, 7305, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,258 Graph Layers,146100 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 15.50 TiB 111.24 MiB Shape (20, 3, 7305, 2700, 3600) (1, 1, 3, 2700, 3600) Count 258 Graph Layers 146100 Chunks Type float32 numpy.ndarray",3  20  3600  2700  7305,

Unnamed: 0,Array,Chunk
Bytes,15.50 TiB,111.24 MiB
Shape,"(20, 3, 7305, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,258 Graph Layers,146100 Chunks
Type,float32,numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,15.50 TiB,111.24 MiB
Shape,"(20, 3, 7305, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,258 Graph Layers,146100 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 15.50 TiB 111.24 MiB Shape (20, 3, 7305, 2700, 3600) (1, 1, 3, 2700, 3600) Count 258 Graph Layers 146100 Chunks Type float32 numpy.ndarray",3  20  3600  2700  7305,

Unnamed: 0,Array,Chunk
Bytes,15.50 TiB,111.24 MiB
Shape,"(20, 3, 7305, 2700, 3600)","(1, 1, 3, 2700, 3600)"
Count,258 Graph Layers,146100 Chunks
Type,float32,numpy.ndarray


In [12]:
import dask
dask.config.set(
    {
        "distributed.comm.timeouts.tcp": "720s",
        "distributed.comm.timeouts.connect": "720s",
        "distributed.scheduler.allowed-failures":10,
    }
)

print(dask.config.get("distributed.scheduler.allowed-failures"))

from dask_gateway import Gateway
gateway = Gateway()



# close existing clusters
open_clusters = gateway.list_clusters()
print(list(open_clusters))
if len(open_clusters)>0:
    for c in open_clusters:
        cluster = gateway.connect(c.name)
        cluster.shutdown()  

options = gateway.cluster_options()
options.worker_memory = 52
options.worker_cores = 12

options.environment = dict(
    DASK_DISTRIBUTED__SCHEDULER__WORKER_SATURATION="1.0"
)

# Create a cluster with those options
cluster = gateway.new_cluster(options)
client = cluster.get_client()

# cluster.adapt(10, 200)
cluster.scale(200)
client

10
[]


0,1
Connection method: Cluster object,Cluster type: dask_gateway.GatewayCluster
Dashboard: /services/dask-gateway/clusters/prod.2d75a9ffa188402e961a129c3214fe62/status,


In [15]:
ds_write = decomp.isel(algo=0, time=slice(0,360*5)) # TODO: run this for all time, but then it would be reallly big
# ds_write = ds_write.sel(term=['Q_star', 'Q_star_ocean', 'Q_star_atmos','Q_star_res_real','Q_star_res_real', 'Q_star_star', 'Q_star_star_ocean', 'Q_star_star_atmos','Q_star_star_res'])

print(f"{ds_write.nbytes/1e12}TB")

# Write out temp data
path = 'leap-scratch/jbusecke/scale-aware-air-sea/visualization/CM26_output_global_v3.zarr'
mapper = fs.get_mapper(path)

if fs.exists(path):
    print('Overwriting existing')
    fs.rm(mapper.root, recursive=True)

to_zarr_split(ds_write, mapper, split_interval=300)

6.998400066192TB
Writing to leap-scratch/jbusecke/scale-aware-air-sea/visualization/CM26_output_global_v3.zarr ...




  0%|          | 0/5 [00:00<?, ?it/s]

