# Demo: Neural Implicit Functions (NIF)

In [1]:
import sys, os
from pyprojroot import here


# spyder up to find the root, local
root = here(project_files=[".root"])
local = here(project_files=[".local"])

# append to path
sys.path.append(str(root))
sys.path.append(str(local))

In [2]:
from pathlib import Path
import argparse
import wandb
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

# # Ensure TF does not see GPU and grab all GPU memory.
# import tensorflow as tf
# tf.config.set_visible_devices([], device_type='GPU')

# os.environ["JAX_PLATFORM_NAME"] = "CPU"
# # ENSURE JAX DOESNT PREALLOCATE
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = str(False)

import jax
import jax.random as jrandom
import jax.numpy as jnp
import equinox as eqx
from ml4ssh._src.data import make_mini_batcher
from ml4ssh._src.io import load_object, save_object
from ml4ssh._src.viz import create_movie, plot_psd_spectrum, plot_psd_score
from ml4ssh._src.utils import get_meshgrid, calculate_gradient, calculate_laplacian

# import parsers
from data import get_data_args, load_data
from preprocess import add_preprocess_args, preprocess_data
from features import add_feature_args, feature_transform
from split import add_split_args, split_data
from model import add_model_args, get_model
from loss import add_loss_args, get_loss_fn
from logger import add_logger_args
from optimizer import add_optimizer_args, get_optimizer
from postprocess import add_postprocess_args, postprocess_data, generate_eval_data
from evaluation import add_eval_args, get_rmse_metrics, get_psd_metrics

%matplotlib inline
%load_ext autoreload
%autoreload 2

2022-06-02 17:58:46.598280: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory


In [3]:
parser = argparse.ArgumentParser()

# logger
parser = add_logger_args(parser)

# data
parser = get_data_args(parser)

# preprocessing, feature transform, split
parser = add_preprocess_args(parser)
parser = add_feature_args(parser)
parser = add_split_args(parser)

# model, optimizer, loss
parser = add_model_args(parser)
parser = add_optimizer_args(parser)
parser = add_loss_args(parser)

# postprocessing, metrics
parser = add_postprocess_args(parser)
parser = add_eval_args(parser)

# parse args
args = parser.parse_args([])

# # jeanzay specific
# args.train_data_dir = "/gpfsdswork/projects/rech/cli/uvo53rl/data/data_challenges/ssh_mapping_2021/train/"
# args.ref_data_dir = "/gpfsdswork/projects/rech/cli/uvo53rl/data/data_challenges/ssh_mapping_2021/ref/"
# args.test_data_dir = "/gpfsdswork/projects/rech/cli/uvo53rl/data/data_challenges/ssh_mapping_2021/test/"
# args.log_dir = "/gpfswork/rech/cli/uvo53rl/logs"

# training args
args.batch_size = 4096
args.n_epochs = 1

# model args
args.model = "nif"
args.block = "siren"
args.activation = "sine"
args.julian_time = True

# smoke test
args.smoke_test = True

# logging stuff
args.wandb_mode = "disabled"
args.wandb_resume = True
# ige/nerf4ssh/kx2nr6qb
args.id = None # "mikf2n1v" # "2uuq7tks" "kx2nr6qb"
# args.entity = "ige"

