# 2022 Flatiron Machine Learning x Science Summer School

## Step 17: Train MLP with fixed symbolic discriminator

Can we utilize a fixed symbolic discriminator (SD) to incentivize quadratic latent features?

In [1]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
import joblib

import torch
import torch.nn as nn
import wandb

from srnet import SRNet, SRData
from sdnet import SDNet, SDData
from csdnet import CSDNet
import srnet_utils as ut

### Step 17.1: Define latent and target data on `X11`

Function library `F11_v1`: `N0*0.05*(X11[:,0] + 0.5*N1)**2 + 0.15*N2`

Sample coefficients for latent functions:

1. `N0`: -1.60, `N1`: 0.33, `N2`: -1.40
2. `N0`: 1.57, `N1`: -1.26, `N2`: -1.68
3. `N0`: 1.79, `N1`: 0.42, `N2`: -0.25

Resulting latent functions $g(x)$:

1. `-0.08*(X11[:,0] + 0.165)**2 - 0.21`
2. `0.0785*(X11[:,0] - 0.63)**2 - 0.252`
3. `0.0895*(X11[:,0] + 0.21)**2 - 0.0375`

A linear target function $f(x)$ would result in a quadratic composition $f(g(x))$.

Thus, let's define $f(x) = x_0 \cdot x_1 + \text{sin}(x_2)$.

In [2]:
# load data
data_path = "data_1k"
in_var = "X11"
lat_var = "G11"
target_var = "F11"

train_data = SRData(data_path, in_var, lat_var, target_var, data_mask=None)

In [3]:
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
styles = ['--', '-.', ':', '-']
print_comp = True

fig, ax = plt.subplots()

ax.plot(train_data.in_data[:,0], train_data.target_data[:,0], ls=styles[-1], color=colors[0])

if print_comp:
    for i in range(train_data.lat_data.shape[1]):
        ax.plot(train_data.in_data[:,0], train_data.lat_data[:,i], ls=styles[i], color=colors[0], alpha=0.5)

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

### Step 17.2: Explore MLP training

Linear SD predictions:

In [4]:
# plot losses
save_names = ["srnet_model_F11_v1_critic_check_lin"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models", excl_names=[])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [5]:
corr_data = [(f"g{i}", train_data.lat_data[:,i]) for i in range(train_data.lat_data.shape[1])]

In [6]:
model_path = "models"
model_ext = ".pkl"

for model_name in models:
    print(model_name)
    
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
            
    corr_mat = ut.node_correlations(acts, all_nodes, corr_data, show=True)
        
    print("")

srnet_model_F11_v1_critic_check_lin_lr_1e-4_sd_0e0
[0.29544127, 0.09934967, 0.081448585]
[1, 2, 0]

Node 1
corr(n1, g0): -0.2106
corr(n1, g1): -0.6267
corr(n1, g2): 0.2640

Node 2
corr(n2, g0): 0.9275
corr(n2, g1): -0.3686
corr(n2, g2): -0.9412

Node 0
corr(n0, g0): 0.3988
corr(n0, g1): -0.9334
corr(n0, g2): -0.3499

srnet_model_F11_v1_critic_check_lin_lr_1e-4_sd_1e-4
[2484.077, 1887.9333, 1343.1079]
[2, 0, 1]

Node 2
corr(n2, g0): 0.8873
corr(n2, g1): -0.2011
corr(n2, g2): -0.9112

Node 0
corr(n0, g0): 0.8876
corr(n0, g1): -0.2016
corr(n0, g2): -0.9115

Node 1
corr(n1, g0): 0.7392
corr(n1, g1): -0.9873
corr(n1, g2): -0.7011

srnet_model_F11_v1_critic_check_lin_lr_1e-5_sd_1e-2
[0.40147242, 0.10085495, 0.00064590725]
[1, 0, 2]

Node 1
corr(n1, g0): 0.4345
corr(n1, g1): -0.9728
corr(n1, g2): -0.3843

Node 0
corr(n0, g0): 0.9776
corr(n0, g1): -0.7717
corr(n0, g2): -0.9651

Node 2
corr(n2, g0): -0.8806
corr(n2, g1): 0.8234
corr(n2, g2): 0.8601

srnet_model_F11_v1_critic_check_lin_lr_1e-5_s

In [7]:
model_name = "srnet_model_F11_v1_critic_check_lin_lr_1e-4_sd_0e0"
model_path = "models"
model_ext = ".pkl"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    bsl_preds, bsl_acts = model(train_data.in_data, get_lat=True)

In [8]:
model_name = "srnet_model_F11_v1_critic_check_lin_lr_1e-5_sd_1e-4"
model_path = "models"
model_ext = ".pkl"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)

