In [3]:
import numpy as np
import matplotlib.pyplot as plt

class Mesh2D:
    def __init__(self, coarse_mat, domain, refinement_factor):
        self.coarse_mat = np.array(coarse_mat)
        self.Ny0, self.Nx0 = self.coarse_mat.shape
        self.Lx, self.Ly = domain
        self.r = refinement_factor
        # Determine refined dims
        if self.Nx0 > 1 and self.Ny0 > 1:
            self.Nx = self.r * self.Nx0
            self.Ny = self.r * self.Ny0
        elif self.Nx0 == 1:
            self.Nx = 1
            self.Ny = self.r * self.Ny0
        elif self.Ny0 == 1:
            self.Ny = 1
            self.Nx = self.r * self.Nx0
        else:
            self.Nx = 1
            self.Ny = 1
        # Spacings
        self.dx = np.full(self.Nx, self.Lx / self.Nx)
        self.dy = np.full(self.Ny, self.Ly / self.Ny)
        # Cell centers
        self.xc = np.linspace(self.dx[0]/2, self.Lx - self.dx[0]/2, self.Nx)
        self.yc = np.linspace(self.dy[0]/2, self.Ly - self.dy[0]/2, self.Ny)
        # Refine material map
        self._create_refined_material_map()

    def _create_refined_material_map(self):
        self.mat = np.zeros((self.Ny, self.Nx), dtype=int)
        for i in range(self.Ny):
            for j in range(self.Nx):
                ci = 0 if self.Ny0 == 1 else min(i // self.r, self.Ny0-1)
                cj = 0 if self.Nx0 == 1 else min(j // self.r, self.Nx0-1)
                self.mat[i, j] = self.coarse_mat[ci, cj]

    def cell_to_index(self, i, j):
        return i * self.Nx + j

    def index_to_cell(self, idx):
        return divmod(idx, self.Nx)

    def plot_material_map(self, title="Material Map"):
        if self.Nx>1 and self.Ny>1:
            plt.imshow(self.mat, origin='lower', extent=[0,self.Lx,0,self.Ly], cmap='tab10')
            plt.colorbar(label='Material ID')
            plt.xlabel('x'); plt.ylabel('y'); plt.title(title)
        elif self.Ny==1:
            plt.plot(self.xc, self.mat[0,:], 'o-'); plt.xlabel('x'); plt.ylabel('Material ID'); plt.title(title)
        elif self.Nx==1:
            plt.plot(self.yc, self.mat[:,0], 'o-'); plt.xlabel('y'); plt.ylabel('Material ID'); plt.title(title)
        plt.grid(alpha=0.3)
        plt.tight_layout()
