In [9]:
#| default_exp sharedUtilities
#| default_cls_lvl 3

# Shared utilties

In [1]:
#| export
import xarray as xr
from glob import glob
from tqdm import tqdm 
import numpy as np

import pandas as pd
import matplotlib.pyplot as plt

In [2]:
#| export
def preprocess(ds):
    """
    Used data has overlapping time data since they contain one month before the actual year
    and one month after the actual year.
    The function selects only the year of interest without the overlapping time data.

    Args:
        ds (xr.Dataset): netCDF4 data

    Returns:
        xr.Dataset: data without time overlap
    """
    return ds.sel(time=str(ds.time[len(ds.time)//2].dt.year.data)).resample(time="1D").mean()


In [3]:
#| export
def read_netcdfs(files, dim, transform_func, transform_calendar=None, cftime = True):
    """Reads multiples netcdfs files"""
    def process_one_path(path):
        try:
            if transform_calendar is not None:
                calendar = False
            else:
                calendar = True
            with xr.open_dataset(path, decode_times = calendar, use_cftime = cftime) as ds:
                if transform_calendar is not None:
                    ds[dim].attrs['calendar'] = transform_calendar
                    ds = xr.decode_cf(ds, use_cftime = cftime)
                if transform_func is not None:
                    ds = transform_func(ds)
                ds.load()
                return ds
        except Exception as e:
            print(path)
            print(e)
            return
    paths = sorted(glob(files))
    datasets = [process_one_path(p) for p in tqdm(paths)]
    datasets = [x for x in datasets if x is not None]
    combined = xr.concat(datasets, dim)
    return combined

In [4]:
#| export
def transform_calendar(ds,
                       timedim="time",
                       calendarname="proleptic_gregorin"):
    """Transforms calendar of time index in xarray dataset"""
    ds[timedim].attrs['calendar'] = calendarname
    return ds

In [5]:
#| export

class TaylorDiagram(object):
    """
    Taylor diagram.
    Plot model standard deviation and correlation to reference (data)
    sample in a single-quadrant polar plot, with r=stddev and
    theta=arccos(correlation).
    """
    def __init__(self, refdata, ax,  *args, **kwargs):
        # memorize reference data
        self.refdata = refdata
        self.ax = ax
        refstd = refdata.std(ddof=1) 
        rlim = [0.0, 1.2*refstd]
        thetalim = [0.0, 0.5*np.pi]
        if ax != None:
            # Add reference point and stddev contour
            try:
                kwargs["marker"]
            except:
                kwargs["marker"] = 'o'
            try:
                kwargs["color"]
            except:
                kwargs["color"] = "black"
            try:
                kwargs["ms"]
            except:
                kwargs["ms"] = 10     
            try:
                kwargs["label"]
            except:
                kwargs["label"] = "reference"                     
            ax.plot([0], refstd, *args, ls='', zorder=0, **kwargs)
            ax.grid(True, linestyle="--")
            
            ax.set_xlabel("St. dev.")
            ax.xaxis.set_label_coords(0.5, -0.1)
            ax.set_xlim(thetalim)
            ax.set_ylim(rlim)
            self.rlim = rlim
            self.thetalim = thetalim
        self.model_std = {"reference" : refstd}
        self.corrcoef = {"reference" : 1.0}
        self.rms = {"reference" : 0.0}
    def add_sample(self, model_data, *args, **kwargs):
        """
        Add sample to the Taylor
        diagram. *args* and *kwargs* are directly propagated to the
        `Figure.plot` command.
        """
        model_std = model_data.std(ddof=1)
        corrcoef = np.corrcoef(model_data, self.refdata)[0, 1]
        rms = np.sqrt(self.model_std["reference"]**2 + model_std**2 - 2.0*self.model_std["reference"]*model_std*np.cos(np.arccos(corrcoef)))
        try:
            label = kwargs["label"]
        except:
            label = "model"+str(len(self.model_std.keys())) 
        if self.ax != None:
            self.ax.plot(np.arccos(corrcoef), model_std,
                            *args, **kwargs)  # (theta, radius)
            if model_std > self.rlim[1]:
                self.rlim = [0, 1.2 * model_std]
                self.ax.set_ylim(*self.rlim)
            if np.arccos(corrcoef) > self.thetalim[1]:
                self.thetalim[1] = 1.2 * np.arccos(corrcoef)
                self.ax.set_xlim(*self.thetalim)                       
        self.model_std[label] = model_std
        self.corrcoef[label] = corrcoef
        self.rms[label] = rms
        #self.ax.set_ylim(rlim) 
    def get_samples(self):
        return self.model_std, self.corrcoef, self.rms
    def finalize(self):
        
        if self.ax is None:
            return 
        self.ax.text(0.5*(self.thetalim[1]-self.thetalim[0]), (1.0+0.03*(self.thetalim[1]-self.thetalim[0])**2)*self.rlim[1],"Correlation", rotation=0.5*(self.thetalim[1]-self.thetalim[0])*180.0/np.pi-90.0)
        rs, ts = np.meshgrid(np.linspace(*self.rlim), np.linspace(*self.thetalim))
        rms = np.sqrt(self.model_std["reference"]**2 + rs**2 - 2.0*self.model_std["reference"]*rs*np.cos(ts))
        self.ax.contourf(ts, rs, rms, 9, alpha=0.3, cmap="RdYlGn_r")
        contours = self.ax.contour(ts, rs, rms, 9, linestyles="--", linewidths=1, alpha=0.5, colors="grey")
        if self.model_std["reference"] > 100.0 or self.model_std["reference"] < 0.1:
            fmt = '%3.2e'
        else:
            fmt = '%3.2f'
        self.ax.clabel(contours, inline=False, fmt=fmt, colors="black")
        t = np.linspace(*self.thetalim)
        r = np.zeros_like(t) + self.model_std["reference"]
        self.ax.plot(t, r, 'k--', label='_')
        xticks = [1.0, 0.99, 0.95, 0.9, 0.8, 0.7, 0.6, 0.4, 0.2, 0.0]
        if self.thetalim[1] > 0.5*np.pi:
            for x in [-0.2, -0.4, -0.6, -0.8, -0.9, -0.95, -0.99 -1.0]:
                if np.arccos(x) > self.thetalim[1]:
                    break
                xticks += [x]
        self.ax.set_xticks(np.arccos(xticks))
        self.ax.set_xticklabels(xticks)