In [1]:
# if running in colab, install the git repo
# !pip install git+https://github.com/jeffreypike/summa

In [1]:
# Reload modules before executing user code
%load_ext autoreload
%autoreload 2

In [2]:
# imports
import polars as pl
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np
# import plotly.express as px
import plotly.graph_objects as go

# summa
from summa.models import NAM, FeatureNet
from summa.training import get_optimal_params

In [39]:
seed = 1056
x = jax.random.normal(key = jax.random.PRNGKey(seed), shape = (100,)) * 3.1415
y = jax.random.normal(key = jax.random.PRNGKey(seed + 1), shape = (100,))
X = jnp.stack([x, y], axis = 1)
z = 3 * jnp.sin(x) - 2 * jnp.cos(y) + 0.2 * jax.random.normal(key = jax.random.PRNGKey(seed + 2), shape = (100,))
df_3d = pl.DataFrame().with_columns(pl.Series(name = 'x', values = np.array(x)),
                                    pl.Series(name = 'y', values = np.array(y)),
                                    pl.Series(name = 'z', values = np.array(z)))
df_3d.head()

x,y,z
f32,f32,f32
-2.650661,1.524662,-1.570335
3.475486,-0.946816,-2.007555
-6.11919,0.393396,-1.365766
0.806682,-0.426975,0.270358
-1.787984,1.336585,-3.065454


In [15]:
scatter_trace = go.Scatter3d(x = x,
                             y = y,
                             z = z,
                             mode = "markers",
                             marker = {"size": 2,
                                       "opacity": 0.8})

fig_scatter3d = go.Figure(data = [scatter_trace])
fig_scatter3d.show()

In [8]:
hidden_units = [1024, 1024]
nam = NAM(hidden_units)
key1, key2 = jax.random.split(jax.random.key(0))
dummy = jax.random.normal(key1, (1,2)) # Dummy input data
params_init = nam.init(key2, dummy) # Initialization call
jax.tree_util.tree_map(lambda t: t.shape, params_init) # Checking output shapes

{'params': {'subnets_0': {'Dense_0': {'bias': (1,), 'kernel': (1024, 1)},
   'ExuLayer_0': {'bias': (1,), 'kernel': (1, 1024)}},
  'subnets_1': {'Dense_0': {'bias': (1,), 'kernel': (1024, 1)},
   'ExuLayer_0': {'bias': (1,), 'kernel': (1, 1024)}}}}

In [18]:
def loss_fn(y, yhat):
    return optax.huber_loss(y, yhat).mean()

In [None]:
# params_optimal, history = get_optimal_params(model = nam,
#                                              params = params_init,
#                                              X_train = X,
#                                              y_train = jnp.expand_dims(z, axis = 1),
#                                              loss_fn = loss_fn,
#                                              optimizer = optax.adam(learning_rate = 1e-2),
#                                              epochs = 2500,
#                                              rngs = {'dropout': jax.random.key(1),
#                                                      'feature_dropout': jax.random.key(2),
#                                                      'batching': jax.random.key(3)},
#                                              hyperparams = {'dropout_rate': 0.2,
#                                                             'feature_dropout_rate': 0.2,
#                                                             'batch_size': 100,
#                                                             'weight_decay': 0.00001,
#                                                             'output_penalty': 0.001}
#                                      )

In [40]:
params_optimal, history = get_optimal_params(model = nam,
                                             params = params_init,
                                             X_train = X,
                                             y_train = jnp.expand_dims(z, axis = 1),
                                             loss_fn = loss_fn,
                                             optimizer = optax.adam(learning_rate = 1e-2),
                                             epochs = 2500,
                                             rngs = None,
                                     )

step 0, loss: 1.554315447807312
step 100, loss: 0.717690110206604
step 200, loss: 0.9960762858390808
step 300, loss: 0.6486291885375977
step 400, loss: 0.6399644613265991
step 500, loss: 0.6188598275184631
step 600, loss: 0.6347697377204895
step 700, loss: 0.6070959568023682
step 800, loss: 0.5934430956840515
step 900, loss: 0.6153350472450256
step 1000, loss: 0.5904744863510132
step 1100, loss: 0.5797809362411499
step 1200, loss: 0.6039299368858337
step 1300, loss: 0.5808226466178894
step 1400, loss: 0.5714450478553772
step 1500, loss: 0.5956602096557617
step 1600, loss: 0.574262797832489
step 1700, loss: 0.5656495690345764
step 1800, loss: 0.5895890593528748
step 1900, loss: 0.5695161819458008
step 2000, loss: 0.5613783597946167
step 2100, loss: 0.5848990678787231
step 2200, loss: 0.5658913254737854
step 2300, loss: 0.5581461787223816
step 2400, loss: 0.5811840295791626


In [41]:
loss_fig = go.Figure(data = [go.Scatter(y = history["train_loss"])])

loss_fig.show()

In [42]:
zhat = nam.apply(params_optimal, X)

In [43]:
model_trace = go.Mesh3d(x = x,
                        y = y,
                        z = nam.apply(params_optimal, X).flatten(),
                        color = 'orange',
                        opacity = 0.50)

fig_results = go.Figure(data = [scatter_trace, model_trace])
fig_results.show()


In [60]:
for j in range(X.shape[1]):
  params_subnet = {"params": params_optimal["params"][f"subnets_{j}"]}
  y = 3 * jnp.sin(X[:, j]) if j == 0 else -2 * jnp.cos(X[:, j])
  yhat = FeatureNet(hidden_units[j]).apply(params_subnet, X[:, j].reshape(100,1)).flatten()
  sorter = jnp.argsort(X[:, j])
  
  true_trace = go.Scatter(x = X[:, j][sorter],
                          y = y[sorter],
                          name = "True Relationship")
  
  approx_trace = go.Scatter(x = X[:, j][sorter],
                            y = yhat[sorter],
                            name = "FeatureNet Approximation")
  
  fig_subnet = go.Figure(data = [true_trace, approx_trace],
                         layout = go.Layout(title = f"Feature {j} Partial Dependence",
                                            xaxis = {"title": "x"},
                                            yaxis = {"title": "y"},
                                            legend = {"title": "Type"})
                         )
  fig_subnet.show()