Skip to content

Commit

Permalink
🩹 Fix assert with upper/lower bound lists (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Jul 26, 2023
1 parent ee7dfc7 commit 06cc019
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torchist/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def quantize(x: Tensor, bins: Tensor, low: Tensor, upp: Tensor) -> Tensor:

x = (x - low) / (upp - low) # in [0.0, 1.0]
x = (bins * x).long() # in [0, bins]
x = torch.clip(x, min=0, max=bins - 1) # in [0, bins)
x = torch.clip(x, min=None, max=bins - 1) # in [0, bins)

return x

Expand Down Expand Up @@ -149,12 +149,12 @@ def histogramdd(

edges = pack

assert torch.all(upp > low), "The upper bound must be strictly larger than the lower bound"

bins = torch.as_tensor(bins).squeeze().long()
low = torch.as_tensor(low).squeeze().to(x)
upp = torch.as_tensor(upp).squeeze().to(x)

assert torch.all(upp > low), "The upper bound must be strictly larger than the lower bound"

if weights is not None:
weights = weights.flatten()

Expand Down

0 comments on commit 06cc019

Please sign in to comment.