In [None]:
%matplotlib inline

In [None]:
import numpy as np
from numpy import newaxis as na
import scipy
import scipy.sparse as sps
from scipy.sparse.linalg import spsolve
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

In [None]:
from pyamg.classical import ruge_stuben_solver
from pyfem.sem import SEMhat
from pyfem.topo import Interval
norm = lambda x: np.max(np.abs(x)) if len(x)>0 else 0.0
kron3 = lambda x,y,z: sps.kron(x,sps.kron(y,z))

# Setup mesh

In [None]:
N  = 8
Ex = 4
Ey = 4
Ez = 4

nx      = N+1
ny      = N+1
nz      = N+1
nx_dofs = N*Ex+1
ny_dofs = N*Ey+1
nz_dofs = N*Ez+1
n_elem  = Ex*Ey*Ez

periodic = True

if periodic:
    nx_dofs -= 1
    ny_dofs -= 1
    nz_dofs -= 1
    
semh = SEMhat(N)

In [None]:
def f(X):
    x = X[:,0]
    y = X[:,1]
    z = X[:,2]
    
    return np.sin(np.pi*x)*np.sin(np.pi*y)*np.sin(np.pi*z)

def f2(X):
    x = X[:,0]
    y = X[:,1]
    z = X[:,2]
    
    return np.sin(np.pi*x)*np.sin(np.pi*y)*np.sin(np.pi*z)*3*(np.pi)**2

# def f(X):
#     x = X[:,0]
#     y = X[:,1]
    
#     return np.cos(np.pi*x)*np.cos(np.pi*y)

# def f2(X):
#     x = X[:,0]
#     y = X[:,1]
    
#     return np.cos(np.pi*x)*np.cos(np.pi*y)*2*(np.pi)**2

sx = 0.1
def mx(chi, eta, zeta):
    return chi+sx*np.sin(np.pi*chi)*np.sin(np.pi*eta)
def dmx1(chi, eta, zeta):
    return np.ones_like(chi)+sx*np.pi*np.cos(np.pi*chi)*np.sin(np.pi*eta)
def dmx2(chi, eta, zeta):
    return np.zeros_like(chi)+sx*np.pi*np.sin(np.pi*chi)*np.cos(np.pi*eta)
def dmx3(chi, eta, zeta):
    return np.zeros_like(chi)

sy = 0.1
def my(chi, eta, zeta):
    return eta+sy*np.sin(np.pi*chi)*np.sin(np.pi*eta)
def dmy1(chi, eta, zeta):
    return np.zeros_like(eta)+sy*np.pi*np.cos(np.pi*chi)*np.sin(np.pi*eta)
def dmy2(chi, eta, zeta):
    return np.ones_like(eta)+sy*np.pi*np.sin(np.pi*chi)*np.cos(np.pi*eta)
def dmy3(chi, eta, zeta):
    return np.zeros_like(eta)

def mz(chi, eta, zeta):
    return zeta
def dmz1(chi, eta, zeta):
    return np.zeros_like(zeta)
def dmz2(chi, eta, zeta):
    return np.zeros_like(zeta)
def dmz3(chi, eta, zeta):
    return np.ones_like(zeta)

In [None]:
topo  = Interval()

vertices = np.linspace(-1, 1, Ex+1)
etvx      = np.zeros((Ex, 2), dtype=np.int)
etvx[:,0] = np.arange(Ex)
etvx[:,1] = np.arange(Ex)+1
xq = topo.ref_to_phys(vertices[etvx], semh.xgll)
jacb_det0x = topo.calc_jacb(vertices[etvx])[0]
if periodic:
    xq = xq.ravel()[:-1]

vertices = np.linspace(-1, 1, Ey+1)
etvy      = np.zeros((Ey, 2), dtype=np.int)
etvy[:,0] = np.arange(Ey)
etvy[:,1] = np.arange(Ey)+1
yq = topo.ref_to_phys(vertices[etvy], semh.xgll)
jacb_det0y = topo.calc_jacb(vertices[etvy])[0]
if periodic:
    yq = yq.ravel()[:-1]
    
vertices = np.linspace(-1, 1, Ez+1)
etvz      = np.zeros((Ez, 2), dtype=np.int)
etvz[:,0] = np.arange(Ez)
etvz[:,1] = np.arange(Ez)+1
zq = topo.ref_to_phys(vertices[etvz], semh.xgll)
jacb_det0z = topo.calc_jacb(vertices[etvz])[0]
if periodic:
    zq = zq.ravel()[:-1]

In [None]:
# Build restriction operator
if periodic:
    R0x = sps.eye(nx_dofs)
    R0y = sps.eye(ny_dofs)
    R0z = sps.eye(nz_dofs)
else:
    R0x = sps.dia_matrix((np.ones(nx_dofs),1),
                          shape=(nx_dofs-2,nx_dofs))
    R0y = sps.dia_matrix((np.ones(ny_dofs),1),
                          shape=(ny_dofs-2,ny_dofs))
    R0z = sps.dia_matrix((np.ones(nz_dofs),1),
                          shape=(nz_dofs-2,nz_dofs))

