In [2]:
#import clik
#import camb
#import ipyparallel as ipp

import autograd.numpy as np
from autograd import grad, jacobian, hessian, elementwise_grad
from autograd.scipy.linalg import sqrtm
from autograd.scipy.stats import norm
from scipy.optimize import minimize, least_squares
from matplotlib import pyplot as plt
from ipywidgets import interact
from scipy.stats import mvn


is_psd = lambda xx: np.all(np.linalg.eigvals(xx) >= 0)


def elementwise_hessian2(f, ii=0, *args, **kwargs):
    
    return np.stack([elementwise_grad(lambda *aa, **kk: np.transpose(elementwise_grad(f, ii)(*aa, **kk))[j], ii)(*args, **kwargs) 
                     for j in range(np.shape((*args, kwargs.items())[ii])[1])], axis=1)


def elementwise_hessian(f, ii=0):
    
    return (lambda *args, **kwargs: elementwise_hessian2(f, ii, *args, **kwargs))


def elementwise_jacobian2(f, ii=0, *args, **kwargs):
    
    return np.stack([elementwise_grad(lambda *aa, **kk: f(*aa, **kk)[:, j], ii)(*args, **kwargs) for j in 
                     range(f(*args, **kwargs).shape[-1])], axis=1)


def elementwise_jacobian(f, ii=0):
    
    return (lambda *args, **kwargs: elementwise_jacobian2(f, ii, *args, **kwargs))


def mvn_logpdf(x, mean, hess, normalized=False):
    
    dim = mean.size
    result = -0.5 * np.sum(((x - mean) @ hess) * (x - mean), axis=-1)
    if normalized:
        lh = np.linalg.cholesky(hess)
        result += np.log(np.prod(np.diag(lh))) - dim / 2 * np.log(2 * np.pi)
    return result


def mvn_pdf(x, mean, hess, normalized=False):
    
    dim = mean.size
    result = np.exp(-0.5 * np.sum(((x - mean) @ hess) * (x - mean), axis=-1))
    if normalized:
        lh = np.linalg.cholesky(hess)
        result *= np.prod(np.diag(lh)) / (2 * np.pi)**(dim / 2)
    return result


def norm_logpdf(x, mean, hess, normalized=False):
    
    result = -0.5 * (x - mean) * hess * (x - mean)
    if normalized:
        result += 0.5 * np.log(hess) - 0.5 * np.log(2 * np.pi)
    return result


def norm_pdf(x, mean, hess, normalized=False):
    
    result = np.exp(-0.5 * (x - mean) * hess * (x - mean))
    if normalized:
        result *= (hess / (2 * np.pi))**0.5
    return result


def _yj(xx, lamb):
    
    if lamb == 0:
        pos = np.log(np.abs(xx) + 1)
        neg = -((-xx + 1)**2 - 1) / 2
        return np.where(xx>=0, pos, neg)
    
    elif lamb == 2:
        pos = ((xx + 1)**2 - 1) / 2
        neg = -np.log(np.abs(xx) + 1)
        return np.where(xx>=0, pos, neg)
    
    elif 0 <= lamb <= 2:
        pos = (np.abs(xx + 1)**lamb - 1) / lamb
        neg = -((np.abs(xx) + 1)**(2 - lamb) - 1) / (2 - lamb)
        return np.where(xx>=0, pos, neg)
    
    else:
        return ValueError
    

def yj(xx, eps):
    
    if eps > 1:
        return _yj(yj(xx, eps - 1), 2)
    
    elif eps < -1:
        return _yj(yj(xx, eps + 1), 0)
    
    elif -1 <= eps <= 1:
        return _yj(_yj(_yj(_yj(xx, eps + 1), eps + 1), eps + 1), eps + 1)
    
    else:
        return ValueError

        
def _jy(xx, lamb):
    
    if lamb == 0:
        pos = np.exp(xx) - 1
        neg = 1 - (1 + 2 * np.abs(xx))**(1 / 2)
        return np.where(xx>=0, pos, neg)
    
    elif lamb == 2:
        pos = np.abs(2 * xx + 1)**(1 / 2) - 1
        neg = 1 - np.exp(-xx)
        return np.where(xx>=0, pos, neg)
    
    elif 0 <= lamb <= 2:
        pos = np.abs(lamb * xx + 1)**(1 / lamb) - 1
        neg = 1 - (1 + (2 - lamb) * np.abs(xx))**(1 / (2 - lamb))
        return np.where(xx>=0, pos, neg)
    
    else:
        return ValueError
    

