# 2022 Flatiron Machine Learning x Science Summer School

## Step 10: Train symbolic discriminator with embedded information

Is the information provided to the symbolic discriminator (SD) sufficient? Intuition suggests that input feature information is important.

How could we provide additional information to the SD? Currently, the SD input size corresponds to the number of training data points.

In order to not only provide the $g(x)$ data, but also the $x$ data, we see three options:

* Embed $x^{(i)}$ and $g(x^{(i)})$ using an additional network that outputs a scalar value

* Assuming that $x$ is low-dimensional (and has a grid structure), use a convolutional neural network

* If $x$ is higher dimensional, consider a (convolutional?) graph neural network

Furthermore, consider adding more information, e.g. about derivatives or curvature.

### Step 10.1: Analyze information available to symbolic discriminator

In [1]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
import joblib

import torch
import wandb

from srnet import SRNet, SRData
from sdnet import SDData
import srnet_utils as ut

In [2]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [3]:
fun_path = "funs/F00_v2.lib"
disc_data = SDData(fun_path, in_var, train_data.in_data)

In [4]:
fig, ax = plt.subplots()

for i in range(disc_data.fun_data.shape[0]):
    ax.plot(disc_data.fun_data[i,:], label=disc_data.funs[i])
    
ax.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

How do latent feature activations look in comparison?

In [5]:
model_name = "srnet_model_F00_v2_bn_mask_sd_study_v21"
model_path = "models"
model_ext = ".pkl"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=False)

In [6]:
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]

corr_data = [
    ("x**2", x_data**2), 
    ("y**2", y_data**2), 
    ("cos(x)", np.cos(x_data)), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]

In [7]:
ut.node_correlations(acts, all_nodes, corr_data);


Node 0
corr(n0, x**2): 0.9997
corr(n0, y**2): 0.0161
corr(n0, cos(x)): -0.9573
corr(n0, cos(y)): -0.0175
corr(n0, x*y): 0.1751

Node 1
corr(n1, x**2): 0.0138
corr(n1, y**2): 0.9988
corr(n1, cos(x)): -0.0280
corr(n1, cos(y)): -0.9688
corr(n1, x*y): 0.0449

Node 2
corr(n2, x**2): 0.1615
corr(n2, y**2): -0.0026
corr(n2, cos(x)): -0.1379
corr(n2, cos(y)): -0.0149
corr(n2, x*y): 0.9976


In [8]:
fig, ax = plt.subplots()

i = 3
ax.plot(disc_data.fun_data[i,:], label=disc_data.funs[i])

i = 1
ax.plot(disc_data.fun_data[i,:], label=disc_data.funs[i])
    
j = 1
ax.plot(acts[:,j], label=j)
    
ax.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
fig, ax = plt.subplots()

i = 0
ax.plot(disc_data.fun_data[i,:], label=disc_data.funs[i])
    
j = 0
ax.plot(acts[:,j], label=j)
    
ax.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Let's sort the 1D input:

In [10]:
fig, ax = plt.subplots()

p_data = np.sort(y_data)

ax.plot(p_data**2)
    
ax.plot(-np.cos(p_data))
    
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

And let's plot over the input data:

In [11]:
fig, ax = plt.subplots()

p_data = np.sort(y_data)

ax.scatter(p_data, p_data**2)
ax.scatter(p_data, -np.cos(p_data))
    
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

### Step 10.2: Check SD training with embedded input feature information

How to set up gradient penalty correctly with additional embedded information?

* Calculate the gradients with respect to all SD inputs, e.g. $g(x)$ and $x$, and get the difference between the gradients' 2-norm and 1. Calculate the mean of the squared differences of the individual latent features.

#### There are some reproducibility issues:

