<a href="https://colab.research.google.com/github/justinfmccarty/epwmorph/blob/main/general_cmip6_query.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install --upgrade xesmf xarray zarr gcsfs cftime nc-time-axis
! pip install xclim
! pip install intake-esm
! pip install gcsfs 

In [82]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
from xclim import ensembles
# import xesmf as xe
# import cartopy
from dask.diagnostics import progress
from tqdm.autonotebook import tqdm  # Fancy progress bars for our loops!
import intake
import dask
%matplotlib inline
plt.rcParams['figure.figsize'] = 12, 6
%config InlineBackend.figure_format = 'retina' 
sys.setrecursionlimit(1500)

In [2]:
col = intake.open_esm_datastore("https://storage.googleapis.com/cmip6/pangeo-cmip6.json")

  exec(code_obj, self.user_global_ns, self.user_ns)


In [83]:
sources = pd.DataFrame(pd.read_csv('/content/modelsources.csv'))
sources['in_ensemble'] = 'Yes'
sources.index = sources['source_id']
sources['in_ensemble']['NorESM2-LM'] = 'Never'
sources['in_ensemble']['MPI-ESM1-2-HR'] = 'Never'
sources['in_ensemble']['NorESM2-MM'] = 'Never'
# sources['in_ensemble'] = 'No'
# sources['in_ensemble']['BCC-CSM2-MR'] = 'Yes'

source_list = sources[sources['in_ensemble'] == 'Yes']['source_id'].values.tolist()


In [73]:
variable_list = ['tas','tasmax','tasmin','clt','psl','pr','huss','vas','uas','rsds']
# variable_list = ['tas','tasmax','tasmin','clt']

In [84]:
cat = col.search(experiment_id=['historical'],  # pick the `historical` forcing experiment
                 table_id='Amon',             # choose to look at atmospheric variables (A) saved at monthly resolution (mon)
                 variable_id= variable_list[9],           # choose to look at near-surface air temperature (tas) as our variable
                 member_id = 'r1i1p1f1',
                 source_id = source_list)      # arbitrarily pick one realization for each model (i.e. just one set of initial conditions)

In [85]:
time_slice = slice('1950','2100') # specific years that bracket our period of interest
longitude = -122.36 
latitude = 49.0253

if longitude < 180:
  longitude = longitude + 360
else:
  longitude = longitude

In [None]:
# convert data catalog into a dictionary of xarray datasets
dset_dict = cat.to_dataset_dict(zarr_kwargs={'consolidated': True, 'decode_times': False})

ds_dict = {}
for name, ds in tqdm(dset_dict.items()):
    print(name)
    # rename spatial dimensions if necessary
    if ('longitude' in ds.dims) and ('latitude' in ds.dims):
        ds = ds.rename({'longitude':'lon', 'latitude': 'lat'}) # some models labelled dimensions differently...
        
    ds = xr.decode_cf(ds) # temporary hack, not sure why I need this but has to do with calendar-aware metadata on the time variable
    
    ds = ds.sel(time=time_slice) # subset the data for the time period of interest

    ds = ds.sel(lat=latitude, lon=longitude, method='nearest')
    
    # drop redundant variables (like "height: 2m")
    for coord in ds.coords:
        if coord not in ['lat','lon','time']:
            ds = ds.drop(coord)

    ds.coords['year'] = ds.time.dt.year
    ds.coords['time'] = xr.cftime_range(start=str(ds.time.dt.year.values[0]),
                                        periods=len(ds.time.dt.year.values),
                                        freq="MS", calendar="noleap")

    print(ds.time.dt.floor("D"))
    # Add variable array to dictionary
    ds_dict[name] = ds

with progress.ProgressBar():
  dsets_aligned_ = dask.compute(ds_dict)[0]

In [147]:
ens = ensembles.create_ensemble([ds.reset_coords(drop=True) for ds in dsets_aligned_.values()])
percentile_list = [15, 50, 85]
ens_perc = ensembles.ensemble_percentiles(ens, values=percentile_list, split=False)
ens_stats = ensembles.ensemble_mean_std_max_min(ens)
ens

  output_sizes={"percentiles": len(values)},


In [90]:
source_ids = list(dsets_aligned_.keys())
source_da = xr.DataArray(source_ids, dims='source_id', name='source_id',
                         coords={'source_id': source_ids})


