In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import comb
import ipywidgets as widgets
from ipywidgets import interact

In [None]:
# ------------------------------------------------------------
# Bernstein basis
# ------------------------------------------------------------
def bernstein_poly(n, i, t):
    """
    Bernstein polynomial B_i^n(t).
    Supports scalar or array i, and scalar or array t.
    """
    return comb(n, i) * (t ** i) * ((1 - t) ** (n - i))


def bernstein_basis(n, t):
    """
    Returns the full Bernstein basis matrix:
    shape = (len(t), n+1) where column i is B_i^n(t).
    """
    i = np.arange(n + 1)
    return bernstein_poly(n, i[None, :], t[:, None])


# ------------------------------------------------------------
# Interactive plot: Bernstein basis polynomials
# ------------------------------------------------------------
def plot_bernstein(n):
    t = np.linspace(0, 1, 200)
    B = bernstein_basis(n, t)

    plt.figure(figsize=(12, 7))
    for i in range(n + 1):
        plt.plot(t, B[:, i], label=fr"$B_{{{n},{i}}}(t)$")

    plt.title(f"Bernstein Basis Polynomials (n={n})")
    plt.xlabel("t")
    plt.ylabel(r"$B_{n,i}(t)$")
    plt.grid(True)
    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.show()


interact(plot_bernstein, n=widgets.IntSlider(min=1, max=10, step=1, value=3));


# ------------------------------------------------------------
# Bézier curve evaluation using Bernstein basis
# ------------------------------------------------------------
def bezier_curve(control_points, t):
    """
    Evaluates a Bézier curve at parameter values t using Bernstein basis.

    Parameters
    ----------
    control_points : (n+1, d) array
    t : (m,) array in [0,1]

    Returns
    -------
    curve : (m, d) array
    B : (m, n+1) Bernstein weight matrix
    """
    control_points = np.asarray(control_points, dtype=float)
    t = np.asarray(t, dtype=float)

    n = len(control_points) - 1
    B = bernstein_basis(n, t)  # (m, n+1)
    curve = B @ control_points  # (m, d)

    return curve, B


# ------------------------------------------------------------
# Example: Curve + Bernstein visualization
# ------------------------------------------------------------
# Control Points
cps = np.array([[0, 0], [3, 2], [6, 2], [8, 0]], dtype=float)

# Parameter range
t = np.linspace(0, 1, 200)

# Evaluate curve + weights
curve, B = bezier_curve(cps, t)
curve_degree = len(cps) - 1

# Select a value for visualization
t_selected = 0.30
B_selected = bernstein_poly(curve_degree, np.arange(curve_degree + 1), t_selected)
point_selected = B_selected @ cps


# ------------------------------------------------------------
# Plot 1: Bézier curve and control polygon
# ------------------------------------------------------------
plt.figure(figsize=(12, 7))
plt.plot(cps[:, 0], cps[:, 1], "ro--", label="Control Polygon")
plt.plot(curve[:, 0], curve[:, 1], "b-", label="Bézier Curve")
plt.scatter(point_selected[0], point_selected[1], color="g", s=100,
            label=f"Point at t={t_selected:.2f}")
plt.xlabel("x")
plt.ylabel("y")
plt.grid(True)
plt.legend()
plt.title("Bézier Curve and Control Polygon")
plt.show()


# ------------------------------------------------------------
# Plot 2: Bernstein basis functions + values at t_selected
# ------------------------------------------------------------
colors = plt.cm.gist_rainbow(np.linspace(0, 1, curve_degree + 1))

plt.figure(figsize=(12, 7))
for i in range(curve_degree + 1):
    plt.plot(t, B[:, i], label=fr"$B_{{{curve_degree},{i}}}(t)$", color=colors[i])
    plt.scatter(t_selected, B_selected[i], color=colors[i], s=60)
    plt.text(t_selected, B_selected[i], f"{B_selected[i]:.2f}",
             verticalalignment="bottom", color=colors[i])

plt.xlabel("t")
plt.ylabel("Bernstein basis value")
plt.grid(True)
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
plt.title(f"Bernstein Basis Polynomials for Bézier Curve (n={curve_degree})")
plt.show()