# Solving 1D Heat Diffusion using Tensorflow

In [82]:
# Import libraries for simulation
import tensorflow as tf
import numpy as np
# Import for plotting
import plotly.graph_objects as go

Just a quick helper function to display our solution. We are using [Plotly](https://plot.ly/) to quickly convert our arrays to figures.

In [142]:
def DisplayArray(y, name="test"):
    x = list(range(len(y)))
    fig = go.Scatter(x=x, y=y, name=name)
    return fig

## Convenience Functions

Some functions to perform the Laplacian given the solution. Note that the current padding allows for Dirichlet boundary conditions.

In [89]:
@tf.function
def make_kernel(a):
  a = np.asarray(a)
  return tf.constant(a, dtype=1)

@tf.function
def simple_conv(x, k):
    data   = tf.reshape(x, [1, int(x.shape[0]), 1], name='data')
    kernel = tf.reshape(k, [int(k.shape[0]), 1, 1], name='kernel')
    res = tf.squeeze(tf.nn.conv1d(data, kernel, stride=1, padding='SAME'))
    return res

@tf.function
def laplace(x):
    k = make_kernel([1., -2., 1.])
    return simple_conv(x, k)

Number of grid points are specified here

In [136]:
N = 500

We are using the initial condition of u(x, 0) = x(1-x) and it follows the boundary condition.

In [146]:
# Initial Conditions
init = lambda x : x*(1. - x)
t = np.linspace(0, 1, N)
U_init = init(t).astype(np.float32)

# Create variables for simulation state
U  = tf.Variable(U_init)
initFig = DisplayArray(U_init, name="Initial: x(1-x)");

Here $\alpha$ is the diffusion coefficient and $\epsilon$ represents the [CFL](https://en.wikipedia.org/wiki/Courant%E2%80%93Friedrichs%E2%80%93Lewy_condition) number. For stability purposes, make sure it is less than 1.

In [147]:
alpha = 0.1
dt = 1.e-5
dx = 1./N
eps = alpha * dt / (dx * dx)
print("CFL: ", eps)

CFL:  0.25000000000000006


In [148]:
from timeit import default_timer as timer
start = timer()
for i in range(int(1e5)):
    U = U + eps * laplace(U)
end = timer()
print("Elapsed time: ", (end-start))
newFig = DisplayArray(U.numpy(), name="After 1s");

Elapsed time:  30.717940739999904


In [145]:
# Plot initial and final
fig = go.Figure()
fig.add_trace(initFig)
fig.add_trace(newFig)