# 2022 Flatiron Machine Learning x Science Summer School

## Step 13: Investigate symbolic discriminator

Next steps:

* Define `F00_v5`

* Define `F07` with extended input range

* Define `F08` with noise

* Run hyperparameter study

* Consider embedding (derivatives?)

* Consider GhostAdam?

* Restart from trained DSN

* Select library functions depending on the number of input features

* Train discriminator independently from `SRNet`?

Discussion with Miles:

* SD optimizer: `beta=(0.5, 0.9)`

* Resample real function data every SD iteration

* Resample input data every batch

### Step 13.1: Create data with extended input range and noise

In [35]:
%matplotlib widget
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
import joblib

import torch
import wandb

from srnet import SRNet, SRData
from sdnet import SDData
import srnet_utils as ut

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [36]:
fun_path = "funs/F07_v1.lib"
in_var = "X07"
shuffle = False

In [37]:
disc_data = SDData(fun_path, in_var, shuffle=shuffle)

In [38]:
data_size = int(1e3)

In [39]:
x_data = torch.randn(size=(data_size, 2))

In [40]:
x_data.max()

tensor(4.1430)

In [41]:
disc_data.funs

['2.7*N*X07[:,0]**2',
 '2.7*N*X07[:,1]**2',
 '0.45*N*X07[:,0]**3',
 '0.45*N*X07[:,1]**3',
 '5*N*np.sin(3*U*X07[:,0])',
 '5*N*np.sin(3*U*X07[:,1])',
 '5*N*np.cos(6*U*X07[:,0])',
 '5*N*np.cos(6*U*X07[:,1])',
 '4.5*N*X07[:,0] * X07[:,1]']

In [42]:
n_samp = 10
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

In [13]:
fig, ax = plt.subplots()

for _ in range(n_samp):
    ax.scatter(x_data[:,0], disc_data.get(in_data=x_data)[0,0,:], color=colors[0], alpha=0.5)
    ax.scatter(x_data[:,0], disc_data.get(in_data=x_data)[0,2,:], color=colors[1], alpha=0.5)
    ax.scatter(x_data[:,0], disc_data.get(in_data=x_data)[0,4,:], color=colors[2], alpha=0.5)
    ax.scatter(x_data[:,0], disc_data.get(in_data=x_data)[0,6,:], color=colors[3], alpha=0.5)
    
ax.scatter(x_data[:,0], 2.7*x_data[:,0]**2, color='k')
ax.scatter(x_data[:,0], 5*torch.cos(3*x_data[:,0]), color='k')

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [14]:
z_data = [
    ("2.7*x*x", 2.7 * x_data[:,0] * x_data[:,0]),
    ("5*cos(3*y)", 5 * torch.cos(3 * x_data[:,1])),
    ("4.5*x*y", 4.5 * x_data[:,0] * x_data[:,1])]
ut.plot_acts(x_data[:,0], x_data[:,1], z_data)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Created `F07` and `F07_v1`.

Let's create some data with noise.

Option 1:

In [58]:
fig, ax = plt.subplots()

ax.scatter(x_data[:,0], torch.sin(x_data[:,0]) + 0.05*(x_data[:,0]+0.75)*torch.cos(10*x_data[:,0]))
ax.scatter(x_data[:,0], torch.sin(x_data[:,0]))

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Option 2:

In [59]:
x_data = np.random.randn(1000,1)

In [60]:
fig, ax = plt.subplots()

a = 1.0
c = 1.5

ax.scatter(x_data, ut.triangle_cos(x_data, a, c))
ax.scatter(x_data, ut.triangle(x_data, a, c))

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [61]:
data_path = "data_1k"
data_ext = ".gz"
info_ext = ".info"

In [62]:
a = 1.0
c = 1.5

In [63]:
g_data = ut.triangle_cos(x_data, a, c)

In [64]:
f_data = ut.triangle(x_data, a, c)

In [None]:
np.savetxt(os.path.join(data_path, "X08" + data_ext), x_data)

In [None]:
np.savetxt(os.path.join(data_path, "G08" + data_ext), g_data)

In [None]:
with open(os.path.join(data_path, "G08" + info_ext), 'w') as f:
    f.write("ut.triangle_cos(X08[:,0], 1.0, 1.5)\n")
    f.write("0\n")

