## Implement the whitening using AR model

### Create AR model

In [None]:
import time
import os
import json 
from pytsa.tsa import SeqView_double_t as SV
from wdf.config.Parameters import Parameters
from wdf.processes.Whitening import Whitening
from wdf.processes.DWhitening import  DWhitening
from wdf.processes.DownSampling import *
from pytsa.tsa import FrameIChannel
import logging, sys

# Configure logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Set up logging output
console_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(console_handler)

# Log a debug message
logging.debug("info")
 
path = os.getcwd()
 
# parent directory

parent_dir=os.path.abspath(os.path.join(path, os.pardir))
#filein=os.path.join(parent_dir,"caches/E1.ffl")
MDC_PATH = "/cvmfs/et-gw.osgstorage.org/et-gw/PUBLIC/MDC1/data"
filein=MDC_PATH+'/E1/E-E1_STRAIN_DATA-1000000000-2048.gwf'
# Flag to determine whether to create a new JSON configuration file
new_json_config_file = True

# Create or update JSON configuration file if needed
if new_json_config_file:
    configuration = {
      "file": filein,
      "channel": 'E1:STRAIN', 
      "len": 2.0,
      "gps": 1000000001.,
      "outdir": "./",
      "dir": "./", 
      "ARorder": 1000,
      "learn": 200,
      "preWhite": 4,
      "ResamplingFactor":4,
    }

filejson = os.path.join(os.getcwd(),"WavRecDS.json")
file_json = open(filejson, "w+")
json.dump(configuration, file_json)
file_json.close()


 
par = Parameters()
filejson = "WavRecDS.json"
try:
    par.load(filejson)
    logging.info("read parameters from JSON file")
except IOError:
    logging.error("Cannot find resource file " + filejson)
    quit()
    
strInfo = FrameIChannel(par.file, par.channel, 1.0, par.gps)
Info = SV()
strInfo.GetData(Info)
par.sampling = int(1.0 / Info.GetSampling())
par.resampling=par.sampling/par.ResamplingFactor
 
logging.info("channel= %s at sampling frequency= %s" %(par.channel, par.resampling))
whiten=Whitening(par.ARorder)

par.ARfile = "./ARcoeff-AR%s-fs%s-%s.txt" % (
                par.ARorder, par.resampling, par.channel)
par.LVfile ="./LVcoeff-AR%s-fs%s-%s.txt" % (
                par.ARorder, par.resampling, par.channel)

if os.path.isfile(par.ARfile) and os.path.isfile(par.LVfile):
        logging.info('Load AR parameters')
        whiten.ParametersLoad(par.ARfile, par.LVfile)
else:
        logging.info('Start AR parameter estimation')
        ######## read data for AR estimation###############
        strLearn = FrameIChannel(
                    par.file, par.channel, par.learn, par.gps
                )
        Learn = SV()
        Learn_DS = SV()
        par.Noutdata = int(par.learn *par.resampling)
        ds = DownSampling(par,estimation=True)
        strLearn.GetData(Learn)
        Learn_DS=ds.Process(Learn)
        whiten.ParametersEstimate(Learn_DS)
        whiten.ParametersSave(par.ARfile, par.LVfile)
        del Learn, ds, strLearn, Learn_DS

In [None]:
# sigma for the noise
par.sigma = whiten.GetSigma()
logging.info('Estimated sigma= %s' % par.sigma)

In [None]:
N=int(par.resampling*par.len)
padlen=int(par.sampling/4)
par.Noutdata=N
ds = DownSampling(par,padlen=padlen)
Dwhiten=DWhitening(whiten.LV,N,0)
streaming = FrameIChannel(par.file, par.channel, par.len, par.gps)
data = SV()
dataw = SV()
dataww = SV()

if os.path.isfile(par.LVfile):
    logging.info('Load LV parameters')
    Dwhiten.ParametersLoad(par.LVfile)

 

###---whitening preheating---###
for i in range(par.preWhite):
    streaming.GetData(data)
    data_ds=ds.Process(data)
    whiten.Process(data_ds, dataw)
    Dwhiten.Process(data_ds, dataww) 
 
    

In [None]:
streaming.GetData(data)
data_ds=ds.Process(data)
whiten.Process(data_ds, dataw)
Dwhiten.Process(data_ds, dataww)

In [None]:
 # Import necessary libraries and modules
