diff --git a/libs/spandrel/spandrel/__helpers/main_registry.py b/libs/spandrel/spandrel/__helpers/main_registry.py index 2866ceb6..2e4f29e8 100644 --- a/libs/spandrel/spandrel/__helpers/main_registry.py +++ b/libs/spandrel/spandrel/__helpers/main_registry.py @@ -13,6 +13,7 @@ GRL, HAT, IPT, + PLKSR, RGT, SAFMN, SAFMNBCIE, @@ -78,4 +79,5 @@ ArchSupport.from_architecture(IPT.IPTArch()), ArchSupport.from_architecture(DRCT.DRCTArch()), ArchSupport.from_architecture(ESRGAN.ESRGANArch()), + ArchSupport.from_architecture(PLKSR.PLKSRArch()), ) diff --git a/libs/spandrel/spandrel/architectures/PLKSR/__init__.py b/libs/spandrel/spandrel/architectures/PLKSR/__init__.py new file mode 100644 index 00000000..5b7ea009 --- /dev/null +++ b/libs/spandrel/spandrel/architectures/PLKSR/__init__.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import math +from typing import Union + +from typing_extensions import override + +from spandrel.util import KeyCondition, get_seq_len + +from ...__helpers.model_descriptor import Architecture, ImageModelDescriptor, StateDict +from .arch.PLKSR import PLKSR +from .arch.RealPLKSR import RealPLKSR + +_PLKSR = Union[PLKSR, RealPLKSR] + + +class PLKSRArch(Architecture[_PLKSR]): + def __init__(self) -> None: + super().__init__( + id="PLKSR", + detect=KeyCondition.has_all( + "feats.0.weight", + "feats.1.lk.conv.weight", + "feats.1.refine.weight", + KeyCondition.has_any( + "feats.1.channe_mixer.0.weight", + "feats.1.channel_mixer.0.weight", + ), + ), + ) + + @override + def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]: + dim = 64 + n_blocks = 28 + scale = 4 + kernel_size = 17 + split_ratio = 0.25 + use_ea = True + + # RealPLKSR only + norm_groups = 4 # un-detectable + dropout = 0 # un-detectable + + dim = state_dict["feats.0.weight"].shape[0] + + total_feat_layers = get_seq_len(state_dict, "feats") + scale = math.isqrt( + state_dict[f"feats.{total_feat_layers - 1}.weight"].shape[0] // 3 + ) + + kernel_size = state_dict["feats.1.lk.conv.weight"].shape[2] + split_ratio = state_dict["feats.1.lk.conv.weight"].shape[0] / dim + + use_ea = "feats.1.attn.f.0.weight" in state_dict + + # Yes, the normal version has this typo. + if "feats.1.channe_mixer.0.weight" in state_dict: + n_blocks = total_feat_layers - 2 + + mixer_0_shape = state_dict["feats.1.channe_mixer.0.weight"].shape[2] + mixer_2_shape = state_dict["feats.1.channe_mixer.2.weight"].shape[2] + if mixer_0_shape == 3 and mixer_2_shape == 1: + ccm_type = "CCM" + elif mixer_0_shape == 3 and mixer_2_shape == 3: + ccm_type = "DCCM" + elif mixer_0_shape == 1 and mixer_2_shape == 3: + ccm_type = "ICCM" + else: + raise ValueError("Unknown CCM type") + more_tags = [ccm_type] + + model = PLKSR( + dim=dim, + upscaling_factor=scale, + n_blocks=n_blocks, + kernel_size=kernel_size, + split_ratio=split_ratio, + use_ea=use_ea, + ccm_type=ccm_type, + ) + # and RealPLKSR doesn't. This makes it really convenient to detect. + elif "feats.1.channel_mixer.0.weight" in state_dict: + more_tags = ["Real"] + + n_blocks = total_feat_layers - 3 + model = RealPLKSR( + dim=dim, + upscaling_factor=scale, + n_blocks=n_blocks, + kernel_size=kernel_size, + split_ratio=split_ratio, + use_ea=use_ea, + norm_groups=norm_groups, + dropout=dropout, + ) + else: + raise ValueError("Unknown model type") + + return ImageModelDescriptor( + model, + state_dict, + architecture=self, + purpose="Restoration" if scale == 1 else "SR", + tags=[f"{dim}dim", f"{n_blocks}nb", f"{kernel_size}ks", *more_tags], + supports_half=False, + supports_bfloat16=True, + scale=scale, + input_channels=3, + output_channels=3, + ) diff --git a/libs/spandrel/spandrel/architectures/PLKSR/arch/LICENSE b/libs/spandrel/spandrel/architectures/PLKSR/arch/LICENSE new file mode 100644 index 00000000..bf3ea526 --- /dev/null +++ b/libs/spandrel/spandrel/architectures/PLKSR/arch/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Dongheon Lee + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/libs/spandrel/spandrel/architectures/PLKSR/arch/PLKSR.py b/libs/spandrel/spandrel/architectures/PLKSR/arch/PLKSR.py new file mode 100644 index 00000000..b030271e --- /dev/null +++ b/libs/spandrel/spandrel/architectures/PLKSR/arch/PLKSR.py @@ -0,0 +1,336 @@ +from functools import partial +from typing import Literal, Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.init import trunc_normal_ + +from spandrel.util import store_hyperparameters + + +# Since Pytorch's interleave is not supported by CoreML, we use this function instead for mobile conversion +def repeat_interleave(x, n): + x = x.unsqueeze(2) + x = x.repeat(1, 1, n, 1, 1) + x = x.reshape(x.shape[0], x.shape[1] * n, x.shape[3], x.shape[4]) + return x + + +class CCM(nn.Sequential): + "Convolutional Channel Mixer" + + def __init__(self, dim: int): + super().__init__( + nn.Conv2d(dim, dim * 2, 3, 1, 1), + nn.GELU(), + nn.Conv2d(dim * 2, dim, 1, 1, 0), + ) + trunc_normal_(self[-1].weight, std=0.02) + + +class ICCM(nn.Sequential): + "Inverted Convolutional Channel Mixer" + + def __init__(self, dim: int): + super().__init__( + nn.Conv2d(dim, dim * 2, 1, 1, 0), + nn.GELU(), + nn.Conv2d(dim * 2, dim, 3, 1, 1), + ) + trunc_normal_(self[-1].weight, std=0.02) + + +class DCCM(nn.Sequential): + "Doubled Convolutional Channel Mixer" + + def __init__(self, dim: int): + super().__init__( + nn.Conv2d(dim, dim * 2, 3, 1, 1), + nn.GELU(), + nn.Conv2d(dim * 2, dim, 3, 1, 1), + ) + trunc_normal_(self[-1].weight, std=0.02) + + +class PLKConv2d(nn.Module): + "Partial Large Kernel Convolutional Layer" + + def __init__(self, dim, kernel_size, with_idt): + super().__init__() + self.with_idt = with_idt + self.conv = nn.Conv2d(dim, dim, kernel_size, 1, kernel_size // 2) + trunc_normal_(self.conv.weight, std=0.02) + self.idx = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.with_idt: + x[:, : self.idx] = x[:, : self.idx] + self.conv(x[:, : self.idx]) + else: + x[:, : self.idx] = self.conv(x[:, : self.idx]) + return x + + +class RectSparsePLKConv2d(nn.Module): + "Rectangular Sparse Partial Large Kernel Convolutional Layer (SLaK style)" + + def __init__(self, dim, kernel_size): + super().__init__() + self.idx = dim + m = kernel_size + n = kernel_size // 3 + self.mn_conv = nn.Conv2d(dim, dim, (m, n), 1, (m // 2, n // 2)) + self.nm_conv = nn.Conv2d(dim, dim, (n, m), 1, (n // 2, m // 2)) + self.nn_conv = nn.Conv2d(dim, dim, (n, n), 1, (n // 2, n // 2)) + + trunc_normal_(self.mn_conv.weight, std=0.02) + trunc_normal_(self.nm_conv.weight, std=0.02) + trunc_normal_(self.nn_conv.weight, std=0.02) + + def forward( + self, x: torch.Tensor + ) -> torch.Tensor: # No reparametrization since this is for a ablative study + if self.training: + x1, x2 = x[:, : self.idx], x[:, self.idx :] + x1 = self.mn_conv(x1) + self.nm_conv(x1) + self.nn_conv(x1) + return torch.cat([x1, x2], dim=1) + + else: + x[:, : self.idx] = ( + self.mn_conv(x[:, : self.idx]) + + self.nm_conv(x[:, : self.idx]) + + self.nn_conv(x[:, : self.idx]) + ) + return x + + +class SparsePLKConv2d(nn.Module): + "Sparse Partial Large Kernel Convolutional Layer (RepLKNet and UniRepLKNet style)" + + def __init__( + self, + dim, + max_kernel_size, + sub_kernel_sizes, + dilations, + use_max_kernel, + with_idt, + ): + super().__init__() + self.use_max_kernel = use_max_kernel + self.max_kernel_size = max_kernel_size + for k, d in zip(sub_kernel_sizes, dilations): + m_k = self._calc_rep_kernel_size(k, d) + if m_k > self.max_kernel_size: + self.max_kernel_size = m_k + self.with_idt = with_idt + + convs = [ + nn.Conv2d( + dim, dim, sub_kernel_size, 1, (sub_kernel_size // 2) * d, dilation=d + ) + for sub_kernel_size, d in zip(sub_kernel_sizes, dilations) + ] + if use_max_kernel: + convs.append( + nn.Conv2d(dim, dim, self.max_kernel_size, 1, self.max_kernel_size // 2) + ) + self.convs = nn.ModuleList(convs) + for m in self.convs: + trunc_normal_(m.weight, std=0.02) + self.idx = dim + self.is_convert = False + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_convert: + x[:, : self.idx, :, :] = self.conv(x[:, : self.idx, :, :]) + return x + else: + x1, x2 = torch.split(x, [self.idx, x.size(1) - self.idx], dim=1) + if self.with_idt: + out = x1 + else: + out = 0.0 + for conv in self.convs: + out = out + conv(x1) + return torch.cat([out, x2], dim=1) # type: ignore + + @staticmethod + def _calc_rep_kernel_size(ks, dilation): + return (ks - 1) * dilation + 1 + + @staticmethod + def _get_origin_kernel(kernel, dilation=1, p=0): + I = torch.ones((1, 1, 1, 1)).to(kernel.device) # noqa: E741 + if kernel.size(1) == 1: # Depth-wise Convolution + dilated = F.conv_transpose2d(kernel, I, stride=dilation) + else: + slices = [] # Dense or Group + for i in range(kernel.size(1)): + dilated = F.conv_transpose2d( + kernel[:, i : i + 1, :, :], I, stride=dilation + ) + slices.append(dilated) + dilated = torch.cat(slices, dim=1) + + # Pad boundary + if p != 0: + dilated = F.pad(dilated, (p, p, p, p)) + return dilated + + @staticmethod + def _dwc_to_dense(kernel): + n_groups = kernel.size(0) + + kernels = [] + for g in range(n_groups): + kernels.append( + torch.cat( + [ + kernel[g : (g + 1)] + if g == i + else torch.zeros_like(kernel[g : (g + 1)]) + for i in range(n_groups) + ], + dim=1, + ) + ) + return torch.cat(kernels, dim=0) + + +class EA(nn.Module): + "Element-wise Attention" + + def __init__(self, dim: int): + super().__init__() + self.f = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1), nn.Sigmoid()) + trunc_normal_(self.f[0].weight, std=0.02) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * self.f(x) + + +class PLKBlock(nn.Module): + def __init__( + self, + dim: int, + # CCM Rep options + ccm_type: Literal["CCM", "ICCM", "DCCM"], + # LK Options + max_kernel_size: int, + split_ratio: float, + lk_type: Literal["PLK", "SparsePLK", "RectSparsePLK"] = "PLK", + # Sparse Rep options + use_max_kernel: bool = False, + sparse_kernels: Sequence[int] = [5, 5, 5], + sparse_dilations: Sequence[int] = [2, 3, 4], + with_idt: bool = False, + # EA ablation + use_ea: bool = True, + ): + super().__init__() + + # Local Texture + if ccm_type == "CCM": + self.channe_mixer = CCM(dim) + elif ccm_type == "ICCM": + self.channe_mixer = ICCM(dim) + elif ccm_type == "DCCM": + self.channe_mixer = DCCM(dim) + else: + raise ValueError(f"Unknown CCM type: {ccm_type}") + + # Long-range Dependency + pdim = int(dim * split_ratio) + if lk_type == "PLK": + self.lk = PLKConv2d(pdim, max_kernel_size, with_idt) + elif lk_type == "SparsePLK": + self.lk = SparsePLKConv2d( + pdim, + max_kernel_size, + sparse_kernels, + sparse_dilations, + use_max_kernel, + with_idt, + ) + elif lk_type == "RectSparsePLK": + self.lk = RectSparsePLKConv2d(pdim, max_kernel_size) + else: + raise ValueError(f"Unknown LK type: {lk_type}") + + # Instance-dependent modulation + if use_ea: + self.attn = EA(dim) + else: + self.attn = nn.Identity() + + # Refinement + self.refine = nn.Conv2d(dim, dim, 1, 1, 0) + trunc_normal_(self.refine.weight, std=0.02) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_skip = x + x = self.channe_mixer(x) + x = self.lk(x) + x = self.attn(x) + x = self.refine(x) + return x + x_skip + + +@store_hyperparameters() +class PLKSR(nn.Module): + hyperparameters = {} + + def __init__( + self, + dim: int = 64, + n_blocks: int = 28, + upscaling_factor: int = 4, + # CCM options + ccm_type: Literal["CCM", "ICCM", "DCCM"] = "CCM", + # LK Options + kernel_size: int = 17, + split_ratio: float = 0.25, + lk_type: Literal["PLK", "SparsePLK", "RectSparsePLK"] = "PLK", + # LK Rep options + use_max_kernel: bool = False, + sparse_kernels: Sequence[int] = [5, 5, 5, 5], + sparse_dilations: Sequence[int] = [1, 2, 3, 4], + with_idt: bool = False, + # EA ablation + use_ea: bool = True, + ): + super().__init__() + + self.feats = nn.Sequential( + *[nn.Conv2d(3, dim, 3, 1, 1)] + + [ + PLKBlock( + dim, + ccm_type, + kernel_size, + split_ratio, + lk_type, + use_max_kernel, + sparse_kernels, + sparse_dilations, + with_idt, + use_ea, + ) + for _ in range(n_blocks) + ] + + [nn.Conv2d(dim, 3 * upscaling_factor**2, 3, 1, 1)] + ) + trunc_normal_(self.feats[0].weight, std=0.02) + trunc_normal_(self.feats[-1].weight, std=0.02) + + self.to_img = nn.PixelShuffle(upscaling_factor) + + self.repeat_op = partial( + torch.repeat_interleave, repeats=upscaling_factor**2, dim=1 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.feats(x) + self.repeat_op(x) + x = self.to_img(x) + return x diff --git a/libs/spandrel/spandrel/architectures/PLKSR/arch/RealPLKSR.py b/libs/spandrel/spandrel/architectures/PLKSR/arch/RealPLKSR.py new file mode 100644 index 00000000..e8b078c1 --- /dev/null +++ b/libs/spandrel/spandrel/architectures/PLKSR/arch/RealPLKSR.py @@ -0,0 +1,134 @@ +from functools import partial + +import torch +from torch import nn +from torch.nn.init import trunc_normal_ + +from spandrel.util import store_hyperparameters + + +class DCCM(nn.Sequential): + "Doubled Convolutional Channel Mixer" + + def __init__(self, dim: int): + super().__init__( + nn.Conv2d(dim, dim * 2, 3, 1, 1), + nn.Mish(), + nn.Conv2d(dim * 2, dim, 3, 1, 1), + ) + trunc_normal_(self[-1].weight, std=0.02) + + +class PLKConv2d(nn.Module): + "Partial Large Kernel Convolutional Layer" + + def __init__(self, dim: int, kernel_size: int): + super().__init__() + self.conv = nn.Conv2d(dim, dim, kernel_size, 1, kernel_size // 2) + trunc_normal_(self.conv.weight, std=0.02) + self.idx = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x[:, : self.idx] = self.conv(x[:, : self.idx]) + return x + + +class EA(nn.Module): + "Element-wise Attention" + + def __init__(self, dim: int): + super().__init__() + self.f = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1), nn.Sigmoid()) + trunc_normal_(self.f[0].weight, std=0.02) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * self.f(x) + + +class PLKBlock(nn.Module): + def __init__( + self, + dim: int, + kernel_size: int, + split_ratio: float, + norm_groups: int, + use_ea: bool = True, + ): + super().__init__() + + # Local Texture + self.channel_mixer = DCCM(dim) + + # Long-range Dependency + pdim = int(dim * split_ratio) + + # Conv Layer + self.lk = PLKConv2d(pdim, kernel_size) + + # Instance-dependent modulation + if use_ea: + self.attn = EA(dim) + else: + self.attn = nn.Identity() + + # Refinement + self.refine = nn.Conv2d(dim, dim, 1, 1, 0) + trunc_normal_(self.refine.weight, std=0.02) + + # Group Normalization + self.norm = nn.GroupNorm(norm_groups, dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_skip = x + x = self.channel_mixer(x) + x = self.lk(x) + x = self.attn(x) + x = self.refine(x) + x = self.norm(x) + + return x + x_skip + + +@store_hyperparameters() +class RealPLKSR(nn.Module): + """Partial Large Kernel CNNs for Efficient Super-Resolution: + https://arxiv.org/abs/2404.11848 + """ + + hyperparameters = {} + + def __init__( + self, + dim: int = 64, + n_blocks: int = 28, + upscaling_factor: int = 4, + kernel_size: int = 17, + split_ratio: float = 0.25, + use_ea: bool = True, + norm_groups: int = 4, + dropout: float = 0, + ): + super().__init__() + dropout = 0 + + self.feats = nn.Sequential( + *[nn.Conv2d(3, dim, 3, 1, 1)] + + [ + PLKBlock(dim, kernel_size, split_ratio, norm_groups, use_ea) + for _ in range(n_blocks) + ] + + [nn.Dropout2d(dropout)] + + [nn.Conv2d(dim, 3 * upscaling_factor**2, 3, 1, 1)] + ) + trunc_normal_(self.feats[0].weight, std=0.02) + trunc_normal_(self.feats[-1].weight, std=0.02) + + self.repeat_op = partial( + torch.repeat_interleave, repeats=upscaling_factor**2, dim=1 + ) + + self.to_img = nn.PixelShuffle(upscaling_factor) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.feats(x) + self.repeat_op(x) + return self.to_img(x) diff --git a/tests/__snapshots__/test_PLKSR.ambr b/tests/__snapshots__/test_PLKSR.ambr new file mode 100644 index 00000000..89494faa --- /dev/null +++ b/tests/__snapshots__/test_PLKSR.ambr @@ -0,0 +1,133 @@ +# serializer version: 1 +# name: test_PLKSR_official_tiny_x4 + ImageModelDescriptor( + architecture=PLKSRArch( + id='PLKSR', + name='PLKSR', + ), + input_channels=3, + output_channels=3, + purpose='SR', + scale=4, + size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False), + supports_bfloat16=True, + supports_half=False, + tags=list([ + '64dim', + '12nb', + '13ks', + 'DCCM', + ]), + tiling=, + ) +# --- +# name: test_PLKSR_official_x2 + ImageModelDescriptor( + architecture=PLKSRArch( + id='PLKSR', + name='PLKSR', + ), + input_channels=3, + output_channels=3, + purpose='SR', + scale=2, + size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False), + supports_bfloat16=True, + supports_half=False, + tags=list([ + '64dim', + '28nb', + '17ks', + 'DCCM', + ]), + tiling=, + ) +# --- +# name: test_PLKSR_official_x3 + ImageModelDescriptor( + architecture=PLKSRArch( + id='PLKSR', + name='PLKSR', + ), + input_channels=3, + output_channels=3, + purpose='SR', + scale=3, + size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False), + supports_bfloat16=True, + supports_half=False, + tags=list([ + '64dim', + '28nb', + '17ks', + 'DCCM', + ]), + tiling=, + ) +# --- +# name: test_PLKSR_official_x4 + ImageModelDescriptor( + architecture=PLKSRArch( + id='PLKSR', + name='PLKSR', + ), + input_channels=3, + output_channels=3, + purpose='SR', + scale=4, + size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False), + supports_bfloat16=True, + supports_half=False, + tags=list([ + '64dim', + '28nb', + '17ks', + 'DCCM', + ]), + tiling=, + ) +# --- +# name: test_RealPLKSR_2x + ImageModelDescriptor( + architecture=PLKSRArch( + id='PLKSR', + name='PLKSR', + ), + input_channels=3, + output_channels=3, + purpose='SR', + scale=2, + size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False), + supports_bfloat16=True, + supports_half=False, + tags=list([ + '64dim', + '28nb', + '17ks', + 'Real', + ]), + tiling=, + ) +# --- +# name: test_RealPLKSR_4x + ImageModelDescriptor( + architecture=PLKSRArch( + id='PLKSR', + name='PLKSR', + ), + input_channels=3, + output_channels=3, + purpose='SR', + scale=4, + size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False), + supports_bfloat16=True, + supports_half=False, + tags=list([ + '64dim', + '28nb', + '17ks', + 'Real', + ]), + tiling=, + ) +# --- diff --git a/tests/images/outputs/16x16/2x_realplksr_mssim_pretrain.png b/tests/images/outputs/16x16/2x_realplksr_mssim_pretrain.png new file mode 100644 index 00000000..a52a4972 Binary files /dev/null and b/tests/images/outputs/16x16/2x_realplksr_mssim_pretrain.png differ diff --git a/tests/images/outputs/16x16/4x_realplksr_mssim_pretrain.png b/tests/images/outputs/16x16/4x_realplksr_mssim_pretrain.png new file mode 100644 index 00000000..48fbe802 Binary files /dev/null and b/tests/images/outputs/16x16/4x_realplksr_mssim_pretrain.png differ diff --git a/tests/images/outputs/16x16/PLKSR_X2_DIV2K.png b/tests/images/outputs/16x16/PLKSR_X2_DIV2K.png new file mode 100644 index 00000000..0038e051 Binary files /dev/null and b/tests/images/outputs/16x16/PLKSR_X2_DIV2K.png differ diff --git a/tests/images/outputs/16x16/PLKSR_X3_DIV2K.png b/tests/images/outputs/16x16/PLKSR_X3_DIV2K.png new file mode 100644 index 00000000..d1a9a4cf Binary files /dev/null and b/tests/images/outputs/16x16/PLKSR_X3_DIV2K.png differ diff --git a/tests/images/outputs/16x16/PLKSR_X4_DIV2K.png b/tests/images/outputs/16x16/PLKSR_X4_DIV2K.png new file mode 100644 index 00000000..a73fd221 Binary files /dev/null and b/tests/images/outputs/16x16/PLKSR_X4_DIV2K.png differ diff --git a/tests/images/outputs/16x16/PLKSR_tiny_X4_DIV2K.png b/tests/images/outputs/16x16/PLKSR_tiny_X4_DIV2K.png new file mode 100644 index 00000000..d49feeb7 Binary files /dev/null and b/tests/images/outputs/16x16/PLKSR_tiny_X4_DIV2K.png differ diff --git a/tests/images/outputs/32x32/2x_realplksr_mssim_pretrain.png b/tests/images/outputs/32x32/2x_realplksr_mssim_pretrain.png new file mode 100644 index 00000000..88238ef4 Binary files /dev/null and b/tests/images/outputs/32x32/2x_realplksr_mssim_pretrain.png differ diff --git a/tests/images/outputs/32x32/4x_realplksr_mssim_pretrain.png b/tests/images/outputs/32x32/4x_realplksr_mssim_pretrain.png new file mode 100644 index 00000000..7e2fa15b Binary files /dev/null and b/tests/images/outputs/32x32/4x_realplksr_mssim_pretrain.png differ diff --git a/tests/images/outputs/32x32/PLKSR_X2_DIV2K.png b/tests/images/outputs/32x32/PLKSR_X2_DIV2K.png new file mode 100644 index 00000000..8af8090f Binary files /dev/null and b/tests/images/outputs/32x32/PLKSR_X2_DIV2K.png differ diff --git a/tests/images/outputs/32x32/PLKSR_X3_DIV2K.png b/tests/images/outputs/32x32/PLKSR_X3_DIV2K.png new file mode 100644 index 00000000..91b41be9 Binary files /dev/null and b/tests/images/outputs/32x32/PLKSR_X3_DIV2K.png differ diff --git a/tests/images/outputs/32x32/PLKSR_X4_DIV2K.png b/tests/images/outputs/32x32/PLKSR_X4_DIV2K.png new file mode 100644 index 00000000..0ffc273b Binary files /dev/null and b/tests/images/outputs/32x32/PLKSR_X4_DIV2K.png differ diff --git a/tests/images/outputs/32x32/PLKSR_tiny_X4_DIV2K.png b/tests/images/outputs/32x32/PLKSR_tiny_X4_DIV2K.png new file mode 100644 index 00000000..0045f108 Binary files /dev/null and b/tests/images/outputs/32x32/PLKSR_tiny_X4_DIV2K.png differ diff --git a/tests/images/outputs/64x64/2x_realplksr_mssim_pretrain.png b/tests/images/outputs/64x64/2x_realplksr_mssim_pretrain.png new file mode 100644 index 00000000..0639451e Binary files /dev/null and b/tests/images/outputs/64x64/2x_realplksr_mssim_pretrain.png differ diff --git a/tests/images/outputs/64x64/4x_realplksr_mssim_pretrain.png b/tests/images/outputs/64x64/4x_realplksr_mssim_pretrain.png new file mode 100644 index 00000000..aef34988 Binary files /dev/null and b/tests/images/outputs/64x64/4x_realplksr_mssim_pretrain.png differ diff --git a/tests/images/outputs/64x64/PLKSR_X2_DIV2K.png b/tests/images/outputs/64x64/PLKSR_X2_DIV2K.png new file mode 100644 index 00000000..40582dbe Binary files /dev/null and b/tests/images/outputs/64x64/PLKSR_X2_DIV2K.png differ diff --git a/tests/images/outputs/64x64/PLKSR_X3_DIV2K.png b/tests/images/outputs/64x64/PLKSR_X3_DIV2K.png new file mode 100644 index 00000000..23e016fa Binary files /dev/null and b/tests/images/outputs/64x64/PLKSR_X3_DIV2K.png differ diff --git a/tests/images/outputs/64x64/PLKSR_X4_DIV2K.png b/tests/images/outputs/64x64/PLKSR_X4_DIV2K.png new file mode 100644 index 00000000..84975f33 Binary files /dev/null and b/tests/images/outputs/64x64/PLKSR_X4_DIV2K.png differ diff --git a/tests/images/outputs/64x64/PLKSR_tiny_X4_DIV2K.png b/tests/images/outputs/64x64/PLKSR_tiny_X4_DIV2K.png new file mode 100644 index 00000000..92882694 Binary files /dev/null and b/tests/images/outputs/64x64/PLKSR_tiny_X4_DIV2K.png differ diff --git a/tests/test_PLKSR.py b/tests/test_PLKSR.py new file mode 100644 index 00000000..d4029e70 --- /dev/null +++ b/tests/test_PLKSR.py @@ -0,0 +1,150 @@ +from spandrel.architectures.PLKSR import PLKSR, PLKSRArch, RealPLKSR + +from .util import ( + ModelFile, + TestImage, + assert_image_inference, + assert_loads_correctly, + assert_size_requirements, + disallowed_props, + skip_if_unchanged, +) + +skip_if_unchanged(__file__) + + +def test_load(): + assert_loads_correctly( + PLKSRArch(), + lambda: PLKSR(), + lambda: PLKSR(dim=32), + lambda: PLKSR(dim=96), + lambda: PLKSR(n_blocks=6), + lambda: PLKSR(n_blocks=35), + lambda: PLKSR(upscaling_factor=2), + lambda: PLKSR(upscaling_factor=6), + lambda: PLKSR(ccm_type="DCCM"), + lambda: PLKSR(ccm_type="CCM"), + lambda: PLKSR(ccm_type="ICCM"), + lambda: PLKSR(kernel_size=9), + lambda: PLKSR(kernel_size=27), + lambda: PLKSR(split_ratio=0.5), + lambda: PLKSR(split_ratio=0.75), + lambda: PLKSR(use_ea=False), + lambda: RealPLKSR(), + lambda: RealPLKSR(dim=32), + lambda: RealPLKSR(dim=96), + lambda: RealPLKSR(n_blocks=6), + lambda: RealPLKSR(n_blocks=35), + lambda: RealPLKSR(upscaling_factor=2), + lambda: RealPLKSR(upscaling_factor=6), + lambda: RealPLKSR(kernel_size=9), + lambda: RealPLKSR(kernel_size=27), + lambda: RealPLKSR(split_ratio=0.5), + lambda: RealPLKSR(split_ratio=0.75), + lambda: RealPLKSR(use_ea=False), + ) + + +def test_size_requirements(): + file = ModelFile.from_url( + "https://drive.google.com/file/d/12ek1vitEporWc5qqaYo6AMy0-RYlRqu8/view", + name="4x_realplksr_mssim_pretrain.pth", + ) + assert_size_requirements(file.load_model()) + + file = ModelFile.from_url( + "https://drive.google.com/file/d/1PA3QElJYlgpPYKl0zQ9_D1pnuYfC1vtt/view", + name="PLKSR_X2_DIV2K.pth", + ) + assert_size_requirements(file.load_model()) + + +def test_PLKSR_official_x4(snapshot): + file = ModelFile.from_url( + "https://drive.google.com/file/d/1YtoOjK7vsfVrHFqWGvdjFNcOay8WUuJ7/view", + name="PLKSR_X4_DIV2K.pth", + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, PLKSR) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) + + +def test_PLKSR_official_tiny_x4(snapshot): + file = ModelFile.from_url( + "https://drive.google.com/file/d/1d8d_6TF0SrEMiX1jrnLqKnPdKjDJahOK/view", + name="PLKSR_tiny_X4_DIV2K.pth", + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, PLKSR) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) + + +def test_PLKSR_official_x3(snapshot): + file = ModelFile.from_url( + "https://drive.google.com/file/d/1W8phbKFTOYL-AnlMJnjx2NWDHZY8jVWW/view", + name="PLKSR_X3_DIV2K.pth", + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, PLKSR) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) + + +def test_PLKSR_official_x2(snapshot): + file = ModelFile.from_url( + "https://drive.google.com/file/d/1PA3QElJYlgpPYKl0zQ9_D1pnuYfC1vtt/view", + name="PLKSR_X2_DIV2K.pth", + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, PLKSR) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) + + +def test_RealPLKSR_4x(snapshot): + file = ModelFile.from_url( + "https://drive.google.com/file/d/12ek1vitEporWc5qqaYo6AMy0-RYlRqu8/view", + name="4x_realplksr_mssim_pretrain.pth", + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, RealPLKSR) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + ) + + +def test_RealPLKSR_2x(snapshot): + file = ModelFile.from_url( + "https://drive.google.com/file/d/1GAdf5VOqYa5ntswT9sYsKKZ2Z7OQp7gO/view", + name="2x_realplksr_mssim_pretrain.pth", + ) + model = file.load_model() + assert model == snapshot(exclude=disallowed_props) + assert isinstance(model.model, RealPLKSR) + assert_image_inference( + file, + model, + [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64], + )