# 2022 Flatiron Machine Learning x Science Summer School

## Step 4: Train MLP with $L_1$ regularization on latent features

### Step 4.1: Check $L_1$ regularization parameter

In [1]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
import joblib

import torch
import wandb

from srnet import SRNet, SRData
import srnet_utils as ut

In [2]:
# set wandb project
wandb_project = "41-l1-study-F00"

In [3]:
# define hyperparameters
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": "MLP",
#         "lat_size": 16,
#         },
#     "epochs": 10000,
#     "runtime": None,
#     "batch_size": 64,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "shuffle": True,
# }

In [4]:
# define hyperparameter study
hp_study = {
    "method": "grid", # random, bayesian
    #"metric": {
    #    "name": "val_loss",
    #    "goal": "minimize",
    #},
    "parameters": {
        "l1": {
            "values": [1e-6, 1e-4, 1e-2, 1e-1]
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

In [None]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for run in runs:
    for f in run.files():
        if f.name[-len(file_ext):] == file_ext and not os.path.isfile(f.name):
            print(f"Downloading {os.path.basename(f.name)}.")
            run.file(f.name).download()

In [6]:
# plot losses
save_names = ["F00_l1", "F00_conv1k"]
save_path = "models"
ut.plot_losses(save_names, save_path="models", excl_names=["gc"]);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
# print losses
states = {}
for save_name in save_names:
    for file_name in sorted(os.listdir(save_path)):
        if save_name in file_name:
            state = joblib.load(os.path.join(save_path, file_name))
            print(f"{file_name.split('.')[0].split('_')[-1]}:\t{state['total_train_loss']:.3e} {state['total_val_loss']:.3e}")

Notes:

* A $L_1$ regularization parameter of `1e-3` seems to be an optimum for the validation error

In [None]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))     # TODO: create mask if file does not exist

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [None]:
# latent feature variance overview
model_path = "models"
model_ext = ".pkl"

models = [
    "srnet_model_F00_conv1k",
    "srnet_model_F00_l1_1e-06",
    "srnet_model_F00_l1_1e-04",
    "srnet_model_F00_l1_1e-03",
    "srnet_model_F00_l1_1e-02",
    "srnet_model_F00_l1_1e-01",
]

for model_name in models:
    print(model_name)
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
    print("")

Notes:

* A $L_1$ regularization parameter of `1e-3` yields three high variance latent features

In [None]:
model_name = "srnet_model_F00_l1_1e-03"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

In [None]:
nodes = all_nodes[:3]

In [None]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [None]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, agg=False, plot_size=plot_size)

In [None]:
corr_data = [
    ("x**2", x_data**2), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]

In [None]:
ut.node_correlations(acts, nodes, corr_data, nonzero=True)

Notes:

* The high variance latent features do not split into the desired latent functions

### Step 4.1: Train MLP with $L_1$ regularization for `F06`

In [7]:
# set wandb project
wandb_project = "42-l1-study-F06"

In [3]:
# define hyperparameters
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": "MLP",
#         "lat_size": 16,
#         },
#     "epochs": 10000,
#     "runtime": None,
#     "batch_size": 64,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "gc": 0.0,
#     "shuffle": True,
# }

In [8]:
# define hyperparameter study
hp_study = {
    "method": "grid", # random, bayesian
    #"metric": {
    #    "name": "val_loss",
    #    "goal": "minimize",
    #},
    "parameters": {
        "l1": {
            "values": [1e-4, 1e-3, 1e-2]
        }
    }
}

In [9]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

Create sweep with ID: cuusilho
Sweep URL: https://wandb.ai/fabxy/42-l1-study-F06/sweeps/cuusilho


In [34]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for run in runs:
    for f in run.files():
        if f.name[-len(file_ext):] == file_ext and not os.path.isfile(f.name):
            print(f"Downloading {os.path.basename(f.name)}.")
            run.file(f.name).download()

Downloading srnet_model_F06_l1_3e-03.pkl.


