# 2022 Flatiron Machine Learning x Science Summer School

## Step 15: Study symbolic discriminator without generator training

Steps:

1. Create a generator that outputs non-symbolic activations:

    * GP activations
    
    * Untrained MLP activations
    
2. Define simple function library

3. Monitor classification accuracy

4. Increase function library until classification fails

5. Provide additional information to SD

6. Increase function library further

7. Utilized trained SD to regularize DSN

### Step 15.1: Create a generator that outputs non-symbolic activations

The first idea is utilizing the activations of an untrained DSN that has the architecture of interest.

What is the input data and its dimensionality?

* We want to add more library functions over time and some of these should ideally also depend on two or more input features

* However, we probably also want to start as simple as possible, so one-dimensional input would be best

* However, we could also simply mask all unwanted input features via the input mask $\alpha$

Let's create two-dimensional input data `X10`

In [1]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
import joblib

import torch
import wandb

from srnet import SRNet, SRData
from sdnet import SDNet, SDData
import srnet_utils as ut

In [2]:
# load data
data_path = "data_1k"

in_var = "X10"
lat_var = None
target_var = None

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

How do the activations of an untrained DSN look like?

In [3]:
torch.manual_seed(0);

In [4]:
x_data = train_data.in_data[:,:1]

In [5]:
n_sample = 10

In [6]:
hp = {
    "arch": {
        "in_size": 1,
        "out_size": 1,
        "hid_num": (2,0),
        "hid_size": 32, 
        "hid_type": ("DSN", "MLP"),
        "hid_kwargs": {
            "alpha": None,
            "norm": None,
            "prune": None,
            },
        "lat_size": 1,
    },
}

In [7]:
fig, ax = plt.subplots()

for _ in range(n_sample):
    
    model = SRNet(**hp['arch'])

    with torch.no_grad():
        preds, acts = model(x_data, get_lat=True)
        
    ax.scatter(x_data[:,0], acts[:,0])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [8]:
hp = {
    "arch": {
        "in_size": 1,
        "out_size": 1,
        "hid_num": (2,0),
        "hid_size": 32, 
        "hid_type": ("MLP", "MLP"),
        "hid_kwargs": {
            "alpha": None,
            "norm": None,
            "prune": None,
            },
        "lat_size": 1,
    },
}

In [9]:
fig, ax = plt.subplots()

for _ in range(n_sample):
    
    model = SRNet(**hp['arch'])

    with torch.no_grad():
        preds, acts = model(x_data, get_lat=True)
        
    ax.scatter(x_data[:,0], acts[:,0])

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Why does the untrained DSN yield linear latent feature activations?

The weight initialization between MLP and DSN is different (uniform and bias vs. normal and no bias).

Due to having no bias, the ReLU kink of the DSN is always at $x=0$ and all positive and negative weights are simply added up to yield the positive and negative slope, respectively.

Let's use an untrained MLP.

### Step 15.2: Define simple function library

The activations above look most similar to a quadratic function, so this could be the first library function:

In [10]:
fun_path = "funs/F10_v1.lib"
in_var = "X10"
shuffle = False

In [11]:
disc_data = SDData(fun_path, in_var, shuffle=shuffle)

In [12]:
disc_data.funs

[['N0*0.05*(X10[:,0] + 0.5*N1)**2 + 0.15*N2', '2*N0*0.05*(X10[:,0] + 0.5*N1)']]

In [13]:
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

In [14]:
fig, ax = plt.subplots()

for _ in range(n_sample):
    ax.scatter(x_data[:,0], disc_data.get(in_data=x_data)[0,0,:,0], color=colors[0], alpha=0.5)
    
    model = SRNet(**hp['arch'])
    with torch.no_grad():
        preds, acts = model(x_data, get_lat=True)
    
    ax.scatter(x_data[:,0], acts[:,0], color=colors[1], alpha=0.5)
    
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

That looks good, let's train this.

### Step 15.3: Monitor classification accuracy

