diff --git a/neurolib/control/optimal_control/cost_functions.py b/neurolib/control/optimal_control/cost_functions.py index 6037861d..4b8cc61b 100644 --- a/neurolib/control/optimal_control/cost_functions.py +++ b/neurolib/control/optimal_control/cost_functions.py @@ -19,7 +19,7 @@ def accuracy_cost(x, target_timeseries, weights, cost_matrix, dt, interval=(0, N :param interval: (t_start, t_end). Indices of start and end point of the slice (both inclusive) in time dimension. Only 'int' positive index-notation allowed (i.e. no negative indices or 'None'). :type interval: tuple, optional - + :return: Accuracy cost. :rtype: float """ @@ -56,7 +56,7 @@ def derivative_accuracy_cost(x, target_timeseries, weights, cost_matrix, interva :param interval: (t_start, t_end). Indices of start and end point of the slice (both inclusive) in time dimension. Only 'int' positive index-notation allowed (i.e. no negative indices or 'None'). :type interval: tuple, optional - + :return: Accuracy cost derivative. :rtype: ndarray """ @@ -84,7 +84,7 @@ def precision_cost(x_sim, x_target, cost_matrix, interval=(0, None)): :param interval: (t_start, t_end). Indices of start and end point of the slice (both inclusive) in time dimension. Only 'int' positive index-notation allowed (i.e. no negative indices or 'None'). :type interval: tuple - + :return: Precision cost for time interval. :rtype: float """ @@ -114,7 +114,7 @@ def derivative_precision_cost(x_sim, x_target, cost_matrix, interval): :param interval: (t_start, t_end). Indices of start and end point of the slice (both inclusive) in time dimension. Only 'int' positive index-notation allowed (i.e. no negative indices or 'None'). :type interval: tuple - + :return: Control-dimensions x T array of precision cost gradients. :rtype: np.ndarray """ @@ -140,7 +140,7 @@ def control_strength_cost(u, weights, dt): :type weights: dictionary :param dt: Time step. :type dt: float - + :return: control strength cost of the control. :rtype: float """ @@ -159,6 +159,9 @@ def control_strength_cost(u, weights, dt): for t in range(u.shape[2]): cost += cost_timeseries[n, v, t] * dt + if weights["w_1D"] != 0.0: + cost += weights["w_1D"] * L1D_cost_integral(u, dt) + return cost @@ -179,6 +182,8 @@ def derivative_control_strength_cost(u, weights): if weights["w_2"] != 0.0: der += weights["w_2"] * derivative_L2_cost(u) + if weights["w_1D"] != 0.0: + der += weights["w_1D"] * derivative_L1D_cost(u, dt) return der @@ -189,7 +194,7 @@ def L2_cost(u): :param u: Control-dimensions x T array. Control signals. :type u: np.ndarray - + :return: L2 cost of the control. :rtype: float """ @@ -203,8 +208,49 @@ def derivative_L2_cost(u): :param u: Control-dimensions x T array. Control signals. :type u: np.ndarray - + :return: Control-dimensions x T array of L2-cost gradients. :rtype: np.ndarray """ return u + + +@numba.njit +def L1D_cost_integral( + u, + dt, +): + """'Directional sparsity' or 'L1D' cost integrated over time. Penalizes for control strength. + :param u: Control-dimensions x T array. Control signals. + :type u: np.ndarray + :param dt: Time step. + :type dt: float + :return: L1D cost of the control. + :rtype: float + """ + + return np.sum(np.sum(np.sqrt(np.sum(u**2, axis=2) * dt), axis=1), axis=0) + + +@numba.njit +def derivative_L1D_cost( + u, + dt, +): + """ + :param u: Control-dimensions x T array. Control signals. + :type u: np.ndarray + :param dt: Time step. + :type dt: float + :return : Control-dimensions x T array of L1D-cost gradients. + :rtype: np.ndarray + """ + + denominator = np.sqrt(np.sum(u**2, axis=2) * dt) + der = np.zeros((u.shape)) + for n in range(der.shape[0]): + for v in range(der.shape[1]): + if denominator[n, v] != 0.0: + der[n, v, :] = u[n, v, :] / denominator[n, v] + + return der diff --git a/neurolib/control/optimal_control/oc.py b/neurolib/control/optimal_control/oc.py index 82fcc145..7e400d68 100644 --- a/neurolib/control/optimal_control/oc.py +++ b/neurolib/control/optimal_control/oc.py @@ -17,6 +17,7 @@ def getdefaultweights(): ) weights["w_p"] = 1.0 weights["w_2"] = 0.0 + weights["w1D"] = 0.0 return weights @@ -471,14 +472,14 @@ def __init__( for v, iv in enumerate(self.model.input_vars): control[:, v, :] = self.model.params[iv] - self.control = control.copy() + self.control = control.copy() self.check_params() self.control = update_control_with_limit( self.N, self.dim_in, self.T, control, 0.0, np.zeros(control.shape), self.maximum_control_strength ) - self.model_params = self.get_model_params() + self.model_params = self.get_model_params() def check_params(self): """Checks a subset of parameters and throws an error if a wrong dimension is found.""" diff --git a/tests/control/optimal_control/test_oc_cost_functions.py b/tests/control/optimal_control/test_oc_cost_functions.py index 0a580110..a96018b7 100644 --- a/tests/control/optimal_control/test_oc_cost_functions.py +++ b/tests/control/optimal_control/test_oc_cost_functions.py @@ -160,6 +160,30 @@ def test_derivative_L2_cost(self): desired_output = u self.assertTrue(np.all(cost_functions.derivative_L2_cost(u) == desired_output)) + def test_L1D_cost(self): + print(" Test L1D cost") + dt = 0.1 + reference_result = 2.0 * np.sum(np.sqrt(np.sum(p.TEST_INPUT_1N_6**2 * dt, axis=1))) + weights = getdefaultweights() + weights["w_1D"] = 1.0 + u = np.concatenate([p.TEST_INPUT_1N_6[:, np.newaxis, :], p.TEST_INPUT_1N_6[:, np.newaxis, :]], axis=1) + L1D_cost = cost_functions.control_strength_cost(u, weights, dt) + + self.assertAlmostEqual(L1D_cost, reference_result, places=8) + + def test_derivative_L1D_cost(self): + print(" Test L1D cost derivative") + dt = 0.1 + denominator = np.sqrt(np.sum(p.TEST_INPUT_1N_6**2 * dt, axis=1)) + + u = np.concatenate([p.TEST_INPUT_1N_6[:, np.newaxis, :], p.TEST_INPUT_1N_6[:, np.newaxis, :]], axis=1) + reference_result = np.zeros((u.shape)) + for n in range(u.shape[0]): + for v in range(u.shape[1]): + reference_result[n, v, :] = u[n, v, :] / denominator[n] + + self.assertTrue(np.all(cost_functions.derivative_L1D_cost(u, dt) == reference_result)) + def test_weights_dictionary(self): print("Test dictionary of cost weights") model = FHNModel()