In [None]:
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

def pde_solver(h,k,T,c,f,g,period=False, a=0, b=0, l=0 ,r=0):
    if period:
        return periodic_pde_sovlver(h,k,T,c,f,g)
    else:
        if a == b == 1:
            return dirichlet_pde_solver(h,k,T,c,f,g,l,r)
        elif a == b == 0:
            return neumann_pde_solver(h,k,T,c,f,g,l,r)
        elif a == 1 and b == 0:
            return mixed10_pde_solver(h,k,T,c,f,g,l,r)
        elif a == 0 and b == 1:
            return mixed01_pde_solver(h,k,T,c,f,g,l,r)
        else:
            return robin_pde_solver(h,k,T,c,f,g,a,b,l,r)
        
def dirichlet_pde_solver(h,k,T,c,f,g,l,r):
    s = 1.0*c*k/h
    if(c>1):
        print "inestable"
        return
    xs = np.arange(-1,1+h,h)
    ts = np.arange(0,T+k,k)
    Nx_plus_1 = xs.shape[0]
    Nt_plus_1 = ts.shape[0]
    W = np.zeros((Nt_plus_1,Nx_plus_1))
    vec1 = np.array([g(xi) for xi in xs[1:-1]])
    vec2 = np.zeros(Nx_plus_1-2)
    vec2[0] = l(ts[0])
    vec2[-1] = r(ts[0])
    cons_vec = k*vec1+(s**2)/2*vec2
    w0 = np.array([f(xi) for xi in xs[1:-1]])
    W[0,1:-1] = w0
    A = np.zeros((Nx_plus_1-2,Nx_plus_1-2))
    for i in range(Nx_plus_1-2):
        if i == 0:
            A[0,0] = 2-2*s**2
            A[0,1] = s**2
        elif i == Nx_plus_1-3:
            A[i,i-1] = s**2
            A[i,i] = 2-2*s**2
        else:
            A[i,i-1] = s**2
            A[i,i] = 2-2*s**2
            A[i,i+1] = s**2
    W[:,0] = np.array([l(t) for t in ts])
    W[:,-1] = np.array([r(t) for t in ts])
    W[1,1:-1] = (1/2)*(np.dot(A,w0)) + cons_vec
    for i in range(2,Nt_plus_1):
        vec = np.zeros(Nx_plus_1-2)
        vec[0] = l(ts[i-1])
        vec[-1] = r(ts[i-1])
        wk_1 = W[i-1,1:-1]
        wk_2 = W[i-2,1:-1]
        W[i,1:-1] = np.dot(A,wk_1) - wk_2 + s**2*vec
    return W,xs,ts

def f(x):
    return -(x+1)*(x-1)

def g(x):
    return -(x+1)*(x-1)

def l(t):
    return t

def r(t):
    return t

def plot_discrete_PDE(xs,ts,W,z_lims):
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    X, T = np.meshgrid(xs, ts)
    surf = ax.plot_surface(X, T, W, rstride=1, cstride=1, linewidth=0, antialiased=False)
    ax.set_zlim(z_lims[0],z_lims[1])
    plt.show()
                
def neumann_pde_solver(h,k,T,c,f,g,l,r):
    s = 1.0*c*k/h
    if(c>1):
        print "inestable"
        return
    xs = np.arange(-1,1+h,h)
    ts = np.arange(0,T+k,k)
    Nx_plus_1 = xs.shape[0]
    Nt_plus_1 = ts.shape[0]
    W = np.zeros((Nt_plus_1,Nx_plus_1))
    vec1 = np.array([g(xi) for xi in xs[1:-1]])
    vec2 = np.zeros(Nx_plus_1-2)
    vec2[0] = l(ts[0])
    vec2[-1] = r(ts[0])
    cons_vec = k*vec1+(s**2)/2*vec2
    w0 = np.array([f(xi) for xi in xs[1:-1]])
    W[0,1:-1] = w0
    A = np.zeros((Nx_plus_1-2,Nx_plus_1-2))
    for i in range(Nx_plus_1-2):
        if i == 0:
            A[0,0] = 2-2*s**2
            A[0,1] = s**2
        elif i == Nx_plus_1-3:
            A[i,i-1] = s**2
            A[i,i] = 2-2*s**2
        else:
            A[i,i-1] = s**2
            A[i,i] = 2-2*s**2
            A[i,i+1] = s**2
    W[1,1:-1] = (1/2)*(np.dot(A,w0)) + cons_vec
    print W
    x1 = xs[1]
    x2 = xs[2]
    c1 = (-1-x1)*(-1-x2)/(-2-x1-x2)
    c2 = (-1-x2)/((x1+1)*(x1-x2))
    c3 = (-1-x1)/((x2+1)*(x2-x1))
    xNx_1 = xs[-2]
    xNx_2 = xs[-3]
    c4 = (1-xNx_2)*(1-xNx_1)/(2-xNx_2-xNx_1)
    c5 = (1-xNx_2)/((xNx_1-xNx_2)*(xNx_1-1))
    c6 = (1-xNx_1)/((xNx_2-xNx_1)*(xNx_2-1))
    for i in range(2,Nt_plus_1):
        vec = np.zeros(Nx_plus_1-2)
        vec[0] = l(ts[i-1])
        vec[-1] = r(ts[i-1])
        wk_1 = W[i-1,1:-1]
        wk_2 = W[i-2,1:-1]
        W[i,1:-1] = np.dot(A,wk_1) - wk_2 + s**2*vec
        ltk = l(ts[i])
        W[i,0]  = c1*(ltk - W[i,1]*c2 - W[i,2]*c3)
        rtk = r(ts[i])
        W[i,-1] = c4*(rtk - W[i,-2]*c5 - W[i,-3]*c6) 
        print W
    return W,xs,ts
k = 0.2
h = 0.2
T = 3
c = 1
W,xs,ts = pde_solver(h,k,T,c,f,g,a=0, b=0, l=l ,r=r)
plot_discrete_PDE(xs,ts,W,(-6,6))

[[ 0.     0.36   0.64   0.84   0.96   1.     0.96   0.84   0.64   0.36   0.   ]
 [ 0.     0.072  0.128  0.168  0.192  0.2    0.192  0.168  0.128  0.072  0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     0.     0.     0.     0.     0.     0.     0.   ]
 [ 0.     0.     0.     0.     0.     0.

In [79]:
np.array([1,2,3])[-3]

1