# 2022 Flatiron Machine Learning x Science Summer School

## Step 18: Compare training DSN on `F07` with fixed vs. trainable symbolic discriminator

First, we train a Disentangled Sparsity Network (DSN) using with a symbolic discriminator (SD), which is pre-trained using the BCE loss to classify function library data vs. untrained MLP activations, for regularization on a problem with a larger function library, namely `F07`.

Next, we train the DSN and the SD in parallel in an adversarial setting.

In [1]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import ipywidgets as ipw
import joblib

import torch
import wandb

from srnet import SRNet, SRData
from sdnet import SDNet, SDData
import srnet_utils as ut

### Step 18.1: Train symbolic discriminator with BCE loss

In [2]:
ut.plot_disc_accuracies("disc_model_F07_v2_fixed_BCE", "models", excl_names=[], avg_hor=50, uncertainty=True);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

It's so good, did it work correctly?

In [3]:
# load data
data_path = "data_1k"
in_var = "X07"
lat_var = "G07"
target_var = "F07"

mask_ext = ".mask"
try:
    masks = joblib.load(os.path.join(data_path, in_var + mask_ext))
    mask = masks['train']
except:
    mask = None
    print("Warning: No mask for training loaded.")

train_data = SRData(data_path, in_var, lat_var, target_var, data_mask=mask)

In [4]:
# load function library
fun_path = "funs/F07_v2.lib"
shuffle = False
iter_sample = False
disc_data = SDData(fun_path, in_var, shuffle=shuffle, iter_sample=iter_sample)

In [5]:
# load trained critic
critic = ut.load_disc("disc_model_F07_v2_fixed_BCE.pkl", save_path="models", disc_cls=SDNet)

In [None]:
# create fake data
hp = {
    "arch": {
        "in_size": 2,
        "out_size": 1,
        "hid_num": (2,0),
        "hid_size": 32, 
        "hid_type": ("MLP", "MLP"),
        "hid_kwargs": {
            "alpha": None,
            "norm": None,
            "prune": None,
            },
        "lat_size": 3,
    },
}

In [None]:
model = SRNet(**hp['arch'])
with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)

In [None]:
critic(acts.T)

In [None]:
# create real data
data_real = disc_data.get(disc_data.len, train_data.in_data)

In [None]:
critic(data_real[0,:,:,0])

### Step 18.2: Train DSN with fixed SD

Let's use the trained SD to regularize the DSN when training on `F07` data.

We also want to track the critic predictions.

Currently, we track:

* Training loss per epoch and averaged over batches

* Validation loss every `log_freq` epochs

* `min_corr` on validation data every `log_freq` epochs (a more frequent computation might be expensive)

Let's also track the critic prediction per epoch averaged over latent features and batches.

In [6]:
models = ut.plot_losses("srnet_model_F07_v2_critic_check_lr_1e-4_sd_1e-4", "models", excl_names=[])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [7]:
models = ut.plot_corrs("srnet_model_F07_v2_critic_check_lr_1e-4_sd_1e-4", "models", excl_names=[])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [8]:
models = ut.plot_disc_preds("srnet_model_F07_v2_critic_check_lr_1e-4_sd_1e-4", "models", excl_names=[])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Notes:

* Up to 5000 epochs, the training with linear and sigmoid predictions is basically equivalent

* At this point, the average critic prediction is already around 1e2 and the sigmoid of the negative of this prediction would be 0

* This only makes sense if the critic regularization is not contributing to the total training loss

Let's track the critic regularization loss.

In [None]:
models = ut.plot_losses("srnet_model_F07_v2_critic_check_lr_1e-4_sd_1e-4", "models", excl_names=[], disc_loss=True)

In [None]:
models = ut.plot_reg_percentage("srnet_model_F07_v2_critic_check_lr_1e-4_sd_1e-4", "models", excl_names=[])

Notes:

* Using sigmoid predictions (`sig`), the SD is not regularizing the DSN at all

* Using linear predictions (`lin`), the prediction and regularization losses have roughly the same magnitude at around 11,000 epochs and afterwards, `min_corr` and `disc_preds` are decreasing and increasing rapidly

How much do we need to increase `sd`, so that the sigmoid predictions of the critic are relevant?

In [None]:
models = ut.plot_losses("srnet_model_F07_v2_critic_check_lr_1e-4", "models", excl_names=["lin"], log=True, disc_loss=True)

In [None]:
models = ut.plot_reg_percentage("srnet_model_F07_v2_critic_check_lr_1e-4", "models", excl_names=["lin"], log=True)

In [None]:
models = ut.plot_disc_preds("srnet_model_F07_v2_critic_check_lr_1e-4", "models", excl_names=["lin"])