R = kron3(R0z, R0y, R0x)

rngx = np.arange(nx_dofs)
rngy = np.arange(ny_dofs)
rngz = np.arange(nz_dofs)

if not periodic:
    assert False
    boundary_dofs = np.hstack([rngx,
                               rngx+nx_dofs*(ny_dofs-1),
                               rngy[1:-1]*nx_dofs])
    boundary_dofs = np.hstack([boundary_dofs,
                               rngy[1:-1]*nx_dofs+nx_dofs-1])
else:
    boundary_dofs = np.array([],dtype=np.int)

boundary_dofs.sort()

In [None]:
dx = np.unique(xq)
dy = np.unique(yq)
dz = np.unique(zq)

XYZ = np.zeros((nz_dofs,ny_dofs,nx_dofs,3))
XYZ[:,:,:,0] = dx[na,na,:]
XYZ[:,:,:,1] = dy[na,:,na]
XYZ[:,:,:,2] = dz[:,na,na]

#X, Y = np.meshgrid(dx, dy)
# dof_ref = np.zeros((len(X.ravel()), 2))
# dof_ref[:,0] = X.ravel()
# dof_ref[:,1] = Y.ravel()

dof_ref = XYZ.reshape((-1,3))

dof_phys = np.zeros_like(dof_ref)
dof_phys[:,0] = mx(dof_ref[:,0], dof_ref[:,1], dof_ref[:,2])
dof_phys[:,1] = my(dof_ref[:,0], dof_ref[:,1], dof_ref[:,2])
dof_phys[:,2] = mz(dof_ref[:,0], dof_ref[:,1], dof_ref[:,2])

In [None]:
ax = np.repeat(semh.wgll[na,:-1], Ex, axis=0)
ax = ax.ravel()
if not periodic:
    ax = np.hstack([ax,ax[0]])
ay = np.repeat(semh.wgll[na,:-1], Ey, axis=0)
ay = ay.ravel()
if not periodic:
    ay = np.hstack([ay,ay[0]])
az = np.repeat(semh.wgll[na,:-1], Ez, axis=0)
az = az.ravel()
if not periodic:
    az = np.hstack([az,az[0]])

wvals  = az[:,na,na]*ay[na,:,na]*ax[na,na,:]
wvals *= jacb_det0x*jacb_det0y*jacb_det0z
wvals  = wvals.ravel()

In [None]:
# Make Q
etdx = np.arange(Ex*(N+1))
etdx = etdx.reshape((Ex, -1))
etdx -= np.arange(Ex)[:,na]
if periodic:
    etdx[-1,-1] = etdx[0,0]
    
etdy = np.arange(Ey*(N+1))
etdy = etdy.reshape((Ey, -1))
etdy -= np.arange(Ey)[:,na]
if periodic:
    etdy[-1,-1] = etdy[0,0]
    
etdz = np.arange(Ez*(N+1))
etdz = etdz.reshape((Ez, -1))
etdz -= np.arange(Ez)[:,na]
if periodic:
    etdz[-1,-1] = etdz[0,0]

cols = etdx.ravel()
rows = np.arange(len(cols))
vals = np.ones(len(cols))
Q0x = sps.coo_matrix((vals,(rows,cols))).tocsr()

cols = etdy.ravel()
rows = np.arange(len(cols))
vals = np.ones(len(cols))
Q0y = sps.coo_matrix((vals,(rows,cols))).tocsr()

cols = etdz.ravel()
rows = np.arange(len(cols))
vals = np.ones(len(cols))
Q0z = sps.coo_matrix((vals,(rows,cols))).tocsr()

Q = kron3(Q0z, Q0y, Q0x)

a = Q.dot(np.arange(nx_dofs*ny_dofs*nz_dofs)).reshape((nz*Ez,ny*Ey,nx*Ex))
etd = np.zeros((Ex*Ey*Ez, nx*ny*nz), dtype=np.int)
ind = 0
for iz in range(Ez):
    for iy in range(Ey):
        for ix in range(Ex):
            indz = iz*nz
            indy = iy*ny
            indx = ix*nx
            etd[ind,:] = a[indz:indz+nz,indy:indy+ny,indx:indx+nx].ravel()
            ind += 1
        
cols = etd.ravel()
rows = np.arange(len(cols))
vals = np.ones(len(cols))
Q = sps.coo_matrix((vals,(rows,cols))).tocsr()

In [None]:
# Compute jacobian information
dr = dof_ref

