# 3 dimensional RRT*


In [1]:
import numpy as np #math
from itertools import tee #iterator
import plotly as py #plotting
from plotly import graph_objs as go #graphic objects
from rtree import index
from operator import itemgetter
import random
import uuid

### Auxiliary geometry functions

In [2]:
def distance(a, b):
    """
    Returns the euclidian distance between 2 points
    :param a: 1st point
    :param b: 2nd point
    :return: euclidian distance
    """
    dist = np.linalg.norm(np.array(b) - np.array(a))
    return dist

def step(start, goal, stepSize):
    """
    Function to take a step from a start position in the direction of a goal position
    :param start: start point
    :param goal: goal point
    :param step: step size (distance)
    :return: steered point in direction of goal
    """
    start, goal = np.array(start), np.array(goal)
    vec = goal - start
    unit = vec / (np.sqrt(np.sum(vec ** 2))) # normalized unit vector
    steered = start + unit * stepSize # resized to step size
    return tuple(steered)

def pairwiseIteration(iterable):
    """
    Makes pairwise iteration over iterable
    :param iterable: iterable
    :return: s(array) -> (s0,s1), (s1,s2), (s2,s3)...
    """
    a, b = tee(iterable) # iterators
    next(b, None)
    return zip(a, b)

def linspacePoints(start, end, space):
    """
    Returns evenly spaced points along a line
    :param start: start of line
    :param end: end of line
    :param space: max distance between points
    :return: points along the line start-end spaced in 'space' distances
    """
    dist = distance(start, end)
    points = int(np.ceil(dist/space)) #number of points possible in space
    if points > 1:
        stepSize = dist / points - 1 #number of 'space' sized spaces posible
        for i in range(points):
            next = step(start, end, i* stepSize)
            yield next

### Tree class

In [3]:
class Tree(object):
    def __init__(self,X) -> None:
        """
        Tree structure
        :param X: search space 
        """
        prop = index.Property()
        prop.dimension = X.dimensions
        self.Vertices = index.Index(interleaved=True, properties=prop)
        self.vCount = 0
        self.Edges = {} # edge[child] = parent

### Heuristics

In [4]:
def costToGo(a: tuple, b: tuple) -> float:
    """
    :param a: current location
    :param b: next location
    :return: estimated segment_cost-to-go from a to b
    """
    return distance(a,b)

def pathCost(Edges, a, b):
    """
    Cost of the unique path from x_init to x in tree edges
    :param E: edges, in form of E[child] = parent
    :param a: initial location
    :param b: goal location
    :return: segment_cost of unique path (in tree) from x_init to x
    """
    cost = 0
    while not b == a:
        p = Edges[b]
        cost += distance(b, p) #distance of edge vertices
        b = p #go to next adjascent edge
        
    return cost

def segCost(a, b):
    """
    Cost function of the line between x_near and x_new (cost to move)
    :param a: start of line
    :param b: end of line
    :return: segment_cost function between a and b
    """
    return distance(a,b)

### RRT

