# 2022 Flatiron Machine Learning x Science Summer School

## Step 3: Train plain MLP

In this step, we train plain multilayer perceptrons (MLP) to approximate the generic data of the various functions $f \circ g$ created in Step 1.

We set up the training pipeline and explore hyperparameters.

In [1]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import joblib
import matplotlib.pyplot as plt

import torch
from srnet import SRNet, SRData, run_training

In [2]:
# set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# load data
data_path = "data"

in_var = "X04"
lat_var = None
target_var = "F04"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))     # TODO: create mask if file does not exist

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"], device=device)
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"], device=device)

# define hyperparameters
hyperparams = {
    "arch": {
        "in_size": train_data.in_data.shape[1],
        "out_size": train_data.target_data.shape[1],
        "hid_num": 3,
        "hid_size": 25, 
        "hid_type": "MLP",
        "lat_size": 10,
        },
    "epochs": 5000,
    "runtime": None,
    "batch_size": 50,
    "lr": 1e-4,                                                     # TODO: adaptive learning rate?
    "wd": 1e-4,
    # "l1": 1e-4,
    "shuffle": True,
}

res = run_training(SRNet, hyperparams, train_data, val_data)

Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Total training loss: 9.532e+01
Total validation loss: 1.012e+02


In [None]:
_, ax = plt.subplots()

# plot training loss
lines = ax.plot(res['train_loss'])

# plot validation loss
ax.plot(res['val_loss'], '--', color=lines[-1].get_color())

ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
plt.show()

In [None]:
_, ax = plt.subplots()

# plot training loss
lines = ax.plot(res['train_loss'])

# plot validation loss
ax.plot(res['val_loss'], '--', color=lines[-1].get_color())

ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
plt.show()