Skip to content

Commit

Permalink
👽️ API changes in PyTorch 1.11 (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Mar 29, 2022
1 parent 48ec8c4 commit 7a56439
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 4 deletions.
2 changes: 1 addition & 1 deletion piqa/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:
assert_type(
input, target,
device=self.kernel.device,
dim_range=(3, -1),
dim_range=(3, 5),
n_channels=self.kernel.size(0),
value_range=(0., self.value_range),
)
Expand Down
2 changes: 1 addition & 1 deletion piqa/utils/color.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def color_conv(
weight: A weight kernel, :math:`(C', C)`.
"""

return F.conv1d(x, weight.view(weight.shape + (1,) * spatial(x)))
return F.linear(x.transpose(1, -1), weight).transpose(1, -1)


RGB_TO_YIQ = torch.tensor([
Expand Down
13 changes: 12 additions & 1 deletion piqa/utils/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,18 @@ def channel_conv(
[144., 153., 162.]]]])
"""

return F.conv1d(x, kernel, padding=padding, groups=x.size(1))
D = len(kernel.shape) - 2

assert D <= 3, "PyTorch only supports 1D, 2D or 3D convolutions."

if D == 3:
return F.conv3d(x, kernel, padding=padding, groups=x.size(-4))
elif D == 2:
return F.conv2d(x, kernel, padding=padding, groups=x.size(-3))
elif D == 1:
return F.conv1d(x, kernel, padding=padding, groups=x.size(-2))
else:
return F.linear(x, kernel.expand(x.size(-1)))


def channel_convs(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

setuptools.setup(
name='piqa',
version='1.2.1',
version='1.2.2',
packages=setuptools.find_packages(),
description='PyTorch Image Quality Assessment',
keywords='image quality processing metrics torch vision',
Expand Down

0 comments on commit 7a56439

Please sign in to comment.