In [1]:
import numpy as np
import jax
import jax.numpy as jnp
import optax
import matplotlib.pyplot as plt
from localization import datasets
from localization import models
from localization import samplers
from localization.experiments import supervise, autoencode, simulate, simulate_or_load, make_key
from localization.utils import ipr, plot_receptive_fields, plot_rf_evolution, build_gaussian_covariance, build_non_gaussian_covariance, entropy_sort, build_DRT
from scipy.special import erf
from tqdm import tqdm
import itertools

gaussian_cdf = lambda x: 0.5 * (erf(x/np.sqrt(2)) + 1)

In [2]:
# define config
config = dict(
seed=0,
num_dimensions=40,
dim=1,
# num_hiddens=60,
# num_hiddens=10,
gain=100,
# gain=0.01,
init_scale=0.1,
activation='relu',
model_cls=models.SimpleNet,
# model_cls=models.MLP,
use_bias=False,
optimizer_fn=optax.sgd,
# learning_rate=0.1,#3.0,
learning_rate=0.3,
batch_size=1000,
num_epochs=1000,
# datset_cls=datasets.NonlinearGPDataset,
dataset_cls=datasets.NonlinearGPCountDataset,
# xi=(5, 4,),
# xi=(0.5, 0.4,),
# xi=(5, 4, 0.3, 0.2, 0.1,),
# xi=(3, 0.1),
num_steps=1000,
adjust=(-1.0, 1.0),
class_proportion=0.5,
sampler_cls=samplers.EpochSampler,
init_fn=models.xavier_normal_init,
loss_fn='mse',
save_=True,
evaluation_interval=100,
)

# log config to wandb
wandb_ = False

In [11]:
gains = [0.01, 100]
hiddens = [1, 2, 40, 100]
xis = [(3., 0.1,), (0.1, 3,), (3, 3), (0.1, 0.1), (5, 1, 0.1)]

# run it
from tqdm import tqdm
for (gain, hidden, xi) in tqdm(itertools.product(gains, hiddens, xis)):
    print(gain, hidden, xi)
    config.update(dict(gain=gain, num_hiddens=hidden, xi=xi))
    weights, metrics = simulate_or_load(supervise=True, wandb_=False, **config)
    # print(metrics[-1,1])
    # _ = plot_receptive_fields(weights[-10:], num_cols=12, figsize=(15, 5), sort_fn=entropy_sort, ind=-1)
    fig, axs = plot_rf_evolution(weights[:,[0],:], figsize=(8, 4))
    # if gain == 100:
    #     break
    fig.suptitle(f"gain={gain}, hidden={hidden}, xi={xi}")
    fig.savefig(f"../thoughts/edge_counting/gain={gain}_hidden={hidden}_xi={xi}.png")
    plt.close(fig)

1it [00:00,  5.85it/s]

0.01 1 (3.0, 0.1)
Already simulated
0.01 1 (0.1, 3)
Already simulated


3it [00:00,  6.89it/s]

0.01 1 (3, 3)
Already simulated
0.01 1 (0.1, 0.1)
Already simulated


5it [00:00,  6.75it/s]

0.01 1 (5, 1, 0.1)
Already simulated
0.01 2 (3.0, 0.1)
Already simulated


7it [00:00,  7.85it/s]

0.01 2 (0.1, 3)
Already simulated
0.01 2 (3, 3)
Already simulated


9it [00:01,  7.48it/s]

0.01 2 (0.1, 0.1)
Already simulated
0.01 2 (5, 1, 0.1)
Already simulated


11it [00:01,  7.11it/s]

0.01 40 (3.0, 0.1)
Already simulated
0.01 40 (0.1, 3)
Already simulated


13it [00:01,  7.29it/s]

0.01 40 (3, 3)
Already simulated
0.01 40 (0.1, 0.1)
Already simulated


15it [00:02,  7.87it/s]

0.01 40 (5, 1, 0.1)
Already simulated
0.01 100 (3.0, 0.1)
Already simulated


16it [00:02,  7.96it/s]

0.01 100 (0.1, 3)
Already simulated


17it [00:02,  4.72it/s]

0.01 100 (3, 3)
Already simulated


19it [00:02,  4.99it/s]

0.01 100 (0.1, 0.1)
Already simulated
0.01 100 (5, 1, 0.1)
Already simulated


21it [00:03,  5.30it/s]

100 1 (3.0, 0.1)
Already simulated
100 1 (0.1, 3)
Already simulated


23it [00:03,  5.78it/s]

100 1 (3, 3)
Already simulated
100 1 (0.1, 0.1)
Already simulated


25it [00:03,  6.13it/s]

100 1 (5, 1, 0.1)
Already simulated
100 2 (3.0, 0.1)
Already simulated


27it [00:04,  5.77it/s]

100 2 (0.1, 3)
Already simulated
100 2 (3, 3)
Already simulated


29it [00:04,  6.37it/s]

100 2 (0.1, 0.1)
Already simulated
100 2 (5, 1, 0.1)
Already simulated


31it [00:04,  5.98it/s]

100 40 (3.0, 0.1)
Already simulated
100 40 (0.1, 3)
Already simulated


33it [00:05,  6.12it/s]

100 40 (3, 3)
Already simulated
100 40 (0.1, 0.1)
Already simulated


35it [00:05,  7.17it/s]

100 40 (5, 1, 0.1)
Already simulated
100 100 (3.0, 0.1)
Already simulated


37it [00:05,  7.42it/s]

100 100 (0.1, 3)
Already simulated
100 100 (3, 3)
Already simulated


39it [00:06,  6.43it/s]

100 100 (0.1, 0.1)
Already simulated
100 100 (5, 1, 0.1)


40it [00:06,  6.40it/s]

Already simulated



