In [1]:
# install the git repo
!pip install git+https://github.com/jeffreypike/summa

Collecting git+https://github.com/jeffreypike/summa
  Cloning https://github.com/jeffreypike/summa to /tmp/pip-req-build-96u3wbbi
  Running command git clone --filter=blob:none --quiet https://github.com/jeffreypike/summa /tmp/pip-req-build-96u3wbbi
  Resolved https://github.com/jeffreypike/summa to commit 8139504469bf403b52698cb238cc53dbc76e4f88
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [2]:
# Reload modules before executing user code
%load_ext autoreload
%autoreload 2

In [3]:
# imports
import polars as pl
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
import numpy as np
import plotly.express as px
import plotly.graph_objects as go

# summa
from summa.models import NAM, FeatureNet
from summa.training import get_optimal_params

In [4]:
seed = 1056
x = jax.random.normal(key = jax.random.PRNGKey(seed), shape = (100, 1)) * 3.1415
y = jax.random.normal(key = jax.random.PRNGKey(seed + 1), shape = (100, 1))
z = 3 * jnp.sin(x) - 2 * jnp.cos(y) + 0.2 * jax.random.normal(key = jax.random.PRNGKey(seed + 2), shape = (100, 1))
df_3d = pl.DataFrame().with_columns(pl.Series(name = 'x', values = np.array(x).flatten()),
                                    pl.Series(name = 'y', values = np.array(y).flatten()),
                                    pl.Series(name = 'z', values = np.array(z).flatten()))
df_3d.head()

x,y,z
f32,f32,f32
-2.650661,1.524662,-1.570335
3.475486,-0.946816,-2.007555
-6.11919,0.393396,-1.365766
0.806682,-0.426975,0.270358
-1.787984,1.336585,-3.065454


In [5]:
scatter_3d = px.scatter_3d(df_3d,
                           x = 'x',
                           y = 'y',
                           z = 'z')

scatter_3d.update_traces(marker_size = 2)
scatter_3d.show()

In [17]:
hidden_units = [1024, 1024]
nam = NAM(hidden_units)
key1, key2 = jax.random.split(jax.random.key(0))
dummy = jax.random.normal(key1, (1,2)) # Dummy input data
params_init = nam.init(key2, dummy) # Initialization call
jax.tree_util.tree_map(lambda t: t.shape, params_init) # Checking output shapes

{'params': {'subnets_0': {'Dense_0': {'bias': (1,), 'kernel': (1024, 1)},
   'ExuLayer_0': {'bias': (1,), 'kernel': (1, 1024)}},
  'subnets_1': {'Dense_0': {'bias': (1,), 'kernel': (1024, 1)},
   'ExuLayer_0': {'bias': (1,), 'kernel': (1, 1024)}}}}

In [18]:
X = jnp.array(df_3d.drop('z'))

In [19]:
def loss_fn(y, yhat):
    return optax.huber_loss(y, yhat).mean()

In [20]:
# params_optimal, history = get_optimal_params(model = nam,
#                                              params = params_init,
#                                              X_train = X,
#                                              y_train = z,
#                                              loss_fn = loss_fn,
#                                              optimizer = optax.adam(learning_rate = 1e-2),
#                                              epochs = 2500,
#                                              rngs = {'dropout': jax.random.key(1),
#                                                      'feature_dropout': jax.random.key(2),
#                                                      'batching': jax.random.key(3)},
#                                              hyperparams = {'dropout_rate': 0.2,
#                                                             'feature_dropout_rate': 0.2,
#                                                             'batch_size': 100,
#                                                             'weight_decay': 0.00001,
#                                                             'output_penalty': 0.001}
#                                      )

In [21]:
params_optimal, history = get_optimal_params(model = nam,
                                             params = params_init,
                                             X_train = X,
                                             y_train = z,
                                             loss_fn = loss_fn,
                                             optimizer = optax.adam(learning_rate = 1e-2),
                                             epochs = 2500,
                                             rngs = None,
                                     )

