In [None]:
# NOTE: This is for when the notebook is converted to a python script
# NOTE: Must come before everything else
def get_ipython():
    return type('Dummy', (object,), dict(run_line_magic=lambda *args, **kwargs: None))

%reload_ext autoreload
%autoreload 2
from importlib import reload

import matplotlib.pyplot as plt
import matplotlib as mpl

mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 12})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.rc('figure', figsize=[8, 4])
plt.rc('axes', grid=True)
plt.rcParams.update({'grid.alpha': 0.25})


In [None]:
import kalepy as kale
kale._reload()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

import kalepy as kale

from kalepy.plot import nbshow

In [None]:
NUM = 1000
NOISE_FRAC = 0.4

sigma = [1.0, 0.1]
corr = [
    [+1.0, +0.5],
    [+0.5, +1.0],
]

cov = np.zeros_like(corr)
for (ii, jj), cc in np.ndenumerate(corr):
    cov[ii, jj] = cc * sigma[ii] * sigma[jj]

data = np.random.multivariate_normal(np.zeros_like(sigma), cov, NUM).T
sub = int(NOISE_FRAC*NUM)
idx = np.random.choice(NUM, sub, replace=False)
temp = np.random.multivariate_normal(
    np.mean(data, axis=-1), np.diag(sigma), size=sub)
data.T[idx, :] = temp

plt.scatter(*data, alpha=0.1)
nbshow()

# Top-Level API

## `kalepy.density`

In [None]:
def check_grid_shapes(pnts, vals, data):

    # Default behavior for N>1 dimensions should be 'grid' densities
    print("Input data has shape: {}, returned pnts: {}, vals: {}".format(
        *[np.shape(xx) for xx in [data, pnts, vals]]))

    # Make sure the shapes of points and values match
    if np.shape(pnts)[0] != 2:
        raise ValueError("`pnts` (shape: {}) should have length 2!".format(np.shape(pnts)))

    shape = np.shape(vals)
    psh = tuple([len(pp) for pp in pnts])
    print("Number of points in each dim: {}".format(psh))
    # Shape of `vals` should match the combined dimensions of `pnts`
    if shape != psh:
        raise ValueError("Shape of density '{}' does not match points '{}'!".format(shape, psh))

    # If the relative-density of the input data are very different, then the length of points in
    #   each dimension should also be different
    frac_var = [kale.utils.iqrange(dd)/np.subtract(*kale.utils.minmax(dd)[::-1]) for dd in data]
    frac_var = np.array(frac_var)
    print("Data iqr/span : ", frac_var)
    '''
    if not np.isclose(*stds, rtol=0.1, atol=0.0) and np.equal(*psh):
        raise ValueError("`pnts` have the same number of elements in each dimension!")
    '''
    if np.sign(np.subtract(*1/frac_var)) != np.sign(np.subtract(*psh)):
        err = "fractional variations of data {} do not match ponts shape {}!".format(frac_var, psh)
        raise ValueError(err)
        
    return


def check_pnts_span(pnts, data):
    # Make sure the `pnts` cover the full range of the data
    for ii, (ee, dd) in enumerate(zip(pnts, data)):
        ex = kale.utils.minmax(ee)
        dx = kale.utils.minmax(dd)
        print("dim {}, edges: [{:.4e}, {:.4e}], data: [{:.4e}, {:.4e}]".format(ii, *ex, *dx))
        if (dx[0] < ex[0]) or (dx[1] > ex[1]):
            raise ValueError("Span of edges ({}) is smaller than data ({})!".format(ex, dx))
            
    return


def check_in_match_out(vin, vout):
    # Output points should match input
    for ii, (pi, po) in enumerate(zip(vin, vout)):
        pi_stats = kale.utils.stats(pi, stats=False)
        po_stats = kale.utils.stats(po, stats=False)
        print("dim {}".format(ii))
        print("\tInput  points: ", pi_stats)
        print("\tOutput points: ", po_stats)

        if not np.all(pi == po):
            err = "Input points: {} do not match output points: {}!".format(pin_stats, pot_stats)
            raise ValueError(err)
            
    return

### No Args

In [None]:
pnts, vals = kale.density(data)

In [None]:
check_grid_shapes(pnts, vals, data)

In [None]:
check_pnts_span(pnts, data)

In [None]:
xx, yy = kale.utils.meshgrid(*pnts)
plt.contour(xx, yy, vals, alpha=1.0, cmap='Reds')
plt.scatter(*data, alpha=0.2, s=5, color='b')
nbshow()

### Given `points`

In `grid` mode

In [None]:
_pnts = [np.histogram(dd, bins='auto')[1] for dd in data]
pnts_grid, vals_grid = kale.density(data, points=_pnts, grid=True)

In [None]:
check_in_match_out(_pnts, pnts_grid)

In [None]:
check_grid_shapes(pnts_grid, vals_grid, data)

In [None]:
check_pnts_span(pnts_grid, data)

In `scatter` mode

In [None]:
UPSAMPLE = 10
BW_SCALE = 0.5

ndim, ndata = data.shape
num = int(UPSAMPLE*ndata)

_pnts = np.random.choice(ndata, size=num)
_pnts = data.T[_pnts].T

bw = BW_SCALE * np.power(ndata, -1.0/(ndim+4))
bw = np.cov(data) * bw

_pnts += np.random.multivariate_normal([0.0, 0.0], bw, size=num).T

plt.scatter(*_pnts, alpha=0.4, s=5)
plt.scatter(*data, alpha=0.3, s=40)
nbshow()

In [None]:
pnts_scat, vals_scat = kale.density(data, points=_pnts, grid=False)

In [None]:
check_in_match_out(_pnts, pnts_scat)

In [None]:
smap = kale.plot._get_smap(vals, cmap='plasma')
colors = smap.to_rgba(vals_scat)
plt.scatter(*pnts_scat, color=colors, alpha=0.2)
nbshow()

In [None]:
def compare_scatter_to_grid(pnts_scat, vals_scat, pnts_grid, vals_grid):
    args = [pnts_scat, vals_scat, pnts_grid, vals_grid]
    # print([kale.utils.jshape(aa) for aa in args])
    
    # Only check points within the range of data values
    extr = [kale.utils.minmax(dd) for dd in data]
    idx = [((extr[dd][0] < pnts_scat[dd]) & (pnts_scat[dd] < extr[dd][-1]))
           for dd in range(2)]
    idx = np.all(idx, axis=0)
    num = np.count_nonzero(idx)
    if num < vals_scat.size//2:
        err = "Very few points found within data values ({}/{})!".format(num, vals_scat.size)
        raise ValueError(err)
    
    pnts_scat = pnts_scat.T[idx].T
    vals_scat = vals_scat.T[idx].T
    
    # find which grid point each scatter point belongs in
    idx = [np.digitize(pp, gg) - 1 for gg, pp in zip(pnts_grid, pnts_scat)]
    vals = vals_grid[idx[0], idx[1]]

    # Compare to grid values
    diff = np.fabs(vals_scat - vals) / np.min([vals_scat, vals], axis=0)
    quants = kale.utils.quantiles(diff, sigmas=[-1, 0, 1])
    print("Difference quantiles between scatter and grid values:", quants)
    # Raise error if differences are too far off
    limits = [0.1, 0.2, 0.5]
    if np.any(quants > limits):
        raise ValueError("Match is unacceptably poor!")
    
    return
    
# print(np.shape(pnts_grid))
compare_scatter_to_grid(pnts_scat, vals_scat, pnts_grid, vals_grid)