In [16]:
import torch 
import numpy as np 
import jax
import jax.numpy as jnp
from jax import config
config.update("jax_enable_x64", True) # Use double precision in JAX
import matplotlib.pyplot as plt
import torch.optim as optim
torch.set_default_tensor_type(torch.DoubleTensor) 



Consider: $A_\epsilon u=g$, ($A_\epsilon = A_0+\epsilon I$)
$$
A_0 =
\left(
\begin{array}{rrr}
1  &  -1   &  0\\
-1 &   2   &  -1\\
0  &  -1   &  1
\end{array}
\right),\quad
g=
\left(
\begin{array}{r}
-1 \\
-1 \\
2  \\
\end{array}
\right)\in R(A_0), \quad
p=
\begin{pmatrix}
1\\
1\\
1
\end{pmatrix}
\in N(A_0).
$$
<br><br><br>

For $f(u) = \frac{1}{2}u^T A u -g^T u$

Gradient descent method: 

$$
u^{k+1} = u^{k} - \eta \nabla f(u^{k})
 =u^{k} - \eta (Au^{k}-g)
$$

Scaled gradient descent
$$
u^{k+1} =u^{k} - \eta [{\rm diag}(A)]^{-1}(Au^{k}-g)
$$

<br>
<br>




In [5]:
%%time 
## GD for 3by3 system 

print("Plain GD: number of iterations needed for 3 by 3 system")
for eps in [0.1,0.01,0.001,1e-4,1e-5,1e-9, 0.]: 
    A3 = torch.tensor([[1+eps,-1,0],[-1,2+eps,-1],[0,-1,1+eps]])
    x = torch.zeros(3)
    x = x.view(3,1)
    b = torch.tensor([[-1.],[-1.],[2.]]) # must be in kernel is eps = 0
    x.data = torch.tensor([[1.0],[2.],[3.0]])
    tol = 1e-8 # tolerance for residual norm 
    residual_norm = torch.norm(torch.matmul(A3,x) -b,2)
    iters = 0 
    while residual_norm > tol: 
        gd = torch.matmul(A3,x) - b 
        x = x - 0.5*gd 
        residual_norm = torch.norm(gd,2)
        iters += 1 
        if iters > 1000000: 
            break
        assert torch.isnan(residual_norm)!=True, "norm is nan, reset learning rate" #somehow nan>tol returns false
    if iters > 1000000: 
        print("eps = "+str(eps)+": over 1,000,000")
    else:
        print("eps = "+str(eps)+": ", iters)

Plain GD: number of iterations needed for 3 by 3 system
eps = 0.1:  340
eps = 0.01:  3006
eps = 0.001:  25506
eps = 0.0001:  209052
eps = 1e-05: over 1,000,000
eps = 1e-09:  29
eps = 0.0:  29
CPU times: user 41.3 s, sys: 139 ms, total: 41.5 s
Wall time: 41.7 s


eps = 1e-05: over 100 iterations


dtype('float64')

Expanded system:
    
Write $u\in \mathbb{R}^3=u_1e_1+u_2e_2+u_3e_3$ as
$$
u=\tilde u_1 e_1+\tilde u_2e_2+\tilde u_3e_3+\tilde
    u_4 p =P\tilde u,
$$
where 
$$
P=\begin{pmatrix}
    1 & 0 & 0 & 1\\
    0 & 1 & 0 & 1\\
    0 & 0 & 1 & 1
\end{pmatrix}, \quad p=
\begin{pmatrix}
    1 \\ 1 \\ 1
\end{pmatrix}
\in {\rm ker}(A_0). 
$$

The equation $A_{\epsilon}u=g$ becomes
$$
A_{\epsilon}P\tilde u=g \Longleftrightarrow
(P^TA_{\epsilon}P)\tilde u=P^Tg,
$$

This leads to a semi-definite system:
$$
\begin{pmatrix}
    1+\epsilon  &  -1   &  0&\epsilon\\
    -1 &   2+\epsilon   &  -1&\epsilon\\
    0  &  -1   &  1+\epsilon&\epsilon\\
    \epsilon&\epsilon&\epsilon&3\epsilon
\end{pmatrix}
\tilde u=
\begin{pmatrix}
      -1 \\
    -1 \\
    2  \\
    0\\
\end{pmatrix}.
$$

In [4]:
## GD for 4by4 system, GD
print("GD: 4 by 4 system")
P = torch.tensor([[1.,0.,0.,1.],[0.,1.,0.,1.],[0.,0.,1.,1.]])
for eps in [0.1,0.01,0.001,1e-4,1e-5,1e-9,0.]: 
    A3 = torch.tensor([[1+eps,-1,0],[-1,2+eps,-1],[0,-1,1+eps]])
    A4 = torch.tensor([[1+eps,-1.0,0,eps],[-1,2+eps,-1,eps],[0,-1,1+eps,eps],[eps,eps,eps,3*eps]])
    x = torch.rand(4)
    x = x.view(4,1)
    b = torch.tensor([[-1.],[-1.],[2.],[0.]]) #
    tol = 1e-8
    residual_norm = torch.norm(A3@(P@x)-P@b,2)
    iters = 0 
    while residual_norm > tol: 
        gd = torch.matmul(A4,x) - b 
        x.data = x.data - 0.5*gd 
        residual_norm = torch.norm(A3@(P@x)-P@b,2)
        iters += 1 
        if iters > 100000: 
            break
    assert torch.isnan(residual_norm)!=True, "norm is nan, reset learning rate"
    print("eps = "+str(eps)+": ", iters)
print()

GD: 4 by 4 system
eps = 0.1:  76
eps = 0.01:  699
eps = 0.001:  6126
eps = 0.0001:  48390
eps = 1e-05:  100001
eps = 1e-09:  28
eps = 0.0:  29



In [5]:
#GD for 4by4 system, modified Jacobi preconditioner
print("Scaled GD: 4 by 4 system")
P = torch.tensor([[1.,0.,0.,1.],[0.,1.,0.,1.],[0.,0.,1.,1.]])
for eps in [0.1,0.01,0.001,1e-4,1e-5,1e-9]: 
    A3 = torch.tensor([[1+eps,-1,0],[-1,2+eps,-1],[0,-1,1+eps]])
    A4 = torch.tensor([[1+eps,-1.0,0,eps],[-1,2+eps,-1,eps],[0,-1,1+eps,eps],[eps,eps,eps,3*eps]])
    D = torch.diag(torch.diag(A4))
    x = torch.rand(4)
    x = x.view(4,1)
    b = torch.tensor([[-1.],[-1.],[2.],[0.]]) 
    tol = 1e-8
    residual_norm = torch.norm(A3@(P@x)-P@b,2)
    iters = 0 
    while residual_norm > tol: 
        gd = torch.matmul(A4,x) - b 
        x.data = x.data - 0.7*torch.matmul(torch.linalg.inv(D),gd)
        residual_norm = torch.norm(A3@(P@x)-P@b,2)
        iters += 1 
        if iters > 100000: 
            break
    assert torch.isnan(residual_norm)!=True, "norm is nan, reset learning rate"
    print("eps = "+str(eps)+": ", iters)

    

Scaled GD: 4 by 4 system
eps = 0.1:  19
eps = 0.01:  21
eps = 0.001:  20
eps = 0.0001:  20
eps = 1e-05:  20
eps = 1e-09:  21


## Playground for JAX

This part was added after the CEMRACRS 2023. On day 4, we were introduced to using JAX for improving code efficiency in python (Martin Guerra from UW Madison. See git repo https://github.com/maguerrap). 

- Automatic differentiation with JAX.
- JIT to improve efficiency 

For the 3 by 3 system, we run GD for 1000 steps. 

1. Use the function gd_update  (6.28s)
2. Use the function gd_update_jit (0.10s)
3. Explicitly computed gradient (0.056s)

In [65]:
%%time 
## JAX version GD. For an equivalent problem, we use gradient descent for the quadratic loss 
def loss_3by3(x):
    """
    x: column vector 
    """
    return jnp.sum(0.5*x.T@A3@x - x.T@b )

def gd_update(x,lr):
    return x - lr*jax.grad(loss_3by3)(x)
def residual_norm(A3,x,b):
    return jnp.linalg.norm(A3@x - b ,2)
eps = 1e-5  
A3 = jnp.asarray([[1+eps,-1,0],[-1,2+eps,-1],[0,-1,1+eps]]) 
b = jnp.asarray([-1.,-1.,2.]).reshape(3,1)
x = jnp.asarray([1.,2.,3.]).reshape(3,1)
tol = 1e-8
lr = 0.5 

num_iters = 1000
for i in range(num_iters):
    x = gd_update(x,lr)
    if residual_norm(A3,x,b) < tol: 
        print("eps = "+str(eps)+": ", i)
        break 
    elif i == num_iters-1:
        print("eps = "+str(eps)+": over {} iterations".format(num_iters))
            

eps = 1e-05: over 1000 iterations
CPU times: user 5.95 s, sys: 18.4 ms, total: 5.97 s
Wall time: 5.99 s


In [66]:
%%time 

num_iters = 1000
gd_update_jit = jax.jit(gd_update, static_argnames = 'lr')
residual_norm_jit = jax.jit(residual_norm)

for i in range(num_iters):
    x = gd_update_jit(x,lr)
    if residual_norm_jit(A3,x,b) < tol: 
        print("eps = "+str(eps)+": ", i)
        break 
    elif i == num_iters-1:
        print("eps = "+str(eps)+": over {} iterations".format(num_iters))
        

eps = 1e-05: over 1000 iterations
CPU times: user 129 ms, sys: 6 ms, total: 135 ms
Wall time: 134 ms


In [67]:
%%time 
## GD for 3by3 system 

print("Plain GD: number of iterations needed for 3 by 3 system")
num_iters = 1000
for eps in [0.00001]: 
    A3 = torch.tensor([[1+eps,-1,0],[-1,2+eps,-1],[0,-1,1+eps]])
    x = torch.zeros(3)
    x = x.view(3,1)
    b = torch.tensor([[-1.],[-1.],[2.]]) # must be in kernel is eps = 0
    x.data = torch.tensor([[1.0],[2.],[3.0]])
    tol = 1e-8 # tolerance for residual norm 
    residual_norm = torch.norm(torch.matmul(A3,x) -b,2)
    iters = 0 
    while residual_norm > tol: 
        gd = torch.matmul(A3,x) - b 
        x = x - 0.5*gd 
        residual_norm = torch.norm(gd,2)
        iters += 1 
        if iters > num_iters: 
            break
        assert torch.isnan(residual_norm)!=True, "norm is nan, reset learning rate" #somehow nan>tol returns false
    if iters > num_iters: 
        print("eps = "+str(eps)+": over {}".format(num_iters))
    else:
        print("eps = "+str(eps)+": ", iters)

Plain GD: number of iterations needed for 3 by 3 system
eps = 1e-05: over 1000
CPU times: user 35.6 ms, sys: 2.25 ms, total: 37.8 ms
Wall time: 36.4 ms
