Skip to content

Commit

Permalink
Add support to save GIFs (#680)
Browse files Browse the repository at this point in the history
* Add support to save GIFs

* Improve tests coverage

* Improve tests coverage
  • Loading branch information
fepegar committed Oct 3, 2021
1 parent 6deb015 commit ca4890a
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 3 deletions.
9 changes: 9 additions & 0 deletions tests/data/test_image.py
Expand Up @@ -201,3 +201,12 @@ def assert_shape(shape_in, shape_out):
assert_shape((5, 5, 5), (1, 5, 5, 5))
assert_shape((1, 5, 5, 5), (1, 5, 5, 5))
assert_shape((4, 5, 5, 5), (4, 5, 5, 5))

def test_fast_gif(self):
with self.assertWarns(UserWarning):
with tempfile.NamedTemporaryFile(suffix='.gif') as f:
self.sample_subject.t1.to_gif(0, 0.0001, f.name)

def test_gif_rgb(self):
with tempfile.NamedTemporaryFile(suffix='.gif') as f:
tio.ScalarImage(tensor=torch.rand(3, 4, 5, 6)).to_gif(0, 1, f.name)
37 changes: 37 additions & 0 deletions torchio/data/image.py
Expand Up @@ -584,6 +584,43 @@ def as_pil(self, transpose=True):
array = tensor.clamp(0, 255).numpy()[0]
return ImagePIL.fromarray(array.astype(np.uint8))

def to_gif(
self,
axis: int,
duration: float, # of full gif
output_path: TypePath,
loop: int = 0,
rescale: bool = True,
optimize: bool = True,
reverse: bool = False,
) -> None:
"""Save an animated GIF of the image.
Args:
axis: Spatial axis (0, 1 or 2).
duration: Duration of the full animation in seconds.
output_path: Path to the output GIF file.
loop: Number of times the GIF should loop.
``0`` means that it will loop forever.
rescale: Use :class:`~torchio.transforms.preprocessing.intensity.rescale.RescaleIntensity`
to rescale the intensity values to :math:`[0, 255]`.
optimize: If ``True``, attempt to compress the palette by
eliminating unused colors. This is only useful if the palette
can be compressed to the next smaller power of 2 elements.
reverse: Reverse the temporal order of frames.
""" # noqa: E501
from ..visualization import make_gif # avoid circular import
make_gif(
self.data,
axis,
duration,
output_path,
loop=loop,
rescale=rescale,
optimize=optimize,
reverse=reverse,
)

def get_center(self, lps: bool = False) -> TypeTripletFloat:
"""Get image center in RAS+ or LPS+ coordinates.
Expand Down
68 changes: 65 additions & 3 deletions torchio/visualization.py
@@ -1,8 +1,13 @@
import warnings

import torch
import numpy as np

from .data.image import Image, LabelMap
from .typing import TypePath
from .data.subject import Subject
from .data.image import Image, LabelMap
from .transforms.preprocessing.spatial.to_canonical import ToCanonical
from .transforms.preprocessing.intensity.rescale import RescaleIntensity


def import_mpl_plt():
Expand All @@ -14,9 +19,9 @@ def import_mpl_plt():
return mpl, plt


def rotate(image, radiological=True):
def rotate(image, radiological=True, n=-1):
# Rotate for visualization purposes
image = np.rot90(image, -1)
image = np.rot90(image, n)
if radiological:
image = np.fliplr(image)
return image
Expand Down Expand Up @@ -163,3 +168,60 @@ def color_labels(arrays, cmap_dict):
rgb[array == label] = color
results.append(rgb)
return results


def make_gif(
tensor: torch.Tensor,
axis: int,
duration: float, # of full gif
output_path: TypePath,
loop: int = 0,
optimize: bool = True,
rescale: bool = True,
reverse: bool = False,
) -> None:
try:
from PIL import Image as ImagePIL
except ModuleNotFoundError as e:
message = (
'Please install Pillow to use Image.to_gif():'
' pip install Pillow'
)
raise RuntimeError(message) from e
tensor = RescaleIntensity((0, 255))(tensor) if rescale else tensor
single_channel = len(tensor) == 1

# Move channels dimension to the end and bring selected axis to 0
axes = np.roll(range(1, 4), -axis)
tensor = tensor.permute(*axes, 0)

if single_channel:
mode = 'P'
tensor = tensor[..., 0]
else:
mode = 'RGB'
array = tensor.byte().numpy()
n = 2 if axis == 1 else 1
images = [ImagePIL.fromarray(rotate(i, n=n)).convert(mode) for i in array]
num_images = len(images)
images = list(reversed(images)) if reverse else images
frame_duration_ms = duration / num_images * 1000
if frame_duration_ms < 10:
fps = round(1000 / frame_duration_ms)
frame_duration_ms = 10
new_duration = frame_duration_ms * num_images / 1000
message = (
'The computed frame rate from the given duration is too high'
f' ({fps} fps). The highest possible frame rate in the GIF'
' file format specification is 100 fps. The duration has been set'
f' to {new_duration:.1f} seconds, instead of {duration:.1f}'
)
warnings.warn(message)
images[0].save(
output_path,
save_all=True,
append_images=images[1:],
optimize=optimize,
duration=frame_duration_ms,
loop=loop,
)

0 comments on commit ca4890a

Please sign in to comment.