In [None]:
################################################################################
#
# Proximity cost, derived from Cost base class. Implements a cost function that
# depends only on state and penalizes -min(distance, max_distance)^2.
#
################################################################################

import torch
import numpy as np
import matplotlib.pyplot as plt

from cost import Cost
from point import Point

class TargetDistance(Cost):
    def __init__(self, position_indices, point,
                 max_distance,
                 name=""):
        """
        Initialize with dimension to add cost to and threshold BELOW which
        to impose quadratic cost. Above the threshold, we use a very light
        quadratic cost. The overall cost is continuous.
        :param position_indices: indices of input corresponding to (x, y)
        :type position_indices: (uint, uint)
        :param point: point from which to compute proximity
        :type point: Point
        :param max_distance: maximum value of distance to penalize
        :type threshold: float
        :param outside_weight: weight of quadratic cost outside threshold
        :type outside_weight: float
        :param apply_after_time: only apply proximity time after this time step
        :type apply_after_time: int
        """
        self._x_index, self._y_index = position_indices
        self._point = point
        self._max_distance = max_distance
        self._outside_weight = outside_weight
        self._apply_after_time = apply_after_time
        super(TargetDistance, self).__init__(name)

    def __call__(self, x, k=0):
        """
        Evaluate this cost function on the given input state and time.
        NOTE: `x` should be a column vector.
        :param x: concatenated state of the two systems
        :type x: torch.Tensor
        :param k: time step, if cost is time-varying
        :type k: uint
        :return: scalar value of cost
        :rtype: torch.Tensor
        """

        # Compute relative distance.
        dx = x[self._x_index, 0] - self._point.x
        dy = x[self._y_index, 0] - self._point.y
        relative_squared_distance = torch.sqrt(dx*dx + dy*dy)
        
        return -(relative_squared_distance - self._max_distance)


    def render(self, ax=None):
        """ Render this obstacle on the given axes. """
        if np.isinf(self._max_squared_distance):
            radius = 1.0
        else:
            radius = np.sqrt(self._max_squared_distance)

        circle = plt.Circle(
            (self._point.x, self._point.y), radius,
            color="g", fill=True, alpha=0.75)
        ax.add_artist(circle)
        ax.text(self._point.x + 1.25, self._point.y + 1.25, "goal", fontsize=10)