In [1]:
from math import floor
import os

import numpy as np
import torch
import sys
torch.set_num_threads(1)

import multiprocessing
CPU_COUNT = multiprocessing.cpu_count()

from bingo.evaluation.fitness_function import VectorBasedFunction
from bingo.evolutionary_algorithms.age_fitness import AgeFitnessEA
from bingo.evaluation.evaluation import Evaluation
from bingo.evolutionary_optimizers.island import Island
from bingo.stats.pareto_front import ParetoFront

from bingo.local_optimizers.continuous_local_opt_md import ContinuousLocalOptimizationMD
from bingo.symbolic_regression.agraphMD.component_generator_md import ComponentGeneratorMD
from bingo.symbolic_regression.agraphMD.crossover_md import AGraphCrossoverMD
from bingo.symbolic_regression.agraphMD.generator_md import AGraphGeneratorMD
from bingo.symbolic_regression.agraphMD.mutation_md import AGraphMutationMD
from bingo.symbolic_regression.explicit_regression_md import ExplicitRegressionMD, ExplicitTrainingDataMD
from bingo.symbolic_regression.implicit_regression_md import ImplicitRegressionMD, ImplicitTrainingDataMD, \
    _calculate_partials

POP_SIZE = 100
STACK_SIZE = 10

np.set_printoptions(threshold=sys.maxsize)


In [2]:
dataset_path = "../data/processed_data/vm_1_bingo_format.txt"
transposed_dataset_path = "../data/processed_data/vm_1_transpose_bingo_format.txt"
data = np.loadtxt(dataset_path)
transposed_data = np.loadtxt(transposed_dataset_path)
state_param_dims = [(0, 0)]
output_dim = (3, 3)
data, transposed_data

