In [1]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [2]:
def get_slope_line(m, x0, y0, x):
    """Returns the y-values of a line with slope m passing through (x0, y0) evaluated at x."""
    return m * (x - x0) + y0

In [4]:
#| label: legendre_u_to_f

# -- Precompute the "constant" data for U(S) --
S_fixed = np.linspace(0, 2, 100)
U_curve = 0.5 * S_fixed**2 + 1

# -- Create base figure with subplots --
fig = make_subplots(
    rows=1, cols=2,
    subplot_titles=("U vs. S", "F vs. T"),
    horizontal_spacing=0.1
)

# --------------------------------------------------------
# 1) Left subplot: "U vs. S"
# --------------------------------------------------------
# (a) U(S) line – does NOT change with 's', so we add it once outside frames.
fig.add_trace(
    go.Scatter(
        x=S_fixed,
        y=U_curve,
        mode="lines",
        line=dict(color="blue"),
        showlegend=False
    ),
    row=1, col=1
)

# (b) A dashed horizontal line at U = 1
#    We do this using add_shape; it won’t change with frames.
# fig.add_shape(
#     type="line",
#     x0=0, x1=2,  # covers the same range as xlim
#     y0=1, y1=1,
#     line=dict(color="gray", dash="dash"),
#     xref="x1", yref="y1"
# )

# -- Prepare placeholders for the 4 dynamic traces on the left subplot
#    (we’ll update them inside frames for each new s):
#    1) Red point at (s, u)
#    2) Tangent line
#    3) Green point for the F-intercept on the y-axis
#    4) (Optional) We can place an annotation or just name the green point “F”.

# Red point (initially just a placeholder)
trace_red_point = go.Scatter(
    x=[],
    y=[],
    mode="markers",
    marker=dict(color="red", size=8),
    name="(s, U)",
    showlegend=False
)

# Tangent line
trace_tangent_line = go.Scatter(
    x=[],
    y=[],
    mode="lines",
    line=dict(color="orange"),
    name="Tangent line",
    showlegend=False
)

# Green intercept (F)
trace_green_intercept = go.Scatter(
    x=[],
    y=[],
    mode="markers+text",
    text=["F"],        # We can label this point “F”
    textposition="top center",
    marker=dict(color="green", size=8),
    name="F-intercept",
    showlegend=False
)

# Add these dynamic traces to the left subplot
fig.add_trace(trace_red_point, row=1, col=1)
fig.add_trace(trace_tangent_line, row=1, col=1)
fig.add_trace(trace_green_intercept, row=1, col=1)

# --------------------------------------------------------
# 2) Right subplot: "F vs. T"
# --------------------------------------------------------
# Similarly, we’ll have 2 dynamic traces on the right subplot:
#    1) Orange point (t, f)
#    2) F(T) line

trace_orange_point = go.Scatter(
    x=[],
    y=[],
    mode="markers",
    marker=dict(color="orange", size=8),
    name="(t, F)",
    showlegend=False
)

trace_F_line = go.Scatter(
    x=[],
    y=[],
    mode="lines",
    line=dict(color="green"),
    name="F(T)",
    showlegend=False
)

# Add these dynamic traces to the right subplot
fig.add_trace(trace_orange_point, row=1, col=2)
fig.add_trace(trace_F_line, row=1, col=2)

# --------------------------------------------------------
# Create frames for each s-value in a range
# --------------------------------------------------------
s_values = np.linspace(0, 2, 51, endpoint=True)