In [35]:
# plot losses
save_names = ["F06_l1", "F06_conv1k"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models", excl_names=["gc"])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [36]:
# print losses
model_ext = ".pkl"
states = {}
for model_name in models:
    state = joblib.load(os.path.join(save_path, model_name + model_ext))
    print(f"{model_name.split('.')[0].split('_')[-1]}:\t{state['total_train_loss']:.3e} {state['total_val_loss']:.3e}")

1e-02:	1.121e-04 9.756e-02
1e-03:	6.536e-05 8.146e-02
1e-04:	9.697e-05 6.264e-02
3e-03:	9.247e-05 8.242e-02
5e-03:	1.455e-04 8.647e-02
conv1k:	4.420e-05 4.926e-02


In [37]:
# load data
data_path = "data_1k"

in_var = "X06"
lat_var = "G06"
target_var = "F06"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [38]:
# latent feature variance overview
model_path = "models"

for model_name in models:
    print(model_name)
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
    print("")

srnet_model_F06_l1_1e-02
[0.049321637, 0.004771799, 0.00019846704, 6.6069247e-06, 1.4752286e-06, 4.5226514e-08, 2.13454e-08, 1.7446855e-08, 1.18685435e-08, 5.5786127e-09, 3.9206283e-09, 2.1274518e-09, 8.199103e-10, 6.736814e-10, 2.3913765e-10, 1.6228391e-10]
[9, 8, 11, 10, 15, 5, 7, 13, 3, 2, 6, 0, 4, 14, 12, 1]

srnet_model_F06_l1_1e-03
[0.039115362, 0.03129036, 0.023613036, 0.011601761, 0.011366514, 0.009309147, 0.002526981, 6.6599045e-05, 5.3251642e-05, 1.1133867e-09, 6.4825845e-10, 6.21421e-10, 5.41498e-10, 4.3636036e-10, 3.4151335e-10, 1.5669119e-10]
[9, 10, 11, 13, 8, 15, 3, 5, 7, 2, 6, 4, 0, 1, 12, 14]

srnet_model_F06_l1_1e-04
[0.035668433, 0.035033483, 0.034245808, 0.033510864, 0.028321486, 0.02800836, 0.025181893, 0.023803087, 0.021075113, 0.019416295, 0.010689795, 0.01015003, 0.009687535, 0.00847647, 0.0072003896, 0.006011835]
[9, 11, 10, 13, 7, 8, 3, 5, 15, 6, 1, 4, 2, 12, 0, 14]

srnet_model_F06_l1_3e-03
[0.043674815, 0.015250864, 0.015104065, 0.007413745, 5.3274063e-05, 2

Notes:

* Interestingly, while `1e-03` is clearly too little regularization, `5e-03` seems to be already too much

* `3e-03` seems to be alright, however, the sensitivity is surprising

In [39]:
# load model
model_name = "srnet_model_F06_l1_3e-03.pkl"
model_path = "models"

model = ut.load_model(model_name, model_path, SRNet)

In [40]:
# get predictions
with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)

In [41]:
# get latent feature variance
all_nodes = ut.get_node_order(acts, show=True)

[0.043674815, 0.015250864, 0.015104065, 0.007413745, 5.3274063e-05, 2.1200472e-06, 3.2012056e-07, 4.970898e-08, 4.043706e-09, 6.789294e-10, 6.3081884e-10, 5.966184e-10, 3.210844e-10, 2.8482752e-10, 1.0929242e-10, 7.421286e-11]
[9, 11, 10, 8, 13, 15, 3, 7, 5, 2, 14, 4, 12, 6, 0, 1]


In [42]:
nodes = all_nodes[:4]

In [43]:
# select data
x0_data = train_data.in_data[:,0]
x3_data = train_data.in_data[:,3]
x5_data = train_data.in_data[:,5]
x7_data = train_data.in_data[:,7]

corr_data = [
    ("x0**2", x0_data**2), 
    ("cos(x3)", np.cos(x3_data)), 
    ("x5*x7", x5_data * x7_data),
    ("x0", x0_data),
    ("x3", x3_data),
    ("x5", x5_data),
    ("x7", x7_data),
]

In [44]:
ut.node_correlations(acts, nodes, corr_data, nonzero=True)


Node 9
corr(n9, x0**2): -0.7860/-0.7860
corr(n9, cos(x3)): -0.2444/-0.2444
corr(n9, x5*x7): -0.5961/-0.5961
corr(n9, x0): -0.0272/-0.0272
corr(n9, x3): -0.0642/-0.0642
corr(n9, x5): -0.0232/-0.0232
corr(n9, x7): 0.0006/0.0006

Node 11
corr(n11, x0**2): -0.7805/-0.7805
corr(n11, cos(x3)): -0.2323/-0.2323
corr(n11, x5*x7): -0.6077/-0.6077
corr(n11, x0): -0.0679/-0.0679
corr(n11, x3): -0.0405/-0.0405
corr(n11, x5): -0.0369/-0.0369
corr(n11, x7): -0.0509/-0.0509

Node 10
corr(n10, x0**2): -0.7831/-0.7831
corr(n10, cos(x3)): -0.2328/-0.2328
corr(n10, x5*x7): -0.6039/-0.6039
corr(n10, x0): -0.0373/-0.0373
corr(n10, x3): -0.0484/-0.0484
corr(n10, x5): -0.0465/-0.0465
corr(n10, x7): -0.0213/-0.0213

Node 8
corr(n8, x0**2): 0.7716/0.7716
corr(n8, cos(x3)): 0.2486/0.2486
corr(n8, x5*x7): 0.6097/0.6097
corr(n8, x0): 0.0504/0.0504
corr(n8, x3): 0.0286/0.0286
corr(n8, x5): -0.0029/-0.0029
corr(n8, x7): 0.0443/0.0443


Notes:

* Despite being now separable, the high variance latent features do not split into the desired latent functions