In [1]:
### IMPORTS ###
from typing import Callable, Sequence, Any
from functools import partial
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import optax
import jaxopt
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import os
import pickle
from functions import Fourier, Mixture, Slope, Polynomial, WhiteNoise, Shift
from networks import MixtureNeuralProcess, MLP, MeanAggregator, SequenceAggregator, NonLinearMVN, ResBlock
print('cuda?', jax.devices(), jax.devices()[0].device_kind)

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".XX"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"

cuda? [CpuDevice(id=0)] cpu


In [2]:
### CONFIGURATION ###
# Test-configuration
dataset_size = 128
test_resolution = 512

# Train-configuration
num_posterior_mc = 1  # number of alatents to sample from p(Z | X, Y)
batch_size = 128  # number of functions to sample from p(Z)

kl_penalty = 1e-4  # Note to self: magnitude of the kl-divergence can take over in the loss
num_target_samples = 32
num_context_samples = 64

In [3]:
### DATA TRAINING DISTRIBUTION/DATA GENERATION ###
f1 = Fourier(n=4, amplitude=.5, period=1.0)
f2 = Fourier(n=2, amplitude=.5, period=1.0)
f3 = Fourier(n=6, amplitude=.5, period=2.0)
f4 = Fourier(n=3, amplitude=1.0, period=2.0)
m = Mixture([Shift(f1, y_shift=-2), Shift(f2, y_shift=0.0), Shift(f3, y_shift=2)])
nm = Mixture([WhiteNoise(m.branches[0], 0.05), WhiteNoise(m.branches[1], 0.2), WhiteNoise(m.branches[2], 0.1)])
fixed_seed = 12345
rng = jax.random.PRNGKey(fixed_seed)

In [4]:
### DEFINING SAMPLERS ###
def joint(
    module: nn.Module, 
    data_sampler: Callable[
        [nn.Module, flax.typing.VariableDict, flax.typing.PRNGKey], 
        tuple[jax.Array, jax.Array]
    ],
    key: flax.typing.PRNGKey, 
    return_params: bool = False
) -> tuple[jax.Array, jax.Array]:
    # Samples from p(Z, X, Y)
    key_param, key_rng, key_data = jax.random.split(key, 3)
    
    params = module.init({'param': key_param, 'default': key_rng}, jnp.zeros(()))
    xs, ys = data_sampler(module, params, key_data)

    if return_params:
        return xs, ys, params
    return xs, ys


def uniform(
    module: nn.Module, 
    params: flax.typing.VariableDict, 
    key: flax.typing.PRNGKey, 
    n: int,
    bounds: tuple[float, float]
) -> tuple[jax.Array, jax.Array]:
    
    # Samples from p(X, Y | Z) = p(Y | Z, X)p(X)
    key_xs, key_ys = jax.random.split(key)
    xs = jax.random.uniform(key_xs, (n,)) * (bounds[1] - bounds[0]) + bounds[0]

    ys = jax.vmap(module.apply, in_axes=(None, 0))(params, xs, rngs={'default': jax.random.split(key_ys, n)})

    return xs, ys

In [5]:
### SPECIFY WHICH FUNCTION-PRIOR TO LEARN ###
data_sampler = partial(
    joint, 
    WhiteNoise(f2, 0.1), 
    partial(uniform, n=num_target_samples + num_context_samples, bounds=(-1, 1))
)

In [6]:
### DEFINE THE SPECIFIC TEST CASE ###
def f(
    key: flax.typing.PRNGKey, 
    x: jax.Array, 
    noise_scale: float = 0.2, 
    mixture_prob: float = 0.5, 
    corrupt: bool = True
):
    noise = jax.random.normal(key, x.shape) * noise_scale
    return(-2-jnp.cos(2 * jnp.pi * x)) + corrupt * noise

rng, key_data, key_test, key_x = jax.random.split(rng, 4)
keys_data = jax.random.split(key_data, (dataset_size,))
keys_test = jax.random.split(key_test, (test_resolution,))

In [7]:
###GENERATE THE DATA###
for dataset_number in range(1, 2):
    xss_yss_unordered = []
    for i in (pbar := tqdm.trange(5000, desc='Generating data. ')):
        rng, key = jax.random.split(rng)    
        key_data, key_model = jax.random.split(key)
        xs, ys = jax.vmap(data_sampler)(jax.random.split(key_data, num=batch_size))
        xs, ys = xs[..., None], ys[..., None]
        xss_yss_unordered.append((xs, ys))
    
    print(f"{dataset_number} generation done!")
    with open(f"saved_datasets/training_data_{dataset_number}.pkl", 'wb') as document_to_write:
        pickle.dump(xss_yss_unordered, document_to_write)

Generating data. : 100%|██████████| 5000/5000 [03:54<00:00, 21.32it/s]


1 generation done!


In [22]:
xss_yss_unordered = []
for i in (pbar := tqdm.trange(5000, desc='Generating data. ')):
    rng, key = jax.random.split(rng)    
    key_data, key_model = jax.random.split(key)
    xs, ys = jax.vmap(data_sampler)(jax.random.split(key_data, num=batch_size))
    xs, ys = xs[..., None], ys[..., None]
    xss_yss_unordered.append((xs, ys))
        
with open(f"saved_datasets/evaluation_data.pkl", 'wb') as document_to_write:
    pickle.dump(xss_yss_unordered, document_to_write)

Generating data. : 100%|██████████| 5000/5000 [03:41<00:00, 22.57it/s]
