# Animated plots in Jupyter sessions with the modern Jupyter flavors, specifically JupyterLab and Jupyter Notebook 7+, served via the MyBinder.org system

At present, launches of this notebook into MyBinder-served sessions will open in JupyterLab orJupyter Notebook 7+. Both of these are, at present, built upon JupyterLab components.  
(If your Jupyter interface is based on NbClassic or Jupyter Notebook 6.4 or earlier, see [here](https://github.com/fomightez/animated_matplotlib_classic-binder).)  
Not sure what flavor of Jupyter tech you are using for your notebooks, then see ['Quickly Navigating the tech of the Jupyter ecosystem post-2023'](https://gist.github.com/fomightez/e873947b502f70388d82644b17196279).  
In particular, see [the '**Modern JupyterLab vs. Jupyter Notebook 7+**' section of 'Quickly Navigating the tech of the Jupyter ecosystem post-2023'](https://gist.github.com/fomightez/e873947b502f70388d82644b17196279#modern-jupyterlab-vs-jupyter-notebook-7) if you are trying to determine which of the modern flavors you are using.  

If you want to switch flavors and you opened in a MyBinder-served session from the repo:  
If you are presently in JupyterLab, edit **close to the end of this url** of this page to remove `lab`. For example, `https://hub.ovh2.mybinder.org/user/fomightez-anima-tplotlib-binder-hc7q3ptr/lab/tree/index.ipynb` would become `https://hub.ovh2.mybinder.org/user/fomightez-anima-tplotlib-binder-hc7q3ptr/tree/index.ipynb`.  

If you are presently in Jupyter Notebook 7+, edit **close to the end of this url** of this page to add `lab` in front of `tree`. For example, `https://hub.ovh2.mybinder.org/user/fomightez-anima-tplotlib-binder-hc7q3ptr/tree/index.ipynb` would become `https://hub.ovh2.mybinder.org/user/fomightez-anima-tplotlib-binder-hc7q3ptr/lab/tree/index.ipynb` by the insertion of `lab` into the URL. 

This notebook started out as [this gist on animated plot approaches that worked in JupyterLab](https://gist.github.com/fomightez/e7e70099da1fea17b5e012d79f1d9d30), that was a companion to [the content that now comprises my repo 'animated_matplotlib_classic-binder'](https://github.com/fomightez/animated_matplotlib_classic-binder), and now has been expanded and improved as JupyterLab and Jupyter Notebook 7+ are the currrent centerpieces of the Jupyter ecosystem.

#### The Approaches

Probably the most established approach that works in JupyterLab and Jupyter Notebook 7+, is the use of Matplotlib's `animation.FuncAnimation()`.
However, the other approaches are illustrated first here because the Matplotlib's `animation.FuncAnimation()` involves making a setting to the backend and I want to be sure to illustrate that the other options don't need that. Feel free to skip on ahead to the the demontrations involving Matplotlib's `animation.FuncAnimation()`.





## Use of `clear_output()` in conjunction with a delay

This approach has the important feature it works in both the classic notebook interface and the modern flavors of Jupyter **without requiring any change to the code**.

Here's the basic version of this approach of using `clear_output()`, adapted from [this StackOveflow Answer](https://stackoverflow.com/a/52672859/8508004).

In [None]:
from IPython.display import clear_output
from matplotlib import pyplot as plt
import numpy as np
import collections

def live_plot(data_dict, figsize=(7,5), title=''):
    clear_output(wait=True)
    plt.figure(figsize=figsize)
    for label,data in data_dict.items():
        plt.plot(data, label=label)
    plt.title(title)
    plt.grid(True)
    plt.xlabel('epoch')
    plt.legend(loc='center left') # the plot evolves to the right
    plt.show()
    time.sleep(0.15) # extend delay between adding next frame in animation

data = collections.defaultdict(list)
for i in range(30):
    data['foo'].append(np.random.random())
    data['bar'].append(np.random.random())
    data['baz'].append(np.random.random())
    live_plot(data)


The more fleshed out demo of this approach builds on the Jupyter notebook with related code that can be viewed [here](https://nbviewer.org/gist/fomightez/f539a3770f2f3287bf12de2d2a549e3a).  
 
For the more fleshed out demo, first `torchflow` needs to be installed.  
Go ahead and run that below to get the final preparation of the environment out of the way. This will take several minutes.

In [None]:
%pip install torch

Restart the kernel after running that cell and then try running the cell below.  
**With that preparation complete**, continue on to examine and run this method...

In a previous notebook using `fig.canvas.draw()` to reset the image each round, which only works in the older, classic notebook interface ([here](https://nbviewer.org/gist/fomightez/f539a3770f2f3287bf12de2d2a549e3a)), the data was collected first and then the collected data used to plot using `fig.canvas.draw()` to reset the image each round. Here, pause is built in to make sure the animation steps through adding showing each next segment with time in between instead of just quickly blinking through doing that too fast to really see and just displaying the last step.    
In addition to the extra delay, the `live_plot()` function that handles updating the plot with each round uses IPython's display control to clear the output with each frame using `clear_output()`.  

In [None]:
# based on https://stackoverflow.com/a/52672859/8508004 to help with https://stackoverflow.com/q/75017358/8508004
from IPython.display import clear_output
from matplotlib import pyplot as plt
import numpy as np
import collections
import time

def live_plot(data_dict, figsize=(12,5), title=''):
    clear_output(wait=True)
    plt.figure(figsize=figsize)
    #plt.plot(data_dict["steps"],data_dict["r"] , 'r-', label = "real")
    #plt.plot(data_dict["steps"],data_dict["b"] , 'b-', label = "prediction")
    for i,_ in enumerate(data_dict["steps"]):
        plt.plot(data_dict["steps"][i], list(data_dict["r"][i]) , 'r-', )
        plt.plot(data_dict["steps"][i], list(data_dict["b"][i]) , 'b-', )
    plt.title(title)
    plt.grid(True)
    #plt.legend(loc='center left') # the plot evolves to the right
    plt.show()
    time.sleep(0.2) # extend delay between adding next frame in animation

data = collections.defaultdict(list)
    
import torch
from torch import nn

# torch.manual_seed(1)    # reproducible

# Hyper Parameters
TIME_STEP = 10      # rnn time step
INPUT_SIZE = 1      # rnn input size
LR = 0.02           # learning rate

# data
steps = np.linspace(0, np.pi*2, 100, dtype=np.float32)  # float32 for converting torch FloatTensor
x_np = np.sin(steps)
y_np = np.cos(steps)

class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()

        self.rnn = nn.RNN(
            input_size=INPUT_SIZE,
            hidden_size=32,     # rnn hidden unit
            num_layers=1,       # number of rnn layer
            batch_first=True,   # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
        )
        self.out = nn.Linear(32, 1)

    def forward(self, x, h_state):
        # x (batch, time_step, input_size)
        # h_state (n_layers, batch, hidden_size)
        # r_out (batch, time_step, hidden_size)
        r_out, h_state = self.rnn(x, h_state)

        outs = []    # save all predictions
        for time_step in range(r_out.size(1)):    # calculate output for each time step
            outs.append(self.out(r_out[:, time_step, :]))
        return torch.stack(outs, dim=1), h_state

        # instead, for simplicity, you can replace above codes by follows
        # r_out = r_out.view(-1, 32)
        # outs = self.out(r_out)
        # outs = outs.view(-1, TIME_STEP, 1)
        # return outs, h_state
        
        # or even simpler, since nn.Linear can accept inputs of any dimension 
        # and returns outputs with same dimension except for the last
        # outs = self.out(r_out)
        # return outs

rnn = RNN()
print(rnn)

optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)   # optimize all cnn parameters
loss_func = nn.MSELoss()

h_state = None      # for initial hidden state

for step in range(100):
    start, end = step * np.pi, (step+1)*np.pi   # time range
    # use sin predicts cos
    steps = np.linspace(start, end, TIME_STEP, dtype=np.float32, endpoint=False)  # float32 for converting torch FloatTensor
    x_np = np.sin(steps)
    y_np = np.cos(steps)

    x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis])    # shape (batch, time_step, input_size)
    y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])

    prediction, h_state = rnn(x, h_state)   # rnn output
    # !! next step is important !!
    h_state = h_state.data        # repack the hidden state, break the connection from last iteration

    loss = loss_func(prediction, y)         # calculate loss
    optimizer.zero_grad()                   # clear gradients for this training step
    loss.backward()                         # backpropagation, compute gradients
    optimizer.step()                        # apply gradients

    # plotting
    data['steps'].append(list(steps))
    data['r'].append(y_np.flatten())
    data['b'].append(prediction.data.numpy().flatten())
    live_plot(data);

The source of that approach suggested:
>"make sure you have a few cells below the plot, otherwise the view snaps in place each time the plot is redrawn."

Before getting to the most established way of performing animations in the modern interfaces, there's another that I've come across.

## Use of Ipywidget's Play (Animation) Widget combined with the plotting under control of Ipywidget's `interact()`

This comes from [here](https://stackoverflow.com/q/76212356/8508004) and relies on [Play (Animation) Widget](https://ipywidgets.readthedocs.io/en/stable/examples/Widget%20List.html#play-animation-widget) in conjunction with ipywidget's `interact()`, and [was part of the original content in my original animation plots repo](https://nbviewer.org/github/fomightez/animated_matplotlib_classic-binder/blob/master/Play_Animation_Widget.ipynb) & I noticed it worked in JupyterLab.  
Hence, this approach has the important feature it works in both the classic notebook interface and the modern flavors of Jupyter **without requiring any change to the code**.

Run the following cell and then press the '`Play`' button in the upper left side.

In [None]:
# https://stackoverflow.com/q/76212356/8508004
import ipywidgets as widgets
from mpl_toolkits import mplot3d
import numpy as np
import matplotlib.pyplot as plt

def theta(t):
    fig = plt.figure(figsize=(10,10))
    ax  = plt.axes(projection = "3d")
    z   = np.linspace(0, t, 500)
    x   = np.sin(z)
    y   = np.cos(z)
    ax.plot3D(x, y, z, 'red')
    plt.show()
    
widgets.interact(theta, t = widgets.Play(min=0, max = 15, interval = 200), continuous_update = True);

What follows next is the probably the most established way of performing animations in the modern interfaces.

The ones abov can cause snapping each time the plot is redrawn. And the one below using, Matplotlib's `animation.FuncAnimation()` doesn't make the notebook jumpy.

## Use of Matplotlib's `animation.FuncAnimation()`

Note that one of the main features of this approach is that it is nearly universal. Very little alteration is necessary to make the exact same code work in the classic notebook interface.

Note that to use this it is important that `ipympl` be installed.  
With that installed, then the backend is set with `%matplotlib ipympl`.

Here is a simple example, I also feature in the classic interface demo notebook; this code sligthly modified from [here](https://stackoverflow.com/a/28077104/8508004) works:

In [None]:
%matplotlib ipympl
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

x = np.linspace(0, 10, 100)
y = np.sin(x)

fig, ax = plt.subplots()
line, = ax.plot(x, y, color='k')

def update(num, x, y, line):
    line.set_data(x[:num], y[:num])
    line.axes.axis([0, 10, 0, 1])
    line.axes.set_ylim(-1.1,1.1)
    return line,

ani = animation.FuncAnimation(fig, update, len(x), fargs=[x, y, line],
                              interval=25, blit=True);

Note that it looks a lot like the style of interactive widget you get when you use `%matplotlib notebook` in the classic Jupyter interface. (However, in that interface you could hit the blur button that was in the upper rigth to stop it. I'm not seeing the stop button for this one, and so far even stopping the kernel doesn't work as it keeps running and the kernel keeps blinking as active as it does continue to run. So far, restarting the kernel by one of the various options under the '`Kernel`' menu seems to be the only way I've found.)

Here is another example, adapted from [here](https://discourse.jupyter.org/t/matplotlib-animation-not-appearing-in-jupyter-notebook/24938/3?u=fomightez).

In [None]:
%matplotlib ipympl
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

stepsize = 0.5
num_steps = 20
num_trials = 5

final_position = []

for _ in range(num_trials):
    pos = np.array([0, 0])
    path = []
    for i in range(num_steps):
        pos = pos + np.random.normal(0, stepsize, 2)
        path.append(pos)
    final_position.append(np.array(path))
    
x = [final_position[i][:,0] for i in range(len(final_position))]
y = [final_position[j][:,1] for j in range(len(final_position))]

fig = plt.figure(figsize=(10,6))
ax = fig.add_subplot()
fig.subplots_adjust(left=0.1, right=0.85)

cmap = plt.get_cmap('tab10')

def animate(frame):
    step_num = frame % (num_steps)
    trial_num = frame//(num_steps)
    color = cmap(trial_num % 10)
    if step_num == num_steps-1:
        label = f"Trial = {trial_num+1}"
    else:
        label = None
    ax.plot(x[trial_num][:step_num], y[trial_num][:step_num], color = color, ls = '-',linewidth = 0.5,
            marker = 'o', ms = 8, mfc = color, mec ='k', zorder = trial_num, label = label)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_title(f"Number of trials = {trial_num+1} \nNumber of steps = {step_num+1}")  
    if step_num == num_steps-1:
        ax.legend(fontsize=10, loc='upper left', bbox_to_anchor=(1, 1))
    ax.grid(True)
    
    return ax

fig.suptitle(f"2D random walk simulation for {num_steps} steps over {num_trials} trials.")
ani = FuncAnimation(fig, animate, frames= np.arange(0, (num_steps * num_trials)), interval = 100, repeat = False)
ani;

#### Demo of the Torchflow one that used `clear_output()` above adpated to  Matplotlib's `animation.FuncAnimation()`

Now that I pulled [the original code](https://stackoverflow.com/q/75017358/8508004) apart enough above (the one involving torchflow) to realize each is a segment and implemented a couple approaches, I also wondered if could use the method with Matplotlib's `animation.FuncAnimation()` with associated widget controller that is illustrated at the bottom of [here](https://nbviewer.org/github/fomightez/animated_matplotlib-binder/blob/master/index.ipynb) (and that I had used to answer [here](https://stackoverflow.com/a/75009196/8508004), recently) to make something that also would work in JupyterLab. (**Note it is not all animations methods involving `FuncAnimation()` that work in JupyterLab**, see [here](https://stackoverflow.com/a/73451172/8508004) for one that only works in classic notebook at this time. The widget is key.)

In [None]:
%matplotlib ipympl
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML, display
import collections
import time


fig = plt.figure(figsize=(12,5))
ax = plt.axes(xlim=(0, 320), ylim=(-1.75, 1.75))
lineplot, = ax.plot([], [], "r-")
lineplot2, = ax.plot([], [], "b-")


import torch
from torch import nn

# torch.manual_seed(1)    # reproducible

# Hyper Parameters
TIME_STEP = 10      # rnn time step
INPUT_SIZE = 1      # rnn input size
LR = 0.02           # learning rate

# data
steps = np.linspace(0, np.pi*2, 100, dtype=np.float32)  # float32 for converting torch FloatTensor
x_np = np.sin(steps)
y_np = np.cos(steps)

class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()

        self.rnn = nn.RNN(
            input_size=INPUT_SIZE,
            hidden_size=32,     # rnn hidden unit
            num_layers=1,       # number of rnn layer
            batch_first=True,   # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
        )
        self.out = nn.Linear(32, 1)

    def forward(self, x, h_state):
        # x (batch, time_step, input_size)
        # h_state (n_layers, batch, hidden_size)
        # r_out (batch, time_step, hidden_size)
        r_out, h_state = self.rnn(x, h_state)

        outs = []    # save all predictions
        for time_step in range(r_out.size(1)):    # calculate output for each time step
            outs.append(self.out(r_out[:, time_step, :]))
        return torch.stack(outs, dim=1), h_state

        # instead, for simplicity, you can replace above codes by follows
        # r_out = r_out.view(-1, 32)
        # outs = self.out(r_out)
        # outs = outs.view(-1, TIME_STEP, 1)
        # return outs, h_state
        
        # or even simpler, since nn.Linear can accept inputs of any dimension 
        # and returns outputs with same dimension except for the last
        # outs = self.out(r_out)
        # return outs

rnn = RNN()
print(rnn)

optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)   # optimize all cnn parameters
loss_func = nn.MSELoss()

h_state = None      # for initial hidden state

data = collections.defaultdict(list)
    
def init():
    global h_state, data 
    lineplot.set_data([], [])
    lineplot2.set_data([], [])
    data = collections.defaultdict(list)
    return lineplot, #return [lineplot] also works like in https://nbviewer.org/github/raphaelquast/jupyter_notebook_intro/blob/master/jupyter_nb_introduction.ipynb#pre-render-animations-and-export-to-HTML

def animate(i):
    global h_state, data 
    step = i
    start, end = step * np.pi, (step+1)*np.pi   # time range
    # use sin predicts cos
    steps = np.linspace(start, end, TIME_STEP, dtype=np.float32, endpoint=False)  # float32 for converting torch FloatTensor
    x_np = np.sin(steps)
    y_np = np.cos(steps)

    x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis])    # shape (batch, time_step, input_size)
    y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])

    prediction, h_state = rnn(x, h_state)   # rnn output
    # !! next step is important !!
    h_state = h_state.data        # repack the hidden state, break the connection from last iteration

    loss = loss_func(prediction, y)         # calculate loss
    optimizer.zero_grad()                   # clear gradients for this training step
    loss.backward()                         # backpropagation, compute gradients
    optimizer.step()                        # apply gradients

    # plotting
    data['steps'].append(list(steps))
    data['r'].append(y_np.flatten())
    data['b'].append(prediction.data.numpy().flatten())
    #lineplot.set_data([x], [y])
    #lineplot2.set_data([x], [z])
    lineplot.set_data(data["steps"],data["r"])
    lineplot2.set_data(data["steps"],data["b"])
    '''
    for i,_ in enumerate(data_dict["steps"]):
        plt.plot(data_dict["steps"][i], list(data_dict["r"][i]) , 'r-', )
        plt.plot(data_dict["steps"][i], list(data_dict["b"][i]) , 'b-', )
    '''
    return [lineplot]

anim = animation.FuncAnimation(fig, animate, init_func=init,
                           frames=100, interval=20, blit=True)
anim;

#### Use of Matplotlib's `animation.FuncAnimation()` with a player controller widget for the frames

If you can also make the animations output as a animation made of frames, and this works in JupyterLab and Jupyter Notebook 7+. It provides a player controll widget with a slider for fine tuning of the frame in view and so it can be a nice feature. 
**A exceptional quality of the animations produced with such a controller is that the animations remain working and controllable with the widget when the 'static' verisons of the saved notebook file are viewed in nbviewer.

(See the bottom of [here](https://nbviewer.org/github/fomightez/animated_matplotlib_classic-binder/blob/master/index.ipynb) for where this was originally utilized and described in more detail. I've tried to transfer much if it to here; however, I may have not entirely done that yet. I know my main original source of the ralited information was [a post by Louis Tiao](https://web.archive.org/web/20230330131423/http://louistiao.me/posts/notebooks/embedding-matplotlib-animations-in-jupyter-as-interactive-javascript-widgets/).)

Below is a demo from [here](https://gist.github.com/fomightez/e89bc19ec31d8ad1de1b8071c659e684) that uses that. When it shows up hit the 'play' button to play it or slide the slider to pick a frame:

In [None]:
%matplotlib ipympl
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML, display
plt.rcParams["animation.html"] = "jshtml"
plt.ioff() #needed so the second time you run it you get only single plot
plt.style.use('seaborn-v0_8-pastel')

fig = plt.figure()
ax = plt.axes(xlim=(0, 4), ylim=(-2, 2))
lineplot, = ax.plot([], [], lw=3)
    
def init():
    lineplot.set_data([], [])
    return lineplot, #return [lineplot] also works like in https://nbviewer.org/github/raphaelquast/jupyter_notebook_intro/blob/master/jupyter_nb_introduction.ipynb#pre-render-animations-and-export-to-HTML

def animate(i):
    x = np.linspace(0, 4, 1000)
    y = np.sin(2 * np.pi * (x - 0.01 * i))
    lineplot.set_data([x], [y])
    return [lineplot]

anim = animation.FuncAnimation(fig, animate, init_func=init,
                           frames=200, interval=20, blit=True)
anim

More use of widget with player slider controller:  
Here is each of the above demos made in turn in that style. Note, really subtle changes only need be made at the top.

In [None]:
%matplotlib ipympl
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML, display
plt.rcParams["animation.html"] = "jshtml"
plt.ioff() #needed so the second time you run it you get only single plot

x = np.linspace(0, 10, 100)
y = np.sin(x)

fig, ax = plt.subplots()
line, = ax.plot(x, y, color='k')

def update(num, x, y, line):
    line.set_data(x[:num], y[:num])
    line.axes.axis([0, 10, 0, 1])
    line.axes.set_ylim(-1.1,1.1)
    return line,

ani = animation.FuncAnimation(fig, update, len(x), fargs=[x, y, line],
                              interval=25, blit=True)
ani

In [None]:
%matplotlib ipympl
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
plt.rcParams["animation.html"] = "jshtml"
plt.ioff() #needed so the second time you run it you get only single plo

stepsize = 0.5
num_steps = 20
num_trials = 5

final_position = []

for _ in range(num_trials):
    pos = np.array([0, 0])
    path = []
    for i in range(num_steps):
        pos = pos + np.random.normal(0, stepsize, 2)
        path.append(pos)
    final_position.append(np.array(path))
    
x = [final_position[i][:,0] for i in range(len(final_position))]
y = [final_position[j][:,1] for j in range(len(final_position))]

fig = plt.figure(figsize=(10,6))
ax = fig.add_subplot()
fig.subplots_adjust(left=0.1, right=0.85)

cmap = plt.get_cmap('tab10')

def animate(frame):
    step_num = frame % (num_steps)
    trial_num = frame//(num_steps)
    color = cmap(trial_num % 10)
    if step_num == num_steps-1:
        label = f"Trial = {trial_num+1}"
    else:
        label = None
    ax.plot(x[trial_num][:step_num], y[trial_num][:step_num], color = color, ls = '-',linewidth = 0.5,
            marker = 'o', ms = 8, mfc = color, mec ='k', zorder = trial_num, label = label)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_title(f"Number of trials = {trial_num+1} \nNumber of steps = {step_num+1}")  
    if step_num == num_steps-1:
        ax.legend(fontsize=10, loc='upper left', bbox_to_anchor=(1, 1))
    ax.grid(True)
    
    return ax

fig.suptitle(f"2D random walk simulation for {num_steps} steps over {num_trials} trials.")
ani = FuncAnimation(fig, animate, frames= np.arange(0, (num_steps * num_trials)), interval = 100, repeat = False)
ani

In [None]:
# If upon first running, it shows a non-interactive, single static shot of the plot below the interactive one with the widget controller,
# JUST RE-RUN TWICE. Re-run usually fixes that display quirk.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML, display
plt.rcParams["animation.html"] = "jshtml"
#plt.ioff() #needed so the second time you run it you get only single plot
import collections
import time


fig = plt.figure(figsize=(12,5))
ax = plt.axes(xlim=(0, 320), ylim=(-1.75, 1.75))
lineplot, = ax.plot([], [], "r-")
lineplot2, = ax.plot([], [], "b-")


import torch
from torch import nn

# torch.manual_seed(1)    # reproducible

# Hyper Parameters
TIME_STEP = 10      # rnn time step
INPUT_SIZE = 1      # rnn input size
LR = 0.02           # learning rate

# data
steps = np.linspace(0, np.pi*2, 100, dtype=np.float32)  # float32 for converting torch FloatTensor
x_np = np.sin(steps)
y_np = np.cos(steps)

class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()

        self.rnn = nn.RNN(
            input_size=INPUT_SIZE,
            hidden_size=32,     # rnn hidden unit
            num_layers=1,       # number of rnn layer
            batch_first=True,   # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
        )
        self.out = nn.Linear(32, 1)

    def forward(self, x, h_state):
        # x (batch, time_step, input_size)
        # h_state (n_layers, batch, hidden_size)
        # r_out (batch, time_step, hidden_size)
        r_out, h_state = self.rnn(x, h_state)

        outs = []    # save all predictions
        for time_step in range(r_out.size(1)):    # calculate output for each time step
            outs.append(self.out(r_out[:, time_step, :]))
        return torch.stack(outs, dim=1), h_state

        # instead, for simplicity, you can replace above codes by follows
        # r_out = r_out.view(-1, 32)
        # outs = self.out(r_out)
        # outs = outs.view(-1, TIME_STEP, 1)
        # return outs, h_state
        
        # or even simpler, since nn.Linear can accept inputs of any dimension 
        # and returns outputs with same dimension except for the last
        # outs = self.out(r_out)
        # return outs

rnn = RNN()
print(rnn)

optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)   # optimize all cnn parameters
loss_func = nn.MSELoss()

h_state = None      # for initial hidden state

data = collections.defaultdict(list)
    
def init():
    global h_state, data 
    lineplot.set_data([], [])
    lineplot2.set_data([], [])
    data = collections.defaultdict(list)
    return lineplot, #return [lineplot] also works like in https://nbviewer.org/github/raphaelquast/jupyter_notebook_intro/blob/master/jupyter_nb_introduction.ipynb#pre-render-animations-and-export-to-HTML

def animate(i):
    global h_state, data 
    step = i
    start, end = step * np.pi, (step+1)*np.pi   # time range
    # use sin predicts cos
    steps = np.linspace(start, end, TIME_STEP, dtype=np.float32, endpoint=False)  # float32 for converting torch FloatTensor
    x_np = np.sin(steps)
    y_np = np.cos(steps)

    x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis])    # shape (batch, time_step, input_size)
    y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])

    prediction, h_state = rnn(x, h_state)   # rnn output
    # !! next step is important !!
    h_state = h_state.data        # repack the hidden state, break the connection from last iteration

    loss = loss_func(prediction, y)         # calculate loss
    optimizer.zero_grad()                   # clear gradients for this training step
    loss.backward()                         # backpropagation, compute gradients
    optimizer.step()                        # apply gradients

    # plotting
    data['steps'].append(list(steps))
    data['r'].append(y_np.flatten())
    data['b'].append(prediction.data.numpy().flatten())
    #lineplot.set_data([x], [y])
    #lineplot2.set_data([x], [z])
    lineplot.set_data(data["steps"],data["r"])
    lineplot2.set_data(data["steps"],data["b"])
    '''
    for i,_ in enumerate(data_dict["steps"]):
        plt.plot(data_dict["steps"][i], list(data_dict["r"][i]) , 'r-', )
        plt.plot(data_dict["steps"][i], list(data_dict["b"][i]) , 'b-', )
    '''
    return [lineplot]

anim = animation.FuncAnimation(fig, animate, init_func=init,
                           frames=100, interval=20, blit=True)
anim

Manually scrubbing back and forth with the slider allows you to choose a point in the building of the plot.

Note, that for some reason I've seen it break the widget normal looping ability at this time. I'm not sure what I did to break it. Manually scrubbing back and forth with the slider did still work even when that happend. I don't know what I broke to make the one above not loop? I tried adding to `init()` and that didn't seem to help. In fact when it broke it, **it would also break it for the one below that is simpler below**. Weird

I had seen the glitch I saw was not simply due to including multiple lines because this related, simple code one works to keep looping IN A SEPARATE, or new, NOTEBOOK: (It may or may not work here after running the one above.)

In [None]:
# If upon first running, it shows a non-interactive, single static shot of the plot below the interactive one with the widget controller,
# JUST RE-RUN TWICE. Re-run usually fixes that display quirk.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML, display
plt.rcParams["animation.html"] = "jshtml"
plt.ioff() #needed so the second time you run it you get only single plot

fig = plt.figure()
ax = plt.axes(xlim=(0, 4), ylim=(-2, 2))
lineplot, = ax.plot([], [], lw=3)
lineplot2, = ax.plot([], [], "r-")
    
def init():
    lineplot.set_data([], [])
    return lineplot, #return [lineplot] also works like in https://nbviewer.org/github/raphaelquast/jupyter_notebook_intro/blob/master/jupyter_nb_introduction.ipynb#pre-render-animations-and-export-to-HTML

def animate(i):
    x = np.linspace(0, 4, 1000)
    y = np.sin(2 * np.pi * (x - 0.01 * i))
    z = np.sin(2.2 * np.pi * (x - 0.31 * i))
    lineplot.set_data([x], [y])
    lineplot2.set_data([x], [z])
    return [lineplot,lineplot2]

anim = animation.FuncAnimation(fig, animate, init_func=init,
                           frames=200, interval=20, blit=True)
anim

More notes on that approach with the widget player controller * `animation.FuncAnimation()`:  
Sometimes `plt.ioff()` isn't needed; howver, if you are seeing an empty plot or the final frame show up in addition to the one with the slider then try including it.

Unlike the earlier methods demonstrated above in this notebook that play through and don't let you 'pause' at specific points, with the widget controller **you can 'scrub' back and forth to chose points to highlight**. Importantly, **the animation remains playable when the notebook is viewed in a static render at nbviewer**, as [this static view of that notebook](https://nbviewer.org/gist/fomightez/d862333d8eefb94a74a79022840680b1) demonstrates; there is no need for actively running the notebook, unlike the animations produced by the cells above. (GitHub's notebook viewer does not presently support that; you must use [nbviewer](https://nbviewer.jupyter.org/). A variation on that process that produces the widget controller can also produce a portable HTML5 video animation file that can be embedded elsewhere. Indeed, a related [example here](https://stackoverflow.com/a/70764815/8508004) is set up to make such an HTML5 video and has a line that can be uncommented for saving a file version of the video. Making the HTML5 is also covered in [a post by Louis Tiao](https://web.archive.org/web/20230330131423/http://louistiao.me/posts/notebooks/embedding-matplotlib-animations-in-jupyter-as-interactive-javascript-widgets/). (Presently, ffmpeg is not installed and so trying to run [the current code](https://stackoverflow.com/a/70764815/8508004) in sessions here results in `RuntimeError: Requested MovieWriter (ffmpeg) not available`. The last three lines can be deleted and then the animation will be shown with no widget.  Alternatively, running in the active session `%conda install ffmpeg` and then restarting the kernel will allow that notebook to save a portable file version of the video.)  
In regards to running the animation with no widget, I also have made comments below a 'animation.FuncAnimation' example [describing how to run that answer code](https://stackoverflow.com/questions/75389311/plotting-a-live-graph-using-matplotlib#comment133024683_75389468) in this session or even in Spyder (using Qt).

Note if you are seeing evidence of overdrawing "causing thick lines and "blocky" texts" or "Distorted tick labels" when using `ArtistAnimation` aor `FuncAnimation`, see [here](https://stackoverflow.com/q/65654880/8508004).