diff --git a/floodlight/models/space.py b/floodlight/models/space.py index 7843abca..4977c0d6 100644 --- a/floodlight/models/space.py +++ b/floodlight/models/space.py @@ -192,15 +192,26 @@ def _calc_cell_controls(self, xy1: XY, xy2: XY): np.nan, ) + # stack and reshape mesh coordinates to (M x 2) array + mesh_points = np.stack((self._meshx_, self._meshy_), axis=2).reshape(-1, 2) + # loop for t in range(T): - # stack and reshape player and mesh coordinates to (M x 2) arrays + # stack and reshape player coordinates to (M x 2) array player_points = np.hstack((xy1.frame(t), xy2.frame(t))).reshape(-1, 2) - mesh_points = np.stack((self._meshx_, self._meshy_), axis=2).reshape(-1, 2) # calculate pairwise distances and determine closest player pairwise_distances = cdist(mesh_points, player_points) - closest_player_index = np.nanargmin(pairwise_distances, axis=1) + + # identify valid segments without all-NaN slices + all_nan_mask = np.isnan(pairwise_distances).all(axis=1) + valid_mask = ~all_nan_mask + + # Init closest player index array + closest_player_index = np.full(pairwise_distances.shape[0], np.NaN) + + if np.any(valid_mask): + closest_player_index = np.nanargmin(pairwise_distances, axis=1) self._cell_controls_[t] = closest_player_index.reshape(self._meshx_.shape) def fit(self, xy1: XY, xy2: XY): @@ -352,20 +363,26 @@ def plot( .. image:: ../../_img/sample_dvm_plot_hex.png """ - # get ax - ax = ax or plt.subplots()[1] - # get colors and construct team color vector - team_color1, team_color2 = team_colors - color_vector = [team_color1] * self._N1_ + [team_color2] * self._N2_ + # check if t refers to an all-nan slice in the cell controlls + if np.isnan(self._cell_controls_[t]).all(): + pass + else: - # call plot by mesh type - if self._mesh_type == "square": - ax = self._plot_square(t, color_vector, ax=ax, **kwargs) - elif self._mesh_type == "hexagonal": - ax = self._plot_hexagonal(t, color_vector, ax=ax, **kwargs) + # get ax + ax = ax or plt.subplots()[1] - return ax + # get colors and construct team color vector + team_color1, team_color2 = team_colors + color_vector = [team_color1] * self._N1_ + [team_color2] * self._N2_ + + # call plot by mesh type + if self._mesh_type == "square": + ax = self._plot_square(t, color_vector, ax=ax, **kwargs) + elif self._mesh_type == "hexagonal": + ax = self._plot_hexagonal(t, color_vector, ax=ax, **kwargs) + + return ax def _plot_square( self,