# 2022 Flatiron Machine Learning x Science Summer School

## Step 7: Explore regularization for bottleneck DSN

What we have seen so far:

* For problem `F06` without degeneracies, the latent features can be discovered perfectly using the DSN (no GhostAdam required)

* For problem `F00` with degeneracies, the latent features cannot be discovered, even when:

    * Using the DSN with GhostAdam
    
    * Training for 100k epochs
    
    * Increasing $a_2$

* It seems that depending on as few input features as possible is not enforced enough or there is a different global minimum, e.g. collapsing into a single latent features (which happens when training for 100k epochs)

* In other words, it seems that enforcing sparsity (`few_latents`) is too similar to enforcing `few_dependencies`

Let's disentangle the problem by creating a bottleneck, i.e. setting the number of latent features to the correct size a-priori.

Then, we can compare the following approaches:

* Training DSN

* Training DSN with $a_2 \ne 0$

* Training DSN with normalized $\alpha$ and $a_2 \ne 0$

* Training DSN with normalized $\alpha$ and minimizing entropy

* Training DSN with normalized $\alpha$ and minimizing entropy x variance

Let's go!

**TODO**: 

* Rename steps

* First `few_latents`, then `few_dependencies`? Truncate? Consider two step training

* Train for longer

### Step 6.1: Train bottleneck DSN

In [1]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
import joblib

import torch
import wandb

from srnet import SRNet, SRData
import srnet_utils as ut

In [None]:
# set wandb project
wandb_project = "61-bn-DSN-F00"

In [None]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("DSN", "MLP"),
#         "lat_size": 3,
#         },
#     "epochs": 10000,
#     "runtime": None,
#     "batch_size": 64,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "gc": 0.0,
#     "shuffle": True,
# }

In [None]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for run in runs:
    for f in run.files():
        if f.name[-len(file_ext):] == file_ext and not os.path.isfile(f.name):
            print(f"Downloading {os.path.basename(f.name)}.")
            run.file(f.name).download()

In [None]:
# plot losses
save_names = ["srnet_model_F00_bn"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models", excl_names=["norm"])

In [None]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))     # TODO: create mask if file does not exist

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [None]:
# overview latent feature variance and alpha matrix
model_path = "models"
model_ext = ".pkl"

alpha_eps = 1e-2
alpha_bin = False

for model_name in models:
    print(model_name)
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
    
    if alpha_eps:
        alpha = model.layers1.alpha.detach().cpu().numpy()[all_nodes]
        
        if alpha_bin:
            alpha[np.abs(alpha) < alpha_eps] = 0
            alpha[np.abs(alpha) > alpha_eps] = 1
        
        print(alpha[np.abs(alpha).sum(axis=1) > alpha_eps])
    
    print("")

In [None]:
model_name = "srnet_model_F00_bn"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

In [None]:
nodes = all_nodes[:3]

In [None]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [None]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

In [None]:
corr_data = [
    ("x**2", x_data**2), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
    ("x", x_data),
    ("y", y_data),
]

In [None]:
ut.node_correlations(acts, nodes, corr_data, nonzero=True)

Notes:

* Very interesting

* One $\alpha$ value goes to zero automatically

* Latent feature shapes are quite different from previously seen

### Step 6.2: Train bottleneck DSN with $a_2$

In [None]:
# set wandb project
wandb_project = "62-bn-DSN-a2-study-F00"

In [None]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("DSN", "MLP"),
#         "lat_size": 3,
#         },
#     "epochs": 10000,
#     "runtime": None,
#     "batch_size": 64,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "gc": 0.0,
#     "shuffle": True,
# }

