In [169]:
import itertools
import numpy as np
from numba import njit

from utilities import chol_params_to_lower_triangular_matrix
from utilities import cov_matrix_to_sdcorr_params

from jax import jacfwd
from kernel_transformations_jax import covariance_from_internal as covariance_from_internal_jax
from kernel_transformations_jax import sdcorr_from_internal as sdcorr_from_internal_jax
from kernel_transformations_jax import probability_from_internal as probability_from_internal_jax

from numpy.testing import assert_array_almost_equal

$$
\tilde{\text{vec}}
\left (
\begin{matrix}
(0,0)  &        &        &        \\
(1, 0) & (1,1)  &        &        \\
(2, 0) & (2, 1) & (2, 2) &        \\
(3, 0) & (3, 1) & (3, 2) & (3, 3) \\
\end{matrix}
\right ) =:
\tilde{\text{vec}}(L) = 
\big ( (0,0), (1,0), (1,1), (2,0), (2, 1), (2, 2), (3, 0), (3, 1), (3, 2), (3, 3) \big )^\top := v
$$

The following two functions allow us to move between these two representation of a (lower-triangular) matrix in a bijective fashion.

In [2]:
MAX_VALUE = 500

SEQUENCE_I = list(itertools.chain.from_iterable(itertools.repeat(i-1, i) for i in range(1, MAX_VALUE)))
SEQUENCE_J = list(itertools.chain.from_iterable(range(i-1) for i in range(1, MAX_VALUE)))

SEQUENCE_I = np.array(SEQUENCE_I)
SEQUENCE_J = np.array(SEQUENCE_J)

@njit
def _vectorized_index_to_matrix_index(index):
    return np.array([SEQUENCE_I[index], SEQUENCE_J[index]])

@njit
def _matrix_index_to_vectorized_index(i, j):
    return int(i * (i + 1) / 2) +  j

for k in range(100):
    assert _matrix_index_to_vectorized_index(*_vectorized_index_to_matrix_index(k)) == k

## Derivative of ``covariance_from_internal``

In [66]:
@njit
def derivative_covariance_from_internal(internal_values):
    dim = len(internal_values)
    
    deriv = np.zeros((dim, dim))
    
    for i in range(dim):
        outer_index = _vectorized_index_to_matrix_index(i)
        
        a = outer_index[0]
        b = outer_index[1]
        min_j = _matrix_index_to_vectorized_index(a, b)
        for j in range(min_j, dim):
            inner_index = _vectorized_index_to_matrix_index(j)
            n = inner_index[0]
            m = inner_index[1]
            
            deriv[i, j] = _derivative_covariance_from_internal_inner(n, m, a, b, internal_values)
                
    return deriv.T

In [67]:
@njit
def _derivative_covariance_from_internal_inner(n, m, a, b, internal_values):
    if b > n:
        deriv = 0
    elif a == n and n == m:
        deriv = 2 * internal_values[_matrix_index_to_vectorized_index(n, b)]
    elif a == n and n != m:
        deriv = internal_values[_matrix_index_to_vectorized_index(m, b)]
    elif a == m and n != m:
        deriv = internal_values[_matrix_index_to_vectorized_index(n, b)]
    else:
        deriv = 0
        
    return deriv

## Example / Testing

In [72]:
J = jacfwd(covariance_from_internal_jax)

In [76]:
def get_random_internal(dim, seed=0):
    np.random.seed(seed)
    chol = np.tril(np.random.randn(dim, dim))
    internal = chol[np.tril_indices(len(chol))]
    return internal

In [77]:
for dim in range(10, 100):
    internal = get_random_internal(dim)

    jax_deriv = J(internal)

    my_deriv = derivative_covariance_from_internal(internal)

    assert_array_almost_equal(jax_deriv, my_deriv)

## Timeit

In [78]:
internal = get_random_internal(100)

In [79]:
%timeit J(internal)

1.89 s ± 75.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [80]:
%timeit derivative_covariance_from_internal(internal)

993 ms ± 145 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Derivative of ``sdcorr_from_internal``

In [128]:
def number_of_triangular_elements_to_dimension(num):
    """Calculate the dimension of a square matrix from number of triangular elements.
    Args:
        num (int): The number of upper or lower triangular elements in the matrix.
    Examples:
        >>> number_of_triangular_elements_to_dimension(6)
        3
        >>> number_of_triangular_elements_to_dimension(10)
        4
    """
    return int(np.sqrt(8 * num + 1) / 2 - 0.5)

In [129]:
def chol_params_to_lower_triangular_matrix(params):
    dim = number_of_triangular_elements_to_dimension(len(params))
    mat = np.zeros((dim, dim))
    mat[np.tril_indices(dim)] = params
    return mat

