In [None]:
import numpy as np
import scipy.sparse as sp
import scipy.sparse.linalg as spla
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def solve_convection_diffusion_one_sided(h):
    # grid
    x = np.arange(0, 1 + h, h)
    y = np.arange(-1, 1 + h, h)
    Nx, Ny = len(x), len(y)
    c = 4*np.pi**2 - 3

    # unknowns at i=1..Nx-2, j=0..Ny-1
    Ni, Nj = Nx-2, Ny
    N = Ni * Nj
    def idx(i,j):
        return (i-1)*Nj + j

    # bottom Neumann: u_y = -e*sin(2Ï€x)
    g_bottom = - np.e * np.sin(2*np.pi * x[1:-1])

    A = sp.lil_matrix((N, N))
    b = np.zeros(N)

    axx = 1/h**2
    ayy = 3/h**2
    adv = c/(2*h)

    for i in range(1, Nx-1):
        for j in range(Ny):
            k = idx(i,j)

            if j == 0:
                # forward difference for u_y = g_bottom
                A[k, idx(i,0)]   += -1/h
                A[k, idx(i,1)]   +=  1/h
                b[k] = g_bottom[i-1]

            elif j == Ny-1:
                # backward difference + Robin: (u_{N-1}-u_{N-2})/h + u_{N-1} = 0
                A[k, idx(i,Ny-2)] += -1/h
                A[k, idx(i,Ny-1)] +=  1/h + 1

            else:
                # interior PDE: u_xx + 3 u_yy - c u_y = 0
                # x-direction
                A[k, k]           += -2*axx
                if i+1 < Nx-1:    A[k, idx(i+1,j)] +=  axx
                if i-1 >  0:      A[k, idx(i-1,j)] +=  axx

                # y-direction second derivative
                A[k, k]           += -2*ayy
                A[k, idx(i,  j+1)] +=  ayy
                A[k, idx(i,  j-1)] +=  ayy

                # convective term: -c * (u_{j+1}-u_{j-1})/(2h)
                A[k, idx(i, j+1)] += -adv
                A[k, idx(i, j-1)] +=  adv

    A = A.tocsr()
    Uvec = spla.spsolve(A, b)

    # rebuild full solution (zeros at x=0,1)
    U = np.zeros((Nx,Ny))
    for i in range(1, Nx-1):
        for j in range(Ny):
            U[i,j] = Uvec[idx(i,j)]
    return x, y, U

if __name__ == "__main__":
    h = 0.1
    x, y, U = solve_convection_diffusion_one_sided(h)
    X, Y = np.meshgrid(x, y, indexing='ij')

    fig = plt.figure(figsize=(8,6))
    ax = fig.add_subplot(111, projection='3d')
    surf = ax.plot_surface(X, Y, U, cmap='viridis')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('u(x,y)')
    ax.set_title(f'Corrected solution (h={h})')
    fig.colorbar(surf, shrink=0.5, pad=0.1)
    plt.show()
