In [None]:
# Standard libraries
import autoroot
import pytorch_lightning
import os
import sys
from typing import Any, Sequence, Optional, Tuple, Iterator, Dict, Callable, Union
import json
import time
from tqdm.auto import tqdm
import numpy as np
from copy import copy
from glob import glob
from collections import defaultdict

# JAX/Flax
# If you run this code on Colab, remember to install flax and optax
# !pip install --quiet --upgrade flax optax
import jax
import jax.numpy as jnp
import jax.random as jrandom
import jax_dataloader as jdl
import optax
import equinox as eqx

# PyTorch for data loading
import torch
import torch.utils.data as data

# Logging with Tensorboard or Weights and Biases
# If you run this code on Colab, remember to install pytorch_lightning
# !pip install --quiet --upgrade pytorch_lightning
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

## Imports for plotting
import matplotlib.pyplot as plt
from IPython.display import set_matplotlib_formats

set_matplotlib_formats("svg", "pdf")  # For export
from matplotlib.colors import to_rgb
import seaborn as sns
import matplotlib

matplotlib.rcParams["lines.linewidth"] = 2.0
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
%matplotlib inline
%load_ext autoreload
%autoreload 2

## Data

In [None]:
def target_function(x):
    return np.sin(x * 10.0)


class RegressionDataset(data.Dataset):
    def __init__(self, num_points, seed):
        super().__init__()
        rng = np.random.default_rng(seed)
        self.x = rng.uniform(low=-1.0, high=1.0, size=num_points)
        self.y = target_function(self.x)

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, idx):
        return self.x[idx : idx + 1], self.y[idx : idx + 1]

In [None]:
num_points = 1_000
seed = 42


rng = np.random.default_rng(seed)
X = rng.uniform(low=-1.0, high=1.0, size=num_points)
y = target_function(X)
arr_ds = jdl.ArrayDataset(X[:, None], y[:, None])
dataloader = jdl.DataLoader(arr_ds, "jax", batch_size=2, shuffle=True)