# 2022 Flatiron Machine Learning x Science Summer School

## Step 10: Train symbolic discriminator with embedded information

What if the information provided to the symbolic discriminator (SD) is not sufficient? Intuition might suggest that input feature information is important.

How could additional information be provided to the SD? Currently, the SD input size corresponds to the number of training data points.

In order to not only provide the $g(x)$ data, but also the $x$ data, we see three options:

* Embed $x^{(i)}$ and $g(x^{(i)})$ using an additional network that outputs a scalar value

* Assuming that $x$ is low-dimensional (and has a grid structure), use a convolutional neural network

* If $x$ is higher dimensional, consider a (convolutional?) graph neural network

Furthermore, consider adding more information, e.g. about derivatives or curvature.

### Step 10.1: Analyze information available to symbolic discriminator

In [2]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
import joblib

import torch
import wandb

from srnet import SRNet, SRData
from sdnet import SDData
import srnet_utils as ut

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [4]:
fun_path = "funs/F00_v2.lib"
disc_data = SDData(fun_path, in_var, train_data.in_data)

In [5]:
fig, ax = plt.subplots()

for i in range(disc_data.fun_data.shape[0]):
    ax.plot(disc_data.fun_data[i,:], label=disc_data.funs[i])
    
ax.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

How do latent feature activations look in comparison?

In [13]:
model_name = "srnet_model_F00_v2_bn_mask_sd_study_v11"
model_path = "models"
model_ext = ".pkl"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=False)

In [10]:
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]

corr_data = [
    ("x**2", x_data**2), 
    ("y**2", y_data**2), 
    ("cos(x)", np.cos(x_data)), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]

In [11]:
ut.node_correlations(acts, all_nodes, corr_data);


Node 0
corr(n0, x**2): 0.9997
corr(n0, y**2): 0.0161
corr(n0, cos(x)): -0.9573
corr(n0, cos(y)): -0.0175
corr(n0, x*y): 0.1751

Node 1
corr(n1, x**2): 0.0138
corr(n1, y**2): 0.9988
corr(n1, cos(x)): -0.0280
corr(n1, cos(y)): -0.9688
corr(n1, x*y): 0.0449

Node 2
corr(n2, x**2): 0.1615
corr(n2, y**2): -0.0026
corr(n2, cos(x)): -0.1379
corr(n2, cos(y)): -0.0149
corr(n2, x*y): 0.9976


In [15]:
fig, ax = plt.subplots()

i = 3
ax.plot(disc_data.fun_data[i,:], label=disc_data.funs[i])

i = 1
ax.plot(disc_data.fun_data[i,:], label=disc_data.funs[i])
    
j = 1
ax.plot(acts[:,j], label=j)
    
ax.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [17]:
fig, ax = plt.subplots()

i = 0
ax.plot(disc_data.fun_data[i,:], label=disc_data.funs[i])
    
j = 0
ax.plot(acts[:,j], label=j)
    
ax.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

If we sort the 1D input:

In [19]:
fig, ax = plt.subplots()

p_data = np.sort(y_data)

ax.plot(p_data**2)
    
ax.plot(-np.cos(p_data))
    
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

If we plot over the input data:

In [20]:
fig, ax = plt.subplots()

p_data = np.sort(y_data)

ax.scatter(p_data, p_data**2)
    
ax.scatter(p_data, -np.cos(p_data))
    
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

### Step 10.2: Check SD training with embedded input feature information

How to set up gradient penalty correctly with additional embedded information?

data_real.shape
torch.Size([3, 700])
data_fake.shape
torch.Size([3, 700])
gradients.shape
Traceback (most recent call last):
  File "<string>", line 1, in <module>
NameError: name 'gradients' is not defined
gradient.shape
torch.Size([3, 700])
pred.shape
torch.Size([3, 1])
interp.shape
torch.Size([3, 700])
alpha.shape
torch.Size([3, 1])
(gradient.norm(2, dim=1) - 1).pow(2).mean()
tensor(0.7922, grad_fn=<MeanBackward0>)
(gradient.norm(2, dim=1) - 1).pow(2).mean().item()
0.7922002673149109
    
0.7922002673149109

This should work now, what are the runtimes?

No embedding, no loop:
00:28<00:00, 34.93it/s, train_loss=1.87, val_loss=3.02
Total training loss: 2.774e+00
Total validation loss: 3.018e+00

No embedding, squeeze:
00:26<00:00, 37.40it/s, train_loss=1.87, val_loss=3.02
Total training loss: 2.774e+00
Total validation loss: 3.018e+00

No embedding, reshape:
00:26<00:00, 37.14it/s, train_loss=1.87, val_loss=3.02
Total training loss: 2.774e+00
Total validation loss: 3.018e+00

No embedding, loop:
00:30<00:00, 32.38it/s, train_loss=1.87, val_loss=3.02
Total training loss: 2.775e+00
Total validation loss: 3.020e+00


---

No embedding, no loop, 0 * gp:
00:27<00:00, 36.85it/s, train_loss=1.74, val_loss=3.11
Total training loss: 2.851e+00
Total validation loss: 3.112e+00

No embedding, squeeze, 0 * gp:
00:29<00:00, 34.26it/s, train_loss=1.74, val_loss=3.11
Total training loss: 2.851e+00
Total validation loss: 3.112e+00


No embedding, loop, 0 * gp:
00:33<00:00, 30.01it/s, train_loss=1.79, val_loss=3.13
Total training loss: 2.861e+00
Total validation loss: 3.126e+00

---

No embedding, no loop, no gp:
00:19<00:00, 50.06it/s, train_loss=3.15, val_loss=2.93
Total training loss: 2.713e+00
Total validation loss: 2.926e+00

No embedding, squeeze, no gp:
00:21<00:00, 46.37it/s, train_loss=3.15, val_loss=2.93
Total training loss: 2.713e+00
Total validation loss: 2.926e+00

No embedding, loop, no gp:
00:21<00:00, 47.30it/s, train_loss=3.15, val_loss=2.93
Total training loss: 2.712e+00
Total validation loss: 2.926e+00

---

No embedding, reshape:
01:54<00:00,  8.77it/s, train_loss=1.00, val_loss=5.17
Total training loss: 4.114e+00
Total validation loss: 5.170e+00

No embedding, loop:
02:13<00:00,  7.48it/s, train_loss=2.02, val_loss=5.18
Total training loss: 4.118e+00
Total validation loss: 5.178e+00

run for longer

increase library:

* Resample `F00_v2` coefficients

* All simple two input functions

(Train restarting from trained DSN)

Train with GhostAdam

* Resample input data

* Resample coefficients of library functions

* Ensure long epoch convergence

* **Input data "noise"**

* Resolve bottleneck

* Input data dimension

* Change complexity of $g(x)$

* Change $f(x)$