In [5]:
class RRTBasic(object):
    def __init__(self, X, Q, x_init, x_goal, max, space, prob=0.01):
        """ 
        RRT Planner (basic ver)
        :param X: search space
        :param Q: list of tree's edges lenghts
        :param x_init: tuple, initial location
        :param x_goal: tuple, goal
        :param max: maximum of samples
        :param space: space between points used in collision checking
        :param prob: probability of checking if there is a solution
        """
        self.X = X
        self.samples = 0
        self.max = max
        self.Q = Q
        self.space = space
        self.prob = prob
        self.x_init = x_init
        self.x_goal = x_goal
        self.trees = []  # list of all trees
        self.addTree()  # add initial tree
        
    def addTree(self):
        self.trees.append(Tree(self.X)) #empty tree init for trees[]
        
    def addVertex(self, tree, v):
        """
        Add vertex to corresponding tree
        :param tree: int, tree to which to add vertex
        :param v: tuple, vertex to add
        """
        self.trees[tree].Vertices.insert(0, v + v, v)
        self.trees[tree].vCount += 1
        self.samples += 1
        
    def addEdge(self, tree, child, parent):
        """
        Add edge to corresponding tree
        :param tree: int, tree to which to add vertex
        :param child: tuple, child vertex
        :param parent: tuple, parent vertex
        """
        self.trees[tree].Edges[child] = parent 
        
    def VerticesNear(self, tree, x, n):
        """
        Return nearby vertices
        :param tree: int, tree being searched
        :param x: tuple, vertex around which searching
        :param n: int, max number of neighbors to return
        :return: list of nearby vertices
        """
        return self.trees[tree].Vertices.nearest(x, num_results=n, objects="raw")
    
    def getNearest(self, tree, x):
        """
        Return vertex nearest to x
        :param tree: int, tree being searched
        :param x: tuple, vertex around which searching
        :return: tuple, nearest vertex to x
        """
        return next(self.VerticesNear(tree, x, 1))
    
    def newNear(self, tree, q):
        """
        Return a new steered vertex and the vertex in tree that is nearest
        :param tree: int, tree being searched
        :param q: length of edge when steering (step)
        :return: vertex, new steered vertex, vertex, nearest vertex in tree to new vertex
        """
        xRand = self.X.sampleFree()
        xNearest = self.getNearest(tree, xRand)
        xNew = self.boundPoint(step(xNearest, xRand, q[0])) # get random point in step radius
        # check if new point is free in search space and not in Vertices already
        if not self.trees[0].Vertices.count(xNew) == 0 or not self.X.obstacleFree(xNew):
            return None, None
        self.samples += 1
        return xNew, xNearest
    
    def boundPoint(self, point):
        # set point to bound if it is out-of-bounds
        point = np.maximum(point, self.X.dimensionLenghts[:,0])
        point = np.minimum(point, self.X.dimensionLenghts[:,1])
        return tuple(point)
    
    def connect(self, tree, xA, xB):
        """
        Connect vertex x_a in tree to vertex x_b
        :param tree: int, tree to which to add edge
        :param xA: tuple, vertex
        :param xB: tuple, vertex
        :return: bool, True if able to add edge, False if prohibited by an obstacle
        """
        if self.trees[tree].Vertices.count(xB) == 0 and self.X.collisionFree(xA, xB, self.space):
            self.addVertex(tree, xB)
            self.addEdge(tree, xB, xA)
            return True
        return False
    
    def canReachGoal(self, tree):
        """
        Check if the goal can be connected to the graph
        :param tree: rtree of all Vertices
        :return: True if can be added, False otherwise
        """
        xNear = self.getNearest(tree, self.x_goal)
        if self.x_goal in self.trees[tree].Edges and xNear in self.trees[tree].Edges[self.x_goal]:
            # nearest vertex is goal
            return True
        if self.X.collisionFree(xNear, self.x_goal, self.space): # check obstacle free
            return True
        return False
    
    def getPath(self):
        """
        Return path through tree from start to goal
        :return: path if possible, None otherwise
        """
        if self.canReachGoal(0):
            print("Can connect to goal")
            self.connectGoal(0)
            return self.reconstructPath(0, self.x_init, self.x_goal)
        print("Could not connect to goal")
        return None
    
    def connectGoal(self, tree):
        """
        Connect x_goal to graph
        (does not check if this should be possible, for that use: can_connect_to_goal)
        :param tree: rtree of all Vertices
        """
        x_nearest = self.getNearest(tree, self.x_goal)
        self.trees[tree].Edges[self.x_goal] = x_nearest

    
    def reconstructPath(self, tree, x_init, x_goal):
        """
        Reconstruct path from start to goal
        :param tree: int, tree in which to find path
        :param x_init: tuple, starting vertex
        :param x_goal: tuple, ending vertex
        :return: sequence of vertices from start to goal
        """
        path = [x_goal]
        current = x_goal
        if x_init == x_goal:
            return path
        while not self.trees[tree].Edges[current] == x_init: #backtrack tree
            path.append(self.trees[tree].Edges[current])
            current = self.trees[tree].Edges[current]
        path.append(x_init)
        path.reverse()
        return path
    
    def checkSolution(self):
        # probabilistically check if solution found
        if self.prob and random.random() < self.prob:
            print("Checking if can connect to goal at", str(self.samples), "samples")
            path = self.getPath()
            if path is not None:
                return True, path
        # check if can connect to goal after generating max_samples
        if self.samples >= self.max:
            return True, self.getPath()
        return False, None
    
