In [1]:
from discretization_1d import GridType1D

In [2]:
from discretization_1d import Grid1D

for k in GridType1D:
  print(k)
  N = 5
  g = Grid1D(-.5, 5, N, k)
  print(g.X)
  assert g.X.size == N

GridType1D.LEFT_CLOSED
[-0.5  0.6  1.7  2.8  3.9]
GridType1D.RIGHT_CLOSED
[0.6 1.7 2.8 3.9 5. ]
GridType1D.OPEN
[0.6 1.7 2.8 3.9 5. ]
GridType1D.INTERIOR
[0.41666667 1.33333333 2.25       3.16666667 4.08333333]
GridType1D.MIDPOINT
[0.05 1.15 2.25 3.35 4.45]
GridType1D.CLOSED
[-0.5    0.875  2.25   3.625  5.   ]


In [None]:
import numpy as np

def test_closed():
  a = 0
  b = 1
  N = 5
  grid_type = GridType1D.CLOSED

  g = Grid1D(a=a, b=b, N=N, grid_type=grid_type)

  assert isinstance(g.X, np.ndarray)
  assert g.X.shape == (N,)

  assert np.isclose(g.X[0],a)
  assert np.isclose(g.X[-1],b)

  dX = np.diff(g.X)
  assert np.all(dX > 0.0)
  assert np.allclose(dX, g.dx)

  assert np.isclose(g.dx, (b-a)/(N-1))

def test_closed_requires_N_ge_2():
  with pytest.raises(ValueError):
    Grid1D(a=0.0, b=1.0, N=1, grid_type=GridType1D.CLOSED)


In [None]:
from dataclasses import dataclass
import numpy as np
from enum import Enum, auto

class GridType1D(Enum):
  CLOSED = auto()
  OPEN = auto()
  LEFT_CLOSED = auto()
  RIGHT_CLOSED = auto()
  MIDPOINT = auto()
  INTERIOR = auto()

@dataclass(frozen=True)
class Grid1d:
  a: float
  b: float
  N: int
  grid_type: GridType1D = GridType1D.CLOSED

  def __post_init__(self):
    if self.b <= self.a:
      raise ValueError("Require b > a")
    if self.N <= 0:
      raise ValueError("Require N > 0")
    
    @property
    def dx(self) -> float:
        if self.grid_type is GridType1D.CLOSED:
            return self.L / (self.N - 1)
        elif self.grid_type is GridType1D.INTERIOR:
            return self.L / (self.N + 1)
        else:
            return self.L / self.N
    
    @property
    def x(self) -> np.ndarray:
        dx = self.dx

        if self.grid_type is GridType1D.LEFT_CLOSED:
            return self.a + dx * np.arange(self.N)

        elif self.grid_type is GridType1D.RIGHT_CLOSED:
            return self.a + dx * (np.arange(self.N) + 1)

        elif self.grid_type is GridType1D.OPEN:
            return self.a + dx * (np.arange(self.N) + 1)

        elif self.grid_type is GridType1D.INTERIOR:
            return self.a + dx * (np.arange(self.N) + 1)

        elif self.grid_type is GridType1D.MIDPOINT:
            return self.a + dx * (np.arange(self.N) + 0.5)

        elif self.grid_type is GridType1D.CLOSED:
            return np.linspace(self.a, self.b, self.N)

        else:
            raise RuntimeError("Unhandled GridType1D")