frames = []
for s in s_values:
    # Compute values needed
    # 1) For left subplot
    dUdS = s
    u = 0.5 * s**2 + 1
    xlim = (0, 2)

    # Slope line endpoints
    slope_endpt0 = (xlim[0], get_slope_line(dUdS, s, u, xlim[0]))
    slope_endpt1 = (xlim[1], get_slope_line(dUdS, s, u, xlim[1]))

    # 2) For right subplot
    t = dUdS
    f = u - dUdS * s  # F at that point
    T_line = np.linspace(xlim[0], t, 100)
    S_line = np.linspace(xlim[0], s, 100)
    F_line = 0.5 * S_line**2 + 1 - T_line * S_line

    # Build the data "updates" for each of our dynamic traces (in the same order they were added):
    #  - trace_red_point ->   [s], [u]
    #  - trace_tangent_line -> [x0, x1], [y0, y1]
    #  - trace_green_intercept -> [x0], [y0]
    #  - trace_orange_point -> [t], [f]
    #  - trace_F_line -> T_line, F_line

    frame_data = [
        go.Scatter(
            x=S_fixed,
            y=U_curve,
            mode="lines",
            line=dict(color="blue"),
        ),
        go.Scatter(x=[s], y=[u], mode="markers"),  # red point
        go.Scatter(x=[slope_endpt0[0], slope_endpt1[0]],
                   y=[slope_endpt0[1], slope_endpt1[1]], mode="lines"),  # tangent line
        go.Scatter(x=[slope_endpt0[0]], y=[slope_endpt0[1]], mode="markers+text", text=["F"]),  # green intercept
        go.Scatter(x=[t], y=[f], mode="markers"),  # orange point
        go.Scatter(x=T_line, y=F_line, mode="lines")  # F(T) line
    ]

    frames.append(
        go.Frame(
            data=frame_data,
            name=f"s={s:.2f}"
        )
    )

# --------------------------------------------------------
# Add frames to the figure
# --------------------------------------------------------
fig.frames = frames

# --------------------------------------------------------
# Create the slider
# --------------------------------------------------------
# We'll build one slider that goes over all s-values.
slider_steps = []
for i, s in enumerate(s_values):
    step = dict(
        method="animate",
        args=[
            [f"s={s:.2f}"],  # frame name
            dict(mode="immediate", frame=dict(duration=0, redraw=True), transition=dict(duration=0))
        ],
        label=f"{s:.2f}"
    )
    slider_steps.append(step)

sliders = [
    dict(
        active=0,
        currentvalue={"prefix": "s = "},
        pad={"t": 50},
        steps=slider_steps
    )
]

# --------------------------------------------------------
# Update figure layout and axis settings
# --------------------------------------------------------
fig.update_xaxes(
    range=[0, 2],
    title_text="S",
    title_font=dict(color="red"),
    row=1, col=1
)
fig.update_yaxes(
    range=[-2, 4],
    title_text="U",
    title_font=dict(color="blue"),
    row=1, col=1
)

fig.update_xaxes(
    range=[0, 2],
    title_text="T",
    title_font=dict(color="orange"),
    row=1, col=2
)
fig.update_yaxes(
    range=[-2, 4],
    title_text="F",
    title_font=dict(color="green"),
    row=1, col=2
)

fig.update_layout(
    width=1000,
    height=600,
    showlegend=True,
    sliders=sliders,
    template="simple_white",
    # # Buttons to play/pause animation
    # updatemenus=[{
    #     "type": "buttons",
    #     "buttons": [
    #         {
    #             "label": "Play",
    #             "method": "animate",
    #             "args": [
    #                 None,
    #                 dict(
    #                     frame=dict(duration=300, redraw=True),
    #                     fromcurrent=True
    #                 )
    #             ]
    #         },
    #         {
    #             "label": "Pause",
    #             "method": "animate",
    #             "args": [
    #                 [None],
    #                 dict(frame=dict(duration=0, redraw=False), mode="immediate")
    #             ]
    #         }
    #     ],
    #     "pad": {"r": 10, "t": 70},
    #     "showactive": True,
    #     "x": 0.1,
    #     "xanchor": "right",
    #     "y": 0,
    #     "yanchor": "top"
    # }]
)

# --------------------------------------------------------
# Show the figure
# --------------------------------------------------------
fig.show()