In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [None]:
import numpy as np
from numpy import newaxis as na
import scipy
import scipy.sparse as sps
from scipy.sparse.linalg import spsolve, LinearOperator
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

In [None]:
from pyamg.classical import ruge_stuben_solver
norm  = lambda x: np.max(np.abs(x)) if len(x)>0 else 0.0
kron3 = lambda x,y,z: sps.kron(x,sps.kron(y,z))

In [None]:
from tensormesh import HexCubePoisson
from maps import LinearIsopMap
from topology import CubicTopology
from poisson import PoissonProblem

# Setup mesh

In [None]:
N  = 8

Ex = 8
Ey = Ex
Ez = Ex

periodic = False
    
do_assemble = False

In [None]:
# def f(X):
#     x = X[:,0]
#     y = X[:,1]
#     z = X[:,2]
    
#     return np.sin(np.pi*x)*np.sin(np.pi*y)*np.sin(np.pi*z)

# def f2(X):
#     x = X[:,0]
#     y = X[:,1]
#     z = X[:,2]
    
#     return np.sin(np.pi*x)*np.sin(np.pi*y)*np.sin(np.pi*z)*3*(np.pi)**2

def f(X):
    x = X[:,0]
    y = X[:,1]
    z = X[:,2]
    
    p = np.pi*2
    r = np.cos(p*x)*np.cos(p*y)*np.cos(p*z)
    if not periodic:
        r += x
    return r

def f2(X):
    x = X[:,0]
    y = X[:,1]
    z = X[:,2]
    
    p = np.pi*2
    return np.cos(p*x)*np.cos(p*y)*np.cos(p*z)*3*(p)**2

In [None]:
lmap = LinearIsopMap()

topo = CubicTopology(N, (Ex, Ey, Ez),
                    periodic=periodic)
topo.build()

etn = topo.elem_to_vertex
Q, etd = topo.Q, topo.elem_to_dof
R = topo.R
boundary_dofs = topo.boundary_dofs

In [None]:
vertex_ref = topo.get_vertex_ref()

vertex_phys = vertex_ref.copy()
vertex_phys[:,0] *= 1
vertex_phys[:,1] *= 1
vertex_phys[:,2] *= 1

shift = 0.5
chi, eta, zeta = vertex_ref.T+shift
sx = sy = sz = 0.1
vp = vertex_phys
sin3  = np.sin(np.pi*chi)*np.sin(np.pi*eta)*np.sin(np.pi*zeta)
vp[:,0] = chi +sx*sin3
vp[:,1] = eta +sy*sin3
vp[:,2] = zeta+sz*sin3

# Poisson

In [None]:
poisson = PoissonProblem(topo, lmap)
poisson.build(vertex_phys)

In [None]:
# # Build poisson stiffness matrix A
# p = poisson
# G11, G12, G13 = p.G11, p.G12, p.G13
# G21, G22, G23 = p.G21, p.G22, p.G23
# G31, G32, G33 = p.G31, p.G32, p.G33
# D1, D2, D3    = p.D1, p.D2, p.D3

# if do_assemble:
#     A0a = []
#     for i in range(n_elem):
#         A0a += [D1.T.dot(G11[i].dot(D1)+G12[i].dot(D2)+G13[i].dot(D3))+\
#                 D2.T.dot(G21[i].dot(D1)+G22[i].dot(D2)+G23[i].dot(D3))+\
#                 D3.T.dot(G31[i].dot(D1)+G32[i].dot(D2)+G33[i].dot(D3))]
#     A0 = sps.block_diag(A0a).tocsr()
#     A0 = Q.T.dot(A0.dot(Q))
#     A  = R.dot(A0.dot(R.T))

In [None]:
p = poisson
if periodic:
    nn = p.n_dofs
else:
    nn = (p.nz_dofs-2)*(p.ny_dofs-2)*(p.nx_dofs-2)

linOp = LinearOperator((nn, nn), matvec=p.apply_A)

M = HexCubePoisson(N,Ex,L=2,periodic=periodic)
M.build_mesh()
precond = LinearOperator((nn,nn), 
                         matvec=M.solve)

## Solve System

In [None]:
dof_phys = p.dof_phys
fh  = f2(dof_phys)
fl = fh
rhs = p.B.dot(fl)
radj = np.zeros(p.nx_dofs*p.ny_dofs*p.nz_dofs)
radj[boundary_dofs] = f(dof_phys)[boundary_dofs]
rhs = R.dot(rhs-p.apply_A(radj, apply_R=False))
exact = f(dof_phys)