class RRT(RRTBasic):
    def __init__(self, X, Q, x_init, x_goal, max, space, prob=0.01):
        super().__init__(X, Q, x_init, x_goal, max, space, prob)
        
    def rrtSearch(self):
        """
        Create and return a Rapidly-exploring Random Tree, keeps expanding until can connect to goal
        :return: list representation of path, dict representing edges of tree in form E[child] = parent
        """
        self.addVertex(0, self.x_init)
        self.addEdge(0, self.x_init, None)
        
        while True:
            for q in self.Q: # iterate edge leghths until solution is found or timeout
                for i in range(q[1]): # iterate over edges of given length to add
                    xNew, xNear = self.newNear(0, q)
                    if xNew is None:
                        continue
                    
                    #connect shortest valid edge
                    self.connect(0, xNear, xNew)
                    
                    sol = self.checkSolution()
                    if sol[0]:
                        return sol[1]
                    
class RRTStar(RRT):
    def __init__(self, X, Q, x_init, x_goal, max, space, prob=0.01, rewire_count=None):
        super().__init__(X, Q, x_init, x_goal, max, space, prob)
        self.rewireCount = rewire_count if rewire_count is not None else 0
        self.c_best = float('inf') # length of best solution this far
        
    def getNearby(self, tree, x_init, x_new):
        """
        Get nearby vertices to new vertex and their associated path costs from the root of tree
        as if new vertex is connected to each one separately.

        :param tree: tree in which to search
        :param x_init: starting vertex used to calculate path cost
        :param x_new: vertex around which to find nearby vertices
        :return: list of nearby vertices and their costs, sorted in ascending order by cost
        """
        xNear = self.VerticesNear(tree, x_new, self.currentRewireCount(tree))
        lNear = [(pathCost(self.trees[tree].Edges, x_init, x_near) + segCost(x_near, x_new), x_near) for x_near in xNear]
        lNear.sort(key=itemgetter(0))
        return lNear
    
    def rewire(self, tree, x_new, lNear):
        """
        Rewire tree to shorten edges if possible
        Only rewires vertices according to rewire count
        :param tree: int, tree to rewire
        :param x_new: tuple, newly added vertex
        :param lNear: list of nearby vertices used to rewire
        :return:
        """
        for cNear, xNear in lNear:
            currCost = pathCost(self.trees[tree].Edges, self.x_init, xNear)
            tentCost = pathCost(self.trees[tree].Edges, self.x_init, x_new) + segCost(x_new, xNear)
            if tentCost < currCost and self.X.collisionFree(xNear, x_new, self.space):
                self.trees[tree].Edges[xNear] = x_new

    def connectShortestValid(self, tree, x_new, lNear):
        """
        Connect to nearest vertex that has an unobstructed path
        :param tree: int, tree being added to
        :param x_new: tuple, vertex being added
        :param lNear: list of nearby vertices
        """
        # check nearby vertices for total cost and connect shortest valid edge
        for cNear, xNear in lNear:
            if cNear + costToGo(xNear, self.x_goal) < self.c_best and self.connect(tree, xNear, x_new):
                break
            
    def currentRewireCount(self, tree):
        """
        Return rewire count
        :param tree: tree being rewired
        :return: rewire count
        """
        # if no rewire count specified, set rewire count to be all vertices
        if self.rewireCount is None:
            return self.trees[tree].vCount
        
        # max valid rewire count
        return min(self.trees[tree].vCount, self.rewireCount)
    
    def rrtStar(self):
        """
        Based on algorithm found in: Incremental Sampling-based Algorithms for Optimal Motion Planning
        :return: set of Vertices; Edges in form: vertex: [neighbor_1, neighbor_2, ...]
        """
        self.addVertex(0, self.x_init)
        self.addEdge(0,self.x_init, None)
        
        while True:
            for q in self.Q: #iterate over different edge lenghts
                for i in range(q[1]): #iterate over # of edges of given lenght
                    xNew, xNearest = self.newNear(0,q)
                    if xNew is None:
                        continue
                    
                    lNear = self.getNearby(0, self.x_init, xNew)
                    
                    self.connectShortestValid(0, xNew, lNear)
                    if xNew in self.trees[0].Edges:
                        self.rewire(0, xNew, lNear)
                        
                    solution = self.checkSolution()
                    if solution[0]:
                        return solution[1]

