In [1]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
import zarr
import gcsfs
from tqdm.autonotebook import tqdm
import os
import cftime
import json
from dask import array

%matplotlib inline
plt.rcParams['figure.figsize'] = 12, 6
%config InlineBackend.figure_format = 'retina' 

  import sys


In [None]:
from dask.distributed import Client
from dask_kubernetes import KubeCluster

cluster = KubeCluster.from_yaml('worker-spec.yml')
cluster.adapt(minimum=1, maximum=20, interval='2s')
client = Client(cluster)
client

In [2]:
df = pd.read_csv('https://storage.googleapis.com/cmip6/cmip6-zarr-consolidated-stores.csv')
gcs = gcsfs.GCSFileSystem(token='anon')
df.head()

Unnamed: 0,activity_id,institution_id,source_id,experiment_id,member_id,table_id,variable_id,grid_label,zstore,dcpp_init_year
0,AerChemMIP,BCC,BCC-ESM1,histSST,r1i1p1f1,AERmon,mmrbc,gn,gs://cmip6/AerChemMIP/BCC/BCC-ESM1/histSST/r1i...,
1,AerChemMIP,BCC,BCC-ESM1,histSST,r1i1p1f1,AERmon,mmrdust,gn,gs://cmip6/AerChemMIP/BCC/BCC-ESM1/histSST/r1i...,
2,AerChemMIP,BCC,BCC-ESM1,histSST,r1i1p1f1,AERmon,mmroa,gn,gs://cmip6/AerChemMIP/BCC/BCC-ESM1/histSST/r1i...,
3,AerChemMIP,BCC,BCC-ESM1,histSST,r1i1p1f1,AERmon,mmrso4,gn,gs://cmip6/AerChemMIP/BCC/BCC-ESM1/histSST/r1i...,
4,AerChemMIP,BCC,BCC-ESM1,histSST,r1i1p1f1,AERmon,mmrss,gn,gs://cmip6/AerChemMIP/BCC/BCC-ESM1/histSST/r1i...,


In [3]:
dfs = pd.read_csv('pangeo.csv')
dfs

Unnamed: 0.1,Unnamed: 0,activity_id,institution_id,source_id,experiment_id,member_id,table_id,variable_id,grid_label,zstore,dcpp_init_year
0,4961,CMIP,AWI,AWI-CM-1-1-MR,1pctCO2,r1i1p1f1,Amon,rlut,gn,gs://cmip6/CMIP/AWI/AWI-CM-1-1-MR/1pctCO2/r1i1...,
1,4963,CMIP,AWI,AWI-CM-1-1-MR,1pctCO2,r1i1p1f1,Amon,rsut,gn,gs://cmip6/CMIP/AWI/AWI-CM-1-1-MR/1pctCO2/r1i1...,
2,4969,CMIP,AWI,AWI-CM-1-1-MR,1pctCO2,r1i1p1f1,Amon,ts,gn,gs://cmip6/CMIP/AWI/AWI-CM-1-1-MR/1pctCO2/r1i1...,
3,4975,CMIP,AWI,AWI-CM-1-1-MR,1pctCO2,r1i1p1f1,Omon,tos,gn,gs://cmip6/CMIP/AWI/AWI-CM-1-1-MR/1pctCO2/r1i1...,
4,4994,CMIP,AWI,AWI-CM-1-1-MR,abrupt-4xCO2,r1i1p1f1,Amon,rlut,gn,gs://cmip6/CMIP/AWI/AWI-CM-1-1-MR/abrupt-4xCO2...,
...,...,...,...,...,...,...,...,...,...,...,...
426,50450,CMIP,NUIST,NESM3,abrupt-4xCO2,r1i1p1f1,Omon,mlotst,gn,gs://cmip6/CMIP/NUIST/NESM3/abrupt-4xCO2/r1i1p...,
427,50453,CMIP,NUIST,NESM3,abrupt-4xCO2,r1i1p1f1,Omon,tos,gn,gs://cmip6/CMIP/NUIST/NESM3/abrupt-4xCO2/r1i1p...,
428,50866,CMIP,NUIST,NESM3,piControl,r1i1p1f1,Amon,ts,gn,gs://cmip6/CMIP/NUIST/NESM3/piControl/r1i1p1f1...,
429,50873,CMIP,NUIST,NESM3,piControl,r1i1p1f1,Omon,mlotst,gn,gs://cmip6/CMIP/NUIST/NESM3/piControl/r1i1p1f1...,


In [4]:
# file_attrs = json.load(open('file_attrs.txt','r'))
# all_attrs = set(['_'.join([a['source_id'],a['member_id'],a['experiment_id'],a['table_id'],a['variable_id']]) for a in file_attrs])

# def want(s, m, e, t, v):
#     key = s+'_'+m+'_'+e+'_'+t+'_'+v
#     return s+'_'+m+'_'+e+'_'+t+'_'+v

# dfs = df[df[['source_id','member_id','experiment_id','table_id','variable_id']].apply(lambda x: want(*x) in all_attrs, axis=1)]

# pangeo_attrs = set(dfs[['source_id','member_id','experiment_id','table_id','variable_id']].apply(lambda x: want(*x), axis=1).values)
# manual_attrs = [a for a in all_attrs if a not in pangeo_attrs]
# json.dump(list(pangeo_attrs), open('pangeo_loads.txt', 'w'))
# json.dump(list(manual_attrs), open('manual_loads.txt', 'w'))

# dfs.to_csv('pangeo_loads.csv')

