# Step 4
## Animation

What's better than a picture?

A bunch of them used to make an animation!

This notebook demonstrates:
- how to use platipy to make animations
- different types of animation

In [None]:
from pathlib import Path

import numpy as np
import SimpleITK as sitk

import matplotlib.pyplot as plt

from platipy.imaging import ImageVisualiser
from platipy.dicom.io.crawl import process_dicom_directory
from platipy.imaging.label.utils import get_com
from platipy.imaging.utils.crop import crop_to_label_extent

from platipy.imaging.registration.linear import linear_registration
from platipy.imaging.registration.deformable import fast_symmetric_forces_demons_registration
from platipy.imaging.registration.utils import apply_transform

from platipy.imaging.visualisation.comparison import contour_comparison

from platipy.imaging.generation.dvf import generate_field_expand

from platipy.imaging.visualisation.utils import project_onto_arbitrary_plane
from platipy.imaging.visualisation.animation import generate_animation_from_image_sequence

import matplotlib.colors as mcolors
import matplotlib.patches as mpatches

import seaborn as sns

%matplotlib notebook

In [None]:
"""
Let's start by loading in some data
"""

data_dir = Path("./input/NIfTI/RTMAC_LIVE_003/")
data = {}
data["MRI"] = sitk.ReadImage( str(list(data_dir.glob("IMAGES/*.nii.gz"))[0]) , sitk.sitkUInt32)
data["LABELS"] = {}
for s_file in data_dir.glob("STRUCTURES/*.nii.gz"):
    data["LABELS"][s_file.name[26:-7]] = sitk.ReadImage( str(s_file) )

In [None]:
"""
Let's visualise the image
"""

vis = ImageVisualiser(data["MRI"], cut=get_com(data["LABELS"]["GLND_SUBMAND_R"]), window=(0,400), figure_size_in=6)
fig = vis.show()

print(get_com(data["LABELS"]["GLND_SUBMAND_R"]))

### Animate the image slices

Useful if you want to see the image, or to just show off!

For all these animations, we need a list of the images/overlays to display.

These should be two dimensional - so we can actually animate it!

In [None]:
img_list = []

for ii in np.arange(0,data["MRI"].GetSize()[2]):
    
    # Take a slice of the image
    img_list.append(data["MRI"][:,:,int(ii)])

In [None]:
"""
Generating an image is not so difficult
"""

animation = generate_animation_from_image_sequence(
    image_list = img_list,
    output_file='animation.gif',
    fps=30,
    contour_list=None,
    scalar_list=None,
    figure_size_in=6,
    image_cmap=plt.cm.Greys_r,
    contour_cmap=plt.cm.rainbow,
    scalar_cmap=plt.cm.magma,
    image_window=[0, 400],
    scalar_min=None,
    scalar_max=None,
    scalar_alpha=0.5,
    image_origin='upper',
)

In [None]:
"""
We can just add contours to this!
"""

img_list = []
ctr_list = []

for ii in np.arange(0,data["MRI"].GetSize()[2]):
    
    # Take a slice of the image
    img_list.append(data["MRI"][:,:,int(ii)])
    ctr_list.append({
        s:data["LABELS"][s][:,:,int(ii)]
        for s in data["LABELS"]
    })

In [None]:
"""
Similar to before!
"""

animation = generate_animation_from_image_sequence(
    image_list=img_list,
    output_file='animation_contours.gif',
    fps=30,
    contour_list=ctr_list,
    scalar_list=None,
    figure_size_in=6,
    image_cmap=plt.cm.Greys_r,
    contour_cmap=plt.cm.rainbow,
    scalar_cmap=plt.cm.magma,
    image_window=[0, 400],
    scalar_min=None,
    scalar_max=None,
    scalar_alpha=0.5,
    image_origin='upper',
)

### Animate rotation

We can simulate rotation around an image with projections.

A useful way to examine contours!

In [None]:
angle_list = np.linspace(0, np.pi*2, 40)
img_list = []
ctr_list = []

for angle in angle_list:
    im_rotate = project_onto_arbitrary_plane(
        data["MRI"],
        projection_name='mean',
        rotation_axis=[0,0,1],
        rotation_angle=angle,
        projection_axis=1,
        default_value=0
    )
    img_list.append(im_rotate)
    
    c_rotate = {c:project_onto_arbitrary_plane(
        data["LABELS"][c],
        projection_name='max',
        rotation_axis=[0,0,1],
        rotation_angle=angle,
        projection_axis=1,
        default_value=0
    ) for c in data["LABELS"]}
    ctr_list.append(c_rotate)

In [None]:
"""
Once again let's animate
"""

