In [58]:
import torch
import numpy as np
from scipy.sparse import csc_matrix
from scipy.sparse.linalg import cg
from IPython.display import Image

## basics

求解线性方程组  $Ax=b$ 的**迭代**方法。它特别适合于大型稀疏矩阵的求解，并且当 $A$ 为对称正定矩阵时，共轭梯度法特别有效。
- $x=A^{-1}b$
- 正定矩阵：$x^TAx\geq 0$

TRPO 的关系

- Hvp：hessian vector product

In [35]:
# https://zh.wikipedia.org/wiki/%E5%85%B1%E8%BD%AD%E6%A2%AF%E5%BA%A6%E6%B3%95
Image(url='../../imgs/cg.png', width=400)

In [52]:
def conjugate_gradient(A, b, delta=0., max_iterations=float('inf')):
    x = torch.zeros_like(b)
    r = b.clone()
    p = b.clone()

    i = 0
    while i < max_iterations:
        AVP = A @ p

        dot_old = r @ r
        alpha = dot_old / (p @ AVP)
        
        # 避免由于数值问题导致的分母过小
        if torch.abs(torch.dot(p, AVP)) < 1e-10:
            break
            
        x_new = x + alpha * p

        if (x - x_new).norm() <= delta:
            return x_new

        i += 1
        r = r - alpha * AVP

        beta = (r @ r) / dot_old
        p = r + beta * p

        x = x_new
    return x


### cases 1

In [51]:
# 定义一个对称正定矩阵 A
A = torch.tensor([[4., 1.], [1., 3.]], dtype=torch.float32)
b = torch.tensor([1., 2.], dtype=torch.float32)
    
x = conjugate_gradient(A, b, delta=1e-5, max_iterations=1000)
x

tensor([nan, nan])

In [49]:
A @ x

tensor([1., 2.])

In [45]:
A.inverse() @ b

tensor([0.0909, 0.6364])

### cases 2

In [53]:
M = torch.rand(9).reshape((3, 3))
A = M.T @ M

b = torch.rand(3)

x = conjugate_gradient(A, b, delta=1e-5, max_iterations=1000)
x

tensor([-1.9017,  0.9358,  1.7556])

In [54]:
A.inverse() @ b

tensor([-1.9017,  0.9358,  1.7556])

### cases 3

In [59]:
P = np.array([[4, 0, 1, 0],
              [0, 5, 0, 0],
              [1, 0, 3, 2],
              [0, 0, 2, 4]])
A = csc_matrix(P)
b = np.array([-1, -0.5, -1, 2])
x, exit_code = cg(A, b, atol=1e-5)
print(exit_code)

0


In [60]:
x

array([ 8.05079565e-17, -1.00000000e-01, -1.00000000e+00,  1.00000000e+00])

In [61]:
np.allclose(A.dot(x), b)

True

In [63]:
A = torch.tensor([[4, 0, 1, 0],
              [0, 5, 0, 0],
              [1, 0, 3, 2],
              [0, 0, 2, 4]], dtype=torch.float32)
b = torch.tensor([-1, -0.5, -1, 2], dtype=torch.float32)
x = conjugate_gradient(A, b, delta=1e-5, max_iterations=1000)

In [65]:
torch.allclose(A @ x, b)

True