```
local, old head:

00:18<00:00, 53.09it/s, train_loss=7.44, val_loss=18.11, min_corr=0.70]
Total training loss: 1.660e+01
Total validation loss: 1.770e+01

00:18<00:00, 53.49it/s, train_loss=7.44, val_loss=18.11, min_corr=0.70]
Total training loss: 1.660e+01
Total validation loss: 1.770e+01

local, new srnet.py:
00:21<00:00, 45.65it/s, train_loss=8.12, val_loss=17.70, min_corr=0.71]
Total training loss: 1.660e+01
Total validation loss: 1.770e+01

local, new srnet.py with extension commented out:
00:21<00:00, 45.65it/s, train_loss=8.12, val_loss=17.70, min_corr=0.71]
Total training loss: 1.660e+01
Total validation loss: 1.770e+01

local, new srnet.py and srnet_utils.py with write statement commented out:
00:20<00:00, 49.26it/s, train_loss=7.44, val_loss=18.11, min_corr=0.70]
Total training loss: 1.660e+01
Total validation loss: 1.770e+01

new normal:
00:20<00:00, 47.70it/s, train_loss=8.12, val_loss=17.70, min_corr=0.71]
Total training loss: 1.660e+01
Total validation loss: 1.770e+01

new sdnet.py:
00:20<00:00, 47.70it/s, train_loss=0.96, val_loss=0.80, min_corr=0.63]
Total training loss: 8.387e-01
Total validation loss: 8.008e-01

fixed sdnet.py:
00:21<00:00, 46.95it/s, train_loss=8.12, val_loss=17.70, min_corr=0.71]
Total training loss: 1.660e+01
Total validation loss: 1.770e+01


local py38, torch11, numpy18, fixed sdnet.py, v23:
1001/1001 [00:44<00:00, 22.30it/s, train_loss=-2.35e+01, val_loss=10.54, min_corr=0.95]
Total training loss: 9.924e+00
Total validation loss: 1.054e+01


local py37, torch11, numpy17, fixed sdnet.py, v23:
1001/1001 [00:51<00:00, 19.47it/s, train_loss=-1.86e+01, val_loss=10.56, min_corr=0.95]
Total training loss: 9.949e+00
Total validation loss: 1.056e+01

local py37, torch11, numpy18, fixed sdnet.py, v23:
1001/1001 [00:47<00:00, 21.02it/s, train_loss=-1.86e+01, val_loss=10.56, min_corr=0.95]
Total training loss: 9.949e+00
Total validation loss: 1.056e+01

local py37, torch11, numpy21, fixed sdnet.py, v23:
1001/1001 [00:46<00:00, 21.61it/s, train_loss=-2.35e+01, val_loss=10.54, min_corr=0.95]
Total training loss: 9.924e+00
Total validation loss: 1.054e+01

local py37.13 torch12, numpy21, fixed sdnet.py, v23:
1001/1001 [00:48<00:00, 20.76it/s, train_loss=-2.15e+01, val_loss=10.53, min_corr=0.95]
Total training loss: 9.922e+00
Total validation loss: 1.053e+01

spartan local copy py37.4, torch11, numpy17, fixed sdnet.py, v23:
1001/1001 [00:49<00:00, 20.25it/s, train_loss=-1.86e+01, val_loss=10.56, min_corr=0.95]
Total training loss: 9.949e+00
Total validation loss: 1.056e+01

cluster py37.4, torch11, numpy17, fixed sdnet.py, v23:
1001/1001 [00:36<00:00, 27.52it/s, train_loss=-2.19e+01, val_loss=10.49, min_corr=0.95]
Total training loss: 9.882e+00
Total validation loss: 1.049e+01
```

The results can vary between `torch` and `numpy` versions.

In [12]:
# set wandb project
wandb_project = "102-bn-mask-DSN-emb-study-F00_v2"