In [None]:
try:
    if get_ipython().__class__.__name__ == "ZMQInteractiveShell":
        from tqdm.notebook import trange
    else:
        raise RuntimeWarning()
except:
    from tqdm import trange
    
from IPython import display

In [None]:
hp = {
    "arch": {
        "in_size": 1,
        "out_size": 1,
        "hid_num": (2,0),
        "hid_size": 32, 
        "hid_type": ("MLP", "MLP"),
        "hid_kwargs": {
            "alpha": None,
            "norm": None,
            "prune": None,
            },
        "lat_size": 5,
    },
    "epochs": 1000,
    "batch_size": train_data.in_data.shape[0],
    "disc": {
        "hid_num": 6,
        "hid_size": 128,
        "emb_size": None,
        "lr": 1e-4,
        "wd": 1e-4,
        "betas": (0.9,0.999),
        "iters": 5,
        "gp": 1e-5,
        },
    }

In [None]:
avg_hor = 100
plot_freq = 100
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

In [None]:
# training loop

# set seed
torch.manual_seed(0)

# create SD
critic = SDNet(hp['batch_size'], **hp['disc'])
critic.train()

# create function library
fun_path = "funs/F10_v1.lib"
shuffle = True
iter_sample = False
disc_data = SDData(fun_path, in_var, shuffle=shuffle, iter_sample=iter_sample)

# load data
data_path = "data_1k"
in_var = "X10"
lat_var = None
target_var = None

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
in_data = train_data.in_data[:,:1]

tot_accs = []

if plot_freq:
    fig, ax = plt.subplots()
    display.display(fig, display_id="fig")

t = trange(hp['epochs'], desc="Epoch")
for epoch in t:
    
    # get generator data
    model = SRNet(**hp['arch'])
    model.train()
    
    with torch.no_grad():
        _, lat_acts = model(in_data, get_lat=True)
        
    data_fake = lat_acts.detach().T
    
    # get real data
    if disc_data.iter_sample:
        datasets_real = disc_data.get(lat_acts.shape[1], in_data, critic.iters)
    else:
        datasets_real = disc_data.get(lat_acts.shape[1], in_data)
    dataset_real = datasets_real[...,0]
                
    if plot_freq and epoch % plot_freq == 0:
        ax.clear()
        for i in range(data_fake.shape[0]):
            ax.scatter(in_data[:,0], dataset_real[0,i], color=colors[0])
            ax.scatter(in_data[:,0], data_fake[i], color=colors[1])
        display.update_display(fig, display_id="fig")
          
    accs = critic.fit(dataset_real, data_fake)
        
    tot_accs.append(np.mean(accs))
    
    t_update = {"acc": tot_accs[-1], "avg_acc": np.mean(tot_accs[-avg_hor:])}
    t.set_postfix({k: f"{v:.2f}" for k, v in t_update.items()})

In [None]:
avg_hor = 50

In [None]:
avg_accs = [np.mean(tot_accs[max(0,i+1-avg_hor):i+1]) for i in range(len(tot_accs))]

In [None]:
fig, ax = plt.subplots()

ax.plot(avg_accs)

plt.show()

Great, this works. Let's run a hyperparameter study.

In [15]:
# set wandb project
wandb_project = "153-sd-study-F10_v1"

In [16]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": 1,
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("MLP", "MLP"),
#         "hid_kwargs": {
#             "alpha": None,
#             "norm": None,
#             "prune": None,
#             },
#         "lat_size": 5,
#     },
#     "epochs": 100000,
#     "runtime": None,
#     "batch_size": train_data.in_data.shape[0],
#     "disc": {
#         "hid_num": 6,
#         "hid_size": 128,
#         "emb_size": None,
#         "lr": 1e-4,
#         "wd": 1e-7,
#         "betas": (0.9,0.999),
#         "iters": 5,
#         "gp": 1e-5,
#     },
# }

