In [15]:
from ipywidgets import interactive, interact
import ipywidgets as widgets
from ipywidgets import HBox, VBox, IntSlider, Play, jslink
import numpy as np

import time

from bokeh.io import push_notebook, show, output_notebook
from bokeh.plotting import figure
output_notebook()

In [31]:
class Simulation:
    
    def __init__(self, initialT, frameCnt, dt, dx, u, k, solver, color):
        N = len(initialT)
        self.x = np.linspace(0, dx * N, N)
        
        s = u * dt / dx
        r = k * dt / (dx * dx)
        if (s + 2 * r) <= 0 or (s + 2 * r) >= 1:
            print(s, r)
        
        self.frames = []
        self.frames.append(initialT)
        self.makeStep = lambda: self.frames.append(solver(self, dt, dx, u, k))
            
    def __getitem__(self, frameId):
        if frameId < -len(self.frames):
            return self.frames[0]
        if frameId < 0:
            return self.frames[frameId]
        while len(self.frames) < frameId + 1:
            self.makeStep()
        return self.frames[frameId]

In [36]:
from scipy.sparse import linalg, csr_matrix

def explicitUpstream(simulation, dt, dx, u, k):
    prev = simulation[-1]
    s = u * dt / dx
    r = k * dt / (dx * dx)
    
    N = len(prev)
    cur = prev.copy()
    for i in range(1, N - 1):
        cur[i] = prev[i] - s * (prev[i] - prev[i - 1]) + r * (prev[i + 1] + prev[i - 1] - 2 * prev[i])
    return cur

def explicitDownstream(simulation, dt, dx, u, k):
    prev = simulation[-1]
    s = u * dt / dx
    r = k * dt / (dx * dx)
    
    N = len(prev)
    cur = prev.copy()
    for i in range(1, N - 1):
        cur[i] = prev[i] - s * (prev[i + 1] - prev[i]) + r * (prev[i + 1] + prev[i - 1] - 2 * prev[i])
    return cur

def implicitUpstream(simulation, dt, dx, u, k):
    prev = simulation[-1]
    s = u * dt / dx
    r = k * dt / (dx * dx)
    
    N = len(prev)
    M = np.zeros([N, N])
    M[0, 0] = 1
    M[-1, -1] = 1
    for i in range(1, N - 1):
        M[i, i - 1] = -s - r
        M[i, i] = s + 2 * r + 1
        M[i, i + 1] = -r
        #cur[i] = prev[i] - s * (cur[i] - cur[i - 1]) + r * (cur[i + 1] + cur[i - 1] - 2 * cur[i])
        #prev[i] = cur[i] * (s + 2 * r + 1) + cur[i - 1] * (-s - r) + cur[i + 1] * -r
    
    return linalg.spsolve(csr_matrix(M), prev)

def implicitDownstream(simulation, dt, dx, u, k):
    prev = simulation[-1]
    s = u * dt / dx
    r = k * dt / (dx * dx)
    
    N = len(prev)
    M = np.zeros([N, N])
    M[0, 0] = 1
    M[-1, -1] = 1
    for i in range(1, N - 1):
        M[i, i - 1] = -r
        M[i, i] = -s + 2 * r + 1
        M[i, i + 1] = s - r
        #cur[i] = prev[i] - s * (cur[i + 1] - cur[i]) + r * (cur[i + 1] + cur[i - 1] - 2 * cur[i])
        #prev[i] = cur[i] * (-s + 2 * r + 1) + cur[i - 1] * (-r) + cur[i + 1] * (s - r)
    
    return linalg.spsolve(csr_matrix(M), prev)

def checkers(simulation, dt, dx, u, k):
    prev = simulation[-1]
    prevprev = simulation[-2]
    s = u * dt / dx
    r = k * dt / (dx * dx)
    
    N = len(prev)
    cur = prev.copy()
    for i in range(1, N - 1):
        cur[i] = prevprev[i] - s * (prev[i + 1] - prev[i - 1]) + r * (prev[i + 1] + prev[i - 1] - 2 * prev[i]) * 2
    return cur

In [37]:
p = figure(title="Heat Equation Solvers", plot_height=300, plot_width=600, y_range=(0,1),
           background_fill_color='#efefef')

def createLine(color):
    return p.line([], [], color=color, line_width=1.5, alpha=0.5)

solverToLine = {
    explicitUpstream: createLine("red"), 
    explicitDownstream: createLine("orange"),
    implicitUpstream: createLine("green"),
    implicitDownstream: createLine("blue"),
    
    checkers: createLine("violet")
}

show(p, notebook_handle=True);

def resetSimultaion(totalTime, dt, dx, u, k, solvers):
    xCnt = int(1 / dx)
    frameCnt = int(totalTime / dt)
    
    initialT = [0.0] * xCnt
    initialT[0] = 1.0
    initialT[xCnt // 2] = 1.0
    
    def createSimulation(solver):
        return Simulation(
            initialT,
            frameCnt=frameCnt,
            dt=dt,
            dx=dx, 
            u=u,
            k=k,
            solver=solver,
            color='red'
        )
    
    solverToSimulation = {}
    for solver, line in solverToLine.items():
        if solver.__name__ in solvers:
            solverToSimulation[solver] = createSimulation(solver)
        else:
            solverToSimulation[solver] = None

    def showFrame(frameId):
        for method, line in solverToLine.items():
            simulation = solverToSimulation[method]
            if simulation is not None:
                line.data_source.data = {
                    'x': simulation.x,
                    'y': simulation[frameId]
                }
            else:
                line.data_source.data = {
                    'x': [],
                    'y': []
                }
                
            push_notebook()
    showFrame(0)
    
    play = widgets.Play(
        value=0,
        min=0,
        max=frameCnt - 1,
        step = min(20, int(frameCnt / 60.0)) 
    )
    frameSlider = widgets.IntSlider(
        min=0,
        max=frameCnt - 1,
        description='frame'
    );
    widgets.jslink((play, 'value'), (frameSlider, 'value'))

    w = interactive(showFrame, frameId=frameSlider)
    
    display(widgets.HBox([play, frameSlider]))


totalTime = widgets.FloatSlider(
    value=1,
    min=0.1,
    max=100,
    continuous_update=False,
    description='Animation Time'
)

dt = widgets.FloatLogSlider(
    value=0.01,
    base=10,
    min=-5, # max exponent of base
    max=0, # min exponent of base
    step=0.5, # exponent step
    continuous_update=False,
    description='dt'
)

dx = widgets.FloatLogSlider(
    value=0.001,
    base=10,
    min=-5, # max exponent of base
    max=-2, # min exponent of base
    step=0.5, # exponent step
    continuous_update=False,
    description='dx'
)

u = widgets.BoundedFloatText(
    value=0.001,
    min=-1,
    max=1,
    continuous_update=False,
    description='u'
)

k = widgets.BoundedFloatText(
    value=1e-6,
    min=0,
    max=1,
    continuous_update=False,
    description='k'
)

solvers = widgets.SelectMultiple(
    options=['explicitUpstream', 'explicitDownstream', 'implicitUpstream', 'implicitDownstream', 'checkers'],
    value=['checkers', 'implicitDownstream'],
    description='Methods:',
    disabled=False
)

widgets.HBox([dt, totalTime])
    
interact(resetSimultaion, totalTime=totalTime, dt=dt, dx=dx, u=u, k=k, solvers=solvers)

interactive(children=(FloatSlider(value=1.0, continuous_update=False, description='Animation Time', min=0.1), …

<function __main__.resetSimultaion(totalTime, dt, dx, u, k, solvers)>