In [4]:
# init wandb logger
wandb.init(
    id=args.id,
    config=args,
    mode=args.wandb_mode,
    project=args.project,
    entity=args.entity,
    dir=args.log_dir,
    resume=args.wandb_resume
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.




In [5]:
%%time

# load data
data = load_data(args)

# preprocess data
data = preprocess_data(data, args)


# feature transformation
data, scaler = feature_transform(data, args)

100%|██████████| 6/6 [00:00<00:00, 15.48it/s]


CPU times: user 2.14 s, sys: 632 ms, total: 2.77 s
Wall time: 2.89 s


In [6]:



# split data
xtrain, ytrain, xvalid, yvalid = split_data(data, args)

args.in_dim = xtrain.shape[-1]
args.n_train = xtrain.shape[0]
args.n_valid = xvalid.shape[0]

wandb.config.update(
    {
        "in_dim": args.in_dim,
        "n_train": args.n_train,
        "n_valid": args.n_valid,
    }
)

In [7]:
import equinox as eqx
import numpy as np

In [84]:
n_spatial = 2
n_temporal = 1

total_n_dims = n_spatial + n_temporal
print(f"Total Dims: {total_n_dims}")

# create random matrix
inputs = np.random.randn(total_n_dims)
print(f"Input Dims: {inputs.shape[0]}")

Total Dims: 3
Input Dims: 3


In [131]:
spatial_dims = [2, 5, 10, 2, 1]
print(f"Hidden Dims (Shape Net): {spatial_dims}")
param_dims = list(map(lambda x: x + 1, spatial_dims))
print(f"Param Dims (Param Net): {param_dims}")
params = np.random.randn(sum(param_dims))
print(f"Param Dims (Param Net): {params.shape[0]}")

Hidden Dims (Shape Net): [2, 5, 10, 2, 1]
Param Dims (Param Net): [3, 6, 11, 3, 2]
Param Dims (Param Net): 25


In [135]:
spatial_dims = [2, 5, 10, 2, 1]
print(f"Hidden Dims (Shape Net): {spatial_dims}")
print([2, 5, 10, 2, 1, 0])
print([0, 2, 5, 10, 2, 1, 0])

Hidden Dims (Shape Net): [2, 5, 10, 2, 1]
[2, 5, 10, 2, 1, 0]
[0, 2, 5, 10, 2, 1, 0]


In [161]:
# get total number of dims
n_layers = len(spatial_dims)
layer_dims = np.asarray(spatial_dims + [0]) * np.asarray([0] + spatial_dims)
layer_dims = layer_dims[1:-1]
total_out_dims = np.dot(spatial_dims + [0], [0] + spatial_dims)

assert total_out_dims == 2*5 + 5*10 + 10 * 2 + 2 * 1
assert sorted(layer_dims) == sorted([2*5, 5*10, 10*2, 2*1])
# plus bias sorted([2*5+5, 5*10+10, 10*2+2, 2*1+1])
# todo add bias!

In [160]:
layer_dims

array([10, 50, 20,  2])

In [146]:
spatial_dims + [0]

[2, 5, 10, 2, 1, 0]

In [147]:
[0] + spatial_dims

[0, 2, 5, 10, 2, 1]

In [134]:
all_elems = np.inner(spatial_dims, spatial_dims)
all_elems

134

In [106]:
split_params = np.split(params, np.cumsum(param_dims[:-1]), axis=0)
split_params

[array([-0.45176115,  0.09466085, -2.20262684,  2.15884523,  0.80498531,
         1.0330875 ]),
 array([ 0.06380215,  0.12346026, -0.15383907, -1.168401  , -0.23856133,
         0.34684593,  0.95689255, -0.74685097, -0.94385985, -1.33021461,
         1.2193369 ]),
 array([1.02685605, 0.17787037, 1.49893591]),
 array([ 0.28416399, -0.59514816])]

In [105]:
np.cumsum(param_dims)

array([ 6, 17, 20, 22])

### Method 1: NIFs

In [111]:
def forward(inputs, params):
    # split inputs (x,t)
    l_params =  np.array_split(params, np.cumsum(param_dims[:-1]), axis=0)
    
    for idim, iparam in zip(hidden_dims, l_params):
        print(f"N Hidden Dims: {idim} | N Params: {iparam.shape[0]}")
        assert idim == (iparam.shape[0]-1)
    
    outputs = inputs
    return outputs

In [128]:
def linear(x, params):
    w, b = np.array_split(params, [params.shape[0]-1], axis=0)
    print(x.shape, w.shape, b.shape)
    x = np.einsum("j,j->j", x, w) + b
    return x

In [129]:
params = np.random.randn(6)
x = np.random.randn(5)
x = linear(x, params)
print(x.shape)

(5,) (5,) (1,)
(5,)


In [112]:

# propagate forward
outputs = forward(inputs, params)
print(f"Outputs: {outputs.shape[0]}")

N Hidden Dims: 5 | N Params: 6
N Hidden Dims: 10 | N Params: 11
N Hidden Dims: 2 | N Params: 3
N Hidden Dims: 1 | N Params: 2
Outputs: 3


### Method 2: Neural Flows

In [52]:
def forward(inputs):
    # split inputs (x,t)
    x, t =  np.array_split(inputs, [n_spatial], axis=0)
    assert x.shape[0] == n_spatial
    assert t.shape[0] == n_temporal
    
    outputs = inputs
    return outputs

In [49]:
# create random matrix
inputs = np.random.randn(total_n_dims)
print(f"Inputs: {inputs.shape[0]}")

# propagate forward
outputs = forward(inputs)
print(f"Outputs: {outputs.shape[0]}")

Inputs: 3
Outputs: 3


In [22]:


# split inputs
# x = inputs[
out = np.array_split(inputs, [n_spatial, n_spatial + n_temporal], axis=0)
# x, t, mu
out

[array([-1.87486774, -0.36550761]),
 array([0.72380307]),
 array([0.97598894, 2.54082406, 1.46186936])]

In [23]:
n_temporal

1