class RRTStarImproved(RRTStar):
    def __init__(self, X, Q, x_init, x_goal, max_samples, space, prob=0.01, rewire_count=None):
        super().__init__(X, Q, x_init, x_goal, max_samples, space, prob, rewire_count)
        self.c_best = float('inf')  # length of best solution this far

    def rrtStar(self):
        self.addVertex(0, self.x_init)
        self.addEdge(0, self.x_init, None)

        while True:
            for q in self.Q:
                for i in range(q[1]):
                    xNew, xNearest = self.newNear(0, q)
                    if xNew is None:
                        continue

                    lNear = self.getNearby(0, self.x_init, xNew)

                    self.connectShortestValid(0, xNew, lNear)
                    if xNew in self.trees[0].Edges:
                        self.rewire(0, xNew, lNear)

                    solution = self.checkSolution()
                    if solution[0]:
                        return solution[1]

    def connectShortestValid(self, tree, x_new, lNear):
        for cNear, xNear in lNear:
            if cNear + costToGo(xNear, self.x_goal) < self.c_best:
                if self.X.collisionFree(xNear, x_new, self.space):
                    self.connect(tree, xNear, x_new)

### Obstacle generation

In [6]:
def obstacleGenerator(obstacles):
    """
    Add obstacles to r-tree
    :param obstacles: list of obstacles
    """
    for obs in obstacles:
        yield (uuid.uuid4(), obs, obs)

### Search Space

In [7]:
class SearchSpace(object):
    def __init__(self, dimensionLenghts, O=None):
        """
        Initialize Search Space
        :param dimension_lengths: range of each dimension
        :param O: list of obstacles
        """
        # sanity check
        if len(dimensionLenghts) < 2:
            raise Exception("Must have at least 2D")
        self.dimensions = len(dimensionLenghts) # # of dimensions
        # sanity checks
        if any(len(i) != 2 for i in dimensionLenghts):
            raise Exception("Dimensions must only have a start and end")
        if any(i[0] >= i[1] for i in dimensionLenghts):
            raise Exception("Dimension start must be less than end")
        self.dimensionLenghts = dimensionLenghts
        p = index.Property()
        p.dimension = self.dimensions
        if O is None:
            self.obs = index.Index(interleaved=True, properties=p)
        else:
            # r-tree representation of obstacles
            # sanity check
            if any(len(o) / 2 != len(dimensionLenghts) for o in O):
                raise Exception("Obstacle has incorrect dimension definition")
            if any(o[i] >= o[int(i + len(o) / 2)] for o in O for i in range(int(len(o) / 2))):
                raise Exception("Obstacle start must be less than obstacle end")
            obstacle_data = list(obstacleGenerator(O))
            print(obstacle_data)
            # self.obs = index.Index(bounds=obstacle_data, interleaved=True, properties=p)

            self.obs = index.Index(bounds=obstacleGenerator(O), interleaved=True, properties=p)
            
    def obstacleFree(self, x):
        """
        Check if a location resides inside of an obstacle
        :param x: location to check
        :return: True if not inside an obstacle, False otherwise
        """
        return self.obs.count(x) == 0
    
    def sampleFree(self):
        """
        Sample a location within X_free
        :return: random location within X_free
        """
        while True:  # sample until not inside of an obstacle
            x = self.sample()
            if self.obstacleFree(x):
                return x
            
    def collisionFree(self, start, end, space):
        """
        Check if a line segment intersects an obstacle
        :param start: starting point of line
        :param end: ending point of line
        :param space: resolution of points to sample along edge when checking for collisions
        :return: True if line segment does not intersect an obstacle, False otherwise
        """
        points = linspacePoints(start, end, space)
        cFree = all(map(self.obstacleFree, points))
        return cFree
    
    def sample(self):
        """
        Return a random location within X
        :return: random location within X (not necessarily X_free)
        """
        x = np.random.uniform(self.dimensionLenghts[:, 0], self.dimensionLenghts[:, 1])
        return tuple(x)