In [9]:
fig, ax = plt.subplots()

nt = 0
na = 0
nb = 0

ax.scatter(train_data.in_data[:,0], train_data.lat_data[:,nt])
ax.scatter(train_data.in_data[:,0], acts[:,na])
ax.scatter(train_data.in_data[:,0], bsl_acts[:,nb])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [10]:
fig, ax = plt.subplots()

nt = 1
na = 1
nb = 1

ax.scatter(train_data.in_data[:,0], train_data.lat_data[:,nt])
ax.scatter(train_data.in_data[:,0], acts[:,na])
ax.scatter(train_data.in_data[:,0], bsl_acts[:,nb])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [11]:
fig, ax = plt.subplots()

nt = 2
na = 2
nb = 2

ax.scatter(train_data.in_data[:,0], train_data.lat_data[:,nt])
ax.scatter(train_data.in_data[:,0], acts[:,na])
ax.scatter(train_data.in_data[:,0], bsl_acts[:,nb])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [12]:
fig, ax = plt.subplots()

ax.scatter(train_data.in_data[:,0], train_data.target_data[:,0])
ax.scatter(train_data.in_data[:,0], preds[:,0])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Sigmoid SD predictions:

In [13]:
# plot losses
save_names = ["srnet_model_F11_v1_critic_check_sig"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models", excl_names=[])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [14]:
corr_data = [(f"g{i}", train_data.lat_data[:,i]) for i in range(train_data.lat_data.shape[1])]

In [15]:
model_path = "models"
model_ext = ".pkl"

for model_name in models:
    print(model_name)
    
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
            
    corr_mat = ut.node_correlations(acts, all_nodes, corr_data, show=True)
        
    print("")

srnet_model_F11_v1_critic_check_sig_lr_1e-4_sd_1e0
[0.24387582, 0.018124828, 0.0014003412]
[1, 0, 2]

Node 1
corr(n1, g0): -0.1024
corr(n1, g1): -0.7083
corr(n1, g2): 0.1569

Node 0
corr(n0, g0): 0.9688
corr(n0, g1): -0.7944
corr(n0, g2): -0.9542

Node 2
corr(n2, g0): -0.8762
corr(n2, g1): 0.8488
corr(n2, g2): 0.8538

srnet_model_F11_v1_critic_check_sig_lr_1e-4_sd_1e2
[0.13236833, 0.020032948, 0.001114716]
[1, 0, 2]

Node 1
corr(n1, g0): 0.3005
corr(n1, g1): -0.9292
corr(n1, g2): -0.2477

Node 0
corr(n0, g0): 0.9613
corr(n0, g1): -0.8098
corr(n0, g2): -0.9453

Node 2
corr(n2, g0): -0.8464
corr(n2, g1): 0.8506
corr(n2, g2): 0.8226

srnet_model_F11_v1_critic_check_sig_lr_1e-4_sd_1e3
[0.124692306, 0.01973096, 0.001017363]
[1, 0, 2]

Node 1
corr(n1, g0): 0.3125
corr(n1, g1): -0.9339
corr(n1, g2): -0.2599

Node 0
corr(n0, g0): 0.9591
corr(n0, g1): -0.8151
corr(n0, g2): -0.9427

Node 2
corr(n2, g0): -0.8447
corr(n2, g1): 0.8452
corr(n2, g2): 0.8213

srnet_model_F11_v1_critic_check_sig_lr_1e-

In [16]:
model_name = "srnet_model_F11_v1_critic_check_sig_lr_1e-4_sd_1e3"
model_path = "models"
model_ext = ".pkl"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)

In [17]:
fig, ax = plt.subplots()

nt = 0
na = 0

ax.scatter(train_data.in_data[:,0], train_data.lat_data[:,nt])
ax.scatter(train_data.in_data[:,0], acts[:,na])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [18]:
fig, ax = plt.subplots()

