# 2022 Flatiron Machine Learning x Science Summer School

## Step 5: Train DSN with $L_1$ regularization on latent features

### Step 5.1: Check $a_1$ and $a_2$ parameters

In [1]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
import joblib

import torch
import wandb

from srnet import SRNet, SRData
import srnet_utils as ut

In [2]:
# set wandb project
wandb_project = "51-a1-a2-study-F00"

In [3]:
# define hyperparameters
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("DSN", "MLP"),
#         "lat_size": 16,
#         },
#     "epochs": 10000,
#     "runtime": None,
#     "batch_size": 64,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "shuffle": True,
# }

In [4]:
# define hyperparameter study
hp_study = {
    "method": "grid", # random, bayesian
    #"metric": {
    #    "name": "val_loss",
    #    "goal": "minimize",
    #},
    "parameters": {
        "a1": {
            "values": [0.0, 1e-5, 1e-3, 1e-1]
        },
        "a2": {
            "values": [0.0, 1e-5, 1e-3, 1e-1]
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

In [None]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for run in runs:
    for f in run.files():
        if f.name[-len(file_ext):] == file_ext and not os.path.isfile(f.name):
            print(f"Downloading {os.path.basename(f.name)}.")
            run.file(f.name).download()

In [5]:
# plot losses
save_names = ["F00_a1"]
excl_names = ["gc"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models", excl_names=excl_names)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [6]:
# get states
states = {}

model_ext = ".pkl"
for model_name in models:
    states[model_name + model_ext] = joblib.load(os.path.join(save_path, model_name + model_ext))

In [7]:
# plot losses
save_temp = "srnet_model_F00_a1_{a1:.0e}_a2_{a2:.0e}.pkl"

train_losses = []
val_losses = []
for a2 in hp_study['parameters']['a2']['values']:
    train_loss = []
    val_loss = []
    for a1 in hp_study['parameters']['a1']['values']:
        save_name = save_temp.format(a1=a1, a2=a2)
        train_loss.append(states[save_name]['total_train_loss'])
        val_loss.append(states[save_name]['total_val_loss'])
    train_losses.append(train_loss)
    val_losses.append(val_loss)

In [8]:
fig, ax = plt.subplots()

a1s = [f"{a1:.0e}" for a1 in hp_study['parameters']['a1']['values']]
a2s = [f"{a2:.0e}" for a2 in hp_study['parameters']['a2']['values']]

psm = ax.pcolor(train_losses)

fig.colorbar(psm, ax=ax)
ax.set_xlabel("Parameter a1")
ax.set_ylabel("Parameter a2")

xticks = ax.get_xticks()[1::2]
ax.set_xticks(xticks)
ax.set_xticklabels(a1s)

yticks = ax.get_yticks()[1::2]
ax.set_yticks(yticks)
ax.set_yticklabels(a2s)

ax.set_title("Training losses")

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
fig, ax = plt.subplots()

a1s = [f"{a1:.0e}" for a1 in hp_study['parameters']['a1']['values']]
a2s = [f"{a2:.0e}" for a2 in hp_study['parameters']['a2']['values']]

psm = ax.pcolor(val_losses)

fig.colorbar(psm, ax=ax)
ax.set_xlabel("Parameter a1")
ax.set_ylabel("Parameter a2")

xticks = ax.get_xticks()[1::2]
ax.set_xticks(xticks)
ax.set_xticklabels(a1s)

yticks = ax.get_yticks()[1::2]
ax.set_yticks(yticks)
ax.set_yticklabels(a2s)

ax.set_title("Validation losses")

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Notes:

* Values of `1e-01` for the DSN parameters $a_1$ and $a_2$ lead to significant jumps in the validation loss

In [10]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))     # TODO: create mask if file does not exist

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [11]:
# overview latent feature variance and alpha matrix
model_path = "models"
model_ext = ".pkl"

models = [
    "srnet_model_F00_a1_0e+00_a2_1e-05",
    "srnet_model_F00_a1_0e+00_a2_1e-03",
    "srnet_model_F00_a1_1e-05_a2_1e-03",
    "srnet_model_F00_a1_1e-03_a2_1e-03",
    "srnet_model_F00_a1_1e-01_a2_1e-03",
]

alpha_eps = 1e-4

for model_name in models:
    print(model_name)
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    ut.get_node_order(acts, show=True)
    
    alpha = model.layers1.alpha.detach().cpu().numpy()
    print(alpha[np.abs(alpha).sum(axis=1) > alpha_eps])
    
    print("")

srnet_model_F00_a1_0e+00_a2_1e-05
[0.6663477, 0.22754952, 0.13303775, 0.08057226, 0.07956417, 0.008945408, 3.1225023e-17, 1.3877788e-17, 1.3877788e-17, 1.3877788e-17, 1.3877788e-17, 0.0, 0.0, 0.0, 0.0, 0.0]
[0, 7, 15, 9, 1, 13, 2, 4, 8, 10, 14, 3, 5, 6, 11, 12]
[[-1.1827385  -1.1604899 ]
 [-0.5680361  -0.7426323 ]
 [ 1.1502272  -0.39936063]
 [ 0.7791239   0.77259094]
 [-0.5597268   0.48970914]
 [ 1.0314969  -0.8139722 ]]

srnet_model_F00_a1_0e+00_a2_1e-03
[0.6579603, 0.16474427, 0.0218615, 0.017693432, 0.004262521, 1.0098834e-06, 3.1225023e-17, 1.3877788e-17, 3.469447e-18, 8.6736174e-19, 2.1684043e-19, 0.0, 0.0, 0.0, 0.0, 0.0]
[0, 1, 15, 9, 13, 7, 12, 8, 4, 2, 14, 3, 5, 6, 10, 11]
[[-6.9626713e-01 -5.4629576e-01]
 [-4.4165763e-01 -3.9943978e-01]
 [ 1.0977784e-01  3.2188484e-06]
 [ 4.6725744e-01  5.8694594e-02]
 [-4.0847880e-01  1.4354388e-01]
 [ 5.2012444e-01 -1.4175722e-01]]

srnet_model_F00_a1_1e-05_a2_1e-03
[0.7007249, 0.16599274, 0.022339601, 0.010651008, 0.0034628494, 1.5755894e-0

In [12]:
# overview latent feature variance and alpha matrix
model_path = "models"
model_ext = ".pkl"

models = [    
    "srnet_model_F00_a1_1e-05_a2_0e+00",
    "srnet_model_F00_a1_1e-03_a2_0e+00",
    "srnet_model_F00_a1_1e-05_a2_1e-05",
    "srnet_model_F00_a1_1e-05_a2_1e-03",
    "srnet_model_F00_a1_1e-05_a2_1e-01",
]

alpha_eps = 1e-4

for model_name in models:
    print(model_name)
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    ut.get_node_order(acts, show=True)
    
    alpha = model.layers1.alpha.detach().cpu().numpy()
    print(alpha[np.abs(alpha).sum(axis=1) > alpha_eps])
    
    print("")

srnet_model_F00_a1_1e-05_a2_0e+00
[0.6776226, 0.22553541, 0.13323744, 0.07859352, 0.06688577, 0.010254435, 1.3877788e-17, 1.3877788e-17, 1.3877788e-17, 1.3877788e-17, 3.469447e-18, 3.469447e-18, 0.0, 0.0, 0.0, 0.0]
[0, 7, 15, 1, 9, 13, 4, 5, 8, 10, 2, 14, 3, 6, 11, 12]
[[-1.1732184  -1.1535953 ]
 [-0.5607364  -0.71192485]
 [ 1.135986   -0.36854565]
 [ 0.7656966   0.7304388 ]
 [-0.5511235   0.47454327]
 [ 1.0226839  -0.8039612 ]]

srnet_model_F00_a1_1e-03_a2_0e+00
[0.6201504, 0.1632466, 0.07770574, 5.3266916e-16, 5.551115e-17, 5.551115e-17, 5.551115e-17, 1.3877788e-17, 1.3877788e-17, 8.6736174e-19, 2.1684043e-19, 0.0, 0.0, 0.0, 0.0, 0.0]
[0, 15, 1, 9, 8, 10, 13, 4, 11, 5, 14, 2, 3, 6, 7, 12]
[[-0.52056277 -0.42505166]
 [-0.289202   -0.15479507]
 [ 0.44180116 -0.09080692]]

srnet_model_F00_a1_1e-05_a2_1e-05
[0.6577685, 0.19030869, 0.12078818, 0.080249384, 0.07855751, 0.008998795, 5.551115e-17, 1.3877788e-17, 1.3877788e-17, 1.3877788e-17, 3.469447e-18, 0.0, 0.0, 0.0, 0.0, 0.0]
[0, 7, 15, 

Notes:

* Interestingly, sparsity is already enforced for $a_2 > 0$ while $a_1 = 0$ (although it is enforced more strongly by $a_1$):

    * $a_1 = 1\text{e-}5$: `0.6776226, 0.22553541, 0.13323744, 0.07859352, 0.06688577, 0.010254435`
    * $a_2 = 1\text{e-}5$: `0.6663477, 0.22754952, 0.13303775, 0.08057226, 0.07956417, 0.008945408`
    * $a_1 = 1\text{e-}3$: `0.6201504, 0.1632466, 0.07770574, 5.3266916e-16, 5.551115e-17, 5.551115e-17`
    * $a_2 = 1\text{e-}3$: `0.6579603, 0.16474427, 0.0218615, 0.017693432, 0.004262521, 1.0098834e-06`

* Generally, $a_1$ and $a_2$ seem to have a surprisingly similar effect

* However, increasing $a_2$ does not seem to promote low input feature dependencies

Was the original loss function definition (sparsity over columns of `alpha`) better than the new one (sparsity over rows of `alpha`) better after all?

In [13]:
# overview latent feature variance and alpha matrix
model_path = "models"
model_ext = ".pkl"

models = [
    "srnet_model_F00_a1_1e-05_a2_1e-03",
    "extra/srnet_model_F00_a1_1e-05_a2_1e-03_dim0",
    # "extra/srnet_model_F00_a1_1e-03_a2_1e-03_newreg2",
]

alpha_eps = 1e-4

for model_name in models:
    print(model_name)
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    ut.get_node_order(acts, show=True)
    
    alpha = model.layers1.alpha.detach().cpu().numpy()
    print(alpha[np.abs(alpha).sum(axis=1) > alpha_eps])
    
    print("")

srnet_model_F00_a1_1e-05_a2_1e-03
[0.7007249, 0.16599274, 0.022339601, 0.010651008, 0.0034628494, 1.5755894e-08, 3.1225023e-17, 1.3877788e-17, 8.6736174e-19, 8.6736174e-19, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0, 1, 9, 15, 13, 7, 12, 8, 2, 14, 3, 4, 5, 6, 10, 11]
[[-6.9769609e-01 -5.4090250e-01]
 [-4.3820301e-01 -3.8749045e-01]
 [ 5.5093076e-02  4.8218321e-06]
 [ 4.6923199e-01  8.0085538e-02]
 [-4.0097216e-01  1.4224520e-01]
 [ 4.9253431e-01 -4.2777900e-02]]

extra/srnet_model_F00_a1_1e-05_a2_1e-03_dim0
[0.8296081, 0.101969145, 0.016981106, 0.003976508, 5.551115e-17, 5.551115e-17, 1.3877788e-17, 1.3877788e-17, 1.3877788e-17, 3.469447e-18, 5.421011e-20, 5.421011e-20, 2.1175824e-22, 0.0, 0.0, 0.0]
[0, 1, 15, 9, 7, 13, 8, 10, 12, 4, 2, 5, 14, 3, 6, 11]
[[-6.1068958e-01 -5.8874309e-01]
 [-3.2437617e-01 -2.6313716e-01]
 [ 2.8077862e-01 -5.6885201e-06]
 [ 3.3674622e-01 -3.9087702e-03]]



Not really. There is more sparsity, but the high-variance latent features still both depend on both input features.

In [14]:
model_name = "srnet_model_F00_a1_1e-03_a2_1e-03"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

[0.54778475, 0.121214814, 0.06913294, 5.551115e-17, 1.3877788e-17, 1.3877788e-17, 8.6736174e-19, 8.6736174e-19, 8.6736174e-19, 1.9058241e-21, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[0, 1, 15, 9, 7, 10, 2, 4, 11, 5, 3, 6, 8, 12, 13, 14]


In [15]:
nodes = all_nodes[:3]

In [16]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [17]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [18]:
corr_data = [
    ("x**2", x_data**2), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
    ("x", x_data),
    ("y", y_data),
]

In [21]:
ut.node_correlations(acts, nodes, corr_data);


Node 0
corr(n0, x**2): -0.6975
corr(n0, cos(y)): -0.2926
corr(n0, x*y): -0.7199
corr(n0, x): -0.2426
corr(n0, y): -0.1217

Node 1
corr(n1, x**2): -0.8321
corr(n1, cos(y)): 0.0010
corr(n1, x*y): -0.4605
corr(n1, x): 0.0079
corr(n1, y): 0.1556

Node 15
corr(n15, x**2): 0.7999
corr(n15, cos(y)): -0.0074
corr(n15, x*y): 0.1791
corr(n15, x): 0.7783
corr(n15, y): 0.0390


In [22]:
model_name = "extra/srnet_model_F00_a1_1e-05_a2_1e-03_dim0"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

[0.8296081, 0.101969145, 0.016981106, 0.003976508, 5.551115e-17, 5.551115e-17, 1.3877788e-17, 1.3877788e-17, 1.3877788e-17, 3.469447e-18, 5.421011e-20, 5.421011e-20, 2.1175824e-22, 0.0, 0.0, 0.0]
[0, 1, 15, 9, 7, 13, 8, 10, 12, 4, 2, 5, 14, 3, 6, 11]


In [23]:
nodes = all_nodes[:4]

In [24]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [25]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [26]:
corr_data = [
    ("x**2", x_data**2), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
    ("x", x_data),
    ("y", y_data),
]

In [28]:
ut.node_correlations(acts, nodes, corr_data);


Node 0
corr(n0, x**2): -0.7488
corr(n0, cos(y)): -0.2480
corr(n0, x*y): -0.6872
corr(n0, x): -0.3455
corr(n0, y): -0.0726

Node 1
corr(n1, x**2): -0.7921
corr(n1, cos(y)): -0.0524
corr(n1, x*y): -0.5029
corr(n1, x): 0.0490
corr(n1, y): 0.0655

Node 15
corr(n15, x**2): 0.8090
corr(n15, cos(y)): 0.0037
corr(n15, x*y): 0.1731
corr(n15, x): 0.7224
corr(n15, y): 0.0314

Node 9
corr(n9, x**2): -0.7803
corr(n9, cos(y)): 0.0008
corr(n9, x*y): -0.1040
corr(n9, x): 0.2995
corr(n9, y): -0.0266


### Step 5.2: Check $a_1$ and $a_2$ parameters for `F06`

In [29]:
# set wandb project
wandb_project = "52-a1-a2-study-F06"

In [30]:
# define hyperparameters
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("DSN", "MLP"),
#         "lat_size": 16,
#         },
#     "epochs": 10000,
#     "runtime": None,
#     "batch_size": 64,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "gc": 0.0,
#     "shuffle": True,
# }

In [31]:
# define hyperparameter study
hp_study = {
    "method": "grid", # random, bayesian
    #"metric": {
    #    "name": "val_loss",
    #    "goal": "minimize",
    #},
    "parameters": {
        "a1": {
            "values": [0.0, 1e-5, 1e-3, 1e-1]
        },
        "a2": {
            "values": [0.0, 1e-5, 1e-3, 1e-1]
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

In [None]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for run in runs:
    for f in run.files():
        if f.name[-len(file_ext):] == file_ext and not os.path.isfile(f.name):
            print(f"Downloading {os.path.basename(f.name)}.")
            run.file(f.name).download()

In [32]:
# plot losses
save_names = ["F06_a1"]
excl_names = ["gc"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models", excl_names=excl_names)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [33]:
# get states
states = {}

model_ext = ".pkl"
for model_name in models:
    states[model_name + model_ext] = joblib.load(os.path.join(save_path, model_name + model_ext))

In [34]:
# plot losses
save_temp = "srnet_model_F06_a1_{a1:.0e}_a2_{a2:.0e}.pkl"

train_losses = []
val_losses = []
for a2 in hp_study['parameters']['a2']['values']:
    train_loss = []
    val_loss = []
    for a1 in hp_study['parameters']['a1']['values']:
        save_name = save_temp.format(a1=a1, a2=a2)
        try:
            train_loss.append(states[save_name]['total_train_loss'])
            val_loss.append(states[save_name]['total_val_loss'])
        except:
            train_loss.append(-1)
            val_loss.append(-1)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
# correct missing losses
train_losses = np.array(train_losses)
val_losses = np.array(val_losses)
train_losses[train_losses < 0] = np.max(train_losses)
val_losses[val_losses < 0] = np.max(val_losses)

In [35]:
fig, ax = plt.subplots()

a1s = [f"{a1:.0e}" for a1 in hp_study['parameters']['a1']['values']]
a2s = [f"{a2:.0e}" for a2 in hp_study['parameters']['a2']['values']]

psm = ax.pcolor(train_losses)

fig.colorbar(psm, ax=ax)
ax.set_xlabel("Parameter a1")
ax.set_ylabel("Parameter a2")

xticks = ax.get_xticks()[1::2]
ax.set_xticks(xticks)
ax.set_xticklabels(a1s)

yticks = ax.get_yticks()[1::2]
ax.set_yticks(yticks)
ax.set_yticklabels(a2s)

ax.set_title("Training losses")

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [36]:
fig, ax = plt.subplots()

a1s = [f"{a1:.0e}" for a1 in hp_study['parameters']['a1']['values']]
a2s = [f"{a2:.0e}" for a2 in hp_study['parameters']['a2']['values']]

psm = ax.pcolor(val_losses)

fig.colorbar(psm, ax=ax)
ax.set_xlabel("Parameter a1")
ax.set_ylabel("Parameter a2")

xticks = ax.get_xticks()[1::2]
ax.set_xticks(xticks)
ax.set_xticklabels(a1s)

yticks = ax.get_yticks()[1::2]
ax.set_yticks(yticks)
ax.set_yticklabels(a2s)

ax.set_title("Validation losses")

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Notes:

* In contrast to `F00`, the maximum value for $a_2$ leads to the best validation error

In [37]:
# load data
data_path = "data_1k"

in_var = "X06"
lat_var = "G06"
target_var = "F06"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))     # TODO: create mask if file does not exist

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [38]:
# overview latent feature variance and alpha matrix
model_path = "models"
model_ext = ".pkl"

models = [
    "srnet_model_F06_a1_0e+00_a2_1e-01",
    "srnet_model_F06_a1_1e-05_a2_1e-01",
    "srnet_model_F06_a1_1e-03_a2_1e-01",
    "srnet_model_F06_a1_1e-01_a2_1e-01",
]

alpha_eps = 1e-2
alpha_bin = True

for model_name in models:
    print(model_name)
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    ut.get_node_order(acts, show=True)
    
    if alpha_eps:
        alpha = model.layers1.alpha.detach().cpu().numpy()[all_nodes]
        
        if alpha_bin:
            alpha[np.abs(alpha) < alpha_eps] = 0
            alpha[np.abs(alpha) > alpha_eps] = 1
        
        print(alpha[np.abs(alpha).sum(axis=1) > alpha_eps])
    
    print("")

srnet_model_F06_a1_0e+00_a2_1e-01
[0.2793835, 0.21531022, 0.07293549, 0.03322749, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[6, 13, 7, 3, 0, 1, 2, 4, 5, 8, 9, 10, 11, 12, 14, 15]
[[0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 1.]
 [1. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0.]]

srnet_model_F06_a1_1e-05_a2_1e-01
[0.231994, 0.1972766, 0.07445352, 0.05141343, 1.3877788e-17, 1.1754944e-38, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[6, 13, 7, 3, 5, 2, 0, 1, 4, 8, 9, 10, 11, 12, 14, 15]
[[0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 1.]
 [1. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0.]]

srnet_model_F06_a1_1e-03_a2_1e-01
[0.20505397, 0.20377514, 0.074096635, 0.071012214, 5.551115e-17, 1.3877788e-17, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[6, 13, 3, 7, 9, 5, 0, 1, 2, 4, 8, 10, 11, 12, 14, 15]
[[0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 1.]
 [1. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0.]]

srnet_model_F06_a1_1e-01_a2_1e-01
[0.46

In [39]:
# overview latent feature variance and alpha matrix
model_path = "models"
model_ext = ".pkl"

models = [
    "srnet_model_F06_a1_1e-03_a2_0e+00",
    "srnet_model_F06_a1_1e-03_a2_1e-05",
    "srnet_model_F06_a1_1e-03_a2_1e-03",
    "srnet_model_F06_a1_1e-03_a2_1e-01",
]

alpha_eps = 1e-2
alpha_bin = True

for model_name in models:
    print(model_name)
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    ut.get_node_order(acts, show=True)
    
    if alpha_eps:
        alpha = model.layers1.alpha.detach().cpu().numpy()[all_nodes]
        
        if alpha_bin:
            alpha[np.abs(alpha) < alpha_eps] = 0
            alpha[np.abs(alpha) > alpha_eps] = 1
        
        print(alpha[np.abs(alpha).sum(axis=1) > alpha_eps])
    
    print("")

srnet_model_F06_a1_1e-03_a2_0e+00
[0.21198046, 0.1942896, 0.17368312, 0.10920197, 0.06532583, 0.04275698, 0.030567221, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[2, 3, 7, 13, 6, 9, 5, 0, 1, 4, 8, 10, 11, 12, 14, 15]
[[1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0.]
 [1. 0. 0. 1. 0. 1. 0. 1.]
 [1. 0. 0. 0. 0. 1. 0. 1.]
 [0. 0. 0. 0. 0. 1. 0. 1.]
 [1. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0.]]

srnet_model_F06_a1_1e-03_a2_1e-05
[0.27475622, 0.2353662, 0.17284352, 0.105235115, 0.063140385, 0.03077288, 0.001204764, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[3, 2, 7, 13, 6, 5, 9, 0, 1, 4, 8, 10, 11, 12, 14, 15]
[[1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0.]
 [1. 0. 0. 1. 0. 1. 0. 1.]
 [1. 0. 0. 0. 0. 1. 0. 1.]
 [0. 0. 0. 0. 0. 1. 0. 1.]
 [1. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0.]]

srnet_model_F06_a1_1e-03_a2_1e-03
[0.3584506, 0.16501725, 0.1549626, 0.13812137, 0.04587842, 0.002812615, 2.537086e-06, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[3, 7, 13, 2,

Notes:

* For $a_2$ being `1e-01` and $a_1$ < `1e-01`, $x_0^2$ is split into two latent features

* For $a_1$ and $a_2$ being `1e-01` the correct split is achieved!

* For $a_2$ < `1e-01`, the latent features are not split correctly

In [40]:
model_name = "srnet_model_F06_a1_1e-01_a2_1e-01"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

[0.46574995, 0.13628185, 0.052221406, 2.2394784e-16, 1.3877788e-17, 8.6736174e-19, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[9, 2, 7, 5, 13, 12, 0, 1, 3, 4, 6, 8, 10, 11, 14, 15]


In [41]:
nodes = all_nodes[:3]

In [42]:
# select data
x0_data = train_data.in_data[:,0]
x3_data = train_data.in_data[:,3]
x5_data = train_data.in_data[:,5]
x7_data = train_data.in_data[:,7]

corr_data = [
    ("x0**2", x0_data**2), 
    ("cos(x3)", np.cos(x3_data)), 
    ("x5*x7", x5_data * x7_data),
    # ("x0", x0_data),
    # ("x3", x3_data),
    # ("x5", x5_data),
    # ("x7", x7_data),
]

In [44]:
ut.node_correlations(acts, nodes, corr_data)


Node 9
corr(n9, x0**2): 0.9999
corr(n9, cos(x3)): -0.0003
corr(n9, x5*x7): 0.0462

Node 2
corr(n2, x0**2): -0.0468
corr(n2, cos(x3)): 0.0082
corr(n2, x5*x7): -0.9998

Node 7
corr(n7, x0**2): -0.0001
corr(n7, cos(x3)): -0.9997
corr(n7, x5*x7): 0.0091


[[0.9999155972804934, -0.000313703317180114, 0.04621954791806068],
 [-0.04683840318915195, 0.008199553788688748, -0.999789552463771],
 [-0.0001496457743340833, -0.9997031211410643, 0.00907727638484032]]

Great result! 

Apparently, using GhostAdam is not required for this problem.