step 0, loss: 1.554315447807312
step 100, loss: 0.7177487015724182
step 200, loss: 0.6815364956855774
step 300, loss: 0.6191128492355347
step 400, loss: 0.643756091594696
step 500, loss: 0.6147710084915161
step 600, loss: 0.6596361994743347
step 700, loss: 0.609609067440033
step 800, loss: 0.5919407606124878
step 900, loss: 0.6176027655601501
step 1000, loss: 0.5889319181442261
step 1100, loss: 0.5781468152999878
step 1200, loss: 0.6037138104438782
step 1300, loss: 0.5795531868934631
step 1400, loss: 0.5701568722724915
step 1500, loss: 0.5949071049690247
step 1600, loss: 0.5731588006019592
step 1700, loss: 0.5645212531089783
step 1800, loss: 0.5887130498886108
step 1900, loss: 0.5685817003250122
step 2000, loss: 0.5605329275131226
step 2100, loss: 0.583942711353302
step 2200, loss: 0.5651670694351196
step 2300, loss: 0.5574885010719299
step 2400, loss: 0.5803300142288208


In [22]:
loss_fig = px.line(x = range(len(history['train_loss'])), y = history['train_loss'])

loss_fig.show()

In [23]:
zhat = nam.apply(params_optimal, X)

In [24]:
scatter_3d = px.scatter_3d(df_3d,
                           x = 'x',
                           y = 'y',
                           z = 'z')

scatter_3d.update_traces(marker_size = 2)

nam_fig = go.Figure(data = [go.Mesh3d(x = np.array(x).flatten(),
                                      y = np.array(y).flatten(),
                                      z = np.array(zhat).flatten(),
                                      color = 'orange',
                                      opacity = 0.50)])

results_fig_3d = scatter_3d.add_trace(nam_fig.data[0])
results_fig_3d.show()


In [25]:
for j in range(X.shape[1]):
  params_subnet = {"params": params_optimal["params"][f"subnets_{j}"]}
  df_net = (pl.DataFrame().with_columns(pl.Series(name = "x", values = np.array(X[:, j])),
                                        pl.Series(name = "True Relationship", values = 3 * np.sin(np.array(X[:, j]))),
                                        pl.Series(name = "FeatureNet Approximation", values = np.array(FeatureNet(hidden_units[j]).apply(params_subnet, X[:, j].reshape(100,1))).reshape(100,)))

                          .melt(id_vars = "x", value_vars = ["True Relationship", "FeatureNet Approximation"], variable_name = "Type")
                          .sort('x'))

  fig = px.line(df_net,
                x = 'x',
                y = 'value',
                color = "Type",
                line_dash = "Type",
                title = f"Feature {j} Partial Dependence",
                color_discrete_sequence = px.colors.qualitative.T10)
  fig.show()

In [26]:
X

Array([[-2.6506612 ,  1.5246624 ],
       [ 3.4754856 , -0.9468156 ],
       [-6.1191897 ,  0.3933959 ],
       [ 0.8066819 , -0.42697453],
       [-1.7879845 ,  1.3365847 ],
       [-1.9792763 ,  0.997208  ],
       [-1.4893208 , -0.70357025],
       [-2.3258204 , -1.5022509 ],
       [ 2.049179  ,  0.11408991],
       [-3.7891552 ,  0.39140534],
       [-1.709596  ,  0.07756606],
       [-0.00919644, -0.49695536],
       [ 1.6299365 ,  0.18339503],
       [ 1.3645695 , -1.0362006 ],
       [ 1.4399596 , -1.6225322 ],
       [ 2.937155  ,  0.20450577],
       [-0.05754102, -0.9235363 ],
       [ 0.2865657 , -0.94140446],
       [-1.0612054 ,  0.7628329 ],
       [-0.64206135,  0.34849542],
       [ 1.4673072 ,  0.0251117 ],
       [-0.61959904,  1.7373608 ],
       [-2.6852565 , -1.7521209 ],
       [ 1.116574  , -0.10521378],
       [-1.1049211 ,  0.18297812],
       [ 3.6906767 , -0.49451113],
       [ 2.742604  , -0.39046142],
       [-2.1049192 ,  0.906245  ],
       [-1.8071306 ,

In [27]:
df_3d

x,y,z
f32,f32,f32
-2.650661,1.524662,-1.570335
3.475486,-0.946816,-2.007555
-6.11919,0.393396,-1.365766
0.806682,-0.426975,0.270358
-1.787984,1.336585,-3.065454
-1.979276,0.997208,-4.175055
-1.489321,-0.70357,-4.651304
-2.32582,-1.502251,-1.912734
2.049179,0.11409,0.919418
-3.789155,0.391405,0.150974
