In [11]:
import numpy as np
from scipy.sparse import lil_matrix
from scipy.sparse.linalg import spsolve
import matplotlib.pyplot as plt

class DiffusionSolver2D:
    def __init__(self, mesh, mats, bcs):
        self.mesh = mesh
        self.mats = mats
        self.bcs = bcs
        self.N = mesh.Nx * mesh.Ny

    def assemble(self):
        m=self.mesh; mats=self.mats
        Nx,Ny=m.Nx,m.Ny; dx,dy=m.dx[0],m.dy[0]
        Ax,Ay=dy,dx
        A=lil_matrix((self.N,self.N)); b=np.zeros(self.N)
        for i in range(Ny):
            for j in range(Nx):
                p=m.cell_to_index(i,j)
                # neighbors and BC
                for dir_, (ni,nj), face_len, d in [
                    ('left',(i,j-1),Ax,dx),('right',(i,j+1),Ax,dx),
                    ('bottom',(i-1,j),Ay,dy),('top',(i+1,j),Ay,dy)]:
                    if 0 <= (ni if dir_ in ['bottom','top'] else i) < Ny and \
                       0 <= (nj if dir_ in ['left','right'] else j) < Nx:
                        if dir_=='left': i2,j2=i,j-1
                        elif dir_=='right':i2,j2=i,j+1
                        elif dir_=='bottom':i2,j2=i-1,j
                        else: i2,j2=i+1,j
                        Df=mats.get_face_D(i,j,i2,j2)
                        tf=Df*face_len/d
                        A[p,p]+=tf; A[p,m.cell_to_index(i2,j2)]-=tf
                    else:
                        bc=self.bcs.get(dir_,{'type':'neumann','value':0.0})
                        if bc['type']=='dirichlet':
                            Dp=mats.get_D(i,j)
                            tf=Dp*face_len/(0.5*d)
                            A[p,p]+=tf; b[p]+=tf*bc['value']
                        else:
                            b[p]+=bc['value']*face_len
                # reaction + source
                Sigma=mats.get_Sigma_a(i,j); vol=dx*dy
                A[p,p]+=Sigma*vol; b[p]+=mats.get_Q(i,j)*vol
        self.A=A.tocsr(); self.b=b
        return self.A,self.b

    def solve(self):
        if not hasattr(self,'A'): self.assemble()
        phi=spsolve(self.A,self.b)
        self.phi=phi.reshape((self.mesh.Ny,self.mesh.Nx))
        return self.phi, {'N':self.N,'nnz':self.A.nnz}

    def plot(self,title="Scalar Flux"):
        m=self.mesh
        if m.Nx>1 and m.Ny>1:
            plt.pcolormesh(m.xc,m.yc,self.phi,shading='auto',cmap='viridis')
            plt.colorbar(label='phi'); plt.xlabel('x'); plt.ylabel('y'); plt.title(title)
        elif m.Ny==1:
            plt.plot(m.xc,self.phi[0,:],'-o'); plt.xlabel('x'); plt.ylabel('phi'); plt.title(title)
        elif m.Nx==1:
            plt.plot(m.yc,self.phi[:,0],'-o'); plt.xlabel('y'); plt.ylabel('phi'); plt.title(title)
        plt.grid(alpha=0.3); plt.tight_layout()
