In [1]:
import numpy as np
import model

import matplotlib.pyplot as plt
from pymir import mpl_stylesheet
from pymir import mpl_utils
mpl_stylesheet.banskt_presentation(splinecolor = 'black')

from gradvi.normal_means import NormalMeansFromPosterior
from gradvi.tests import toy_priors

In [2]:
def fun1(x, sj2, s2, dj, jac = True):
    #h = np.sum(x**2)
    nm = NormalMeansFromPosterior(x, prior, sj2, scale = s2, d = dj, method = 'newton')
    Pb, dPdb, dPdw, dPds2 = nm.penalty_operator(jac = True)
    h = np.sum(Pb)
    if jac:
        dhdx = 2 * x
        return h, dPdb
    else:
        return h

    
def numerical_derivative(fn, a, sj2, s2, dj, eps = 1e-4):
    n = a.shape[0]
    dfdx = np.zeros(n)
    h = fn(a, sj2, s2, dj, jac = False)
    for i in range(n):
        a_eps     = a.copy()
        a_eps[i] += eps
        h_eps1    = fn(a_eps, sj2, s2, dj, jac = False)
        a_eps[i] -= 2 * eps
        h_eps2    = fn(a_eps, sj2, s2, dj, jac = False)
        dfdx[i]   = (h_eps1 - h_eps2) / (2 * eps)
    return dfdx

def fun2(x, A, sj2, s2, dj, jac = True):
    y = np.dot(A, x)
    h, dhdy = fun1(y, sj2, s2, dj, jac = True)
    if jac:
        dhdx = np.dot(A.T, dhdy)
        return h, dhdx
    return h

def numerical_derivative2(fn, a, D, sj2, s2, dj, eps = 1e-4):
    n = a.shape[0]
    dfdx = np.zeros(n)
    h = fn(a, D, sj2, s2, dj, jac = False)
    for i in range(n):
        a_eps     = a.copy()
        a_eps[i] += eps
        h_eps1    = fn(a_eps, D, sj2, s2, dj, jac = False)
        a_eps[i] -= 2 * eps
        h_eps2    = fn(a_eps, D, sj2, s2, dj, jac = False)
        dfdx[i]   = (h_eps1 - h_eps2) / (2 * eps)
    return dfdx

def center_and_scale_tfbasis(Z):
    '''
    Basis matrix Z is always 2D.
    b is the coefficient vector
    The first column of Z is all 1, hence it has zero standard deviation.
    '''
    dim  = Z.ndim
    std  = np.std(Z, axis = 0)
    skip = 0
    if std[0] == 0:
        # do not scale the first column
        print ("The first column has all equal values.")
        std[0] = 1.0
        skip = 1
    Znew = Z / std
    colmeans = np.mean(Znew[:, skip:], axis = 0)
    Znew[:, skip:] = Znew[:, skip:] - colmeans.reshape(1, -1)
    return Znew

In [3]:
n = 100
degree = 3
np.random.seed(100)
b = np.random.normal(1, 4, size = n)

# Trendfiltering matrices
M = model.trendfiltering_basis_matrix(n, degree)
T = model.trendfiltering_basis_matrix_inverse(n, degree)
Ms = center_and_scale_tfbasis(M)
Ts = np.linalg.inv(Ms)

# Prior
prior = toy_priors.get_ash_scaled(k = 4, sparsity = None, skbase = 10)

The first column has all equal values.


In [4]:
B = M.copy()
Binv = T.copy()

In [5]:
mdj = np.sum(np.square(B), axis = 0)
ms2 = 1.2 ** 2
msj2 = ms2 / mdj
print(mdj)

[1.00000000e+02 3.28350000e+05 4.75414170e+08 3.56820159e+11
 3.32218551e+11 3.09085357e+11 2.87346804e+11 2.66932109e+11
 2.47773397e+11 2.29805603e+11 2.12966389e+11 1.97196052e+11
 1.82437447e+11 1.68635897e+11 1.55739114e+11 1.43697125e+11
 1.32462185e+11 1.21988709e+11 1.12233196e+11 1.03154156e+11
 9.47120374e+10 8.68691638e+10 7.95896614e+10 7.28393958e+10
 6.65859075e+10 6.07983498e+10 5.54474273e+10 5.05053373e+10
 4.59457116e+10 4.17435607e+10 3.78752182e+10 3.43182886e+10
 3.10515946e+10 2.80551270e+10 2.53099958e+10 2.27983823e+10
 2.05034933e+10 1.84095157e+10 1.65015733e+10 1.47656844e+10
 1.31887209e+10 1.17583685e+10 1.04630884e+10 9.29207998e+09
 8.23524490e+09 7.28315216e+09 6.42700456e+09 5.65860616e+09
 4.97033094e+09 4.35509252e+09 3.80631504e+09 3.31790504e+09
 2.88422442e+09 2.50006442e+09 2.16062064e+09 1.86146903e+09
 1.59854280e+09 1.36811040e+09 1.16675430e+09 9.91350767e+08
 8.39050486e+08 7.07260086e+08 5.93624486e+08 4.96010086e+08
 4.12488765e+08 3.413226

In [6]:
np.allclose(np.dot(B, Binv), np.eye(n))

True

In [7]:
h, dhdb = fun1(b, msj2, ms2, mdj)
d1 = numerical_derivative(fun1, b, msj2, ms2, mdj)

g, dgdb = fun2(b, Binv, msj2, ms2, mdj)
d2 = numerical_derivative2(fun2, b, Binv, msj2, ms2, mdj)

In [8]:
np.testing.assert_allclose(dhdb, d1, atol = 1e-2, rtol = 1e-8)

In [9]:
np.testing.assert_allclose(dgdb, d2, atol = 1e-2, rtol = 1e-8)