In [1]:
import sympy as sp

In [2]:
A_ext = sp.Matrix([
 [ 5,  7, -5, -47],
 [ 0, -2,  2,  10],
 [-4, -8, -7,  63],
 [ 1,  1,  2,  -1],
 [ 2, -1,  2,  -4],
 [ 4,  1,  4,  -2]
])
A_ext

Matrix([
[ 5,  7, -5, -47],
[ 0, -2,  2,  10],
[-4, -8, -7,  63],
[ 1,  1,  2,  -1],
[ 2, -1,  2,  -4],
[ 4,  1,  4,  -2]])

In [3]:
B = A_ext[:, -1]
B

Matrix([
[-47],
[ 10],
[ 63],
[ -1],
[ -4],
[ -2]])

In [4]:
A = A_ext[:, :-1]
A

Matrix([
[ 5,  7, -5],
[ 0, -2,  2],
[-4, -8, -7],
[ 1,  1,  2],
[ 2, -1,  2],
[ 4,  1,  4]])

In [5]:
X = sp.Matrix(sp.symbols("x1:4"))
X

Matrix([
[x1],
[x2],
[x3]])

In [6]:
mse = A*X - B
mse = mse.T * mse / A.rows
mse = mse[0]
mse

(-2*x2 + 2*x3 - 10)**2/6 + (-4*x1 - 8*x2 - 7*x3 - 63)**2/6 + (x1 + x2 + 2*x3 + 1)**2/6 + (2*x1 - x2 + 2*x3 + 4)**2/6 + (4*x1 + x2 + 4*x3 + 2)**2/6 + (5*x1 + 7*x2 - 5*x3 + 47)**2/6

In [7]:
from collections import namedtuple

Result = namedtuple("Result", "argmin,min,steps,path")

def grad_decent(f, start, vars, step, eps=0.01):
  for_subs = lambda x: dict(zip(vars, x))
  d_f = sp.Matrix([sp.diff(f, v) for v in vars])

  X = sp.Matrix(start)
  f_prev = f.subs(for_subs(X))
  path = [(list(X), f_prev)]
  f_curr = None

  while True:
    X -= step * d_f.subs(for_subs(X))
    f_curr = f.subs(for_subs(X))
    
    if abs(f_curr - f_prev) < eps:
      break
    
    if f_curr < f_prev:
      path.append((list(X), f_curr))
      f_prev = f_curr
    else:
      break
  
  return Result(
      argmin=path[-1][0], 
      min=f.subs(for_subs(X)),
      steps=len(path) - 1,
      path=path
  )


In [8]:
argmin, min, steps, path = grad_decent(mse, (0, 0, 0), X, 0.01)

In [9]:
print(f"Got mse = {min:0.4f} on {argmin} in {steps} steps")

Got mse = 36.6162 on [-0.167388149990381, -6.89548865273941, -0.536104978582283] in 60 steps


In [10]:
A * sp.Matrix(argmin) - B

Matrix([
[0.575163573783669],
[ 2.71876734831424],
[-3.41380332804727],
[-7.13508675989435],
[ 9.48850239559408],
[-7.70946116703006]])

In [11]:
from pprint import pprint
pprint(path)

[([0, 0, 0], 6299/6),
 ([-1.68000000000000, -2.84000000000000, -0.680000000000000], 247.497333333333),
 ([-2.29346666666667, -4.10440000000000, -0.790000000000000], 104.486005072593),
 ([-2.47595688888889, -4.71219777777778, -0.722969777777778], 73.7744351503798),
 ([-2.48449883555556, -5.03898750814815, -0.620976468148148], 64.1857120851198),
 ([-2.42352395196049, -5.24020775715556, -0.530073773777778], 59.5591801838142),
 ([-2.33577437740421, -5.38153056800477, -0.461073818362403], 56.5686545495930),
 ([-2.23893438867603, -5.49162915212318, -0.412953715575168], 54.3188045027200),
 ([-2.14042833655631, -5.58365270715924, -0.381557545907986], 52.4834821824143),
 ([-2.04342438650519, -5.66404931755218, -0.362603262751765], 50.9155148661358),
 ([-1.94928823396929, -5.73624833862081, -0.352549335645413], 49.5397376262838),
 ([-1.85859827530033, -5.80223662841780, -0.348704491658409], 48.3141676628190),
 ([-1.77157404413591, -5.86325973173118, -0.349105210896942], 47.2131528644480),
 ([-1.