diff --git a/tests/data/test_image.py b/tests/data/test_image.py index 01012cc7a..1d4b5e1f9 100644 --- a/tests/data/test_image.py +++ b/tests/data/test_image.py @@ -201,3 +201,12 @@ def assert_shape(shape_in, shape_out): assert_shape((5, 5, 5), (1, 5, 5, 5)) assert_shape((1, 5, 5, 5), (1, 5, 5, 5)) assert_shape((4, 5, 5, 5), (4, 5, 5, 5)) + + def test_fast_gif(self): + with self.assertWarns(UserWarning): + with tempfile.NamedTemporaryFile(suffix='.gif') as f: + self.sample_subject.t1.to_gif(0, 0.0001, f.name) + + def test_gif_rgb(self): + with tempfile.NamedTemporaryFile(suffix='.gif') as f: + tio.ScalarImage(tensor=torch.rand(3, 4, 5, 6)).to_gif(0, 1, f.name) diff --git a/torchio/data/image.py b/torchio/data/image.py index 97851c14a..eaad2c52b 100644 --- a/torchio/data/image.py +++ b/torchio/data/image.py @@ -584,6 +584,43 @@ def as_pil(self, transpose=True): array = tensor.clamp(0, 255).numpy()[0] return ImagePIL.fromarray(array.astype(np.uint8)) + def to_gif( + self, + axis: int, + duration: float, # of full gif + output_path: TypePath, + loop: int = 0, + rescale: bool = True, + optimize: bool = True, + reverse: bool = False, + ) -> None: + """Save an animated GIF of the image. + + Args: + axis: Spatial axis (0, 1 or 2). + duration: Duration of the full animation in seconds. + output_path: Path to the output GIF file. + loop: Number of times the GIF should loop. + ``0`` means that it will loop forever. + rescale: Use :class:`~torchio.transforms.preprocessing.intensity.rescale.RescaleIntensity` + to rescale the intensity values to :math:`[0, 255]`. + optimize: If ``True``, attempt to compress the palette by + eliminating unused colors. This is only useful if the palette + can be compressed to the next smaller power of 2 elements. + reverse: Reverse the temporal order of frames. + """ # noqa: E501 + from ..visualization import make_gif # avoid circular import + make_gif( + self.data, + axis, + duration, + output_path, + loop=loop, + rescale=rescale, + optimize=optimize, + reverse=reverse, + ) + def get_center(self, lps: bool = False) -> TypeTripletFloat: """Get image center in RAS+ or LPS+ coordinates. diff --git a/torchio/visualization.py b/torchio/visualization.py index 4ec85cd7f..89f745b5f 100644 --- a/torchio/visualization.py +++ b/torchio/visualization.py @@ -1,8 +1,13 @@ +import warnings + +import torch import numpy as np -from .data.image import Image, LabelMap +from .typing import TypePath from .data.subject import Subject +from .data.image import Image, LabelMap from .transforms.preprocessing.spatial.to_canonical import ToCanonical +from .transforms.preprocessing.intensity.rescale import RescaleIntensity def import_mpl_plt(): @@ -14,9 +19,9 @@ def import_mpl_plt(): return mpl, plt -def rotate(image, radiological=True): +def rotate(image, radiological=True, n=-1): # Rotate for visualization purposes - image = np.rot90(image, -1) + image = np.rot90(image, n) if radiological: image = np.fliplr(image) return image @@ -163,3 +168,60 @@ def color_labels(arrays, cmap_dict): rgb[array == label] = color results.append(rgb) return results + + +def make_gif( + tensor: torch.Tensor, + axis: int, + duration: float, # of full gif + output_path: TypePath, + loop: int = 0, + optimize: bool = True, + rescale: bool = True, + reverse: bool = False, + ) -> None: + try: + from PIL import Image as ImagePIL + except ModuleNotFoundError as e: + message = ( + 'Please install Pillow to use Image.to_gif():' + ' pip install Pillow' + ) + raise RuntimeError(message) from e + tensor = RescaleIntensity((0, 255))(tensor) if rescale else tensor + single_channel = len(tensor) == 1 + + # Move channels dimension to the end and bring selected axis to 0 + axes = np.roll(range(1, 4), -axis) + tensor = tensor.permute(*axes, 0) + + if single_channel: + mode = 'P' + tensor = tensor[..., 0] + else: + mode = 'RGB' + array = tensor.byte().numpy() + n = 2 if axis == 1 else 1 + images = [ImagePIL.fromarray(rotate(i, n=n)).convert(mode) for i in array] + num_images = len(images) + images = list(reversed(images)) if reverse else images + frame_duration_ms = duration / num_images * 1000 + if frame_duration_ms < 10: + fps = round(1000 / frame_duration_ms) + frame_duration_ms = 10 + new_duration = frame_duration_ms * num_images / 1000 + message = ( + 'The computed frame rate from the given duration is too high' + f' ({fps} fps). The highest possible frame rate in the GIF' + ' file format specification is 100 fps. The duration has been set' + f' to {new_duration:.1f} seconds, instead of {duration:.1f}' + ) + warnings.warn(message) + images[0].save( + output_path, + save_all=True, + append_images=images[1:], + optimize=optimize, + duration=frame_duration_ms, + loop=loop, + )