In [None]:
models = ut.plot_corrs("srnet_model_F07_v2_critic_check_lr_1e-4", "models", excl_names=["lin"])

Notes:

* Even for `sd = 1e2`, the regularization loss is only really relevant for the first ten epochs

* Afterwards, it's 3 to 4 orders of magnitude smaller, until it drops entirely somewhere between 1000 to 5000 epochs

* Interestingly, these small regularization losses seem sufficient to yield different, i.e. worse, `min_corr` values for high `sd` values

Anyway, we will need to train DSN and SD in parallel

### Step 18.3: Train DSN and SD in parallel

Let's first explore training the DSN and SD in parallel by comparing different approaches to utilize the critic predictions for regularization.

Note that the DSN wants to maximize critic predictions, i.e. learning latent features that are similar to data from the function library.

1. `linear`: Simply take the negative value of the critic predictions, so minimizing the total loss is maximizing critic predictions

2. `sigmoid`: Take the sigmoid of the negative critic predictions, so positive values yield regularization losses between 0 and 0.5 and negative losses between 0.5 and 1

3. `logsigmoid`: Take the negative of the logarithm of the sigmoid of the critic predictions, so positive values yield sigmoid values between 0.5 and 1, the logarithm of which yields small negative numbers. Negative values yield sigmoid values between 0 and 0.5, the logarithm of which yields large negative numbers (nearly linear with the original critic predictions)

