In [2]:
import numpy as np
import scipy
import matplotlib.pyplot as plt

In [None]:
class Quadrature:
    def __init__(self, order):
        # NOTE: order needs to be even to avoid mu = 0
        self.order = order
        self.mus = None
        self.weights = None
        
    def compute_quadrature(self):
        # Compute the quadrature points and weights
        # In 1D slab geometry, the quadrature points are the Legendre roots
        # and the weights are the corresponding Legendre weights
        (self.mus, self.weights) = scipy.special.roots_legendre(self.order)
        
        # By default, angles go from negative to positive, and we want the opposite 
        self.mus = np.flip(self.mus)
class Mesh:
    def __init__(self, max_mesh_width, num_points):
        self.max_mesh_width = max_mesh_width
        self.num_points = num_points 
        self.mesh = None
        self.dx = None
        self.num_cells = None
    
    def create_mesh(self):
        # Number of mesh intervals and mesh spacing
        (self.mesh, self.dx) = np.linspace(0, self.max_mesh_width, self.num_points, 
                                           retstep=True)
        self.num_cells = int(self.max_mesh_width/self.dx)

class MaterialProperties:
    def __init__(self, zones_sigma_t, zones_sigma_a, zones_sigma_s, zones_q):
        self.zones_sigma_t = zones_sigma_t # Dictionary where keys are tuples zone = (zone_begin, zone_end) and values are sigma_t for that zone
        self.zones_sigma_a = zones_sigma_a # Same as above, but for sigma_a
        self.zones_sigma_s = zones_sigma_s # Same as above, but for sigma_s
        self.zones_q       = zones_q       # Same as above, but for external volumetric sources

    def return_xs(self, mesh_point):
        sigma_t = None
        sigma_a = None
        sigma_s = None

        for (zone, sigma_t_val) in self.zones_sigma_t.items():
            if zone[0] <= mesh_point <= zone[1]:
                sigma_t = sigma_t_val

        for (zone, sigma_a_val) in self.zones_sigma_a.items():
            if zone[0] <= mesh_point <= zone[1]:
                sigma_a = sigma_a_val

        for (zone, sigma_s_val) in self.zones_sigma_s.items():
            if zone[0] <= mesh_point <= zone[1]:
                sigma_s = sigma_s_val

        return (sigma_t, sigma_a, sigma_s)        
             

class Solver:
    def __init__(self, psi_left, psi_right, zones_sigma_t,
                zones_sigma_a, zones_sigma_s, zones_q,
                max_mesh_width, num_mesh_points, quad_order,
                solver_type="diamond-difference"):
        self.psi_left = psi_left 
        self.psi_right = psi_right
        self.zones_sigma_t = zones_sigma_t
        self.zones_sigma_a = zones_sigma_a
        self.zones_sigma_s = zones_sigma_s
        self.zones_q       = zones_q
        self.max_mesh_width = max_mesh_width
        self.num_mesh_points = num_mesh_points
        self.quad_order     = quad_order
        self.solver_type    = solver_type

        # psi_{d, i} = a*psi_{d, i+1/2} + (1-a) psi_{d, i-1/2}, 0 <= a <= 1
        # Step scheme: a = 0 for mu > 0 and a = 1 for mu < 0
        # Diamond-difference scheme: a = 1/2
        self.a = 0.5 if self.solver_type == "diamond-difference" else 0 # Due to ordering of mu's 

        # Create mesh, material properties, and quadrature objects
        self.mesh = Mesh(self.max_mesh_width, self.num_mesh_points)
        self.mat_props = MaterialProperties(self.zones_sigma_t, self.zones_sigma_a, self.zones_sigma_s,
                                            self.zones_q)
        self.quad_rules = Quadrature(self.quad_order)

        # Initial guesses to perform source iteration
        self.curr_scalar_fluxes = np.full((self.mesh.num_cells, ), np.log(2))
        self.new_scalar_fluxes  = np.zeros_like(self.curr_scalar_fluxes)

        self.eps = 1e-14 # Relative error tolerance
        self.iter_count = 0 # Iteration counter
        self.scalar_flux_rel_err = np.max(np.abs((self.new_scalar_fluxes - self.curr_scalar_fluxes)/
                                                 self.curr_scalar_fluxes))

    def run(self):
        num_cells = np.size(self.curr_scalar_fluxes)

        # Perform source iteration until scalar flux relative error (in all mesh cells)
        # between previous and current scalar flux iterates is below eps
        while self.scalar_flux_rel_err >= self.eps:
            self.iter_count += 1

            # Loop over angles
            for d, mu in enumerate(self.quad_rules.mus):
                forward_sweep = True if mu > 0.0 else False
                psi_prev = self.psi_left[d] if forward_sweep else self.psi_right[d]

                # Set weight to compute cell-centered scalar flux
                self.a = ()

                abs_mu = np.abs(mu)
                

                # Loop over cells
                for i in np.arange(num_cells):
                    # Determine order of iteration
                    idx = i if forward_sweep else (num_cells - i - 1)

                    psi_next = 



if __name__ == '__main__':
    pass
    

