Skip to content

Commit

Permalink
additional cli
Browse files Browse the repository at this point in the history
  • Loading branch information
bryandlee committed Mar 3, 2021
1 parent adc0516 commit 09cc72d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
14 changes: 9 additions & 5 deletions model.py
Expand Up @@ -87,18 +87,22 @@ def __init__(self, ):
nn.Tanh()
)

def forward(self, input):
def forward(self, input, align_corners=True):
out = self.block_a(input)
half_size = out.size()[-2:]
out = self.block_b(out)
out = self.block_c(out)

out = F.interpolate(out, half_size, mode="bilinear", align_corners=True)
# out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
if align_corners:
out = F.interpolate(out, half_size, mode="bilinear", align_corners=True)
else:
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
out = self.block_d(out)

out = F.interpolate(out, input.size()[-2:], mode="bilinear", align_corners=True)
# out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
if align_corners:
out = F.interpolate(out, input.size()[-2:], mode="bilinear", align_corners=True)
else:
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
out = self.block_e(out)

out = self.out_layer(out)
Expand Down
24 changes: 17 additions & 7 deletions test.py
Expand Up @@ -11,15 +11,16 @@
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

def load_image(image_path):
def load_image(image_path, x32=False):
img = cv2.imread(image_path).astype(np.float32)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
h, w = img.shape[:2]

def to_32s(x):
return 256 if x < 256 else x - x%32
if x32: # resize image to multiple of 32s
def to_32s(x):
return 256 if x < 256 else x - x%32
img = cv2.resize(img, (to_32s(w), to_32s(h)))

img = cv2.resize(img, (to_32s(w), to_32s(h)))
img = torch.from_numpy(img)
img = img/127.5 - 1.0
return img
Expand All @@ -36,14 +37,14 @@ def test(args):
os.makedirs(args.output_dir, exist_ok=True)

for image_name in sorted(os.listdir(args.input_dir)):
if os.path.splitext(image_name)[-1] not in [".jpg", ".png", ".bmp", ".tiff"]:
if os.path.splitext(image_name)[-1].lower() not in [".jpg", ".png", ".bmp", ".tiff"]:
continue

image = load_image(os.path.join(args.input_dir, image_name))
image = load_image(os.path.join(args.input_dir, image_name), args.x32)

with torch.no_grad():
input = image.permute(2, 0, 1).unsqueeze(0).to(device)
out = net(input).squeeze(0).permute(1, 2, 0).cpu().numpy()
out = net(input, args.upsample_align).squeeze(0).permute(1, 2, 0).cpu().numpy()
out = (out + 1)*127.5
out = np.clip(out, 0, 255).astype(np.uint8)

Expand Down Expand Up @@ -74,6 +75,15 @@ def test(args):
type=str,
default='cuda:0',
)
parser.add_argument(
'--upsample_align',
type=bool,
default=False,
)
parser.add_argument(
'--x32',
action="store_true",
)
args = parser.parse_args()

test(args)
Expand Down

0 comments on commit 09cc72d

Please sign in to comment.