nt = 1
na = 1

ax.scatter(train_data.in_data[:,0], train_data.lat_data[:,nt])
ax.scatter(train_data.in_data[:,0], acts[:,na])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [19]:
fig, ax = plt.subplots()

nt = 2
na = 2

ax.scatter(train_data.in_data[:,0], train_data.lat_data[:,nt])
ax.scatter(train_data.in_data[:,0], acts[:,na])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [20]:
fig, ax = plt.subplots()

ax.scatter(train_data.in_data[:,0], train_data.target_data[:,0])
ax.scatter(train_data.in_data[:,0], preds[:,0])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

### Step 17.3: Run hyperparameter study

In [21]:
# set wandb project
wandb_project = "173-fixed-critic-study-F11_v1"

In [22]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2, 2),
#         "hid_size": (32, 32), 
#         "hid_type": ("MLP", "MLP"),
#         "hid_kwargs": {
#             "alpha": None,
#             "norm": None,
#             "prune": None,
#             },
#         "lat_size": 3,
#         },
#     "epochs": 20000,
#     "runtime": None,
#     "batch_size": train_data.in_data.shape[0],
#     "shuffle": False,
#     "lr": 1e-4,
#     "wd": 1e-7,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "e1": 0.0,
#     "e2": 0.0,
#     "e3": 0.0,
#     "gc": 0.0,
#     "sd": 1e-4,
#     "sd_fun": "linear",
#     # "ext": None,
#     # "ext_type": None,
#     # "ext_size": 0,
#     # "disc": {
#     #     "hid_num": 1,
#     #     "hid_size": 64,
#     #     "lr": 1e-4,
#     #     "wd": 1e-7,
#     #     "betas": (0.9,0.999),
#     #     "iters": 5,
#     #     "gp": 0.0,
#     #     "loss_fun": "BCE",
#     # },
# }

