# 2022 Flatiron Machine Learning x Science Summer School

## Step 14: Apply symbolic discriminator to advanced problems

Next steps:

* Define `F09` with larger latent dimensions

* <s>Define `F10` with nonlinear function $f(x)$</s>

* <s>Define `F11` with larger input and latent dimensions</s>

* Define suitable function libraries `F09_v1` <s>and `F10_v1`</s>

* Run hyperparameter studies

### Step 14.1: Create data with larger latent dimensions

In [1]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
import joblib

import torch
import wandb

from srnet import SRNet, SRData
from sdnet import SDData
import srnet_utils as ut

What would be a suitable extension?

Currently, we have a latent dimension of 3. Let's define a more complex problem with a latent dimension of 5:

In [2]:
data_size = int(1e3)

In [3]:
x_data = torch.randn(size=(data_size, 2))

In [4]:
z_data = [
    ("1.5*x*x", 1.5 * x_data[:,0] * x_data[:,0]),
    ("3.5*sin(2.5*y)", 3.5 * torch.sin(2.5 * x_data[:,1])),
    ("3.0*x*cos(0.5*x)", 3.0 * x_data[:,0] * torch.cos(0.5*x_data[:,0])),
    ("x*y", x_data[:,0] * x_data[:,1]),
    ("0.5*y*exp(x)", 0.5 * x_data[:,1] * torch.exp(x_data[:,0])),
]

ut.plot_acts(x_data[:,0], x_data[:,1], z_data)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Create function library `F09_v1`:

In [5]:
fun_path = "funs/F09_v1.lib"
in_var = "X09"
shuffle = False

In [6]:
disc_data = SDData(fun_path, in_var, shuffle=shuffle)

In [7]:
disc_data.funs

['N*1.5*X09[:,0]**2',
 'N*1.5*X09[:,1]**2',
 'N*0.25*X09[:,0]**3',
 'N*0.25*X09[:,1]**3',
 'N*3.5*np.sin(2*U*2.5*X09[:,0])',
 'N*3.5*np.sin(2*U*2.5*X09[:,1])',
 'N*3.5*np.cos(2*U*2.5*X09[:,0])',
 'N*3.5*np.cos(2*U*2.5*X09[:,1])',
 'N*3.0*X09[:,0]*np.sin(2*U*0.5*X09[:,0])',
 'N*3.0*X09[:,1]*np.sin(2*U*0.5*X09[:,1])',
 'N*3.0*X09[:,0]*np.cos(2*U*0.5*X09[:,0])',
 'N*3.0*X09[:,1]*np.cos(2*U*0.5*X09[:,1])',
 'N*X09[:,0]*X09[:,1]',
 'N*X09[:,1]*X09[:,0]',
 'N*0.5*X09[:,0]*np.exp(U*X09[:,1])',
 'N*0.5*X09[:,1]*np.exp(U*X09[:,0])']

In [8]:
n_samp = 10
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

In [9]:
fig, ax = plt.subplots()

for _ in range(n_samp):
    ax.scatter(x_data[:,0], disc_data.get(in_data=x_data)[0,0,:], color=colors[0], alpha=0.5)
    ax.scatter(x_data[:,0], disc_data.get(in_data=x_data)[0,2,:], color=colors[1], alpha=0.5)
    ax.scatter(x_data[:,0], disc_data.get(in_data=x_data)[0,4,:], color=colors[2], alpha=0.5)
    ax.scatter(x_data[:,0], disc_data.get(in_data=x_data)[0,6,:], color=colors[3], alpha=0.5)
    ax.scatter(x_data[:,0], disc_data.get(in_data=x_data)[0,8,:], color=colors[4], alpha=0.5)
    ax.scatter(x_data[:,0], disc_data.get(in_data=x_data)[0,10,:], color=colors[5], alpha=0.5)
    
ax.scatter(x_data[:,0], 1.5*x_data[:,0]**2, color='k')
ax.scatter(x_data[:,0], 3.5*torch.sin(2.5*x_data[:,0]), color='k')
ax.scatter(x_data[:,0], 3.0*x_data[:,0]*torch.cos(0.5*x_data[:,0]), color='k')
plt.show()


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

### Step 14.2: Train bottleneck masked DSN with SD regularization on `F09_v1`

In [10]:
# set wandb project
wandb_project = "142-bn-mask-DSN-sd-study-F09_v1"

In [11]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("DSN", "MLP"),
#         "hid_kwargs": {
#             "alpha": [[1,0],[0,1],[1,0],[1,1],[1,1]],
#             "norm": None,
#             "prune": None,
#             },
#         "lat_size": 5,
#         },
#     "epochs": 30000,
#     "runtime": None,
#     "batch_size": train_data.in_data.shape[0],
#     "shuffle": False,
#     "lr": 1e-4,
#     "wd": 1e-6,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "e1": 0.0,
#     "e2": 0.0,
#     "e3": 0.0,
#     "gc": 0.0,
#     "sd": 1e-6,
#     "disc": {
#         "hid_num": 6,
#         "hid_size": 128,
#         "emb_size": None,
#         "lr": 1e-3,
#         "wd": 1e-7,
#         "betas": (0.9,0.999),
#         "iters": 5,
#         "gp": 1e-5,
#     },
# }

