In [None]:
from mpi4py import MPI
import numpy as np
import matplotlib.pyplot as plt

# MPI setup
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

# Parameters
alpha = 1.11e-4
nx, ny_total = 50, 50
nt = 10000
Lx, Ly = 1.0, 1.0
dx = Lx / (nx - 1)
dy = Ly / (ny_total - 1)
dt = 0.1

# Stability check
assert alpha * dt * (1 / dx**2 + 1 / dy**2) < 0.5, "Stability condition violated!"

# Divide y-dimension among processes
rows_per_rank = ny_total // size
remainder = ny_total % size

if rank < remainder:
    start_y = rank * (rows_per_rank + 1)
    end_y = start_y + rows_per_rank + 1
else:
    start_y = rank * rows_per_rank + remainder
    end_y = start_y + rows_per_rank

local_ny = end_y - start_y

# Allocate local temperature array (with ghost rows)
T_local = np.zeros((local_ny + 2, nx, nt))
T_local[:, :, 0] = 300

# Apply boundary conditions
if start_y == 0:
    T_local[1, :, :] = 20
if end_y == ny_total:
    T_local[-2, :, :] = 50
T_local[:, 0, :] = 400
T_local[:, -1, :] = 300

# Time-stepping loop
for k in range(nt - 1):
    if rank > 0:
        comm.Sendrecv(T_local[1, :, k], dest=rank - 1, sendtag=0,
                      recvbuf=T_local[0, :, k], source=rank - 1, recvtag=1)
    if rank < size - 1:
        comm.Sendrecv(T_local[-2, :, k], dest=rank + 1, sendtag=1,
                      recvbuf=T_local[-1, :, k], source=rank + 1, recvtag=0)

    for i in range(1, local_ny + 1):
        for j in range(1, nx - 1):
            T_local[i, j, k + 1] = T_local[i, j, k] + alpha * dt * (
                (T_local[i + 1, j, k] - 2 * T_local[i, j, k] + T_local[i - 1, j, k]) / dy**2 +
                (T_local[i, j + 1, k] - 2 * T_local[i, j, k] + T_local[i, j - 1, k]) / dx**2
            )

final_local = T_local[1:-1, :, -1]
gathered = comm.gather(final_local, root=0)

if rank == 0:
    T_final = np.vstack(gathered)
    plt.imshow(T_final, cmap='hot', origin='lower', extent=[0, Lx, 0, Ly], vmin=20, vmax=400)
    plt.colorbar(label='Temperature')
    plt.title('Final Temperature Distribution (MPI)')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.tight_layout()
    plt.show()
