# 2022 Flatiron Machine Learning x Science Summer School

## Step 7: Explore regularization for bottleneck DSN

What we have seen so far:

* For problem `F06` without degeneracies, the latent features can be discovered perfectly using the DSN (no GhostAdam required)

* For problem `F00` with degeneracies, the latent features cannot be discovered, even when:

    * Using the DSN with GhostAdam
    
    * Training for 100k epochs
    
    * Increasing $a_2$

It seems that depending on as few input features as possible is not enforced enough or there is a different (undesired) global minimum, e.g. collapsing into a single latent feature (which happens when training for 100k epochs). In other words, it seems that enforcing sparsity (`few_latents`) is too similar to enforcing `few_dependencies`.

Let's disentangle the problem by creating a bottleneck, i.e. setting the number of latent features to the correct size a-priori.

Then, we can compare the following approaches:

1. Training DSN

2. Training DSN with $a_2 \ne 0$

3. Training DSN with normalized $\alpha$ and $a_2 \ne 0$

4. Training DSN with normalized $\alpha$ and minimizing entropy

5. Training DSN with normalized $\alpha$ and minimizing entropy $\times$ variance

Let's go!

### Step 7.1: Train bottleneck DSN

In [1]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
import joblib

import torch
import wandb

from srnet import SRNet, SRData
import srnet_utils as ut

In [2]:
# set wandb project
wandb_project = "71-bn-DSN-F00"

In [3]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("DSN", "MLP"),
#         "lat_size": 3,
#         },
#     "epochs": 10000,
#     "runtime": None,
#     "batch_size": 64,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "gc": 0.0,
#     "shuffle": True,
# }

In [None]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for run in runs:
    for f in run.files():
        if f.name[-len(file_ext):] == file_ext and not os.path.isfile(f.name):
            print(f"Downloading {os.path.basename(f.name)}.")
            run.file(f.name).download()

