# 2022 Flatiron Machine Learning x Science Summer School

## Step 12: Explore symbolic discriminator alternatives

In this step, as an alternative to the symbolic discriminator, we try to maximize the probability that the latent features come from the "distribution" of library functions: $f_i \in \mathbb{R}^d$, $D = \{f_i\}_{i=1}^N$

The indices of library functions $\pi$ come from a categorical distribution with a uniform prior: $p(\pi) = 1/|D|$

The likelihood is the probability of the latent feature given a library function index: $p(f|\pi) = N(f_\pi, I\sigma^2)$

The posterior is the probability of library function indices given the latent feature: $p(\pi|f)$

We want to find $\hat{f}$ that minimizes the entropy of $p(\pi|f)$.

**Implementation**:

How can we implement this?

* MSE of latent features and library functions, softmax and entropy

* Assume the likelihood to be Gaussian, normalize the posterior and calculate entropy

What are the differences?

* Define $l_{MSE} = \sum_i ||\hat{f}^{(i)} - f^{(i)}_\pi||^2$

* A low MSE should indicate a high probability and vice versa, thus, we want to take the negative MSE

* Gaussian constant and prior disappear during normalization

* The different would practically be to multiply $l_{MSE}$ by $1/(2\sigma^2)$ or not

**Issue**:

One issue is that, depending on the target function, the MSE can be high despite a good correlation and higher than for a target function with a low correlation.

Let's normalize the differences by the absolute real function values, which corresponds to point-wise variances, right?

**TODO**: Double check derivation.

**Correlation**:

A second alternative is calculating and regularizing the correlation values directly, i.e. we minimize the entropy of the softmax-normalized correlations.

**Work in progress**:

So far, only the correlation approach yielded somewhat reasonable results and also only when using a pre-trained network. More experiments need to be run.

In [1]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
import joblib

import torch
import wandb

from srnet import SRNet, SRData
from sdnet import SDData
import srnet_utils as ut

In [2]:
# load data
data_path = "data_1k"

in_var = "X00"
lat_var = "G00"
target_var = "F00"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [3]:
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]

In [4]:
model_name = "srnet_model_F00_v3_bn_mask_corr_check_v1"
model_path = "models"
model_ext = ".pkl"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)
    
all_nodes = ut.get_node_order(acts, show=True)

print(model.layers1.norm(model.layers1.alpha).detach().cpu().numpy()[all_nodes])

print("")

[5.4851813, 1.4364765, 0.427018]
[0, 2, 1]
[[1. 0.]
 [1. 1.]
 [0. 1.]]



In [5]:
nodes = all_nodes[:3]

In [6]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    ("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    #("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [7]:
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=nodes, model=model, bias=True, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [8]:
corr_data = [
    ("x**2", x_data**2), 
    ("y**2", y_data**2), 
    ("cos(x)", np.cos(x_data)), 
    ("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]

In [9]:
ut.node_correlations(acts, nodes, corr_data);


Node 0
corr(n0, x**2): -0.9934
corr(n0, y**2): -0.0152
corr(n0, cos(x)): 0.9746
corr(n0, cos(y)): 0.0146
corr(n0, x*y): -0.1587

Node 2
corr(n2, x**2): -0.0458
corr(n2, y**2): 0.0001
corr(n2, cos(x)): -0.0004
corr(n2, cos(y)): 0.0078
corr(n2, x*y): -0.9817

Node 1
corr(n1, x**2): 0.0112
corr(n1, y**2): 0.9806
corr(n1, cos(x)): -0.0263
corr(n1, cos(y)): -0.9901
corr(n1, x*y): 0.0559


In [10]:
fig, ax = plt.subplots()

n = 0
bias = True

ax.scatter(x_data, x_data**2 + 1.5)
# ax.scatter(x_data, -np.cos(x_data))
ax.scatter(x_data, model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item())

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [11]:
fig, ax = plt.subplots()

n = 1
bias = False

ax.scatter(y_data, np.cos(y_data) - 1.5)
ax.scatter(y_data, model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item())

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [12]:
# select plotting data
x_data = train_data.in_data[:,0]
y_data = train_data.in_data[:,1]
z_data = [
    #("target", train_data.target_data),
    #("x**2", x_data**2), 
    #("cos(y)", np.cos(y_data)), 
    ("x*y", x_data * y_data),
]
plot_size = train_data.target_data.shape[0]

In [13]:
n = 2
ut.plot_acts(x_data, y_data, z_data, acts=acts, nodes=[n], model=model, bias=False, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [14]:
def Gaussian(x, mu, sigma):
    return 1/(sigma * np.sqrt(2 * np.pi)) * np.exp(-0.5*((x - mu) / sigma)**2)

In [15]:
p_data = np.linspace(-5,5)

In [16]:
fig, ax = plt.subplots()

ax.plot(p_data, Gaussian(p_data, 0, 1))
ax.plot(p_data, Gaussian(p_data, 1, 1))
ax.plot(p_data, Gaussian(p_data, 1, 2))

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

What would the ideal correlations look like?

**N(0,1)**:

`x**2`:
```
Node 0
corr(n0, x**2): 1.0000
corr(n0, y**2): 0.0159
corr(n0, cos(x)): -0.9553
corr(n0, cos(y)): -0.0170
corr(n0, x*y): 0.1732
```

`cos(y)`:
```
Node 0
corr(n0, x**2): -0.0170
corr(n0, y**2): -0.9632
corr(n0, cos(x)): 0.0298
corr(n0, cos(y)): 1.0000
corr(n0, x*y): -0.0622
```

`x*y`:
```
Node 0
corr(n0, x**2): 0.1732
corr(n0, y**2): 0.0431
corr(n0, cos(x)): -0.1515
corr(n0, cos(y)): -0.0622
corr(n0, x*y): 1.0000
```

**U(-1,1)**:

`x**2`:
```
Node 0
corr(n0, x**2): 1.0000
corr(n0, y**2): 0.0380
corr(n0, cos(x)): -0.9998
corr(n0, cos(y)): -0.0377
corr(n0, x*y): 0.0492
```

`cos(y)`:
```
Node 0
corr(n0, x**2): -0.0377
corr(n0, y**2): -0.9997
corr(n0, cos(x)): 0.0338
corr(n0, cos(y)): 1.0000
corr(n0, x*y): 0.0192
```

`x*y`:
```
Node 0
corr(n0, x**2): 0.0492
corr(n0, y**2): -0.0243
corr(n0, cos(x)): -0.0494
corr(n0, cos(y)): 0.0192
corr(n0, x*y): 1.0000
```

**U(-5,5)**:

`x**2`:
```
Node 0
corr(n0, x**2): 1.0000
corr(n0, y**2): -0.0238
corr(n0, cos(x)): -0.4797
corr(n0, cos(y)): 0.1668
corr(n0, x*y): -0.0543
```

`cos(y)`:
```
Node 0
corr(n0, x**2): 0.1668
corr(n0, y**2): -0.3535
corr(n0, cos(x)): -0.0979
corr(n0, cos(y)): 1.0000
corr(n0, x*y): -0.0339
```

`x*y`:
```
Node 0
corr(n0, x**2): -0.0543
corr(n0, y**2): -0.0506
corr(n0, cos(x)): 0.1231
corr(n0, cos(y)): -0.0339
corr(n0, x*y): 1.0000
```