In [5]:
def load_srch_data(df, source_id, expt_id):

    uri = df[(df.source_id == source_id) &
                         (df.experiment_id == expt_id)].zstore.values[0]
    
    ds = xr.open_zarr(gcs.get_mapper(uri), consolidated=True)
    return ds

def load_data(series):
    ds = xr.open_zarr(gcs.get_mapper(series.zstore), consolidated=True)
    return ds

def get_dims(ds):
    ds_coords = [l for l in list(ds.coords.keys()) if 'bnds' not in l and 'vert' not in l]
    dims = [[l for l in ds_coords if 'lat' in l][0], [l for l in ds_coords if 'lon' in l][0]]
    lat = ds.coords.get(dims[0]).data
    lon = ds.coords.get(dims[1]).data
    return lat, lon, dims

def get_area(ds, df):
    var = ds.get(ds.variable_id)
    realm = ds.table_id[0].lower()
    lat, lon, dims = get_dims(ds)

    df_area = df.query("variable_id == 'areacell"+realm+"' & source_id == '"+ds.source_id+"'")
    if len(df_area.zstore.values) == 0:
        if len(lat) > 2000:
            area = lat
            dims = ["ncells"]
            total_area = lat.sum()
        else:
            time, lon, area = np.meshgrid(ds.time, np.cos(lat), lon, indexing='ij')
            total_area = area[0,:,:].sum()
    else:
        ds_area = xr.open_zarr(gcs.get_mapper(df_area.zstore.values[0]), consolidated=True)
        area = ds_area.get("areacell"+realm)
        total_area = area.sum(area.dims)
        dims = area.dims

    return area, dims, total_area

def avg_var(ds, df):
    area, dims, total_area = get_area(ds, df)
    var = ds.get(ds.variable_id)
    
    ta_timeseries = (var * area).sum(dim=dims) / total_area
    
    if isinstance(ta_timeseries, type(None)):
        print('failed')
    return ta_timeseries



In [6]:
results = {}

In [None]:
for num in tqdm(list(range(len(dfs)))):
    s = dfs.iloc[num]
    name = '_'.join([s.source_id, s.experiment_id, s.member_id, s.variable_id])
    
    if name + '.npy' in os.listdir('data'):
        continue
    
    ds = load_data(s)
    try:
        print(str(num),':',name)

        if ds.experiment_id == 'piControl' or ds.experiment_id == '1pctCO2':
            ds = ds.sel(time=slice(ds.time[0], ds.time[min([1799, len(ds.time)-1])]))
        elif len(ds.time) > 2400:
            ds = ds.sel(time=slice(ds.time[0], ds.time[2399]))
        m = avg_var(ds, df)

        if not isinstance(m, type(None)):
            np.save('data/'+name, np.array([m.values[:], np.array([np.datetime64(t) for t in m.time.values])[:]]))
    except:
        print('FAILED on '+str(num)+' : '+name)

HBox(children=(FloatProgress(value=0.0, max=431.0), HTML(value='')))

164 : EC-Earth3-Veg_1pctCO2_r1i1p1f1_ts
165 : EC-Earth3-Veg_1pctCO2_r1i1p1f1_mlotst
166 : EC-Earth3-Veg_1pctCO2_r1i1p1f1_tos
167 : EC-Earth3-Veg_abrupt-4xCO2_r1i1p1f1_ts
168 : EC-Earth3-Veg_abrupt-4xCO2_r1i1p1f1_tos


  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  return array(a, dtype, copy=False, order=order)


169 : EC-Earth3-Veg_piControl_r1i1p1f1_ts


  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  return array(a, dtype, copy=False, order=order)


170 : EC-Earth3-Veg_piControl_r1i1p1f1_mlotst


In [59]:
# realm = ds.table_id[0].lower()
# df_area = df.query("variable_id == 'areacell"+realm+"' & source_id == '"+ds.source_id+"'")
# ds_area = xr.open_zarr(gcs.get_mapper(df_area.zstore.values[0]), consolidated=True)
# area = ds_area.get("areacell"+realm)
lat, lon, dims = get_dims(ds)
print(ds.coords.get('latitude').latitude.data[324].compute())

[-66.41343 -66.41343 -66.41343 ... -66.41343 -66.41343 -66.41343]


In [None]:
def percentiles(ds, df):
    Area = get_area(ds, df)
    Var = ds.get(ds.variable_id)
    
    percentiles = []
    
    for t in range(len(Var.time)):
        area = Area[t]
        var = Var[t]
        time = Var.time[t]
        
        weights = np.reshape(area.data, (np.prod(np.shape(var)),1))
        vals = np.reshape(var.data, (np.prod(np.shape(var)),1))
        weights = np.array(weights[~np.isnan(vals)])
        vals = np.array(vals[~np.isnan(vals)])

        idx = np.argsort(vals)
        vals = np.take_along_axis(vals, idx, axis=0)
        sorted_weights = np.array(np.take_along_axis(weights, idx, axis=0))
        total = np.nansum(weights)

        i = 0
        low = 0
        N = len(weights)
        pcts = np.array([0.05, 0.17, 0.5, 0.83, 0.95])
        pct_vals = []
        for j in range(0,len(vals)):
            low = low + sorted_weights[j]
            high = low + sorted_weights[min(j+1,N)]
            # If the cumulative weights are nearest the next percentile
            # Then mark down the value
            if low/total < pcts[i] and high/total >= pcts[i]:
                pct_vals += [[pcts[i], vals[j]]]
                if i == np.shape(pcts)[0] - 1:
                    break
                i = i + 1
                
        percentiles += [time, pct_vals]
    
    return percentiles
