# 2022 Flatiron Machine Learning x Science Summer School

## Step 4: Train MLP with $L_1$ regularization on latent features

### Step 4.1: Check $L_1$ regularization parameter

In [1]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
import joblib

import torch
import wandb

from srnet import SRNet, SRData
import srnet_utils as ut

In [2]:
# set wandb project
wandb_project = "41-l1-study-F00"

In [3]:
# define hyperparameters
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": "MLP",
#         "lat_size": 16,
#         },
#     "epochs": 10000,
#     "runtime": None,
#     "batch_size": 64,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     "l1": 0.0,
#     "shuffle": True,
# }

In [11]:
strig = "models/srnet_model_{target_var}_l1_{l1:.0e}.pkl"

In [7]:
dic = {"target_var": "F00", "l1": 1e-4}

## strig.format(**dic)

In [4]:
# define hyperparameter study
hp_study = {
    "method": "grid", # random, bayesian
    #"metric": {
    #    "name": "val_loss",
    #    "goal": "minimize",
    #},
    "parameters": {
        "l1": {
            "values": [1e-6, 1e-4, 1e-2, 1e-1]
        }
    }
}

In [5]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

Create sweep with ID: mqkzram7
Sweep URL: https://wandb.ai/fabxy/41-l1-study-F00/sweeps/mqkzram7


In [4]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for run in runs:
    for f in run.files():
        if f.name[-len(file_ext):] == file_ext and not os.path.isfile(f.name):
            print(f"Downloading {os.path.basename(f.name)}.")
            run.file(f.name).download()

Downloading srnet_model_F01_l1_v_v.pkl.
Downloading srnet_model_F01_l1_v.pkl.
Downloading srnet_model_F01_l1.pkl.


In [17]:
# plot losses
ut.plot_losses(["l1", "F01_conv1k"], save_path="models")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Notes:

* Increasing $L_1$ regularization parameter leads to higher training errors but similar/smaller validation errors

In [18]:
# load data
data_path = "data_1k"

in_var = "X01"
lat_var = "G01"
target_var = "F01"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))     # TODO: create mask if file does not exist

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [23]:
# latent feature variance overview
model_path = "models"
model_ext = ".pkl"

models = [
    "srnet_model_F01_conv1k",
    "srnet_model_F01_l1_1n6",
    "srnet_model_F01_l1_1n4",
    "srnet_model_F01_l1_1n2",
    "srnet_model_F01_l1_1n1",
]

for model_name in models:
    print(model_name)
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
    print("")

srnet_model_F01_conv1k
[0.442663, 0.3568648, 0.34265625, 0.31251407, 0.30543056, 0.285317, 0.26871365, 0.25194448, 0.20889138, 0.19993809, 0.18855587, 0.17806222, 0.17064506, 0.14985456, 0.1324731, 0.13221803]
[3, 6, 10, 4, 15, 7, 5, 8, 2, 13, 12, 1, 9, 14, 11, 0]

srnet_model_F01_l1_1n6
[0.4472595, 0.35221803, 0.3384347, 0.31435305, 0.29468605, 0.28430617, 0.26510754, 0.24103655, 0.20872487, 0.2011079, 0.19931003, 0.18311673, 0.15897202, 0.1499768, 0.14157562, 0.12980556]
[3, 6, 10, 4, 15, 7, 5, 8, 2, 13, 12, 1, 9, 14, 0, 11]

srnet_model_F01_l1_1n4
[0.4251673, 0.35212377, 0.3484957, 0.28564206, 0.2754437, 0.22463171, 0.18807544, 0.16734438, 0.15142357, 0.15080051, 0.14650822, 0.14612693, 0.1414289, 0.13368121, 0.107792415, 0.08147589]
[3, 6, 10, 5, 7, 4, 2, 15, 13, 1, 12, 9, 0, 8, 14, 11]

srnet_model_F01_l1_1n2
[0.053635195, 0.03521261, 0.031321988, 0.021988414, 0.007459262, 0.004458526, 0.0022832672, 0.001976531, 0.0018776564, 0.0012145492, 0.00088398193, 0.0007383683, 0.0005758562

In [32]:
model_name = "srnet_model_F01_l1_1n2"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

nodes = all_nodes[:4]

[0.053635195, 0.03521261, 0.031321988, 0.021988414, 0.007459262, 0.004458526, 0.0022832672, 0.001976531, 0.0018776564, 0.0012145492, 0.00088398193, 0.0007383683, 0.00057585625, 0.00034302214, 0.00011891858, 0.00011776732]
[3, 6, 10, 5, 2, 7, 0, 14, 13, 1, 4, 12, 9, 15, 11, 8]


In [40]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    # train_data.lat_data[:,0],
    ("x**2", x_data**2), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [43]:
ut.plot_acts(x_data, y_data, z_data, model=model, acts=acts, nodes=nodes, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [41]:
ut.node_correlations(acts, nodes, z_data, nonzero=True)


Node 3
corr(n3, x**2): 0.7875/0.7875
corr(n3, cos(y)): 0.1547/0.1547
corr(n3, x*y): 0.3645/0.3645

Node 6
corr(n6, x**2): 0.7902/0.7902
corr(n6, cos(y)): 0.1306/0.1306
corr(n6, x*y): 0.3754/0.3754

Node 10
corr(n10, x**2): 0.7879/0.7879
corr(n10, cos(y)): 0.1083/0.1083
corr(n10, x*y): 0.3219/0.3219

Node 5
corr(n5, x**2): -0.7810/-0.7810
corr(n5, cos(y)): -0.1716/-0.1716
corr(n5, x*y): -0.4075/-0.4075