### Plotting

In [8]:
colors = ['darkblue', 'teal']

class Plot(object):
    def __init__(self, filename):
        """
        Create a plot
        :param filename: filename
        """
        self.filename = "./output_" + filename + ".html"
        self.data = []
        self.layout = {'title': 'Plot', 'showlegend': False}
        self.fig = {'data': self.data, 'layout': self.layout}
        
    def plotTree(self, X, trees):
        """
        Plot tree
        :param X: Search Space
        :param trees: list of trees
        """
        if X.dimensions == 2:  # plot in 2D
            print("Cannot plot in 2 dimensions")
        elif X.dimensions == 3:  # plot in 3D
            self.plotTree3D(trees)
        else:  # can't plot in higher dimensions
            print("Cannot plot in > 3 dimensions")

    def plotTree3D(self, trees):
        """
        Plot 3D trees
        :param trees: trees to plot
        """
        for i, tree in enumerate(trees):
            for start, end in tree.Edges.items():
                if end is not None:
                    trace = go.Scatter3d(
                        x=[start[0], end[0]],
                        y=[start[1], end[1]],
                        z=[start[2], end[2]],
                        line=dict(
                            color=colors[i]
                        ),
                        mode="lines"
                    )
                    self.data.append(trace)
        
    def plotObstacles(self, X, O):
        """
        Plot obstacles
        :param X: Search Space
        :param O: list of obstacles
        """
        if X.dimensions == 2:  # plot in 2D
            print("Cannot plot 2D")
        elif X.dimensions == 3:  # plot in 3D
            for O_i in O:
                obs = go.Mesh3d(
                    x=[O_i[0], O_i[0], O_i[3], O_i[3], O_i[0], O_i[0], O_i[3], O_i[3]],
                    y=[O_i[1], O_i[4], O_i[4], O_i[1], O_i[1], O_i[4], O_i[4], O_i[1]],
                    z=[O_i[2], O_i[2], O_i[2], O_i[2], O_i[5], O_i[5], O_i[5], O_i[5]],
                    i=[7, 0, 0, 0, 4, 4, 6, 6, 4, 0, 3, 2],
                    j=[3, 4, 1, 2, 5, 6, 5, 2, 0, 1, 6, 3],
                    k=[0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6],
                    color='purple',
                    opacity=0.70
                )
                self.data.append(obs)
        else:  # can't plot in higher dimensions
            print("Cannot plot in > 3 dimensions")
            
    def plotPath(self, X, path):
        """
        Plot path through Search Space
        :param X: Search Space
        :param path: path through space given as a sequence of points
        """
        if X.dimensions == 2:  # plot in 2D
            print("Cannot plot 2D path")
        elif X.dimensions == 3:  # plot in 3D
            x, y, z = [], [], []
            for i in path:
                x.append(i[0])
                y.append(i[1])
                z.append(i[2])
            trace = go.Scatter3d(
                x=x,
                y=y,
                z=z,
                line=dict(
                    color="red",
                    width=4
                ),
                mode="lines"
            )

            self.data.append(trace)
        else:  # can't plot in higher dimensions
            print("Cannot plot in > 3 dimensions")
            
    def plotStart(self, X, x_init):
        """
        Plot starting point
        :param X: Search Space
        :param x_init: starting location
        """
        if X.dimensions == 2:  # plot in 2D
            print("Cannot plot 2D start")

            self.data.append(trace)
        elif X.dimensions == 3:  # plot in 3D
            trace = go.Scatter3d(
                x=[x_init[0]],
                y=[x_init[1]],
                z=[x_init[2]],
                line=dict(
                    color="orange",
                    width=10
                ),
                mode="markers"
            )

            self.data.append(trace)
        else:  # can't plot in higher dimensions
            print("Cannot plot in > 3 dimensions")

    def plotGoal(self, X, x_goal):
        """
        Plot goal point
        :param X: Search Space
        :param x_goal: goal location
        """
        if X.dimensions == 2:  # plot in 2D
            print("Cannot plot 2D goal")

            self.data.append(trace)
        elif X.dimensions == 3:  # plot in 3D
            trace = go.Scatter3d(
                x=[x_goal[0]],
                y=[x_goal[1]],
                z=[x_goal[2]],
                line=dict(
                    color="green",
                    width=10
                ),
                mode="markers"
            )

            self.data.append(trace)
        else:  # can't plot in higher dimensions
            print("Cannot plot in > 3 dimensions")

    def drawPlot(self, auto_open=True):
        """
        Render the plot to a file
        """
        py.offline.plot(self.fig, filename=self.filename, auto_open=auto_open)

