In [2]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML

In [62]:
class FrameAnimator(object):
    
    def __init__(self, figsize=(6, 5)):
        self.figsize = figsize
        
    def build(self, imgs, max_concentration,
              positions_x, positions_y,
              gradient_history,
              frame_start_i, frame_stop_i, 
              interval=100, blit=True, text=None):
        
        fig, ax = plt.subplots(1, figsize=self.figsize)
        concentration_plot = ax.imshow(imgs[0], vmin=0, vmax=max_concentration, cmap="Greens")
        scatter_plot = ax.scatter([], [], c="orange", s=80, edgecolors='black')
        quiver_plot = ax.quiver(positions_x[0], positions_y[0], gradient_history[0][0], gradient_history[0][1], color="black")

        fig.colorbar(concentration_plot, ax=ax, shrink=1)
        ax.axis('on')
        
        def init():
            concentration_plot.set_array(np.zeros_like(imgs[0]))
            scatter_plot.set_offsets(np.array([0, 0]))
            quiver_plot.set_offsets([0, 0, 0, 0])
            return (concentration_plot, scatter_plot, quiver_plot)
        
        def animate(i):
            concentration_plot.set_array(imgs[i])
            scatter_plot.set_offsets(np.array([positions_x[i], positions_y[i]]))
            quiver_plot.set_offsets(np.array([positions_x[i], positions_y[i]]))
            quiver_plot.set_UVC(gradient_history[i][0], gradient_history[i][1])
            text_str = text[i] if text is not None else ""
            ax.set_title("Frame {}/{} -- {}".format(frame_start_i+i, frame_stop_i, text_str))
            
            return (concentration_plot, scatter_plot, quiver_plot)
            
        
        plt.close()
        self.anim = animation.FuncAnimation(fig, animate, 
                                           init_func=init,
                                           frames=len(imgs),
                                           interval=interval,
                                           blit=blit,
                                           repeat=False)
        
    def visualize(self, mode='HTML'):
        if mode == 'HTML':
            return HTML(self.anim.to_html5_video())
        elif mode == 'jsHTLM':
            return HTML(self.anim.to_jshtml())