In [1]:
import os
import numpy as np
import jax
import jax.numpy as jnp
import optax
import matplotlib.pyplot as plt
from localization import datasets
from localization import models
from localization import samplers
from localization.experiments.batched_online import simulate, make_key

from utils import ipr, entropy, entropy_sort, mean_sort, var_sort, plot_receptive_fields, plot_rf_evolution

def simulate_or_load(**kwargs):
    path_key = make_key(**kwargs)
    if path_key + '.npz' in os.listdir('../localization/results/weights'):
        print('Already simulated')
        data = np.load('../localization/results/weights/' + path_key + '.npz', allow_pickle=True)
        weights_, metrics_ = data['weights'], data['metrics']
    else:
        print('Simulating')
        weights_, metrics_ = simulate(**kwargs)
    return weights_, metrics_

def build_gaussian_covariance(L, xi):
    C = np.abs(jnp.tile(jnp.arange(L)[:, jnp.newaxis], (1, L)) - jnp.tile(jnp.arange(L), (L, 1)))
    C = jnp.minimum(C, L - C)
    C = np.exp(-C ** 2 / (xi ** 2))
    return C

config_ = dict(
    # data config
    num_dimensions=40, xi1=2, xi2=1,
    # num_dimensions=100, xi1=6, xi2=3,
    dataset_cls=datasets.NonlinearGPDataset,
    batch_size=1000,
    support=(-1, 1), # defunct
    class_proportion=0.5,
    # model config
    model_cls=models.SimpleNet,
    num_hiddens=40,
    # num_hiddens=100,
    activation='relu',
    use_bias=False,
    sampler_cls=samplers.EpochSampler,
    init_fn=models.xavier_normal_init,
    init_scale=1.,
    # learning config
    num_epochs=5000,
    # num_epochs=20000,
    evaluation_interval=100,
    optimizer_fn=optax.sgd,
    learning_rate=5.,
    # experiment config
    seed=0,
    save_=True,
    wandb_=False,
)

from scipy.optimize import curve_fit
gabor_real = lambda x, c, a, x0, k0: c * np.cos(k0 * (x - x0)) * np.exp(-(x - x0) ** 2 / a ** 2)

def fit(weights):
    K, n = weights.shape
    p, var = np.zeros((K, 4)), np.zeros((K, 4))
    
    for k in range(K):    
        x0_init = np.argmax( np.abs(weights[k]) )
        try:
            x = np.arange(n) # np.linspace(0, n, 25 * n)
            y = weights[k] # np.tile(weights[k], 25).T.flatten()
            p_, cov = curve_fit(gabor_real, x, y, p0=[0.5, 1, x0_init, 0.5], bounds=([-1, 0, -np.inf, -2], [1, np.inf, np.inf, 2]))
            p[k] = p_
            var[k] = np.diag(cov)
        except Exception:
            p[k] = var[k] = np.nan
    
    return p, var