In [1]:
import jax
import jax.numpy as jnp
import optax
import pandas
import numpy as np
import pickle
from tqdm import tqdm

In [2]:
# read data
df = pandas.read_csv("data/dataset.csv")
# keep only train/val data for modelling
df = df[df.fold != -1]

features = sorted([
    "H_modelled",
    "H_modelled_grad", 
    "H_modelled_lapl", 
    "v_modelled",
    "v_modelled_grad",
    "v_modelled_lapl",
    "boundary_proximity",
    "elevation1",
    "elevation1_gradient", 
    "elevation1_laplacian",
    "elevation2_gradient",
    "elevation2_laplacian",
    "elevation_diff",
])

feature_idxs = {feature: idx for idx, feature in enumerate(features)}
target = "H"

In [3]:
# calculate min and max values of features 
feature_mins = {_: float(df[_].min()) for _ in features + ["H"]}
feature_maxs = {_: float(df[_].max()) for _ in features + ["H"]}

# some fixes to ensure consistency
feature_mins["boundary_proximity"] = 0.0
feature_mins["H_modelled"] = 0.0
feature_mins["H"] = 0.0
feature_maxs["H"] = feature_maxs["H_modelled"]

In [4]:
# save these stats to reuse later
with open("data/stats.pickle", "wb") as dst:
    pickle.dump((feature_mins, feature_maxs), dst)

In [5]:
# normalise the data to be in range from 0 to 1
df_norm = df.copy()

for feature in features + ["H"]:
    vmin = feature_mins[feature]
    vmax = feature_maxs[feature]
    feature_norm = (df_norm[feature] - vmin) / (vmax - vmin)
    df_norm[feature] = feature_norm

In [6]:
# define data generator
def data_generator(df, batch_size, noise=0, rng_key=None):
    if noise and rng_key is None:
        raise ValueError("rng_key should be provided if noise is used")
    if rng_key is not None and isinstance(rng_key, int):
        rng_key = jax.random.PRNGKey(rng_key)
    df = df.copy()
    df = df.sample(frac=1).reset_index() # shuffle
    for idx in range(0, len(df), batch_size):
        batch = df.iloc[idx:min(idx + batch_size, len(df))]
        x, y = convert_to_jax_xy_pair(batch)
        if noise:
            for feature in x:
                rng_key, new_rng_key = jax.random.split(rng_key, 2)
                feature_noise = jax.random.normal(rng_key, x[feature].shape) * noise
                x[feature] = x[feature] + feature_noise
                rng_key = new_rng_key
            x["boundary_proximity"] = jnp.clip(x["boundary_proximity"], a_min=0.0)
        yield x, y


def convert_to_jax_xy_pair(df, features=features, target=target):
    x = {_: jnp.array(df[_]) for _ in features}
    y = jnp.array(df[target])    
    return x, y

In [7]:
# define a general multi-layer perceptron
def initialise_mlp_params(
    in_size, out_size, n_hidden_layers, n_hidden_units, 
    initialiser=jax.nn.initializers.he_uniform(),
    rng_key=jax.random.PRNGKey(42)
):
    params = []
    
    for l_idx, _ in enumerate(range(n_hidden_layers + 1)):
        key_w, rng_key = jax.random.split(rng_key, 2)
        if n_hidden_layers == 0:
            w_shape = (out_size, in_size)
            b_shape = out_size
        elif l_idx == 0:
            w_shape = (n_hidden_units, in_size)
            b_shape = n_hidden_units
        elif l_idx == n_hidden_layers:
            w_shape = (out_size, n_hidden_units)
            b_shape = out_size
        else:
            w_shape = (n_hidden_units, n_hidden_units)
            b_shape = n_hidden_units
        w = initialiser(key_w, w_shape)
        b = jnp.zeros(b_shape)
        params.append((w, b))

    return params


def build_mlp(activation, final_activation=None, squeeze=False):
    if final_activation is None:
        def linear(x):
            return x
        final_activation = jax.jit(linear)
        
    def mlp(params, x):
        for w, b in params[:-1]:
            x = jnp.dot(w, x) + b
            x = activation(x)
        w, b = params[-1]
        y = jnp.dot(w, x) + b
        y = final_activation(y)
        if squeeze:
            y = jnp.squeeze(y)
        return y

    return jax.jit(mlp)

In [8]:
# define DATHICE model which is boundary_proximity * f1(x) * ReLU(H_modelled + f2(x))
def build_dathice(mlp_bcs, mlp):
    def dathice(params, x):
        mlp_bcs_params, mlp_params = params
        p = jnp.array([x[_] for _ in sorted(x.keys())])
        mlp_bcs_output = mlp_bcs(mlp_bcs_params, p)
        mlp_output = mlp(mlp_params, p)
        d = x["boundary_proximity"]
        H_modelled = x["H_modelled"]
        bcs_factor = d * mlp_bcs_output
        H_corrected = H_modelled + mlp_output
        H_constrained = jax.nn.relu(H_corrected)
        y = bcs_factor * H_constrained
        return y
        
    return jax.jit(dathice)

In [9]:
# define loss functions
@jax.jit
def ae(true, pred):
    return jnp.abs(true - pred)


@jax.jit
def se(true, pred):
    return (true - pred)**2.0
    
    
@jax.jit
def logcosh(true, pred):
    return jnp.log(jnp.cosh(true - pred))


@jax.jit
def sle(true, pred): # only for positive values, which works here
    return (jnp.log(true + 1.0) - jnp.log(pred + 1.0))**2.0


def mean_aggregate(loss_metric):
    @jax.jit
    def aggregated(true, pred):
        loss_values = loss_metric(true, pred)
        return jnp.mean(loss_values)
    return aggregated


