Skip to content

Commit

Permalink
Made some small changes to improve the speed of computations. Mostly …
Browse files Browse the repository at this point in the history
…just added the numba decorator to some functions. Changed the copy function slightly and the init function so that np.array is only called when it is needed as it requires a significant amount of overhead when called many times.
  • Loading branch information
ckielasjensen committed Apr 24, 2019
1 parent c0074c8 commit 29b263a
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 22 deletions.
51 changes: 36 additions & 15 deletions bezier.py
Expand Up @@ -12,15 +12,14 @@

from collections import defaultdict

from gjk.gjk import gjk
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
#import numba
from numba import njit, jit
import numpy as np
import scipy.optimize
from scipy.special import binom

from gjk.gjk import gjk


#TODO:
# Implement curve using Bernstein basis instead of de cast
Expand Down Expand Up @@ -54,11 +53,19 @@ def __init__(self, cpts=None, tau=None, tf=1.0):
self._tf = float(tf)
self._curve = None

if tau is None:
self._tau = np.linspace(0, self._tf, 1001)
# if tau is None:
# self._tau = np.linspace(0, self._tf, 1001)

if cpts is not None:
self._cpts = np.array(cpts, ndmin=2, dtype=float)
# Checking to see if the cpts are in the desired format. If they
# are, don't call np.array since it causes a bottleneck in certain
# iterative procedures.
if (isinstance(cpts, np.ndarray) and
cpts.dtype == 'float64' and
cpts.ndim == 2):
self._cpts = cpts
else:
self._cpts = np.array(cpts, ndmin=2, dtype=float)
self._dim = self._cpts.shape[0]
self._deg = self._cpts.shape[1] - 1
else:
Expand All @@ -73,7 +80,12 @@ def cpts(self):
def cpts(self, value):
self._curve = None

newCpts = np.array(value, ndmin=2)
if (isinstance(value, np.ndarray) and
value.ndim == 2 and
value.dtype == 'float64'):
newCpts = value
else:
newCpts = np.array(value, ndmin=2, dtype=float)

self._dim = newCpts.shape[0]
self._deg = newCpts.shape[1] - 1
Expand Down Expand Up @@ -108,7 +120,7 @@ def tf(self, value):
def tau(self):
if self._tau is None:
self._tau = np.linspace(0, self._tf, 1001)
else:
elif not isinstance(self._tau, np.ndarray):
self._tau = np.array(self._tau)
return self._tau

Expand Down Expand Up @@ -207,7 +219,8 @@ def copy(self):
:return: Deep copy of Bezier object
:rtype: Bezier
"""
return Bezier(self.cpts, self.tau, self.tf)
# return Bezier(self.cpts, self.tau, self.tf)
return Bezier(self.cpts, None, self.tf)

def plot(self, axisHandle=None, showCpts=True, **kwargs):
"""Plots the Bezier curve in 1D or 2D
Expand Down Expand Up @@ -724,8 +737,14 @@ def normSquare(self):
prodM = prodMatrix(self.deg).T
Bezier.productMatrixCache[self.deg][self.deg] = prodM

return Bezier(_normSquare(self.cpts, 1, self.dim, prodM.T),
tau=self.tau, tf=self.tf)
normCpts = _normSquare(self.cpts, 1, self.dim, prodM.T)

newCurve = self.copy()
newCurve.cpts = normCpts

return newCurve
# return Bezier(_normSquare(self.cpts, 1, self.dim, prodM.T),
# tau=self.tau, tf=self.tf)


class RationalBezier(BezierParams):
Expand Down Expand Up @@ -1173,6 +1192,7 @@ def _norm(x):
# return res


@njit(cache=True)
def _normSquare(x, Nveh, Ndim, prodM):
"""Compute the control points of the square of the norm of a vector
Expand All @@ -1184,16 +1204,17 @@ def _normSquare(x, Nveh, Ndim, prodM):
Code ported over from Venanzio Cichella's MATLAB norm_square function.
NOTE: This only works on 1D or 2D matricies. It will fail for 3 or more.
"""
x = np.array(x)
if x.ndim == 1:
x = x[None]
# x = np.array(x)
# if x.ndim == 1:
# x = x[None]

m, N = x.shape

xsquare = np.zeros((m, prodM.shape[0]))

for i in range(m):
xaug = np.dot(x[i, None].T, x[i, None])
# xaug = np.dot(x[i, None].T, x[i, None])
xaug = np.dot(x.T, x)
xnew = xaug.reshape((N**2, 1))
xsquare[i, :] = np.dot(prodM, xnew).T[0]

Expand Down
14 changes: 7 additions & 7 deletions gjk/gjk.py
Expand Up @@ -28,7 +28,7 @@ def gjk(polygon1, polygon2, method='nearest', *args, **kwargs):
return algo(polygon1, polygon2, *args, **kwargs)


@njit
@njit(cache=True)
def gjkNearest(polygon1, polygon2, maxIter=10):
"""
Finds the shortest distance between two polygons using the GJK algorithm.
Expand Down Expand Up @@ -85,7 +85,7 @@ def gjkCollision(polygon1, polygon2):
raise NotImplementedError(errorMsg)


@njit
@njit(cache=True)
def support(shape, direction):
"""
Returns the point in shape that is furthest in the desired direction.
Expand Down Expand Up @@ -115,7 +115,7 @@ def support(shape, direction):
return supportPoint


@njit
@njit(cache=True)
def closestPointToOrigin2(a, b):
"""
Finds the closest point to the origin on the line AB.
Expand All @@ -138,7 +138,7 @@ def closestPointToOrigin2(a, b):
return closestPoint, distance


@njit
@njit(cache=True)
def closestPointToOrigin(a, b):
"""
Finds the closest point to the origin on the line AB.
Expand Down Expand Up @@ -172,7 +172,7 @@ def closestPointToOrigin(a, b):
return closestPt, np.sqrt(dot(closestPt, closestPt))


@njit
@njit(cache=True)
def dot(a, b):
"""
Fast implementation of the dot product.
Expand All @@ -195,7 +195,7 @@ def dot(a, b):
return a[0]*b[0] + a[1]*b[1] + a[2]*b[2]


@njit
@njit(cache=True)
def tripleProduct(a, b, c):
"""
Fast implementation of the vector triple product.
Expand Down Expand Up @@ -248,6 +248,6 @@ def tripleProduct(a, b, c):
print(retVal)

poly3 = np.random.random((10, 3))
poly4 = np.random.random((10, 3))
poly4 = np.random.random((10, 3))+3

print(gjk(poly3, poly4))

0 comments on commit 29b263a

Please sign in to comment.