Jacb = np.zeros((dr.shape[0],3,3,))
Jacb[:,0,0] = dmx1(dr[:,0], dr[:,1], dr[:,2])
Jacb[:,0,1] = dmx2(dr[:,0], dr[:,1], dr[:,2])
Jacb[:,0,2] = dmx3(dr[:,0], dr[:,1], dr[:,2])
Jacb[:,1,0] = dmy1(dr[:,0], dr[:,1], dr[:,2])
Jacb[:,1,1] = dmy2(dr[:,0], dr[:,1], dr[:,2])
Jacb[:,1,2] = dmy3(dr[:,0], dr[:,1], dr[:,2])
Jacb[:,2,0] = dmz1(dr[:,0], dr[:,1], dr[:,2])
Jacb[:,2,1] = dmz2(dr[:,0], dr[:,1], dr[:,2])
Jacb[:,2,2] = dmz3(dr[:,0], dr[:,1], dr[:,2])

jacb_det = np.linalg.det(Jacb)
jacb_det = jacb_det.ravel()

Jacb_inv = np.linalg.inv(Jacb)

# Poisson

In [None]:
# build Gij
G0 = np.zeros_like(Jacb)
assert len(Jacb)==len(wvals)

for i in range(len(Jacb)):
    G0[i,:,:]  = Jacb_inv[i].dot(Jacb_inv[i].T)
    G0[i,:,:] *= wvals[i]*jacb_det[i]   

G11 = []
G12 = []
G13 = []
G21 = []
G22 = []
G23 = []
G31 = []
G32 = []
G33 = []

nn = nx*ny*nz
s  = (nn, nn)
for i in range(n_elem):
    
    G11 += [sps.dia_matrix((G0[etd[i],0,0], 0), shape=s)]
    G12 += [sps.dia_matrix((G0[etd[i],0,1], 0), shape=s)]
    G13 += [sps.dia_matrix((G0[etd[i],0,2], 0), shape=s)]
    
    G21 += [sps.dia_matrix((G0[etd[i],1,0], 0), shape=s)]
    G22 += [sps.dia_matrix((G0[etd[i],1,1], 0), shape=s)]
    G23 += [sps.dia_matrix((G0[etd[i],1,2], 0), shape=s)]
    
    G31 += [sps.dia_matrix((G0[etd[i],2,0], 0), shape=s)]
    G32 += [sps.dia_matrix((G0[etd[i],2,1], 0), shape=s)]
    G33 += [sps.dia_matrix((G0[etd[i],2,2], 0), shape=s)]

In [None]:
# Build poisson stiffness matrix A

D1 = kron3(sps.eye(nz), sps.eye(ny), semh.Dh)/jacb_det0x
D2 = kron3(sps.eye(nz), semh.Dh,     sps.eye(nx))/jacb_det0y
D3 = kron3(semh.Dh,     sps.eye(ny), sps.eye(nx))/jacb_det0z

A0a = []
for i in range(n_elem):
    A0a += [D1.T.dot(G11[i].dot(D1)+G12[i].dot(D2)+G13[i].dot(D3))+\
            D2.T.dot(G21[i].dot(D1)+G22[i].dot(D2)+G23[i].dot(D3))+\
            D3.T.dot(G31[i].dot(D1)+G32[i].dot(D2)+G33[i].dot(D3))]
A0 = sps.block_diag(A0a).tocsr()
A0 = Q.T.dot(A0.dot(Q))
A  = R.dot(A0.dot(R.T))

# Build mass matrix B
nd = nx_dofs*ny_dofs*nz_dofs
b = Q.T.dot((wvals*jacb_det)[etd.ravel()])
# Bl is not the local mass matrix.
# I am just using bad notation here
Bl = sps.dia_matrix((b, 0), shape=(nd,nd))
Binv_data = (1.0/Bl.data).ravel()
Binv_data = R.dot(Binv_data)

if nd<=1e3:
    print np.min(np.linalg.svd(A.toarray())[1])

In [None]:
fh  = f2(dof_phys)
fl = fh
rhs = Bl.dot(fl)
radj = np.zeros(nx_dofs*ny_dofs*nz_dofs)
radj[boundary_dofs] = f(dof_phys)[boundary_dofs]
rhs = R.dot(rhs-A0.dot(radj))

if periodic:
    rhs -= np.mean(rhs)
    
ml = ruge_stuben_solver(A)
residuals = []
sol = R.T.dot(ml.solve(rhs, tol=1e-14, 
                       maxiter=500, residuals=residuals,
                       accel='cg'))
sol[boundary_dofs] = f(dof_phys)[boundary_dofs]

if periodic:
    sol -= np.mean(sol)

print len(residuals), residuals[-1]
print 
print norm(f(dof_phys)-sol)/norm(f(dof_phys))

In [None]:
# fig = plt.figure()
# ax = fig.gca(projection='3d')
# s = (ny_dofs,nx_dofs)
# X, Y = dof_phys[:,0], dof_phys[:,1]
# X = X.reshape(s)
# Y = Y.reshape(s)
# ax.plot_wireframe(X, Y, f(dof_phys).reshape(s))
# ax.plot_wireframe(X, Y, sol.reshape(s),
#                   color='g')
# plt.savefig("sol.pdf")

In [None]:
plt.scatter(dof_phys[:,0], dof_phys[:,1])