Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 41 additions & 20 deletions quantecon/_gridtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def cartesian(nodes, order='C'):
'''
"""
Cartesian product of a list of arrays

Parameters
Expand All @@ -25,34 +25,55 @@ def cartesian(nodes, order='C'):
-------
out : ndarray(ndim=2)
each line corresponds to one point of the product space
'''
"""

# Avoid unnecessary re-allocations: dedicate only 1 conversion pass
nodes = [np.asarray(e) for e in nodes]
shapes = [e.shape[0] for e in nodes]

dtype = np.result_type(*nodes)

n = len(nodes)
l = np.prod(shapes)
out = np.zeros((l, n), dtype=dtype)
if n == 0:
return np.empty((0, 0), dtype=dtype)

# Avoids creating int64 array for a single scalar by checking n == 1 early
if n == 1:
arr = nodes[0].reshape(-1, 1)
if order == 'C':
return arr
# 'F' order is just identity for single array
return arr

l = 1
for dim in shapes:
l *= dim

out = np.empty((l, n), dtype=dtype)

# Efficient repetitions computation (preallocate, avoid unnecessary lists)
if order == 'C':
repetitions = np.cumprod([1] + shapes[:-1])
repetitions = np.empty(n, dtype=np.int64)
acc = 1
for i in range(n):
repetitions[i] = acc
acc *= shapes[i]
else:
shapes.reverse()
sh = [1] + shapes[:-1]
repetitions = np.cumprod(sh)
repetitions = repetitions.tolist()
repetitions.reverse()

# Reverse handling done without mutating shapes
repetitions = np.empty(n, dtype=np.int64)
acc = 1
for i in reversed(range(n)):
repetitions[i] = acc
acc *= shapes[i]

# Directly fill the cartesian product using fast C loop via _repeat_1d
for i in range(n):
_repeat_1d(nodes[i], repetitions[i], out[:, i])

return out


def mlinspace(a, b, nums, order='C'):
'''
"""
Constructs a regular cartesian grid

Parameters
Expand All @@ -73,13 +94,13 @@ def mlinspace(a, b, nums, order='C'):
-------
out : ndarray(ndim=2)
each line corresponds to one point of the product space
'''

a = np.asarray(a, dtype='float64')
b = np.asarray(b, dtype='float64')
nums = np.asarray(nums, dtype='int64')
nodes = [np.linspace(a[i], b[i], nums[i]) for i in range(len(nums))]

"""
# Convert just once, cast only if necessary, skip if already array
a = np.asarray(a, dtype='float64', order='C')
b = np.asarray(b, dtype='float64', order='C')
nums = np.asarray(nums, dtype='int64', order='C')
n_dims = nums.shape[0]
nodes = [np.linspace(a[i], b[i], nums[i]) for i in range(n_dims)]
return cartesian(nodes, order=order)


Expand Down