In [17]:
# define hyperparameter study
hp_study = {
    "method": "random",
    "parameters": {
        "arch": {
            "parameters": {
                "in_size": {
                    "values": [1]
                },
                "out_size": {
                    "values": [1]
                },
                "hid_num": {
                    "values": [(2,0)]
                },
                "hid_size": {
                    "values": [32]
                },
                "lat_size": {
                    "values": [1, 3, 5, 10]
                },
            }
        },
        "disc": {
            "parameters": {
                "hid_num": {
                    "values": [2, 4, 6, 8]
                },
                "hid_size": {
                    "values": [64, 128, 256, 512]
                },
                "lr": {
                    "values": [1e-5, 1e-4, 1e-3, 1e-2]
                },
                "iters": {
                    "values": [1, 3, 5, 10]
                },
                "gp": {
                    "values": [0.0, 1e-6, 1e-5, 1e-4, 1e-3]
                },
            }
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

<img src="results/153-sd-study-F10_v1_conv.png">

Notes:

* This is the SD accuracy over the respective last 500 epochs

* The results are smoothed a lot

* Different trends are observable:

    1. Models that learn quickly and then maintain a constant performance
    
    2. Models that learn quickly and then drop in performance
    
    3. Models that learn slowly and keep improving        
    
    4. Models that basically don't learn anything
    
    
Which hyperparameters characterize the individual trends?

1. Small architecture, low `lr`:

    * `time`: 0.5
    * `lat_size`: 3
    * `gp`: 1e-4
    * `hid`: (2,64)
    * `iters`: 5
    * `lr`: 1e-5
    

2. Low `lat_size` and `iters` values:

    * `time`: 0.5/0.2
    * `lat_size`: 3/1
    * `gp`: 1e-5/1e-4
    * `hid`: (2,512)/(4,256)
    * `iters`: 3/1
    * `lr`: 1e-4/1e-5


3. Large architecture, rather high `lat_size` or `iters` values, rather low `lr`:

    * `time`: 0.75/3/2/5
    * `lat_size`: 1/3/10/5
    * `gp`: 1e-5/1e-6/1e-6/1e-5
    * `hid`: (2,512)/(8,256)/(6,128)/(8,512)
    * `iters`: 3/10/1/5
    * `lr`: 1e-2/1e-4/1e-4/1e-4


4. No gradient penalty, rather low `lat_size` and `iters` values:

    * `time`: 2/0.2/5/0.5/3
    * `lat_size`: 3/1/3/3/1
    * `gp`: 0/0/0/0/1e-6
    * `hid`: (8,512)/(2,64)/(4,128)/(4,256)/(8,512)
    * `iters`: 5/5/3/3/3
    * `lr`: 1e-5/1e-3/1e-4/1e-3/1e-2

Let's specifically analyze the impact of depth on the network. We start from the model of trend 1 above:

* v0: `hid_num`: 2, `hid_size`: 64, `lr`: 1e-5, `wd`: 1e-7, `betas`: (0.9,0.999), `iters`: 5, `gp`: 1e-4

* v1: `hid_num`: 4

* v2: `hid_num`: 8

* v3: `gp`: 1e-5

* v4: `lr`: 1e-4

* v5: `gp`: 1e-6

* v6: `hid_size`: 256

* v7: `hid_num`: 16, `hid_size`: 64

In [18]:
# plot accuracies
avg_hor = 500
save_names = ["disc_model_F10_v1_depth_check"]
save_path = "models"

models = ut.plot_accuracies(save_names, save_path=save_path, excl_names=[], avg_hor=avg_hor, uncertainty=False)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

### Step 15.4: Increase function library until classification fails

Done.

### Step 15.5: Provide additional information to SD

Types of information:

* Input data

* Gradient information

Methodologies:

* Stacking

* Embedding

* <s>Convolution</s>

**TODO**: 

* Check correct gradient calculation

* Discuss computational graph for regularizing DSN

* Double check that input and parameter gradients do not impact each other (check impact in both directions)

Let's run the hyperparameter studies:

1. Embed input data

2. Embed gradient data

3. Stack gradient data

**NOTE**: We use an arbitrary embedding network with 2 hidden layers and 64 hidden nodes.

#### Step 15.5.1: Embed input data

In [19]:
# set wandb project
wandb_project = "155-ext1-study-F10_v1"

In [20]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": 1,
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("MLP", "MLP"),
#         "hid_kwargs": {
#             "alpha": None,
#             "norm": None,
#             "prune": None,
#             },
#         "lat_size": 3,
#     },
#     "epochs": 100000,
#     "runtime": None,
#     "batch_size": train_data.in_data.shape[0],
#     "ext": ["input"],
#     "ext_type": "embed",
#     "ext_size": 1,
#     "disc": {
#         "hid_num": 2,
#         "hid_size": 64,
#         "lr": 1e-4,
#         "wd": 1e-7,
#         "betas": (0.9,0.999),
#         "iters": 5,
#         "gp": 1e-5,
#     },
# }

In [21]:
# define hyperparameter study
hp_study = {
    "method": "random",
    "parameters": {
        "arch": {
            "parameters": {
                "in_size": {
                    "values": [1]
                },
                "out_size": {
                    "values": [1]
                },
                "hid_num": {
                    "values": [(2,0)]
                },
                "hid_size": {
                    "values": [32]
                },
                "lat_size": {
                    "values": [1, 3, 5]
                },
            }
        },
        "disc": {
            "parameters": {
                "hid_num": {
                    "values": [(2,2), (2,4), (2,8), (2,16)]
                },
                "hid_size": {
                    "values": [(64,64), (64,128), (64,256), (64,512)]
                },
                "lr": {
                    "values": [1e-5, 1e-4, 1e-3]
                },
                "iters": {
                    "values": [1, 3, 5]
                },
                "gp": {
                    "values": [1e-6, 1e-5, 1e-4]
                },
            }
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

<img src="results/155-ext1-study-F10_v1_conv.png">

#### Step 15.5.2: Embed gradient data

In [22]:
# set wandb project
wandb_project = "155-ext2-study-F10_v1"

In [23]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": 1,
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("MLP", "MLP"),
#         "hid_kwargs": {
#             "alpha": None,
#             "norm": None,
#             "prune": None,
#             },
#         "lat_size": 3,
#     },
#     "epochs": 100000,
#     "runtime": None,
#     "batch_size": train_data.in_data.shape[0],
#     "ext": ["grad"],
#     "ext_type": "embed",
#     "ext_size": 1,
#     "disc": {
#         "hid_num": 2,
#         "hid_size": 64,
#         "lr": 1e-4,
#         "wd": 1e-7,
#         "betas": (0.9,0.999),
#         "iters": 5,
#         "gp": 1e-5,
#     },
# }

In [24]:
# define hyperparameter study
hp_study = {
    "method": "random",
    "parameters": {
        "arch": {
            "parameters": {
                "in_size": {
                    "values": [1]
                },
                "out_size": {
                    "values": [1]
                },
                "hid_num": {
                    "values": [(2,0)]
                },
                "hid_size": {
                    "values": [32]
                },
                "lat_size": {
                    "values": [1, 3, 5]
                },
            }
        },
        "disc": {
            "parameters": {
                "hid_num": {
                    "values": [(2,2), (2,4), (2,8), (2,16)]
                },
                "hid_size": {
                    "values": [(64,64), (64,128), (64,256), (64,512)]
                },
                "lr": {
                    "values": [1e-5, 1e-4, 1e-3]
                },
                "iters": {
                    "values": [1, 3, 5]
                },
                "gp": {
                    "values": [1e-6, 1e-5, 1e-4]
                },
            }
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

<img src="results/155-ext2-study-F10_v1_conv.png">

#### Step 15.5.3: Stack gradient data

In [25]:
# set wandb project
wandb_project = "155-ext3-study-F10_v1"

In [26]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": 1,
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("MLP", "MLP"),
#         "hid_kwargs": {
#             "alpha": None,
#             "norm": None,
#             "prune": None,
#             },
#         "lat_size": 3,
#     },
#     "epochs": 100000,
#     "runtime": None,
#     "batch_size": train_data.in_data.shape[0],
#     "ext": ["stack"],
#     "ext_type": "embed",
#     "ext_size": 1,
#     "disc": {
#         "hid_num": 2,
#         "hid_size": 64,
#         "lr": 1e-4,
#         "wd": 1e-7,
#         "betas": (0.9,0.999),
#         "iters": 5,
#         "gp": 1e-5,
#     },
# }

In [27]:
# define hyperparameter study
hp_study = {
    "method": "random",
    "parameters": {
        "arch": {
            "parameters": {
                "in_size": {
                    "values": [1]
                },
                "out_size": {
                    "values": [1]
                },
                "hid_num": {
                    "values": [(2,0)]
                },
                "hid_size": {
                    "values": [32]
                },
                "lat_size": {
                    "values": [1, 3, 5]
                },
            }
        },
        "disc": {
            "parameters": {
                "hid_num": {
                    "values": [2, 4, 8, 16]
                },
                "hid_size": {
                    "values": [64, 128, 256, 512]
                },
                "lr": {
                    "values": [1e-5, 1e-4, 1e-3]
                },
                "iters": {
                    "values": [1, 3, 5]
                },
                "gp": {
                    "values": [1e-6, 1e-5, 1e-4]
                },
            }
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

<img src="results/155-ext3-study-F10_v1_conv.png">

Let's compare the respective best results:

In [28]:
# plot accuracies
avg_hor = 500
save_names = ["disc_model_F10_v1_sd_study", "disc_model_F10_v1_ext"]
save_path = "models"

models = ut.plot_accuracies(save_names, save_path=save_path, excl_names=[], avg_hor=avg_hor, uncertainty=False)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Let's load the best model:

In [29]:
model_name = "disc_model_F10_v1_ext2_study_v1"

In [30]:
critic = ut.load_disc(model_name + ".pkl", "models", SDNet)

In [31]:
critic

SDNet(
  (layers1): Sequential(
    (0): Linear(in_features=2, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=1, bias=True)
  )
  (layers2): Sequential(
    (0): Linear(in_features=700, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=512, bias=True)
    (5): ReLU()
    (6): Linear(in_features=512, out_features=512, bias=True)
    (7): ReLU()
    (8): Linear(in_features=512, out_features=1, bias=True)
  )
)

How do correct and false predictions look like?

In [32]:
state = joblib.load(os.path.join("models", model_name + ".pkl"))
hp = state['hyperparams']

In [33]:
false_num = 5
corr_list = []
false_list = []

while len(false_list) < false_num:

    # get fake data
    model = SRNet(**hp['arch'])
    model.train()

    with torch.no_grad():
        _, lat_acts = model(x_data, get_lat=True)
    
    data_fake = lat_acts.detach().T

    ext_data_fake = []
    if 'ext' in hp and hp['ext'] is not None:
        for ext_name in hp['ext']:
            if ext_name == "input":
                ext_data_fake.append(x_data)
            elif ext_name == "grad":
                grad_data_fake = model.jacobian(x_data, get_lat=True).transpose(0,1)
                ext_data_fake.append(grad_data_fake.detach())
            else:
                raise KeyError(f"Extension {ext_name} is not defined.")

        data_fake = ut.extend(data_fake, *ext_data_fake, ext_type=hp['ext_type'])
        
    preds = critic(data_fake).squeeze()
    
    corr_pred = (preds <= 0).nonzero().reshape(-1).tolist()
    false_pred = (preds > 0).nonzero().reshape(-1).tolist()
    
    for idx in corr_pred:
        corr_list.append(data_fake[idx])
    
    for idx in false_pred:
        false_list.append(data_fake[idx])

In [34]:
print(len(corr_list))

19


In [35]:
fig, ax = plt.subplots()

for y_data in corr_list[:false_num]:
    ax.scatter(x_data, y_data[:,0])
    
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [36]:
fig, ax = plt.subplots()

for y_data in false_list[:false_num]:
    if len(y_data.shape) > 1:
        y_data = y_data[:,0]
    ax.scatter(x_data, y_data)
    
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

**TODO**: Analyze optimal hyperparameters in Step 15.5