# 2022 Flatiron Machine Learning x Science Summer School

## Step 9: Implement symbolic discriminator

In this step, we train a symbolic discriminator to regularize the latent features of our `SRNet`.

The aim is that the latent features are incentivized to resemble functions that are easily discoverable with symbolic regression.

The plan is training a Wasserstein Generative Adversarial Network (WGAN) with gradient penalty. The advantage of a WGAN over a regular GAN is that it avoids mode collapse and vanishing gradients. The WGAN requires the discriminator to be 1-Lipschitz continuous, which is softly enforced with gradient penalty (check https://www.youtube.com/watch?v=v6y5qQ0pcg4).

First, we are going to apply a WGAN to the simplest possible formulation of our problem:

* Two input features

* Bottleneck

* Fixed $\alpha$ mask values

* Pretrained weights

* Only target functions as real data

Notes:

* We probably do not want to shuffle the input data

Questions:

* Do we feed the discriminator only latent feature values or also input feature values?

* Do we want different batches, i.e. the `SRNet` and real functions evaluated at different positions?

**TODO**:

* Try `LeakyReLU`, `BatchNorm`

* Check https://github.com/aladdinpersson/Machine-Learning-Collection

* Check `prior.jl`

* Check Miles' code (https://colab.research.google.com/drive/1IoJh46JV8EaxWjF3kDS51MZyAdXKzBxc?usp=sharing#scrollTo=4G3ZNAUeCxMX)

### Step 9.1: Check bottleneck DSN with SD regularization

In [1]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
import joblib

import torch
import wandb

from srnet import SRNet, SRData
import srnet_utils as ut

In [2]:
# plot losses
save_names = ["srnet_model_F00_bn_norm_sd_1e-04_test.pkl"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models", excl_names=["max"])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [3]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [5]:
model_name = "srnet_model_F00_bn_norm_sd_1e-04_test"
model_path = "models"
model_ext = ".pkl"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

print(model.layers1.alpha.detach().cpu().numpy()[all_nodes])
print(model.layers1.norm(model.layers1.alpha).detach().cpu().numpy()[all_nodes])

print("")

[1.8365434, 0.94673306, 0.1891221]
[2, 1, 0]
[[0.58066916 0.41933084]
 [0.7490023  0.25099772]
 [0.7104324  0.28956768]]



In [6]:
nodes = all_nodes[:3]

In [7]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [8]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
corr_data = [
    ("x**2", x_data**2), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]

In [10]:
ut.node_correlations(acts, nodes, corr_data)


Node 2
corr(n2, x**2): 0.9999
corr(n2, cos(y)): -0.0170
corr(n2, x*y): 0.1786

Node 1
corr(n1, x**2): 0.1706
corr(n1, cos(y)): -0.0622
corr(n1, x*y): 0.9999

Node 0
corr(n0, x**2): -0.0153
corr(n0, cos(y)): 0.9997
corr(n0, x*y): -0.0672


In [11]:
fig, ax = plt.subplots()

ax.scatter(x_data, x_data**2)
ax.scatter(x_data, acts[:,2])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [12]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    #("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [13]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=[1], model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [14]:
fig, ax = plt.subplots()

ax.scatter(y_data, np.cos(y_data))
ax.scatter(y_data, acts[:,0])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

### Step 9.2: Explore SD hyperparameters

In [15]:
# set wandb project
wandb_project = "92-bn-DSN-sd-study-F00"

In [16]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("DSN", "MLP"),
#         "hid_kwargs": {
#             "alpha": None,
#             "norm": "softmax",
#             "prune": None,
#             },
#         "lat_size": 3,
#         },
#     "epochs": 30000,
#     "runtime": None,
#     "batch_size": train_data.in_data.shape[0],
#     "shuffle": False,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "e1": 0.0,
#     "e2": 0.0,
#     "gc": 0.0,
#     "sd": 1e-4,
#     "disc": {
#         "hid_num": 2,
#         "hid_size": 128,
#         "lr": 1e-4,
#         "wd": 1e-4,
#         "iters": 5,
#         "gp": 1e-4,
#     },
# }

In [18]:
# define hyperparameter study
hp_study = {
    "method": "random",
    "parameters": {
        "sd": {
            "values": [1e-6, 1e-4, 1e-2, 1e0]
        },
        "disc": {
            "parameters": {
                "hid_num": {
                    "values": [1, 2, 4]
                },
                "hid_size": {
                    "values": [32, 64, 128, 256]
                },
                "lr": {
                    "values": [1e-5, 1e-4, 1e-3]
                },
                "iters": {
                    "values": [1, 2, 5, 10]
                },
                "gp": {
                    "values": [1e-5, 1e-4, 1e-3]
                },
            }
        }
    }
}

In [19]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

Create sweep with ID: 3uczgith
Sweep URL: https://wandb.ai/fabxy/92-bn-DSN-sd-study-F00/sweeps/3uczgith


In [None]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for run in runs:
    for f in run.files():
        if f.name[-len(file_ext):] == file_ext and not os.path.isfile(f.name):
            print(f"Downloading {os.path.basename(f.name)}.")
            run.file(f.name).download()

In [6]:
# plot losses
save_names = ["F00_l1_1e-03_gc_1e"]
save_path = "models"
model_names = ut.plot_losses(save_names, save_path="models")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …