# Direct Solvers

## Kronecker Product - Sparse Solver
### General algorithm
Rewrite $AX+XB=C$ as 
- $\mathcal{A}x = c$

where $\mathcal{A} = I \otimes A + B^* \otimes I$, $x = \text{vec}(X)$ and $c = \text{vec}(C)$

### TU + UT = F
Rewrite as
- $\mathcal{T}u = f$

where $\mathcal{T} = I \otimes T + T \otimes I$, $u = \text{vec}(U)$ and $f = \text{vec}(F)$

In [5]:
import numpy as np
from scipy.sparse import kron
from scipy.sparse.linalg import spsolve
import time

def kron_prod_dir(T,F):
    n = len(F)
    start_time = time.time()
    F = np.reshape(F,-1)
    A = kron(np.eye(n),T) + kron(T.transpose(),np.eye(n)) #A = I kron T + T^t kron I
    kron_time = time.time() - start_time
    U = spsolve(A, F) #Solve using scipy sparse solver
    end_time = time.time()
    solve_time = end_time - kron_time - start_time
    total_time = end_time - start_time
    timings = [kron_time, solve_time, total_time]
    U = np.reshape(U, (n,n)) #Reshape U for plotting/computing error
    return U, timings

In [8]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

def plot_sol(X,Y,U):
    plt.figure(0)
    xline = np.reshape(X, -1)
    yline = np.reshape(Y, -1)
    zline = np.reshape(U, -1)
    plt.imshow(U, extent=[0,1,0,1], origin='lower')
    plt.colorbar()
    plt.axis(aspect='image')
    plt.xlabel('x')
    plt.ylabel('y')

def compute_err(U, U_exact):
    n = len(U)
    err_inf = 0
    err_sq = 0
    for i in range(0,n):
        for j in range(0,n):
            err_sq += np.absolute(U_exact[i][j] - U[i][j])**2
            if np.absolute(U_exact[i][j] - U[i][j]) > err_inf:
                err_inf = np.absolute(U_exact[i][j] - U[i][j])
        
    err_sq = (err_sq * h**2)**0.5
    return err_inf, err_sq

In [10]:
from scipy.sparse import diags

#Define parameters
n = 125 #number of internal nodes in each direction (number of unknowns)
h = 1/(n+1) #step size

U = np.zeros([n,n])

#Define x and y as arrays between 0 and 1 with n evenly spaced points (internal nodes)
x = np.linspace(h, 1-h, n)
y = np.linspace(h, 1-h, n)

#Create internal mesh (excludes boundaries)
X, Y = np.meshgrid(x, y, indexing='ij')

#Define F 
F = 2 * np.pi**2 * np.sin(np.pi*X) * np.sin(np.pi*Y)  

#Define tridiagonal matrix T
diagonals = [[-2],[1],[1]]
T = np.multiply((-1)/(h**2), diags(diagonals, [0, -1, 1], shape=(n, n)).toarray())

#Compute exact solution for comparison
U_exact = np.sin(np.pi*X) * np.sin(np.pi*Y)

#Solve system
U, total_time = kron_prod_dir(T,F)
err_inf, err_sq = compute_err(U, U_exact)
print('Time taken:', total_time)
print('Error inf:', err_inf, '\nError sq:', err_sq)

Time taken: [0.005593299865722656, 0.0912163257598877, 0.09680962562561035]
Error inf: 5.18072938893166e-05 
Error sq: 2.590364694610492e-05