In [156]:
def derivative_sdcorr_from_internal(internal_values):
    dim = len(internal_values)
    chol = chol_params_to_lower_triangular_matrix(internal_values)
    deriv = _derivative_sdcorr_from_internal_njit(dim, chol)
    return deriv

In [155]:
@njit
def _derivative_sdcorr_from_internal_njit(dim, chol):
    deriv = np.zeros((dim, dim))
    
    for i in range(dim):
        outer_index = _vectorized_index_to_matrix_index(i)
        
        a = outer_index[0]
        b = outer_index[1]
        min_j = _matrix_index_to_vectorized_index(a, b)
        for j in range(min_j, dim):
            inner_index = _vectorized_index_to_matrix_index(j)
            n = inner_index[0]
            m = inner_index[1]
            
            deriv[i, j] = _derivative_sdcorr_from_internal_inner(n, m, a, b, chol)
                
    return deriv.T

In [154]:
@njit
def _derivative_sdcorr_from_internal_inner(n, m, a, b, chol):
    ln_norm = np.sqrt(np.sum(chol[n] ** 2))
    lm_norm = np.sqrt(np.sum(chol[m] ** 2))
    lnlm_norm = ln_norm * lm_norm
    
    const = lnlm_norm ** 2
    
    part1 = 0
    if b <= n and a == n:
        part1 += lm_norm
    if b <= m and a == m:
        part1 += ln_norm
    part1 *= chol[a, b]
    
    part2 = 0
    if b <= m:
        if a == n:
            part2 += chol[m, b]
        if a == m:
            part2 += chol[n, b]
            
    factor = np.dot(chol[n], chol[m])
            
    deriv = (lnlm_norm * part2 - factor * part1) / const
    return deriv

## Example / Testing

In [136]:
J = jacfwd(sdcorr_from_internal_jax)

In [139]:
for dim in range(10, 11):
    internal = get_random_internal(dim)

    jax_deriv = J(internal)

    my_deriv = derivative_sdcorr_from_internal(internal)

    assert_array_almost_equal(jax_deriv, my_deriv)

AssertionError: 
Arrays are not almost equal to 6 decimals

Mismatch: 25.2%
Max absolute difference: 1.17721257
Max relative difference: nan
 x: array([[ 1.      ,  0.      ,  0.      , ...,  0.      ,  0.      ,
         0.      ],
       [ 0.      ,  0.098566,  0.99513 , ...,  0.      ,  0.      ,...
 y: array([[0.866247, 0.      , 0.      , ..., 0.      , 0.      , 0.      ],
       [0.994493, 0.031119, 0.      , ..., 0.      , 0.      , 0.      ],
       [0.      , 0.062239, 0.628365, ..., 0.      , 0.      , 0.      ],...

In [182]:
internal = np.array([1, 2, 4, 2, 0.5, 1])

In [167]:
J(internal).round(2)

DeviceArray([[ 1.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ],
             [ 0.  ,  0.45,  0.89,  0.  ,  0.  ,  0.  ],
             [ 0.  ,  0.  ,  0.  ,  0.87,  0.22,  0.44],
             [ 0.  ,  0.18, -0.09,  0.  ,  0.  ,  0.  ],
             [ 0.  ,  0.  ,  0.  ,  0.1 , -0.08, -0.17],
             [ 0.  ,  0.14, -0.07, -0.03,  0.33, -0.11]], dtype=float32)

In [168]:
derivative_sdcorr_from_internal(internal).round(2)

array([[ 0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.02,  0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  , -0.69, -1.39,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  , -0.33,  0.  ,  0.  ],
       [ 0.  , -0.07, -0.47, -0.32,  0.26,  0.  ],
       [ 0.  ,  0.  ,  0.  , -0.98, -0.25, -0.49]])

## Derivative of ``probability_from_internal``

In [281]:
@njit
def derivative_probability_from_internal(internal_values):
    dim = len(internal_values)
    
    sum_ = np.sum(internal_values)
    
    part1 = np.eye(dim) / sum_
    
    part2 = np.ones((dim, dim)) * (internal_values / (sum_ ** 2))
    
    deriv = part1 - part2
    return deriv.T

## Example / Testing

In [180]:
J = jacfwd(probability_from_internal_jax)

In [None]:
bad = []
for dim in range(10, 100):
    try:
        internal = get_random_internal(dim)
        jax_deriv = J(internal)
        my_deriv = derivative_probability_from_internal(internal)
        assert_array_almost_equal(jax_deriv, my_deriv)
    except: AssertionError:
        bad.append(dim)
        
print(bad)

[16, 17, 21, 34, 39, 70, 75, 90, 92]

## Timeit

In [266]:
internal = get_random_internal(100)

In [267]:
%timeit J(internal)

175 ms ± 59.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [275]:
%timeit derivative_probability_from_internal(internal)

566 ms ± 5.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
