In [1]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
import zarr
import gcsfs
#from tqdm.autonotebook import tqdm
import os
import cftime
import json
from dask import array

%matplotlib inline
plt.rcParams['figure.figsize'] = 12, 6
%config InlineBackend.figure_format = 'retina' 

In [2]:
%%capture
from tqdm import tqdm_notebook as tqdm
tqdm().pandas()

In [3]:
from dask.distributed import Client
from dask_kubernetes import KubeCluster

cluster = KubeCluster()
cluster.adapt(minimum=1, maximum=20, interval='2s')
client = Client(cluster)
client

distributed.scheduler - INFO - Clear task state
distributed.scheduler - INFO -   Scheduler at:   tcp://10.48.177.9:33833
distributed.scheduler - INFO -   dashboard at:                     :8787
distributed.scheduler - INFO - Receive client connection: Client-1d6f106c-7f96-11ea-8195-8e6cb5068046
distributed.core - INFO - Starting established connection


0,1
Client  Scheduler: tcp://10.48.177.9:33833  Dashboard: /user/ghall3-pangeo_tests-m5bmicbe/proxy/8787/status,Cluster  Workers: 0  Cores: 0  Memory: 0 B


distributed.scheduler - INFO - Register worker <Worker 'tcp://10.48.178.9:34881', name: 0, memory: 0, processing: 0>
distributed.scheduler - INFO - Starting worker compute stream, tcp://10.48.178.9:34881
distributed.core - INFO - Starting established connection


In [4]:
df = pd.read_csv('https://storage.googleapis.com/cmip6/cmip6-zarr-consolidated-stores.csv')
gcs = gcsfs.GCSFileSystem(token='anon')

In [8]:
dfs = pd.read_csv('records/pangeo.csv')

In [8]:
def load_srch_data(df, source_id, expt_id):

    uri = df[(df.source_id == source_id) &
                         (df.experiment_id == expt_id)].zstore.values[0]
    
    ds = xr.open_zarr(gcs.get_mapper(uri), consolidated=True)
    return ds

def load_data(series):
    ds = xr.open_zarr(gcs.get_mapper(series.zstore), consolidated=True)
    return ds

def get_dims(ds):
    ds_coords = [l for l in list(ds.coords.keys()) if 'bnds' not in l and 'vert' not in l]
    dims = [[l for l in ds_coords if 'lat' in l][0], [l for l in ds_coords if 'lon' in l][0]]
    lat = ds.coords.get(dims[0])
    lon = ds.coords.get(dims[1])
    return lat, lon, dims

def get_area(ds, df):
    var = ds.get(ds.variable_id)
    realm = ds.table_id[0].lower()
    lat, lon, dims = get_dims(ds)

#     df_area = df.query("variable_id == 'areacell"+realm+"' & source_id == '"+ds.source_id+"' & grid_label== '"+ds.grid_label+"'")
#     if len(df_area.zstore.values) == 0:
    if len(lat.data) > 2000:
        area = np.cos(lat * np.pi / 180)
        dims = ["ncells"]
        total_area = lat.sum()
    elif np.shape(lat) == np.shape(var)[1:]:
        area = np.cos(lat.data * np.pi / 180)
        total_area = area.sum()
        dims = ds.get(dims[0]).dims
    else:
        time, area, lon = np.meshgrid(ds.time, np.cos(lat.data * np.pi / 180), lon, indexing='ij')
        total_area = area[0,:,:].sum()
#     else:
#         ds_area = xr.open_zarr(gcs.get_mapper(df_area.zstore.values[0]), consolidated=True)
#         area = ds_area.get("areacell"+realm)
#         total_area = area.sum(area.dims)
#         dims = area.dims

    return area, dims, total_area

def avg_var(ds, df):
    area, dims, total_area = get_area(ds, df)
    var = ds.get(ds.variable_id)
    
    ta_timeseries = (var * area).sum(dim=dims) / total_area
    
    if isinstance(ta_timeseries, type(None)):
        print('failed')
    return ta_timeseries



Load all files available on Pangeo servers
=======

In [None]:
for num in tqdm(list(range(len(a_tas)))[10:]):
    s = a_tas.iloc[num]
    name = '_'.join([s.source_id, s.experiment_id, s.member_id, s.variable_id])
    
    ds = load_data(s)
    print(name)
#     df_area = df.query("variable_id == 'areacell"+ds.table_id[0].lower()+"' & source_id == '"+ds.source_id+"' & grid_label== '"+ds.grid_label+"'")
#     if len(df_area.zstore.values != 0):
#         continue
        
#     print(str(num)+" : "+name)

