In [None]:
import jax
import jax.numpy as jnp
import numpy as np

import gym
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.animation import FuncAnimation
from typing import Iterable, Union, Callable, Tuple, List, Dict, Any
import time

In [None]:
env = gym.make('CartPole-v1', render_mode="rgb_array")
env.reset()
image_seq = []

In [None]:
done = False
while not done:
  observation, reward, done, truncated, info = env.step(env.action_space.sample())
  image_seq.append(env.render())

In [None]:
def animate(image_seq: Iterable, axes=None, close_fig=True):
  from matplotlib import rc
  rc('animation', html='html5')

  fig = plt.figure() # get current figure
  if axes is None:
    fig, axes = plt.subplots(1,1)

  graphical_element = axes.imshow(image_seq[0])

  def animate_frame(i):
    graphical_element.set_data(image_seq[i])
    return graphical_element,
  animation_handler = animation.FuncAnimation(fig, animate_frame, frames=len(image_seq), interval=50)

  if close_fig:
    plt.close(fig)
    
  return animation_handler

In [None]:
a = animate(image_seq)
print(type(a))
a



In [23]:
# Row Major
def animation_table(image_seq: Iterable, grids: Iterable[int], close_fig=True):
  from matplotlib import rc
  rc('animation', html='html5')

  assert len(image_seq) == grids[0] * grids[1], "The number of images must be equal to the number of grids"
  fig, axes = plt.subplots(grids[0], grids[1], squeeze=False)

  max_num_frames = max([len(img) if img is not None else 0 for img in image_seq])

  graphical_element_mapping = dict()
  for i, rows in enumerate(axes):
    for j, ax in enumerate(rows):
      img = image_seq[i * grids[1] + j]
      if img is None:
        ax.axis('off')
        continue
      graphical_element = ax.imshow(img[0]) # first frame
      graphical_element_mapping[graphical_element] = img

  def animate_frame(i):
    for graphical_element, img in graphical_element_mapping.items():
      graphical_element.set_data(img[min(i, len(img)-1)])
    return tuple(graphical_element_mapping.keys())
  animation_handler = animation.FuncAnimation(fig, animate_frame, frames=max_num_frames, interval=50)

  if close_fig:
    plt.close(fig)
    
  return animation_handler

In [24]:
a = animation_table([image_seq, image_seq, None, image_seq], (2,2))
a