# Assignment 4

## 2. Heat Equation - Crank-Nicolson Method

Solve by **Crank-Nicolson Method**, the equation

\[
\frac{\partial u}{\partial t} = \frac{\partial^2 u}{\partial x^2}, \quad 0 < x < 1,\ t > 0
\]

with the boundary and initial conditions:

\[
u(x, 0) = 100x(1 - x), \quad u(0, t) = 0, \quad u(1, t) = 0
\]

for **one time step**. Take

\( h = \x = 0.25 \)

**Task**:
- Write Python code using **Crank-Nicolson Method** to solve the heat equation.
- Use the function format as shown in the instructions.
- Compute values of u1, u2, u3 after the first time step.
- Display results in tabluar form.


# Solution

### Q.1
> #### Notes for output
>
> Parameters:
>    L - Length of spatial domain;
>   T - Total time to simulate;
>    h - Time step;
>    k - Space step;
>    alpha - Diffusion coefficient;
>    u_initial - Initial condition function u(x,0);
>    u_left - Left boundary condition function u(0,t);
>    u_right - Right boundary condition function u(1,t)
>
>    Returns:
>    x - Spatial grid;
>    t - Time grid;
>    u - Solution matrix

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd

def crank_nicolson_1D(L, T, h, k, alpha, u_initial, u_left, u_right):
    # Calculate number of space steps and time steps
    nx = int(L / k) + 1
    nt = int(T / h) + 1

    # Create spatial and time grids
    x = np.linspace(0, L, nx)
    t = np.linspace(0, T, nt)

    # Initialize solution matrix
    u = np.zeros((nx, nt))

    # Set initial condition
    for i in range(nx):
        u[i, 0] = u_initial(x[i])

    # Set boundary conditions for all time steps
    for j in range(nt):
        u[0, j] = u_left(t[j])
        u[-1, j] = u_right(t[j])

    # Compute parameter lambda
    lambda_val = alpha * h / (2 * k**2)

    # Create tridiagonal system matrices
    A = np.zeros((nx-2, nx-2))
    for i in range(nx-2):
        if i > 0:
            A[i, i-1] = -lambda_val
        A[i, i] = 1 + 2*lambda_val
        if i < nx-3:
            A[i, i+1] = -lambda_val

    # Solve for each time step
    for j in range(0, nt-1):
        # Create RHS vector b
        b = np.zeros(nx-2)
        for i in range(nx-2):
            # Interior points
            b[i] = u[i+1, j] + lambda_val * (u[i+2, j] - 2*u[i+1, j] + u[i, j])

        # Add boundary conditions to RHS
        b[0] += lambda_val * u[0, j+1]
        b[-1] += lambda_val * u[-1, j+1]

        # Solve the system Au = b
        u_interior = np.linalg.solve(A, b)

        # Update solution
        u[1:-1, j+1] = u_interior

    return x, t, u

def main():
    # Problem parameters
    L = 1.0  # Length of spatial domain
    T = 0.25  # Total simulation time (one time step)
    h = 0.25  # Time step
    k = 0.1   # Space step
    alpha = 1.0  # Diffusion coefficient

    # Define initial and boundary conditions
    u_initial = lambda x: 100 * x * (1 - x)
    u_left = lambda t: 0
    u_right = lambda t: 0

    # Solve using Crank-Nicolson method
    x, t, u = crank_nicolson_1D(L, T, h, k, alpha, u_initial, u_left, u_right)

    # Create a DataFrame for the solution
    df = pd.DataFrame(u, index=x, columns=t)
    df.index.name = 'x'
    df.columns.name = 't'

    print("Solution DataFrame:")
    print(df)

    # 2D Plot: u vs x for different times
    plt.figure(figsize=(10, 6))
    plt.plot(x, u[:, 0], 'b-', label=f't = {t[0]:.2f}')
    plt.plot(x, u[:, -1], 'r-', label=f't = {t[-1]:.2f}')
    plt.xlabel('x')
    plt.ylabel('u(x,t)')
    plt.title('Solution of Heat Equation using Crank-Nicolson Method')
    plt.legend()
    plt.grid(True)
    plt.show()

    # 3D Plot
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111, projection='3d')

    X, T = np.meshgrid(x, t)
    surf = ax.plot_surface(T.T, X.T, u, cmap=cm.viridis, edgecolor='none')

    ax.set_xlabel('Time (t)')
    ax.set_ylabel('Position (x)')
    ax.set_zlabel('Temperature u(x,t)')
    ax.set_title('3D Plot of Heat Equation Solution')
    fig.colorbar(surf, shrink=0.5, aspect=5)
    plt.show()

if __name__ == "__main__":
    main()