In [2]:
import dask
import fsspec
import glob
import intake
import requests
import datetime
import gc  

import matplotlib as mpl
import numpy as np
import pandas as pd
import xarray as xr

from collections import defaultdict
from matplotlib import pyplot as plt
from params import allnames

from dask.diagnostics import progress
from scipy.stats import norm
from tqdm.autonotebook import tqdm
import xml.etree.ElementTree as ET

from params import allnames
from params import homedir
from params import experiment_ids, years, table_ids, labels, variables, savepath

figdir = homedir + 'figures/'


import xml.etree.ElementTree as ET

from dask.diagnostics import progress
from scipy.stats import norm
from tqdm.autonotebook import tqdm

from params import allnames, experiment_ids, variables, years, table_ids, labels
from params import homedir, savepath

## ## style ## ##
xr.set_options(display_style='html')
plt.style.use('./science.mplstyle')
mpl.rcParams['axes.linewidth'] = 2
mpl.rcParams['xtick.major.width'] = 2
mpl.rcParams['ytick.major.width'] = 2
mpl.rcParams['xtick.top']= False
mpl.rcParams['ytick.right']= False


  from tqdm.autonotebook import tqdm


In [3]:
import os
## 
def wrfread(modeldir, gcm, exp, variant, domain, var):
    all_files = sorted(os.listdir(modeldir))
    read_files = []
    for ii in all_files:
        if (
            ii.startswith(var + ".")
            #and gcm in ii
            and variant in ii
            and domain in ii
            and exp in ii
        ):
            if domain in ii:
                read_files.append(os.path.join(modeldir, str(ii)))
    assert len(read_files) > 0, f"No matching files found in {modeldir}"

    del all_files

    data = xr.open_mfdataset(read_files, combine="by_coords")
    var_read = data.variables[var]

    dates = []
    for val in data["day"].data:
        try:
            dates.append(datetime.datetime.strptime(str(val)[0:-2], "%Y%m%d").date())
        except ValueError:
            dates.append(datetime.datetime(int(str(val)[0:4]), int(str(val)[4:6]), 28))


    var_read = xr.DataArray(var_read, dims=["day", "lat2d", "lon2d"])
    var_read["day"] = dates
    return var_read



In [4]:

def get_swei(ds):
    swe = ds['snow']
    ntime = swe.shape[0]
    nlat = swe.shape[2]
    nlon = swe.shape[1]
    nyr = int(ntime / 12)
    nd = nlat * nlon
    nm = 12

    # Compute the 3-month cumulative sum for each pixel
    ds_cumsum = ds.rolling(time=3, min_periods=3).sum()

    years = np.unique(ds.time.dt.year)
    months = np.unique(ds.time.dt.month)

    # Reshape the data back into a 4D array of (year, month, lat, lon)
    ds_new = xr.DataArray(
        ds_cumsum['snow'].data.reshape((-1, 12, ds.sizes['lat'], ds.sizes['lon'])),
        dims=('year', 'month', 'lat', 'lon'),
        coords={'year': years, 'month': months, 'lat': ds['lat'], 'lon': ds['lon']}
    )
    categ = np.zeros((nyr, nm, nlon, nlat))
    nsample = nyr
    sweix = droughtindx(nsample)  # all values for each pixel.
    sweix = np.array(sweix)

    aindx = np.argsort(ds_new.data, axis=0)

    # Create a broadcasting version of sweix
    sweix_broadcasted = sweix[:, np.newaxis, np.newaxis, np.newaxis]

    # Assign sorted sweix values to categ based on sorted indices (array, indices, values, axis)
    np.put_along_axis(categ, aindx,sweix_broadcasted, axis=0)

    # Create the new xarray Dataset
    ds_swei = xr.Dataset(
        {'swei':(('year','month','lat','lon'), categ)},
        coords={'year': years, 'month': months,'lat': ds['lat'], 'lon': ds['lon'], }
    )
    return ds_swei