#     if ds.experiment_id == 'piControl' or ds.experiment_id == '1pctCO2':
#         ds = ds.sel(time=slice(ds.time[0], ds.time[min([1799, len(ds.time)-1])]))
#     elif len(ds.time) > 2400:
#         ds = ds.sel(time=slice(ds.time[0], ds.time[2399]))
    m = avg_var(ds, df)

    if not isinstance(m, type(None)):
        np.save('pangeo_data/'+name, np.array([m.values[:], np.array([np.datetime64(t) for t in m.time.values])[:]]))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  """Entry point for launching an IPython kernel.


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))

CNRM-ESM2-1_piControl_r1i1p1f2_tas


distributed.scheduler - INFO - Register worker <Worker 'tcp://10.48.178.10:33907', name: 1, memory: 0, processing: 0>
distributed.scheduler - INFO - Starting worker compute stream, tcp://10.48.178.10:33907
distributed.core - INFO - Starting established connection
  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  return array(a, dtype, copy=False, order=order)
distributed.scheduler - INFO - Retire worker names (0,)
distributed.scheduler - INFO - Retire workers {<Worker 'tcp://10.48.178.9:34881', name: 0, memory: 0, processing: 0>}
distributed.scheduler - INFO - Closing worker tcp://10.48.178.9:34881
distributed.scheduler - INFO - Remove worker <Worker 'tcp://10.48.178.9:34881', name: 0, memory: 0, processing: 0>
distributed.core - INFO - Removing comms to tcp://10.48.178.9:34881


Get all files not available on Pangeo's servers
=========

In [None]:
from collections import defaultdict
manual = json.load(open('manual_loads.txt','r'))
allfiles = json.load(open('allfiles.txt','r'))
mapping = defaultdict(list)
[mapping['_'.join([a.split('/')[9],a.split('/')[11],a.split('/')[10]]+ a.split('/')[12:14])].append(a) for a in allfiles]
to_load = [(m, mapping.get(m)) for m in manual]
to_load = sorted(to_load, key=lambda x: len(x[1]))

Get all info I need to download these by hand.
=======

In [None]:
saved = pd.read_csv('saved-data.csv')
failed = []
esgf = pd.DataFrame([m.split('_') for m in manual])
esgf = esgf.rename(columns={0:'source_id',1:'member_id',2:'experiment_id',3:'table_id',4:'variable_id'})
esgf.head()

Merge downloaded mlotst files
=======

In [None]:
files = os.listdir('wgets/abrupt-4xCO2-mlotst')
files = pd.DataFrame([m.replace('.nc','').split('_') + [m] for m in files])
files = files.rename(columns={0:"variable_id",1:"table_id",2:"source_id",3:"experiment_id",4:"member_id",5:"grid_label",6:"time_range",7:"file_name"})

In [None]:
from netCDF4 import Dataset
failed = []
for key, group in files.groupby(['experiment_id','variable_id','source_id','member_id']):
    group = group.sort_values(by=['time_range'])
    
    merged = None
    t = False
    
    try:
        for f in group['file_name']:
                ds = xr.open_dataset('wgets/abrupt-4xCO2-mlotst/'+f)
                m = avg_var(ds, df)
                if not t:
                    merged = m
                else:
                    merged = xr.concat([merged, m],"time")

        if not isinstance(merged, type(None)):
            series = group.iloc[0]
            fname = '_'.join([series.source_id, series.experiment_id, series.member_id, series.variable_id])
            saved = pd.concat([saved, group])
            np.save('manual_data/'+fname, np.array([m.values[:], np.array([np.datetime64(t) for t in m.time.values])[:]]))
    except OSError:
        print("failed on",f)
        failed.append(f)

Check for missing files in cloud
===

In [4]:
dfs = pd.read_csv('saved-data.csv')

def find(**keys):
    ndfs = dfs.copy()
    for k,v in keys.items():
        ndfs = ndfs[ndfs[k] == v]
    return ndfs

def load(ndfs=None, **keys):
    if isinstance(ndfs, type(None)):
        ndfs = find(**keys)
    ns = dict([(f, np.load(f, allow_pickle = True)) for f in ndfs.file_name])
    return ns

ts_files = find(variable_id = "ts")
bad_sources = []
for file, d in ts_files.groupby('file_name'):
    ts = list(load(file_name=file).values())[0]
    if np.mean(ts[0]) < 250:
        bad_sources += [(d.source_id.values[0],d.experiment_id.values[0], 'ts')]
                
v = {'piControl':['tos','ts','mlotst'],
    '1pctCO2':['tos','ts','rtmt','mlotst'],
    'abrupt-4xCO2':['tos','ts','mlotst']}
incompletes = []
for keys, group in dfs.groupby(['source_id','experiment_id']):
    has_all = True
    source = keys[0]
    expt = keys[1]
    s = list(group[group.experiment_id == expt].variable_id.values)
    for x in v[expt]:
        if x not in s:
            incompletes += [(*keys, x)]
redownload = incompletes + bad_sources
redownload2 = []

In [5]:
check = pd.DataFrame()
for r in redownload:
    l = df[(df.source_id == r[0]) & (df.experiment_id == r[1]) & (df.variable_id == r[2]) & (df.grid_label == 'gn')]
    if len(l.zstore.values) > 0:
        check = pd.concat([check, l])
    else:
        redownload2 += [r]

In [5]:
tas = df[df.variable_id == "tas"]
taspi = tas[tas.experiment_id == "piControl"]
tasppt = tas[tas.experiment_id == "1pctCO2"]
tasab = tas[tas.experiment_id == "abrupt-4xCO2"]
ks = [set([k for k,g in i.groupby(['source_id'])]) for i in [taspi, tasppt, tasab]]
k = ks[0] & ks[1] & ks[2]
def check_sc(row):
    return (row in k)
a_tas = pd.concat([taspi, tasppt, tasab])
a_tas = a_tas[a_tas.source_id.apply(check_sc)]

In [6]:
len(k)

30