## Data visualization

In [24]:
import numpy as np
import jax
import jax.numpy as jnp
import optax
import matplotlib.pyplot as plt
from localization import datasets
from localization import models
from localization import samplers
from localization.experiments import supervise, autoencode, simulate, simulate_or_load, make_key
from localization.utils import ipr, plot_receptive_fields, plot_rf_evolution, build_gaussian_covariance, build_non_gaussian_covariance, entropy_sort, build_DRT
from scipy.special import erf
from tqdm import tqdm
import itertools
import cblind as cb

In [35]:
config = dict(
    key=jax.random.PRNGKey(0),
    num_dimensions=100, 
    dim=1,
    num_exemplars=10000,
    xi=(1,),
    # gain=100,
    # dataset_cls=datasets.NortaDataset,
    # marginal_qdf=datasets.LaplaceQDF(),
    # marginal_qdf=datasets.GaussianQDF(),
    # marginal_qdf=datasets.UniformQDF(),
    # marginal_qdf=datasets.BernoulliQDF(),
    # marginal_qdf=datasets.AlgQDF(4),
    # dataset_cls=datasets.NonlinearGPDataset,
    adjust=(-1.0, 1.0),
    class_proportion=0.5,
)

### Ising Model

In [36]:
config_ = config.copy(); config_['xi'] = (0.3,) # (1.2,)
dataset = datasets.IsingDataset(**config_)
x_ising, _ = dataset[:10000]
cov_ising = jnp.cov(x_ising.T)
hist_ising, bins_ising = np.histogram(x_ising, bins=20, range=(-1.5,1.5), density=False)
hist_ising = hist_ising / hist_ising.sum()

### Gaussian

In [37]:
dataset = datasets.NonlinearGPCountDataset(gain=0.01, **config)
x_gaussian, _ = dataset[:10000]
cov_gaussian = jnp.cov(x_gaussian.T)
hist_gaussian, bins_gaussian = np.histogram(x_gaussian, bins=100, density=False)
hist_gaussian = hist_gaussian / hist_gaussian.sum()

### Alg(5)

In [38]:
dataset = datasets.NortaDataset(marginal_qdf=datasets.AlgQDF(k=5), **config)
x_alg, _ = dataset[:10000]
cov_alg = jnp.cov(x_alg.T)
hist_alg, bins_alg = np.histogram(x_alg, bins=100, range=(-5,5), density=False)
hist_alg = hist_alg / hist_alg.sum()

Approximate standard deviation: 0.6950598719782084


### Plotting

In [39]:
for model, (x, cov, hist, bins) in zip(
    ['ising', 'gaussian', 'alg5'],
    [(x_ising, cov_ising, hist_ising, bins_ising), 
     (x_gaussian, cov_gaussian, hist_gaussian, bins_gaussian),
     (x_alg, cov_alg, hist_alg, bins_alg)]
):
    
    # Sample
    idx = 1 if model == 'alg5' else 0 # just so gaussian and alg5 don't look misleading similar
    fig, ax = plt.subplots(figsize=(4,2))
    ax.plot(x[idx], color='#00356b'); ax.set_xticks([])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    if model == 'ising': ax.set_yticks([-1, -0.5, 0, 0.5, 1] )
    fig.savefig(f'fig2/samples/{model}.pdf', bbox_inches='tight')
    plt.close(fig)
    
    # Covariance
    fig, ax = plt.subplots(figsize=(4,4))
    im = ax.imshow(cov, cmap=cb.cbmap('cb.solstice'), vmin=-1, vmax=1)
    ax.set_xticks([]); ax.set_yticks([])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    fig.savefig(f'fig2/cov/{model}.pdf', bbox_inches='tight')
    plt.close(fig)
    
    # Marginal
    fig, ax = plt.subplots(figsize=(4,2))
    bin_width = bins[1] - bins[0]  # Calculate the width of each bin
    ax.bar(bins[:-1], hist, width=bin_width, align='edge', color='#00356b')
    ax.set_yticks([0.0, 0.2, 0.4, 0.6] if model == 'ising' else [0.0, 0.02, 0.04])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    fig.savefig(f'fig2/marginal/{model}.pdf', bbox_inches='tight')
    plt.close(fig)

  im = ax.imshow(cov, cmap=cb.cbmap('cb.solstice'), vmin=-1, vmax=1)
  im = ax.imshow(cov, cmap=cb.cbmap('cb.solstice'), vmin=-1, vmax=1)
  im = ax.imshow(cov, cmap=cb.cbmap('cb.solstice'), vmin=-1, vmax=1)
