Skip to content

Commit

Permalink
Add validation tiling, fix CRAFT inference and tweak default configs
Browse files Browse the repository at this point in the history
  • Loading branch information
muslll committed May 9, 2024
1 parent 00c5fe0 commit 15e1a2d
Show file tree
Hide file tree
Showing 41 changed files with 65 additions and 627 deletions.
14 changes: 13 additions & 1 deletion neosr/archs/craft_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,7 @@ def __init__(self,
super(craft, self).__init__()

self.split_size = (split_size_0, split_size_1)
self.window_size = window_size

num_in_ch = in_chans
num_out_ch = in_chans
Expand Down Expand Up @@ -782,9 +783,16 @@ def forward_features(self, x):

return x

def forward(self, x):
_, _, h_old, w_old = x.size()
h_pad = (h_old // self.window_size + 1) * self.window_size - h_old
w_pad = (w_old // self.window_size + 1) * self.window_size - w_old
pad = h_pad != 0 or w_pad != 0

if pad:
x = torch.cat([x, torch.flip(x, [2])], 2)[:, :, : h_old + h_pad, :]
x = torch.cat([x, torch.flip(x, [3])], 3)[:, :, :, : w_old + w_pad]

def forward(self, x):
self.h, self.w = x.shape[2:]
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
Expand All @@ -794,5 +802,9 @@ def forward(self, x):

x = self.upsample(x)
x = x / self.img_range + self.mean

if pad:
x = x[..., : h_old * self.upscale, : w_old * self.upscale]

return x

62 changes: 16 additions & 46 deletions neosr/models/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,47 +490,18 @@ def update_learning_rate(self, current_iter, warmup_iter=-1):
self._set_lr(warm_up_lr_l)

def test(self):
# pad to multiplication of window_size
if self.opt.get('window_size', None) is not None:
window_size = self.opt.get('window_size')
elif self.opt.get('patch_size', None) is not None:
window_size = self.opt.get('patch_size')
else:
window_size = 8

scale = self.opt.get('scale', 1)
mod_pad_h, mod_pad_w = 0, 0
_, _, h, w = self.lq.size()
if h % window_size != 0:
mod_pad_h = window_size - h % window_size
if w % window_size != 0:
mod_pad_w = window_size - w % window_size
img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect')

self.net_g.eval()
with torch.inference_mode():
self.output = self.net_g(img)
self.net_g.train()

_, _, h, w = self.output.size()
self.output = self.output[:, :, 0:h -
mod_pad_h * scale, 0:w - mod_pad_w * scale]

'''
# TODO: verify
def test(self):
self.tile = self.opt['val'].get('tile', False)
if not self.tile:
self.tile = self.opt['val'].get('tile', -1)
if self.tile == -1:
self.net_g.eval()
with torch.no_grad():
with torch.inference_mode():
self.output = self.net_g(self.lq)
self.net_g.train()

# test by partitioning
else:
_, C, h, w = self.lq.size()
split_token_h = h // 200 + 1 # number of horizontal cut sections
split_token_w = w // 200 + 1 # number of vertical cut sections
split_token_h = h // self.tile + 1 # number of horizontal cut sections
split_token_w = w // self.tile + 1 # number of vertical cut sections

patch_size_tmp_h = split_token_h
patch_size_tmp_w = split_token_w
Expand All @@ -541,7 +512,7 @@ def test(self):
mod_pad_h = patch_size_tmp_h - h % patch_size_tmp_h
if w % patch_size_tmp_w != 0:
mod_pad_w = patch_size_tmp_w - w % patch_size_tmp_w

img = self.lq
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, :h+mod_pad_h, :]
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, :w+mod_pad_w]
Expand All @@ -553,7 +524,6 @@ def test(self):
# overlapping
shave_h = 16
shave_w = 16
scale = self.scale # self.opt.get('scale', 1)
ral = H // split_h
row = W // split_w
slices = [] # list of partition borders
Expand Down Expand Up @@ -583,31 +553,31 @@ def test(self):
img_chops.append(img[..., top, left])

self.net_g.eval()
with torch.no_grad():
with torch.inference_mode():
outputs = []
for chop in img_chops:
out = self.net_g(chop) # image processing of each partition
outputs.append(out)
_img = torch.zeros(1, C, H * scale, W * scale)
_img = torch.zeros(1, C, H * self.scale, W * self.scale)
# merge
for i in range(ral):
for j in range(row):
top = slice(i * split_h * scale, (i + 1) * split_h * scale)
left = slice(j * split_w * scale, (j + 1) * split_w * scale)
top = slice(i * split_h * self.scale, (i + 1) * split_h * self.scale)
left = slice(j * split_w * self.scale, (j + 1) * split_w * self.scale)
if i == 0:
_top = slice(0, split_h * scale)
_top = slice(0, split_h * self.scale)
else:
_top = slice(shave_h * scale, (shave_h + split_h) * scale)
_top = slice(shave_h * self.scale, (shave_h + split_h) * self.scale)
if j == 0:
_left = slice(0, split_w * scale)
_left = slice(0, split_w * self.scale)
else:
_left = slice(shave_w * scale, (shave_w + split_w) * scale)
_left = slice(shave_w * self.scale, (shave_w + split_w) * self.scale)
_img[..., top, left] = outputs[i * row + j][..., _top, _left]
self.output = _img
self.net_g.train()
_, _, h, w = self.output.size()
self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale]
'''
self.output = self.output[:, :, 0:h - mod_pad_h * self.scale, 0:w - mod_pad_w * self.scale]


@torch.no_grad()
def feed_data(self, data):
Expand Down
26 changes: 0 additions & 26 deletions options/test_agdn.yml

This file was deleted.

132 changes: 0 additions & 132 deletions options/train_agdn.yml

This file was deleted.

Loading

0 comments on commit 15e1a2d

Please sign in to comment.