# Evaluating spectrogram normalizers to counteract effect of non-flat spectra due to filter banks
This notebook does the following:
1) Generates spectrogram samples (noise-only or tone+noise)
2) Performs non-coherent averaging 
3) Applies split-window-two-pass-mean (sw2pm.py) or simple_norm.py normalization and thresholding algorithms
4) For each normalization algorithm, plots spectra and tone detections for many values of number of averages,
and number of subbands.  The simple_norm.py algorithm is computationally simpler but achieves similar results
to the sw2pm algorithm.  Tones are detected as expected, even at the far edges of the coarse channel band.



In [None]:
import sys
import os

#matplotlib inline
import matplotlib.pyplot as plt
params = {'legend.fontsize': 'medium',
          'figure.figsize': (10,6),
         'axes.labelsize': 'large',
         'axes.titlesize':'large',
         'xtick.labelsize':'large',
         'ytick.labelsize':'large'}
plt.rcParams.update(params)

import numpy as np
import scipy as sp
from scipy.fft import fft, fftshift

import time

from pathlib import Path
from numpy.random import default_rng

from sw2pm import sw2pm
from simple_norm import simple_norm

from detect_tones import detect1D, add_detection
from gen_simple_coarse_sg import gen_simple_coarse_sg

from numpy.random import randn, rand, randint, seed
seed(22)

output_dir = './sg_norm_plots/'
if not os.path.isdir(output_dir[0:-1]):
    os.system('mkdir '+output_dir[0:-1])

import std_fns as s



In [None]:
n_freq = 1024*64
n_time = 512
n_subband = 1
dc_offset = 0
dc_phase = (1+1j)/np.sqrt(2.)
i_case = 2
H_edge_db = -3
# H_edge_db = -5
freq = np.linspace(-.5,.5,n_freq,endpoint=False)
dc_reject = 1
tone_snr_db = 3.5
tone_freq = np.arange(-.45,.46,.10) 
tone_enable = 1
shear_threshold = 2.3
display_figs = False

sw2pm_enable = False

z_det = 10.
# z_det = 8.
# z_det = 6.


In [None]:
#
# generate spectrogram for single coarse channel
#

n_sti=1
# n_sti=n_time
dc = dc_offset*dc_phase

if (i_case==1):
    n_pol=1
else:
    n_pol=2

[sg,freq,Hsq]=gen_simple_coarse_sg(n_time,n_freq,n_sti,H_edge_db,n_pol,dc,tone_freq,tone_snr_db)