animation = generate_animation_from_image_sequence(
    image_list = img_list,
    output_file='animation_rotate.gif',
    fps=10,
    contour_list=ctr_list,
    scalar_list=None,
    figure_size_in=6,
    image_cmap=plt.cm.Greys_r,
    contour_cmap=plt.cm.rainbow,
    scalar_cmap=plt.cm.magma,
    image_window=[0, 200],
    scalar_min=None,
    scalar_max=None,
    scalar_alpha=0.5,
    image_origin='lower',
)

In [None]:
"""
Sometimes an outline is hard to see
We can also visualise scalar fields
"""

angle_list = np.linspace(0, np.pi*2, 40)
img_list = []
ctr_list = []

for angle in angle_list:
    im_rotate = project_onto_arbitrary_plane(
        data["MRI"],
        projection_name='mean',
        rotation_axis=[0,0,1],
        rotation_angle=angle,
        projection_axis=1,
        default_value=0
    )
    img_list.append(im_rotate)
    
    # Similar to before - here we use mean intensity projection 
    # This will effectively show "thickness"
    c_rotate = project_onto_arbitrary_plane(
        data["LABELS"]["PAROTID_L"] | data["LABELS"]["PAROTID_R"],
        projection_name='mean',
        rotation_axis=[0,0,1],
        rotation_angle=angle,
        projection_axis=1,
        default_value=0
    )
    ctr_list.append(c_rotate)

In [None]:
"""
Once again let's animate
"""

animation = generate_animation_from_image_sequence(
    image_list = img_list,
    output_file='animation_rotate_scalar.gif',
    fps=10,
    contour_list=None,
    scalar_list=ctr_list,
    figure_size_in=6,
    image_cmap=plt.cm.Greys_r,
    contour_cmap=plt.cm.rainbow,
    scalar_cmap=plt.cm.plasma,
    image_window=[0, 200],
    scalar_min=0.01,
    scalar_max=None,
    scalar_alpha=0.5,
    image_origin='lower',
)

### Animate anatomical changes

This is useful for imaging acquired over time.

Since we don't have that here, lets' simulate!

In [None]:
"""
Let's simulate the effect of the parotids shrinking
"""

parotids_expand, tfm_expand, dvf_expand = generate_field_expand(data["LABELS"]["PAROTID_L"] | data["LABELS"]["PAROTID_R"], expand=20)

In [None]:
vis = ImageVisualiser(data["MRI"], cut=get_com(data["LABELS"]["PAROTID_R"]), window=(0,400), figure_size_in=6)
vis.add_contour({
    "Expanded":parotids_expand,
    "Original":data["LABELS"]["PAROTID_L"] | data["LABELS"]["PAROTID_R"]
})
fig = vis.show()

In [None]:
"""
Now we interpolate the deformation field
Applying this to the image and contours
"""

interp_list = np.linspace(0, 1, 10)
img_list = []
ctr_list = []

for index, interp_value in enumerate(interp_list):
    
    print(index, interp_value)
    

    """
    Generate interpolated DVF
    """

    expand_dvf_interpolate = sitk.Compose(*[float(interp_value)*sitk.VectorIndexSelectionCast(dvf_expand, i, sitk.sitkFloat64) for i in [0,1,2]])
    expand_transform_interpolate = sitk.DisplacementFieldTransform( expand_dvf_interpolate )


    """
    Generate deformed and projected image
    """

    im_deform = apply_transform(data["MRI"], transform=expand_transform_interpolate, default_value=0, interpolator=sitk.sitkLinear)


    im_project = project_onto_arbitrary_plane(
        im_deform,
        projection_name='mean',
        rotation_axis=[0,0,1],
        rotation_angle=0,
        projection_axis=1
    )

    img_list.append(im_project)

    
    c_deform = apply_transform(data["LABELS"]["PAROTID_L"] | data["LABELS"]["PAROTID_R"], transform=expand_transform_interpolate, default_value=0, interpolator=sitk.sitkNearestNeighbor)
    
    c_project = project_onto_arbitrary_plane(
        c_deform,
        projection_name='mean',
        rotation_axis=[0,0,1],
        rotation_angle=angle,
        projection_axis=1,
        default_value=0
    )
    ctr_list.append(c_project)

In [None]:
"""
Once again let's animate
"""

animation = generate_animation_from_image_sequence(
    image_list = img_list,
    output_file='animation_expand.gif',
    fps=2,
    contour_list=None,
    scalar_list=ctr_list,
    figure_size_in=6,
    image_cmap=plt.cm.Greys_r,
    contour_cmap=plt.cm.rainbow,
    scalar_cmap=plt.cm.plasma,
    image_window=[0, 200],
    scalar_min=0.01,
    scalar_max=None,
    scalar_alpha=0.5,
    image_origin='lower',
)