## use this
def get_highest_month(ds,var = 'snow'):
    tmp = ds[var].data.reshape(-1,12,ds[var].shape[1],ds[var].shape[2])
    highest_month = np.argmax(tmp, axis=1)
    coords = {'year': np.unique(ds.time.dt.year), 'lat' : ds.lat, 'lon' : ds.lon}
    data = xr.DataArray(highest_month, dims=('year', 'lat','lon'), coords=coords)
    return data

def convert_year_month(ds,var):
    years = np.unique(ds.time.dt.year)
    months = np.unique(ds.time.dt.month)
    
    data = ds[var].data.reshape(-1,12,ds[var].shape[1],ds[var].shape[2])
    # We then create a new dataset with year, month, lat, lon coordinates
    ds_new = xr.Dataset(
        {
            var: (("year", "month", "lat", "lon"), data),
        },
        coords={
            "year": years,
            "month": months,
            "lat": ds.lat,
            "lon": ds.lon,
        },
    )
    ds_new.attrs = ds.attrs
    return ds_new
    
## use this
def collapse_to_highest_month(ds, var, snw_ds, convert = True):
    highest_month = get_highest_month(snw_ds)
    if convert:
        ds = convert_year_month(ds, var)
    highest_month = highest_month.drop('lat')
    highest_month = highest_month.drop('lon')
    result = ds.sel(month=highest_month, method='nearest').drop('month')
    return result

## usethis
def get_3m_sum(data):
    rolling_sum = data.rolling(time = 3, min_periods=3).sum()
    return rolling_sum

def get_sd_categ(swei, pr, tas, attrs):
    swei_cond = swei < -0.8  # snow drought
    pr_cond = pr < 0  # dry 
    tas_cond = tas > 0  # warm
    
    ds_new = xr.Dataset(
        {
            "swei_cond": (("year", "lat", "lon"), swei_cond.swei.values),
            "pr_cond": (("year", "lat", "lon"), pr_cond.prec.values),
            "tas_cond": (("year", "lat", "lon"), tas_cond.t2.values),
        },
        coords={
            "year": pr.year,
            "lat": pr.lat,
            "lon": pr.lon,
        },
    )
    ds_new.attrs = attrs

    return ds_new
from scipy.stats import norm
def droughtindx(nsample):
    indx = []
    for i in range(nsample):
        px = (i+1-0.44)/(nsample+0.12)
        indx.append(norm.ppf(px))
    return indx


def polarCentral_set_latlim(lat_lims, ax):
    ax.set_extent([-180, 180, lat_lims[0], lat_lims[1]], ccrs.PlateCarree())
    theta = np.linspace(0, 2*np.pi, 100)
    center, radius = [0.5, 0.5], 0.5
    verts = np.vstack([np.sin(theta), np.cos(theta)]).T
    circle = mpath.Path(verts * radius + center)
    ax.set_boundary(circle, transform=ax.transAxes)
def reshape_3d(ds, var):
    # Convert daily data to monthly data by taking the maximum value for each month
    ds_sorted = ds.sortby('time')
    ds_monthly = ds_sorted.resample(time='1M').max(dim='time')
    # Extract year and month from the time coordinate
    ds_monthly['year'] = ds_monthly['time.year']
    ds_monthly['month'] = ds_monthly['time.month']
    prepend_times = pd.date_range(start='1980-01-31', periods=8, freq='M')
    append_times = pd.date_range(end='2100-09-30', periods=4, freq='M')
    time_dataarray = ds_monthly.time
    prepend_dataarray = xr.DataArray(prepend_times, dims=['time'])
    append_dataarray = xr.DataArray(append_times, dims=['time'])
    time_to_use = xr.concat([prepend_dataarray, time_dataarray, append_dataarray], dim='time')

    month = np.unique(ds_monthly.month.data) #[9,10,11,12,1,  2,  3,  4,  5,  6,  7,  8]
    year = np.unique(ds_monthly.year.data)
    lat = ds_monthly.lat2d.data
    lon = ds_monthly.lon2d.data
    data =ds_monthly[var]
    # Number of layers to add
    n_front_layers = 8
    n_end_layers = 4
    pad_widths = [(n_front_layers, n_end_layers)] + [(0, 0)] * (data.ndim - 1)
    arr_padded = np.pad(data, pad_widths, mode='constant', constant_values=np.nan)

    #reshaped_data = arr_padded.reshape((-1,12,arr_padded.shape[1], arr_padded.shape[2]))
    # Create xarray dataset
    reshaped = xr.Dataset(
        {
            var: (['time', 'lat', 'lon'], arr_padded),
        },
        coords={
            'time': time_to_use,
            'lat': lat,
            'lon': lon,
        },
    )
    return reshaped