def jy(xx, eps):
    
    if eps > 1:
        return _jy(jy(xx, eps - 1), 2)
    
    elif eps < -1:
        return _jy(jy(xx, eps + 1), 0)
    
    elif -1 <= eps <= 1:
        return _jy(_jy(_jy(_jy(xx, eps + 1), eps + 1), eps + 1), eps + 1)
    
    else:
        return ValueError
        

def sinhn(xx, eta):
    
    if eta > 0:
        return np.sinh(eta * xx) / eta
    
    elif eta < 0:
        return np.arcsinh(eta * xx) / eta
    
    elif eta == 0:
        return xx #+ xx**5 * eta**4 / 24
    
    else:
        raise ValueError

    
def _to_gauss(zz, nl):
    
    eta, eps, beta = nl
    #return np.sinh((eta + 1) * np.arcsinh(yj(zz, eps) / (eta + 1)))
    #return np.sign(yj(zz, eps)) * ((np.abs(yj(zz, eps)) + 1)**(eta + 1) - 1) / (eta + 1)
    return sinhn(yj(zz/np.exp(beta), eps), eta)*np.exp(beta)
    #return yj(sinhn(zz, eta), eps)

    
def _from_gauss(yy, nl):
    
    eta, eps = nl
    return jy(sinhn(yy, -eta), eps)
        
    
def _to_gauss_g(zz, nl):
    
    if zz.ndim == 0:
        return grad(_to_gauss, 0)(zz, nl)
    
    elif zz.ndim == 1:
        return elementwise_grad(_to_gauss, 0)(zz, nl)
    
    else:
        raise ValueError
        
    
def to_gauss(zz, nl):
    
    zz_2 = zz.reshape((1, -1)) if zz.ndim == 1 else zz.T
    nl_2 = np.atleast_2d(nl)
    dim = zz_2.shape[0]
    return np.array([_to_gauss(zz_2[i], nl_2[i]) for i in range(dim)]).T.reshape(zz.shape)


def from_gauss(yy, nl):
    
    yy_2 = yy.reshape((1, -1)) if yy.ndim == 1 else yy.T
    nl_2 = np.atleast_2d(nl)
    dim = yy_2.shape[0]
    return np.array([_from_gauss(yy_2[i], nl_2[i]) for i in range(dim)]).T.reshape(yy.shape)


def to_gauss_g(zz, nl):
    
    zz_2 = zz.reshape((1, -1)) if zz.ndim == 1 else zz.T
    nl_2 = np.atleast_2d(nl)
    dim = zz_2.shape[0]
    return np.array([_to_gauss_g(zz_2[i], nl_2[i]) for i in range(dim)]).T.reshape(zz.shape)


def log_q(xx, mean, hess, nl, normalized=False):
    
    yy = to_gauss(xx - mean, nl)
    
    if mean.size <= 1:
        return (norm_logpdf(yy, np.zeros_like(mean), hess, normalized) + np.log(np.abs(to_gauss_g(xx - mean, nl)))).reshape(xx.shape)
        
    elif mean.size > 1:
        return mvn_logpdf(yy, np.zeros_like(mean), hess, normalized) + np.sum(np.log(np.abs(to_gauss_g(xx - mean, nl))), axis=-1)
    
    else:
        raise ValueError


ModuleNotFoundError: No module named 'autograd'

In [2]:
# start from a 1-d Gaussian N(mu, sig^2), transform it with (eta, eps, beta) and plot the transformed pdf. Here the rescaling factor is exp(beta) so it's always positive.
@interact(mu=(-3,3,0.01), sig=(0.01,5,0.01), eta=(-1,1,0.01), eps=(-1,1,0.01), beta=(-1,1,0.01))
def foo(mu, sig, eta, eps, beta):
    xx = np.linspace(-10,10,1000)
    plt.plot(xx, np.exp(log_q(xx, np.array(mu), 1/sig**2, np.array([eta, eps, beta]), True)))

interactive(children=(FloatSlider(value=0.0, description='mu', max=3.0, min=-3.0, step=0.01), FloatSlider(valu…