In [None]:
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax

from differt2d.geometry import Point, normalize
from differt2d.scene import Scene

In [None]:
key = jax.random.PRNGKey(12345)
key, key_example = jax.random.split(key, 2)


def random_scenes(key):
    while True:
        key, key_ratio, key_tx_rx, key_num_walls, key_walls = jax.random.split(key, 5)
        ratio = jax.random.uniform(key_ratio)
        points = jax.random.uniform(key_tx_rx, (2, 2))
        tx = Point(point=points[0, :])
        rx = Point(point=points[1, :])
        scene = Scene.square_scene_with_obstacle(ratio=ratio)
        indices = jnp.arange(len(scene.objects), dtype=jnp.uint32)
        num_walls = jax.random.randint(
            key_num_walls, (), minval=0, maxval=len(scene.objects) + 1
        )
        wall_indices = jax.random.choice(
            key_walls, indices, shape=(num_walls,), replace=False
        )
        objects = [scene.objects[wall_index] for wall_index in wall_indices]

        points = jnp.vstack([points, *[obj.points for obj in objects]])

        yield (
            scene.with_transmitters(tx=tx).with_receivers(rx=rx).with_objects(*objects),
            points,
        )


def samples(key, order=1):
    for scene, points in random_scenes(key):
        paths = [
            path.points
            for _, _, path, _ in scene.all_valid_paths(
                min_order=order, max_order=order, approx=False
            )
        ]

        if len(paths) > 0:
            y = jnp.stack(paths)
            assert y.shape == (len(paths), order + 2, 2)
        else:
            y = jnp.zeros((0, order + 2, 2))

        yield points, y


scenes = random_scenes(key_example)

In [None]:
ax = plt.gca()
scene, _ = next(scenes)
scene.plot(ax)

for _, _, path, _ in scene.all_valid_paths(approx=False):
    path.plot(ax)

plt.show()

In [None]:
key_model, key_train, key_test = jax.random.split(key, 3)
train_samples = samples(key_train)
test_samples = samples(key_test)

In [None]:
next(train_samples)

In [None]:
class Phi(eqx.Module):
    layers: list

    def __init__(self, *, key):
        key1, key2 = jax.random.split(key, 2)
        self.layers = [
            eqx.nn.Conv1d(in_channels=2, out_channels=1, kernel_size=5, key=key1),
            jax.nn.relu,
            eqx.nn.MLP(in_size=2, out_size=10, width_size=500, depth=3, key=key2),
        ]

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)

        return x


class Rho(eqx.Module):
    layers: list

    def __init__(self, *, key):
        key1, key2, key3 = jax.random.split(key, 3)
        self.layers = [
            eqx.nn.Linear(6480, 500, key=key1),
            jax.nn.relu,
            eqx.nn.Linear(500, 50, key=key2),
            jax.nn.relu,
            eqx.nn.Linear(50, 10, key=key3),
            jax.nn.relu,
        ]

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)

        x = x * jnp.arange(10)  # Each of the 10 outputs is assigned a digit

        return x.sum()


class Tau(eqx.Module):
    layers: list

    def __init__(self, *, key):
        key1, key2, key3 = jax.random.split(key, 3)
        self.layers = [
            eqx.nn.Linear(6480, 500, key=key1),
            jax.nn.relu,
            eqx.nn.Linear(500, 50, key=key2),
            jax.nn.relu,
            eqx.nn.Linear(50, 10, key=key3),
            jax.nn.relu,
        ]

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)

        x = x * jnp.arange(10)  # Each of the 10 outputs is assigned a digit

        return x.sum()


class Validity(eqx.Module):
    layers: list

    def __init__(self, *, key):
        key1, key2 = jax.random.split(key, 2)
        self.layers = [
            eqx.nn.Conv2d(1, 10, kernel_size=5, key=key1),
            eqx.nn.MaxPool2d(input_size=2),
            jax.nn.relu,
            eqx.nn.Conv2d(10, 20, kernel_size=5, key=key2),
            eqx.nn.MaxPool2d(kernel_size=2),
            jax.nn.relu,
        ]

    def __call__(self, path, walls_embedding):
        for layer in self.layers:
            x = layer(x)

        return x


class Model(eqx.Module):
    phi: Phi
    rho: Rho
    tau: Tau

    def __init__(self, *, key):
        key1, key2, key3 = jax.random.split(key, 3)

        self.phi = Phi(key=key1)
        self.rho = Rho(key=key2)
        self.tau = Tau(key=key3)

    def __call__(self, x):
        tx = x[0, :]
        rx = x[1, :]
        walls = x[2:, :].reshape(-1, 2, 2)  # [num_walls, 2, 2]
        starts = walls[:, 0, :]
        ends = walls[:, 1, :]
        directions, _ = jax.vmap(normalize)(ends - starts)
        normals = directions.at[:, 0].set(directions[:, 1])
        normals = normals.at[:, 1].set(-directions[:, 0])

        x = walls
        print(x.shape)

        x = jax.vmap(self.phi)(x)

        print(x.shape)

        paths = []
        paths = jnp.zeros((0, order + 2, 2))

        while True:
            path = new_path(paths, ...)
            p = validity(path)

        x = jnp.expand_dims(x, axis=1)
        x = jax.vmap(self.phi)(x)
        # We sum over `num_images`
        x = jnp.sum(x, axis=0)
        x = jnp.ravel(x)
        x = self.rho(x)

        if len(paths) > 0:
            y = jnp.stack(paths)
            assert y.shape == (len(paths), order + 2, 2)
        else:
            y = jnp.zeros((0, order + 2, 2))

        return y


def loss(
    model,
    x,
    y_true,
):
    y_pred = model(x)
    return (y_true - y_pred).sum() ** 2

In [None]:
model = Model(key=key_model)
optim = optax.adam(learning_rate=1e-3)

x, _ = next(train_samples)

model(x)

In [None]:
def train(
    model,
    train_samples,
    test_samples,
    optim,
    steps=10_000,
    print_every=1_000,
    num_test_samples=100,
):
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    @eqx.filter_jit
    def make_step(
        model,
        opt_state,
        x,
        y,
    ):
        loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
        updates, opt_state = optim.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss_value

    with trange(steps, desc="", unit=" steps", leave=False) as bar:
        for step, (x_train, y_train) in zip(bar, train_samples, strict=False):
            model, opt_state, train_loss = make_step(model, opt_state, x_train, y_train)
            if (step % print_every) == 0 or (step == steps - 1):
                test_loss = 0
                for _, (x_test, y_test) in zip(
                    range(num_test_samples), test_samples, strict=False
                ):
                    test_loss += loss(model, x_test, y_test)
                test_loss /= num_test_samples

                bar.set_description(
                    f"train_loss = {float(train_loss):.1f}, test_loss = {float(test_loss):.1f}"
                )

    return model