Skip to content

Commit

Permalink
Detach indices in Interpolation()
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobrgardner committed Apr 16, 2018
1 parent f70f52f commit 8b50f2f
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions gpytorch/utils/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def interpolate(self, x_grid, x_target, interp_points=range(-2, 2)):
lower_grid_pt_idxs = torch.floor((x_target[:, i] - x_grid[i, 0]) / grid_delta).squeeze()
lower_pt_rel_dists = (x_target[:, i] - x_grid[i, 0]) / grid_delta - lower_grid_pt_idxs
lower_grid_pt_idxs = lower_grid_pt_idxs - interp_points.max()
lower_grid_pt_idxs.detach_()

scaled_dist = lower_pt_rel_dists.unsqueeze(-1) + interp_points_flip.unsqueeze(-2)
dim_interp_values = self._cubic_interpolation_kernel(scaled_dist)
Expand Down

0 comments on commit 8b50f2f

Please sign in to comment.