# 2022 Flatiron Machine Learning x Science Summer School

## Step 4: Train MLP with $L_1$ regularization on latent features

### Step 4.1: Check $L_1$ regularization parameter

In [1]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
import joblib

import torch
import wandb

from srnet import SRNet, SRData
import srnet_utils as ut

In [2]:
# set wandb project
wandb_project = "41-l1-study-F00"

In [3]:
# define hyperparameters
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": "MLP",
#         "lat_size": 16,
#         },
#     "epochs": 10000,
#     "runtime": None,
#     "batch_size": 64,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "shuffle": True,
# }

In [4]:
# define hyperparameter study
hp_study = {
    "method": "grid", # random, bayesian
    #"metric": {
    #    "name": "val_loss",
    #    "goal": "minimize",
    #},
    "parameters": {
        "l1": {
            "values": [1e-6, 1e-4, 1e-2, 1e-1]
        }
    }
}

In [15]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

Create sweep with ID: klh56zo4
Sweep URL: https://wandb.ai/fabxy/41-l1-study-F00/sweeps/klh56zo4


In [58]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for run in runs:
    for f in run.files():
        if f.name[-len(file_ext):] == file_ext and not os.path.isfile(f.name):
            print(f"Downloading {os.path.basename(f.name)}.")
            run.file(f.name).download()

Downloading srnet_model_F00_l1_1e-03.pkl.


In [59]:
# plot losses
save_names = ["F00_l1", "F00_conv1k"]
save_path = "models"
ut.plot_losses(save_names, save_path="models")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [60]:
# print losses
states = {}
for save_name in save_names:
    for file_name in sorted(os.listdir(save_path)):
        if save_name in file_name:
            state = joblib.load(os.path.join(save_path, file_name))
            print(f"{file_name.split('.')[0].split('_')[-1]}:\t{state['total_train_loss']:.3e} {state['total_val_loss']:.3e}")

1e-01:	8.017e-04 1.166e-03
1e-02:	1.823e-04 1.749e-03
1e-03:	7.821e-05 4.486e-04
1e-04:	1.783e-04 1.232e-03
1e-06:	1.147e-04 3.317e-03
conv1k:	1.155e-04 3.428e-03


Notes:

* A $L_1$ regularization parameter of `1e-3` seems to be an optimum for the validation error

In [40]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))     # TODO: create mask if file does not exist

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [61]:
# latent feature variance overview
model_path = "models"
model_ext = ".pkl"

models = [
    "srnet_model_F00_conv1k",
    "srnet_model_F00_l1_1e-06",
    "srnet_model_F00_l1_1e-04",
    "srnet_model_F00_l1_1e-03",
    "srnet_model_F00_l1_1e-02",
    "srnet_model_F00_l1_1e-01",
]

for model_name in models:
    print(model_name)
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
    print("")

srnet_model_F00_conv1k
[0.10614656, 0.09515473, 0.08990478, 0.08980098, 0.07525699, 0.072339885, 0.06941165, 0.06478065, 0.06317273, 0.06313105, 0.059894454, 0.056416783, 0.0487215, 0.04010966, 0.039172273, 0.031783156]
[1, 10, 15, 0, 14, 5, 12, 8, 7, 6, 9, 3, 13, 4, 2, 11]

srnet_model_F00_l1_1e-06
[0.10398815, 0.093214676, 0.08905195, 0.08769247, 0.073902845, 0.07085094, 0.06706689, 0.06349006, 0.062067565, 0.061193217, 0.05973298, 0.05828404, 0.047263596, 0.038578775, 0.03761107, 0.030239193]
[1, 10, 15, 0, 14, 5, 12, 8, 6, 7, 3, 9, 13, 4, 2, 11]

srnet_model_F00_l1_1e-04
[0.07022424, 0.05906304, 0.058847193, 0.056146547, 0.04448867, 0.038455795, 0.03106859, 0.028268447, 0.017816655, 0.015585285, 0.0128462305, 0.010780148, 0.009062688, 0.0036060382, 0.003577155, 0.0005865455]
[15, 0, 10, 1, 8, 5, 14, 7, 9, 6, 3, 12, 13, 2, 11, 4]

srnet_model_F00_l1_1e-03
[0.1238546, 0.08738362, 0.01592117, 0.00017046831, 1.50748965e-05, 1.2677921e-07, 1.1294611e-07, 4.7373774e-08, 4.062948e-08, 4.3

Notes:

* A $L_1$ regularization parameter of `1e-3` yields three high variance latent features

In [62]:
model_name = "srnet_model_F00_l1_1e-03"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

[0.1238546, 0.08738362, 0.01592117, 0.00017046831, 1.50748965e-05, 1.2677921e-07, 1.1294611e-07, 4.7373774e-08, 4.062948e-08, 4.3782458e-09, 4.0273433e-09, 3.2008614e-09, 2.0540378e-09, 1.9341104e-09, 9.1654095e-10, 9.0483304e-10]
[15, 0, 8, 10, 1, 5, 14, 3, 6, 9, 12, 7, 13, 11, 4, 2]


In [63]:
nodes = all_nodes[:3]

In [64]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [65]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [66]:
corr_data = [
    ("x**2", x_data**2), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]

In [67]:
ut.node_correlations(acts, nodes, corr_data, nonzero=True)


Node 15
corr(n15, x**2): 0.8422/0.8422
corr(n15, cos(y)): 0.1718/0.1718
corr(n15, x*y): 0.6315/0.6315

Node 0
corr(n0, x**2): -0.8289/-0.8289
corr(n0, cos(y)): -0.1898/-0.1898
corr(n0, x*y): -0.6422/-0.6422

Node 8
corr(n8, x**2): -0.7825/-0.7825
corr(n8, cos(y)): -0.2553/-0.2553
corr(n8, x*y): -0.6632/-0.6632


Notes:

* The high variance latent features do not split into the desired latent functions