In [2]:
### IMPORTS ###
from typing import Callable, Sequence, Any
from functools import partial
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import optax
import jaxopt
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import os
import pickle
from functions import Fourier, Mixture, Slope, Polynomial, WhiteNoise, Shift
from networks import MixtureNeuralProcess, MLP, MeanAggregator, SequenceAggregator, NonLinearMVN, ResBlock
print('cuda?', jax.devices(), jax.devices()[0].device_kind)

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".XX"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"

fixed_seed = 12345
rng = jax.random.PRNGKey(fixed_seed)

cuda? [CpuDevice(id=0)] cpu


In [3]:
epochs = 2500 #actually 5000
no_pacing_function = jnp.ones(2500)
step_pacing_function = jnp.concatenate([jnp.full(1250, 0.3), jnp.ones(1250)])
fixedexp_pacing_function = jnp.concatenate([jnp.full(500, 0.075), jnp.full(500, 0.15), jnp.full(500, 0.3), jnp.full(500, 0.6), jnp.full(500, 1)])#start percentage 0.075, inc 2

#the pacing function should contain as a function of the current "progress" a percentage of the data that we access
def create_curriculum(all_xs_sorted, all_ys_sorted, pacing_function, last_key):
    mini_batches = []
    for i in (pbar := tqdm.trange(5000 // 2, desc='Creating batches.')):
        last_key, subkey = jax.random.split(last_key)
        indices1 = jax.random.permutation(last_key, int(pacing_function[i]*len(all_xs_sorted)))[:128]
        xs1 = all_xs_sorted[indices1]
        ys1 = all_ys_sorted[indices1]
        last_key, subkey = jax.random.split(last_key)
        indices2 = jax.random.permutation(last_key, int(pacing_function[i]*len(all_xs_sorted)))[:128]
        xs2 = all_xs_sorted[indices2]
        ys2 = all_ys_sorted[indices2]
        mini_batches.append((xs1, ys1, xs2, ys2))
    return jnp.asarray(mini_batches)

In [8]:
for i in range(1, 2):
    xss_yss_unordered = []
    with open(f"saved_datasets/training_data_{i}.pkl", "rb") as file:
        xss_yss_unordered = pickle.load(file)
    all_xs_unordered = jnp.concatenate([xs for xs, ys in xss_yss_unordered])
    all_ys_unordered = jnp.concatenate([ys for xs, ys in xss_yss_unordered])
    unordered_mini_batches_train = create_curriculum(all_xs_unordered, all_ys_unordered, no_pacing_function, rng)
    with open(f"saved_batches/train_no_curr_{i}.pkl", "wb") as file:
        pickle.dump(unordered_mini_batches_train, file)
    print(f"{i} train batches done.")
    unordered_mini_batches_validate = create_curriculum(all_xs_unordered, all_ys_unordered, no_pacing_function, rng)
    with open(f"saved_batches/validation_data_{i}.pkl", "wb") as file:
        pickle.dump(unordered_mini_batches_validate, file)
    print(f"{i} validate batches done.")

Creating batches.: 100%|██████████| 2500/2500 [49:49<00:00,  1.20s/it] 


1 train batches done.


Creating batches.: 100%|██████████| 2500/2500 [46:53<00:00,  1.13s/it] 


1 validate batches done.


In [5]:
with open(f"saved_datasets/evaluation_data.pkl", "rb") as file:
    xss_yss_unordered = pickle.load(file)
all_xs_unordered = jnp.concatenate([xs for xs, ys in xss_yss_unordered])
all_ys_unordered = jnp.concatenate([ys for xs, ys in xss_yss_unordered])
unordered_mini_batches_train = create_curriculum(all_xs_unordered, all_ys_unordered, no_pacing_function, rng)
with open(f"saved_batches/evaluation_batches.pkl", "wb") as file:
    pickle.dump(unordered_mini_batches_train, file)
print("evaluation batches done.")

Creating batches.: 100%|██████████| 2500/2500 [51:30<00:00,  1.24s/it]


evaluation batches done.


In [5]:
for i in range(1):
    all_xs_ordered = []
    all_ys_ordered = []
    with open(f"saved_datasets/ordered_mini_training_data_xs_{i}.pkl", "rb") as file:
        all_xs_ordered = pickle.load(file)
    with open(f"saved_datasets/ordered_mini_training_data_ys_{i}.pkl", "rb") as file:
        all_ys_ordered = pickle.load(file)
    single_step_mini_batches_train = create_curriculum(all_xs_ordered, all_ys_ordered, step_pacing_function, rng)
    with open(f"saved_batches/train_step_curr_{i}.pkl", "wb") as file:
        pickle.dump(single_step_mini_batches_train, file)
    print(f"{i} train batches done.")

Creating batches.: 100%|██████████| 2500/2500 [28:49<00:00,  1.45it/s]


0 train batches done.


In [5]:
with open(f"saved_batches/train_step_curr_{i}.pkl", "wb") as file:
    pickle.dump(single_step_mini_batches_train, file)
print(f"{i} train batches done.")

0 train batches done.


In [4]:
for i in range(1):
    all_xs_ordered = []
    all_ys_ordered = []
    with open(f"saved_datasets/ordered_mini_training_data_xs_{i}.pkl", "rb") as file:
        all_xs_ordered = pickle.load(file)
    with open(f"saved_datasets/ordered_mini_training_data_ys_{i}.pkl", "rb") as file:
        all_ys_ordered = pickle.load(file)
    multi_step_mini_batches_train = create_curriculum(all_xs_ordered, all_ys_ordered, fixedexp_pacing_function, rng)
    with open(f"saved_batches/train_fixedexp_curr_{i}.pkl", "wb") as file:
        pickle.dump(multi_step_mini_batches_train, file)
    print(f"{i} train batches done.")

Creating batches.: 100%|██████████| 2500/2500 [17:14<00:00,  2.42it/s]


0 train batches done.


In [4]:
for i in range(1):
    all_xs_ordered = []
    all_ys_ordered = []
    with open(f"saved_datasets/ordered_normal_training_data_xs_{i}.pkl", "rb") as file:
        all_xs_ordered = pickle.load(file)
    with open(f"saved_datasets/ordered_normal_training_data_ys_{i}.pkl", "rb") as file:
        all_ys_ordered = pickle.load(file)
    multi_step_normal_batches_train = create_curriculum(all_xs_ordered, all_ys_ordered, fixedexp_pacing_function, rng)
    with open(f"saved_batches/bootstrap_fixedexp_curr_{i}.pkl", "wb") as file:
        pickle.dump(multi_step_normal_batches_train, file)
    print(f"{i} train batches done.")

Creating batches.: 100%|██████████| 2500/2500 [22:01<00:00,  1.89it/s]


0 train batches done.


In [6]:
for i in range(1):
    all_xs_ordered = []
    all_ys_ordered = []
    with open(f"saved_datasets/ordered_normal_training_data_xs_{i}.pkl", "rb") as file:
        all_xs_ordered = pickle.load(file)
    with open(f"saved_datasets/ordered_normal_training_data_ys_{i}.pkl", "rb") as file:
        all_ys_ordered = pickle.load(file)
    single_step_normal_batches_train = create_curriculum(all_xs_ordered, all_ys_ordered, step_pacing_function, rng)
    with open(f"saved_batches/bootstrap_step_curr_{i}.pkl", "wb") as file:
        pickle.dump(single_step_normal_batches_train, file)
    print(f"{i} train batches done.")

Creating batches.: 100%|██████████| 2500/2500 [34:29<00:00,  1.21it/s]


0 train batches done.