import matplotlib.pyplot as plt
import IPython
import IPython.display
import matplotlib as mpl
from matplotlib import cm
from scipy import signal
from matplotlib.colors import LogNorm
from PIL import Image
# Import NumPy library
import numpy as np
# Set matplotlib configuration parameters
mpl.rcParams['figure.figsize'] = (8, 6)
mpl.rcParams['axes.grid'] = False
 

x=np.zeros(dataw.GetSize())
y=np.zeros(dataw.GetSize())
yw=np.zeros(dataw.GetSize())
yww=np.zeros(dataww.GetSize())
for i in range(dataw.GetSize()):
    x[i]=data_ds.GetX(i)
    y[i]=data_ds.GetY(0,i)
    yw[i]=dataw.GetY(0,i)
    yww[i]=dataww.GetY(0,i)



# Plotting
plt.figure(figsize=(10, 4))
plt.plot(x, y, label='Raw data')
plt.plot(x, yw, label='Whitened data')
plt.plot(x, yww, label='Double whitened data')

# Add legend
plt.legend()

# Show plot
plt.show()


In [None]:
 

# Calculate power spectral density (PSD) using Welch's method
f, Pxx_den = signal.welch(y, par.sampling, nperseg=2048)
f, Pxx_denW = signal.welch(yw, par.sampling, nperseg=2048)
f, Pxx_denWW = signal.welch(yww, par.sampling, nperseg=2048)

# Create a new figure and axis object
fig, ax = plt.subplots()

# Plot PSD for raw data
ax.semilogy(f, Pxx_den, label='Raw data')

# Plot PSD for whitened data
ax.semilogy(f, Pxx_denW, label='Whitened data')

# Plot PSD for double whitened data
ax.semilogy(f, Pxx_denWW, label='Double whitened data')

# Set labels for x and y axes
plt.xlabel('frequency [Hz]')
plt.ylabel('PSD [V**2/Hz]')

# Add legend
plt.legend()

# Show plot
plt.show()


In [None]:


def prepareImage(y, fs, title):
    """
    Create and display a spectrogram plot using the Short-Time Fourier Transform (STFT).

    Parameters:
        y (numpy.ndarray): Signal data.
        fs (float): Sampling frequency.
        title (str): Title for the plot.

    Returns:
        None
    """
    # Calculate the spectrogram using the Short-Time Fourier Transform (STFT)
    f, t, Sxx = signal.spectrogram(y, fs)
    
    # Create a pseudocolor mesh plot using the spectrogram data
    plt.pcolormesh(t, f, Sxx, cmap='viridis', shading='gouraud', alpha=0.95)
    
    # Set the y-axis scale to logarithmic
    plt.yscale('log')
    
    # Set the y-axis limits to be between 10 Hz and half of the sampling frequency
    plt.ylim(10, fs/2)
    
    # Set the title for the plot
    plt.title(str(title))
    
    # Set labels for the x and y axes
    plt.xlabel('Time (secs)')
    plt.ylabel('Frequency (Hz)')
    
    # Add a colorbar to the plot for visualization
    plt.colorbar()
    
    # Display the spectrogram plot
    plt.show()
    plt.close()

In [None]:
fs=par.resampling
prepareImage(yw,fs,title="Spectrogram")

In [None]:
def prepareImage_cwt(y, fs, title="title"):
    """
    Create and display a Continuous Wavelet Transform (CWT) spectrogram plot.

    Parameters:
        y (numpy.ndarray): Signal data.
        fs (float): Sampling frequency.
        title (str, optional): Title for the plot (default is "title").

    Returns:
        None
    """
    # Define the scale parameter 'w' for the Morlet wavelet
    w = 6.
    
    # Create an array 'freq' representing frequencies from 1 Hz to Nyquist frequency (fs/2)
    freq = np.linspace(1, fs/2, int(fs/2))
    
    # Calculate the wavelet widths using the scale 'w' and frequencies
    widths = w * fs / (2 * freq * np.pi)
    
    # Calculate the Continuous Wavelet Transform (CWT) of the signal and obtain the squared magnitude
    z = np.abs(signal.cwt(y, signal.morlet2, widths, w=w))**2
    
    # Create an array 'x' representing time values (indices)
    x = np.arange(0, len(y))
    
    # Create a pseudocolor mesh plot using the CWT spectrogram data
    cmap = plt.colormaps["plasma"]
    cmap = cmap.with_extremes(bad=cmap(0))
    plt.pcolormesh(x, freq, z,cmap=cmap, rasterized=True)
    # Set the y-axis scale to logarithmic
    plt.yscale('log')
    
    # Set the y-axis limits between 1 Hz and the Nyquist frequency (fs/2)
    plt.ylim(1, fs/2)
    
    # Set the title for the plot
    plt.title(str(title))
    
    # Display the CWT spectrogram plot
    plt.show()
    plt.close()

