In [44]:
import numpy as np
import math
from env import *
import copy


In [2]:
x_bounds = (0, 50)
y_bounds = (0, 30)

In [59]:
class Node:
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.parent = None
        self.cost = None

class RRT_STAR:
    def __init__(self, start, goal, env, prob_gs):

        self.node_s = Node(start[0], start[1])
        self.node_g = Node(goal[0], goal[1])
        self.env = env

        # CHECK IF START NODE AND GOAL NODE IN THE ENV......

        self.prob_gs = prob_gs # Empirical value
        
        self.node_s.cost = 0
        self.Tree = [self.node_s]
        
        self.radius = max(self.env.x_range[-1], self.env.y_range[-1])//10 # Empirical value
    
    
    def choose_parent(self, N_near, n_nearest, n_new):
        n_parent = n_nearest
        cost_min = n_nearest.cost + self.compute_cost(n_nearest, n_new) # Init. guess parent for sampled node

        for n_near in N_near:
            # Compute Trajectory control HERE!!!
            # Collision CHECK!!!
            inter_cost = n_near.cost + self.compute_cost(n_near, n_new) # Intermediate cost
            if n_new in self.Tree:
                if inter_cost < n_new.cost and inter_cost < cost_min:
                    n_parent = n_near         # New best alternative
                    cost_min = inter_cost
            else:
                if inter_cost < cost_min:
                    n_parent = n_near         # New best alternative
                    cost_min = inter_cost
        
        return n_parent

    def rewire(self, N_near, n_parent, n_new):
        
        N_near.remove(n_parent)
        for n_near in N_near:
            # Steer!!!
            # Obstacle!!!
            if n_new.cost + self.compute_cost(n_near, n_new) < n_near.cost:
                n_near_old = copy.copy(n_near)
                n_near.parent = n_new
                n_near.cost = n_new.cost + self.compute_cost(n_near, n_new)
                self.Tree[self.Tree.index(n_near_old)] = n_near
        return
    
    def compute_NN(self, n_curr, near = False):
        
        dist_list = np.array([self.compute_dist(n_curr, n_tree) for n_tree in self.Tree])

        if not near: # Nearest
            return self.Tree[np.argmin(dist_list)]
        else:        # Near
            node_list = [self.Tree[idx] for idx in range(len(dist_list)) if dist_list[idx] < self.radius]
            
            return node_list
    
    
    def sample(self):
        ''' Sample function RRT*, sample the goal node with a low probability '''
        if np.random.random() > self.prob_gs:
            x_new = np.random.uniform(self.env.x_range[-1])
            y_new = np.random.uniform(self.env.y_range[-1])
            # print(x_new, y_new)
            return Node(x_new, y_new)
        
        return self.node_g

    def compute_dist(self, n1, n2):

        dist = math.sqrt((n1.x-n2.x)**2 + (n1.y - n2.y)**2)

        return dist
    
    def compute_cost(self, n1, n2):
        # Add Traversability!!!
        cost = self.compute_dist(n1, n2)

        return cost
    
    

        




In [60]:
start = 2,3
goal = 7,8
env = Env(x_bounds=x_bounds, y_bounds=y_bounds)
prob_gs = 0.1


rrt = RRT_STAR(start, goal, env, prob_gs)

In [61]:
n_new = Node(10, 0)
n_nearest = Node(8, 0)
n_nearest.cost = 10
rrt.Tree.append(n_nearest)
n1 = Node(7,0)
n1.cost = 3
n2 = Node(12, 0)
n2.cost = 12

rrt.Tree.append(n1)
rrt.Tree.append(n2)


In [62]:
N_near = [n_nearest, n1, n2]

In [63]:
parent = rrt.choose_parent(N_near, n_nearest, n_new)

In [65]:
parent.x

7