In [None]:
import sys
sys.path.append("../")

from utils.utils import plot_simulated_meshgrid, plot_collage
import model
import gt_sampling

import numpy as np
import torch
import bbobtorch

In [None]:
# create ground truth functions
n_dim = 2
samples = 1000
seed = 42

problem_f01 = bbobtorch.create_f01(n_dim, seed=seed)

In [None]:
# sampling from groundtruth functions
sample_f01 = gt_sampling.get_sample(problem_f01, n_samples=samples, n_dim=2, seed=42, method='random', )

In [None]:
X_input = sample_f01[0]

## Simulate the function with NN

In [None]:
def higher_order_derivatives(f, wrt, n):
    derivatives = [ff for ff in f]
    for f_ in f:
        for _ in range(n):
            grads = torch.autograd.grad(f_.flatten(), wrt, create_graph=True)[0]
            f = grads.max(0).values 
            derivatives.append(f)
    return torch.hstack(derivatives)

class ZehleTaylor(torch.nn.Module):
    def __init__(self, diff_degree, criterion):
        super().__init__()
        self.diff_degree = diff_degree
        self.criterion = criterion

    def forward(self, pred, true, x):
        true = higher_order_derivatives(true, x, self.diff_degree)
        pred = higher_order_derivatives(pred, x, self.diff_degree)
        loss = self.criterion(pred.flatten(), true.flatten())
        return loss

In [None]:
input_dim = 2
hidden_dim = 16
hidden_layers = 4
output_dim = 1
num_epochs = 1000
bs = 256
learning_rate = 1e-6

m = model.NN(input_dim, hidden_dim, hidden_layers)
optimizer = optim.SGD(m.parameters(), lr=learning_rate)
criterion = ZehleTaylor(3, torch.nn.MSELoss())

# Training loop
for epoch in range(num_epochs):
    x = X_input.clone().detach().requires_grad_(True)

    optimizer.zero_grad()
    trues = bobby.create_f01(2)(x)
    preds = m(x)
    
    Karloss = criterion(preds, trues, x)

    # Backward and optimize
    Karloss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if (epoch + 1) % 20 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {Karloss.item():.4f}')

In [None]:

# generate mesh grid for plotting
with torch.no_grad():
    x = np.linspace(-5.0, 5.0, 100)
    y = np.linspace(-5.0, 5.0, 100)
    X, Y = np.meshgrid(x, y)
    mesh_samples = np.c_[X.ravel(), Y.ravel()]
    mesh_samples_tensor = torch.tensor(mesh_samples, dtype=torch.float32)
    mesh_results = m(mesh_samples_tensor).reshape(X.shape)

In [None]:
plot_simulated_meshgrid(X, Y, mesh_results, model='NN')

In [None]:
plot_collage(sample_f01[0].detach().numpy(), sample_f01[1].detach().numpy(), problem_f01, "BBOB F24", "Phelipe", X, Y, mesh_results)
#samples, results, problem, problem_name, model_name, X, Y, mesh_results

In [None]:
# pickle nn model
sys.path.append("../")
torch.save(m.state_dict(), "../models/f01_mse_nn_model.pt")