# 2022 Flatiron Machine Learning x Science Summer School

## Step 9: Implement symbolic discriminator

In this step, we train a symbolic discriminator to regularize the latent features of our `SRNet`.

The aim is that the latent features are incentivized to resemble functions that are easily discoverable with symbolic regression.

The plan is training a Wasserstein Generative Adversarial Network (WGAN) with gradient penalty. The advantage of a WGAN over a regular GAN is that it avoids mode collapse and vanishing gradients. The WGAN requires the discriminator to be 1-Lipschitz continuous, which is softly enforced with gradient penalty (check https://www.youtube.com/watch?v=v6y5qQ0pcg4).

First, we are going to apply a WGAN to the simplest possible formulation of our problem:

* Two input features

* Bottleneck

* <s>Fixed $\alpha$ mask values</s>

* <s>Pretrained weights</s>

* Only target functions as real data

Notes: 

* We probably do not want to shuffle the input data.

Questions:

* Do we feed the discriminator only latent feature values or also input feature values?

* Do we want different batches, i.e. the `SRNet` and real functions evaluated at different positions?

**TODO**:

* Try `LeakyReLU`, `BatchNorm`

* Check https://github.com/aladdinpersson/Machine-Learning-Collection

* Check `prior.jl`

* Check Miles' code (https://colab.research.google.com/drive/1IoJh46JV8EaxWjF3kDS51MZyAdXKzBxc?usp=sharing#scrollTo=4G3ZNAUeCxMX)

### Step 9.1: Check bottleneck DSN with SD regularization

In [1]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
import joblib

import torch
import wandb

from srnet import SRNet, SRData
import srnet_utils as ut

In [2]:
# plot losses
save_names = ["srnet_model_F00_v1_bn_norm_sd_1e-04_check"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models", excl_names=["max"])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [4]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [5]:
model_name = "srnet_model_F00_v1_bn_norm_sd_1e-04_check"
model_path = "models"
model_ext = ".pkl"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

print(model.layers1.alpha.detach().cpu().numpy()[all_nodes])
print(model.layers1.norm(model.layers1.alpha).detach().cpu().numpy()[all_nodes])

print("")

[1.8365434, 0.94673306, 0.1891221]
[2, 1, 0]
[[-1.4036835  -1.0781626 ]
 [-1.9185202   0.82522213]
 [ 1.3652941  -0.46780935]]
[[0.58066916 0.41933084]
 [0.7490023  0.25099772]
 [0.7104324  0.28956768]]



In [6]:
nodes = all_nodes[:3]

In [7]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [8]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
corr_data = [
    ("x**2", x_data**2), 
    ("cos(y)", np.cos(y_data)),
    ("x*y", x_data * y_data),
]

In [10]:
ut.node_correlations(acts, nodes, corr_data);


Node 2
corr(n2, x**2): 0.9999
corr(n2, cos(y)): -0.0170
corr(n2, x*y): 0.1786

Node 1
corr(n1, x**2): 0.1706
corr(n1, cos(y)): -0.0622
corr(n1, x*y): 0.9999

Node 0
corr(n0, x**2): -0.0153
corr(n0, cos(y)): 0.9997
corr(n0, x*y): -0.0672


In [11]:
fig, ax = plt.subplots()

n = 2
bias = True

ax.scatter(x_data, x_data**2, label="Target: $x_0^2$")
ax.scatter(x_data, model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item(), label="DSN SD")

ax.set_xlabel("$x_0$")
ax.set_ylabel("$g_0$")
ax.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [18]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    #("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [19]:
_ = plt.figure()
ax = plt.axes(projection='3d')

xp_data = x_data[:plot_size].numpy()
yp_data = y_data[:plot_size].numpy()

for d in z_data:
    zp_data = d[1]
    ax.scatter3D(xp_data, yp_data, zp_data, label="Target: $x_0 \cdot x_1$")
    
n = 1
b = model.layers2[0]._parameters['bias'].item()
w = model.layers2[0]._parameters['weight'].detach().numpy()[0, n]
n_data = w * acts[:,n][:plot_size].numpy() + b
ax.scatter3D(x_data, y_data, n_data, label="DSN SD")

ax.set_xlabel("$x_0$")
ax.set_ylabel("$x_1$")
ax.set_zlabel("$g_2$")
ax.view_init(elev=5.16234, azim=-56.591)
ax.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [20]:
fig, ax = plt.subplots()

n = 0
bias = True

ax.scatter(y_data, np.cos(y_data), label="Target: cos$(x_1)$")
ax.scatter(y_data, model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item(), label="DSN SD")

ax.set_xlabel("$x_1$")
ax.set_ylabel("$g_1$")
ax.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

### Step 9.1.1: Save predictions of bottleneck DSN with SD regularization

In [21]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = None
target_var = "F00"

all_data = SRData(data_path, in_var, lat_var, target_var)

In [24]:
model_name = "srnet_model_F00_v1_bn_norm_sd_1e-04_check"
model_path = "models"
model_ext = ".pkl"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(all_data.in_data, get_lat=True)

In [27]:
ut.save_preds(preds, "F00_p1", "data_1k", model_name)

In [28]:
ut.save_preds(acts, "G00_p1", "data_1k", model_name)

### Step 9.2: Explore SD hyperparameters

In [29]:
# set wandb project
wandb_project = "92-bn-DSN-sd-study-F00"

In [30]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("DSN", "MLP"),
#         "hid_kwargs": {
#             "alpha": None,
#             "norm": "softmax",
#             "prune": None,
#             },
#         "lat_size": 3,
#         },
#     "epochs": 30000,
#     "runtime": None,
#     "batch_size": train_data.in_data.shape[0],
#     "shuffle": False,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "e1": 0.0,
#     "e2": 0.0,
#     "gc": 0.0,
#     "sd": 1e-4,
#     "disc": {
#         "hid_num": 2,
#         "hid_size": 128,
#         "lr": 1e-4,
#         "wd": 1e-4,
#         "iters": 5,
#         "gp": 1e-4,
#     },
# }

In [31]:
# define hyperparameter study
hp_study = {
    "method": "random",
    "parameters": {
        "sd": {
            "values": [1e-6, 1e-4, 1e-2, 1e0]
        },
        "disc": {
            "parameters": {
                "hid_num": {
                    "values": [1, 2, 4]
                },
                "hid_size": {
                    "values": [32, 64, 128, 256]
                },
                "lr": {
                    "values": [1e-5, 1e-4, 1e-3]
                },
                "iters": {
                    "values": [1, 2, 5, 10]
                },
                "gp": {
                    "values": [1e-5, 1e-4, 1e-3]
                },
            }
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

<img src="results/92-bn-DSN-sd-study-F00.png">

Using only target functions in the SD library appears to be robust w.r.t. the studied hyperparameters.

### Step 9.3: Extend SD library

Version 1:

```
X00[:,0]**2
np.cos(X00[:,1])
X00[:,0] * X00[:,1]
```

Version 2:

```
X00[:,0]**2
X00[:,1]**2
np.cos(X00[:,0])
np.cos(X00[:,1])
X00[:,0] * X00[:,1]
```

In [32]:
# plot losses
save_names = ["F00_v2_bn_norm_sd"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models", excl_names=["study"])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [33]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [34]:
model_name = "srnet_model_F00_v2_bn_norm_sd_1e-06_check"
model_path = "models"
model_ext = ".pkl"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

print(model.layers1.alpha.detach().cpu().numpy()[all_nodes])
print(model.layers1.norm(model.layers1.alpha).detach().cpu().numpy()[all_nodes])

print("")

[1.6364664, 0.9781076, 0.45337594]
[2, 1, 0]
[[-1.2925727 -1.1796976]
 [-1.8483588  0.8866545]
 [ 1.3686377 -0.4594895]]
[[0.5281889  0.47181118]
 [0.7234629  0.2765371 ]
 [0.71282583 0.28717417]]



In [35]:
nodes = all_nodes[:3]

In [36]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    ("pred", preds),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [37]:
ut.plot_acts(x_data, y_data, z_data, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [38]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [39]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [40]:
corr_data = [
    ("x**2", x_data**2), 
    ("y**2", y_data**2), 
    ("cos(x)", np.cos(x_data)), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]

In [41]:
ut.node_correlations(acts, nodes, corr_data);


Node 2
corr(n2, x**2): 0.8389
corr(n2, y**2): 0.4236
corr(n2, cos(x)): -0.8093
corr(n2, cos(y)): -0.4279
corr(n2, x*y): 0.4568

Node 1
corr(n1, x**2): 0.0089
corr(n1, y**2): 0.9942
corr(n1, cos(x)): -0.0237
corr(n1, cos(y)): -0.9734
corr(n1, x*y): 0.0456

Node 0
corr(n0, x**2): -0.4328
corr(n0, y**2): -0.6815
corr(n0, cos(x)): 0.4422
corr(n0, cos(y)): 0.7030
corr(n0, x*y): 0.2161


In [42]:
fig, ax = plt.subplots()

ax.scatter(x_data, x_data**2)
ax.scatter(x_data, acts[:,2])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [43]:
fig, ax = plt.subplots()

ax.scatter(y_data, np.cos(y_data))
ax.scatter(y_data, acts[:,1])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [44]:
fig, ax = plt.subplots()

ax.scatter(y_data, np.cos(y_data))
ax.scatter(y_data, acts[:,0])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Using a slightly extended SD library, the target latent features are not identified clearly anymore.

Let's run a hyperparameter study before rushing to conclusions.

In [45]:
# set wandb project
wandb_project = "93-bn-DSN-sd-study-F00_v2"

In [46]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("DSN", "MLP"),
#         "hid_kwargs": {
#             "alpha": None,
#             "norm": "softmax",
#             "prune": None,
#             },
#         "lat_size": 3,
#         },
#     "epochs": 30000,
#     "runtime": None,
#     "batch_size": train_data.in_data.shape[0],
#     "shuffle": False,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "e1": 0.0,
#     "e2": 0.0,
#     "gc": 0.0,
#     "sd": 1e-4,
#     "disc": {
#         "hid_num": 2,
#         "hid_size": 128,
#         "lr": 1e-4,
#         "wd": 1e-4,
#         "iters": 5,
#         "gp": 1e-4,
#     },
# }

In [47]:
# define hyperparameter study
hp_study = {
    "method": "random",
    "parameters": {
        "sd": {
            "values": [1e-7, 1e-6, 1e-5, 1e-4]
        },
        "disc": {
            "parameters": {
                "hid_num": {
                    "values": [1, 2, 4]
                },
                "hid_size": {
                    "values": [32, 128, 256]
                },
                "lr": {
                    "values": [1e-5, 1e-4, 1e-3]
                },
                "iters": {
                    "values": [1, 5, 10]
                },
                "gp": {
                    "values": [1e-5, 1e-4, 1e-3]
                },
            }
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

<img src="results/93-bn-DSN-sd-study-F00_v2.png">

An `sd` value of `1e-6` or lower seems necessary to achieve low validation errors.

In [None]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for r, run in enumerate(runs):
    for f in run.files():
        if f.name[-len(file_ext):] == file_ext:
            file_name = f.name.replace(file_ext, f"_v{r+1}{file_ext}")
            print(f"Downloading {os.path.basename(file_name)}.")
            run.file(f.name).download()
            os.rename(f.name, file_name)

In [48]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [49]:
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]

In [50]:
corr_data = [
    ("x**2", x_data**2), 
    # ("y**2", y_data**2), 
    # ("cos(x)", np.cos(x_data)), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]

In [51]:
# get validation loss and latent feature correlations
model_path = "models"
save_name = "F00_v2_bn_norm_sd_study"

models = [f for f in os.listdir(model_path) if save_name in f]

val_corr = {}

for model_name in models:
    print(f"Loading {model_name}.")
    model = ut.load_model(model_name, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=False)
        
    corr_mat = ut.node_correlations(acts, all_nodes, corr_data, show=False)
    corr = [np.abs(c).max() for c in corr_mat]
    
    with torch.no_grad():
        preds = model(val_data.in_data)
        
    val_loss = (preds - val_data.target_data).pow(2).mean().item()
    val_corr[model_name] = (val_loss, corr)

Loading srnet_model_F00_v2_bn_norm_sd_study_v1.pkl.
Loading srnet_model_F00_v2_bn_norm_sd_study_v10.pkl.
Loading srnet_model_F00_v2_bn_norm_sd_study_v11.pkl.
Loading srnet_model_F00_v2_bn_norm_sd_study_v12.pkl.
Loading srnet_model_F00_v2_bn_norm_sd_study_v13.pkl.
Loading srnet_model_F00_v2_bn_norm_sd_study_v14.pkl.
Loading srnet_model_F00_v2_bn_norm_sd_study_v15.pkl.
Loading srnet_model_F00_v2_bn_norm_sd_study_v16.pkl.
Loading srnet_model_F00_v2_bn_norm_sd_study_v17.pkl.
Loading srnet_model_F00_v2_bn_norm_sd_study_v2.pkl.
Loading srnet_model_F00_v2_bn_norm_sd_study_v3.pkl.
Loading srnet_model_F00_v2_bn_norm_sd_study_v4.pkl.
Loading srnet_model_F00_v2_bn_norm_sd_study_v5.pkl.
Loading srnet_model_F00_v2_bn_norm_sd_study_v6.pkl.
Loading srnet_model_F00_v2_bn_norm_sd_study_v7.pkl.
Loading srnet_model_F00_v2_bn_norm_sd_study_v8.pkl.
Loading srnet_model_F00_v2_bn_norm_sd_study_v9.pkl.


In [52]:
fig, ax = plt.subplots()

for v in val_corr:
    ax.plot(val_corr[v][0], np.min(val_corr[v][1]), 'x', label=v.split('.')[0].split('_')[-1])

ax.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [53]:
# plot losses
save_names = [
    "srnet_model_F00_v2_bn_norm_sd_study_v3",
    "srnet_model_F00_v2_bn_norm_sd_study_v7",
    "srnet_model_F00_v2_bn_norm_sd_study_v11",
    "srnet_model_F00_v2_bn_norm_sd_study_v13",
]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models", excl_names=[])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [54]:
corr_data = [
    ("x**2", x_data**2), 
    ("y**2", y_data**2), 
    ("cos(x)", np.cos(x_data)), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]

In [55]:
model_path = "models"
model_ext = ".pkl"

for model_name in models:
    print(model_name)
    
    state = joblib.load(os.path.join(model_path, model_name + model_ext))
    
    print(state['hyperparams']['sd'])
    print(state['hyperparams']['disc'])
    
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
    
    print(f"Validation error: {state['total_val_loss']:.4e}")
        
    corr_mat = ut.node_correlations(acts, all_nodes, corr_data, show=True)
        
    print("")

srnet_model_F00_v2_bn_norm_sd_study_v3
1e-07
{'gp': 1e-05, 'hid_num': 4, 'hid_size': 32, 'iters': 1, 'lr': 0.001}
[1.9866552, 0.34647146, 0.29016]
[1, 0, 2]
Validation error: 1.9998e-02

Node 1
corr(n1, x**2): 0.9858
corr(n1, y**2): -0.0740
corr(n1, cos(x)): -0.9404
corr(n1, cos(y)): 0.0734
corr(n1, x*y): 0.1362

Node 0
corr(n0, x**2): -0.9077
corr(n0, y**2): 0.3623
corr(n0, cos(x)): 0.8990
corr(n0, cos(y)): -0.3441
corr(n0, x*y): -0.1231

Node 2
corr(n2, x**2): 0.1456
corr(n2, y**2): -0.0883
corr(n2, cos(x)): -0.1178
corr(n2, cos(y)): 0.0850
corr(n2, x*y): 0.9768

srnet_model_F00_v2_bn_norm_sd_study_v7
1e-07
{'gp': 1e-05, 'hid_num': 4, 'hid_size': 32, 'iters': 10, 'lr': 0.001}
[1.3430996, 1.0023396, 0.09775755]
[0, 2, 1]
Validation error: 1.9630e-02

Node 0
corr(n0, x**2): 0.0254
corr(n0, y**2): 0.9975
corr(n0, cos(x)): -0.0362
corr(n0, cos(y)): -0.9699
corr(n0, x*y): -0.0050

Node 2
corr(n2, x**2): 0.8989
corr(n2, y**2): 0.0311
corr(n2, cos(x)): -0.8482
corr(n2, cos(y)): -0.0263
corr

`v3`: No correlation with $\text{cos}(y)$

`v7`: Low correlation with $x \cdot y$

`v11`: Somewhat reasonable correlations

`v13`: Good correlations (due to higher `sd`?), but more strongly correlated with $y^2$ than $\text{cos}(y)$ and rather high validation error

All runs have a high learning rate and a deep architecture in common. 

`gp` is low for low `hid_size`, but this could be a coincidence.

In [56]:
model_name = "srnet_model_F00_v2_bn_norm_sd_study_v13"
model_path = "models"
model_ext = ".pkl"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

print(model.layers1.norm(model.layers1.alpha).detach().cpu().numpy()[all_nodes])

print("")

[1.4492056, 1.1712595, 0.68234366]
[1, 2, 0]
[[0.8143813  0.18561867]
 [0.36721858 0.63278145]
 [0.76156855 0.23843145]]



In [57]:
nodes = all_nodes[:3]

In [58]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [59]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [60]:
corr_data = [
    ("x**2", x_data**2), 
    ("y**2", y_data**2), 
    ("cos(x)", np.cos(x_data)), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]

In [61]:
ut.node_correlations(acts, nodes, corr_data);


Node 1
corr(n1, x**2): 0.9979
corr(n1, y**2): -0.0344
corr(n1, cos(x)): -0.9556
corr(n1, cos(y)): 0.0358
corr(n1, x*y): 0.1675

Node 2
corr(n2, x**2): -0.0542
corr(n2, y**2): 0.9968
corr(n2, cos(x)): 0.0438
corr(n2, cos(y)): -0.9611
corr(n2, x*y): 0.0357

Node 0
corr(n0, x**2): 0.2539
corr(n0, y**2): 0.0047
corr(n0, cos(x)): -0.2298
corr(n0, cos(y)): -0.0204
corr(n0, x*y): 0.9936


In [62]:
fig, ax = plt.subplots()

n = 1
bias = False

ax.scatter(x_data, x_data**2)
# ax.scatter(x_data, acts[:,n])
ax.scatter(x_data, model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item())

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [63]:
fig, ax = plt.subplots()

n = 2
bias = True

ax.scatter(y_data, np.cos(y_data))
ax.scatter(y_data, model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item())

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [64]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    #("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [65]:
n = 0
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=[n], model=model, bias=False, nonzero=False, agg=False, plot_size=plot_size)

  _ = plt.figure()


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

For `v13`, the target functions are matched nicely, both in terms of correlations and dependencies. Only the similarity of $y^2$ and $\text{cos}(y)$ leads to some differences.

**Note**: There is an additional bias term, which is added to node 1 and subtracted from node 0.

For other runs, such as `v11`, the matches are worse.

### Step 9.4: Train bottleneck masked DSN with SD regularization

In [66]:
# set wandb project
wandb_project = "94-bn-mask-DSN-sd-study-F00_v2"

In [67]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("DSN", "MLP"),
#         "hid_kwargs": {
#             "alpha": [[1,0],[0,1],[1,1]],
#             "norm": None,
#             "prune": None,
#             },
#         "lat_size": 3,
#         },
#     "epochs": 30000,
#     "runtime": None,
#     "batch_size": train_data.in_data.shape[0],
#     "shuffle": False,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "e1": 0.0,
#     "e2": 0.0,
#     "gc": 0.0,
#     "sd": 1e-4,
#     "disc": {
#         "hid_num": 2,
#         "hid_size": 128,
#         "lr": 1e-4,
#         "wd": 1e-4,
#         "iters": 5,
#         "gp": 1e-4,
#     },
# }

In [68]:
# define hyperparameter study
hp_study = {
    "method": "random",
    "parameters": {
        "sd": {
            "values": [1e-7, 1e-6, 1e-5, 1e-4]
        },
        "disc": {
            "parameters": {
                "hid_num": {
                    "values": [2, 4, 6]
                },
                "hid_size": {
                    "values": [32, 128, 256]
                },
                "lr": {
                    "values": [1e-5, 1e-4, 1e-3, 1e-2]
                },
                "iters": {
                    "values": [1, 5, 10]
                },
                "gp": {
                    "values": [1e-5, 1e-4, 1e-3]
                },
            }
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

In [69]:
sweep_name = "hvuomrrb"

<img src="results/94-bn-mask-DSN-sd-study-F00_v2.png">

Quick notes:

* Large `sd` values can lead to high validation errors, while a value of `1e-7` is too low

* A large architecture (`hid_num` $\times$ `hid_size`) seems beneficial

* Large learning rates are not necessary, but work

* A large number of discriminator iterations do not seems to be necessary

In [70]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for r, run in enumerate(runs):
    if run.sweep.name == sweep_name:
        for f in run.files():
            if f.name[-len(file_ext):] == file_ext:
                file_name = f.name.replace(file_ext, f"_v{r+1}{file_ext}")
                print(f"Downloading {os.path.basename(file_name)}.")
                run.file(f.name).download()
                os.rename(f.name, file_name)

Downloading srnet_model_F00_v2_bn_mask_sd_study_v11.pkl.
Downloading srnet_model_F00_v2_bn_mask_sd_study_v12.pkl.
Downloading srnet_model_F00_v2_bn_mask_sd_study_v13.pkl.
Downloading srnet_model_F00_v2_bn_mask_sd_study_v14.pkl.
Downloading srnet_model_F00_v2_bn_mask_sd_study_v15.pkl.
Downloading srnet_model_F00_v2_bn_mask_sd_study_v16.pkl.
Downloading srnet_model_F00_v2_bn_mask_sd_study_v17.pkl.
Downloading srnet_model_F00_v2_bn_mask_sd_study_v18.pkl.
Downloading srnet_model_F00_v2_bn_mask_sd_study_v19.pkl.
Downloading srnet_model_F00_v2_bn_mask_sd_study_v20.pkl.
Downloading srnet_model_F00_v2_bn_mask_sd_study_v21.pkl.
Downloading srnet_model_F00_v2_bn_mask_sd_study_v22.pkl.
Downloading srnet_model_F00_v2_bn_mask_sd_study_v23.pkl.
Downloading srnet_model_F00_v2_bn_mask_sd_study_v24.pkl.
Downloading srnet_model_F00_v2_bn_mask_sd_study_v25.pkl.
Downloading srnet_model_F00_v2_bn_mask_sd_study_v26.pkl.
Downloading srnet_model_F00_v2_bn_mask_sd_study_v27.pkl.
Downloading srnet_model_F00_v2_

In [71]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [72]:
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]

In [73]:
corr_data = [
    ("x**2", x_data**2), 
    # ("y**2", y_data**2), 
    # ("cos(x)", np.cos(x_data)), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]

In [76]:
# get validation loss and latent feature correlations
model_path = "models"
save_name = "F00_v2_bn_mask_sd_study"

models = [f for f in os.listdir(model_path) if save_name in f and "ext" not in f]

val_corr = {}

for model_name in models:
    print(f"Loading {model_name}.")
    model = ut.load_model(model_name, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=False)
        
    corr_mat = ut.node_correlations(acts, all_nodes, corr_data, show=False)
    corr = [np.abs(c).max() for c in corr_mat]
    
    with torch.no_grad():
        preds = model(val_data.in_data)
        
    val_loss = (preds - val_data.target_data).pow(2).mean().item()
    val_corr[model_name] = (val_loss, corr)

Loading srnet_model_F00_v2_bn_mask_sd_study_v11.pkl.
Loading srnet_model_F00_v2_bn_mask_sd_study_v12.pkl.
Loading srnet_model_F00_v2_bn_mask_sd_study_v13.pkl.
Loading srnet_model_F00_v2_bn_mask_sd_study_v14.pkl.
Loading srnet_model_F00_v2_bn_mask_sd_study_v15.pkl.
Loading srnet_model_F00_v2_bn_mask_sd_study_v16.pkl.
Loading srnet_model_F00_v2_bn_mask_sd_study_v17.pkl.
Loading srnet_model_F00_v2_bn_mask_sd_study_v18.pkl.
Loading srnet_model_F00_v2_bn_mask_sd_study_v19.pkl.
Loading srnet_model_F00_v2_bn_mask_sd_study_v20.pkl.
Loading srnet_model_F00_v2_bn_mask_sd_study_v21.pkl.
Loading srnet_model_F00_v2_bn_mask_sd_study_v22.pkl.
Loading srnet_model_F00_v2_bn_mask_sd_study_v23.pkl.
Loading srnet_model_F00_v2_bn_mask_sd_study_v24.pkl.
Loading srnet_model_F00_v2_bn_mask_sd_study_v25.pkl.
Loading srnet_model_F00_v2_bn_mask_sd_study_v26.pkl.
Loading srnet_model_F00_v2_bn_mask_sd_study_v27.pkl.
Loading srnet_model_F00_v2_bn_mask_sd_study_v28.pkl.
Loading srnet_model_F00_v2_bn_mask_sd_study_v2

In [77]:
fig, ax = plt.subplots()

for v in val_corr:
    ax.plot(val_corr[v][0], np.min(val_corr[v][1]), 'x', label=v.split('.')[0].split('_')[-1])

ax.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [78]:
save_names = [s for s in val_corr if val_corr[s][0] < 0.03 and min(val_corr[s][1]) > 0.95]
save_names

['srnet_model_F00_v2_bn_mask_sd_study_v13.pkl',
 'srnet_model_F00_v2_bn_mask_sd_study_v17.pkl',
 'srnet_model_F00_v2_bn_mask_sd_study_v21.pkl',
 'srnet_model_F00_v2_bn_mask_sd_study_v22.pkl',
 'srnet_model_F00_v2_bn_mask_sd_study_v27.pkl',
 'srnet_model_F00_v2_bn_mask_sd_study_v33.pkl']

In [79]:
# plot losses
# save_names = ["srnet_model_F00_v2_bn_mask_sd_study_v3"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models", excl_names=[])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [80]:
corr_data = [
    ("x**2", x_data**2), 
    ("y**2", y_data**2), 
    ("cos(x)", np.cos(x_data)), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]

In [81]:
model_path = "models"
model_ext = ".pkl"

for model_name in models:
    print(model_name)
    
    state = joblib.load(os.path.join(model_path, model_name + model_ext))
    
    print(state['hyperparams']['sd'])
    print(state['hyperparams']['disc'])
    
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
    
    print(f"Validation error: {state['total_val_loss']:.4e}")
        
    corr_mat = ut.node_correlations(acts, all_nodes, corr_data, show=True)
        
    print("")

srnet_model_F00_v2_bn_mask_sd_study_v13
1e-06
{'gp': 0.001, 'hid_num': 2, 'hid_size': 256, 'iters': 5, 'lr': 0.001}
[1.4469135, 1.2598801, 0.13963567]
[1, 0, 2]
Validation error: 3.0945e-03

Node 1
corr(n1, x**2): 0.0168
corr(n1, y**2): 0.9979
corr(n1, cos(x)): -0.0306
corr(n1, cos(y)): -0.9754
corr(n1, x*y): 0.0531

Node 0
corr(n0, x**2): 0.9991
corr(n0, y**2): 0.0163
corr(n0, cos(x)): -0.9536
corr(n0, cos(y)): -0.0178
corr(n0, x*y): 0.1788

Node 2
corr(n2, x**2): -0.3547
corr(n2, y**2): 0.1709
corr(n2, cos(x)): 0.3249
corr(n2, cos(y)): -0.1721
corr(n2, x*y): -0.9509

srnet_model_F00_v2_bn_mask_sd_study_v17
1e-05
{'gp': 1e-05, 'hid_num': 6, 'hid_size': 256, 'iters': 5, 'lr': 0.01}
[1.4544828, 1.2752091, 0.6358797]
[0, 1, 2]
Validation error: 2.3825e-02

Node 0
corr(n0, x**2): 0.9994
corr(n0, y**2): 0.0162
corr(n0, cos(x)): -0.9598
corr(n0, cos(y)): -0.0176
corr(n0, x*y): 0.1743

Node 1
corr(n1, x**2): 0.0146
corr(n1, y**2): 0.9990
corr(n1, cos(x)): -0.0289
corr(n1, cos(y)): -0.9685
co

Generally, having a fixed $\alpha$ mask, which can be determined with entropy regularization, improves the results. Instead of **1/17** training runs achieving good correlations without the mask, **6/23** training runs achieve good correlations with a fixed $\alpha$ mask.

Nevertheless, the issue of converging to $y^2$ instead of $\text{cos}(y)$ due to their similarity remains.

Furthermore, there seems to be trade-off between precision and correlation (due to not converging to $\text{cos}(y)$).

```
v17: 0.9961, 2.3825e-02     1e-05,  {'gp': 1e-05, 'hid_num': 6, 'hid_size': 256, 'iters': 5, 'lr': 0.01}
v21: 0.9976, 2.2615e-02    0.0001, {'gp': 1e-05, 'hid_num': 6, 'hid_size': 32,  'iters': 5, 'lr': 0.01}
v33: 0.9969, 2.5827e-02    1e-05,  {'gp': 1e-05, 'hid_num': 6, 'hid_size': 128, 'iters': 5, 'lr': 0.001}

v13: 0.9509, 3.0945e-03     1e-06,  {'gp': 0.001, 'hid_num': 2, 'hid_size': 256, 'iters': 5, 'lr': 0.001}
v27: 0.9618, 2.1251e-03    1e-06,  {'gp': 0.001, 'hid_num': 4, 'hid_size': 256, 'iters': 1, 'lr': 0.0001}
v22: 0.9566, 7.2420e-03    1e-05,  {'gp': 0.001, 'hid_num': 6, 'hid_size': 256, 'iters': 5, 'lr': 1e-05}
```

Okay, for improved correlations (and worse precision), a larger `sd` value, a lower `gp` value, a larger architecture and a higher learning rate are required.

In [82]:
model_name = "srnet_model_F00_v2_bn_mask_sd_study_v21"
model_path = "models"
model_ext = ".pkl"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

print(model.layers1.norm(model.layers1.alpha).detach().cpu().numpy()[all_nodes])

print("")

[1.3792901, 1.2907324, 0.71521765]
[0, 1, 2]
[[1. 0.]
 [0. 1.]
 [1. 1.]]



In [83]:
nodes = all_nodes[:3]

In [84]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [85]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [86]:
corr_data = [
    ("x**2", x_data**2), 
    ("y**2", y_data**2), 
    ("cos(x)", np.cos(x_data)), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]

In [87]:
ut.node_correlations(acts, nodes, corr_data);


Node 0
corr(n0, x**2): 0.9997
corr(n0, y**2): 0.0161
corr(n0, cos(x)): -0.9573
corr(n0, cos(y)): -0.0175
corr(n0, x*y): 0.1751

Node 1
corr(n1, x**2): 0.0138
corr(n1, y**2): 0.9988
corr(n1, cos(x)): -0.0280
corr(n1, cos(y)): -0.9688
corr(n1, x*y): 0.0449

Node 2
corr(n2, x**2): 0.1615
corr(n2, y**2): -0.0026
corr(n2, cos(x)): -0.1379
corr(n2, cos(y)): -0.0149
corr(n2, x*y): 0.9976


In [88]:
fig, ax = plt.subplots()

n = 0
bias = False

ax.scatter(x_data, x_data**2)
# ax.scatter(x_data, acts[:,n])
ax.scatter(x_data, model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item())

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [89]:
fig, ax = plt.subplots()

n = 1
bias = True

ax.scatter(y_data, np.cos(y_data))
ax.scatter(y_data, model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item())

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [90]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    #("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [91]:
n = 2
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=[n], model=model, bias=False, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Again, the target functions are matched nicely, both in terms of correlations and dependencies. Only the similarity of $y^2$ and $\text{cos}(y)$ leads to some differences.

Let's explore more extreme hyperparameters.

In [92]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("DSN", "MLP"),
#         "hid_kwargs": {
#             "alpha": [[1,0],[0,1],[1,1]],
#             "norm": None,
#             "prune": None,
#             },
#         "lat_size": 3,
#         },
#     "epochs": 30000,
#     "runtime": None,
#     "batch_size": train_data.in_data.shape[0],
#     "shuffle": False,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "e1": 0.0,
#     "e2": 0.0,
#     "gc": 0.0,
#     "sd": 1e-4,
#     "disc": {
#         "hid_num": 2,
#         "hid_size": 128,
#         "lr": 1e-4,
#         "wd": 1e-4,
#         "iters": 5,
#         "gp": 1e-4,
#     },
# }

In [93]:
# define hyperparameter study
hp_study = {
    "method": "random",
    "parameters": {
        "sd": {
            "values": [1e-5, 1e-4, 1e-3]
        },
        "disc": {
            "parameters": {
                "hid_num": {
                    "values": [4, 6, 8, 10]
                },
                "hid_size": {
                    "values": [128, 256, 512]
                },
                "lr": {
                    "values": [1e-4, 1e-3, 1e-2, 1e-1]
                },
                "iters": {
                    "values": [5]
                },
                "gp": {
                    "values": [1e-3, 1e-1, 1e1]
                },
            }
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

<img src="results/94-bn-mask-DSN-sd-study-F00_v2_ext.png">

More extreme hyperparameters do not lead to improved results.

### Step 9.5: Train bottleneck masked DSN with SD regularization and failed embedding

In the `wandb` sweep, the subdictionary `disc` is replaced completely. Thus, only the in-the-sweep-defined parameters remain in the dictionary. Since the embedding size was defined as a `disc` parameter, it is missing and no additional information is embedded.

In [94]:
# set wandb project
wandb_project = "95-bn-mask-DSN-sd-study-F00_v2-emb-fail"

In [95]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("DSN", "MLP"),
#         "hid_kwargs": {
#             "alpha": [[1,0],[0,1],[1,1]],
#             "norm": None,
#             "prune": None,
#             },
#         "lat_size": 3,
#         },
#     "epochs": 30000,
#     "runtime": None,
#     "batch_size": train_data.in_data.shape[0],
#     "shuffle": False,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "e1": 0.0,
#     "e2": 0.0,
#     "gc": 0.0,
#     "sd": 1e-7,
#     "disc": {
#         "hid_num": (1,3),
#         "hid_size": (32,128),
#         "emb_size": train_data.in_data.shape[1] + 1,
#         "lr": 1e-3,
#         "wd": 1e-4,
#         "iters": 5,
#         "gp": 1e-5,
#     },
# }

In [96]:
# define hyperparameter study
hp_study = {
    "method": "random",
    "parameters": {
        "sd": {
            "values": [1e-8, 1e-7, 1e-6, 1e-5]
        },
        "disc": {
            "parameters": {
                "hid_num": {
                    "values": [(1,4), (2,4), (1,6), (2,6)]
                },
                "hid_size": {
                    "values": [(32,128), (64,128), (32,256), (64,256)]
                },
                "lr": {
                    "values": [1e-5, 1e-4, 1e-3, 1e-2]
                },
                "gp": {
                    "values": [1e-5, 1e-4, 1e-3]
                },
            }
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

<img src="results/95-bn-mask-DSN-sd-study-F00_v2-emb-fail.png">

In [97]:
# define hyperparameter study
hp_study = {
    "method": "random",
    "parameters": {
        "sd": {
            "values": [1e-6, 1e-5, 1e-4, 1e-3]
        },
        "disc": {
            "parameters": {
                "hid_num": {
                    "values": [(1,4), (2,4), (1,6), (2,6)]
                },
                "hid_size": {
                    "values": [(32,128), (64,128), (32,256), (64,256)]
                },
                "lr": {
                    "values": [1e-4, 1e-3, 1e-2]
                },
                "gp": {
                    "values": [1e-5, 1e-4]
                },
            }
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

<img src="results/95-bn-mask-DSN-sd-study-F00_v2-emb-fail_ext.png">

Let's download the high correlation training runs.

In [98]:
# plot losses
save_names = ["srnet_model_F00_v2_bn_mask_sd_study_ext"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models", excl_names=[])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [99]:
corr_data = [
    ("x**2", x_data**2), 
    ("y**2", y_data**2), 
    ("cos(x)", np.cos(x_data)), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]

In [100]:
model_path = "models"
model_ext = ".pkl"

for model_name in models:
    print(model_name)
    
    state = joblib.load(os.path.join(model_path, model_name + model_ext))
    
    print(state['hyperparams']['sd'])
    print(state['hyperparams']['disc'])
    
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
    
    print(f"Validation error: {state['total_val_loss']:.4e}")
        
    corr_mat = ut.node_correlations(acts, all_nodes, corr_data, show=True)
        
    print("")

srnet_model_F00_v2_bn_mask_sd_study_ext_v1
1e-06
{'gp': 0.0001, 'hid_num': [2, 4], 'hid_size': [64, 256], 'lr': 0.0001}
[1.3869069, 1.2386308, 0.10921547]
[1, 0, 2]
Validation error: 1.4853e-02

Node 1
corr(n1, x**2): 0.0158
corr(n1, y**2): 0.9992
corr(n1, cos(x)): -0.0300
corr(n1, cos(y)): -0.9694
corr(n1, x*y): 0.0464

Node 0
corr(n0, x**2): 0.9990
corr(n0, y**2): 0.0138
corr(n0, cos(x)): -0.9467
corr(n0, cos(y)): -0.0151
corr(n0, x*y): 0.1800

Node 2
corr(n2, x**2): -0.2454
corr(n2, y**2): -0.2217
corr(n2, cos(x)): 0.2407
corr(n2, cos(y)): 0.2096
corr(n2, x*y): -0.9732

srnet_model_F00_v2_bn_mask_sd_study_ext_v2
1e-06
{'gp': 1e-05, 'hid_num': [1, 4], 'hid_size': [32, 128], 'lr': 0.001}
[1.5793058, 1.1558449, 0.5685844]
[1, 0, 2]
Validation error: 1.9740e-02

Node 1
corr(n1, x**2): 0.0154
corr(n1, y**2): 0.9992
corr(n1, cos(x)): -0.0293
corr(n1, cos(y)): -0.9705
corr(n1, x*y): 0.0474

Node 0
corr(n0, x**2): 0.9998
corr(n0, y**2): 0.0144
corr(n0, cos(x)): -0.9542
corr(n0, cos(y)): -0.

These models are even a bit better, but still suffer from the same issue of converging to $y^2$ instead of $\text{cos}(y)$.

In [101]:
model_name = "srnet_model_F00_v2_bn_mask_sd_study_ext_v2"
model_path = "models"
model_ext = ".pkl"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

print(model.layers1.norm(model.layers1.alpha).detach().cpu().numpy()[all_nodes])

print("")

[1.5793058, 1.1558449, 0.5685844]
[1, 0, 2]
[[0. 1.]
 [1. 0.]
 [1. 1.]]



In [102]:
nodes = all_nodes[:3]

In [103]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [104]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [105]:
corr_data = [
    ("x**2", x_data**2), 
    ("y**2", y_data**2), 
    ("cos(x)", np.cos(x_data)), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]

In [106]:
ut.node_correlations(acts, nodes, corr_data);


Node 1
corr(n1, x**2): 0.0154
corr(n1, y**2): 0.9992
corr(n1, cos(x)): -0.0293
corr(n1, cos(y)): -0.9705
corr(n1, x*y): 0.0474

Node 0
corr(n0, x**2): 0.9998
corr(n0, y**2): 0.0144
corr(n0, cos(x)): -0.9542
corr(n0, cos(y)): -0.0154
corr(n0, x*y): 0.1739

Node 2
corr(n2, x**2): 0.1254
corr(n2, y**2): -0.0211
corr(n2, cos(x)): -0.1015
corr(n2, cos(y)): 0.0078
corr(n2, x*y): 0.9948


In [107]:
fig, ax = plt.subplots()

n = 0
bias = False

ax.scatter(x_data, x_data**2)
# ax.scatter(x_data, acts[:,n])
ax.scatter(x_data, model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item())

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [108]:
fig, ax = plt.subplots()

n = 1
bias = True

ax.scatter(y_data, np.cos(y_data))
ax.scatter(y_data, model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item())

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [109]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    #("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [110]:
n = 2
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=[n], model=model, bias=False, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Next, let's try to embed additional information in the network.