In [None]:
import numpy as np
import matplotlib.pyplot as plt
from numba import jit, prange
from PIL import Image
from tqdm import tqdm

In [None]:
@jit(nopython=True, parallel=True)
def step_heat(u, alpha, dt, dx, dy, nx, ny):
    u_next = np.zeros_like(u)
    coeff_x = alpha * dt / dx**2
    coeff_y = alpha * dt / dy**2

    for i in prange(1, ny - 1):
        for j in range(1, nx - 1):
            for k in range(3):  # RGB channels
                u_next[i, j, k] = (
                    u[i, j, k]
                    + coeff_x * (u[i + 1, j, k] + u[i - 1, j, k] - 2 * u[i, j, k])
                    + coeff_y * (u[i, j + 1, k] + u[i, j - 1, k] - 2 * u[i, j, k])
                )
    return np.clip(u_next, 0, 1)

def solve_heat_equation(image_path, alpha, dt, dx, dy, T, save_interval=10):
    image = Image.open(image_path).convert("RGB")
    u = np.array(image, dtype=np.float64) / 255.0
    ny, nx, _ = u.shape
    frames = int(T / dt)

    for n in tqdm(range(frames)):
        u = step_heat(u, alpha, dt, dx, dy, nx, ny)
        
        if n % save_interval == 0:
            fig = plt.figure(figsize=(nx / dpi, ny / dpi), dpi=dpi)
            ax = fig.add_axes([0, 0, 1, 1])
            ax.imshow(u, aspect='auto')
            ax.axis("off")
            plt.savefig(f"images/frame-{int(n*dt):04d}.png", dpi=dpi, bbox_inches='tight', pad_inches=0, transparent=True)
            plt.close(fig)

# Parameters
dx = dy = 0.1
dt = 0.1
alpha = 0.02  # Thermal diffusivity
T = 2000
image_path = "thebeatles.jpg"
dpi = 72

# Solve the heat equation
solve_heat_equation(image_path, alpha, dt, dx, dy, T)