In [15]:
import util

domains = ['d01']
variables = ['snow', 't2', 'prec']
gcms = ['cesm2','mpi-esm1-2-lr','cnrm-esm2-1',
        'ec-earth3-veg','fgoals-g3','ukesm1-0-ll',
        'canesm5','access-cm2','ec-earth3',]

## TO DO : get data transferred for d01 for these: access, earth3 [no veg]
## and run
variants = ['r11i1p1f1','r7i1p1f1','r1i1p1f2',
            'r1i1p1f1','r1i1p1f1','r2i1p1f2',
            'r1i1p2f1','r5i1p1f1','r1i1p1f1',]
      

calendar = ['365_day','proleptic_gregorian','proleptic_gregorian',
            'proleptic_gregorian','365_day','360_day',
             '365_day','proleptic_gregorian', 'proleptic_gregorian',]

ssps = ['ssp370','ssp370','ssp370','ssp370',
        'ssp370','ssp370','ssp370','ssp370',
        'ssp370']
'''  

domains = ['d03', 'd04']
variables = ['snow', 't2', 'prec']
gcms = ['ec-earth3-veg',]
variants = ['r1i1p1f1',]
calendar = ['proleptic_gregorian']
ssps = ['ssp370']
'''

basedir = '/global/cfs/cdirs/m4099/fate-of-snotel/wrfdata/'
model = None
n = 0
for domain in domains:
    for idx, gcm in enumerate(gcms[n:]):
        variant = variants[idx+n]
        mod_historical = gcm +'_'+ variant + '_historical_bc'
        mod_future = gcm +'_' + variant+ '_ssp370_bc'
        datadir = basedir + mod_historical + '/postprocess/' + domain + '/'
        for var in variables:
            # historical
            gcm = mod_historical
            datadir = basedir + mod_historical + '/postprocess/' + domain + '/'
            date_start_pd, date_end_pd = [1980, 1, 1], [2013, 12, 31]  # 30 years, historical
            exp = "hist"
            var_wrf = wrfread(datadir, gcm, exp, variant, domain, var)
            var_wrf = util.screen_times_wrf(var_wrf, date_start_pd, date_end_pd)

            # future 
            date_start_pd, date_end_pd = [2014, 1, 1], [2100, 12, 31]
            gcm = mod_future
            exp = "ssp370"
            datadir = basedir + mod_future + '/postprocess/' + domain + '/'

            var_wrf_ssp370 = wrfread(datadir, gcm, exp, variant, domain, var)
            var_wrf_ssp370 = util.screen_times_wrf(var_wrf_ssp370, date_start_pd, date_end_pd)
            wrfdata = [var_wrf, var_wrf_ssp370]
            ds_concat = xr.concat(wrfdata, dim = 'day').to_dataset(name = var)
            ds_concat = ds_concat.rename({'day':'time'})
            ds_concat['time'] = ds_concat['time'].astype('datetime64')
            try:
                ds_concat.to_netcdf(f'{savepath}{var}_{gcms[idx+n]}_{domain}_bc.nc')
                print('saved' + f'{savepath}{var}_{gcms[idx+n]}_{domain}_bc.nc')
            except:
                print('already loaded')