The generator in standard GAN training minimizes `logsigmoid` (https://youtu.be/OljTVUVzPpM): 

* `min log(1 - sig(D(fake)))` -> `max log(sig(D(fake)))` due to saturating gradients, however:

    * `min log(1 - D(fake))` is equal to `max fake + log(1 + exp(-fake))`

    * `max log(sig(D(fake)))` is equal to `min log(1 + exp(-fake))`

Also standard GAN training:

* Uses `LeakyReLU` for generator and discriminator

* Uses the same learning rate for generator and discriminator (try magic `3e-4`)

* Uses `BCELoss` for the discriminator (as we do now): `max log(sig(D(real))) + log(1 - sig(D(fake)))`

* Runs only one iteration of discriminator training per epoch

**Q**: One difference of our training approach is that the fake data is basically fixed, i.e. it changes only slightly per epoch and cannot be sampled. Is that a problem? Could we do something about it, e.g. dropout?

In [9]:
models = ut.plot_losses("srnet_model_F07_v2_comb_check", "models", excl_names=["arch"], disc_loss=False, log=True)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [10]:
models = ut.plot_reg_percentage("srnet_model_F07_v2_comb_check", "models", excl_names=["arch"])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [11]:
models = ut.plot_corrs("srnet_model_F07_v2_comb_check", "models", excl_names=["arch"])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
models = ut.plot_disc_preds("srnet_model_F07_v2_comb_check", "models", excl_names=["arch"])

* Using sigmoid predictions (`sig`):

    * The regularization loss dominates the total loss after 12000 epochs.
    
    * However, the prediction loss keeps decreasing, while the regularization loss is constant unity and `min_corr` decreases.
    
    * The average SD predictions are very low. Does the critic not learn anything?

* Using linear predictions (`lin`):

    * The results are very promising.
    
    * While the training is quite unstable (what is normal for GAN training?), `min_corr` reaches near 100% correlation at an acceptable validation loss. 
    
    * Compared to the `sig` results, the average SD predictions are moderate and the regularization loss never completely dominates the total loss.

* Using `logsigmoid` predictions (`log`):

    * Similar to `lin` for large negative critic predictions, but converges to 0 instead of negative values for large positive critic predictions. 
    
    * While the average SD predictions and the regularization to prediction loss ratio are roughly similar and slightly lower validation errors are achieved, a similar level of `min_corr` is not reached.

Note that all of these observations are specific to the one set of selected hyperparameters.

Let's inspect the recorded latent activations.

In [12]:
# load data
data_path = "data_1k"
in_var = "X07"
lat_var = "G07"
target_var = "F07"

mask_ext = ".mask"
try:
    masks = joblib.load(os.path.join(data_path, in_var + mask_ext))
    mask = masks['val']
except:
    mask = None
    print("Warning: No mask for training loaded.")

val_data = SRData(data_path, in_var, lat_var, target_var, data_mask=mask)

In [13]:
state = joblib.load("models/srnet_model_F07_v2_comb_check_lr_1e-4_sd_1e0_lin.pkl")
rec_acts = joblib.load("models/srnet_model_F07_v2_comb_check_lr_1e-4_sd_1e0_lin.rec")

epochs = len(state['train_loss'])
logs = len(rec_acts)

add_log = epochs % (logs - 1) == 0
log_freq = int(epochs / (logs-add_log))

log_epochs = (np.arange(logs) * log_freq).tolist()
log_epochs[-1] = epochs - 1

In [14]:
node_inputs = {
    0: [0],
    1: [1],
    2: [0, 1],
}

fig_width = 9
view = (6, -92)

w_epoch = ipw.SelectionSlider(
    options=log_epochs,
    value=0,
    description="Epoch",
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True
)

fig_num = len(node_inputs)
fig_ar = plt.rcParams['figure.figsize'][0] / plt.rcParams['figure.figsize'][1]
fig = plt.figure(figsize=(fig_width, fig_width/fig_ar/fig_num))

for n in node_inputs:
    if len(node_inputs[n]) == 1:
        ax = fig.add_subplot(1, fig_num, n+1)
    else:
        ax = fig.add_subplot(1, fig_num, n+1, projection='3d')
        ax.view_init(elev=view[0], azim=view[1])

def update_plot(epoch):
    
    e = log_epochs.index(epoch)
       
    for n in node_inputs:
        
        ax = fig.axes[n]
        ax.clear()
        
        if len(node_inputs[n]) == 1:
            
            i = node_inputs[n][0]
            
            ax.scatter(val_data.in_data[:,i], val_data.lat_data[:,n])
            ax.scatter(val_data.in_data[:,i], rec_acts[e][:,n])
            
        else:
                                    
            i = node_inputs[n][0]
            j = node_inputs[n][1]
                    
            ax.scatter3D(val_data.in_data[:,i], val_data.in_data[:,j], val_data.lat_data[:,n])
            ax.scatter3D(val_data.in_data[:,i], val_data.in_data[:,j], rec_acts[e][:,n])

ipw.interact(update_plot, epoch=w_epoch);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

interactive(children=(SelectionSlider(continuous_update=False, description='Epoch', options=(0, 25, 50, 75, 10…

In [15]:
models = ut.plot_corrs("srnet_model_F07_v2_comb_check_lr_1e-4_sd_1e0_lin.pkl", "models", excl_names=[])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Notes:

* Between epoch 20000 and 40000, the quadratic function (node 0) and the $x_0 \cdot x_1$ function (node 2) are approximated very accurately and the training effort seems to go into modeling the cosine function (node 1)

* Is the capacity of the DSN too low to accurately capture the cosine function?

Let's increase the DSN architecture from two hidden layers with 32 nodes (`arch_S`) to three hidden layers with 64 nodes (`arch_M`).

In [None]:
models = ut.plot_losses("srnet_model_F07_v2_comb_check_lr_1e-4_sd_1e0_lin", "models", excl_names=[], disc_loss=True)

In [None]:
models = ut.plot_reg_percentage("srnet_model_F07_v2_comb_check_lr_1e-4_sd_1e0_lin", "models", excl_names=[])

In [16]:
models = ut.plot_corrs("srnet_model_F07_v2_comb_check_lr_1e-4_sd_1e0_lin", "models", excl_names=[])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
models = ut.plot_disc_preds("srnet_model_F07_v2_comb_check_lr_1e-4_sd_1e0_lin", "models", excl_names=[])

Notes:

* These results are great, `min_corr` values of nearly 100% are achieved!

* However, there is instability in the training process. Can this be improved by adapting the hyperparameters?

* Also, can similar results be achieved with `sig` or `log` predictions?

Let's run a hyperparameter study

### Step 18.4: Run hyperparameter study

In [None]:
# set wandb project
wandb_project = "184-DSN-SD-comb-study-F07_v2"

In [None]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2, 0),
#         "hid_size": 32,
#         "hid_type": ("DSN", "MLP"),
#         "hid_kwargs": {
#             "alpha": [[1,0],[0,1],[1,1]],
#             "norm": None,
#             "prune": None,
#             },
#         "lat_size": 3,
#         },
#     "epochs": 50000,
#     "runtime": None,
#     "batch_size": train_data.in_data.shape[0],
#     "shuffle": False,
#     "lr": 1e-4,
#     "wd": 1e-7,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "e1": 0.0,
#     "e2": 0.0,
#     "e3": 0.0,
#     "gc": 0.0,
#     "sd": 1.0,
#     "sd_fun": "linear",
#     "ext": None,
#     "ext_type": None,
#     "ext_size": 0,
#     "disc": {
#         "hid_num": 2,
#         "hid_size": 64,
#         "lr": 1e-4,
#         "wd": 1e-7,
#         "betas": (0.9,0.999),
#         "iters": 5,
#         "gp": 0.0,
#         "loss_fun": "BCE",
#     },
# }

In [None]:
# define hyperparameter study
hp_study = {
    "method": "random",
    "parameters": {
        "arch": {
            "parameters": {
                "in_size": {
                    "values": [2]
                },
                "out_size": {
                    "values": [1]
                },
                "hid_num": {
                    "values": [(2,0), (3,0), (4,0)]
                },
                "hid_size": {
                    "values": [32, 64, 128]
                },
                "hid_type": {
                    "values": [("DSN", "MLP")]
                },
                "hid_kwargs": {
                    "values": [{"alpha": [[1,0],[0,1],[1,1]]}]
                },
                "lat_size": {
                    "values": [3]
                },
            }
        },
        "lr": {
            "values": [1e-5, 1e-4, 1e-3]
        },
        "sd": {
            "values": [1e-4, 1e-2, 1e0, 1e2]
        },
        "sd_fun": {
            "values": ["linear", "sigmoid", "logsigmoid"]
        },
        "disc": {
            "parameters": {
                "hid_num": {
                    "values": [2]
                },
                "hid_size": {
                    "values": [64]
                },
                "lr": {
                    "values": [1e-5, 1e-4, 1e-3]
                },
                "iters": {
                    "values": [1, 5, 10]
                },
                "loss_fun": {
                    "values": ["BCE"]
                },
            }
        },
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

In [22]:
df = pd.read_csv("results/184-DSN-SD-comb-study-F07_v2.csv")

model_base_name = "srnet_model_F07_v2_DSN_SD_comb_study"
df["Name"] = df["Name"].str.replace(model_base_name + "_", "")

In [None]:
df.style.format({'sd': '{:.0e}'.format,'lr': '{:.0e}'.format,'disc.lr': '{:.0e}'.format})

Top results:

* `sigmoid`: 

    * Steady and `min_corr` value of 0.85

* `linear`: 

    * Steady and `min_corr` value of 0.99
    
    * 17/34 runs achieve `min_corr > 0.95` during training:
    
        * Stable (3x): `sd = 1e-4`, `lr <= disc.lr`, `disc.lr = 1e-3`
        
        * Late (3x): `sd = 1e-2`, `lr = 1e-5`
        
        * Unstable: `sd = 1e-2` or `sd = 1e0`

* `logsigmoid`: 

    * Instable or late convergence to `min_corr` value of 0.99
    
    * Only 5/34 runs achieve `min_corr > 0.95` during training:
    
        * Late (2x): `sd = 1e-2`, `lr <= disc.lr`, `disc.lr = 1e-3`
        
        * Semi-stable (1x): `sd = 1e-2`, `lr = 1e-5`
        
        * Unstable (2x): `sd = 1e0`, `lr > disc.lr`
        
Notes:

* `hid_num >= 3` required

* `hid_size` and `iters` inconclusive

Let's plot the best model!

In [17]:
model_id = "76k8u8dh"

In [18]:
models = ut.plot_losses(model_id, save_path="models", excl_names=[], disc_loss=True)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
models = ut.plot_reg_percentage(model_id, "models", excl_names=[])

In [19]:
models = ut.plot_corrs(model_id, "models", excl_names=[])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
models = ut.plot_disc_preds(model_id, "models", excl_names=[])

In [23]:
model_name = '_'.join([model_base_name, model_id])
model_path = "models"
model_ext = ".pkl"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)

In [24]:
fig, ax = plt.subplots()

n = 0
bias = True

ax.scatter(train_data.in_data[:,n], train_data.lat_data[:,n])
ax.scatter(train_data.in_data[:,n], model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item())

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [25]:
fig, ax = plt.subplots()

n = 1
bias = True

ax.scatter(train_data.in_data[:,n], train_data.lat_data[:,n])
ax.scatter(train_data.in_data[:,n], model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item())

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [26]:
n = 2
z_data = [("x*y", train_data.lat_data[:,n])]
plot_size = train_data.target_data.shape[0]

ut.plot_acts(train_data.in_data[:,0], train_data.in_data[:,1], z_data, acts=acts, nodes=[n], model=model, bias=False, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

These results are great!

What are reasonable next steps?

Let's check the current assumptions and limitations:

* Latent size is known (bottleneck DSN)

* Input dependency is known (entropy regularization)

* Definition of function library (functions, coefficients)

* Linear function $f(x)$

* No noise

I do think that we should keep it a three step framework for now.

More interesting next steps could be:

* Removing coefficients in the function library

* Extending function library

* Building hierarchical models

Technical ideas:

* Include gradients

* Add linear transformations to input and output of cell

* Gradient clipping