Skip to content

Commit

Permalink
make new image visualizer for label map
Browse files Browse the repository at this point in the history
  • Loading branch information
leVirve committed Jul 23, 2018
1 parent 8dfeb12 commit 883fdd9
Showing 1 changed file with 27 additions and 3 deletions.
30 changes: 27 additions & 3 deletions onegan/visualizer/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import torch


DEFAULT_COLORS = [[.2, .8, .4], [.6, .4, .8], [.8, .4, .2]]


def img_normalize(img, val_range=None):
''' Normalize the tensor into (0, 1)
Expand Down Expand Up @@ -45,7 +48,7 @@ def make_valid_batched_dim(x):
return torch.cat(channels, dim=1)


def as_rgb_visual(tensor, vallina=False):
def as_rgb_visual(tensor, vallina=False, colors=None):
''' Make tensor into colorful image
Args:
tensor: shape in (C, H, W) or (N, C, H, W)
Expand All @@ -60,13 +63,34 @@ def batched_colorize(batched_x):
return stack_visuals(*channels)
else:
dtype = channels[0].type()
colors = torch.tensor([[.2, .8, .4], [.6, .4, .8], [.8, .4, .2]]).type(dtype)
palette = torch.tensor(colors or DEFAULT_COLORS).type(dtype)

canvas = torch.zeros(n, h, w, 3).to(batched_x)
for i in range(c):
canvas += channels[i].unsqueeze(-1) * colors[i]
canvas += channels[i].unsqueeze(-1) * palette[i]
return canvas.permute(0, 3, 1, 2)

if tensor.dim() == 3:
return batched_colorize(tensor.unsqueeze(0)).squeeze(-1)

return batched_colorize(tensor)


def label_as_rgb_visual(x, colors):
''' Make segment tensor into colorful image
Args:
tensor: shape in (N, H, W) or (N, 1, H, W)
'''
if x.dim() == 4:
x = x.squeeze(1)

n, h, w = x.size()
palette = torch.tensor(colors).to(x.device)
canvas = torch.zeros(n, h, w, 3).to(x.device)

for i, lbl_id in enumerate(torch.unique(x.cpu())):
lbl_id = lbl_id.to(x)
if canvas[x == lbl_id].size(0):
canvas[x == lbl_id] = palette[i]

return canvas.permute(0, 3, 1, 2)

0 comments on commit 883fdd9

Please sign in to comment.