### Conjugate gradient for linear regression
#### General Objective Quadratic Form: 

$f = \frac{1}{2} \sum_{n}{(y_n - w^{T}x_n)^2} + \frac{1}{2}\lambda w^{t}w = \frac{1}{2}w^{T}Aw - w^{T}b$

$ A = \lambda I + X^T X  $ 

Unique minimum exists if and only if A is postive definite.

$ b = X^T Y$

#### Line Search
$ x_{k+1} = x_k + \alpha_{k}p_k$ 

where $p_k$ is current searching direction and solving $a_k$ for step

$ f_{(x_{k+1})} = \frac{1}{2}(x_k + a_k p_k)^T A (x_k + a_k p_k) - (x_k + a_k p_k)^T b$

$ \frac{\partial{f_{(x_{k+1})}}}{\partial{a_k}} = p_{k}^T A x_k + p_{k}^T A a_k p_k - p_{k}^T b = 0 $

$a_k = \frac{p_{k}^T b - p_{k}^T A x_k}{p_{k}^T A p_k}$

In [1]:
def conjugate_gradient(A,b):
    dim = A.shape[0]
    # init w with zero vector
    w = np.zeros([dim,1])
    # init p with negative gradient at initial w
    g = A @ w - b
    p = -g
    old_val = 1e32
    for i in range(100):
        alpha = (p.T @ -g) / (p.T @ A @ p)
        w = w + alpha * p
        new_g = A @ w - b
        beta = (new_g.T @ new_g) / (g.T @ g)
        p = -new_g + beta * p
        g = new_g
        val = 0.5 * w.T @ (g-b)
        if old_val - val < 1:
            break
        old_val = val
    return w

### Sparse linear regression
#### Conjugate gradient without $A$ to save N^2 memory

In [2]:
def sparse_conjugate_gradient(datas,b, lamda):
    dim = A.shape[0]
    # init w with zero vector
    w = np.zeros([dim,1])
    # init p with negative gradient at initial w
    g = A @ w - b
    p = -g
    old_val = 1e32
    for i in range(100):
        pap = lamda * p.T @ p
        for n in range(len(datas)):
            pap += (p.T @ datas[n])**2
        # alpha = (p.T @ -g) / (p.T @ A @ p)
        alpha = (p.T @ -g) / pap
        w = w + alpha * p
        # new_g = A @ w - b
        new_g = -b + lamda * w
        for n in range(len(datas)):
            new_g += datas[n] * (datas[n].T @ w)
        beta = (new_g.T @ new_g) / (g.T @ g)
        p = -new_g + beta * p
        g = new_g
        val = 0.5 * w.T @ (g-b)
        if old_val - val < 1:
            break
        old_val = val
    return w