# Heat equation example in 1D

One of the simplest examples of solving PDEs with neural networks is considered in this notebook. A problem is set up in such a way that it admits an exact solution, which can be compared compared to a PINN approximation. The heat transfer example is based on the **one-dimensional heat equation**
$$
\frac{\partial u}{\partial t} - \alpha \frac{\partial^2 u}{\partial x^2} = 0, \quad
t \in [0, T], \quad x \in [0, L].
$$
Dirichlet boundary conditions $u(t, 0) = u(t, L) = 0$ for $t \in [0, T]$ and initial conditions $u(0, x) = u_0(x) = \sin \left( \frac{n \pi x}{L} \right)$ for $x \in [0, L]$ and a certain $n \in \{1, 2, \ldots\}$ are imposed. Through **separation of variables** one can obtain the factorized solution as
$$
u(t, x) = \sin \left( \frac{n \pi}{L} x \right) \exp \left( -\frac{n^2 \pi^2}{L^2} \alpha t \right).
$$

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch

from utils import SimpleHeatConduction1D

## Problem setup

In [None]:
alpha = 0.07
length = 1.0
n = 3

simple_problem = SimpleHeatConduction1D(
    alpha=alpha,
    length=length,
    n=n
)

## Exact solution

In [None]:
t_values = np.linspace(0, 1, 1001)
x_values = np.linspace(0, length, 1001)

t = t_values.reshape(-1, 1) # (time, 1)
x = x_values.reshape(1, -1) # (1, space)

u_values = simple_problem(t=t, x=x) # (time, space)

In [None]:
fig, ax = plt.subplots(figsize=(6, 4))
for tidx in (0, 50, 100, 150, 200, -1):
    label = 't={:.2f}'.format(t_values[tidx])
    ax.plot(x_values, u_values[tidx,:], alpha=0.8, label=label)
ax.set(xlabel='x', ylabel='u(t, x)')
ax.set_xlim((x_values.min(), x_values.max()))
ax.legend()
ax.grid(visible=True, which='both', color='lightgray', linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()

In [None]:
fig, ax = plt.subplots(figsize=(6, 4))
img = ax.imshow(
    u_values.T,
    cmap='PiYG',
    aspect='auto',
    interpolation='bilinear',
    vmin=-1, vmax=1,
    origin='lower',
    extent=(0, 1, 0, length)
)
ax.set(xlabel='t', ylabel='x')
fig.colorbar(img, ax=ax)
fig.tight_layout()