In [None]:
from collections.abc import Iterator

import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from jaxtyping import Array, Float, PRNGKeyArray

from differt2d.geometry import Point, normalize, path_length
from differt2d.scene import Scene

In [None]:
ax = plt.gca()
scene = Scene.square_scene_with_obstacle()
scene.plot(ax)

for _, _, path, _ in scene.all_valid_paths(approx=False):
    path.plot(ax)

plt.show()

In [None]:
key = jax.random.PRNGKey(12345)
key, key_example = jax.random.split(key, 2)


def random_scenes(
    key: PRNGKeyArray,
) -> Iterator[tuple[Scene, Float[Array, "2+num_walls*2 2"]]]:
    while True:
        key, key_ratio, key_tx_rx, key_num_walls, key_walls = jax.random.split(key, 5)
        ratio = jax.random.uniform(key_ratio)
        points = jax.random.uniform(key_tx_rx, (2, 2))
        tx = Point(point=points[0, :])
        rx = Point(point=points[1, :])
        scene = Scene.square_scene_with_obstacle(ratio=ratio)
        indices = jnp.arange(len(scene.objects), dtype=jnp.uint32)
        num_walls = jax.random.randint(
            key_num_walls, (), minval=0, maxval=len(scene.objects) + 1
        )
        wall_indices = jax.random.choice(
            key_walls, indices, shape=(num_walls,), replace=False
        )
        objects = [scene.objects[wall_index] for wall_index in wall_indices]

        points = jnp.vstack([points, *[obj.points for obj in objects]])

        yield (
            scene.with_transmitters(tx=tx).with_receivers(rx=rx).with_objects(*objects),
            points,
        )


def samples(
    key: PRNGKeyArray, order: int = 1
) -> Iterator[
    tuple[Float[Array, "2+num_walls*2 2"], Float[Array, "num_paths {order}+2 2"]]
]:
    for scene, points in random_scenes(key):
        paths = [
            path.points
            for _, _, path, _ in scene.all_valid_paths(
                min_order=order, max_order=order, approx=False
            )
        ]

        if len(paths) > 0:
            y = jnp.stack(paths)
        else:
            y = jnp.zeros((0, order + 2, 2))

        yield points, y


scenes = random_scenes(key_example)

In [None]:
ax = plt.gca()
scene, _ = next(scenes)
scene.plot(ax)

for _, _, path, _ in scene.all_valid_paths(approx=False):
    path.plot(ax)

plt.show()

In [None]:
key_model, key_train, key_test = jax.random.split(key, 3)
train_samples = samples(key_train)
test_samples = samples(key_test)

In [None]:
next(train_samples)

In [None]:
class WallsEmbed(eqx.Module):
    """A DeepSets model that extract information about walls."""

    mlp_phi: eqx.nn.Sequential
    mlp_rho: eqx.nn.MLP

    def __init__(
        self,
        in_size: int = 4,
        intermediate_size: int = 500,
        out_size: int = 100,
        width_size: int = 500,
        depth: int = 3,
        *,
        key: PRNGKeyArray,
    ):
        key1, key2 = jax.random.split(key, 2)
        self.mlp_phi = eqx.nn.Sequential(
            [
                eqx.nn.Lambda(jnp.ravel),
                eqx.nn.MLP(
                    in_size=in_size,
                    out_size=intermediate_size,
                    width_size=width_size,
                    depth=depth,
                    key=key1,
                ),
            ]
        )
        self.mlp_rho = eqx.nn.MLP(
            in_size=intermediate_size,
            out_size=out_size,
            width_size=width_size,
            depth=depth,
            key=key2,
        )

    def __call__(self, x):
        x = jax.vmap(self.mlp_phi)(x)
        x = jnp.sum(x, axis=0)
        x = self.mlp_rho(x)

        return x


class GeneratePath(eqx.Module):
    """A recurrent model that returns a path of a given order."""

    order: int
    cell: eqx.nn.GRUCell
    state2xy: eqx.nn.MLP

    def __init__(
        self,
        order: int = 1,
        input_size: int = 100,
        hidden_size: int = 10,
        width_size: int = 100,
        depth: int = 3,
        *,
        key: PRNGKeyArray,
    ):
        key1, key2 = jax.random.split(key, 2)

        self.order = order
        self.cell = eqx.nn.GRUCell(
            input_size=input_size, hidden_size=hidden_size, key=key1
        )
        self.state2xy = eqx.nn.MLP(
            in_size=hidden_size,
            out_size=2,
            width_size=width_size,
            depth=depth,
            key=key2,
        )

    def __call__(
        self,
        x: Float[Array, " input_size"],
        start: Float[Array, "2"],
        end: Float[Array, "2"],
    ) -> Float[Array, "{self.order}+2 2"]:
        def scan_fn(state, input_):
            xy = self.state2xy(state)
            return self.cell(input_, state), xy

        init_state = x

        _, path = jax.lax.scan(scan_fn, init_state, length=self.order)

        return jnp.vstack((start, path, end))


