# 2022 Flatiron Machine Learning x Science Summer School

## Step 3: Train plain MLP

In this step, we train plain multilayer perceptrons (MLP) to approximate the generic data of the various functions $f \circ g$ created in Step 1.

We set up the training pipeline and explore hyperparameters.

### Step 3.1: Check convergence

First, we define baseline hyperparameters and inspect the convergence of the resulting baseline models on the (so far) five datasets of generic data.

In [57]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
import joblib

import torch
import wandb

from srnet import SRNet, SRData
import srnet_utils as ut

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# set wandb project
wandb_project = "31-check-convergence"

In [3]:
# define hyperparameters
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": 2,
#         "hid_size": 32, 
#         "hid_type": "MLP",
#         "lat_size": 16,
#         },
#     "epochs": 10000,
#     "runtime": None,
#     "batch_size": 64,
#     "lr": 1e-4,
#     "wd": 1e-4,
#     # "l1": 1e-4,
#     "shuffle": True,
# }

In [22]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for run in runs:
    for f in run.files():
        if f.name[-len(file_ext):] == file_ext and not os.path.isfile(f.name):
            print(f"Downloading {os.path.basename(f.name)}.")
            run.file(f.name).download()

Downloading srnet_model_F04_conv.pkl.
Downloading srnet_model_F02_conv.pkl.


In [24]:
# plot losses
ut.plot_losses("conv", save_path="models")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Notes:

* `F01` and `F05` seem well converged with a reasonable validation loss

* `F02` has a significantly higher validation loss than `F01`, despite having the same underlying target function (however, `X01` and `X02` are different)

* `F04` also has a significant validation loss that oscillates

* `F03` shows a massive validation loss

Analyzing the models for all datasets and optimizing their hyperparameters might be difficult. Let's start with `F01`.

### Step 3.2: Analyze `F01` model

Let's check how well the current model for `F01` performs.

In [38]:
# load data
data_path = "data"

in_var = "X01"
lat_var = "G01"
target_var = "F01"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))     # TODO: create mask if file does not exist

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [30]:
# load model
model_name = "srnet_model_F01_conv.pkl"
model_path = "models"

model = ut.load_model(model_name, model_path, SRNet)

In [44]:
# get predictions
with torch.no_grad():
    preds = model(train_data.in_data)

In [45]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [train_data.target_data, preds]
plot_size = train_data.target_data.shape[0]

In [46]:
ut.plot_acts(x_data, y_data, z_data, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

The predictions on the training data seem to be good enough. What about the validation data?

In [47]:
# get predictions
with torch.no_grad():
    preds = model(val_data.in_data)

In [48]:
# select plotting data
x_data = val_data.in_data[:,0]
y_data = val_data.in_data[:,1]
z_data = [val_data.target_data, preds]
plot_size = val_data.target_data.shape[0]

In [49]:
ut.plot_acts(x_data, y_data, z_data, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Except for a few outliers at the edges, the validation data is also approximated accurately.

**TODO**: Rerun the pipeline with a dataset size of 10,000.

What do the **latent features** look like?

In [50]:
# get predictions
with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)

In [64]:
# get latent feature variance
all_nodes = ut.get_node_order(acts, show=True)

[0.442663, 0.3568648, 0.34265625, 0.31251407, 0.30543056, 0.285317, 0.26871365, 0.25194448, 0.20889138, 0.19993809, 0.18855587, 0.17806222, 0.17064506, 0.14985456, 0.1324731, 0.13221803]
[3, 6, 10, 4, 15, 7, 5, 8, 2, 13, 12, 1, 9, 14, 11, 0]


In [70]:
nodes = all_nodes[:8]

In [71]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    train_data.lat_data[:,0], 
    # x_data**2, 
    # np.cos(y_data), 
    # x_data * y_data,
]
plot_size = train_data.target_data.shape[0]

In [72]:
ut.plot_acts(x_data, y_data, z_data, model=model, acts=acts, nodes=nodes, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

How do we know that g is approximated by the first part and f by the second part of the network?

### Step 3.X: Optimize hyperparameters for `F01`

We investigate the effects of the following hyperparameters:

* Architecture (`hid_num`, `hid_size`, `lat_size`)

* Dataset size

* Batch size

* Learning rate

* Weight decay

Sweeps:
```
sweep_config = {
    "method": "random"  # or grid or bayesian
}

metric = {
    "name": "loss",
    "goal": "minimize",
}

sweep_config["metric"] = metric

parameters_dict = {
    
    "layer_size": {
        "values": [128, 256, 512],
    }

    "learning_rate": {
        'distribution': 'uniform',      # q_log_uniform
        'min': 0,
        'max': 0.1,
    }
}

sweep_config["parameters"] = parameters_dict

sweep_id = wandb.sweep(sweep_config, project="name")

wandb.agent(sweep_id, train, count=5)
``` 

Are the results on `F04(X04)` accurate enough for reconstruction?

Sweep:

Initialize sweep via this notebook.

Run sweep via this PC and Colab

Analyze on wandb and here?

In [None]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import joblib
import matplotlib.pyplot as plt

import torch
from srnet import SRNet, SRData, run_training

In [None]:
hyperparams = {
    "arch": {
        "in_size": train_data.in_data.shape[1],
        "out_size": train_data.target_data.shape[1],
        "hid_num": 3,
        "hid_size": 25, 
        "hid_type": "MLP",
        "lat_size": 10,
        },
    "epochs": 10000,
    "runtime": None,
    "batch_size": 50,
    "lr": 1e-4,                                                     # TODO: adaptive learning rate?
    "wd": 1e-4,
    # "l1": 1e-4,
    "shuffle": True,
}

In [None]:
sweep_config = {
    "method": "random", # grid, bayesian
    "metric": {
        "name": "val_loss",
        "goal": "minimize",
    },
    "lr": {
        "values": [1e-3, 5e-4, 1e-4]
    },
    "batch_size": {
        "values": [16, 32, 64]
    },
    "hid_num1": {
        "values": [2, 4, 8]
    },
    "hid_num2": {
        "values": [1, 2, 4]
    },
    "hid_size": {
        "values": [32, 64, 128]
    },
}

sweep_config["parameters"] = parameters_dict

sweep_id = wandb.sweep(sweep_config, project="name")

In [None]:
# set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# load data
data_path = "data"

in_var = "X04"
lat_var = None
target_var = "F04"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))     # TODO: create mask if file does not exist

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"], device=device)
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"], device=device)

# define hyperparameters
hyperparams = {
    "arch": {
        "in_size": train_data.in_data.shape[1],
        "out_size": train_data.target_data.shape[1],
        "hid_num": 3,
        "hid_size": 25, 
        "hid_type": "MLP",
        "lat_size": 10,
        },
    "epochs": 5000,
    "runtime": None,
    "batch_size": 50,
    "lr": 1e-4,                                                     # TODO: adaptive learning rate?
    "wd": 1e-4,
    # "l1": 1e-4,
    "shuffle": True,
}

res = run_training(SRNet, hyperparams, train_data, val_data, device=device)

In [None]:
_, ax = plt.subplots()

# plot training loss
lines = ax.plot(res['train_loss'])

# plot validation loss
ax.plot(res['val_loss'], '--', color=lines[-1].get_color())

ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
plt.show()

Running the training on a GPU is actually slower than on a CPU. (19.26it/s vs. 31.04it/s)

`wandb.watch` slows down the training process

`num_workers` also slows down the training process (7.38it/s vs. 17.88it/s for `num_workers=2`)

Check:

* `torch.backends.cudnn.benchmark` (for constant batch sizes)

* `accelerate`

* `lighting`