In [23]:
# define hyperparameter study
hp_study = {
    "method": "random",
    "parameters": {
        "arch": {
            "parameters": {
                "in_size": {
                    "values": [1]
                },
                "out_size": {
                    "values": [1]
                },
                "hid_num": {
                    "values": [1, 2, 4]
                },
                "hid_size": {
                    "values": [32, 64]
                },
                "lat_size": {
                    "values": [3]
                },
            }
        },
        "lr": {
            "values": [1e-6, 1e-5, 1e-4, 1e-3]
        },
        "sd": {
            "values": [1e-5, 1e-4, 1e-3, 1e-2]
        },
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

In [24]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2, 2),
#         "hid_size": (32, 32), 
#         "hid_type": ("MLP", "MLP"),
#         "hid_kwargs": {
#             "alpha": None,
#             "norm": None,
#             "prune": None,
#             },
#         "lat_size": 3,
#         },
#     "epochs": 20000,
#     "runtime": None,
#     "batch_size": train_data.in_data.shape[0],
#     "shuffle": False,
#     "lr": 1e-4,
#     "wd": 1e-7,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "e1": 0.0,
#     "e2": 0.0,
#     "e3": 0.0,
#     "gc": 0.0,
#     "sd": 1e-4,
#     "sd_fun": "sigmoid",
#     # "ext": None,
#     # "ext_type": None,
#     # "ext_size": 0,
#     # "disc": {
#     #     "hid_num": 1,
#     #     "hid_size": 64,
#     #     "lr": 1e-4,
#     #     "wd": 1e-7,
#     #     "betas": (0.9,0.999),
#     #     "iters": 5,
#     #     "gp": 0.0,
#     #     "loss_fun": "BCE",
#     # },
# }

In [25]:
# define hyperparameter study
hp_study = {
    "method": "random",
    "parameters": {
        "arch": {
            "parameters": {
                "in_size": {
                    "values": [1]
                },
                "out_size": {
                    "values": [1]
                },
                "hid_num": {
                    "values": [1, 2, 4]
                },
                "hid_size": {
                    "values": [32, 64]
                },
                "lat_size": {
                    "values": [3]
                },
            }
        },
        "lr": {
            "values": [1e-6, 1e-5, 1e-4, 1e-3]
        },
        "sd": {
            "values": [1e0, 1e1, 1e2, 1e3]
        },
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

**Linear**:

<img src="results/173-fixed-critic-study-F11_v1_lin2.png">

**Sigmoid**:

<img src="results/173-fixed-critic-study-F11_v1_sig2.png">

Utilizing **sigmoid** predictions yields lower validation errors and more runs with a `min_corr` value of 87% or higher.

One concern w.r.t. **linear** predictions is that positive predictions are maximized and taken into focus instead of minimizing the validation loss. Additionally, due to using sigmoid in critic training, a critic prediction of 1000 is not more certain than one of 100.

On the other hand, utilizing **sigmoid** predictions, once the latent features reach a certain similarity to the real data, there is basically no feedback from the critic anymore.

In [34]:
# plot losses
save_names = ["srnet_model_F11_v1_critic_study"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models", excl_names=[], log=False)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [28]:
corr_data = [(f"g{i}", train_data.lat_data[:,i]) for i in range(train_data.lat_data.shape[1])]

In [29]:
model_path = "models"
model_ext = ".pkl"

for model_name in models:
    print(model_name)
    
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
            
    corr_mat = ut.node_correlations(acts, all_nodes, corr_data, show=True)
        
    print("")

srnet_model_F11_v1_critic_study_lin_v1
[43267.44, 18236.11, 6458.543]
[0, 2, 1]

Node 0
corr(n0, g0): 0.5264
corr(n0, g1): -0.9900
corr(n0, g2): -0.4790

Node 2
corr(n2, g0): 0.4867
corr(n2, g1): -0.9836
corr(n2, g2): -0.4380

Node 1
corr(n1, g0): 0.4786
corr(n1, g1): -0.9820
corr(n1, g2): -0.4297

srnet_model_F11_v1_critic_study_lin_v2
[1.0286546, 0.9936166, 0.5156132]
[2, 1, 0]

Node 2
corr(n2, g0): 0.9013
corr(n2, g1): -0.2320
corr(n2, g2): -0.9237

Node 1
corr(n1, g0): 0.4617
corr(n1, g1): -0.9793
corr(n1, g2): -0.4123

Node 0
corr(n0, g0): 0.8877
corr(n0, g1): -0.2018
corr(n0, g2): -0.9116

srnet_model_F11_v1_critic_study_sig_v1
[0.041821864, 0.03000773, 0.024630899]
[0, 2, 1]

Node 0
corr(n0, g0): -0.9895
corr(n0, g1): 0.5193
corr(n0, g2): 0.9953

Node 2
corr(n2, g0): 0.9962
corr(n2, g1): -0.6546
corr(n2, g2): -0.9927

Node 1
corr(n1, g0): -0.9944
corr(n1, g1): 0.5685
corr(n1, g2): 0.9970

srnet_model_F11_v1_critic_study_sig_v2
[0.025904426, 0.021181995, 0.020423802]
[0, 1, 2]

N

In [35]:
model_name = "srnet_model_F11_v1_critic_study_sig_v1"
model_path = "models"
model_ext = ".pkl"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)

In [36]:
fig, ax = plt.subplots()

nt = 0
na = 0

ax.scatter(train_data.in_data[:,0], train_data.lat_data[:,nt])
ax.scatter(train_data.in_data[:,0], acts[:,na])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [37]:
fig, ax = plt.subplots()

nt = 1
na = 1

ax.scatter(train_data.in_data[:,0], train_data.lat_data[:,nt])
ax.scatter(train_data.in_data[:,0], acts[:,na])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [38]:
fig, ax = plt.subplots()

nt = 2
na = 2

ax.scatter(train_data.in_data[:,0], train_data.lat_data[:,nt])
ax.scatter(train_data.in_data[:,0], acts[:,na])

plt.show()

  fig, ax = plt.subplots()


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [39]:
fig, ax = plt.subplots()

ax.scatter(train_data.in_data[:,0], train_data.target_data[:,0])
ax.scatter(train_data.in_data[:,0], preds[:,0])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Learning quadratic latent features works well, however, since we are not learning the exact underlying quadratic functions (but transformations of these), discovering the nonlinear function $f(x)$ might be difficult.

In [None]:
# ut.save_preds(preds, "F11_p1", "data_1k", model_name)

In [None]:
# ut.save_preds(acts, "G11_p1", "data_1k", model_name)