if periodic:
    rhs -= np.mean(rhs)    

In [None]:
# # Check apply_A against full matrix
# if do_assemble:
#     print norm(p.apply_A(rhs)-A.dot(rhs))

### Solve with AMG

In [None]:
if do_assemble:
    ml = ruge_stuben_solver(A)
    residuals = []
    sol = R.T.dot(ml.solve(rhs, tol=1e-14, 
                           maxiter=500, residuals=residuals,
                           accel='cg'))
    sol[boundary_dofs] = f(dof_phys)[boundary_dofs]

    
    if periodic:
        sol   -= sol[0]
        exact -= exact[0]

    print len(residuals), residuals[-1]
    print 
    print norm(exact-sol)/norm(exact)

### Solve with CG

In [None]:
class CB(object):
    def __init__(self):
        self.n_iter = 0
    def __call__(self, x):
        self.n_iter += 1
        
cb = CB()
solcg, errc = sps.linalg.cg(linOp, rhs, tol=1e-14, 
                            maxiter=2000, callback=cb,
                            M=precond)

solcg = R.T.dot(solcg)
if periodic:
    
    solcg -= np.mean(solcg)
    exact -= np.mean(exact)
else:
    solcg[boundary_dofs] = f(dof_phys[boundary_dofs])
    
print cb.n_iter, norm(rhs-p.apply_A(R.dot(solcg)))
print
print norm(exact-solcg)/norm(exact)
if do_assemble:
    print norm(sol-solcg)

In [None]:
# Dirichlet convergence on twisted domain
six = [8.15638640052e-05,
       8.31382849419e-07,
       4.30139456801e-09]

five = [0.000962099963069,
        5.61991711754e-06,
        8.71037655601e-08]

four = [0.00285615225999,
        7.57244184609e-05,
        1.94727477124e-06]

pt = plt.loglog
hv = 2.0/(4*np.arange(1,4))
Ka = np.array([4, 8, 16])

pt(Ka, four)
pt(Ka, five)
pt(Ka, six)

print np.log2(four[-1]/four[-2])
print np.log2(five[-1]/five[-2])
print np.log2(six[-1]/six[-2])

In [None]:
# # Dirichlet convergence
# six = [0.000101995581293,
#        1.03927870865e-06,
#        5.38245403803e-09]

# five = [0.00120279085549,
#         7.01764904154e-06,
#         1.09136199744e-07]

# four = [0.00359712312981,
#         9.49320670836e-05,
#         2.43538830924e-06]

# pt = plt.loglog
# hv = 2.0/(4*np.arange(1,4))
# Ka = np.array([4, 8, 16])

# pt(Ka, four)
# pt(Ka, five)
# pt(Ka, six)

# print np.log2(four[-1]/four[-2])
# print np.log2(five[-1]/five[-2])
# print np.log2(six[-1]/six[-2])

In [None]:
# # Periodic convergence
# six = [0.000201669449932,
#        2.06566033268e-06,
#        1.06832162949e-08]

# five = [0.00237745836202,
#         1.39552087489e-05,
#         2.16385736726e-07]

# four = [0.00708725698047,
#         0.000188003387049,
#         4.83988801374e-06]

# pt = plt.loglog
# hv = 2.0/(4*np.arange(1,4))
# Ka = np.array([4, 8, 16])

# pt(Ka, four)
# pt(Ka, five)
# pt(Ka, six)

# print np.log2(four[-1]/four[-2])
# print np.log2(five[-1]/five[-2])
# print np.log2(six[-1]/six[-2])

In [None]:
dp = dof_phys.reshape((p.nz_dofs,p.ny_dofs,p.nx_dofs,3))
if periodic:
    dp = dp.copy()
    dp[dp==1.0] = -1.0
ds = solcg.reshape((p.nz_dofs,p.ny_dofs,p.nx_dofs))

In [None]:
fig = plt.figure()
ax = fig.gca(projection='3d')
s = (p.ny_dofs,p.nx_dofs)
k = int(0.4*p.nz_dofs)
X = dp[k,:,:,0]
Y = dp[k,:,:,1]

ax.plot_wireframe(X, Y, 
                  exact.reshape((p.nz_dofs,p.ny_dofs,p.nx_dofs))[k,:,:])
ax.plot_wireframe(X, Y, ds[k,:,:].reshape(s),
                  color='g')
plt.show()

In [None]:
plt.scatter(X,Y)