In [1]:
import jax
from jax import numpy as np
from jax import scipy as sp
from jax import grad

import paragami


import time



In [2]:
k_approx = 30
dim = 4

In [3]:
# pattern for an array of covariances
covar_array_pattern = \
        paragami.pattern_containers.PatternArray(array_shape = (k_approx, ), \
                    base_pattern = paragami.PSDSymmetricMatrixPattern(size=dim))

# covar_array_pattern = \
#         paragami.PSDSymmetricMatrixPattern(size=dim)

In [4]:
# randomly set value
covar_array = covar_array_pattern.random()

use_free = True

covar_array_flattened = covar_array_pattern.flatten(covar_array, free = use_free)


In [5]:
# define function
def fun(covar_array): 
    return (covar_array**2).sum()

In [6]:
# flattened function
fun_flattened = paragami.FlattenFunctionInput(original_fun=fun, 
                                patterns = covar_array_pattern,
                                free = use_free,
                                argnums = 0) 

fun_flattened = jax.jit(fun_flattened)


# Function times

In [7]:
t0 = time.time()
fun_flattened(covar_array_flattened)
print(time.time() - t0)

0.48675060272216797


# gradient times

In [8]:
grad_fun_flattened = jax.jit(grad(fun_flattened))

In [9]:
# compile time (in seconds)
t0 = time.time()
grad_fun_flattened(covar_array_flattened)
print(time.time() - t0)

1.281162977218628


In [10]:
# fast after compiling

In [11]:
t0 = time.time()
grad_fun_flattened(covar_array_flattened)
print(time.time() - t0)

0.0004546642303466797


# Hessian time

In [12]:
hess_fun_flattened = jax.jit(jax.hessian(fun_flattened))

In [13]:
# compile time (in seconds)
t0 = time.time()
hess_fun_flattened(covar_array_flattened)
print(time.time() - t0)

3.269158124923706


In [14]:
# fast after compiling

In [15]:
t0 = time.time()
grad_fun_flattened(covar_array_flattened)
print(time.time() - t0)

0.0004525184631347656
