Skip to content

Commit

Permalink
update A
Browse files Browse the repository at this point in the history
  • Loading branch information
leVirve committed Aug 3, 2018
1 parent d2ca602 commit f25ac10
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 48 deletions.
1 change: 0 additions & 1 deletion .travis.yml
Expand Up @@ -15,7 +15,6 @@ install:
- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION pytorch scipy opencv -c pytorch
- source activate test-environment
- pip install --upgrade pytest pytest-cov flake8
- pip install -r requirements.txt

before_script:
- flake8 .
Expand Down
3 changes: 1 addition & 2 deletions onegan/extension/__init__.py
Expand Up @@ -7,9 +7,8 @@
from .tensorboard import * # noqa
from .imagesaver import * # noqa
from .history import * # noqa
from .colorize import * # noqa
from .tensorcollect import * # noqa


__all__ = ('Checkpoint', 'TensorBoardLogger', 'TensorCollector',
'ImageSaver', 'Colorizer', 'History')
'ImageSaver', 'History')
38 changes: 0 additions & 38 deletions onegan/extension/colorize.py

This file was deleted.

4 changes: 2 additions & 2 deletions onegan/extension/tensorboard.py
Expand Up @@ -5,7 +5,6 @@

import tensorboardX

from onegan.visualizer import image as oneimage
from .base import Extension, unique_experiment_name


Expand Down Expand Up @@ -80,7 +79,8 @@ def image(self, images_dict, epoch, prefix='') -> None:
num_summaried_img = len(next(iter(images_dict.values())))
self._tag_base_counter += num_summaried_img

[self.writer.add_image(f'{prefix}{tag}/{self._tag_base_counter + i}', oneimage.img_normalize(image), epoch)
[self.writer.add_image(f'{prefix}{tag}/{self._tag_base_counter + i}',
image, epoch)
for tag, images in images_dict.items()
for i, image in enumerate(images)]

Expand Down
30 changes: 25 additions & 5 deletions onegan/visualizer/image.py
Expand Up @@ -87,18 +87,38 @@ def label_as_rgb_visual(x, colors):
Args:
x (torch.Tensor): shape in (N, H, W) or (N, 1, H, W)
colors (tuple or list): list of RGB colors, range from 0 to 1.
Returns:
canvas (torch.Tensor): colorized tensor in the shape of (N, C, H, W)
"""
if x.dim() == 4:
x = x.squeeze(1)
assert x.dim() == 3

n, h, w = x.size()
palette = torch.tensor(colors).to(x.device)
canvas = torch.zeros(n, h, w, 3).to(x.device)
labels = torch.arange(x.max() + 1).to(x)

# for i, lbl_id in enumerate(range(x.max() + 1)):
# lbl_id = torch.tensor(lbl_id).to(x)
for i, lbl_id in enumerate(torch.arange(x.max() + 1).to(x)):
canvas = torch.zeros(n, h, w, 3).to(x.device)
for color, lbl_id in zip(palette, labels):
if canvas[x == lbl_id].size(0):
canvas[x == lbl_id] = palette[i]
canvas[x == lbl_id] = color

return canvas.permute(0, 3, 1, 2)


def make_bar(images):
""" Make a list of iamges turn to a long thumbnail.
"""
img = images[0]

n, c, h, w = img.size()
pad = torch.ones(n, c, h, 5)

outs = []
for im in images:
outs.append(im)
outs.append(pad)

return torch.cat(outs[:-1], dim=3)

0 comments on commit f25ac10

Please sign in to comment.