In [1]:
%pip install bqplot

In [63]:
import math
import sympy
from sympy import Matrix, Symbol, lambdify, simplify, latex
from IPython.display import Math

def least_squares(points, n, debug=False):
    # 매개변수 초기화
    parameters = [Symbol(f"b_{i}") for i in range(n)]

    # f 함수 설정: b_1 x^0 + b_2 x^1 + b_3 x^2 + ...
    x = Symbol("x")
    f = lambdify(x, sum([parameters[i] * (x**i) for i in range(n)]))

    # 오차 제곱 계산
    error_squares = 0
    for X, Y in points: # 데이터가 너무 클 때 매번 simplify를 안해주면 심각하게 렉걸릴수도
        error_squares = simplify(error_squares + (Y - f(X))**2)

    gradient = sympy.derive_by_array(error_squares, parameters)
    stationary_points = sympy.linsolve(gradient, parameters)

    # Jupyter에 과정을 수식으로 출력
    if debug: display(Math(f"""
        \\begin{{aligned}}
        f(x)&={latex(f(x))} \\\\
        E &= \\sum_{{i=1}}^n e_i = {latex(error_squares)} \\\\
        \\nabla E &= \\mathbf 0 \\iff ({",".join([latex(p) for p in parameters])}) \\in {latex(stationary_points)}
        \\end{{aligned}}
    """))

    # 임계점 중에서 최솟값을 갖는 곳 찾기
    min_value = (None, math.inf)
    for point in stationary_points:
        value = error_squares.subs({b: point[index] for index, b in enumerate(parameters)})
        if value < min_value[1]:
            min_value = (point, value)

    result = min_value[0]
    result_f = f(x).subs({b: result[index] for index, b in enumerate(parameters)})
    if debug: display(Math("\\text{solution: \\(f(x)=" + latex(result_f) + "\\)}"))
    return result, lambdify(x, result_f)

least_squares([[1, 1], [2, 2]], 2, debug=True)

<IPython.core.display.Math object>

<IPython.core.display.Math object>

((0, 1), <function _lambdifygenerated(x)>)

In [None]:
from random import random
import bqplot.pyplot as plt
import numpy as np

x = np.linspace(0, 1, 50)

# 점 렌덤으로 생성
points = np.array([[random() for j in range(2)] for i in range(6)])

def compute():
    global points
    _, f = least_squares(points, 4)
    return list(map(f, x))

plt.figure(title="Calculating Least Squares")

# 점들 Scatter 플롯으로 그리기
points_plot = plt.scatter(points[:, 0], points[:, 1])

# 함수 y=f(x) 그리기
fn_plot = plt.plot(x, compute())

plt.xlim(0, 1)
plt.ylim(0, 1)
plt.show()

def update(change):
    global points
    points = np.array([points_plot.x, points_plot.y]).T
    fn_plot.y = compute()

# Scatter 그래프 점을 드래그할 수 있도록 하는 것
points_plot.observe(update, ["x", "y"])
points_plot.enable_move = True

VBox(children=(Figure(axes=[Axis(scale=LinearScale(max=1.0, min=0.0)), Axis(orientation='vertical', scale=Line…