In [None]:
np.savetxt(os.path.join(data_path, "F08" + data_ext), f_data)

In [None]:
with open(os.path.join(data_path, "F08" + info_ext), 'w') as f:
    f.write("ut.triangle(X08[:,0], 1.0, 1.5)\n")
    f.write("0\n")

Created `F08`. What should the function library look like?

In [65]:
x_data = torch.Tensor(x_data)
g_data = torch.Tensor(g_data)
f_data = torch.Tensor(f_data)

In [66]:
fun_path = "funs/F08_v1.lib"
in_var = "X08"
shuffle = False

In [67]:
disc_data = SDData(fun_path, in_var, shuffle=shuffle)

In [68]:
disc_data.funs

['0.75*N*np.sin(4*U*X08[:,0])', '0.75*N*np.cos(4*U*X08[:,0])']

In [69]:
disc_data.funs = [
    '0.75*N*np.sin(4*U*X08[:,0])', 
    '0.75*N*np.cos(4*U*X08[:,0])'
]

In [70]:
disc_data.len = len(disc_data.funs)

In [71]:
n_samp = 5
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

In [76]:
fig, ax = plt.subplots()

for _ in range(n_samp):
    ax.scatter(x_data, disc_data.get(in_data=x_data)[0,0,:], color=colors[0], alpha=0.5)
    ax.scatter(x_data, disc_data.get(in_data=x_data)[0,1,:], color=colors[1], alpha=0.5)
    # ax.scatter(x_data, disc_data.get(in_data=x_data)[2,:], color=colors[2], alpha=0.5)
    # ax.scatter(x_data, disc_data.get(in_data=x_data)[3,:], color=colors[3], alpha=0.5)
    
ax.scatter(x_data, g_data - 0.75, color='k')
ax.scatter(x_data, f_data - 0.75, color='k')

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

`F08_v1` created.

Let's run some initial tests:

In [79]:
# plot losses
save_names = ["srnet_model_F08_v1_bn_sd"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models", excl_names=["study"])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [80]:
# load data
data_path = "data_1k"

in_var = "X08"
lat_var = "G08"
target_var = "F08"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [82]:
model_name = "srnet_model_F08_v1_bn_sd_1e-08_5e-08_restart"
model_path = "models"
model_ext = ".pkl"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)

In [83]:
fig, ax = plt.subplots()

ax.scatter(train_data.in_data, train_data.target_data)
ax.scatter(train_data.in_data, preds)
ax.scatter(train_data.in_data, train_data.lat_data)
# ax.scatter(train_data.in_data, acts)

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [84]:
(train_data.target_data - train_data.lat_data).pow(2).mean()

tensor(0.0130)

### Step 13.2: Train bottleneck masked DSN with SD regularization on `F07_v1`

In [85]:
# set wandb project
wandb_project = "132-bn-mask-DSN-sd-study-F07_v1"

In [86]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": ("DSN", "MLP"),
#         "hid_kwargs": {
#             "alpha": [[1,0],[0,1],[1,1]],
#             "norm": None,
#             "prune": None,
#             },
#         "lat_size": 3,
#         },
#     "epochs": 30000,
#     "runtime": None,
#     "batch_size": train_data.in_data.shape[0],
#     "shuffle": False,
#     "lr": 1e-4,
#     "wd": 1e-6,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "e1": 0.0,
#     "e2": 0.0,
#     "e3": 0.0,
#     "gc": 0.0,
#     "sd": 1e-6,
#     "disc": {
#         "hid_num": 6,
#         "hid_size": 128,
#         "emb_size": None,
#         "lr": 1e-3,
#         "wd": 1e-4,
#         "iters": 5,
#         "gp": 1e-5,
#     },
# }

