Checking the expression of the "inverse aDM" in 2d.

In [14]:
import jax
from jax import numpy as np
from jax import vmap
from jax import jacfwd

from jax import random
key = random.PRNGKey(42)

import numpy as onp

from metrics import aDM

In [15]:
# N:=Number of samples
N = 6000
# D:=Number of dimensions
D = 2

# Generate the samples
S = random.uniform(key, shape=(N, D), minval=0.0, maxval=1.0)

In [30]:
key+=1

In [31]:
# Let's use the post nonlinear model
from mixing_functions import post_nonlinear_model

A = random.multivariate_normal(key, mean=np.zeros(D*2), cov=np.eye(D*2)).reshape(2,2)

f, g = post_nonlinear_model(A, nonlinearity='cube')

In [32]:
f_batched = vmap(f)
X = f_batched(S)

In [33]:
Jf, Jg = jacfwd(f), jacfwd(g)
Jf_batched, Jg_batched = vmap(Jf), vmap(Jg)

In [34]:
aDM(Jf_batched, S)

DeviceArray(2.8009524, dtype=float32)

In [35]:
def aDM_inverse(Jg, x):
    '''
    anti Darmois Metric
    
    Input:
    Jg: batched Jacobian (function)
    Applied to s, which has shape (N,D), returns a Jacobian with shape (N,D,D)
    
    x: a collection of samples of observations (=mixed sources), which has shape (N,D)
    
    Output:
    
    aDM_inverse metric: a scalar
    '''
    
    # Get shapes and dimensions
    N = x.shape[0]
    
    # Compute the Jacobian
    Jacg = Jg(x)

    # Compute the norm of its rows
    grad_norms = np.linalg.norm(Jacg, axis=2)
#     # Compute the norm of its columns
#     grad_norms = np.linalg.norm(Jacf, axis=1)
    log_grad_norms = np.log(grad_norms)

    # Sum the norms
    sum_log_norms = np.sum(log_grad_norms, axis = -1)

    # Compute the determinants of the Jacobians
    # N.B. Jacf needs to be of shape (..., M, M)
    
    # Just use slogdet here!
    jac_log_dets = np.linalg.slogdet(Jacg)[1]
    
    return np.sum(sum_log_norms - jac_log_dets)/ N

In [36]:
aDM_inverse(Jg_batched, X)

DeviceArray(2.8009524, dtype=float32)