In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.fft import fft, fftfreq, fftshift

# data = CODIF("../vela_jimble_output.codif", flatten_groups=True)

# data = data.data
VELA_PERIOD = 89.33  # ms

In [None]:
data = np.load("../all_data.npy")

In [None]:
OVERLAP = 0
SEGMENT = 2**7
SAMPLING_RATE = int(51_200_000 / 27)  # Hz

OVERSAMPLED_BANDWIDTH = SAMPLING_RATE
BANDWIDTH = SAMPLING_RATE * 27 / 32
SKY_FREQUENCY = 926 * BANDWIDTH


N_THREADS = data.shape[0]

window = np.hanning(SEGMENT)

n_segments = (data.shape[1] - SEGMENT) // (SEGMENT - OVERLAP)


frequencies_axis = fftfreq(SEGMENT, 1 / SAMPLING_RATE)
frequencies_axis = np.fft.fftshift(frequencies_axis) + SKY_FREQUENCY
frequencies_axis_mhz = frequencies_axis / 1_000_000
times = np.arange(n_segments) * (SEGMENT - OVERLAP) / int(SAMPLING_RATE)

In [None]:
output = np.empty((N_THREADS, n_segments, SEGMENT))

for channel in range(N_THREADS):
    for j in range(n_segments):
        start = j * (SEGMENT - OVERLAP)
        end = start + SEGMENT
        if end > data.shape[1]:
            print("past end of array")
            break
        sliced_data = data[channel, start:end]

        ff_transform = fft(sliced_data * window)
        ff_shift = fftshift(ff_transform)
        power = np.abs(ff_shift)

        output[channel, j, :] = power

In [None]:
START = 50

for start in range(0, data.shape[0], 10):
    fig, axs = plt.subplots(10, 1, figsize=(20, 16), sharex=True, sharey=True)

    for i in range(10):
        for j in range(1):
            if start + i >= output.shape[0]:
                continue
            ax = axs[i]
            im = ax.imshow(
                output[start + i].T,
                aspect="auto",
                cmap="inferno",
                origin="lower",
                # vmin=500,
                extent=[times[0], times[-1], frequencies_axis[0], frequencies_axis[-1]],
            )
            ax.set_ylabel(f"Channel {i + 1} Frequency [Hz]")
            ax.set_xlabel("Time [s]")

    fig.colorbar(im, ax=axs, label="Amplitude")

    plt.suptitle("Frequency vs Time for Multiple Channels", fontsize=16)

    # plt.tight_layout()
    plt.show()

In [None]:
n_channels = 8


fig, axs = plt.subplots(n_channels, 1, figsize=(20, 16), sharex=True, sharey=True)

for i in range(n_channels):
    ax = axs[i]
    im = ax.plot(
        times,
        output[i].sum(axis=1),
        # aspect="auto",
        # cmap="inferno",
        # origin="lower",
        # vmin=500,
        # extent=[times[0], times[-1], frequencies_axis[0], frequencies_axis[-1]],
    )
    ax.set_ylabel(f"Channel {i + 1} Frequency [Hz]")
    ax.set_xlabel("Time [s]")


# fig.colorbar(im, ax=axs, label="Amplitude")


plt.suptitle("Frequency vs Time for Multiple Channels", fontsize=16)


# plt.tight_layout()
plt.show()

In [None]:
n_channels = 8


plt.plot(
    times,
    output.sum(axis=0).sum(axis=1) / n_channels,
    # aspect="auto",
    # cmap="inferno",
    # origin="lower",
    # vmin=500,
    # extent=[times[0], times[-1], frequencies_axis[0], frequencies_axis[-1]],
)
plt.ylabel(f"Channel {i + 1} Frequency [Hz]")
plt.xlabel("Time [s]")


# im.colorbar(label="Amplitude")


plt.suptitle("Frequency vs Time for Multiple Channels", fontsize=16)


# plt.tight_layout()
plt.show()

In [None]:
# len(times) is number of data points in SEGMENT time
# so this gives the number of SEGMENT (FFT) required to get a period of
# 89ms = VELA period
seconds_between_samples = 27 / 51_200_000
samples_89ms = VELA_PERIOD / 1000 / seconds_between_samples
samples_89ms

In [None]:
samples_89ms // (SEGMENT - OVERLAP)