## Main

In [9]:
X_dimensions = np.array([(0, 100), (0, 100), (0, 100)])  # dimensions of Search Space
# obstacles
Obstacles = np.array(
    [(20, 20, 20, 40, 40, 40), (20, 20, 60, 40, 40, 80), (20, 60, 20, 40, 80, 40), (60, 60, 20, 80, 80, 40),
     (60, 20, 20, 80, 40, 40), (60, 20, 60, 80, 40, 80), (20, 60, 60, 40, 80, 80), (60, 60, 60, 80, 80, 80)])
x_init = (0, 0, 0)  # starting location
x_goal = (100, 100, 100)  # goal location

Q = np.array([(2, 1)])  # length of tree edges
r = 0.01  # length of smallest edge to check for intersection with obstacles
max_samples = 1024  # max number of samples to take before timing out
rewire_count = 32  # optional, number of nearby branches to rewire
prc = 0.001  # probability of checking for a connection to goal

# create Search Space
X = SearchSpace(X_dimensions, Obstacles)

# create rrt_search
rrt = RRTStarImproved(X, Q, x_init, x_goal, max_samples, r, prc, rewire_count)
path = rrt.rrtStar()

# plot
plot = Plot("rrt_star_3d")
plot.plotTree(X, rrt.trees)
if path is not None:
    plot.plotPath(X, path)
plot.plotObstacles(X, Obstacles)
plot.plotStart(X, x_init)
plot.plotGoal(X, x_goal)
plot.drawPlot(auto_open=True)

[(UUID('dca9d2e7-ea1b-4c62-b1f5-ffb04e552982'), array([20, 20, 20, 40, 40, 40]), array([20, 20, 20, 40, 40, 40])), (UUID('43862818-bd77-4720-9b4e-ca521051f524'), array([20, 20, 60, 40, 40, 80]), array([20, 20, 60, 40, 40, 80])), (UUID('199a62c7-f9c7-48ff-937a-0d85d3b74173'), array([20, 60, 20, 40, 80, 40]), array([20, 60, 20, 40, 80, 40])), (UUID('98584c9a-4bab-40db-aa2f-40e891b7dced'), array([60, 60, 20, 80, 80, 40]), array([60, 60, 20, 80, 80, 40])), (UUID('7130d30a-1961-46df-ac99-7183308030b4'), array([60, 20, 20, 80, 40, 40]), array([60, 20, 20, 80, 40, 40])), (UUID('583f73fa-e20f-434f-9eab-6600f506ba2a'), array([60, 20, 60, 80, 40, 80]), array([60, 20, 60, 80, 40, 80])), (UUID('9bb8f56b-742c-4313-b1d4-1dfe06891775'), array([20, 60, 60, 40, 80, 80]), array([20, 60, 60, 40, 80, 80])), (UUID('107d613a-d75e-453b-93fd-55669a641ddf'), array([60, 60, 60, 80, 80, 80]), array([60, 60, 60, 80, 80, 80]))]


KeyboardInterrupt: 