In [13]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("DSN", "MLP"),
#         "hid_kwargs": {
#             "alpha": [[1,0],[0,1],[1,1]],
#             "norm": None,
#             "prune": None,
#             },
#         "lat_size": 3,
#         },
#     "epochs": 30000,
#     "runtime": None,
#     "batch_size": train_data.in_data.shape[0],
#     "shuffle": False,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "e1": 0.0,
#     "e2": 0.0,
#     "gc": 0.0,
#     "sd": 1e-7,
#     "disc": {
#         "hid_num": (1,3),
#         "hid_size": (32,128),
#         "emb_size": train_data.in_data.shape[1] + 1,
#         "lr": 1e-3,
#         "wd": 1e-4,
#         "iters": 5,
#         "gp": 1e-5,
#     },
# }

In [14]:
# define hyperparameter study
hp_study = {
    "method": "random",
    "parameters": {
        "sd": {
            "values": [1e-8, 1e-7, 1e-6, 1e-5]
        },
        "disc": {
            "parameters": {
                "hid_num": {
                    "values": [(1,4), (2,4), (1,6), (2,6)]
                },
                "hid_size": {
                    "values": [(32,128), (64,128), (32,256), (64,256)]
                },
                "emb_size": {
                    "values": [3]
                },
                "lr": {
                    "values": [1e-5, 1e-4, 1e-3, 1e-2]
                },
                "wd": {
                    "values": [1e-6]
                },
                "iters": {
                    "values": [5]
                },
                "gp": {
                    "values": [1e-5, 1e-4, 1e-3]
                },
            }
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

<img src="results/102-bn-mask-DSN-emb-study-F00_v2.png">

In [39]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for r, run in enumerate(runs):
    if run.summaryMetrics['min_corr'] > 0.9:
        for f in run.files():
            if f.name[-len(file_ext):] == file_ext:
                file_name = f.name.replace(file_ext, f"_v{r+1}{file_ext}")
                print(f"Downloading {os.path.basename(file_name)}.")
                run.file(f.name).download()
                os.rename(f.name, file_name)

Downloading srnet_model_F00_v2_bn_mask_emb_study_v2.pkl.
Downloading srnet_model_F00_v2_bn_mask_emb_study_v13.pkl.
Downloading srnet_model_F00_v2_bn_mask_emb_study_v14.pkl.
Downloading srnet_model_F00_v2_bn_mask_emb_study_v17.pkl.
Downloading srnet_model_F00_v2_bn_mask_emb_study_v20.pkl.


In [15]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [16]:
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]

In [17]:
corr_data = [
    ("x**2", x_data**2), 
    # ("y**2", y_data**2), 
    # ("cos(x)", np.cos(x_data)), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]

In [18]:
# get validation loss and latent feature correlations
model_path = "models"
save_name = "F00_v2_bn_mask_emb_study"

models = [f for f in os.listdir(model_path) if save_name in f]

val_corr = {}

for model_name in models:
    print(f"Loading {model_name}.")
    model = ut.load_model(model_name, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=False)
        
    corr_mat = ut.node_correlations(acts, all_nodes, corr_data, show=False)
    corr = [np.abs(c).max() for c in corr_mat]
    
    with torch.no_grad():
        preds = model(val_data.in_data)
        
    val_loss = (preds - val_data.target_data).pow(2).mean().item()
    val_corr[model_name] = (val_loss, corr)

Loading srnet_model_F00_v2_bn_mask_emb_study_v13.pkl.
Loading srnet_model_F00_v2_bn_mask_emb_study_v14.pkl.
Loading srnet_model_F00_v2_bn_mask_emb_study_v17.pkl.
Loading srnet_model_F00_v2_bn_mask_emb_study_v2.pkl.
Loading srnet_model_F00_v2_bn_mask_emb_study_v20.pkl.


In [19]:
fig, ax = plt.subplots()

for v in val_corr:
    ax.plot(val_corr[v][0], np.min(val_corr[v][1]), 'x', label=v.split('.')[0].split('_')[-1])

ax.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [20]:
# plot losses
save_names = [
    "srnet_model_F00_v2_bn_mask_emb_study_v13",
    "srnet_model_F00_v2_bn_mask_emb_study_v20",
    "srnet_model_F00_v2_bn_mask_emb_study_v17",
    "srnet_model_F00_v2_bn_mask_emb_study_v2",
]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models", excl_names=[])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [21]:
corr_data = [
    ("x**2", x_data**2), 
    ("y**2", y_data**2), 
    ("cos(x)", np.cos(x_data)), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]

In [22]:
model_path = "models"
model_ext = ".pkl"

for model_name in models:
    print(model_name)
    
    state = joblib.load(os.path.join(model_path, model_name + model_ext))
    
    print(state['hyperparams']['sd'])
    print(state['hyperparams']['disc'])
    
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
    
    print(f"Validation error: {state['total_val_loss']:.4e}")
        
    corr_mat = ut.node_correlations(acts, all_nodes, corr_data, show=True)
        
    print("")

srnet_model_F00_v2_bn_mask_emb_study_v13
1e-06
{'emb_size': 3, 'gp': 0.0001, 'hid_num': [1, 4], 'hid_size': [32, 256], 'iters': 5, 'lr': 0.001, 'wd': 1e-06}
[1.5251404, 1.3298478, 0.102596655]
[1, 0, 2]
Validation error: 2.1712e-02

Node 1
corr(n1, x**2): 0.0173
corr(n1, y**2): 0.9985
corr(n1, cos(x)): -0.0313
corr(n1, cos(y)): -0.9737
corr(n1, x*y): 0.0506

Node 0
corr(n0, x**2): 0.9990
corr(n0, y**2): 0.0150
corr(n0, cos(x)): -0.9493
corr(n0, cos(y)): -0.0165
corr(n0, x*y): 0.1809

Node 2
corr(n2, x**2): -0.1650
corr(n2, y**2): -0.3646
corr(n2, cos(x)): 0.1613
corr(n2, cos(y)): 0.3529
corr(n2, x*y): -0.9401

srnet_model_F00_v2_bn_mask_emb_study_v20
1e-08
{'emb_size': 3, 'gp': 1e-05, 'hid_num': [2, 4], 'hid_size': [64, 256], 'iters': 5, 'lr': 0.0001, 'wd': 1e-06}
[1.8680242, 1.2094715, 0.14276965]
[1, 0, 2]
Validation error: 2.5851e-03

Node 1
corr(n1, x**2): 0.0139
corr(n1, y**2): 0.9908
corr(n1, cos(x)): -0.0287
corr(n1, cos(y)): -0.9785
corr(n1, x*y): 0.0477

Node 0
corr(n0, x**2):

These results are good, but very similar to the results without embedding. The issue of converging to $y^2$ instead of $\text{cos}(y)$ remains.

What are possible reasons?

* The signal $\text{cos}(y)$ is small compared to $x^2$ and $x \cdot y$ and thus, predictions errors are also small.

* The input data is constant. Does it not provide additional information to the discriminator?

In [23]:
model_name = "srnet_model_F00_v2_bn_mask_emb_study_v2"
model_path = "models"
model_ext = ".pkl"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

print(model.layers1.norm(model.layers1.alpha).detach().cpu().numpy()[all_nodes])

print("")

[1.3996812, 1.0809922, 0.45018354]
[0, 1, 2]
[[1. 0.]
 [0. 1.]
 [1. 1.]]



In [24]:
nodes = all_nodes[:3]

In [25]:
fig, ax = plt.subplots()

n = 0
bias = True

ax.scatter(x_data, x_data**2)
# ax.scatter(x_data, acts[:,n])
ax.scatter(x_data, model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item())

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [26]:
fig, ax = plt.subplots()

n = 1
bias = True

ax.scatter(y_data, np.cos(y_data))
ax.scatter(y_data, model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item())

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [27]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    #("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [28]:
n = 2
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=[n], model=model, bias=False, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

**TODO**: Plot embedding