In [4]:
# plot losses
save_names = ["srnet_model_F00_bn"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models", excl_names=["norm", "a2"])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [5]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))     # TODO: create mask if file does not exist

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [6]:
# overview latent feature variance and alpha matrix
model_path = "models"
model_ext = ".pkl"

alpha_eps = 1e-6
alpha_bin = False

for model_name in models:
    print(model_name)
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
    
    if alpha_eps:
        alpha = model.layers1.alpha.detach().cpu().numpy()[all_nodes]
        
        if alpha_bin:
            alpha[np.abs(alpha) < alpha_eps] = 0
            alpha[np.abs(alpha) > alpha_eps] = 1
        
        print(alpha[np.abs(alpha).sum(axis=1) > alpha_eps])
        
    print(model.layers1.norm(model.layers1.alpha.detach().cpu()).numpy()[all_nodes])
    
    print("")

srnet_model_F00_bn
[5.482872, 0.33377925, 1.2634401e-06]
[1, 0, 2]
[[-2.1264930e+00  9.6783280e-01]
 [ 1.5003376e+00 -6.6390908e-01]
 [-3.6289787e-01 -7.9980089e-10]]
[[-2.1264930e+00  9.6783280e-01]
 [ 1.5003376e+00 -6.6390908e-01]
 [-3.6289787e-01 -7.9980089e-10]]



In [7]:
model_name = "srnet_model_F00_bn"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

[5.482872, 0.33377925, 1.2634401e-06]
[1, 0, 2]


In [8]:
nodes = all_nodes[:3]

In [9]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [10]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [12]:
corr_data = [
    ("x**2", x_data**2), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
    ("x", x_data),
    ("y", y_data),
]

In [12]:
ut.node_correlations(acts, nodes, corr_data);


Node 1
corr(n1, x**2): -0.8694
corr(n1, cos(y)): -0.0501
corr(n1, x*y): -0.5165
corr(n1, x): -0.1020
corr(n1, y): -0.0796

Node 0
corr(n0, x**2): -0.1374
corr(n0, cos(y)): -0.3930
corr(n0, x*y): -0.4840
corr(n0, x): -0.4720
corr(n0, y): 0.0973

Node 2
corr(n2, x**2): -0.2492
corr(n2, cos(y)): -0.0037
corr(n2, x*y): -0.1246
corr(n2, x): -0.9958
corr(n2, y): -0.0252


Notes:

* Latent feature shapes are quite different from previously seen

* One $\alpha$ value goes to zero automatically

* Still no dependence (of the high-variance latent features) on individual input features

### Step 7.2: Train bottleneck DSN with $a_2$

In [13]:
# set wandb project
wandb_project = "72-bn-DSN-a2-study-F00"

In [14]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("DSN", "MLP"),
#         "lat_size": 3,
#         },
#     "epochs": 10000,
#     "runtime": None,
#     "batch_size": 64,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "gc": 0.0,
#     "shuffle": True,
# }

In [15]:
# define hyperparameter study
hp_study = {
    "method": "grid", # random, bayesian
    #"metric": {
    #    "name": "val_loss",
    #    "goal": "minimize",
    #},
    "parameters": {
        #"l1": {
        #    "values": [1e-4, 1e-3, 1e-2]
        #},
        "a2": {
            "values": [1e-3, 1e-2, 1e-1, 1e0]
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

In [None]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for run in runs:
    for f in run.files():
        if f.name[-len(file_ext):] == file_ext and not os.path.isfile(f.name):
            print(f"Downloading {os.path.basename(f.name)}.")
            run.file(f.name).download()

In [16]:
# plot losses
save_names = ["F00_bn_a2"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [17]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))     # TODO: create mask if file does not exist

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [18]:
# overview latent feature variance and alpha matrix
model_path = "models"
model_ext = ".pkl"

alpha_eps = 1e-2
alpha_bin = False

for model_name in models:
    print(model_name)
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
    
    if alpha_eps:
        alpha = model.layers1.alpha.detach().cpu().numpy()[all_nodes]
        
        if alpha_bin:
            alpha[np.abs(alpha) < alpha_eps] = 0
            alpha[np.abs(alpha) > alpha_eps] = 1
        
        print(alpha[np.abs(alpha).sum(axis=1) > alpha_eps])
        
    print(model.layers1.norm(model.layers1.alpha.detach().cpu()).numpy()[all_nodes])
    
    print("")

srnet_model_F00_bn_a2_1e+00
[0.39600012, 8.881784e-16, 1.3877788e-17]
[2, 0, 1]
[[-0.00895066 -0.00502903]]
[[-8.9506609e-03 -5.0290320e-03]
 [-1.2541069e-18  3.3487419e-18]
 [ 5.4481973e-07  7.0096348e-08]]

srnet_model_F00_bn_a2_1e-01
[0.6369252, 0.0004733717, 3.5527137e-15]
[2, 0, 1]
[[-4.1295648e-02 -2.2897277e-02]
 [ 1.0456526e-02  8.0807167e-06]]
[[-4.1295648e-02 -2.2897277e-02]
 [ 1.0456526e-02  8.0807167e-06]
 [ 2.4354176e-09 -1.9725133e-09]]

srnet_model_F00_bn_a2_1e-02
[0.6763779, 0.17751801, 5.551115e-17]
[1, 0, 2]
[[-0.20615605  0.08827081]
 [ 0.19110097 -0.09128037]]
[[-2.0615605e-01  8.8270806e-02]
 [ 1.9110097e-01 -9.1280371e-02]
 [ 1.3835078e-10  6.2785566e-10]]

srnet_model_F00_bn_a2_1e-03
[3.9267466, 0.24354573, 2.220446e-16]
[1, 0, 2]
[[-1.5844834   0.46944278]
 [ 1.0996279  -0.41996077]]
[[-1.5844834e+00  4.6944278e-01]
 [ 1.0996279e+00 -4.1996077e-01]
 [-2.6066029e-39 -9.1203791e-40]]



Notes:

* Using $a_2$ to enforce `few_dependencies` has a larger effect on enforcing `few_latents` than on `few_dependencies`

* For $a_2$ being `1e-02`, the resulting model has only two high-variance features, but these still each depend on both input features

In [19]:
model_name = "srnet_model_F00_bn_a2_1e-02"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

[0.6763779, 0.17751801, 5.551115e-17]
[1, 0, 2]


In [20]:
nodes = all_nodes[:2]

In [21]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [22]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [23]:
corr_data = [
    ("x**2", x_data**2), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
    ("x", x_data),
    ("y", y_data),
]

In [24]:
ut.node_correlations(acts, nodes, corr_data);


Node 1
corr(n1, x**2): -0.8579
corr(n1, cos(y)): -0.1630
corr(n1, x*y): -0.5906
corr(n1, x): -0.2864
corr(n1, y): -0.0202

Node 0
corr(n0, x**2): -0.7611
corr(n0, cos(y)): -0.2188
corr(n0, x*y): -0.6844
corr(n0, x): -0.2229
corr(n0, y): -0.0552


### Step 7.3: Train bottleneck DSN with $\hat{\alpha}$ and $a_2$

In [25]:
# set wandb project
wandb_project = "73-bn-DSN-norm-a2-study-F00"

In [26]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("DSN", "MLP"),
#         "hid_kwargs": ({"norm": "softmax"}, {}),
#         "lat_size": 3,
#         },
#     "epochs": 10000,
#     "runtime": None,
#     "batch_size": 64,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "gc": 0.0,
#     "shuffle": True,
# }

In [27]:
# define hyperparameter study
hp_study = {
    "method": "grid", # random, bayesian
    #"metric": {
    #    "name": "val_loss",
    #    "goal": "minimize",
    #},
    "parameters": {
        #"l1": {
        #    "values": [1e-4, 1e-3, 1e-2]
        #},
        "a2": {
            "values": [0.0, 1e-3, 1e-2, 1e-1, 1e0]
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

In [None]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for run in runs:
    for f in run.files():
        if f.name[-len(file_ext):] == file_ext and not os.path.isfile(f.name):
            print(f"Downloading {os.path.basename(f.name)}.")
            run.file(f.name).download()

In [28]:
# plot losses
save_names = ["srnet_model_F00_bn.pkl", "F00_bn_norm_a2"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [29]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))     # TODO: create mask if file does not exist

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [30]:
# overview latent feature variance and alpha matrix
model_path = "models"
model_ext = ".pkl"

alpha_eps = 1e-6
alpha_bin = False

for model_name in models:
    print(model_name)
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
    
    if alpha_eps:
        alpha = model.layers1.alpha.detach().cpu().numpy()[all_nodes]
        
        if alpha_bin:
            alpha[np.abs(alpha) < alpha_eps] = 0
            alpha[np.abs(alpha) > alpha_eps] = 1
        
        print(alpha[np.abs(alpha).sum(axis=1) > alpha_eps])
    
    print(model.layers1.norm(model.layers1.alpha.detach().cpu()).numpy()[all_nodes])
    
    print("")

srnet_model_F00_bn
[5.482872, 0.33377925, 1.2634401e-06]
[1, 0, 2]
[[-2.1264930e+00  9.6783280e-01]
 [ 1.5003376e+00 -6.6390908e-01]
 [-3.6289787e-01 -7.9980089e-10]]
[[-2.1264930e+00  9.6783280e-01]
 [ 1.5003376e+00 -6.6390908e-01]
 [-3.6289787e-01 -7.9980089e-10]]

srnet_model_F00_bn_norm_a2_0e+00
[0.7364508, 0.48084727, 0.23385297]
[1, 0, 2]
[[-1.5839932   0.51744026]
 [ 0.97438246 -0.5205308 ]
 [-0.8817944  -0.97543824]]
[[0.7439408  0.25605917]
 [0.6115546  0.3884454 ]
 [0.47660613 0.52339387]]

srnet_model_F00_bn_norm_a2_1e+00
[0.78763723, 0.40730256, 0.172619]
[1, 0, 2]
[[ 1.32152272e-05 -1.13696697e-05]
 [ 3.54378926e-06 -2.50317498e-06]
 [ 1.25902352e-06  1.42729195e-05]]
[[0.5000005  0.49999955]
 [0.50000024 0.49999973]
 [0.49999678 0.5000033 ]]

srnet_model_F00_bn_norm_a2_1e-01
[0.6218816, 0.37553275, 0.2425685]
[1, 0, 2]
[[-3.2783086e-05 -1.7600223e-05]
 [ 6.1090641e-06  8.7805711e-06]
 [-9.0448593e-06  2.8484001e-06]]
[[0.5000038  0.49999622]
 [0.4999993  0.50000066]
 [0.5

Notes:

* Just `softmax` normalization, i.e. $a_2 = 0$, does not yield any improvements over the standard bottleneck DSN

* Actually, regularization of $\alpha$ via $a_2$ does not make any sense when `softmax` normalization is applied afterwards

**TODO**:

* Apply regularization via $a_2$ after `softmax`

### Step 7.4: Train bottleneck DSN with $\hat{\alpha}$ and $e_1$

In [31]:
# set wandb project
wandb_project = "74-bn-DSN-norm-e1-study-F00"

In [32]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("DSN", "MLP"),
#         "hid_kwargs": ({"norm": "softmax"}, {}),
#         "lat_size": 3,
#         },
#     "epochs": 10000,
#     "runtime": None,
#     "batch_size": 64,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "e1": 0.0,
#     "gc": 0.0,
#     "shuffle": True,
# }

In [33]:
# define hyperparameter study
hp_study = {
    "method": "grid", # random, bayesian
    #"metric": {
    #    "name": "val_loss",
    #    "goal": "minimize",
    #},
    "parameters": {
        #"l1": {
        #    "values": [1e-4, 1e-3, 1e-2]
        #},
        "e1": {
            "values": [0.0, 1e-5, 1e-3, 1e-2, 1e-1, 1e1]
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

In [None]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for run in runs:
    for f in run.files():
        if f.name[-len(file_ext):] == file_ext and not os.path.isfile(f.name):
            print(f"Downloading {os.path.basename(f.name)}.")
            run.file(f.name).download()

In [34]:
# plot losses
save_names = ["F00_bn_norm_e1"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models", excl_names=["max"])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [35]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))     # TODO: create mask if file does not exist

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [36]:
# overview latent feature variance and alpha matrix
model_path = "models"
model_ext = ".pkl"

alpha_eps = None
alpha_bin = False

for model_name in models:
    print(model_name)
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
    
    if alpha_eps:
        alpha = model.layers1.alpha.detach().cpu().numpy()[all_nodes]
        
        if alpha_bin:
            alpha[np.abs(alpha) < alpha_eps] = 0
            alpha[np.abs(alpha) > alpha_eps] = 1
        
        print(alpha[np.abs(alpha).sum(axis=1) > alpha_eps])
        
    print(model.layers1.norm(model.layers1.alpha.detach().cpu()).numpy()[all_nodes])
    
    print("")

srnet_model_F00_bn_norm_e1_0e+00
[0.7364508, 0.48084727, 0.23385297]
[1, 0, 2]
[[0.7439408  0.25605917]
 [0.6115546  0.3884454 ]
 [0.47660613 0.52339387]]

srnet_model_F00_bn_norm_e1_1e+01
[2.45539, 0.38972935, 0.14927715]
[1, 0, 2]
[[9.9946922e-01 5.3078489e-04]
 [9.9723595e-01 2.7640408e-03]
 [1.6764498e-03 9.9832350e-01]]

srnet_model_F00_bn_norm_e1_1e-01
[0.7481753, 0.7479166, 0.62162703]
[2, 1, 0]
[[0.0083454  0.99165463]
 [0.9936801  0.00631982]
 [0.97826666 0.02173341]]

srnet_model_F00_bn_norm_e1_1e-02
[1.1345829, 0.62815154, 0.2887137]
[1, 0, 2]
[[0.97732645 0.02267359]
 [0.93166333 0.06833664]
 [0.0288407  0.9711593 ]]

srnet_model_F00_bn_norm_e1_1e-03
[0.66432446, 0.51293874, 0.23554066]
[1, 0, 2]
[[0.9051772  0.09482283]
 [0.75741494 0.2425851 ]
 [0.3511344  0.6488656 ]]

srnet_model_F00_bn_norm_e1_1e-05
[0.7657546, 0.46918085, 0.23475884]
[1, 0, 2]
[[0.74743396 0.25256613]
 [0.61274755 0.38725242]
 [0.47625598 0.52374405]]



Notes:

* Minimizing entropy seems to be efficient

* For $e_2$ being `1e-01` or larger, the validation error is poor

* For `1e-05`, sparse dependencies on the input features are not enforced

* Even `1e-03` appears to be too low, as two latent features still depend on two input features

* For `1e-02`, the lowest validation error is achieved and all latent features depend largely on a single input feature

In [37]:
model_name = "srnet_model_F00_bn_norm_e1_1e-02"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

[1.1345829, 0.62815154, 0.2887137]
[1, 0, 2]


In [38]:
nodes = all_nodes[:3]

In [39]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [40]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [41]:
corr_data = [
    ("x**2", x_data**2), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
    ("x", x_data),
    ("y", y_data),
]

In [43]:
ut.node_correlations(acts, nodes, corr_data);


Node 1
corr(n1, x**2): -0.6755
corr(n1, cos(y)): 0.0114
corr(n1, x*y): -0.1712
corr(n1, x): -0.8817
corr(n1, y): -0.0393

Node 0
corr(n0, x**2): -0.6556
corr(n0, cos(y)): -0.0993
corr(n0, x*y): -0.6574
corr(n0, x): 0.1075
corr(n0, y): 0.3171

Node 2
corr(n2, x**2): 0.0314
corr(n2, cos(y)): 0.3564
corr(n2, x*y): 0.0860
corr(n2, x): -0.0281
corr(n2, y): 0.8895


Notes:

* The model resulting from $e_2$ being `1e-02` has clearer dependencies:

    * Node 1: $\approx x^2$
    
    * Node 0: $\approx x \cdot y$
    
    * Node 2: $\approx \text{cos}(y)$

In [44]:
fig, ax = plt.subplots()

n = 1
bias = True

ax.scatter(x_data, x_data**2)
# ax.scatter(x_data, acts[:,n])
ax.scatter(x_data, model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item())

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [45]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    #("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [46]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=[0], model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [47]:
fig, ax = plt.subplots()

n = 2
bias = True

ax.scatter(y_data, np.cos(y_data))
# ax.scatter(y_data, acts[:,n])
ax.scatter(y_data, model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item())

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Although we learn clearer dependencies, we can observe the impact of degeneracies. 

For example, node 0 primarily models $x \cdot y$, but also contains a part of $x^2$. At this point, the symbolic discriminator should come in to incentivize the desired shapes.

A few questions remain:

1. How is $x \cdot y$ realized when all latent features largely depend on a single input feature? Is there a weight that amplifies the signal masked by $\alpha$?

2. Can we prune/limit the $\alpha$ mask to get strict dependencies on single input features?

3. What happens when we train more epochs?

**Question 1**: How is $x \cdot y$ realized when all latent features largely depend on a single input feature? Is there a weight that amplifies the signal masked by $\alpha$?

In [48]:
model.layers1.norm(model.layers1.alpha)

tensor([[0.9317, 0.0683],
        [0.9773, 0.0227],
        [0.0288, 0.9712]], grad_fn=<SoftmaxBackward0>)

In [49]:
model.layers1.w[0][0,...].abs().max(dim=1)

torch.return_types.max(
values=tensor([0.7276, 1.1862], grad_fn=<MaxBackward0>),
indices=tensor([27, 17]))

In [50]:
model.layers1.w[0][1,...].abs().max(dim=1)

torch.return_types.max(
values=tensor([0.8105, 0.0571], grad_fn=<MaxBackward0>),
indices=tensor([16, 16]))

In [51]:
model.layers1.w[0][2,...].abs().max(dim=1)

torch.return_types.max(
values=tensor([0.7636, 0.6659], grad_fn=<MaxBackward0>),
indices=tensor([16,  0]))

Node 0 shows a slight amplification of $y$.

**Question 2**: Can we prune/limit the $\alpha$ mask to get strict dependencies on single input features?

Implement as threshold? https://pytorch.org/docs/stable/generated/torch.nn.Threshold.html#torch.nn.Threshold

In [52]:
model_org = model
acts_org = acts

In [53]:
model_name = "srnet_model_F00_bn_norm_prune_5e-02_e2_1e-02"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

print(model.layers1.norm(model.layers1.alpha.detach().cpu()).numpy()[all_nodes])

[0.9102438, 0.6963518, 0.27245826]
[1, 0, 2]
[[0.9773646  0.0226354 ]
 [0.93092346 0.06907651]
 [0.02262298 0.977377  ]]


In [54]:
nodes = all_nodes[:3]

In [55]:
corr_data = [
    ("x**2", x_data**2), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
    ("x", x_data),
    ("y", y_data),
]

In [56]:
ut.node_correlations(acts, nodes, corr_data);


Node 1
corr(n1, x**2): -0.6816
corr(n1, cos(y)): 0.0121
corr(n1, x*y): -0.1724
corr(n1, x): -0.8793
corr(n1, y): -0.0422

Node 0
corr(n0, x**2): -0.6845
corr(n0, cos(y)): -0.1080
corr(n0, x*y): -0.6801
corr(n0, x): 0.0711
corr(n0, y): 0.2504

Node 2
corr(n2, x**2): 0.0318
corr(n2, cos(y)): 0.3954
corr(n2, x*y): 0.0255
corr(n2, x): 0.0256
corr(n2, y): 0.8748


In [57]:
fig, ax = plt.subplots()

n = 1
bias = True

ax.scatter(x_data, x_data**2, label="Target: $x_0^2$")
ax.scatter(x_data, model_org.layers2[0].weight[0,n].item()*acts_org[:,n] + bias * model_org.layers2[0].bias.item(), label="DSN Entropy")
ax.scatter(x_data, model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item(), label="DSN Entropy + Pruning")

ax.set_xlabel("$x_0$")
ax.set_ylabel("$g_0$")
ax.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [58]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    #("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [59]:
_ = plt.figure()
ax = plt.axes(projection='3d')

x_data = train_data.in_data[:,0][:plot_size].numpy()
y_data = train_data.in_data[:,1][:plot_size].numpy()
z_data = [("x*y", x_data * y_data)]

for d in z_data:
    z_data = d[1]
    ax.scatter3D(x_data, y_data, z_data, label="Target: $x_0 \cdot x_1$")

n = 0
b = model_org.layers2[0]._parameters['bias'].item()
w = model_org.layers2[0]._parameters['weight'].detach().numpy()[0, n]
n_data = w * acts_org[:,n][:plot_size].numpy() + b
ax.scatter3D(x_data, y_data, n_data, label="DSN Entropy")
    
n = 0
b = model.layers2[0]._parameters['bias'].item()
w = model.layers2[0]._parameters['weight'].detach().numpy()[0, n]
n_data = w * acts[:,n][:plot_size].numpy() + b
ax.scatter3D(x_data, y_data, n_data, label="DSN Entropy + Pruning")

ax.set_xlabel("$x_0$")
ax.set_ylabel("$x_1$")
ax.set_zlabel("$g_2$")
ax.view_init(elev=5.16234, azim=-56.591)
ax.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [60]:
fig, ax = plt.subplots()

n = 2
bias = True

ax.scatter(y_data, np.cos(y_data), label="Target: cos$(x_1)$")
ax.scatter(y_data, model_org.layers2[0].weight[0,n].item()*acts_org[:,n] + bias * model_org.layers2[0].bias.item(), label="DSN Entropy")
ax.scatter(y_data, model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item(), label="DSN Entropy + Pruning")

ax.set_xlabel("$x_1$")
ax.set_ylabel("$g_1$")
ax.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Pruning yields very similar latent features with slightly clearer dependencies on the input features.

**Question 3**: What happens when we train more epochs?

In [61]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for run in runs:
    for f in run.files():
        if f.name[-len(file_ext):] == file_ext and not os.path.isfile(f.name):
            print(f"Downloading {os.path.basename(f.name)}.")
            run.file(f.name).download()

In [62]:
# plot losses
save_names = ["F00_bn_norm_e1_1e-02_max"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [63]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))     # TODO: create mask if file does not exist

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [64]:
# overview latent feature variance and alpha matrix
model_path = "models"
model_ext = ".pkl"

alpha_eps = None
alpha_bin = False

for model_name in models:
    print(model_name)
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
    
    if alpha_eps:
        alpha = model.layers1.alpha.detach().cpu().numpy()[all_nodes]
        
        if alpha_bin:
            alpha[np.abs(alpha) < alpha_eps] = 0
            alpha[np.abs(alpha) > alpha_eps] = 1
        
        print(alpha[np.abs(alpha).sum(axis=1) > alpha_eps])
    
    print(model.layers1.norm(model.layers1.alpha.detach().cpu()).numpy()[all_nodes])
    
    print("")

srnet_model_F00_bn_norm_e1_1e-02_max
[0.28535384, 0.071490675, 0.0]
[0, 2, 1]
[[0.9474744  0.05252557]
 [0.02267693 0.97732306]
 [0.97731006 0.02268998]]



In [65]:
model_name = "srnet_model_F00_bn_norm_e1_1e-02_max"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

[0.28535384, 0.071490675, 0.0]
[0, 2, 1]


In [66]:
nodes = all_nodes[:2]

In [67]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [68]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [69]:
corr_data = [
    ("x**2", x_data**2), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
    ("x", x_data),
    ("y", y_data),
]

In [70]:
ut.node_correlations(acts, nodes, corr_data);


Node 0
corr(n0, x**2): -0.8425
corr(n0, cos(y)): -0.0667
corr(n0, x*y): -0.6543
corr(n0, x): -0.2661
corr(n0, y): 0.0567

Node 2
corr(n2, x**2): 0.0138
corr(n2, cos(y)): 0.7685
corr(n2, x*y): -0.0103
corr(n2, x): 0.0122
corr(n2, y): 0.5615


In [71]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    #("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [72]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=[0], model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [73]:
fig, ax = plt.subplots()

ax.scatter(y_data, np.cos(y_data))
ax.scatter(y_data, acts[:,2])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

When training for 100k epochs with $e_1$ being `1e-2`, the latent features seem to converge to $x^2 + x \cdot y$ and $\text{cos}(y)$. Why does one latent feature have zero variance? Weight decay? Its $\alpha$ values are not zero.

**Next steps**:

* Symbolic discriminator

* Resolve bottleneck

* Ensure long epoch convergence

* Increase input feature dimension

* Change $f(x)$

* Check AI Feynman: https://github.com/SJ001/AI-Feynman

* Techniques:

    * Consider two step training

    * Regularize mean instead of summed entropy (of high-variance features)
    
    * Entropy $\times$ variance
    
    * Pruning

**Work-in-progress**:

### Step 7.5: Train bottleneck DSN with $\hat{\alpha}$ and $e_2$

We want to penalize `F.softmax(lat_acts.var(dim=0)) * entropy`, not `lat_acts.var(dim=0) * entropy`.

However, this might be more effective for non-bottleneck DSNs.

Takeaway: less clear dependencies

In [None]:
# set wandb project
wandb_project = "75-bn-DSN-norm-e2-study-F00"

In [None]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("DSN", "MLP"),
#         "hid_kwargs": ({"norm": "softmax"}, {}),
#         "lat_size": 3,
#         },
#     "epochs": 10000,
#     "runtime": None,
#     "batch_size": 64,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "e1": 0.0,
#     "e2": 0.0,
#     "gc": 0.0,
#     "shuffle": True,
# }

In [None]:
# define hyperparameter study
hp_study = {
    "method": "grid", # random, bayesian
    #"metric": {
    #    "name": "val_loss",
    #    "goal": "minimize",
    #},
    "parameters": {
        #"l1": {
        #    "values": [1e-4, 1e-3, 1e-2]
        #},
        "e2": {
            "values": [0.0, 1e-5, 1e-3, 1e-1, 1e1]
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

In [None]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for run in runs:
    for f in run.files():
        if f.name[-len(file_ext):] == file_ext and not os.path.isfile(f.name):
            print(f"Downloading {os.path.basename(f.name)}.")
            run.file(f.name).download()

In [None]:
# plot losses
save_names = ["F00_bn_norm_e2"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models")

In [None]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))     # TODO: create mask if file does not exist

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [None]:
# overview latent feature variance and alpha matrix
model_path = "models"
model_ext = ".pkl"

alpha_eps = 1e-2
alpha_bin = False

for model_name in models:
    print(model_name)
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
    
    if alpha_eps:
        alpha = model.layers1.alpha.detach().cpu().numpy()[all_nodes]
        
        if alpha_bin:
            alpha[np.abs(alpha) < alpha_eps] = 0
            alpha[np.abs(alpha) > alpha_eps] = 1
        
        print(alpha[np.abs(alpha).sum(axis=1) > alpha_eps])
        
    print(model.layers1.alpha_n.detach().cpu().numpy()[all_nodes])
    
    print("")

In [None]:
model_name = "srnet_model_F00_bn_norm_e2_1e+01"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

In [None]:
nodes = all_nodes[:3]

In [None]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [None]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

In [None]:
corr_data = [
    ("x**2", x_data**2), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
    ("x", x_data),
    ("y", y_data),
]

In [None]:
ut.node_correlations(acts, nodes, corr_data, nonzero=True)

In [None]:
fig, ax = plt.subplots()

ax.scatter(x_data, x_data**2)
ax.scatter(x_data, acts[:,1])

plt.show()

In [None]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    #("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [None]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=[0], model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

In [None]:
fig, ax = plt.subplots()

ax.scatter(y_data, np.cos(y_data))
ax.scatter(y_data, acts[:,0])

plt.show()