Skip to content

Commit

Permalink
Support DiffBIR SwinIR models.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Sep 7, 2023
1 parent cb080e7 commit 8be4643
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
17 changes: 16 additions & 1 deletion comfy_extras/chainner_models/architecture/SwinIR.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,7 @@ def __init__(
num_in_ch = in_chans
num_out_ch = in_chans
supports_fp16 = True
self.start_unshuffle = 1

self.model_arch = "SwinIR"
self.sub_type = "SR"
Expand Down Expand Up @@ -874,6 +875,11 @@ def __init__(
else 64
)

if "conv_first.1.weight" in self.state:
self.state["conv_first.weight"] = self.state.pop("conv_first.1.weight")
self.state["conv_first.bias"] = self.state.pop("conv_first.1.bias")
self.start_unshuffle = round(math.sqrt(self.state["conv_first.weight"].shape[1] // 3))

num_in_ch = self.state["conv_first.weight"].shape[1]
in_chans = num_in_ch
if "conv_last.weight" in state_keys:
Expand Down Expand Up @@ -968,7 +974,7 @@ def __init__(
self.depths = depths
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.scale = upscale
self.scale = upscale / self.start_unshuffle
self.upsampler = upsampler
self.img_size = img_size
self.img_range = img_range
Expand Down Expand Up @@ -1101,6 +1107,9 @@ def __init__(
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
if self.upscale == 4:
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
elif self.upscale == 8:
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
Expand Down Expand Up @@ -1157,6 +1166,9 @@ def forward(self, x):
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range

if self.start_unshuffle > 1:
x = torch.nn.functional.pixel_unshuffle(x, self.start_unshuffle)

if self.upsampler == "pixelshuffle":
# for classical SR
x = self.conv_first(x)
Expand Down Expand Up @@ -1186,6 +1198,9 @@ def forward(self, x):
)
)
)
elif self.upscale == 8:
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.lrelu(self.conv_up3(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.conv_last(self.lrelu(self.conv_hr(x)))
else:
# for image denoising and JPEG compression artifact reduction
Expand Down
2 changes: 2 additions & 0 deletions comfy_extras/nodes_upscale_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def INPUT_TYPES(s):
def load_model(self, model_name):
model_path = folder_paths.get_full_path("upscale_models", model_name)
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""})
out = model_loading.load_state_dict(sd).eval()
return (out, )

Expand Down

0 comments on commit 8be4643

Please sign in to comment.