In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'

import jax
import jax.numpy as jnp
import nifty8.re as jft
import matplotlib.pyplot as plt
from functools import partial


from jax import random

In [2]:
seed = 42
key = random.PRNGKey(seed)

### Prior models

In [3]:
# dims = (128, 128)
dims = (1000, 1000)

# correlated field zero mode GP offset and stddev
cf_zm = dict(offset_mean=0.0, offset_std=(1e-3, 1e-4))
# correlated field fluctuations (mostly don't need tuning)
cf_fl = dict(
    fluctuations=(1e0, 5e-2), # fluctuations: y-offset in power spectrum in fourier space (zero mode)
    loglogavgslope=(-3e0, 1e-2), # power-spectrum slope in log-log space in frequency domain (Fourier space)
    flexibility=(1e0, 5e-1), # deviation from simple power-law
    asperity=(5e-1, 5e-2), # small scale features in power-law
)

# put together in correlated field model
cfm = jft.CorrelatedFieldMaker("cf")
cfm.set_amplitude_total_offset(**cf_zm)
cfm.add_fluctuations(
    dims, distances=1.0 / dims[0], **cf_fl, prefix="ax1", non_parametric_kind="power"
)
gp = cfm.finalize()

# put together in correlated field model
cfm_atmos = jft.CorrelatedFieldMaker("cf_atmos")
cfm_atmos.set_amplitude_total_offset(**cf_zm)
cfm_atmos.add_fluctuations(
    dims, distances=1.0 / dims[0], **cf_fl, prefix="ax1", non_parametric_kind="power"
)
gp_atmos = cfm_atmos.finalize()

In [4]:
gp_atmos.domain

{'cf_atmoszeromode': ShapeWithDtype(shape=(), dtype=<class 'jax.numpy.float64'>),
 'cf_atmosax1fluctuations': ShapeWithDtype(shape=(), dtype=<class 'jax.numpy.float64'>),
 'cf_atmosax1loglogavgslope': ShapeWithDtype(shape=(), dtype=<class 'jax.numpy.float64'>),
 'cf_atmosax1spectrum': ShapeWithDtype(shape=(79132, 2), dtype=<class 'jax.numpy.float64'>),
 'cf_atmosax1flexibility': ShapeWithDtype(shape=(), dtype=<class 'jax.numpy.float64'>),
 'cf_atmosax1asperity': ShapeWithDtype(shape=(), dtype=<class 'jax.numpy.float64'>),
 'cf_atmosxi': ShapeWithDtype(shape=(1000, 1000), dtype=<class 'jax.numpy.float64'>)}

In [5]:
import numpy as np
noise_std = 2
test = np.random.standard_normal()
import operator



_config = jft.config._config

def hartley(p, axes=None):
    from jax.numpy import fft

    tmp = fft.fftn(p, axes=axes)
    c = _config.get("hartley_convention")
    add_or_sub = operator.add if c == "non_canonical_hartley" else operator.sub
    return add_or_sub(tmp.real, tmp.imag)


def inv_hartley(p, axes=None):
    from jax.numpy import fft
    tmp = fft.fftn(p, axes=axes)  # Forward transform (not IFFT!)
    c = _config.get("hartley_convention")
    add_or_sub = operator.add if c == "non_canonical_hartley" else operator.sub
    return add_or_sub(tmp.real, tmp.imag) / p.size  # Normalize


a = jnp.float64(np.random.normal(size = (100,100)))
b = hartley(inv_hartley(a))
print(jnp.std(b))

1.0123557536142516


In [6]:
# key, sub = random.split(key)
# xi = jft.random_like(sub, gp.domain) # generate std normal parameters 
# res = gp(xi) # draw sample from gp

# pspec = cfm_atmos.power_spectrum(xi)
# pspec.shape