if dc_reject:
    sg[:,n_freq//2]=n_pol

if (n_sti==n_time):
    sg = np.reshape(sg,(n_freq))
    
f0 = freq[0]
df = freq[1]-freq[0]


In [None]:
n_subband_list = [1, 2, 8, 32, 128]
sw2pm_enable_list = [True,False]

for _,sw2pm_enable in enumerate(sw2pm_enable_list):
    for _,n_subband in enumerate(n_subband_list):
        print(f'{sw2pm_enable=}, {n_subband=}')
        log2_n_avg = round(np.log2(n_time/n_sti))

        for i_avg in range(log2_n_avg):
        # for _,i_avg in enumerate([log2_n_avg-1]):
        # for _,i_avg in enumerate([3]):
            n_avg = n_sti* 2**(i_avg+1)
            n_line_avg = 2**(i_avg+1)

            print(f'{n_sti=}, {n_line_avg=}, {n_avg=}')

            #
            # sum spectroram rows
            #

            col_sum = np.mean(sg[0:n_line_avg,:],0)

            #
            # normalize spectrum
            #

            if sw2pm_enable:
                print('Running sw2pm')
                [norm1, Hsq_est, mean1, std1] = sw2pm(col_sum,201,11,shear_threshold,calc_stats=True,n_subband=n_subband)
            else:
                print('Running simple_norm')
                [norm1, Hsq_est, mean1, std1] = simple_norm(col_sum,shear_threshold,n_subband=n_subband)

            col_sum_db = s.db(col_sum)
            
            norm1_db = s.db(norm1)
            Hsq_db = s.db(Hsq)
            Hsq_est_db = s.db(Hsq_est)

            mean1_db = s.db(mean1)
            threshold1_db = s.db(mean1 + z_det*std1)

            # np.set_printoptions(formatter={'float': '{: 0.3f}'.format})
            # print('mean:')
            # print(mean1)
            # print('std:')
            # print(std1)
            # print('Threshold dB:')
            # print(threshold1_db)

            #
            # find detections
            #
            
            [n_det,det_out]=detect1D(norm1,f0,df,n_subband,mean1,std1,z_det,verbose=False)

            # print(f'{n_avg=}, {n_det=}')
            # print(det_out)
            
            if (n_det>0):
                f_det = np.array(det_out.loc[:,"freq"])
                det_peak_db = np.array(det_out.loc[:,"peak_db"])
                margin_list_db = np.array(det_out.loc[:,"margin_db"])
                mean_margin_db = np.mean(margin_list_db)
                snr_list_db = np.array(det_out.loc[:,"snr_db"])
                mean_snr_db = np.mean(snr_list_db)
            else:
                f_det = []
                det_peak_db = []
                margin_list_db = []
                mean_margin_db = np.nan
                snr_list_db = []
                mean_snr_db = np.nan
            
            #
            # find detections
            #

            f_subband = np.linspace(-.5,.5,n_subband,endpoint=False)+1/n_subband/2

            if sw2pm_enable:
                case_str = f'sw2pm Test Case, {n_avg} x {n_freq}'
            else:
                case_str = f'Simple Normalizer Test Case, {n_avg} x {n_freq}'
                
            if (n_subband==1):
                result_str = f'n-avg={n_avg:4d} mean={mean1:.4f} std-dev={std1:.3f}\nmean/std={mean1/std1:6.3f} vs. {np.sqrt(2*n_avg):6.3f}\n#Dets={n_det}, Mean SNR={mean_snr_db:.2f} dB, Mean Margin={mean_margin_db:.2f} dB'
            else:
                result_str = f'#Averages={n_avg:4d}, #subbands={n_subband}\n#Dets={n_det}, Mean SNR={mean_snr_db:.2f} dB, Mean Margin={mean_margin_db:.2f} dB'
            
            print(result_str)

            #
            # plot output
            #

            import matplotlib.pyplot as plt
            #matplotlib inline

            fig = plt.figure(figsize=(10, 6))
            plt.subplot(2,1,1)
            plt.plot(freq,col_sum_db,'-',label='Mean Spectrum')
            plt.plot(freq,Hsq_db,'-',label='Coarse Channel Hsq(f)')
            plt.plot(freq,Hsq_est_db,'-r',label='Local Mean Estimate')

            plt.xlim(-.5,.5)
            plt.ylim(-5.,20.)
            #plt.xlabel('Normalized Frequency within Coarse Channel')
            plt.ylabel('Amplitude dB')
            plt.title(case_str)
            plt.figtext(.15,.80,result_str,fontsize=10)
            plt.legend(loc='upper right')
            plt.grid()

            plt.subplot(2,1,2)
            plt.plot(freq,norm1_db,'-',label='Normalized Spectrum')
            if (n_subband==1):
                plt.plot([-.5,.5],mean1_db*np.ones(2),'--',label='Mean',linewidth=2.5)
                plt.plot([-.5,.5],threshold1_db*np.ones(2),'-r',label=f'Threshold z={z_det:.0f}',linewidth=2.5)
                plt.plot(f_det,det_peak_db,'*g',label=f'Detections',linewidth=2.5)
            else:
                plt.plot(f_subband,mean1_db,'--',label='Mean',linewidth=2.5)
                plt.plot(f_subband,threshold1_db,'-*r',label=f'Threshold z={z_det:.0f}',linewidth=2.5)
                plt.plot(f_det,det_peak_db,'*g',label=f'Detections',linewidth=2.5)
            plt.xlim(-.5,.5)
            plt.ylim(-5.,20.)
            plt.xlabel('Normalized Frequency within Coarse Channel')
            plt.ylabel('Amplitude dB')
            # plt.title('After Normalization and Detection')
            plt.figtext(.15,.42,'After Normalization and Detection',fontsize='large')
            plt.legend(loc='upper right')
            plt.grid()

            if sw2pm_enable:
                plt.savefig(output_dir+f'01-sw2pm-normalization-{n_subband}-seg-{n_avg:03d}-avg'+'.png',bbox_inches='tight')
            else:
                plt.savefig(output_dir+f'02-simple-normalization-{n_subband}-seg-{n_avg:03d}-avg'+'.png',bbox_inches='tight')
                

            if display_figs:
                plt.show()
            else:
                plt.close(fig)


In [None]:
det_out

In [None]:
whos

In [None]:
# Beep in WSL
os.system("powershell.exe '[console]::beep(261.6,700)'")