In [None]:
from tsi_toolkit import *
import tsi_toolkit as tsi
import numpy as np
import matplotlib.pyplot as plt
from pylag import *
import pylag
%load_ext autoreload
%autoreload 2

In [None]:
def generate_time_series_with_power_spectrum(power_spectrum_func, duration, sampling_rate):
    n_points = int(duration * sampling_rate)
    dt = 1 / sampling_rate

    freqs = np.fft.fftfreq(n_points, d=dt)
    freqs = np.fft.fftshift(freqs)  # Shift zero frequency to the center
    positive_freqs = freqs[freqs >= 0]

    psd = power_spectrum_func(positive_freqs) 
    amplitudes = np.sqrt(psd)
    random_phases = np.exp(1j * 2 * np.pi * np.random.rand(len(amplitudes)))
    positive_fft = amplitudes * random_phases

    full_fft = np.zeros_like(freqs, dtype=complex)
    full_fft[freqs >= 0] = positive_fft
    full_fft[freqs < 0] = np.conj(positive_fft[::-1])

    time_series = np.fft.ifft(np.fft.ifftshift(full_fft)).real
    time = np.arange(0, duration, dt)
    fft = np.fft.fftshift(np.fft.fft(time_series))
    power_spectrum = np.abs(fft) ** 2

    return time, time_series+0.1, freqs, fft, power_spectrum

def example_power_spectrum(freqs):
    psd = 1 / (freqs ** 2 + 1)
    return psd

def example_power_spectrum_2(freqs):
    psd = 1 / (freqs ** 1.5 + 1)
    return psd

# Generate time series
duration = 10
sampling_rate = 100 
time, time_series, freqs, fft, power_spectrum = generate_time_series_with_power_spectrum(example_power_spectrum, duration, sampling_rate)

plt.figure(figsize=(12, 6))
plt.subplot(2, 1, 1)
plt.plot(time, time_series)
plt.title("Time Series")
plt.xlabel("Time")
plt.ylabel("Amplitude")

plt.subplot(2, 1, 2)
plt.scatter(freqs[freqs >= 0], power_spectrum[freqs >= 0])
plt.title("Power Spectrum")
plt.xlabel("Frequency")
plt.ylabel("Power")
plt.xscale("log")
plt.yscale("log")
plt.tight_layout()
plt.show()

In [None]:
# verify psd
power_spectrum_tsi = PowerSpectrum(times=time, values=time_series, norm=False)

# Check if power_spectrum_tsi.powers agrees with power_spectrum[freqs >= 0]
assert np.allclose(power_spectrum_tsi.powers, power_spectrum[freqs > 0]), "The power spectra do not match!"

lc_pylag = pylag.LightCurve(t=time, r=time_series)
power_spectrum_pylag = pylag.Periodogram(lc=lc_pylag, norm=False)

plt.scatter(power_spectrum_tsi.freqs, 
           power_spectrum_tsi.powers, label='tsi_toolkit',s=2)
plt.scatter(freqs[freqs >= 0], power_spectrum[freqs >= 0], label='Known',s=2)
plt.scatter(power_spectrum_pylag.freq, 
           power_spectrum_pylag.periodogram, label='pylag',s=2)
plt.legend()
plt.xscale('log')
plt.yscale('log')

In [None]:
# verify cross spectrum
time2, time_series2, freqs2, fft2, power_spectrum2 = generate_time_series_with_power_spectrum(example_power_spectrum_2, duration, sampling_rate)
true_cross_spectrum = np.conj(fft) * fft2
cross_spectrum_tsi = tsi.CrossSpectrum(times1=time, values1=time_series, times2=time2, values2=time_series2, norm=False)

plt.scatter(cross_spectrum_tsi.freqs, 
            cross_spectrum_tsi.cs, label='tsi_toolkit',s=2)
plt.scatter(freqs[freqs >= 0], true_cross_spectrum[freqs >= 0], label='Known',s=2)
plt.legend()
plt.xscale('log')
plt.yscale('log')

In [None]:
# verify lag frequency spectrum
true_lag_spectrum = np.angle(true_cross_spectrum) / (2 * np.pi * freqs)
lag_spectrum_tsi = tsi.LagFrequencySpectrum(
    times1=time, values1=time_series, times2=time2, values2=time_series2, subtract_coh_bias=False
    )

plt.scatter(lag_spectrum_tsi.freqs, lag_spectrum_tsi.lags, label='tsi_toolkit',s=2)
plt.scatter(freqs[freqs >= 0], true_lag_spectrum[freqs >= 0], label='Known',s=2)
plt.legend()
plt.xscale('log')