def _mk_expanded_amp(amp, sub_dom):  # Avoid late binding
    def expanded_amp(p):
        return amp(p)[sub_dom.harmonic_grid.power_distributor]

    return expanded_amp

expanded_amplitudes = []
namps = cfm_atmos.get_normalized_amplitudes()
for amp, sgrid in zip(namps, cfm_atmos._target_grids):
    expanded_amplitudes.append(_mk_expanded_amp(amp, sgrid))

def atmos_amplitude(p):
    outer = expanded_amplitudes[0](p)
    for amp in expanded_amplitudes[1:]:
        # NOTE, the order is important here and must match with the
        # excitations
        # TODO, use functions instead and utilize numpy's casting
        outer = cfm_atmos.azm(p) * jnp.tensordot(outer, amp(p), axes=0)
    return outer

noise_std = 42
def modified_noise_std():
    def f(p):
        atmos_amp = atmos_amplitude(p)
        mod_noise_cov = atmos_amp**2 + noise_std**2 / atmos_amp.size # correct for hartley factor
        print("atmos_amp.size", atmos_amp.size)
        # return jnp.sqrt(mod_noise_cov)
        return mod_noise_cov**0.5

    subtree = {k: v for k, v in cfm_atmos._parameter_tree.items() if k != 'cf_atmosxi'}
    # print("HERE:", subtree)
    init = {
        # k: partial(jft.random_like, primals=v) for k, v in cfm_atmos._parameter_tree.items() if k != 'cfxi' 
        k: partial(jft.random_like, primals=v) for k, v in subtree.items()
    }   
    mod_noi_std_model = jft.Model(f, domain=subtree, init=init)

    return mod_noi_std_model

key, sub = random.split(key)
xi = jft.random_like(sub, gp.domain) # generate std normal parameters 
d = gp(xi)

noi_cov = modified_noise_std()
# def data_residual(p):
#     return d - gp(p)
# inp = [data_residual, modified_noise_std]
class Varcov_inp(jft.Model):
    def __init__(self, gp, modified_noise_std):
        self.gp = gp
        self.modified_noise_std = modified_noise_std
        super().__init__(init=self.gp.init | self.modified_noise_std.init)

    def __call__(self, x):
        return inv_hartley(self.gp(x)), self.modified_noise_std(x)


inp = Varcov_inp(gp, noi_cov)

var_cov_gau_llhd = jft.VariableCovarianceGaussian(inv_hartley(d)).amend(inp)

key, sub = random.split(key)
xi = jft.random_like(sub, var_cov_gau_llhd.domain)
print(var_cov_gau_llhd(xi))
# print(atmos_amplitude(xi).shape)
var_cov_gau_llhd.domain

atmos_amp.size 1000000
atmos_amp.size 1000000
atmos_amp.size 1000000
2603193.666037463


{'cf_atmosax1asperity': ShapeDtypeStruct(shape=(), dtype=float64),
 'cf_atmosax1flexibility': ShapeDtypeStruct(shape=(), dtype=float64),
 'cf_atmosax1fluctuations': ShapeDtypeStruct(shape=(), dtype=float64),
 'cf_atmosax1loglogavgslope': ShapeDtypeStruct(shape=(), dtype=float64),
 'cf_atmosax1spectrum': ShapeDtypeStruct(shape=(79132, 2), dtype=float64),
 'cf_atmoszeromode': ShapeDtypeStruct(shape=(), dtype=float64),
 'cfax1asperity': ShapeDtypeStruct(shape=(), dtype=float64),
 'cfax1flexibility': ShapeDtypeStruct(shape=(), dtype=float64),
 'cfax1fluctuations': ShapeDtypeStruct(shape=(), dtype=float64),
 'cfax1loglogavgslope': ShapeDtypeStruct(shape=(), dtype=float64),
 'cfax1spectrum': ShapeDtypeStruct(shape=(79132, 2), dtype=float64),
 'cfxi': ShapeDtypeStruct(shape=(1000, 1000), dtype=float64),
 'cfzeromode': ShapeDtypeStruct(shape=(), dtype=float64)}

