Skip to content

Commit

Permalink
Fix ind device (#88)
Browse files Browse the repository at this point in the history
Co-authored-by: sanjays <sanjayss34@users.noreply.github.com>
  • Loading branch information
sanjayss34 and sanjayss34 committed Jan 14, 2021
1 parent 9f93432 commit 8372901
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torch_struct/semirings/sparse_max.py
Expand Up @@ -53,7 +53,7 @@ def backward(ctx, grad_output):
def project_simplex(v, dim, z=1):
v_sorted, _ = torch.sort(v, dim=dim, descending=True)
cssv = torch.cumsum(v_sorted, dim=dim) - z
ind = torch.arange(1, 1 + v.shape[dim]).to(dtype=v.dtype)
ind = torch.arange(1, 1 + v.shape[dim]).to(dtype=v.dtype).to(v.device)
cond = v_sorted - cssv / ind >= 0
k = cond.sum(dim=dim, keepdim=True)
tau = cssv.gather(dim, k - 1) / k.to(dtype=v.dtype)
Expand Down

0 comments on commit 8372901

Please sign in to comment.