# Playground for optimization with neural networks

In [None]:
%matplotlib inline
from functools import partial
import sys
import GPyOpt
import numpy as np
from shinyutils.matwrap import MatWrap as mw
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tensorboardX import SummaryWriter

from infopt.nnmodel import NNModel
from infopt.nrmodel import NRModel
from exputils.models import NRNet
from infopt.ihvp import IterativeIHVP, LowRankIHVP

LDA = 1e-2

## Helper plot function

In [None]:
mw.configure(context="notebook")


def plot_comp_1d(
    f_true,
    f_noisy,
    model,
    X_tr,
    Y_tr=None,
    mu=0,
    sig=1,
    X_te=np.linspace(0, 1, 200)[:, np.newaxis],
):
    if Y_tr is None:
        Y_tr = f_noisy.f(X_tr)
        model.updateModel(X_tr, Y_tr, None, None)

    Y_te = f_true.f(X_te)
    Yhat_te, s_te = model.predict(X_te)
    if sig > 0:
        Yhat_te = sig * Yhat_te
    Yhat_te = Yhat_te + mu

    fig = mw.plt().figure()
    ax = fig.add_subplot(111)
    ax.plot(X_te, Y_te, label="$f$", color="r")
    ax.scatter(
        X_tr, Y_tr, marker="x", color="r", label=f"Training data ($n={len(X_tr)}$)"
    )
    ax.plot(X_te, Yhat_te, label="$\hat{f}$", color="b", ls="--")
    X_te, Yhat_te, s_te = X_te[:, 0], Yhat_te[:, 0], s_te[:, 0]
    ax.fill_between(X_te, Yhat_te - s_te, Yhat_te + s_te, color="b", alpha=0.25)
    ax.set_xlabel("$x$")
    ax.set_ylabel("$y$")
    ax.legend()
    mw.sns().despine(fig=fig, ax=ax, trim=True)


## Objective function

In [None]:
f_true = GPyOpt.objective_examples.experiments1d.forrester(sd=0)
f_noisy = GPyOpt.objective_examples.experiments1d.forrester(sd=1)
bounds = [{"name": "var_1", "type": "continuous", "domain": (0, 1)}]
f_min = f_true.fmin

In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1, 3)
        self.fc2 = nn.Linear(3, 3)
        self.fc3 = nn.Linear(3, 1)
        self.layers = [self.fc1, self.fc2, self.fc3]

    def forward(self, x):
        x = F.tanh(self.fc1(x))
        x = F.tanh(self.fc2(x))
        return self.fc3(x)

## Offline (batch) learning

In [None]:
net = NRNet(num_layers=3, in_dim=1, hidden_dim=4, nonlin=torch.tanh, has_bias=True, init_as_design=True)

tb_writer = SummaryWriter("logs/nr_offline_play")
net_optim = optim.Adam(net.parameters(), lr=0.05)
nr_model = NRModel(
    net,
    lda=LDA,
    optim=net_optim,
    update_batch_size=np.inf,
    update_iters_per_point=20,
    tb_writer=tb_writer,
)

X_tr = np.random.rand(10, 1)
plot_comp_1d(f_true, f_noisy, nr_model, X_tr)

## Online optimization

In [None]:
tb_writer = SummaryWriter("logs/nr_online_play")
net = NRNet(num_layers=3, in_dim=1, hidden_dim=4, nonlin=torch.tanh, has_bias=True, init_as_design=True)
net_optim = optim.Adam(net.parameters(), lr=0.05)  # eta
nr_model = NRModel(
    net,
    lda=LDA,
    optim=net_optim,
    update_batch_size=np.inf,
    update_iters_per_point=10,
    tb_writer=tb_writer,
)

objective = GPyOpt.core.task.SingleObjective(f_noisy.f)
space = GPyOpt.Design_space(space=bounds)
acq_optimizer = GPyOpt.optimization.AcquisitionOptimizer(space)
initial_design = GPyOpt.experiment_design.initial_design("random", space, 5)
acq = GPyOpt.acquisitions.AcquisitionLCB(
    nr_model, space, acq_optimizer, exploration_weight=2  # gamma
)
feval = GPyOpt.core.evaluators.Sequential(acq)
nr_bo = GPyOpt.methods.ModularBayesianOptimization(
    nr_model, space, objective, acq, feval, initial_design, normalize_Y=True
)

nr_bo.run_optimization(max_iter=20, verbosity=True, eps=-1)
mu, sig = float(nr_bo.Y.mean()), float(nr_bo.Y.std())
plot_comp_1d(f_true, f_noisy, nr_model, nr_bo.X, nr_bo.Y, mu, sig)

fig = mw.plt().figure()
ax = fig.add_subplot(111)
diffs = np.linalg.norm(nr_bo.X[1:] - nr_bo.X[:-1], axis=1)
ax.semilogy(diffs)
ax.set_ylabel("$\|x_{n+1} - x_n\|$")
ax.set_xlabel("n")
_ = ax.set_title("Distance between consecutive x's")

fig = mw.plt().figure()
ax = fig.add_subplot(111)
diffs = np.abs(f_true.f(nr_bo.X)[:, 0] - f_min)
ax.semilogy(diffs)
ax.set_ylabel("$|f(x_n) - f_*|$")
ax.set_xlabel("n")
_ = ax.set_title("Distance from the minimum")