(array([[-3.33333587e+01,  6.66667174e+01, -3.33333587e+01,
          0.00000000e+00],
        [-4.00000254e+01,  8.00000508e+01, -4.00000254e+01,
          2.00000000e-03],
        [-4.66666920e+01,  9.33333841e+01, -4.66666920e+01,
          4.00000000e-03],
        [-5.33333587e+01,  1.06666717e+02, -5.33333587e+01,
          6.00000000e-03],
        [-6.00000254e+01,  1.20000051e+02, -6.00000254e+01,
          8.00000000e-03],
        [-6.66666920e+01,  1.33333384e+02, -6.66666920e+01,
          1.00000000e-02],
        [-7.33333587e+01,  1.46666717e+02, -7.33333587e+01,
          1.20000000e-02],
        [            nan,             nan,             nan,
                     nan],
        [ 6.68305169e+01,  7.96493727e+01, -1.46479890e+02,
          0.00000000e+00],
        [ 6.07542347e+01,  7.24092626e+01, -1.33163497e+02,
          2.00000000e-03],
        [ 5.46779525e+01,  6.51691526e+01, -1.19847105e+02,
          4.00000000e-03],
        [ 4.86016703e+01,  5.79290425e+01, 

In [3]:
def _calculate_partials_local(X, window_size=7):
    from bingo.symbolic_regression.implicit_regression_md import _savitzky_golay_gram
    """Calculate derivatives with respect to time (first dimension).

    Parameters
    ----------
     X : 2d numpy array
        array for which derivatives will be calculated in the first dimension.
        Distinct trajectories can be specified by separating the datasets
        within X by rows of np.nan

    Returns
    -------
    2d numpy array :
        updated X array and corresponding time derivatives
    """
    # find splits
    break_points = np.where(np.any(np.isnan(X), 1))[0].tolist()
    break_points.append(X.shape[0])

    start = 0
    for end in break_points:
        x_seg = np.copy(X[start:end, :])
        # calculate time derivs using filter
        time_deriv = np.empty(x_seg.shape)
        for i in range(x_seg.shape[1]):
            time_deriv[:, i] = _savitzky_golay_gram(x_seg[:, i], window_size, 3, 1)
        # remove edge effects
        edge_effect_start = window_size//2
        edge_effect_end = -(window_size//2)
        if window_size % 2 == 1:
            edge_effect_end -= 1

        # TODO manually set valid range to entire dataset so we can use window size < 7, bad!!!
        edge_effect_start = 0
        edge_effect_end = len(time_deriv) + 1
        time_deriv = time_deriv[edge_effect_start:edge_effect_end, :]
        x_seg = x_seg[edge_effect_start:edge_effect_end, :]

        if start == 0:
            x_all = np.copy(x_seg)
            time_deriv_all = np.copy(time_deriv)
            inds_all = np.arange(start + 3, end - 4)
        else:
            x_all = np.vstack((x_all, np.copy(x_seg)))
            time_deriv_all = np.vstack((time_deriv_all,
                                        np.copy(time_deriv)))

            inds_all = np.hstack((inds_all,
                                  np.arange(start + 3, end - 4)))
        start = end + 1

    return x_all, time_deriv_all, inds_all


In [4]:
x, dx_dt, _ = _calculate_partials_local(data, window_size=5)

In [5]:
x

array([[-3.33333587e+01,  6.66667174e+01, -3.33333587e+01,
         0.00000000e+00],
       [-4.00000254e+01,  8.00000508e+01, -4.00000254e+01,
         2.00000000e-03],
       [-4.66666920e+01,  9.33333841e+01, -4.66666920e+01,
         4.00000000e-03],
       [-5.33333587e+01,  1.06666717e+02, -5.33333587e+01,
         6.00000000e-03],
       [-6.00000254e+01,  1.20000051e+02, -6.00000254e+01,
         8.00000000e-03],
       [-6.66666920e+01,  1.33333384e+02, -6.66666920e+01,
         1.00000000e-02],
       [-7.33333587e+01,  1.46666717e+02, -7.33333587e+01,
         1.20000000e-02],
       [ 6.68305169e+01,  7.96493727e+01, -1.46479890e+02,
         0.00000000e+00],
       [ 6.07542347e+01,  7.24092626e+01, -1.33163497e+02,
         2.00000000e-03],
       [ 5.46779525e+01,  6.51691526e+01, -1.19847105e+02,
         4.00000000e-03],
       [ 4.86016703e+01,  5.79290425e+01, -1.06530713e+02,
         6.00000000e-03],
       [ 4.25253881e+01,  5.06889325e+01, -9.32143206e+01,
      

In [6]:
dx_dt

array([[-6.66666667e+00,  1.33333333e+01, -6.66666667e+00,
         2.00000000e-03],
       [-6.66666667e+00,  1.33333333e+01, -6.66666667e+00,
         2.00000000e-03],
       [-6.66666667e+00,  1.33333333e+01, -6.66666667e+00,
         2.00000000e-03],
       [-6.66666667e+00,  1.33333333e+01, -6.66666667e+00,
         2.00000000e-03],
       [-6.66666667e+00,  1.33333333e+01, -6.66666667e+00,
         2.00000000e-03],
       [-6.66666667e+00,  1.33333333e+01, -6.66666667e+00,
         2.00000000e-03],
       [-6.66666667e+00,  1.33333333e+01, -6.66666667e+00,
         2.00000000e-03],
       [-6.07628220e+00, -7.24011004e+00,  1.33163922e+01,
         2.00000000e-03],
       [-6.07628220e+00, -7.24011004e+00,  1.33163922e+01,
         2.00000000e-03],
       [-6.07628220e+00, -7.24011004e+00,  1.33163922e+01,
         2.00000000e-03],
       [-6.07628220e+00, -7.24011004e+00,  1.33163922e+01,
         2.00000000e-03],
       [-6.07628220e+00, -7.24011004e+00,  1.33163922e+01,
      

In [7]:
x_transposed, dx_dt_transposed, _ = _calculate_partials(transposed_data, window_size=5)

In [8]:
x_transposed

array([[-3.33333587e+01,  6.66667174e+01, -3.33333587e+01,
         0.00000000e+00],
       [ 6.68305169e+01,  7.96493727e+01, -1.46479890e+02,
         0.00000000e+00],
       [ 5.25713507e+01,  9.21796752e+00, -6.17893183e+01,
         0.00000000e+00],
       [ 6.64006803e+01, -3.83529338e+01, -2.80477464e+01,
         0.00000000e+00],
       [ 5.27592226e+01, -1.44894142e+02,  9.21349193e+01,
         0.00000000e+00],
       [-3.77843531e+01, -1.03837661e+02,  1.41622014e+02,
         0.00000000e+00],
       [-4.17828811e+01, -2.40972269e+01,  6.58801080e+01,
         0.00000000e+00],
       [-6.24549025e+01,  1.10314858e+01,  5.14234168e+01,
         0.00000000e+00],
       [-1.14599839e+02,  1.36568454e+02, -2.19686153e+01,
         0.00000000e+00],
       [-3.33333587e+01,  6.66667174e+01, -3.33333587e+01,
         0.00000000e+00],
       [-3.33333587e+01, -3.33333587e+01,  6.66667174e+01,
         0.00000000e+00],
       [ 3.03728237e+01, -6.65815361e+01,  3.62087124e+01,
      

In [9]:
dx_dt_transposed

array([[ 1.36497650e+02,  2.51733989e+01, -1.61671049e+02,
         0.00000000e+00],
       [ 4.27802914e+01, -2.40433053e+01, -1.87369861e+01,
         0.00000000e+00],
       [-7.46093956e+00, -6.10381327e+01,  6.84990723e+01,
         0.00000000e+00],
       [ 8.84315376e+00, -8.74508202e+01,  7.86076664e+01,
         0.00000000e+00],
       [-6.15938363e+01, -4.08802184e+01,  1.02474055e+02,
         0.00000000e+00],
       [-5.22901039e+01,  7.64159084e+01, -2.41258044e+01,
         0.00000000e+00],
       [-2.50044446e+00,  5.31242146e+01, -5.06237702e+01,
         0.00000000e+00],
       [-4.89155548e+01,  9.29017560e+01, -4.39862012e+01,
         0.00000000e+00],
       [ 1.87102357e+01,  3.78598321e+01, -5.65700678e+01,
         0.00000000e+00],
       [ 4.64420097e+01, -1.06800124e+02,  6.03581139e+01,
         0.00000000e+00],
       [ 2.32876529e+01, -6.61203593e+01,  4.28327064e+01,
         0.00000000e+00],
       [ 8.43360102e+01, -5.77270847e+01, -2.66089255e+01,
      