In [None]:
FOLD_SEGMENT = int(samples_89ms // (SEGMENT - OVERLAP))
n_channels = output.shape[0]
output_folded = np.zeros((n_channels, FOLD_SEGMENT, SEGMENT))
num_folds = 0
for channel in range(n_channels):
    i = 0
    while True:
        start = i * FOLD_SEGMENT
        end = start + FOLD_SEGMENT
        if end > output.shape[1]:
            break
        output_folded[channel, :] += output[channel, start:end]
        num_folds += 1
        i += 1
num_folds /= n_channels

output_folded /= num_folds

In [None]:
def average_3d_array(arr, fold_length):
    """
    Averages the last two dimensions every fold_length elements along the second-to-last dimension.
    If the second-to-last dimension isn't divisible by fold_length, it truncates the extra elements.

    Parameters:
    - arr: np.ndarray of shape (D, N, M)
    - fold_length: length of the folded section

    Returns:
    - np.ndarray of shape (D, fold_length, M)
    """
    D, N, M = arr.shape
    N_trimmed = (N // fold_length) * fold_length

    arr_trimmed = arr[:, :N_trimmed, :]
    reshaped = arr_trimmed.reshape(D, N_trimmed // fold_length, fold_length, M)
    averaged = reshaped.mean(axis=1)

    return averaged


output_folded_second_method = average_3d_array(output, FOLD_SEGMENT)

In [None]:
fig, axs = plt.subplots(n_channels, 1, figsize=(20, 16), sharex=True, sharey=True)

for i in range(n_channels):
    ax = axs[i]
    im = ax.imshow(
        output_folded_second_method[i].T,
        aspect="auto",
        # cmap="inferno",
        origin="lower",
        # extent=[times[0], times[-1], frequencies_axis[0], frequencies_axis[-1]],
    )
    ax.set_ylabel(f"Channel {i + 1} Frequency [Hz]")
    ax.set_xlabel("Time [s]")


fig.colorbar(im, ax=axs, label="Amplitude")


plt.suptitle("Frequency vs Time for Multiple Channels (Folded)", fontsize=16)


# plt.tight_layout()
plt.show()

In [None]:
n_channels = 8


plt.plot(
    times[0:FOLD_SEGMENT],
    output_folded.sum(axis=0).sum(axis=1) / n_channels,
    # aspect="auto",
    # cmap="inferno",
    # origin="lower",
    # vmin=500,
    # extent=[times[0], times[-1], frequencies_axis[0], frequencies_axis[-1]],
)
plt.ylabel(f"Channel {i + 1} Frequency [Hz]")
plt.xlabel("Time [s]")


# im.colorbar(label="Amplitude")


plt.suptitle("Frequency vs Time for Multiple Channels", fontsize=16)


# plt.tight_layout()
plt.show()

In [None]:
n_channels = 8


fig, ax = plt.subplots(1, 1)

ax.imshow(
    output_folded.sum(axis=0).T / n_channels,
    aspect="auto",
    cmap="inferno",
    origin="lower",
    # vmin=500,
    # extent=[times[0], times[-1], frequencies_axis[0], frequencies_axis[-1]],
)
ax.set_ylabel(f"Channel {i + 1} Frequency [Hz]")
ax.set_xlabel("Time [s]")


# im.colorbar(label="Amplitude")


plt.suptitle("Frequency vs Time for Multiple Channels", fontsize=16)


# plt.tight_layout()
plt.show()

In [None]:
fig, axs = plt.subplots(n_channels, 1, figsize=(20, 16), sharex=True, sharey=True)

for i in range(n_channels):
    ax = axs[i]
    im = ax.imshow(
        output_folded[i].T,
        aspect="auto",
        # cmap="inferno",
        origin="lower",
        # extent=[times[0], times[-1], frequencies_axis[0], frequencies_axis[-1]],
    )
    ax.set_ylabel(f"Channel {i + 1} Frequency [Hz]")
    ax.set_xlabel("Time [s]")


fig.colorbar(im, ax=axs, label="Amplitude")


plt.suptitle("Frequency vs Time for Multiple Channels (Folded)", fontsize=16)


# plt.tight_layout()
plt.show()

In [None]:
output_folded.shape

In [None]:
## Dedisperse data
# We now want to de-disperse this data.


VELA_DM = 67.99

SECONDS_BETWEEN_SEGMENTS = seconds_between_samples * (SEGMENT - OVERLAP)


delay_samples = np.round(
    4.15
    * 10**3
    * VELA_DM
    * ((frequencies_axis_mhz**-2) - (frequencies_axis_mhz[-1] ** -2))
    / SECONDS_BETWEEN_SEGMENTS,
    0,
).astype(int)

delay_samples

In [None]:
#

transposed_data = np.transpose(output_folded, (0, 2, 1)).copy()
transposed_data_shape = transposed_data.shape

final_data = np.zeros(
    shape=(
        transposed_data_shape[0],
        transposed_data_shape[1],
        transposed_data_shape[2] - max(delay_samples),
    )
)
for i in range(delay_samples.shape[0]):
    if delay_samples[i] < max(delay_samples):
        final_data[:, i, :] = transposed_data[
            :, i, delay_samples[i] : (delay_samples[i] - max(delay_samples))
        ]
    else:
        final_data[:, i, :] = transposed_data[:, i, delay_samples[i] :]

dedispersed_data = np.transpose(final_data, (0, 2, 1))

In [None]:
# The next few cells are just sanity checks.
i = 127

dedispersed_data[:, :, i]

In [None]:
output_folded[:, delay_samples[i] :, i]

In [None]:
(
    dedispersed_data[:, :, i]
    == output_folded[:, delay_samples[i] : (delay_samples[i] - max(delay_samples)), i]
).all()

In [None]:
## find channel with the highest peak
original_folded_frequency_scrunched = output_folded.mean(axis=2)
normalized_original_folded_scrunched = (
    original_folded_frequency_scrunched
    - original_folded_frequency_scrunched.mean(axis=1, keepdims=True)
)


frequency_scrunched_data = dedispersed_data.mean(axis=2)
normalized_scrunched_data = frequency_scrunched_data - frequency_scrunched_data.mean(
    axis=1, keepdims=True
)
plt.plot(times[: FOLD_SEGMENT - max(delay_samples)], normalized_scrunched_data[127])
plt.plot(times[:FOLD_SEGMENT], normalized_original_folded_scrunched[127])

In [None]:
max(normalized_scrunched_data[127])

In [None]:
max(normalized_original_folded_scrunched[127])

In [None]:
### run an optimization routine on the maximum power of channel 127.
### start with only tweaking the VELA_PERIOD
from scipy.optimize import minimize


def max_power(x):
    n_channels = 1
    vela_period = VELA_PERIOD + x[0] * 100000
    seconds_between_samples = 27 / 51_200_000
    samples_89ms = vela_period / 1000 / seconds_between_samples
    samples_89ms

    FOLD_SEGMENT = int(round(samples_89ms // (SEGMENT - OVERLAP), 0))
    # n_channels = output.shape[0]
    output_folded = np.zeros((n_channels, FOLD_SEGMENT, SEGMENT))
    num_folds = 0
    channel = 127
    i = 0
    while True:
        start = i * FOLD_SEGMENT
        end = start + FOLD_SEGMENT
        if end > output.shape[1]:
            break
        output_folded[0, :] += output[channel, start:end]
        num_folds += 1
        i += 1
    num_folds /= n_channels

    output_folded /= num_folds

    VELA_DM = 67.99

    SECONDS_BETWEEN_SEGMENTS = seconds_between_samples * (SEGMENT - OVERLAP)
    frequencies_axis = fftfreq(SEGMENT, 1 / SAMPLING_RATE)
    frequencies_axis = np.fft.fftshift(frequencies_axis) + SKY_FREQUENCY
    frequencies_axis_mhz = frequencies_axis / 1_000_000

    delay_samples = np.round(
        4.15
        * 10**3
        * VELA_DM
        * ((frequencies_axis_mhz**-2) - (frequencies_axis_mhz[-1] ** -2))
        / SECONDS_BETWEEN_SEGMENTS,
        0,
    ).astype(int)

    transposed_data = np.transpose(output_folded, (0, 2, 1)).copy()
    transposed_data_shape = transposed_data.shape

    final_data = np.zeros(
        shape=(
            transposed_data_shape[0],
            transposed_data_shape[1],
            transposed_data_shape[2] - max(delay_samples),
        )
    )
    for i in range(delay_samples.shape[0]):
        if delay_samples[i] < max(delay_samples):
            final_data[:, i, :] = transposed_data[
                :, i, delay_samples[i] : (delay_samples[i] - max(delay_samples))
            ]
        else:
            final_data[:, i, :] = transposed_data[:, i, delay_samples[i] :]

    dedispersed_data = np.transpose(final_data, (0, 2, 1))

    output_folded_scrunched = dedispersed_data.mean(axis=2)
    normalized_output_folded_scrunched = (
        output_folded_scrunched - output_folded_scrunched.mean(axis=1, keepdims=True)
    )
    return -max(normalized_output_folded_scrunched[0]) * 1000


x_init = [0]

min_period = minimize(max_power, x_init, method="Nelder-Mead")

In [None]:
min_period

In [None]:
min_period.x * 100000 + VELA_PERIOD

In [None]:
seconds_between_samples = 27 / 51_200_000
samples_89ms = 89.33 / 1000 / seconds_between_samples
samples_89ms

In [None]:
### run an optimization routine on the maximum power of channel 127.
### now tweak just the SEGMENT length with VELA period set at 89.33 ms.


def max_power(x, segment):
    N_THREADS = 1
    # segment = x[0]
    window = np.hanning(segment)

    n_segments = (data.shape[1] - segment) // segment

    frequencies_axis = fftfreq(segment, 1 / SAMPLING_RATE)
    frequencies_axis = np.fft.fftshift(frequencies_axis) + SKY_FREQUENCY
    frequencies_axis_mhz = frequencies_axis / 1_000_000

    output = np.empty((N_THREADS, n_segments, segment))

    for channel in [127]:
        for j in range(n_segments):
            start = j * (segment)
            end = start + segment
            if end > data.shape[1]:
                print("past end of array")
                break
            sliced_data = data[channel, start:end]

            ff_transform = fft(sliced_data * window)
            ff_shift = fftshift(ff_transform)
            power = np.abs(ff_shift)

            output[0, j, :] = power

    n_channels = 1
    vela_period = VELA_PERIOD + x[0] * 100000
    seconds_between_samples = 27 / 51_200_000
    samples_89ms = vela_period / 1000 / seconds_between_samples
    samples_89ms

    FOLD_SEGMENT = int(round(samples_89ms // (segment), 0))
    # n_channels = output.shape[0]
    output_folded = np.zeros((n_channels, FOLD_SEGMENT, segment))
    num_folds = 0

    i = 0
    while True:
        start = i * FOLD_SEGMENT
        end = start + FOLD_SEGMENT
        if end > output.shape[1]:
            break
        output_folded[0, :] += output[0, start:end]
        num_folds += 1
        i += 1
    num_folds /= n_channels

    output_folded /= num_folds
    VELA_DM = 67.99

    SECONDS_BETWEEN_SEGMENTS = seconds_between_samples * (segment)
    frequencies_axis = fftfreq(segment, 1 / SAMPLING_RATE)
    frequencies_axis = np.fft.fftshift(frequencies_axis) + SKY_FREQUENCY
    frequencies_axis_mhz = frequencies_axis / 1_000_000

    delay_samples = np.round(
        4.15
        * 10**3
        * VELA_DM
        * ((frequencies_axis_mhz**-2) - (frequencies_axis_mhz[-1] ** -2))
        / SECONDS_BETWEEN_SEGMENTS,
        0,
    ).astype(int)

    transposed_data = np.transpose(output_folded, (0, 2, 1)).copy()
    transposed_data_shape = transposed_data.shape

    final_data = np.zeros(
        shape=(
            transposed_data_shape[0],
            transposed_data_shape[1],
            transposed_data_shape[2] - max(delay_samples),
        )
    )
    for i in range(delay_samples.shape[0]):
        if delay_samples[i] < max(delay_samples):
            final_data[:, i, :] = transposed_data[
                :, i, delay_samples[i] : (delay_samples[i] - max(delay_samples))
            ]
        else:
            final_data[:, i, :] = transposed_data[:, i, delay_samples[i] :]

    dedispersed_data = np.transpose(final_data, (0, 2, 1))
    output_folded_scrunched = dedispersed_data.mean(axis=2)
    normalized_output_folded_scrunched = (
        output_folded_scrunched - output_folded_scrunched.mean(axis=1, keepdims=True)
    )
    return -max(normalized_output_folded_scrunched[0]) * 1000


# x_init = [0]

# min_period = minimize(max_power, x_init, method="Nelder-Mead")

peaks = []
periods = []
for i in range(100, 300):
    min_opt = minimize(max_power, [0], args=(i), method="Nelder-Mead")
    periods.append(VELA_PERIOD + 100000 * min_opt.x)
    peaks.append(min_opt.fun)

In [None]:
plt.plot(np.linspace(100, 300, 200), peaks)

In [None]:
plt.plot(np.linspace(100, 300, 200), periods)

In [None]:
def max_power(vela_period, segment):
    N_THREADS = 1
    # segment = x[0]
    window = np.hanning(segment)

    n_segments = (data.shape[1] - segment) // segment

    frequencies_axis = fftfreq(segment, 1 / SAMPLING_RATE)
    frequencies_axis = np.fft.fftshift(frequencies_axis) + SKY_FREQUENCY
    frequencies_axis_mhz = frequencies_axis / 1_000_000

    output = np.empty((N_THREADS, n_segments, segment))

    for channel in [127]:
        for j in range(n_segments):
            start = j * (segment)
            end = start + segment
            if end > data.shape[1]:
                print("past end of array")
                break
            sliced_data = data[channel, start:end]

            ff_transform = fft(sliced_data * window)
            ff_shift = fftshift(ff_transform)
            power = np.abs(ff_shift)

            output[0, j, :] = power

    n_channels = 1
    seconds_between_samples = 27 / 51_200_000
    samples_89ms = vela_period / 1000 / seconds_between_samples
    samples_89ms

    FOLD_SEGMENT = int(round(samples_89ms // (segment), 0))
    # n_channels = output.shape[0]
    output_folded = np.zeros((n_channels, FOLD_SEGMENT, segment))
    num_folds = 0

    i = 0
    while True:
        start = i * FOLD_SEGMENT
        end = start + FOLD_SEGMENT
        if end > output.shape[1]:
            break
        output_folded[0, :] += output[0, start:end]
        num_folds += 1
        i += 1
    num_folds /= n_channels

    output_folded /= num_folds
    VELA_DM = 67.99

    SECONDS_BETWEEN_SEGMENTS = seconds_between_samples * (segment)
    frequencies_axis = fftfreq(segment, 1 / SAMPLING_RATE)
    frequencies_axis = np.fft.fftshift(frequencies_axis) + SKY_FREQUENCY
    frequencies_axis_mhz = frequencies_axis / 1_000_000

    delay_samples = np.round(
        4.15
        * 10**3
        * VELA_DM
        * ((frequencies_axis_mhz**-2) - (frequencies_axis_mhz[-1] ** -2))
        / SECONDS_BETWEEN_SEGMENTS,
        0,
    ).astype(int)

    transposed_data = np.transpose(output_folded, (0, 2, 1)).copy()
    transposed_data_shape = transposed_data.shape

    final_data = np.zeros(
        shape=(
            transposed_data_shape[0],
            transposed_data_shape[1],
            transposed_data_shape[2] - max(delay_samples),
        )
    )
    for i in range(delay_samples.shape[0]):
        if delay_samples[i] < max(delay_samples):
            final_data[:, i, :] = transposed_data[
                :, i, delay_samples[i] : (delay_samples[i] - max(delay_samples))
            ]
        else:
            final_data[:, i, :] = transposed_data[:, i, delay_samples[i] :]

    dedispersed_data = np.transpose(final_data, (0, 2, 1))
    # output_folded_scrunched = dedispersed_data.mean(axis=2)
    # normalized_output_folded_scrunched = output_folded_scrunched - output_folded_scrunched.mean(axis=1, keepdims=True)
    return dedispersed_data


optimal_dedispersed = max_power(89.4765, SEGMENT // 2)


fig, ax = plt.subplots(1, 1, figsize=(20, 16), sharex=True, sharey=True)

for i in range(1):
    #    ax = axs[i]
    im = ax.imshow(
        optimal_dedispersed[i].T,
        aspect="auto",
        # cmap="inferno",
        origin="lower",
        # extent=[times[0], times[-1], frequencies_axis[0], frequencies_axis[-1]],
    )
    ax.set_ylabel(f"Channel {i + 1} Frequency [Hz]")
    ax.set_xlabel("Time [s]")


fig.colorbar(im, ax=axs, label="Amplitude")


plt.suptitle("Frequency vs Time for Multiple Channels (Folded)", fontsize=16)


# plt.tight_layout()
plt.show()

In [None]:
# Calculate the signal to noise


frequency_spectrum = optimal_dedispersed.mean(axis=2)
frequency_spectrum = frequency_spectrum - frequency_spectrum.mean(axis=1, keepdims=True)

# Take the last half of the sample to calibrate the stdev
max(frequency_spectrum[0]) / np.std(frequency_spectrum[len(frequency_spectrum) // 2 :])
# so the signal to noise to about 7 standard deviations away.

In [None]:
plt.plot(times[: FOLD_SEGMENT * 2 - 5], frequency_spectrum[0])