This is a wrapper that helps us run a few types of models on a few splits of
data. The outputs of this script are (1) saved features and (2) a trained model,
selected to have the best dev set score.

The main input parameters are the path to the `train_yaml` (relative to the root directory) and the bootstrap index to use. Since we may want to run this notebook as a python script (using `nbconvert`) we look up these arguments using environmental variables.

In [None]:
import os
import json
from addict import Dict
from pathlib import Path
from stability.data import initialize_loader
from stability.models.vae import VAE, vae_loss
from stability.models.cnn import CBRNet, cnn_loss
import stability.train as st
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import pandas as pd
import torch
import torch.optim
import yaml

train_yaml = os.environ["TRAIN_YAML"]
bootstrap = os.environ["BOOTSTRAP"]
data_dir = pathlib.Path(os.environ["DATA_DIR"])
root_dir = pathlib.Path(os.environ["ROOT_DIR"])
opts = Dict(yaml.safe_load(open(root_dir / train_yaml, "r")))
print(opts.train)

Let's define the model and the loss functions. This is not super elegant, basically a long switch statement.

In [None]:
if opts.train.model == "cnn":
    model = CBRNet(p_in=opts.train.p_in)
    loss_fn = cnn_loss
elif opts.train.model == "vae":
    model = VAE(z_dim=opts.train.z_dim)
    loss_fn = vae_loss
elif opts.train.model == "rcf":
    pass
else:
    raise NotImplementedError()

Next, we'll create directories for saving all the features. We'll also read in all paths for training / development / testing. This is a bit more involved than the usual training process, since we'll want loaders specifically for looking at changes in feature activations.

In [None]:
features_dir = data_dir / opts.organization.features_dir
os.makedirs(features_dir, exist_ok=True)

splits = pd.read_csv(data_dir / opts.organization.splits)
resample_ix = pd.read_csv(data_dir / opts.bootstrap.path)

paths = {
    "train": splits.loc[splits.split == "train", "path"].values[resample_ix.loc[bootstrap]],
    "dev": splits.loc[splits.split == "dev", "path"].values,
    "test": splits.loc[splits.split == "test", "path"].values,
    "all": splits["path"].values
}

np.random.seed(0)
save_ix = np.random.choice(len(splits), opts.train.save_subset, replace=False)
loaders = {
    "train_fixed": initialize_loader(paths["train"], data_dir, opts, num_samples=20),
    "train": initialize_loader(paths["train"], data_dir, opts, shuffle=True, num_samples=20),
    "dev": initialize_loader(paths["dev"], data_dir, opts, num_samples=20),
    "test": initialize_loader(paths["test"], data_dir, opts, num_samples=20),
    "features": initialize_loader(paths["all"][save_ix], data_dir, opts, num_samples=20)
}

Next, let's prepare a logger to save the training progress. We also save the indices of the samples for which we'll write activations -- it would be too much (and not really necessary) to write activations for all the samples.

In [None]:
subset_path = data_dir / opts.organization.features_dir / "subset.csv"
splits.iloc[save_ix, :].to_csv(subset_path)
writer = SummaryWriter(features_dir / "logs")
writer.add_text("conf", json.dumps(opts))
out_paths = [
    data_dir / opts.organization.features_dir, # where features are saved
    data_dir / opts.organization.metadata, # metadata for features (e.g., layer name)
    data_dir / opts.organization.model # where model gets saved
]

Finally, we can train our model. Training for the random convolutional features model is just ridge regression -- there are no iterations necessary. For the CNN and VAE, all the real logic is hidden away in the `st.train` function.

In [None]:
if opts.train.model == "rcf":
    ridge_model, D, y_hat = rcf.train_rcf(model, loaders["train"])
    metadata, errors = [], {}
    for split in ["dev", "test"]:
        D, y_hat, y = rcf.predict_rcf(model, ridge_model, loaders[split])
        errors[split] = np.mean((y - y_hat) ** 2)
        np.save(out_paths[0] / f"{split}_features.csv", D)

else:
    optim = torch.optim.SGD(model.parameters(), lr=opts.train.lr, momentum=opts.train.momentum, nesterov=True)
    st.train(model, optim, loaders, opts, out_paths, writer, loss_fn)