In [None]:
import dash
from dash import dcc, html, Input, Output
import plotly.express as px
import pandas as pd
from jax import vmap

dash_app = dash.Dash()

dash_app.layout = html.Div(
    [
        html.H4("Animated self-adaptive weights and predictions"),
        html.P("Select an animation:"),
        dcc.RadioItems(
            id="selection",
            options=["λ weights", "Prediction"],
            value="Prediction",
        ),
        dcc.Loading(dcc.Graph(id="graph"), type="cube"),
    ],
    style={"backgroundColor": "white"},
)

df = pd.DataFrame({'ground_truth' : u_test[0].ravel(),
                   'pred': vmap(vmap(trainer.model, (None, 0, None)), (None, None, 0))(a_test[0], x, t).ravel(), 
                   'x': jnp.tile(x, len(t)), 
                   't': jnp.repeat(t, len(x)),
                   })

@dash_app.callback(Output("graph", "figure"), Input("selection", "value"))
def display_animated_graph(selection):
    animations = {
        "λ weights": px.imshow(
            jnp.array(trainer.λ_list),
            animation_frame=0,
        ),
        "Prediction": px.line(df, x='x', y=['ground_truth', 'pred'], animation_frame='t',
            labels={'t': 'Time (t)', 'u': 'u(x, t)'},
        ),
    }
    fig = animations[selection]
    
    # Update transition duration here
    fig.layout.updatemenus[0].buttons[0].args[1]["transition"]["duration"] = 0

    return fig

dash_app.run_server()

In [None]:
import plotly.express as px

fig = px.line(df, 
              x='x', 
              y=['ground_truth', 'pred'], 
              animation_frame='t', 
              labels={'t': 'Time (t)', 'u': 'u(x, t)'}
             )

fig.layout.updatemenus[0].buttons[0].args[1]["transition"]["duration"] = 0

fig["data"][0]["line"]["dash"] = "dash"

for frame in fig.frames:
    for data in frame.data:
        if data.name == 'ground_truth':
            data.line.dash = 'dash'

fig.show()

In [None]:
def animate_predictions(a_test, u_test, x, t, model, *args):
    dash_app = dash.Dash()
    dash_app.layout = html.Div(
        [
            html.H4("Animated predictions"),
            html.P("Select an animation:"),
            dcc.RadioItems(
                id="selection",
                options=["Predictions", "Errors"],
                value="Predictions",
            ),
            dcc.Input(id='sample', type="number", value = 0, placeholder="Sample number"),
            dcc.Loading(dcc.Graph(id="graph"), type="cube"),
        ],
        style={"backgroundColor": "white"},
    )
    
    
    # create a callback that plots the prediction
    @dash_app.callback(
        Output("graph", "figure"), 
        [Input("selection", "value"), Input("sample", "value")]
    )
    
    def display_animated_graph(selection, sample):
        preds = vmap(vmap(model, (None, 0, None)), (None, None, 0))(a_test[sample], x, t)
        
        df = pd.DataFrame({
            "ground_truth": u_test[sample].flatten(),
            "pred": preds.flatten(),
            "x": jnp.tile(x, len(t)),
            "t": jnp.repeat(t, len(x))
        })
        
        animations = {
            "Predictions": px.line(df, x='x', y=['ground_truth', 'pred'], animation_frame='t',
                labels={'t': 'Time (t)', 'u': 'u(x, t)'},
            ),
            "Errors": px.line(df, x='x', y=['ground_truth', 'pred'], animation_frame='t',
                labels={'t': 'Time (t)', 'u': 'u(x, t)'},
            ),
        }
        
        fig = animations[selection]
        
        # Update transition duration here
        fig.layout.updatemenus[0].buttons[0].args[1]["transition"]["duration"] = 0

        return fig

    dash_app.run_server()
    