In [None]:
import jax
from jax import numpy as jnp
import numpy as np
import distrax
import haiku as hk
from residual import TriangularResidual, ConstantScaling
from utils import get_config

from plotting import cart2pol, scatterplot_variables

from mixing_functions import build_moebius_transform, build_automorphism

In [None]:
model_root = '/draco/u/vstimper/projects/ica-flows/experiments/triresflow/2d/0300/'
config = get_config(model_root + 'config/config.yaml')

In [None]:
S_train = jnp.array(jnp.load(model_root + 'data/sources_train.npy'))
S_test = jnp.array(jnp.load(model_root + 'data/sources_test.npy'))
X_train = jnp.array(jnp.load(model_root + 'data/observation_train.npy'))
X_test = jnp.array(jnp.load(model_root + 'data/observation_test.npy'))
mean_std = jnp.load(model_root + 'data/observation_mean_std.npy', allow_pickle=True).item()
mean_train, std_train = mean_std['mean'], mean_std['std']
moeb_params = jnp.load(model_root + 'data/moebius_transform_params.npy', allow_pickle=True).item()

In [None]:
alpha = 1.0
A = jnp.array(moeb_params['A'])
a = jnp.array(moeb_params['a'])
b = jnp.zeros(2)

mixing_moebius, mixing_moebius_inv = build_moebius_transform(alpha, A, a, b, epsilon=2)
mixing_batched = jax.vmap(mixing_moebius)

In [None]:
_, colors_train = cart2pol(S_train[:, 0], S_train[:, 1])
_, colors_test = cart2pol(S_test[:, 0], S_test[:, 1])

scatterplot_variables(S_test, 'Sources (test)',
                      colors=colors_test, savefig=False, show=True)

In [None]:
params = hk.data_structures.to_immutable_dict(jnp.load(model_root + 'checkpoints/model_100000.npy', allow_pickle=True).item())

In [None]:
# Setup model
n_layers = config['model']['flow_layers']
hidden_units = config['model']['nn_layers'] * [config['model']['nn_hidden_units']]

def inv_map_fn(x):
    flows = distrax.Chain([TriangularResidual(hidden_units + [2], name='residual_' + str(i))
                           for i in range(n_layers)] + [ConstantScaling(std_train)])
    return flows.inverse(x)

def fw_map_fn(x):
    flows = distrax.Chain([TriangularResidual(hidden_units + [2], name='residual_' + str(i))
                           for i in range(n_layers)] + [ConstantScaling(std_train)])
    return flows.forward(x)

fw_map = hk.transform(fw_map_fn)
inv_map = hk.transform(inv_map_fn)

In [None]:
S_rec = inv_map.apply(params, None, X_test)
S_rec_uni = jnp.column_stack([jax.scipy.stats.norm.cdf(S_rec[:, 0]),
                              jax.scipy.stats.norm.cdf(S_rec[:, 1])])
S_rec_uni -= 0.5

scatterplot_variables(S_rec_uni, 'Reconstructed sources (test)',
                      colors=colors_test, savefig=False, show=True)

In [None]:
theta = np.radians(25)
c, s = np.cos(theta), np.sin(theta)
R = np.array([[c, -s], [s, c]])

measure_preserving, measure_preserving_inv = build_automorphism(R)
measure_preserving_batched = jax.vmap(measure_preserving)

In [None]:
S_ = measure_preserving_batched(S_test + 0.5)

scatterplot_variables(S_, 'Mapped sources (test)',
                      colors=colors_test, savefig=False, show=True)

In [None]:
S_rec_uni_ = measure_preserving_batched(S_rec_uni + 0.5)

scatterplot_variables(S_rec_uni_, 'Reconstructed sources (test)',
                      colors=colors_test, savefig=False, show=True)