saved/global/cfs/cdirs/m4099/cowherd/snow_c_d01_bc.nc
saved/global/cfs/cdirs/m4099/cowherd/t2_c_d01_bc.nc
saved/global/cfs/cdirs/m4099/cowherd/prec_c_d01_bc.nc
saved/global/cfs/cdirs/m4099/cowherd/snow_p_d01_bc.nc
saved/global/cfs/cdirs/m4099/cowherd/t2_p_d01_bc.nc
saved/global/cfs/cdirs/m4099/cowherd/prec_p_d01_bc.nc
saved/global/cfs/cdirs/m4099/cowherd/snow_r_d01_bc.nc
saved/global/cfs/cdirs/m4099/cowherd/t2_r_d01_bc.nc
saved/global/cfs/cdirs/m4099/cowherd/prec_r_d01_bc.nc
saved/global/cfs/cdirs/m4099/cowherd/snow_e_d01_bc.nc
saved/global/cfs/cdirs/m4099/cowherd/t2_e_d01_bc.nc
saved/global/cfs/cdirs/m4099/cowherd/prec_e_d01_bc.nc
saved/global/cfs/cdirs/m4099/cowherd/snow_l_d01_bc.nc
saved/global/cfs/cdirs/m4099/cowherd/t2_l_d01_bc.nc
saved/global/cfs/cdirs/m4099/cowherd/prec_l_d01_bc.nc
saved/global/cfs/cdirs/m4099/cowherd/snow_1_d01_bc.nc
saved/global/cfs/cdirs/m4099/cowherd/t2_1_d01_bc.nc
saved/global/cfs/cdirs/m4099/cowherd/prec_1_d01_bc.nc
saved/global/cfs/cdirs/m4099/cowherd/sno

FileNotFoundError: [Errno 2] No such file or directory: '/global/cfs/cdirs/m4099/fate-of-snotel/wrfdata/access-cm2_r5i1p1f1_ssp370_bc/postprocess/d01/'

In [16]:

domains = ['d01']
variables = ['snow', 't2', 'prec']
gcms = ['cesm2','mpi-esm1-2-lr','cnrm-esm2-1',
        'ec-earth3-veg','fgoals-g3','ukesm1-0-ll',
        'canesm5']#,'access-cm2','ec-earth3',]


variants = ['r11i1p1f1','r7i1p1f1','r1i1p1f2',
            'r1i1p1f1','r1i1p1f1','r2i1p1f2',
            'r1i1p2f1','r5i1p1f1','r1i1p1f1',]
      

calendar = ['365_day','proleptic_gregorian','proleptic_gregorian',
            'proleptic_gregorian','365_day','360_day',
             '365_day','proleptic_gregorian', 'proleptic_gregorian',]

ssps = ['ssp370','ssp370','ssp370','ssp370',
        'ssp370','ssp370','ssp370','ssp370',
        'ssp370']

'''
domains = ['d03', 'd04']
variables = ['snow', 't2', 'prec']
gcms = ['ec-earth3-veg',]
variants = ['r1i1p1f1',]
calendar = ['proleptic_gregorian']
ssps = ['ssp370']
'''
gc.collect()
datasets = {}

for domain in domains:
    for gcm in gcms:
        filenames = [f"{savepath}{var}_{gcm}_{domain}_bc.nc" for var in variables]
        for filename in filenames:
            datasets[filename.split('/')[-1].split('.')[0]] = xr.open_dataset(filename)
            
datasets_3m_sum = {}
swei_datasets = {}
for name, ds in datasets.items():
    var = name.split('_')[0]
    ds = reshape_3d(ds,var)
    datasets_3m_sum[name] = get_3m_sum(ds)
    if var == 'snow':
        swei_datasets[name] = get_swei(ds)

