In [None]:
import numpy as np
from scipy.fftpack import fft, ifft
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Sample Code

In [None]:
N=24
x1 = (2.*np.pi/N)*np.arange(N)
f = np.sin(x1)**2.*np.cos(x1) + np.exp(2.*np.sin(x1+1))

k = np.concatenate(( np.arange(0,N/2) ,
                    np.array([0]) , # Because hat{f}'(k) at k = N/2 is zero.
                    np.arange(-N/2+1,0) ))

# Approximates the derivative using the pseudospectral method
f_hat = fft(f)
fp_hat = ((1j*k)*f_hat)
fp = np.real(ifft(fp_hat))

# Calculates the derivative analytically
x2 = np.linspace(0,2*np.pi,200)
derivative = (2.*np.sin(x2)*np.cos(x2)**2. -
              np.sin(x2)**3. +
              2*np.cos(x2+1)*np.exp(2*np.sin(x2+1))
             )

plt.plot(x2,derivative,'-k',linewidth=2.)
plt.plot(x1,fp,'*b')
plt.show()

# Problem 1

In [None]:
d2 = (-2*np.exp(2*np.sin(x2+1))*np.sin(x2+1) +
      2*np.cos(x2)**3 -
      2*np.cos(x2)*np.sin(x2)**2 +
      4*np.exp(2*np.sin(x2+1))*np.cos(x2+1)**2 -
      5*np.sin(x2)**2*np.cos(x2))

fpp_hat = ((1j*k)**2*f_hat)
fpp = np.real(ifft(fpp_hat))

plt.plot(x1,fpp,'*b')
plt.plot(x2,d2,'-k')
plt.show()

In [None]:
plt.plot(x2,.5*d2 - derivative,'-k')
plt.plot(x1,.5*fpp - fp,'*b')
plt.show()

# Problem 2 - setup

In [None]:
def initialize_all(a,b,y0,h):
    n = int((b-a)/h+1)
    X = np.linspace(a,b,n)
    if isinstance(y0,np.ndarray):
        Y = np.empty((n, y0.size))
    else:
        Y = np.empty(n)
    Y[0] = y0
    return X,Y,h,int(n)

def RK4(f,X,Y,h,n):
    for i in range(n-1):
        K1 = f(X[i],Y[i])
        K2 = f(X[i]+h/2,Y[i]+h/2*K1)
        K3 = f(X[i]+h/2,Y[i]+h/2*K2)
        K4 = f(X[i+1],Y[i]+h*K3)
        Y[i+1] = Y[i] + h/6*(K1+2*K2+2*K3+K4)
    return Y

# Problem 2.

In [None]:
t_steps = 150
x_steps = 100
N = x_steps
a = 0
b = 8
h = float(b-a)/t_steps
k = np.hstack((np.arange(0, N//2), np.array([0]), np.arange(-N//2+1,0,1)))

x_domain = np.arange(x_steps)*2*np.pi/x_steps
t_domain = np.linspace(a,b,t_steps)
c = .2+np.sin(x_domain-1)**2
y0 = np.exp(-100*(x_domain-1)**2)

f = lambda u,y:np.real(-c*ifft(1j*k*fft(y)))

X, Y, h, n = initialize_all(a, b, y0, h)
sol = RK4(f,X,Y,h,n)[:-1]

X,Y = np.meshgrid(x_domain, t_domain)
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.plot_wireframe(X,Y,sol)
ax.set_zlim(0,3)
plt.show()