def scale_loss_metric(loss_metric):
    EPSILON = 1e-3
    @jax.jit
    def scaled_loss_metric(true, pred):
        loss_values = loss_metric(true, pred)
        loss_scales = loss_metric(true, 0) + EPSILON
        return loss_values / loss_scales
    return scaled_loss_metric


def combine_loss_metrics(loss_metrics, weights=None):
    if weights is None:
        weights = [1.0 for _ in loss_metrics]
    if len(loss_metrics) != len(weights):
        raise ValueError("The lengths of loss_metrics and weights should be equal.")
    @jax.jit
    def combined_loss_metric(true, pred):
        total = 0
        for loss_metric, weight in zip(loss_metrics, weights):
            total += (weight * loss_metric(true, pred))
        return total
    return combined_loss_metric
    

def get_loss_fn(batched_forward, loss_metric, aggregate=mean_aggregate):
    @jax.jit
    def loss_fn(params, x, true):
        pred = batched_forward(params, x)
        loss_value = aggregate(loss_metric)(true, pred)
        return loss_value
    return loss_fn

In [10]:
# whether to perform hyperparameter tuning or not
hyperparameter_tuning = False

if not hyperparameter_tuning:
    # a mock hyperparameter set
    hyperparameters = dict(
        activation_function="elu",
        nlayers1=4,
        nlayers2=4,
        nneurons1=8,
        nneurons2=8,
        init_method="lecun",
        loss="se",
        batch_size=128,
        grad_norm_clip=3,
        l2_regularization=0,
        base_lr=0.25,
        data_noise=1e-2,
        optimizer="sgd", 
        nepochs=500,
    )

else:
    """
    EXERCISE: Implement hyperparameter tuning
    HINT: You might want to reuse the code below to construct a cost function
    """
    raise NotImplementedError()

In [11]:
# build and train the final model
activation_fuction = {
    "swish": jax.nn.swish,
    "relu": jax.nn.relu,
    "softplus": jax.nn.softplus,
    "elu": jax.nn.elu,
}[hyperparameters["activation_function"]]

init_method = {
    "he": jax.nn.initializers.he_uniform,
    "glorot": jax.nn.initializers.glorot_uniform,
    "lecun": jax.nn.initializers.lecun_uniform,
}[hyperparameters["init_method"]]

# boundary-condition MLP
mlp_bcs = build_mlp(
    activation_fuction, 
    final_activation=jax.nn.softplus, 
    squeeze=True
)
# correction MLP
mlp = build_mlp(
    activation_fuction, 
    squeeze=True
)

dathice = build_dathice(mlp_bcs, mlp)

mlp_bcs_batched = jax.vmap(mlp_bcs, in_axes=(None, 0))
mlp_batched = jax.vmap(mlp, in_axes=(None, 0))
dathice_batched = jax.vmap(dathice, in_axes=(None, 0))

mlp_bcs_params = initialise_mlp_params(
    len(features), 
    1, 
    hyperparameters["nlayers1"], 
    hyperparameters["nneurons1"], 
    init_method()
)
mlp_params = initialise_mlp_params(
    len(features), 
    1, 
    hyperparameters["nlayers2"], 
    hyperparameters["nneurons2"], 
    init_method()
)

dathice_params = (mlp_bcs_params, mlp_params)

In [12]:
lr = hyperparameters["base_lr"] * np.sqrt(hyperparameters["batch_size"] / 32)

# in optax, optimizers are defined as operations on gradients
opt_chain = [
    optax.clip_by_global_norm(hyperparameters["grad_norm_clip"]), 
] + [optax.scale_by_adam()] if hyperparameters["optimizer"] == "adam" else [] + \
    [optax.scale_by_rms()] if hyperparameters["optimizer"] == "rms" else [] + [
    optax.add_decayed_weights(hyperparameters["l2_regularization"]),
    optax.scale(-lr), # !note the minus sign
]
optimizer = optax.chain(*opt_chain)

In [13]:
loss = {
    "ae": ae,
    "se": se,
    "logcosh": logcosh,
    "sle": sle,
}[hyperparameters["loss"]]

params = dathice_params

opt_state = optimizer.init(params)
loss_fn = get_loss_fn(dathice_batched, loss)
loss_fn_grad = jax.value_and_grad(loss_fn)

noise_key = jax.random.PRNGKey(42)


# one training step (i.e., for one batch)
@jax.jit
def step(params, x, y, opt_state):
    loss, grads = loss_fn_grad(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss, grads


with tqdm(total=hyperparameters["nepochs"], desc="") as pbar:
    for epoch in range(hyperparameters["nepochs"]):             
        noise_key, new_noise_key = jax.random.split(noise_key, 2)
    
        train_loss = 0
        n_steps = 0

        # here (assuming after hyperparameter tuning), we merge train and validation subsets
        for x, y in data_generator(
            df_norm, 
            hyperparameters["batch_size"], 
            noise=hyperparameters["data_noise"], 
            rng_key=noise_key
        ):
            params, opt_state, loss, grads = step(params, x, y, opt_state)
            train_loss += loss
            n_steps += 1
        train_loss /= n_steps
            
        noise_key = new_noise_key

        pbar.set_description(f"Train loss = {train_loss:.6f}")
        pbar.update(1)

Train loss = 0.006475: 100%|██████████| 500/500 [02:08<00:00,  3.88it/s]


In [14]:
# save the parameters and hyperparameters
# adjust the paths as needed
with open("params/params.pickle", "wb") as dst:
    pickle.dump(params, dst)
with open("hyperparameters/hyperparameters.pickle", "wb") as dst:
    pickle.dump(hyperparameters, dst)