big_ds = xr.concat([ds.reset_coords(drop=True)
                    for ds in dsets_aligned_.values()],
                    dim=source_da)

big_ds_lite = big_ds#.sel(time=slice('1950','2024'))

var = 'rsds'

x = getattr(big_ds_lite, var)
ens = ensembles.create_ensemble(x)

RecursionError: ignored

# Above for Faster GRab

In [22]:
#splice together the hist and spp

ssp_dict = {}
for name, ds in ds_dict.items():
    if 'ssp' not in name: continue
        
    add_hist_ds = None
    
    # Loop through dictionary to find matching historical simulation
    for hist_name, hist_ds in ds_dict.items():
        model_name = name.split(".")[2] # extract model name from name of simulation
        
        # find historical simulation that matches SSP simulation (same model)
        if (model_name in hist_name) and ('historical' in hist_name):
            add_hist_ds = hist_ds
    
    # if we found one, slice it together with the SSP and add it to the dictionary
    if add_hist_ds is not None:
        ds_ssp = xr.concat([add_hist_ds, ds],dim='time')
        # ds_ssp = ds_ssp.sel(time=slice('2065','2095'))
        ssp_dict[name] = ds_ssp

In [None]:
percentile_list = [15, 50, 85]
ens_perc = ensembles.ensemble_percentiles(ens, values=percentile_list, split=False)
ens_stats = ensembles.ensemble_mean_std_max_min(ens)

In [None]:
def drop_all_bounds(ds):
    drop_vars = [vname for vname in ds.coords
                 if (('_bounds') in vname ) or ('_bnds') in vname]
    return ds.drop(drop_vars)

def open_dset(df):
    # assert len(df) == 1
    ds = xr.open_zarr(fsspec.get_mapper(df.zstore.values[0]), consolidated=True)
    
    if 'lat' in ds.coords:
      ds = ds.sel(
          lat=50, lon=-121,method='nearest')
    else:
      ds = ds.sel(
          latitude=50, longitude=-121,method='nearest')


    ds.coords['year'] = ds.time.dt.year
    ds.coords['time'] = xr.cftime_range(start=str(ds.time.dt.year.values[0]), 
                                    periods=len(ds.time.dt.year.values), 
                                    freq="MS", calendar="noleap")
    ds = ds.sel(time=slice('1950','2100'))
    return drop_all_bounds(ds)

def open_delayed(df):
    return dask.delayed(open_dset)(df)

from collections import defaultdict
dsets = defaultdict(dict)

for group, df in tqdm.tqdm(col_subset.df.groupby(by=['source_id', 'experiment_id'])):
    # print(group)
    dsets[group[0]][group[1]] = open_delayed(df)

with progress.ProgressBar():
    dsets_ = dask.compute(dict(dsets))[0]

# calculate global means

def get_lat_name(ds):
    for lat_name in ['lat', 'latitude']:
        if lat_name in ds.coords:
            return lat_name
    raise RuntimeError("Couldn't find a latitude coordinate")

def global_mean(ds):
    lat = ds[get_lat_name(ds)]
    weight = np.cos(np.deg2rad(lat))
    weight /= weight.mean()
    other_dims = set(ds.dims) - {'time'}
    return (ds * weight).mean(other_dims)

expt_da = xr.DataArray(expts, dims='experiment_id', name='experiment_id',
                       coords={'experiment_id': expts})

dsets_aligned = {}

for k, v in tqdm.tqdm(dsets_.items()):
    expt_dsets = v.values()
    if any([d is None for d in expt_dsets]):
        print(f"Missing experiment for {k}")
        continue

    
    for ds in expt_dsets:
        ds.coords['year'] = ds.time.dt.year

    # workaround for
    # https://github.com/pydata/xarray/issues/2237#issuecomment-620961663
    dsets_ann_mean = [v[expt]
                      for expt in expts]

    # align everything with the 4xCO2 experiment
    dsets_aligned[k] = xr.concat(dsets_ann_mean, join='outer',
                                 dim='historical')

with progress.ProgressBar():
    dsets_aligned_ = dask.compute(dsets_aligned)[0]


100%|██████████| 49/49 [00:00<00:00, 4357.49it/s]

[                                        ] | 0% Completed |  0.1s




