In [159]:
import numpy as np
import plotly.graph_objects as go

bg_color = "#FAF9F6"
line_color = "#121212"
grid_color = "#D0D0D0"
blue = "#3399FF"

L = 10
x = np.linspace(0, L, 1000)


def u(x, t, n):
    return np.sin(n * np.pi * x / L) * np.cos(n * np.pi * t / L)


def plot_wave_2d(function, title):
    fig = go.Figure()

    t_values = np.linspace(0, 2 * L, 250)

    fig.add_trace(
        go.Scatter(
            x=x,
            y=function(x, t_values[0]),
            mode="lines",
            line=dict(
                color=blue,
                width=3,
            ),
        )
    )

    frames = []
    for t in t_values:
        frames.append(
            go.Frame(
                data=[
                    go.Scatter(
                        x=x, y=function(x, t), mode="lines", line=dict(color=blue, width=3)
                    )
                ],
                name=f"t={t:.2f}",
            )
        )

    fig.frames = frames

    fig.update_layout(
        title=title,
        xaxis_title="x",
        yaxis_title="y(x,t)",
        autosize=True,
        margin=dict(t=60, b=50, l=50, r=50),
        paper_bgcolor=bg_color,
        plot_bgcolor=bg_color,
        font=dict(
            family="monospace",
            size=14,
            color=line_color,
        ),
        xaxis=dict(
            gridcolor=grid_color,
            zerolinecolor=line_color,
        ),
        yaxis=dict(
            gridcolor=grid_color,
            zerolinecolor=line_color,
            range=[-2, 2],
        ),
        sliders=[
            dict(
                active=0,
                len=0.4,
                x=0,
                y=0,
                xanchor="left",
                yanchor="top",
                currentvalue=dict(prefix="t = ", visible=True, xanchor="right"),
                pad=dict(b=10, t=50),
                minorticklen=0,
                steps=[
                    dict(
                        args=[
                            [f"t={t:.2f}"],
                            dict(frame=dict(duration=0, redraw=True), mode="immediate"),
                        ],
                        label=f"{t:.1f}",
                        method="animate",
                    )
                    for t in t_values
                ],
            ),
        ],
    )

    fig.show()


def plot_wave_3d(Z, title):
    fig = go.Figure(
        data=[
            go.Surface(
                z=Z,
                x=X,
                y=T,
                showscale=False,
                colorscale="Plasma",
            )
        ]
    )
    fig.update_layout(
        title=title,
        autosize=True,
        margin=dict(t=60, b=50, l=50, r=50),
        paper_bgcolor=bg_color,
        font=dict(
            family="monospace",
            size=14,
            color=line_color,
        ),
        scene=dict(
            bgcolor=bg_color,
            xaxis_title="x",
            yaxis_title="t",
            zaxis_title="y(x,t)",
            xaxis=dict(
                gridcolor=grid_color,
                zerolinecolor=line_color,
                backgroundcolor=bg_color,
            ),
            yaxis=dict(
                gridcolor=grid_color,
                zerolinecolor=line_color,
                backgroundcolor=bg_color,
            ),
            zaxis=dict(
                gridcolor=grid_color,
                zerolinecolor=line_color,
                backgroundcolor=bg_color,
            ),
        ),
    )

    fig.show()

In [160]:
def u1(x, t):
    return u(x, t, 1)


x = np.linspace(0, 10, 100)
t = np.linspace(0, 20, 100)
X, T = np.meshgrid(x, t)
Z = u1(X, T)

plot_wave_2d(u1, "Harmonic n=1")
plot_wave_3d(Z, "Harmonic n=1 over space and time")

In [161]:
def u2(x, t):
    return u(x, t, 2)


def u3(x, t):
    return u(x, t, 3)


def superposition(x, t):
    return u2(x, t) + u3(x, t)

x = np.linspace(0, 10, 100)
t = np.linspace(0, 20, 100)
X, T = np.meshgrid(x, t)
Z = superposition(X, T)

plot_wave_2d(superposition, "Superposition of harmonics n=2 and n=3")
plot_wave_3d(Z, "Superposition of harmonics n=2 and n=3 over space and time")

In [164]:
def plot_wave_2d(functions, labels, colors, title):
    fig = go.Figure()

    t_values = np.linspace(0, 2 * L, 250)

    # Add initial traces for each function
    for i, (function, label, color) in enumerate(zip(functions, labels, colors)):
        fig.add_trace(
            go.Scatter(
                x=x,
                y=function(x, t_values[0]),
                mode="lines",
                name=label,  # This creates the legend entry
                line=dict(
                    color=color,
                    width=3,
                ),
            )
        )

    # Create frames with all functions
    frames = []
    for t in t_values:
        frame_data = []
        for function, label, color in zip(functions, labels, colors):
            frame_data.append(
                go.Scatter(
                    x=x,
                    y=function(x, t),
                    mode="lines",
                    name=label,
                    line=dict(color=color, width=3),
                )
            )
        frames.append(
            go.Frame(
                data=frame_data,
                name=f"t={t:.2f}",
            )
        )

    fig.frames = frames

    fig.update_layout(
        title=title,
        xaxis_title="x",
        yaxis_title="y(x,t)",
        autosize=True,
        margin=dict(t=60, b=50, l=50, r=50),
        paper_bgcolor=bg_color,
        plot_bgcolor=bg_color,
        font=dict(
            family="monospace",
            size=14,
            color=line_color,
        ),
        xaxis=dict(
            gridcolor=grid_color,
            zerolinecolor=line_color,
        ),
        yaxis=dict(
            gridcolor=grid_color,
            zerolinecolor=line_color,
            range=[-2, 2],
        ),
        showlegend=True,  # Enable legend for clicking
        legend=dict(
            x=1.02,
            y=1,
            xanchor="left",
            yanchor="top",
        ),
        sliders=[
            dict(
                active=0,
                len=0.4,
                x=0,
                y=0,
                xanchor="left",
                yanchor="top",
                currentvalue=dict(prefix="t = ", visible=True, xanchor="right"),
                pad=dict(b=10, t=50),
                minorticklen=0,
                steps=[
                    dict(
                        args=[
                            [f"t={t:.2f}"],
                            dict(frame=dict(duration=0, redraw=True), mode="immediate"),
                        ],
                        label=f"{t:.1f}",
                        method="animate",
                    )
                    for t in t_values
                ],
            ),
        ],
    )

    fig.show()


# Usage example:
def f1(x, t):
    return np.sin(1 * np.pi * x / L) * np.cos(1 * np.pi * t / L)


def f2(x, t):
    return np.sin(2 * np.pi * x / L) * np.cos(2 * np.pi * t / L)


def f3(x, t):
    return np.sin(3 * np.pi * x / L) * np.cos(3 * np.pi * t / L)

def superposition(x, t):
    return f1(x, t) + f2(x, t) + f3(x, t)

# Plot multiple functions
plot_wave_2d(
    functions=[f2, f3, superposition],
    labels=["n=2", "n=3", "superposition"],
    colors=["#3399FF", "#FF3399", "#33FF99"],
    title="Wave Functions",
)