In [None]:
import numpy as np

setattr(np, "int", int)
setattr(np, "float", float)
setattr(np, "bool", bool)

In [None]:
# Imports
import functools as ft
from typing import List

import design_bench
import diffrax as dfx
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.random as random
import matplotlib.pyplot as plt
from jaxtyping import Array, Float, Integer, PRNGKeyArray, PyTree
from sklearn.model_selection import train_test_split
from diffusionlib.sampler import DDIMVP
from diffusionlib.optimizer import SMCDiffOptOptimizer

import optax  # manually fix some modules that use tuple type-hints

In [None]:
# Constants and Config
T = jnp.array(1000)
BETA_MIN = jnp.array(0.1) / T
BETA_MAX = jnp.array(20) / T

LEARNING_RATE = 1e3

key = random.PRNGKey(100)
task = design_bench.make("Superconductor-RandomForest-v0")

In [None]:
# Train-Test split
# NOTE: we don't use the y for training; just want associated correctly
train_x, val_x, train_y, val_y = train_test_split(task.x, task.y, test_size=0.1, random_state=0)

In [None]:
# Define the diffusion
def beta_t(
    t: Float[Array, " batch"],
    beta_min: Float[Array, ""] = BETA_MIN,
    beta_max: Float[Array, ""] = BETA_MAX,
) -> Float[Array, " batch"]:
    return beta_min + t * (beta_max - beta_min) / T


def alpha_t(t: Float[Array, " batch"]) -> Float[Array, " batch"]:
    return 1 - beta_t(t)


alpha = alpha_t(jnp.arange(T + 1))
cumulative_alpha_values = jnp.cumprod(alpha)

def c_t(t: Integer[Array, " batch"]) -> Float[Array, " batch"]:
    return jnp.sqrt(cumulative_alpha_values[t])


def d_t(t: Integer[Array, " batch"]) -> Float[Array, " batch"]:
    return jnp.sqrt(1 - cumulative_alpha_values[t])

def forward_marginal(
    key: PRNGKeyArray, x_0: Float[Array, "batch dim"], t: Integer[Array, " batch"]
) -> Float[Array, "batch dim"]:
    return c_t(t) * x_0 + d_t(t) ** 2 * random.normal(key, x_0.shape)

In [None]:
# Define network, loss, and training loop


class FullyConnectedWithTime(eqx.Module):
    """A simple model with multiple fully connected layers and some fourier features for the time
    variable.
    """

    layers: List[eqx.nn.Linear]

    def __init__(self, in_size: int, key: PRNGKeyArray):
        key1, key2, key3, key4 = jax.random.split(key, 4)
        out_size = in_size

        self.layers = [
            eqx.nn.Linear(in_size + 4, 256, key=key1),
            eqx.nn.Linear(256, 256, key=key2),
            eqx.nn.Linear(256, 256, key=key3),
            eqx.nn.Linear(256, out_size, key=key4),
        ]

    def __call__(self, x: Array, t: Array) -> Array:
        t_fourier = jnp.array(
            [t - 0.5, jnp.cos(2 * jnp.pi * t), jnp.sin(2 * jnp.pi * t), -jnp.cos(4 * jnp.pi * t)],
        )

        x = jnp.concatenate([x, t_fourier])

        for layer in self.layers[:-1]:
            x = jax.nn.relu(layer(x))

        x = self.layers[-1](x)

        return x


@jax.jit
@jax.value_and_grad
def loss(model: FullyConnectedWithTime, data: Array, key: PRNGKeyArray) -> Array:
    key1, key2 = random.split(key, 2)

    random_times = random.randint(key1, (data.shape[0],), minval=0, maxval=T)

    # NOTE: noise will match as both use key2
    noise = random.normal(key2, data.shape)
    noised_data = forward_marginal(key2, data, random_times[:, jnp.newaxis])

    # NOTE: rescale time to in [0, 1]
    output = jax.vmap(model)(noised_data, random_times / (T - 1))

    loss = jnp.mean((noise - output) ** 2)

    return loss


def single_loss_fn(model, data, t, key):
    noise = random.normal(key, data.shape)
    noised_data = forward_marginal(key, data, t)

    output = model(noised_data, t / (T - 1))

    return jnp.mean((noise - output) ** 2)


def batch_loss_fn(model, data, key):
    batch_size = data.shape[0]
    t_key, loss_key = jr.split(key)
    loss_key = jr.split(loss_key, batch_size)

    # Low-discrepancy sampling over t to reduce variance
    t = random.randint(t_key, (batch_size,), minval=0, maxval=T)

    loss_fn = ft.partial(single_loss_fn, model)
    loss_fn = jax.vmap(loss_fn)

    return jnp.mean(loss_fn(data, t, loss_key))


def dataloader(data, batch_size, *, key):
    dataset_size = data.shape[0]
    indices = jnp.arange(dataset_size)
    while True:
        key, subkey = jr.split(key, 2)
        perm = jr.permutation(subkey, indices)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield data[batch_perm]
            start = end
            end = start + batch_size


@eqx.filter_jit
def make_step(model, data, key, opt_state, opt_update):
    loss_fn = eqx.filter_value_and_grad(batch_loss_fn)
    loss, grads = loss_fn(model, data, key)
    updates, opt_state = opt_update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    key = jr.split(key, 1)[0]
    return loss, model, key, opt_state

In [None]:
# Optimisation hyperparameters
num_steps = 40_000
lr = 3e-4
batch_size = 256
print_every = 1_000

model_key, train_key, loader_key, sample_key = jr.split(key, 4)

data = task.normalize_x(train_x)
val_data = task.normalize_x(val_x)

model = FullyConnectedWithTime(data.shape[1], key=model_key)

opt = optax.adabelief(lr)
opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array))

total_value = 0
total_size = 0
for step, data_ in zip(range(num_steps), dataloader(data, batch_size, key=loader_key)):
    value, model, train_key, opt_state = make_step(model, data_, train_key, opt_state, opt.update)
    total_value += value.item()
    total_size += 1
    if (step % print_every) == 0 or step == num_steps - 1:
        key, sub_key = jr.split(key)
        val_loss = batch_loss_fn(model, val_data, sub_key)

        print(
            f"Step={step:05}",
            f"Train Loss={total_value / total_size:.4f}",
            f"Val Loss={val_loss:.4f}",
            sep="\t|\t",
        )

        total_value = 0
        total_size = 0

In [None]:
sampler = DDIMVP(
    num_steps=T,
    shape=(100, train_x.shape[1]),
    model=jax.vmap(model),  # assumes epsilon model (not score), so okay here!
    beta_min=BETA_MIN,
    beta_max=BETA_MAX,
    eta=1.0,  # NOTE: equates to using DDPM
)

In [None]:
sample_base = sampler.sample(key)
sample = task.denormalize_x(sample_base)
sample = jnp.clip(sample, a_min=0)

In [None]:
optimizer = SMCDiffOptOptimizer(base_sampler=sampler, gamma_t = lambda t: 1 - d_t(t))
particle_samples = optimizer.optimize(key, lambda x: -task.predict(task.denormalize_x(x)))

In [None]:
from pathlib import Path

# Load data
data_path = Path().absolute() / "design-bench" / "design_bench_data" / "superconductor"

x_files = sorted(data_path.glob("*x*.npy"))
y_files = sorted(data_path.glob("*y*.npy"))

x_data = jnp.vstack([jnp.load(file) for file in x_files])
y_data = jnp.vstack([jnp.load(file) for file in y_files])

In [None]:
y_data.max()

In [None]:
jnp.mean(task.predict(task.denormalize_x(particle_samples)) / 185)
