In [None]:
import sys
import argparse
import numpy as np
import scipy.stats as ss
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
class FRBGenerator:
    def __init__(self,
                 num_chan,
                 num_samp,
                 fc,
                 bw,
                 t_bin_width,
                 dm_range,
                 width_range,
                 snr_range):
        self.num_chan = num_chan
        self.num_samp = num_samp
        self.fc = fc * 1e3      # convert to MHz
        self.bw = bw
        self.t_bin_width = t_bin_width
        self.dm_range = dm_range
        self.width_range = width_range
        self.snr_range = snr_range

        # reference frequency, taken as the centre frequency of the highest
        # frequency channel
        self.f_ref = self.fc + (self.bw / 2)
        # centre frequencies of each channel
        self.f_chan = np.linspace(self.f_ref - bw, self.f_ref, self.num_chan)\
            .reshape((self.num_chan, 1))

    def generate(self, num_frb):
        # generate Gaussian random noise as background
        self.specs = np.random.normal(loc=0.0,
                                      scale=1.0,
                                      size=(num_frb,
                                            self.num_chan,
                                            self.num_samp))

        # generate random DMs
        dm = np.random.uniform(low=self.dm_range[0],
                               high=self.dm_range[1],
                               size=(1, num_frb))
        print('DM =', dm)

        # compute the dispersion delay per channel
        # (eq. 5.1 of Lorimer & Kramer, 2005)
        delta_t = np.abs(np.matmul(4.15e6 * (self.f_ref**-2 - self.f_chan**-2),
                                   dm))
        print(delta_t.shape)
        print(delta_t[0, :],  delta_t[-1, :])

        # generate Gaussian pulses
        pulse = self._generate_pulses(self.width_range,
                                      self.snr_range,
                                      self.t_bin_width,
                                      num_frb)

        # generate pulse and add it to the background
        for i, spec in enumerate(self.specs):
            t_start = np.random.uniform(low=-delta_t[0][i] / 2,
                                        high=self.num_samp - delta_t[0][i] / 2,
                                        size=1)[0]
            print(t_start)
            for j in range(self.num_chan):
                sample_lo = int(np.round(t_start + delta_t[self.num_chan - 1 - j][i] - len(pulse) / 2))
                sample_hi = int(np.round(t_start + delta_t[self.num_chan - 1 - j][i] + len(pulse) / 2))
                #print(sample_lo, sample_hi)
                k = 0
                for sample in range(sample_lo, sample_hi):
                    if sample >= 0 and sample < self.num_samp:
                        spec[self.num_chan - 1 - j][sample] += pulse[k][i]
                    k += 1
                    if k == len(pulse):
                        assert True

    def _generate_pulses(self, width_range, snr_range, t_bin_width, num_frb):
        # convert width (full width at half max.) to standard deviation
        std_range = width_range / (2 * np.sqrt(2 * np.log(2)))
        std = np.random.uniform(low=std_range[0],
                                high=std_range[1],
                                size=num_frb)
        print('std =', std)

        snr = np.random.uniform(low=snr_range[0],
                                high=snr_range[1],
                                size=num_frb)
        print('snr =', snr)

        x_hi = 6 * std_range[1]
        x_lo = -x_hi
        x = np.linspace(x_lo, x_hi, 2 * std_range[1] / t_bin_width)
        x = x.reshape((len(x), 1))

        z = (x**2 / 2) * std**-2
        pulse = snr * np.exp(-z)

        return pulse

In [None]:
frb_gen = FRBGenerator(32, 32, 1, 100, 1,
                       (20, 20),
                       (5, 5),
                       (30, 30))
frb_gen.generate(6)

# plot a few random dynamic spectra
num_plot = 6
indices = np.random.randint(low=0, high=6, size=num_plot)
plt.figure(figsize=(10, 6))
for i in range(num_plot):
    if i < 6:
        plt.subplot(2, 3, i + 1)
        plt.imshow(frb_gen.specs[i], origin='lower', aspect='auto')
plt.tight_layout()
plt.show()