datasets_3m_sum_maxsnw = {}
for name, ds in datasets_3m_sum.items():
    var = name.split('_')[0]
    snw_ds = datasets_3m_sum[f'snow_{name.split(f"{var}_")[-1]}']
    datasets_3m_sum_maxsnw[name] = collapse_to_highest_month(ds, var, snw_ds, True)


swei_datasets_maxsnw = {}
for name, ds in swei_datasets.items():
    var = name.split('_')[0]
    snw_long = datasets_3m_sum[f'snow_{name.split(f"{var}_")[-1]}']
    tmp = collapse_to_highest_month(ds, 'swei', snw_long, False)
    tmp.to_netcdf(f'{savepath}swei_max_{name}.nc')
    swei_datasets_maxsnw[f'{name}'] = tmp
categs = {}
for name, swei in swei_datasets_maxsnw.items():
    var = name.split('_')[0]
    
    pr_long = datasets_3m_sum_maxsnw[f'prec_{name.split(f"{var}_")[-1]}']
    pr_anom = pr_long - pr_long.sel(year=slice(pr_long.year[0], pr_long.year[49])).mean(dim='year') ## val minus average

    tas_long = datasets_3m_sum_maxsnw[f't2_{name.split(f"{var}_")[-1]}']
    tas_anom = tas_long - tas_long.sel(year=slice(tas_long.year[0], tas_long.year[49])).mean(dim='year') ## val minus average

    tmp = get_sd_categ(swei, pr_anom ,tas_anom, pr_long.attrs)
    try:
        tmp.to_netcdf(f'{savepath}categs_{name}.nc')
        print('saved ' + f'{savepath}categs_{name}.nc')
    except:
        continue
    categs[name] = tmp

saved /global/cfs/cdirs/m4099/cowherd/categs_snow_cesm2_d01_bc.nc
saved /global/cfs/cdirs/m4099/cowherd/categs_snow_mpi-esm1-2-lr_d01_bc.nc
saved /global/cfs/cdirs/m4099/cowherd/categs_snow_cnrm-esm2-1_d01_bc.nc
saved /global/cfs/cdirs/m4099/cowherd/categs_snow_ec-earth3-veg_d01_bc.nc
saved /global/cfs/cdirs/m4099/cowherd/categs_snow_fgoals-g3_d01_bc.nc
saved /global/cfs/cdirs/m4099/cowherd/categs_snow_ukesm1-0-ll_d01_bc.nc
saved /global/cfs/cdirs/m4099/cowherd/categs_snow_canesm5_d01_bc.nc


In [43]:
def reshape_data(ds, var):
    # Convert daily data to monthly data by taking the maximum value for each month
    ds_sorted = ds.sortby('time')
    ds_monthly = ds_sorted.resample(time='1M').max(dim='time')
    # Extract year and month from the time coordinate
    ds_monthly['year'] = ds_monthly['time.year']
    ds_monthly['month'] = ds_monthly['time.month']
    ds_monthly = ds_monthly.drop('time')
    month = np.unique(ds_monthly.month.data) #[9,10,11,12,1,  2,  3,  4,  5,  6,  7,  8]
    year = np.unique(ds_monthly.year.data)
    lat = ds_monthly.lat2d.data
    lon = ds_monthly.lon2d.data
    data =ds_monthly[var]
    # Number of layers to add
    n_front_layers = 8
    n_end_layers = 4
    pad_widths = [(n_front_layers, n_end_layers)] + [(0, 0)] * (data.ndim - 1)
    arr_padded = np.pad(data, pad_widths, mode='constant', constant_values=np.nan)

    reshaped_data = arr_padded.reshape((-1,12,arr_padded.shape[1], arr_padded.shape[2]))
    # Create xarray dataset
    reshaped = xr.Dataset(
        {
            var: (['year','month', 'lat', 'lon'], reshaped_data),
        },
        coords={
            'year': year,
            'month': month,
            'lat': lat,
            'lon': lon,
        },
    )
    return reshaped