In [None]:
# define hyperparameter study
hp_study = {
    "method": "grid", # random, bayesian
    #"metric": {
    #    "name": "val_loss",
    #    "goal": "minimize",
    #},
    "parameters": {
        #"l1": {
        #    "values": [1e-4, 1e-3, 1e-2]
        #},
        "a2": {
            "values": [1e-3, 1e-2, 1e-1, 1e0]
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

In [None]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for run in runs:
    for f in run.files():
        if f.name[-len(file_ext):] == file_ext and not os.path.isfile(f.name):
            print(f"Downloading {os.path.basename(f.name)}.")
            run.file(f.name).download()

In [None]:
# plot losses
save_names = ["F00_bn_a2"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models")

In [None]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))     # TODO: create mask if file does not exist

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [None]:
# overview latent feature variance and alpha matrix
model_path = "models"
model_ext = ".pkl"

alpha_eps = 1e-2
alpha_bin = False

for model_name in models:
    print(model_name)
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
    
    if alpha_eps:
        alpha = model.layers1.alpha.detach().cpu().numpy()[all_nodes]
        
        if alpha_bin:
            alpha[np.abs(alpha) < alpha_eps] = 0
            alpha[np.abs(alpha) > alpha_eps] = 1
        
        print(alpha[np.abs(alpha).sum(axis=1) > alpha_eps])
    
    print("")

In [None]:
model_name = "srnet_model_F00_bn_a2_1e-01"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

In [None]:
nodes = all_nodes[:2]

In [None]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [None]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

In [None]:
corr_data = [
    ("x**2", x_data**2), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
    ("x", x_data),
    ("y", y_data),
]

In [None]:
ut.node_correlations(acts, nodes, corr_data, nonzero=True)

### Step 6.3: Train bottleneck DSN with $\hat{\alpha}$ and $a_2$

In [None]:
# set wandb project
wandb_project = "63-bn-DSN-norm-a2-study-F00"

In [None]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("DSN", "MLP"),
#         "hid_kwargs": ({"norm": "softmax"}, {}),
#         "lat_size": 3,
#         },
#     "epochs": 10000,
#     "runtime": None,
#     "batch_size": 64,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "gc": 0.0,
#     "shuffle": True,
# }

In [None]:
# define hyperparameter study
hp_study = {
    "method": "grid", # random, bayesian
    #"metric": {
    #    "name": "val_loss",
    #    "goal": "minimize",
    #},
    "parameters": {
        #"l1": {
        #    "values": [1e-4, 1e-3, 1e-2]
        #},
        "a2": {
            "values": [0.0, 1e-3, 1e-2, 1e-1, 1e0]
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

In [None]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for run in runs:
    for f in run.files():
        if f.name[-len(file_ext):] == file_ext and not os.path.isfile(f.name):
            print(f"Downloading {os.path.basename(f.name)}.")
            run.file(f.name).download()

In [None]:
# plot losses
save_names = ["F00_bn_norm_a2"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models")

In [None]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))     # TODO: create mask if file does not exist

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [None]:
# overview latent feature variance and alpha matrix
model_path = "models"
model_ext = ".pkl"

alpha_eps = 1e-6
alpha_bin = False

for model_name in models:
    print(model_name)
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
    
    if alpha_eps:
        alpha = model.layers1.alpha.detach().cpu().numpy()[all_nodes]
        
        if alpha_bin:
            alpha[np.abs(alpha) < alpha_eps] = 0
            alpha[np.abs(alpha) > alpha_eps] = 1
        
        print(alpha[np.abs(alpha).sum(axis=1) > alpha_eps])
    
    print("")

In [None]:
model_name = "srnet_model_F00_bn"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

In [None]:
nodes = all_nodes[:3]

In [None]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [None]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

In [None]:
corr_data = [
    ("x**2", x_data**2), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
    ("x", x_data),
    ("y", y_data),
]

In [None]:
ut.node_correlations(acts, nodes, corr_data, nonzero=True)

### Step 6.4: Train bottleneck DSN with $\hat{\alpha}$ and $e_1$

In [9]:
# set wandb project
wandb_project = "64-bn-DSN-norm-e1-study-F00"

In [10]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("DSN", "MLP"),
#         "hid_kwargs": ({"norm": "softmax"}, {}),
#         "lat_size": 3,
#         },
#     "epochs": 10000,
#     "runtime": None,
#     "batch_size": 64,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "e1": 0.0,
#     "gc": 0.0,
#     "shuffle": True,
# }

In [11]:
# define hyperparameter study
hp_study = {
    "method": "grid", # random, bayesian
    #"metric": {
    #    "name": "val_loss",
    #    "goal": "minimize",
    #},
    "parameters": {
        #"l1": {
        #    "values": [1e-4, 1e-3, 1e-2]
        #},
        "e1": {
            "values": [0.0, 1e-5, 1e-3, 1e-1, 1e1]
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

In [12]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for run in runs:
    for f in run.files():
        if f.name[-len(file_ext):] == file_ext and not os.path.isfile(f.name):
            print(f"Downloading {os.path.basename(f.name)}.")
            run.file(f.name).download()

Downloading srnet_model_F00_bn_norm_e1_1e-02.pkl.


In [13]:
# plot losses
save_names = ["F00_bn_norm_e1"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [14]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))     # TODO: create mask if file does not exist

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [15]:
# overview latent feature variance and alpha matrix
model_path = "models"
model_ext = ".pkl"

alpha_eps = 1e-2
alpha_bin = False

for model_name in models:
    print(model_name)
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
    
    if alpha_eps:
        alpha = model.layers1.alpha.detach().cpu().numpy()[all_nodes]
        
        if alpha_bin:
            alpha[np.abs(alpha) < alpha_eps] = 0
            alpha[np.abs(alpha) > alpha_eps] = 1
        
        print(alpha[np.abs(alpha).sum(axis=1) > alpha_eps])
    
    print("")

srnet_model_F00_bn_norm_e1_0e+00
[0.7364508, 0.48084727, 0.23385297]
[1, 0, 2]
[[-1.5839932   0.51744026]
 [ 0.97438246 -0.5205308 ]
 [-0.8817944  -0.97543824]]

srnet_model_F00_bn_norm_e1_1e+01
[2.45539, 0.38972935, 0.14927715]
[1, 0, 2]
[[-7.5406394e+00  1.6684064e-05]
 [ 5.8884821e+00 -1.8821185e-04]
 [-2.0058741e-04 -6.3895998e+00]]

srnet_model_F00_bn_norm_e1_1e-01
[0.7481753, 0.7479166, 0.62162703]
[2, 1, 0]
[[ 1.4666026e-05 -4.7776794e+00]
 [-5.0577326e+00  8.3070981e-06]
 [ 3.8075233e+00 -5.9154321e-04]]

srnet_model_F00_bn_norm_e1_1e-02
[1.1345829, 0.62815154, 0.2887137]
[1, 0, 2]
[[-3.7636557e+00  3.5644669e-05]
 [ 2.6125922e+00  6.6711596e-05]
 [-9.8887729e-05 -3.5168018e+00]]

srnet_model_F00_bn_norm_e1_1e-03
[0.66432446, 0.51293874, 0.23554066]
[1, 0, 2]
[[-2.2562172e+00  9.6815296e-05]
 [ 1.2901431e+00 -1.5158442e-01]
 [-6.3947332e-01 -1.2535298e+00]]

srnet_model_F00_bn_norm_e1_1e-05
[0.7657546, 0.46918085, 0.23475884]
[1, 0, 2]
[[-1.5981584   0.5131855 ]
 [ 0.97823876 -

In [16]:
model_name = "srnet_model_F00_bn_norm_e1_1e-02"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

[1.1345829, 0.62815154, 0.2887137]
[1, 0, 2]


In [17]:
nodes = all_nodes[:3]

In [18]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [19]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [20]:
corr_data = [
    ("x**2", x_data**2), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
    ("x", x_data),
    ("y", y_data),
]

In [21]:
ut.node_correlations(acts, nodes, corr_data, nonzero=True)


Node 1
corr(n1, x**2): -0.6755/-0.6755
corr(n1, cos(y)): 0.0114/0.0114
corr(n1, x*y): -0.1712/-0.1712
corr(n1, x): -0.8817/-0.8817
corr(n1, y): -0.0393/-0.0393

Node 0
corr(n0, x**2): -0.6556/-0.6556
corr(n0, cos(y)): -0.0993/-0.0993
corr(n0, x*y): -0.6574/-0.6574
corr(n0, x): 0.1075/0.1075
corr(n0, y): 0.3171/0.3171

Node 2
corr(n2, x**2): 0.0314/0.0314
corr(n2, cos(y)): 0.3564/0.3564
corr(n2, x*y): 0.0860/0.0860
corr(n2, x): -0.0281/-0.0281
corr(n2, y): 0.8895/0.8895


In [22]:
fig, ax = plt.subplots()

ax.scatter(y_data, np.cos(y_data))
ax.scatter(y_data, acts[:,2])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [23]:
fig, ax = plt.subplots()

ax.scatter(x_data, x_data**2)
ax.scatter(x_data, acts[:,1])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [24]:
fig, ax = plt.subplots()

ax.scatter(x_data, x_data**2)
ax.scatter(x_data, acts[:,0])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [25]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    #("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [26]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=[0], model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [27]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    #("target", train_data.target_data),
    ("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [28]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=[1], model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [29]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    #("target", train_data.target_data),
    #("x**2", x_data**2), 
    ("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [30]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=[2], model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

### Step 6.5: Train bottleneck DSN with $\hat{\alpha}$ and $e_2$

We want to penalize `F.softmax(lat_acts.var(dim=0)) * entropy`, not `lat_acts.var(dim=0) * entropy`.

However, this might be more effective for non-bottleneck DSNs.

In [31]:
# set wandb project
wandb_project = "65-bn-DSN-norm-e2-study-F00"

In [32]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("DSN", "MLP"),
#         "hid_kwargs": ({"norm": "softmax"}, {}),
#         "lat_size": 3,
#         },
#     "epochs": 10000,
#     "runtime": None,
#     "batch_size": 64,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "e1": 0.0,
#     "e2": 0.0,
#     "gc": 0.0,
#     "shuffle": True,
# }

In [33]:
# define hyperparameter study
hp_study = {
    "method": "grid", # random, bayesian
    #"metric": {
    #    "name": "val_loss",
    #    "goal": "minimize",
    #},
    "parameters": {
        #"l1": {
        #    "values": [1e-4, 1e-3, 1e-2]
        #},
        "e2": {
            "values": [0.0, 1e-5, 1e-3, 1e-1, 1e1]
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

In [34]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for run in runs:
    for f in run.files():
        if f.name[-len(file_ext):] == file_ext and not os.path.isfile(f.name):
            print(f"Downloading {os.path.basename(f.name)}.")
            run.file(f.name).download()

Downloading srnet_model_F00_bn_norm_e2_1e-03.pkl.


In [35]:
# plot losses
save_names = ["F00_bn_norm_e2"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [7]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))     # TODO: create mask if file does not exist

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [8]:
# overview latent feature variance and alpha matrix
model_path = "models"
model_ext = ".pkl"

alpha_eps = 1e-2
alpha_bin = False

for model_name in models:
    print(model_name)
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
    
    if alpha_eps:
        alpha = model.layers1.alpha.detach().cpu().numpy()[all_nodes]
        
        if alpha_bin:
            alpha[np.abs(alpha) < alpha_eps] = 0
            alpha[np.abs(alpha) > alpha_eps] = 1
        
        print(alpha[np.abs(alpha).sum(axis=1) > alpha_eps])
    
    print("")

srnet_model_F00_bn_norm_e2_0e+00
[0.7364508, 0.48084727, 0.23385297]
[1, 0, 2]
[[-1.5839932   0.51744026]
 [ 0.97438246 -0.5205308 ]
 [-0.8817944  -0.97543824]]

srnet_model_F00_bn_norm_e2_1e+01
[10.2357, 1.0083678, 0.6425728]
[1, 0, 2]
[[-7.5407500e+00 -2.0447320e-05]
 [ 1.9879901e+00  1.7829683e-01]
 [-8.0089658e-01 -1.3859421e+00]]

srnet_model_F00_bn_norm_e2_1e-01
[6.978867, 0.48938346, 0.15883532]
[1, 0, 2]
[[-5.0491509e+00  1.8623332e-05]
 [ 1.0000064e+00 -3.9814898e-01]
 [-8.0738872e-01 -9.3687499e-01]]

srnet_model_F00_bn_norm_e2_1e-03
[0.7005332, 0.5055884, 0.21763101]
[1, 0, 2]
[[-1.7058394   0.29456338]
 [ 0.999833   -0.5090959 ]
 [-0.8744926  -1.021463  ]]

srnet_model_F00_bn_norm_e2_1e-05
[0.7393068, 0.49176195, 0.22757514]
[1, 0, 2]
[[-1.585797    0.51597667]
 [ 0.97586644 -0.52029014]
 [-0.87763184 -0.9796664 ]]



In [None]:
model_name = "srnet_model_F00_bn_norm_e2_1e+01"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

In [None]:
nodes = all_nodes[:3]

In [None]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [None]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

In [None]:
corr_data = [
    ("x**2", x_data**2), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
    ("x", x_data),
    ("y", y_data),
]

In [None]:
ut.node_correlations(acts, nodes, corr_data, nonzero=True)

In [None]:
fig, ax = plt.subplots()

ax.scatter(x_data, x_data**2)
ax.scatter(x_data, acts[:,0])

plt.show()

In [None]:
fig, ax = plt.subplots()

ax.scatter(y_data, np.cos(y_data))
ax.scatter(y_data, acts[:,2])

plt.show()

In [None]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    #("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [None]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=[2], model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)