In [None]:
prepareImage_cwt(yw,fs,"Wavelet map")

## Let's whiten the data containing a high-SNR signal and plot it for not whitened and whitened data

In [None]:
tc=1001620463.11925
par.file=MDC_PATH+'/E1/E-E1_STRAIN_DATA-1001619968-2048.gwf'
par.len=4.0
lenS=par.len
N=int(par.resampling*lenS)
padlen=int(par.sampling/4)
par.Noutdata=N
par.len=lenS
gpsEvent=tc
gps=gpsEvent-par.len-(par.preWhite)*lenS +padlen/par.sampling
ds = DownSampling(par,padlen=padlen)
Dwhiten=DWhitening(whiten.LV,N,0)
streaming = FrameIChannel(par.file, par.channel, lenS, gps)

#Try to center 1sec beore and 1 after the event
del data,dataw,dataww,data_ds
data = SV()
dataw = SV()
dataww = SV()
 

###---whitening preheating---###
for i in range(par.preWhite):
    streaming.GetData(data)
    data_ds=ds.Process(data)
    whiten.Process(data_ds, dataw)
    Dwhiten.Process(data_ds, dataww)  

In [None]:
streaming.GetData(data)
data_ds=ds.Process(data)
whiten.Process(data_ds, dataw)
Dwhiten.Process(data_ds, dataww)


 
x=np.zeros(dataw.GetSize())
y=np.zeros(dataw.GetSize())
yw=np.zeros(dataw.GetSize())
yww=np.zeros(dataww.GetSize())
for i in range(dataw.GetSize()):
    x[i]=data_ds.GetX(i)
    y[i]=data_ds.GetY(0,i)
    yw[i]=dataw.GetY(0,i)
    yww[i]=dataww.GetY(0,i)


In [None]:
fig, ax = plt.subplots()

ax.plot(x, y,  label='Raw data')
ax.plot(x, yw, label='whitened data')
ax.plot(x, yww, label='D-whitened data')
ax.legend()
plt.show()

In [None]:
%matplotlib inline
fig, ax = plt.subplots()
#ax.plot(x, yww, 'gray',label='D-whitened data')
ax.plot(x, yw, 'c',label='whitened data')

In [None]:
 
xraw=np.zeros(data.GetSize())
yraw=np.zeros(data.GetSize())
for i in range(data.GetSize()):
    xraw[i]=data.GetX(i)
    yraw[i]=data.GetY(0,i)
fs=par.sampling
frs=par.resampling
fraw, Pxx_denraw = signal.welch(yraw, par.sampling, nperseg=fs)
f, Pxx_den = signal.welch(y, par.resampling, nperseg=frs)
fds, Pxx_denDS = signal.welch(yw, par.resampling, nperseg=frs)
fw, Pxx_denW= signal.welch(yww, par.resampling, nperseg=frs)
fig, ax = plt.subplots(figsize=(8,6))
ax.loglog(fraw, np.sqrt(Pxx_denraw),'b', label = 'h_raw', linewidth=2)
ax.loglog(f, np.sqrt(Pxx_den),'gray', label = 'h_rawDS')
ax.loglog(fds, np.sqrt(Pxx_denDS),'g', label='resamp+whiten') 
ax.loglog(fw, np.sqrt(Pxx_denW),'r', label='resamp +  doublew') 
plt.xlabel('frequency [Hz]')
plt.ylabel('PSD [V**2/Hz]')
plt.legend(), plt.ylim(1e-26, 1e-18), plt.grid(which='both');

In [None]:
 

def prepareImage_gw(x,y,fs,title="Time-Freqeuncy map"):
    w = 10.
    freq = np.linspace(1, fs/2, int(fs/2))
    widths = w*fs / (2*freq*np.pi)
    z = np.abs(signal.cwt(y, signal.morlet2, widths, w=w))**2

    plt.pcolormesh(x, freq,z,cmap='coolwarm',shading='gouraud',alpha=0.95)
    plt.yscale('log')
    plt.ylim(4, 1000)
    plt.title=title
    plt.show() 
   
    return 

In [None]:
prepareImage_gw(x,y,par.resampling)

In [None]:
prepareImage_gw(x,yw,par.resampling)

prepareImage_gw(x,yww,par.resampling)

## <font color='purple'> Challenge: Could you prepare the same plot for a signal with SNR <30? </font>