# 2022 Flatiron Machine Learning x Science Summer School

This notebook trains SRNets on Colab.

In [None]:
!git clone https://github.com/fabxy/symrep.git

In [None]:
%cd symrep

In [None]:
!pip install wandb einops --upgrade

In [None]:
!wandb login

In [None]:
import os
import torch
import joblib
from srnet import SRNet, SRData, run_training
import wandb

In [None]:
# set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# set wandb options
wandb_project = "51-a1-a2-study-F00"
sweep_id = "2rtgnxtt"
sweep_num = 16

In [None]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = None
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"], device=device)
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"], device=device)

In [None]:
# set save file
save_file = "models/srnet_model_F00_a1_{a1:.0e}_a2_{a2:.0e}.pkl"

In [None]:
# define hyperparameters
hyperparams = {
    "arch": {
        "in_size": train_data.in_data.shape[1],
        "out_size": train_data.target_data.shape[1],
        "hid_num": (2,0),
        "hid_size": 32, 
        "hid_type": ("DSN", "MLP"),
        "lat_size": 16,
        },
    "epochs": 10000,
    "runtime": None,
    "batch_size": 64,
    "lr": 1e-4,
    "wd": 1e-4,
    "l1": 0.0,
    "a1": 0.0,
    "a2": 0.0,
    "shuffle": True,
}

In [None]:
def train():
    run_training(SRNet, hyperparams, train_data, val_data, save_file=save_file, device=device, wandb_project=wandb_project)

In [None]:
# hyperparameter study
if sweep_id:
    wandb.agent(sweep_id, train, count=sweep_num, project=wandb_project)

# one training run
else:
    train()