In [2]:
from datetime import datetime

from pathlib import Path
import warnings

import os
os.environ['XLA_FLAGS'] = f'--xla_force_host_platform_device_count={2}'

import re
import pickle
import pandas as pd

import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
import optax

from src.config.core import Config
from src.config.sampler import Sampler
from src.config.data import DatasetType
import src.dataset as ds
import src.training.utils as train_utils
import src.inference.utils as inf_utils
import src.visualization as viz
from src.config.data import Task
from src.inference.evaluation import evaluate_de, evaluate_bde

import os

key = jax.random.PRNGKey(0)
jax.random.split(key, 2), jax.random.split(key, 3)

(Array([[4146024105,  967050713],
        [2718843009, 1272950319]], dtype=uint32),
 Array([[2467461003,  428148500],
        [3186719485, 3840466878],
        [2562233961, 1946702221]], dtype=uint32))

In [None]:
n_cycles = 2
cycle_length = 2500
n_samples_per_cycle = 500
exploration_ratio = 0.1
step_size_sampling = 0.01

def _explore(step_count: int) -> bool:
    return (step_count % cycle_length) < cycle_length - n_samples_per_cycle

def _scheduler_fn(step_count: int) -> float:
    explore = _explore(step_count)
    return jax.lax.cond(
        explore,
        lambda x: 0.1,
        lambda x: 1.e-8,
        1  # dummy operand
    )



In [5]:
from matplotlib import pyplot as plt
import numpy as np

n_samples = 12000
n_samples_per_cycle = 500
n_cycles = 1
cycle_length = n_samples // n_cycles
step_size_init = 2.e-6

def _explore(step_count: int) -> bool:
    """Determine if the current step is in the exploration phase."""
    return (step_count % cycle_length) < cycle_length - n_samples_per_cycle

def _scheduler_fn(step_count: int) -> jax.Array:
    cos_out = jnp.cos(jnp.pi * (step_count % cycle_length) / cycle_length) + 1
    step_size = 0.5 * cos_out * step_size_init
    return step_size

step_sizes = np.array([_scheduler_fn(i) for i in range(n_samples)])
explore = np.array([_explore(i) for i in range(n_samples)])
# sampling_points = np.ma.masked_where(explore, step_sizes)
# fig, ax = plt.subplots()
# ax.plot(step_sizes, lw=2, ls="--", color="r", label="Exploration stage")
# ax.plot(sampling_points, lw=2, ls="-", color="k", label="Sampling stage")
# plt.show()

print(step_sizes[~explore])

[8.55511395e-09 8.52096083e-09 8.48692672e-09 8.45289261e-09
 8.41897752e-09 8.38512193e-09 8.35132585e-09 8.31759017e-09
 8.28391311e-09 8.25035595e-09 8.21679791e-09 8.18335977e-09
 8.14998113e-09 8.11666290e-09 8.08340328e-09 8.05020317e-09
 8.01706346e-09 7.98404187e-09 7.95102117e-09 7.91811949e-09
 7.88527732e-09 7.85249465e-09 7.81983100e-09 7.78716824e-09
 7.75462361e-09 7.72207986e-09 7.68965513e-09 7.65728991e-09
 7.62498420e-09 7.59273799e-09 7.56055130e-09 7.52848361e-09
 7.49641682e-09 7.46446815e-09 7.43257988e-09 7.40075112e-09
 7.36898187e-09 7.33733163e-09 7.30568184e-09 7.27415062e-09
 7.24267979e-09 7.21126803e-09 7.17991577e-09 7.14862347e-09
 7.11739068e-09 7.08621739e-09 7.05516356e-09 7.02416880e-09
 6.99323399e-09 6.96235913e-09 6.93154334e-09 6.90078750e-09
 6.87015067e-09 6.83951384e-09 6.80899603e-09 6.77853818e-09
 6.74813982e-09 6.71780098e-09 6.68758160e-09 6.65736177e-09
 6.62726141e-09 6.59716148e-09 6.56718013e-09 6.53725873e-09
 6.50739684e-09 6.477653