In [None]:
import jax
import numpy as np
import xarray as xr

# import flax
import imageio
import matplotlib.pyplot as plt
import tqdm.notebook as tqdm

# import equinox as eq

In [None]:
# Load image, take a square crop from the center
# image_url = "https://github.com/AntonBaumannDE/fourier_features_MLP_tf2/raw/main/images/monkey.jpg"
image_url = "https://live.staticflickr.com/7492/15677707699_d9d67acf9d_b.jpg"
img = imageio.imread(image_url)[..., :3] / 255.0
c = [img.shape[0] // 2, img.shape[1] // 2]
r = 256
img = img[c[0] - r : c[0] + r, c[1] - r : c[1] + r]

In [None]:
fig = plt.figure()
fig.suptitle("Ground Truth", fontsize=14, fontweight="bold")
plt.imshow(img)
plt.show()

In [None]:
# Create input pixel coordinates in the unit square
coords = np.linspace(0, 1, img.shape[0], endpoint=False)
x_test = np.stack(np.meshgrid(coords, coords), -1)
test_data = [x_test, img]
train_data = [x_test[::2, ::2], img[::2, ::2]]

In [None]:
train_data[0].shape, train_data[0].shape, test_data[0].shape, test_data[0].shape

In [None]:
from flax import linen as nn


class MLP(nn.Module):
    """A simple MLP model."""

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=3)(x)
        x = nn.sigmoid(x)
        return x

In [None]:
# initialize model
model = MLP()
# get demo data
batch = train_data[0]

variables = model.init(jax.random.PRNGKey(0), batch)
output = model.apply(variables, batch)

In [None]:
# Fourier feature mapping
def input_mapping(x, B):
    if B is None:
        return x
    else:
        x_proj = (2.0 * np.pi * x) @ B.T
    return np.concatenate([np.sin(x_proj), np.cos(x_proj)], axis=-1)

In [None]:
rand_key = jax.random.PRNGKey(123)
mapping_size = 256
scale = 10

# generate random features
B = jax.random.normal(rand_key, (mapping_size, 2))
B *= scale

# transform to features
features = input_mapping(batch, B)

# check shape
# assert v.shape == (batch.shape[0],batch.shape[1],mapping_size*2)

In [None]:
train_data[0].shape, train_data[1].shape

In [None]:
coords.shape,

In [None]:
# Same as JAX version but using model.apply().
def mse(params, x_batched, y_batched):
    # Define the squared loss for a single pair (x,y)
    def squared_error(x, y):
        pred = model.apply(params, x)
        return jnp.inner(y - pred, y - pred) / 2.0

    # Vectorize the previous to compute the average of the loss on all samples.
    return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)

In [None]:
import optax

alpha = 1e-4
tx = optax.sgd(learning_rate=alpha)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)

In [None]:
for i in tqdm.trange(101):
    loss_val, grads = loss_grad_fn(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 10 == 0:
        print("Loss step {}: ".format(i), loss_val)