[########################################] | 100% Completed | 14.6s


100%|██████████| 49/49 [00:00<00:00, 141.08it/s]


[########################################] | 100% Completed | 48.6s


In [None]:
source_ids = list(dsets_aligned_.keys())
source_da = xr.DataArray(source_ids, dims='source_id', name='source_id',
                         coords={'source_id': source_ids})

big_ds = xr.concat([ds.reset_coords(drop=True)
                    for ds in dsets_aligned_.values()],
                    dim=source_da)

big_ds_lite = big_ds.sel(time=slice('1950','2024'))

vari = list(big_ds_lite.data_vars.keys())[2]
print(vari)

x = getattr(big_ds_lite, 'tas')
ens = ensembles.create_ensemble(x)

percentile_list = [15, 50, 85]
ens_perc = ensembles.ensemble_percentiles(ens, values=percentile_list, split=False)
ens_stats = ensembles.ensemble_mean_std_max_min(ens)

tas


  output_sizes={"percentiles": len(values)},


### Read EPW


In [None]:
import pandas as pd

In [None]:
path1 = '/content/CAN_BC_Bella.Bella.AP.715820_TMYx.epw'
path2 = '/content/CAN_BC_Ballenas.Island.717690_TMYx.epw'

def query_epw_period(epw_orig_file):
  record = pd.read_csv(epw_orig_file,skiprows=3).iloc[1,1]
  record_years_start = record.split('=')[2].split(';')[0].split('-')[0]
  record_years_end = record.split('=')[2].split(';')[0].split('-')[1]
  return record_years_start, record_years_end

def epw_details(path):
# query_epw_period(path)
  name = pd.read_csv(path,header=None,nrows=1).iloc[0,1]
  latidude = pd.read_csv(path,header=None,nrows=1).iloc[0,7]
  longitude = pd.read_csv(path,header=None,nrows=1).iloc[0,6]
  utc = pd.read_csv(path,header=None,nrows=1).iloc[0,8]
  elevation = pd.read_csv(path,header=None,nrows=1).iloc[0,9]

df = pd.DataFrame()

pathlist = [path1,path2]
for path in pathlist:
  data = pd.read_csv(path,header=None,nrows=1,usecols=[1,2,3,4,5,6,7,8,9]).rename(columns={1:'location',
                                                                                          2:'province',
                                                                                          3:'country',
                                                                                          4:'type',
                                                                                          5:'usaf',
                                                                                          6:'longitude',
                                                                                          7:'latitude',
                                                                                          8:'utc',
                                                                                          9:'elevation'})
  df = df.append(data,ignore_index=True)
df

Unnamed: 0,location,province,country,type,usaf,longitude,latitude,utc,elevation
0,Bella Bella AP,BC,CAN,ISD-TMYx,715820,52.185,-128.1567,-8.0,43.0
1,Ballenas Island,BC,CAN,ISD-TMYx,717690,49.3503,-124.1603,-8.0,10.0


In [None]:
def epw_to_dataframe(weather_path):
    epw_labels = ['year', 'month', 'day', 'hour', 'minute', 'datasource', 'drybulb_C', 'dewpoint_C', 'relhum_percent',
                  'atmos_Pa', 'exthorrad_Whm2', 'extdirrad_Whm2', 'horirsky_Whm2', 'glohorrad_Whm2', 'dirnorrad_Whm2',
                  'difhorrad_Whm2', 'glohorillum_lux', 'dirnorillum_lux', 'difhorillum_lux', 'zenlum_lux', 'winddir_deg',
                  'windspd_ms', 'totskycvr_tenths', 'opaqskycvr_tenths', 'visibility_km', 'ceiling_hgt_m',
                  'presweathobs', 'presweathcodes', 'precip_wtr_mm', 'aerosol_opt_thousandths', 'snowdepth_cm',
                  'days_last_snow', 'Albedo', 'liq_precip_depth_mm', 'liq_precip_rate_Hour']
    return pd.DataFrame(pd.read_csv(weather_path, skiprows=8, header=None, names=epw_labels).drop('datasource', axis=1))


In [None]:
# list(range(len(pd.DataFrame(pd.read_csv('/content/population_centers.csv')))))
epw_to_dataframe('/content/CAN_BC_Abbotsford.Intl.AP.711080_TMYx.epw')['hour']

0        1
1        2
2        3
3        4
4        5
        ..
8755    20
8756    21
8757    22
8758    23
8759    24
Name: hour, Length: 8760, dtype: int64