In [64]:
################################
# Solves du/dt = u''           #
# With zero Dirichlet boundary #
################################
import radiant as rad
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML, display_html


# Problem Parameters
a = 0.
b = 1.
t = 0.1
alpha = 1.


def f(t, *x):
    return np.zeros_like(x[0])


def g(t, *x):
    return np.zeros_like(x[0])


def u0(*x):
    return np.prod(np.sin(np.pi * np.asarray(x)), axis=0)


def exact(t, *x):
    return np.exp(- np.pi ** 2 * alpha * t) * np.prod(np.sin(np.pi * np.asarray(x)), axis=0)

In [65]:
# Parameters
d = 1
k = 2
levels = 5
start_tdelta = 0.1
start_xdelta = 0.5
start_tN = 15
start_xN = 10

# Computed Parameters
xranges = ((a, b),) * d
ranges = ((0., t), *xranges)
centre_thinning = 2 ** np.arange(levels)
delta_thinning = 2 ** np.arange(levels)
incs = [(t / thin * (start_tN - 1), *(((b - a) / (thin * (start_xN - 1)),) * d)) for thin in centre_thinning]
xcs = [rad.gridinc(ranges, inc, flat=True, unitary=False) for inc in incs]
deltas = [np.array([start_tdelta] + [start_xdelta] * d) / thin for thin in delta_thinning]

L = lambda func: lambda *x: - alpha * np.sum([func(*x, m=(i, i)) for i in range(1, d+1)], axis=0)
B = lambda func: lambda *x: func(*x)

bndry_eps = 1e-10
Lidx = lambda cs: np.all([np.logical_and(np.abs(a - c) > bndry_eps, np.abs(b - c) > bndry_eps) for c in cs], axis=0)
Bidx = lambda cs: np.any([np.logical_or(np.abs(a - c) <= bndry_eps, np.abs(b - c) <= bndry_eps) for c in cs], axis=0)

# Solve for approximate solution
phis = [rad.Wendland(d+1, k, delta, xc) for delta, xc in zip(deltas, xcs)]
solver = rad.solve.MultilevelSolver(phis, rad.solve.SpaceTimeCollocation, L, Lidx, B, Bidx)
approx = solver.solve(f, g, u0)

In [66]:
fig = plt.figure(figsize=(12, 5))
ax = fig.add_subplot(111)
ax.margins(x=0.)

xs = np.linspace(a, b, 100)
exact_line, = ax.plot(xs, exact(0, xs), label="Exact")
approx_line, = ax.plot(xs, approx(*np.meshgrid(0, xs)), label="Approximate")
plt.legend()

ts = np.linspace(0, t, 100)
def func(i):
    tx = np.meshgrid(ts[i], xs)
    approx_line.set_ydata(approx(*tx))
    exact_line.set_ydata(exact(ts[i], xs))
    return approx_line, exact_line,


anim = animation.FuncAnimation(
    fig,
    func,
    frames=len(ts),
    interval=200,
    blit=True,
    repeat=True
)

display_html(HTML(anim.to_jshtml()))
anim.save('figures/ml_st_heat.gif', writer=animation.FFMpegWriter(fps=30))
plt.close()