Skip to content

Commit

Permalink
🐛 Fix MDSI runtime on CUDA
Browse files Browse the repository at this point in the history
Mean_cuda is not implemented for complex types (pytorch/pytorch#46982)
  • Loading branch information
francois-rozet committed Dec 10, 2020
1 parent d277888 commit 0e89000
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion piqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
specific image quality assessement metric.
"""

__version__ = '1.0.5'
__version__ = '1.0.6'
4 changes: 3 additions & 1 deletion piqa/mdsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def mdsi(

# Mean deviation similarity
gcs_q = gcs ** q
score = (gcs_q - gcs_q.mean((-1, -2), keepdim=True)).abs()
gcs_q_avg = torch.view_as_real(gcs_q).mean((-2, -3), True)
gcs_q_avg = torch.view_as_complex(gcs_q_avg)
score = (gcs_q - gcs_q_avg).abs()
mds = (score ** rho).mean((-1, -2)) ** (o / rho)

return mds
Expand Down

0 comments on commit 0e89000

Please sign in to comment.