In [87]:
# define hyperparameter study
hp_study = {
    "method": "random",
    "parameters": {
        "lr": {
            "values": [1e-5, 1e-4, 1e-3]
        },
        "sd": {
            "values": [1e-7, 1e-6, 1e-5, 1e-4]
        },
        "disc": {
            "parameters": {
                "hid_num": {
                    "values": [2, 4, 6, 8]
                },
                "hid_size": {
                    "values": [64, 128, 256, 512]
                },
                "lr": {
                    "values": [1e-5, 1e-4, 1e-3, 1e-2]
                },
                "iters": {
                    "values": [2, 5, 8]
                },
                "gp": {
                    "values": [1e-6, 1e-5, 1e-4, 1e-3]
                },
            }
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

<img src="results/132-bn-mask-DSN-sd-study-F07_v1.png">

In [None]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for r, run in enumerate(runs):
    if run.summaryMetrics['min_corr'] > 0.7:
        for f in run.files():
            if f.name[-len(file_ext):] == file_ext:
                file_name = f.name.replace(file_ext, f"_v{r+1}{file_ext}")
                print(f"Downloading {os.path.basename(file_name)}.")
                run.file(f.name).download()
                os.rename(f.name, file_name)

In [88]:
# load data
data_path = "data_1k"

in_var = "X07"
lat_var = "G07"
target_var = "F07"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [89]:
corr_data = [
    ("x**2", train_data.lat_data[:,0]), 
    ("cos(y)", train_data.lat_data[:,1]), 
    ("x*y", train_data.lat_data[:,2]),
]

In [95]:
# get validation loss and latent feature correlations
model_path = "models"
save_name = "srnet_model_F07_v1_bn_mask_sd_study"

models = [f for f in os.listdir(model_path) if save_name in f]

val_corr = {}

for model_name in models:
    print(f"Loading {model_name}.")
    model = ut.load_model(model_name, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=False)
        
    corr_mat = ut.node_correlations(acts, all_nodes, corr_data, show=False)
    corr = [np.abs(c).max() for c in corr_mat]
    
    with torch.no_grad():
        preds = model(val_data.in_data)
        
    val_loss = (preds - val_data.target_data).pow(2).mean().item()
    val_corr[model_name] = (val_loss, corr)

Loading srnet_model_F07_v1_bn_mask_sd_study_v10.pkl.
Loading srnet_model_F07_v1_bn_mask_sd_study_v15.pkl.
Loading srnet_model_F07_v1_bn_mask_sd_study_v17.pkl.
Loading srnet_model_F07_v1_bn_mask_sd_study_v18.pkl.
Loading srnet_model_F07_v1_bn_mask_sd_study_v29.pkl.
Loading srnet_model_F07_v1_bn_mask_sd_study_v29_max.pkl.
Loading srnet_model_F07_v1_bn_mask_sd_study_v4.pkl.
Loading srnet_model_F07_v1_bn_mask_sd_study_v5.pkl.
Loading srnet_model_F07_v1_bn_mask_sd_study_v8.pkl.
Loading srnet_model_F07_v1_bn_mask_sd_study_v9.pkl.


In [96]:
fig, ax = plt.subplots()

for v in val_corr:
    ax.plot(val_corr[v][0], np.min(val_corr[v][1]), 'x', label=v.split('.')[0].split('_')[-1])

ax.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [100]:
# plot losses
save_names = [
    "srnet_model_F07_v1_bn_mask_sd_study_v8",
    "srnet_model_F07_v1_bn_mask_sd_study_v15",
    "srnet_model_F07_v1_bn_mask_sd_study_v4",
    "srnet_model_F07_v1_bn_mask_sd_study_v17",
    "srnet_model_F07_v1_bn_mask_sd_study_v29",
]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models", excl_names=[])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [101]:
corr_data = [
    ("x**2", train_data.lat_data[:,0]),
    ("y**2", train_data.in_data[:,1]**2), 
    ("cos(x)", torch.cos(train_data.in_data[:,0])), 
    ("cos(y)", train_data.lat_data[:,1]), 
    ("x*y", train_data.lat_data[:,2]),
]

In [102]:
model_path = "models"
model_ext = ".pkl"

for model_name in models:
    print(model_name)
    
    state = joblib.load(os.path.join(model_path, model_name + model_ext))
    
    print(state['hyperparams']['sd'])
    print(state['hyperparams']['disc'])
    
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    all_nodes = ut.get_node_order(acts, show=True)
    
    print(f"Validation error: {state['total_val_loss']:.4e}")
        
    corr_mat = ut.node_correlations(acts, all_nodes, corr_data, show=True)
        
    print("")

srnet_model_F07_v1_bn_mask_sd_study_v8
1e-07
{'gp': 0.0001, 'hid_num': 2, 'hid_size': 64, 'iters': 8, 'lr': 0.001}
[3.4767916, 1.8724633, 0.99059254]
[2, 0, 1]
Validation error: 3.6128e-02

Node 2
corr(n2, x**2): -0.1143
corr(n2, y**2): 0.3107
corr(n2, cos(x)): 0.1046
corr(n2, cos(y)): -0.3220
corr(n2, x*y): -0.8738

Node 0
corr(n0, x**2): -0.9970
corr(n0, y**2): 0.0535
corr(n0, cos(x)): 0.9752
corr(n0, cos(y)): 0.0281
corr(n0, x*y): 0.0063

Node 1
corr(n1, x**2): 0.0397
corr(n1, y**2): -0.4102
corr(n1, cos(x)): -0.0412
corr(n1, cos(y)): -0.7401
corr(n1, x*y): -0.0058

srnet_model_F07_v1_bn_mask_sd_study_v15
1e-07
{'gp': 0.0001, 'hid_num': 4, 'hid_size': 512, 'iters': 5, 'lr': 0.01}
[3.412355, 1.850282, 1.4938799]
[2, 0, 1]
Validation error: 1.7346e-02

Node 2
corr(n2, x**2): -0.1323
corr(n2, y**2): -0.3480
corr(n2, cos(x)): 0.1235
corr(n2, cos(y)): -0.2577
corr(n2, x*y): -0.8523

Node 0
corr(n0, x**2): -0.9988
corr(n0, y**2): 0.0504
corr(n0, cos(x)): 0.9701
corr(n0, cos(y)): 0.0285
co

Compare models:

In [112]:
comp_models = [
    'srnet_model_F07_v1_bn_mask_sd_study_v8',
    #'srnet_model_F07_v1_bn_mask_sd_study_v15',
    #'srnet_model_F07_v1_bn_mask_sd_study_v4',
    #'srnet_model_F07_v1_bn_mask_sd_study_v17',
    'srnet_model_F07_v1_bn_mask_sd_study_v29',
    'srnet_model_F07_v1_bn_mask_sd_study_v29_max'
]

In [113]:
model_path = "models"
model_ext = ".pkl"

fig = {}
ax = {}
bias = True

for n in range(2):
    fig[n], ax[n] = plt.subplots()
    ax[n].scatter(train_data.in_data[:,n], train_data.lat_data[:,n])

for model_name in comp_models:
    
    model = ut.load_model(model_name + model_ext, model_path, SRNet)
    
    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)
        
    for n in range(2):
        ax[n].scatter(train_data.in_data[:,n], model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item(), label=model_name.split('_')[-1])

for n in range(2):
    ax[n].legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Plot individual model:

In [114]:
model_name = "srnet_model_F07_v1_bn_mask_sd_study_v29_max"
model_path = "models"
model_ext = ".pkl"

model = ut.load_model(model_name + model_ext, model_path, SRNet)

with torch.no_grad():
    preds, acts = model(train_data.in_data, get_lat=True)

In [115]:
fig, ax = plt.subplots()

n = 0
bias = True

ax.scatter(train_data.in_data[:,n], train_data.lat_data[:,n])
ax.scatter(train_data.in_data[:,n], model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item())

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [116]:
fig, ax = plt.subplots()

n = 1
bias = True

ax.scatter(train_data.in_data[:,n], train_data.lat_data[:,n])
ax.scatter(train_data.in_data[:,n], model.layers2[0].weight[0,n].item()*acts[:,n] + bias * model.layers2[0].bias.item())

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [117]:
n = 2
z_data = [("x*y", train_data.lat_data[:,n])]
plot_size = train_data.target_data.shape[0]

ut.plot_acts(train_data.in_data[:,0], train_data.in_data[:,1], z_data, acts=acts, nodes=[n], model=model, bias=False, nonzero=False, agg=False, plot_size=plot_size)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Runs that are similar to v29:

* `gp`: most `1e-3`, one `1e-4`

* `hid_num`: 2, 4, 4, 4, 8

* `hid_size`: most 256, one 128

* `iters`: 2, 5, 5, 8, 8

* `disc.lr`: 1e-5, 1e-4, 1e-3, 1e-3, 1e-3

* `lr`: all 1e-5

* `sd`: 1e-7, 1e-7, 1e-7, 1e-6, 1e-6

Define baseline setup and test:

* Adam parameters

* Resample coefficients each SD epoch

Next:

* Embed input data

* Resample input data

* Embed gradients

* No weight decay

* Gradients:

    * Stack
    
    * Embed
    
    * CNN

* Restart `max`

### Step 13.3: Train bottleneck DSN with SD regularization on `F08_v1`

In [118]:
# set wandb project
wandb_project = "133-bn-DSN-sd-study-F08_v1"

In [119]:
# hyperparams = {
#     "arch": {
#         "in_size": train_data.in_data.shape[1],
#         "out_size": train_data.target_data.shape[1],
#         "hid_num": (2,0),
#         "hid_size": 32, 
#         "hid_type": "MLP",
#         "hid_kwargs": {
#             "alpha": None,
#             "norm": None,
#             "prune": None,
#             },
#         "lat_size": 1,
#         },
#     "epochs": 50000,
#     "runtime": None,
#     "batch_size": train_data.in_data.shape[0],
#     "shuffle": False,
#     "lr": 1e-4,
#     "wd": 1e-6,
#     "l1": 0.0,
#     "a1": 0.0,
#     "a2": 0.0,
#     "e1": 0.0,
#     "e2": 0.0,
#     "e3": 0.0,
#     "gc": 0.0,
#     "sd": 1e-8,
#     "disc": {
#         "hid_num": 2,
#         "hid_size": 32,
#         "emb_size": None,
#         "lr": 1e-3,
#         "wd": 1e-7,
#         "iters": 5,
#         "gp": 1e-5,
#     },
# }

In [120]:
# define hyperparameter study
hp_study = {
    "method": "random",
    "parameters": {
        "lr": {
            "values": [1e-5, 1e-4, 1e-3]
        },
        "sd": {
            "values": [1e-8, 5e-8, 1e-7, 5e-7, 1e-6, 1e-5]
        },
        "disc": {
            "parameters": {
                "hid_num": {
                    "values": [2, 4, 6]
                },
                "hid_size": {
                    "values": [64, 128, 256]
                },
                "lr": {
                    "values": [1e-5, 1e-4, 1e-3, 1e-2]
                },
                "iters": {
                    "values": [2, 5]
                },
                "gp": {
                    "values": [1e-6, 1e-5, 1e-4, 1e-3]
                },
            }
        }
    }
}

In [None]:
# create sweep
sweep_id = wandb.sweep(hp_study, project=wandb_project)

<img src="results/133-bn-DSN-sd-study-F08_v1.png">

In [None]:
# download data from wandb
file_ext = ".pkl"

api = wandb.Api()

runs = api.runs(wandb_project)
for r, run in enumerate(runs):
    if run.summaryMetrics['min_corr'] > 0.9:
        for f in run.files():
            if f.name[-len(file_ext):] == file_ext:
                file_name = f.name.replace(file_ext, f"_v{r+1}{file_ext}")
                print(f"Downloading {os.path.basename(file_name)}.")
                run.file(f.name).download()
                os.rename(f.name, file_name)

In [121]:
# plot losses
save_names = ["srnet_model_F08_v1_bn_sd_study"]
save_path = "models"
models = ut.plot_losses(save_names, save_path="models", excl_names=[])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [122]:
# load data
data_path = "data_1k"

in_var = "X08"
lat_var = "G08"
target_var = "F08"

mask_ext = ".mask"
masks = joblib.load(os.path.join(data_path, in_var + mask_ext))

train_data = SRData(data_path, in_var, lat_var, target_var, masks["train"])
val_data = SRData(data_path, in_var, lat_var, target_var, masks["val"])

In [123]:
# model_name = "srnet_model_F08_v1_bn_mask_sd_1e-08_5e-08_restart"
model_path = "models"
model_ext = ".pkl"

fig, ax = plt.subplots()

for model_name in models:

    model = ut.load_model(model_name + model_ext, model_path, SRNet)

    with torch.no_grad():
        preds, acts = model(train_data.in_data, get_lat=True)

    ax.scatter(train_data.in_data, preds)
    
ax.scatter(train_data.in_data, train_data.target_data, color='k')
ax.scatter(train_data.in_data, train_data.lat_data, color='k')

plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …