# 2022 Flatiron Machine Learning x Science Summer School

## Step 3: Train plain MLP

In this step, we train plain multilayer perceptrons (MLP) to approximate the generic data of the various functions $f \circ g$ created in Step 1.

We set up the training pipeline and explore hyperparameters.

### Step 3.1: Check convergence

First, we define baseline hyperparameters and inspect the convergence of the resulting baseline models on the (so far) five datasets of generic data.

In [1]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
import joblib

import torch
import wandb

from srnet import SRNet, SRData
import srnet_utils as ut

In [2]:
# set wandb project
wandb_project = "31-check-convergence"

In [3]:
# define hyperparameters
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": 2,
#         "hid_size": 32, 
#         "hid_type": "MLP",
#         "lat_size": 16,
#         },
#     "epochs": 10000,
#     "runtime": None,
#     "batch_size": 64,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     # "l1": 1e-4,
#     "shuffle": True,
# }

In [4]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for run in runs:
    for f in run.files():
        if f.name[-len(file_ext):] == file_ext and not os.path.isfile(f.name):
            print(f"Downloading {os.path.basename(f.name)}.")
            run.file(f.name).download()

Downloading srnet_model_F06_conv1k.pkl.
Downloading srnet_model_F05_conv.pkl.
Downloading srnet_model_F04_conv.pkl.
Downloading srnet_model_F03_conv.pkl.
Downloading srnet_model_F02_conv.pkl.
Downloading srnet_model_F01_conv.pkl.


In [5]:
# plot losses
ut.plot_losses("conv1k", save_path="models");

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

['srnet_model_F00_conv1k',
 'srnet_model_F01_conv1k',
 'srnet_model_F02_conv1k',
 'srnet_model_F03_conv1k',
 'srnet_model_F04_conv1k',
 'srnet_model_F05_conv1k',
 'srnet_model_F06_conv1k']

Notes:

* `F01` and `F05` seem well converged with a reasonable validation loss

* `F02` has a significantly higher validation loss than `F01`, despite having the same underlying target function (however, `X01` and `X02` are different)

* `F04` also has a significant validation loss that oscillates

* `F03` shows a massive validation loss

* `F00` as the simplest expression shows the lowest training and validation errors (however, some overfitting seems to occur)

* `F06` has a low training error, but the validation loss is not very low. Are we overfitting?

Analyzing the models for all datasets and optimizing their hyperparameters might be difficult. 

And how do we know that $g$ is approximated by the first part and $f$ by the second part of the network?

Let's start with `F00`.

### Step 3.2: Analyze `F00` model

Let's check how well the current model for `F00` performs.

In [None]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))     # TODO: create mask if file does not exist

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [None]:
# load model
model_name = "srnet_model_F00_conv1k.pkl"
model_path = "models"

model = ut.load_model(model_name, model_path, SRNet)

In [None]:
# get predictions
with torch.no_grad():
    preds = model(train_data.in_data)

In [None]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [("target", train_data.target_data), ("pred", preds)]
plot_size = train_data.target_data.shape[0]

In [None]:
ut.plot_acts(x_data, y_data, z_data, plot_size=plot_size)

The predictions on the training data seem to be good enough. What about the validation data?

In [None]:
# get predictions
with torch.no_grad():
    preds = model(val_data.in_data)

In [None]:
# select plotting data
x_data = val_data.in_data[:,0]
y_data = val_data.in_data[:,1]
z_data = [("target", val_data.target_data), ("pred", preds)]
plot_size = val_data.target_data.shape[0]

In [None]:
ut.plot_acts(x_data, y_data, z_data, plot_size=plot_size)

Except for a few outliers at the edges, the validation data is also approximated accurately.

**TODO**: Rerun the pipeline with a dataset size of 10,000.

What do the **latent features** look like?

In [None]:
# get predictions
with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)

In [None]:
# get latent feature variance
all_nodes = ut.get_node_order(acts, show=True)

In [None]:
nodes = all_nodes[:8]

In [None]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", train_data.lat_data[:,0]), 
    #("cos(y)", train_data.lat_data[:,1]), 
    #("x*y", train_data.lat_data[:,2]),
]
plot_size = train_data.target_data.shape[0]

In [None]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, agg=True, plot_size=plot_size)

### Step 3.3: Optimize hyperparameters for `F00`

**STILL OPEN**

We investigate the effects of the following hyperparameters:

* Architecture (`hid_num`, `hid_size`, `lat_size`)

* Dataset size

* Batch size

* Learning rate

* Weight decay

In [None]:
sweep_config = {
    "method": "random", # grid, bayesian
    "metric": {
        "name": "val_loss",
        "goal": "minimize",
    },
    "lr": {
        "values": [1e-3, 5e-4, 1e-4]
    },
    "batch_size": {
        "values": [16, 32, 64]
    },
    "hid_num1": {
        "values": [2, 4, 8]
    },
    "hid_num2": {
        "values": [1, 2, 4]
    },
    "hid_size": {
        "values": [32, 64, 128]
    },
}

sweep_config["parameters"] = parameters_dict

sweep_id = wandb.sweep(sweep_config, project="name")

**Notes**:

Running the training on a GPU is actually slower than on a CPU. (19.26it/s vs. 31.04it/s)

`wandb.watch` slows down the training process

`num_workers` also slows down the training process (7.38it/s vs. 17.88it/s for `num_workers=2`)

`torch.backends.cudnn.benchmark` does not impact the GPU training speed


Check:

* `accelerate`

* `lighting`

### Step 3.4: Analyze `F06` model

Due to 8 input features, the latent features cannot be plotted that easily anymore.

In [11]:
# load data
data_path = "data_1k"

in_var = "X06"
lat_var = "G06"
target_var = "F06"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))     # TODO: create mask if file does not exist

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [12]:
# load model
model_name = "srnet_model_F06_conv1k.pkl"
model_path = "models"

model = ut.load_model(model_name, model_path, SRNet)

In [13]:
# get predictions
with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)

In [14]:
# get latent feature variance
all_nodes = ut.get_node_order(acts, show=True)

[0.07345882, 0.038932025, 0.038396627, 0.030609325, 0.022553695, 0.021590559, 0.021060195, 0.018421583, 0.016839076, 0.014656181, 0.013953692, 0.013143861, 0.012904241, 0.011364459, 0.008470869, 0.0072448477]
[4, 5, 2, 10, 7, 0, 13, 12, 3, 14, 9, 1, 6, 11, 15, 8]


In [15]:
nodes = all_nodes[:8]

In [16]:
# select data
x0_data = train_data.in_data[:,0]
x3_data = train_data.in_data[:,3]
x5_data = train_data.in_data[:,5]
x7_data = train_data.in_data[:,7]

corr_data = [
    ("x0**2", x0_data**2), 
    ("cos(x3)", np.cos(x3_data)), 
    ("x5*x7", x5_data * x7_data),
    ("x0", x0_data),
    ("x3", x3_data),
    ("x5", x5_data),
    ("x7", x7_data),
]

In [17]:
ut.node_correlations(acts, nodes, corr_data, nonzero=True)


Node 4
corr(n4, x0**2): 0.6778/0.6778
corr(n4, cos(x3)): -0.1845/-0.1845
corr(n4, x5*x7): 0.2032/0.2032
corr(n4, x0): -0.0778/-0.0778
corr(n4, x3): -0.0098/-0.0098
corr(n4, x5): 0.0490/0.0490
corr(n4, x7): -0.0724/-0.0724

Node 5
corr(n5, x0**2): -0.7653/-0.7653
corr(n5, cos(x3)): 0.0262/0.0262
corr(n5, x5*x7): -0.3672/-0.3672
corr(n5, x0): 0.1303/0.1303
corr(n5, x3): 0.0474/0.0474
corr(n5, x5): -0.0803/-0.0803
corr(n5, x7): -0.1144/-0.1144

Node 2
corr(n2, x0**2): -0.8322/-0.8322
corr(n2, cos(x3)): -0.1152/-0.1152
corr(n2, x5*x7): -0.4042/-0.4042
corr(n2, x0): -0.0667/-0.0667
corr(n2, x3): -0.0849/-0.0849
corr(n2, x5): -0.0361/-0.0361
corr(n2, x7): 0.0767/0.0767

Node 10
corr(n10, x0**2): 0.5816/0.5816
corr(n10, cos(x3)): -0.1766/-0.1766
corr(n10, x5*x7): -0.0120/-0.0120
corr(n10, x0): 0.1843/0.1843
corr(n10, x3): 0.0886/0.0886
corr(n10, x5): -0.0155/-0.0155
corr(n10, x7): -0.0887/-0.0887

Node 7
corr(n7, x0**2): 0.8282/0.8282
corr(n7, cos(x3)): 0.0887/0.0887
corr(n7, x5*x7): 0.4154/