In [7]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
from linx.const import hbar, me, aFS

ModuleNotFoundError: No module named 'linx'

We are trying to find the spectrum $f_X(x)$ where $X$ is $\gamma$ or $e^- e^+$ where

$$
f_X (E) = \frac{1}{\Gamma_X(E, T)} \left( S_X (E, T) + \int_{E}^\infty dE' K_{X' \rightarrow X} (E, E') f_X' (E') \right)
$$

This is the number of particle per volume per energy, so $[f_X (x)] = \frac{1}{\rm{cm}^{3} \rm{MeV}}$

The source term can be written as 
$$
S_X (E, T) = S^{(0)}_{X} \delta (E - E_{0}) + S_{X}^{\rm{FSR}} (E)
$$

In [8]:
class DM_injection():

    def __init__(self, m_dm, tau_dm, T_0, n_DM, bree, braa):

        """
        Calculate injected photon and electron spectrum.

        Parameters:
        m_dm (float): Mass of the dark matter particle.
        tau_dm (float): Lifetime of the dark matter particle.
        T_0 (float): Current temperature of the universe at time of decay
        n_DM (float): Number density of dark matter particles relative to photons
        bree (float): Branching ratio for electron-positron reaction
        braa (float): Branching ratio for photon reactions

        Returns:
        spectrum (array): The calculated injection spectrum.
        """

        self.m_dm = m_dm
        self.tau_dm = tau_dm
        self.T_0 = T_0
        self.n_DM = n_DM
        self.bree = bree
        self.braa = braa
    
        self.E0 = m_dm/2 #injection energy

    def _source_photon_0(self, T):
        return self.braa * 2. * self.n_DM(T) * (hbar/self.tau_dm)


    def _source_electron_0(self, T):
        return self.bree * self.n_DM(T) * (hbar/self.tau_dm)


    def _source_photon_c(self, E, T):
        EX = self.E0

        x = E/EX
        y = me**2/(4.*EX**2.)

        if 1. - y < x:
            return 0.

        _sp = self._source_electron_0(T)

        return (_sp/EX) * (aFS/jnp.pi) * ( 1. + (1.-x)**2. )/x * jnp.log( (1.-x)/y )
    


In [None]:
def _source_photon_c(self, E, T):
    EX = self._sE0

    x = E/EX
    y = me2/(4.*EX**2.)

    if 1. - y < x:
        return 0.

    _sp = self._source_electron_0(T)

    return (_sp/EX) * (alpha/pi) * ( 1. + (1.-x)**2. )/x * log( (1.-x)/y )

def get_spectrum(self, E0, S0f, SCf, T, allX=False):
    # Define the dimension of the grid
    # from the params in 'params.py'...
    NE = int(jnp.log10(E0/Emin)*NE_pd)
    # ... but do not use less than NE_min
    # points
    NE = max(NE, NE_min)

    # Save the dimension of the species grid
    NX = self._sNX

    # Generate the grid for the energy
    E_grid = jnp.logspace(jnp.log(Emin), jnp.log(E0), NE, base=jnp.e)

    # Generate the grid for the different species
    X_grid = np.arange(NX)

    # Generate the grid for the rates
    G = np.array([[self._rate_x(X, E, T) for E in E_grid] for X in X_grid])
        # first index: X, second index according to energy E

    # Generate the grid for the kernels
    K = np.array([[[[self._kernel_x_xp(X, Xp, E, Ep, T) if Ep >= E else 0. for Ep in E_grid] for E in E_grid] for Xp in X_grid] for X in X_grid])
        # first index: X, second index: Xp
        # third index according to energy E
        # fourth index according to energy Ep;
        # For Ep < E, the kernel is simply 0.

    # Generate the grids for the source terms
    # monochromatic + continuous
    S0 = np.array([ S0X(T)                     for S0X in S0f])
    SC = np.array([[SCX(E, T) for E in E_grid] for SCX in SCf])

    # Calculate the spectra by solving
    # the cascade equation
    sol = _JIT_solve_cascade_equation(E_grid, G, K, S0, SC, T)

    # 'sol' always has at least two columns
    return sol[0:2,:] if not allX else sol


def get_universal_spectrum(self, E0, S0f, SCf, T, offset=0.):
    # Define EC and EX as in 'astro-ph/0211258'
    EC = me2/(22.*T)
    EX = me2/(80.*T)

    # Define the normalization K0 as in 'astro-ph/0211258'
    K0 = E0/( (EX**2.) * ( 2. + log( EC/EX ) ) )

    # Define the dimension of the grid
    # as defined in 'params.py'...
    NE = int(log10(E0/Emin)*NE_pd)
    # ... but not less than NE_min points
    NE = max(NE, NE_min)

    # Generate the grid for the energy
    E_grid = np.logspace(log(Emin), log(E0), NE, base=np.e)
    # Generate the grid for the photon spectrum
    F_grid = np.zeros(NE)

    # Calculate the spectrum for the different energies
    # TODO: Incoporate the continuous source terms in the
    #       normalization by integrating it over the energy
    SN = lambda T: sum(S0X(T) for S0X in S0f) # Normalization
    for i, E in enumerate(E_grid):
        if E < EX:
            F_grid[i] = SN(T) * K0 * (EX/E)**1.5/self.rate_photon(E, T)
        elif E >= EX and E <= (1. + offset)*EC: # an offset enables better interpolation
            F_grid[i] = SN(T) * K0 * (EX/E)**2.0/self.rate_photon(E, T)

    # Remove potential zeros
    F_grid[F_grid < approx_zero] = approx_zero

    # Define the output array...
    sol = np.zeros( (2, NE) )
    # ...and fill it
    sol[0, :] = E_grid
    sol[1, :] = F_grid

    return sol
