@@ -2531,35 +2531,45 @@ def quiver(self, *args,
25312531 Any additional keyword arguments are delegated to
25322532 :class:`~matplotlib.collections.LineCollection`
25332533 """
2534- def calc_arrow (uvw , angle = 15 ):
2535- """
2536- To calculate the arrow head. uvw should be a unit vector.
2537- We normalize it here:
2538- """
2539- # get unit direction vector perpendicular to (u, v, w)
2540- norm = np .linalg .norm (uvw [:2 ])
2541- if norm > 0 :
2542- x = uvw [1 ] / norm
2543- y = - uvw [0 ] / norm
2544- else :
2545- x , y = 0 , 1
2534+ def calc_arrows (UVW , angle = 15 ):
2535+ # get unit direction vector perpendicular to (u,v,w)
2536+ x = UVW [:, 0 ]
2537+ y = UVW [:, 1 ]
2538+ norm = np .linalg .norm (UVW [:, :2 ], axis = 1 )
2539+ x_p = np .divide (y , norm , where = norm != 0 , out = np .zeros_like (x ))
2540+ y_p = np .divide (- x , norm , where = norm != 0 , out = np .ones_like (x ))
25462541
25472542 # compute the two arrowhead direction unit vectors
25482543 ra = math .radians (angle )
25492544 c = math .cos (ra )
25502545 s = math .sin (ra )
25512546
25522547 # construct the rotation matrices
2553- Rpos = np .array ([[c + (x ** 2 )* (1 - c ), x * y * (1 - c ), y * s ],
2554- [y * x * (1 - c ), c + (y ** 2 )* (1 - c ), - x * s ],
2555- [- y * s , x * s , c ]])
2548+ Rpos = np .array (
2549+ [[c + (x_p ** 2 ) * (1 - c ), x_p * y_p * (1 - c ), y_p * s ],
2550+ [y_p * x_p * (1 - c ), c + (y_p ** 2 ) * (1 - c ), - x_p * s ],
2551+ [- y_p * s , x_p * s , np .full_like (x_p , c )]])
2552+ Rpos = Rpos .transpose (2 , 0 , 1 )
2553+
25562554 # opposite rotation negates all the sin terms
25572555 Rneg = Rpos .copy ()
2558- Rneg [[0 , 1 , 2 , 2 ], [2 , 2 , 0 , 1 ]] = \
2559- - Rneg [[0 , 1 , 2 , 2 ], [2 , 2 , 0 , 1 ]]
2556+ Rneg [:, [0 , 1 , 2 , 2 ], [2 , 2 , 0 , 1 ]] = \
2557+ - Rneg [:, [0 , 1 , 2 , 2 ], [2 , 2 , 0 , 1 ]]
2558+
2559+ # expand dimensions for batched matrix multiplication
2560+ UVW = np .expand_dims (UVW , axis = - 1 )
25602561
25612562 # multiply them to get the rotated vector
2562- return Rpos .dot (uvw ), Rneg .dot (uvw )
2563+ Rpos_vecs = np .matmul (Rpos , UVW )
2564+ Rneg_vecs = np .matmul (Rneg , UVW )
2565+
2566+ # transpose for concatenation
2567+ Rpos_vecs = Rpos_vecs .transpose (0 , 2 , 1 )
2568+ Rneg_vecs = Rneg_vecs .transpose (0 , 2 , 1 )
2569+
2570+ head_dirs = np .concatenate ([Rpos_vecs , Rneg_vecs ], axis = 1 )
2571+
2572+ return head_dirs
25632573
25642574 had_data = self .has_data ()
25652575
@@ -2621,7 +2631,7 @@ def calc_arrow(uvw, angle=15):
26212631 # compute the shaft lines all at once with an outer product
26222632 shafts = (XYZ - np .multiply .outer (shaft_dt , UVW )).swapaxes (0 , 1 )
26232633 # compute head direction vectors, n heads x 2 sides x 3 dimensions
2624- head_dirs = np . array ([ calc_arrow ( d ) for d in UVW ] )
2634+ head_dirs = calc_arrows ( UVW )
26252635 # compute all head lines at once, starting from the shaft ends
26262636 heads = shafts [:, :1 ] - np .multiply .outer (arrow_dt , head_dirs )
26272637 # stack left and right head lines together
0 commit comments