class EvaluatePath(eqx.Module):
    """A model that returns a probability that a path is valid."""

    order: int
    mlp: eqx.nn.MLP

    def __init__(
        self,
        width_size: int = 500,
        depth: int = 3,
        order: int = 1,
        *,
        key: PRNGKeyArray,
    ):
        self.order = order

        self.mlp = eqx.nn.MLP(
            in_size=2 + order * 2,
            out_size=1,
            width_size=width_size,
            depth=depth,
            final_activation=jax.nn.sigmoid,
            key=key,
        )

    def __call__(
        self,
        path: Float[Array, "{self.order}+2 2"],
        walls_embed: Float[Array, " "],
    ) -> Float[Array, " "]:
        x = jnp.ravel(path)
        x = self.mlp(x)

        return x


class Model(eqx.Module):
    """Global Deep-Learning model."""

    order: int
    walls_embed: WallsEmbed
    gen_path: GeneratePath
    val_path: EvaluatePath
    # cell: eqx.nn.GRUCell

    def __init__(self, order: int = 1, *, key: PRNGKeyArray):
        key1, key2, key3 = jax.random.split(key, 3)

        self.order = order
        self.walls_embed = WallsEmbed(key=key1)
        self.gen_path = GeneratePath(order=order, key=key2)
        self.val_path = EvaluatePath(order=order, key=key3)
        # self.cell = eqx.nn.GRUCell

    def __call__(
        self, x: Float[Array, "2+num_paths*2 2"]
    ) -> Float[Array, "num_paths {self.order}+2 2"]:
        # Processing input
        tx = x[0, :]
        rx = x[1, :]
        walls = x[2:, :].reshape(-1, 2, 2)  # [num_walls, 2, 2]
        starts = walls[:, 0, :]
        ends = walls[:, 1, :]

        # todo: pass those as parameters to force using specular reflection
        directions, _ = jax.vmap(normalize)(ends - starts)
        normals = directions.at[:, 0].set(directions[:, 1])
        normals = normals.at[:, 1].set(-directions[:, 0])

        init_state = self.walls_emded(walls)

        # Generate paths

        paths = []

        # question: how to properly handle a variable-sized output?
        while True:
            path = self.gen_path(state, tx, rx)

            # question: how to combine both the path and the walls_embed
            # i.e., the knowledge we have about the geometry
            # do we just 'concat' the inputs altogether?
            is_valid = self.val_path(path)

            if is_valid < threshold:
                break

            paths.append(path)

            # state = update_state(self)

        if len(paths) > 0:
            y = jnp.stack(paths)
        else:
            y = jnp.zeros((0, order + 2, 2))
        return y


def loss(
    model: Model,
    x: Float[Array, "2+num_walls*2 2"],
    y_true: Float[Array, "num_paths order+2 2"],
) -> Float[Array, " "]:
    y_pred = model(x)

    # sort-by length (ascending)

    paths_length_true = jax.vmap(path_length)(y_true)
    y_true = y_true[jnp.argsort(paths_length_true), ...]

    paths_length_pred = jax.vmap(path_length)(y_pred)
    y_pred = y_pred[jnp.argsort(paths_length_pred), ...]

    # check the number of paths

    num_paths_pred = y_pred.shape[0]
    num_paths_true = y_true.shape[0]

    if num_paths_pred < num_paths_true:
        penalty = jnp.sum(1 / jax.vmap(path_length)(y_true[num_paths_pred:, ...]))
        y_true = y_true[:num_paths_pred, ...]
    elif num_paths_pred > num_paths_true:
        penalty = jnp.sum(1 / jax.vmap(path_length)(y_pred[num_paths_true:, ...]))
        y_pred = y_[:num_paths_true, ...]
    else:
        penalty = 0.0

    # the longer the path, the less important

    weights = 1 / jax.vmap(path_length)(y_true)

    differences = (y_true - y_pred).sum(axis=(1, 2)) ** 2

    return (weights * differences).sum() + penalty

In [None]:
model = Model(key=key_model)
optim = optax.adam(learning_rate=1e-3)

In [None]:
loss(model, *next(train_samples))

In [None]:
eqx.tree_pprint(model)

In [None]:
def train(
    model: Model,
    train_samples: Iterator,
    test_samples: Iterator,
    optim: optax.GradientTransformation,
    steps: int = 10_000,
    print_every: int = 1_000,
    num_test_samples: int = 100,
):
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    @eqx.filter_jit
    def make_step(
        model,
        opt_state,
        x,
        y,
    ):
        loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
        updates, opt_state = optim.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value

    with trange(steps, desc="", unit=" steps", leave=False) as bar:
        for step, (x_train, y_train) in zip(bar, train_samples, strict=False):
            model, opt_state, train_loss = make_step(model, opt_state, x_train, y_train)
            if (step % print_every) == 0 or (step == steps - 1):
                test_loss = 0
                for _, (x_test, y_test) in zip(
                    range(num_test_samples), test_samples, strict=False
                ):
                    test_loss += loss(model, x_test, y_test)
                test_loss /= num_test_samples

                bar.set_description(
                    f"train_loss = {float(train_loss):.1f}, test_loss = {float(test_loss):.1f}"
                )

    return model