In [12]:
# define hyperparameter study
hp_study = {
    "method": "random",
    "parameters": {
        "lr": {
            "values": [1e-5, 1e-4, 1e-3]
        },
        "sd": {
            "values": [1e-7, 1e-6, 1e-5, 1e-4]
        },
        "disc": {
            "parameters": {
                "hid_num": {
                    "values": [2, 4, 6, 8]
                },
                "hid_size": {
                    "values": [64, 128, 256, 512]
                },
                "lr": {
                    "values": [1e-5, 1e-4, 1e-3, 1e-2]
                },
                "iters": {
                    "values": [2, 5, 8]
                },
                "gp": {
                    "values": [1e-6, 1e-5, 1e-4, 1e-3]
                },
            }
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

<img src="results/132-bn-mask-DSN-sd-study-F07_v1.png">

Quick notes:

* Only 2 out of 60 training runs improve `min_corr` without large oscillations:

        good, better

    * lr: 1e-5, 1e-5

    * sd: 1e-6, 1e-7
       
    * sd.lr: 1e-3, 1e-2
    
    * sd.iters: 2, 8
    
    * sd.gp: 1e-4, 1e-4
    
    * sd.hid_num: 2, 6

    * sd.hid_size: 256, 64
    
* 1 training runs with low validation error but high oscillations:

    * lr: 1e-3
    
    * sd: 1e-7
       
    * sd.lr: 1e-3
    
    * sd.iters: 8
    
    * sd.gp: 1e-4
    
    * sd.hid_num: 4
    
    * sd.hid_size: 128

* 2 training runs in between:

    more osc., less osc.

    * lr: 1e-4, 1e-4

    * sd: 1e-6, 1e-7
       
    * sd.lr: 1e-3, 1e-5
    
    * sd.iters: 5, 8
    
    * sd.gp: 1e-3, 1e-3
    
    * sd.hid_num: 8, 2

    * sd.hid_size: 128, 256

* 1 training run with large training error:

    * lr: 1e-4

    * sd: 1e-6
       
    * sd.lr: 1e-3
    
    * sd.iters: 5
    
    * sd.gp: 1e-5
    
    * sd.hid_num: 2

    * sd.hid_size: 256


In [13]:
# load data
data_path = "data_1k"

in_var = "X09"
lat_var = "G09"
target_var = "F09"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [14]:
# plot losses
save_names = [
    "srnet_model_F09_v1_bn_mask_sd_study_v1",
]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models", excl_names=[])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
train_data.lat_data.shape

torch.Size([700, 5])

In [16]:
corr_data = [(f"g{i}", train_data.lat_data[:,i]) for i in range(train_data.lat_data.shape[1])]

In [17]:
model_path = "models"
model_ext = ".pkl"

for model_name in models:
    print(model_name)
    
    state = joblib.load(os.path.join(model_path, model_name + model_ext))
    
    print(state['hyperparams']['sd'])
    print(state['hyperparams']['disc'])
    
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
    
    print(f"Validation error: {state['total_val_loss']:.4e}")
        
    corr_mat = ut.node_correlations(acts, all_nodes, corr_data, show=True)
        
    print("")

srnet_model_F09_v1_bn_mask_sd_study_v1
1e-07
{'gp': 0.0001, 'hid_num': 6, 'hid_size': 64, 'iters': 8, 'lr': 0.01}
[9.079713, 7.470909, 6.1827435, 3.0480175, 1.0701797]
[4, 3, 0, 1, 2]
Validation error: 1.0992e+00

Node 4
corr(n4, g0): -0.2870
corr(n4, g1): -0.6009
corr(n4, g2): -0.4044
corr(n4, g3): -0.5032
corr(n4, g4): -0.5091

Node 3
corr(n3, g0): -0.4710
corr(n3, g1): -0.3507
corr(n3, g2): -0.3667
corr(n3, g3): -0.6553
corr(n3, g4): -0.5799

Node 0
corr(n0, g0): -0.8660
corr(n0, g1): 0.0208
corr(n0, g2): -0.5016
corr(n0, g3): -0.1287
corr(n0, g4): 0.0008

Node 1
corr(n1, g0): 0.0235
corr(n1, g1): -0.8218
corr(n1, g2): 0.0028
corr(n1, g3): 0.0068
corr(n1, g4): -0.3690

Node 2
corr(n2, g0): -0.3999
corr(n2, g1): 0.0272
corr(n2, g2): -0.8802
corr(n2, g3): -0.0660
corr(n2, g4): -0.0052



In [18]:
model_name = "srnet_model_F09_v1_bn_mask_sd_study_v1"
model_path = "models"
model_ext = ".pkl"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)

In [19]:
fig, ax = plt.subplots()

n = 0
bias = True

ax.scatter(train_data.in_data[:,n], train_data.lat_data[:,n])
ax.scatter(train_data.in_data[:,n], model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item())

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [20]:
fig, ax = plt.subplots()

n = 1
bias = True

ax.scatter(train_data.in_data[:,n], train_data.lat_data[:,n])
ax.scatter(train_data.in_data[:,n], model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item())

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [21]:
fig, ax = plt.subplots()

n = 2
bias = True

ax.scatter(train_data.in_data[:,0], train_data.lat_data[:,n])
ax.scatter(train_data.in_data[:,0], model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item())

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [22]:
n = 3
z_data = [(f"g{n}", train_data.lat_data[:,n])]
plot_size = train_data.target_data.shape[0]

ut.plot_acts(train_data.in_data[:,0], train_data.in_data[:,1], z_data, acts=acts, nodes=[n], model=model, bias=False, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [23]:
n = 4
z_data = [(f"g{n}", train_data.lat_data[:,n])]
plot_size = train_data.target_data.shape[0]

ut.plot_acts(train_data.in_data[:,0], train_data.in_data[:,1], z_data, acts=acts, nodes=[n], model=model, bias=False, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

**Note**: Training `srnet_model_F09_v1_bn_mask_sd_study_v1` for 150,000 epochs does not increase `min_corr` and diverges eventually