In [7]:



assert False

AssertionError: 

In [None]:
key, sub = random.split(key)
xi = jft.random_like(sub, gp.domain) # generate std normal parameters 
res = gp(xi) # draw sample from gp

print("res type:", type(res))
print("res:", res)
print("res shape:", res.shape)

plt.imshow(res)
plt.colorbar()
plt.show()

In [5]:
ps = jft.InvGammaPrior(
    a=5.0,
    scale=1.0,
    name="ps",
    shape=dims,
)

In [None]:
key, sub = random.split(key)
xi = jft.random_like(sub, ps.domain)
res = ps(xi)

plt.imshow(res)
plt.colorbar()
plt.show()

In [7]:
class Signal(jft.Model):
    def __init__(self, gp, ps):
        self.gp = gp
        self.ps = ps
        super().__init__(init=self.gp.init | self.ps.init)

    def __call__(self, x):
        return self.gp(x) + self.ps(x)


signal = Signal(gp, ps)

In [None]:
key, sub = random.split(key)
xi = jft.random_like(sub, signal.domain)
res = signal(xi)

plt.imshow(res)
plt.colorbar()
plt.show()

### Create mock data

In [None]:
signal_response = signal

key, subkey = random.split(key)
pos_truth = jft.random_like(subkey, signal_response.domain)
signal_response_truth = signal_response(pos_truth)

noise_truth = 0.1 * jft.random_like(subkey, signal_response.target)
# noise_truth = 1e0 * jft.random_like(subkey, signal_response.target)

data = signal_response_truth + noise_truth

plt.imshow(data)
plt.colorbar()
plt.show()

plt.imshow(signal_response_truth)
plt.colorbar()
plt.show()

### Likelihood

In [None]:
noise_cov_inv = lambda x: 1e-1**-2 * x

lh = jft.Gaussian(data, noise_cov_inv).amend(signal_response)

### Approximate posterior

Maximum A Posteriori (MAP) estimation

In [None]:
n_it = 1
delta = 1e-4
n_samples = 0 # no samples -> maximum aposteriory posterior

key, k_i, k_o = random.split(key, 3)

samples, state = jft.optimize_kl(
    lh, # likelihood
    jft.Vector(lh.init(k_i)), # initial position in model space (initialisation)
    n_total_iterations=n_it, # no of optimisation steps (global)
    n_samples=n_samples, # draw samples
    key=k_o, # random jax init
    draw_linear_kwargs=dict( # sampling parameters
        cg_name="SL",
        cg_kwargs=dict(absdelta=delta * jft.size(lh.domain) / 10.0, maxiter=100),
    ),
    nonlinearly_update_kwargs=dict( # map from multivariate gaussian to more compl. distribution (coordinate transformations)
        minimize_kwargs=dict(
            name="SN",
            xtol=delta,
            cg_kwargs=dict(name=None),
            maxiter=5,
        )
    ),
    kl_kwargs=dict( # shift transformed multivar gauss to best match true posterior
        minimize_kwargs=dict(
            name="M", xtol=delta, cg_kwargs=dict(name=None), maxiter=35
        )
    ),
    sample_mode="nonlinear_resample", # how steps are combined (samples + nonlin + KL)
)

In [None]:
# plot maximum of posterior (mode)
sig_map = signal(samples.pos)

plt.imshow(sig_map - signal_response_truth)
plt.colorbar()
plt.show()

Geometric Variational Inference (GeoVI) -> Jakob's MGVI + nonlin updates

In [None]:
assert False

In [None]:
n_it = 5 # no of SL + SN + M
delta = 1e-4 # degree of tolerance
n_samples = 4 # *2 antithetic samples per n_it

key, k_i, k_o = random.split(key, 3)

