In [1]:
import xarray as xr
import numpy as np
from time import perf_counter
from pathlib import Path

import xskillscore as xss

In [2]:
def _weighted_corr(x, y, w, dim=None):
    # example explanation: https://stats.stackexchange.com/questions/221246/such-thing-as-a-weighted-correlation
    # NOTE: this function has an error.
    # It follows the correct math, but produces incorrect results. 
    # I think this must be something not behaving as I expect with
    # Xarray's `weighted` class. 
    xw = x.weighted(w)
    yw = y.weighted(y)
    # weighted mean:
    xm = xw.mean(dim=dim)
    ym = yw.mean(dim=dim)
    # weighted variance:
    xv = xw.var(dim=dim)
    yv = yw.var(dim=dim)
    # weighted covariance:
    devx = x - xm
    devy = y - ym
    devxy = devx * devy
    covxy = devxy.weighted(w).mean(dim=dim)
    denom = np.sqrt(xv * yv)
    return covxy / denom


In [3]:
def _weighted_corr_xss(x, y, w, dim=None):
    return xss.pearson_r(x, y, dim=dim, weights=w, skipna=True, keep_attrs=True)



In [4]:
baseloc = Path("/glade/scratch/brianpm/cam_diag_climo/files")
case1loc = baseloc / "b.e20.BHIST.f09_g17.20thC.297_05"
case2loc = baseloc / "b.e20.BHIST.f09_g16.20thC.125.02"
varlist = ["CLDLIQ", "LWCF", "PS"]

In [29]:
%%time
for v in varlist:
    ds1 = xr.open_dataset(list(case1loc.glob(f"*_{v}_*"))[0])
    ds2 = xr.open_dataset(list(case2loc.glob(f"*_{v}_*"))[0])
    x1 = ds1[v]
    x2 = ds2[v]
    assert x1.shape == x2.shape, f"Shapes don't match: {x1.shape} vs. {x2.shape}"
    print(f"Input shape of {v}: {x1.shape}")
    lat = x1.lat
    coslat = np.cos(np.radians(lat))
    starttime = perf_counter()
    correlation = _weighted_corr(x1, x2, coslat, dim=("lat","lon"))
    print(f"Correlation calculation took {perf_counter()-starttime} seconds.")
    print(f"Correlation output shape: {correlation.shape} --- Max correlation: {correlation.max().item()}, Min correlation: {correlation.min().item()}")
    # Xskillscore
    starttime = perf_counter()
    x1dims = x1.dims
    excludedims = list(set(x1dims).difference(["lat","lon"]))
    xwgt = coslat.broadcast_like(x1, exclude=excludedims)
    xcorr = _weighted_corr_xss(x1, x2, xwgt, dim=["lat","lon"])
    print(f"XCorrelation calculation took {perf_counter()-starttime} seconds.")
    print(f"XCorrelation output shape: {xcorr.shape} --- Max correlation: {xcorr.max().item()}, Min correlation: {xcorr.min().item()}")


    # print(correlation)


Input shape of CLDLIQ: (12, 32, 192, 288)
Correlation calculation took 2.6723395637236536 seconds.
Correlation output shape: (12, 32) --- Max correlation: inf, Min correlation: -inf
XCorrelation calculation took 2.476666236296296 seconds.
XCorrelation output shape: (12, 32) --- Max correlation: 0.9549578481704152, Min correlation: -0.010545818459770005
Input shape of LWCF: (12, 192, 288)
Correlation calculation took 0.07073515094816685 seconds.
Correlation output shape: (12,) --- Max correlation: 0.9127461135744436, Min correlation: 0.8170271443060588
XCorrelation calculation took 0.05532099027186632 seconds.
XCorrelation output shape: (12,) --- Max correlation: 0.9683857074238917, Min correlation: 0.9386151870414404
Input shape of PS: (12, 192, 288)
Correlation calculation took 0.06826743390411139 seconds.
Correlation output shape: (12,) --- Max correlation: 0.7957456042825333, Min correlation: 0.7850279916325452
XCorrelation calculation took 0.05228321626782417 seconds.
XCorrelation 

['time', 'lev']


In [21]:
excl

['lev', 'time']

In [12]:
s1 = set(["a", "c", "w"]) # required variables
s2 = set(["z", "x", "y", "a", "b", "c"]) # e.g. all variables

if s1.issubset(s2):
    print(f"You don't have all the data: missing {s1-s2}")
else:
    print("Proceed")

Proceed