samples, state = jft.optimize_kl(
    lh,
    jft.Vector(lh.init(k_i)),
    n_total_iterations=n_it,
    n_samples=n_samples,
    key=k_o,
    draw_linear_kwargs=dict(
        cg_name="SL",
        cg_kwargs=dict(absdelta=delta * jft.size(lh.domain) / 10.0, maxiter=100),
    ),
    nonlinearly_update_kwargs=dict(
        minimize_kwargs=dict(
            name="SN",
            xtol=delta,
            cg_kwargs=dict(name=None),
            maxiter=5,
        )
    ),
    kl_kwargs=dict(
        minimize_kwargs=dict(
            name="M", xtol=delta, cg_kwargs=dict(name=None), maxiter=35
        )
    ),
    sample_mode="nonlinear_resample",
)

In [None]:
sig_mean, sig_std = jft.mean_and_std(tuple(signal(s) for s in samples))

plt.imshow(sig_mean)
plt.colorbar()
plt.show()

plt.imshow(sig_std)
plt.colorbar()
plt.show()

In [None]:
gp_mean, gp_std = jft.mean_and_std(tuple(gp(s) for s in samples))

plt.imshow(gp_mean)
plt.colorbar()
plt.show()

plt.imshow(gp_std)
plt.colorbar()
plt.show()

In [None]:
ps_mean, ps_std = jft.mean_and_std(tuple(ps(s) for s in samples))

plt.imshow(ps_mean)
plt.colorbar()
plt.show()

plt.imshow(ps_std)
plt.colorbar()
plt.show()

### Response function

In [None]:
class Response(jft.Model):
    def __init__(self, signal, mask):
        self.signal = signal
        self.mask = mask
        super().__init__(init=self.signal.init)

    def __call__(self, x):
        return self.mask * self.signal(x)


mask = jnp.ones(dims)
mask = mask.at[32:64, 32:64].set(0.0)

signal_response_masked = Response(signal, mask)

In [None]:
key, sub = random.split(key)
xi = jft.random_like(sub, signal_response_masked.domain)
res = signal_response_masked(xi)

plt.imshow(res)
plt.colorbar()
plt.show()

In [None]:
key, subkey = random.split(key)
pos_truth = jft.random_like(subkey, signal_response_masked.domain)
signal_response_masked_truth = signal_response_masked(pos_truth)

noise_truth = 0.1 * jft.random_like(subkey, signal_response_masked.target)

data_masked = signal_response_masked_truth + noise_truth

plt.imshow(data_masked)
plt.colorbar()
plt.show()

In [None]:
noise_cov_inv = lambda x: 0.1**-2 * x

lh_masked = jft.Gaussian(data_masked, noise_cov_inv).amend(signal_response_masked)

In [None]:
n_it = 5
delta = 1e-4
n_samples = 4

key, k_i, k_o = random.split(key, 3)

samples, state = jft.optimize_kl(
    lh_masked,
    jft.Vector(lh.init(k_i)),
    n_total_iterations=n_it,
    n_samples=n_samples,
    key=k_o,
    draw_linear_kwargs=dict(
        cg_name="SL",
        cg_kwargs=dict(absdelta=delta * jft.size(lh.domain) / 10.0, maxiter=100),
    ),
    nonlinearly_update_kwargs=dict(
        minimize_kwargs=dict(
            name="SN",
            xtol=delta,
            cg_kwargs=dict(name=None),
            maxiter=5,
        )
    ),
    kl_kwargs=dict(
        minimize_kwargs=dict(
            name="M", xtol=delta, cg_kwargs=dict(name=None), maxiter=35
        )
    ),
    sample_mode="nonlinear_resample",
)

In [None]:
sig_mean, sig_std = jft.mean_and_std(tuple(signal(s) for s in samples))

plt.imshow(sig_mean)
plt.colorbar()
plt.show()

plt.imshow(sig_std)
plt.colorbar()
plt.show()