# Common functionality

In [None]:
from time import sleep
import apsuite.commisslib.meas_fofb_sysid as fofbsysid
print(fofbsysid.__file__)
from multiprocessing import Pool, set_start_method
from resource import getrusage, RUSAGE_SELF

from enum import Enum
import h5py
from matplotlib import pyplot as plt
from matplotlib.colors import LogNorm
import matplotlib.ticker as ticker
import matplotlib
import numpy as np
from os import listdir, path
import pickle
from scipy import signal, io

from apsuite.commisslib.meas_fofb_sysid import FOFBSysIdAcq
from siriuspy.devices import EVG, InjCtrl, SOFB, Trigger
%matplotlib widget
%matplotlib inline

In [None]:
## fpath = '/ibira/lnls/labs/swc/MachineStudies/15-7-2024-biggerlimit/current/' ## '/ibira/lnls/labs/swc/MachineStudies/22-07-2024/current/'
## prbs_type = 'Correctors'

fpath = '/ibira/lnls/labs/swc/MachineStudies/15-7-2024-biggerlimit/bpm-lfsr9-sd4-teste/'
prbs_type = 'BPMs'

In [None]:
acq = FOFBSysIdAcq()

# Data analysis

In [None]:
def init_acq_from_data(acq, data):
    # parameters to generate singular mode levels
    params = data['params']
    acq.params.svd_levels_regularize_matrix = True
    acq.params.svd_levels_reg_sinval_min = data['params']['svd_levels_reg_sinval_min']
    acq.params.svd_levels_reg_tikhonov_const = data['params']['svd_levels_reg_tikhonov_const']
    acq.params.svd_levels_bpmsx_enbllist = data['params']['svd_levels_bpmsx_enbllist']
    acq.params.svd_levels_bpmsy_enbllist = data['params']['svd_levels_bpmsy_enbllist']
    acq.params.svd_levels_ch_enbllist = data['params']['svd_levels_ch_enbllist']
    acq.params.svd_levels_cv_enbllist = data['params']['svd_levels_cv_enbllist']
    acq.params.svd_levels_rf_enbllist = data['params']['svd_levels_rf_enbllist']
    acq.params.svd_levels_respmat = data['params']['svd_levels_respmat']

    acq.params.svd_levels_singmode_idx = data['params']['svd_levels_singmode_idx']

In [None]:
def get_mode_ffts(file, prbs_type):

    if prbs_type not in ['BPMs', 'Correctors']:
        print(f'Invalid prbs_type value. Valid values are: \'BPMs\', \'Correctors\'')
        return

    acq = FOFBSysIdAcq()
    
    print(f'reading {file}...')
    ## with open(f'{file}', 'rb') as f:
        ## data = pickle.load(f)
    data = acq.load_data(file)

    init_acq_from_data(acq, data)

    if prbs_type == 'BPMs':
        u_n = np.hstack([data['params']['prbs_bpmposx_lvl0'][acq.params.svd_levels_bpmsx_enbllist], \
                         data['params']['prbs_bpmposy_lvl0'][acq.params.svd_levels_bpmsy_enbllist]])
        # u_n = np.hstack([data['data']['prbs_bpmposx_lvl0_beam_order'][0][acq.params.svd_levels_bpmsx_enbllist], \
        #                 data['data']['prbs_bpmposy_lvl0_beam_order'][0][acq.params.svd_levels_bpmsy_enbllist]])
        u_n = u_n/2.0
    elif prbs_type == 'Correctors':
        u_n = data['params']['prbs_fofbacc_lvl0'][np.hstack([acq.params.svd_levels_ch_enbllist, \
                                                             acq.params.svd_levels_cv_enbllist])]
        val = 0.02
        if 'C3' in path.split(file)[-1]:
            val /= 2
        u_n *= val/(u_n[u_n != 0][0])
        ## psconfig_mat = data['data']['psconfig_mat']
        ## psconfig_mat = np.reshape(psconfig_mat, (160, -1))
        ## file = file.replace('\uf022', ':')
        ## corrname = path.basename(file).removesuffix('.h5')
        ## corridx = acq.devices['famsysid'].psnames.index(corrname)
        ## filt_gain = psconfig_mat[corridx, 2]
        ## u_n /= filt_gain

    u_n = np.array([u_n], dtype=float)

    lfsr = data['data']['prbs_lfsr_len'][0]
    sd = data['data']['prbs_step_duration'][0]
    p = (2**lfsr - 1) * sd
 
    # Transient in PRBS periods
    n_p_transient = 4

    prbs = np.array(data['data']['prbs_data'][0], dtype=float)
    prbs = prbs[data['params']['acq_nrpoints_before']:]
    prbs = prbs[n_p_transient*p:]
    prbs[prbs == 0] = -1

    # Mimic gateware's PRBS moving average
    if prbs_type == 'Correctors':
        N = 2 # TODO: get it from data['params']

        a = np.zeros(N)
        a[0] = 1
        b = (1/N)*np.ones(N)
        prbs = signal.lfilter(b, a, prbs)

    prbsu = np.array([prbs]).T @ u_n
    prbsu -= np.mean(prbsu, axis=0)[None, :]

    orb = np.array(np.hstack([data['data']['orbx'], data['data']['orby']]), dtype=float)

    not_excited_orb = orb[:data['params']['acq_nrpoints_before']]
    ref_orb = np.mean(not_excited_orb, axis=0)[None, :]

    excited_orb = orb[data['params']['acq_nrpoints_before']:]
    excited_orb = excited_orb[n_p_transient*p:]
    #excited_orb -= np.mean(excited_orb, axis=0)[None, :]
    excited_orb -= ref_orb

    # SNR computation
    # noise_orb = not_excited_orb - ref_orb
    # noise_orb_rms = np.sqrt(np.mean(noise_orb**2, axis=0)[None, :])

    # excited_orb_rms = np.sqrt(np.mean(excited_orb**2, axis=0)[None, :])

    # snr = 20*np.log10(np.divide(excited_orb_rms, noise_orb_rms)[0])

    # Switching noise removal
    # for N in [2]: ##[2, 4]:
    #     a = np.zeros(N)
    #     a[0] = 1
    #     b = (1/N)*np.ones(N)

    #     prbsu = signal.filtfilt(b, a, prbsu)
    #     excited_orb = signal.filtfilt(b, a, excited_orb)

    # Orbit averaging
    excited_orb = excited_orb.reshape((excited_orb.shape[0] // p, p, excited_orb.shape[1]))
    excited_orb = np.average(excited_orb, axis=0)
    prbsu = prbsu[:p]

    fs = data['data']['sampling_frequency']
    freqs = np.fft.rfftfreq(prbsu.shape[0], d=1/fs)

    orb_fft = np.fft.rfft(excited_orb, axis=0)
    prbsu_fft = np.fft.rfft(prbsu, axis=0)

    # FFT binning
    n = excited_orb.shape[0]
    step = n // p

    freqs_b = freqs[step::step]
    prbsu_fft_b = prbsu_fft[step::step]
    orb_fft_b = orb_fft[step::step]

    # PRBS notches removal
    notches = np.arange(p // sd - 1, freqs_b.shape[0], p // sd)

    freqs_b = np.delete(freqs_b, notches)
    prbsu_fft_b = np.delete(prbsu_fft_b, notches, axis=0)
    orb_fft_b = np.delete(orb_fft_b, notches, axis=0)

    return freqs_b, prbsu_fft_b, orb_fft_b

### Build PRBS and orbit FFT 'cubes'

In [None]:
## acqs_fn = listdir(fpath)
acqs_fn = acq.devices['famsysid'].psnames
out_corrs = ['SI-01M2:PS-FCH', 'SI-01M1:PS-FCH', 'SI-01M2:PS-FCV', 'SI-01M1:PS-FCV']

for corr in out_corrs:
    acqs_fn.remove(corr)
print(len(acqs_fn))

for i in range(len(acqs_fn)):
    acqs_fn[i] = str(acqs_fn[i])
    ##if '07C2' in acqs_fn[i]:
    ##    acqs_fn[i] = acqs_fn[i].replace('07', '06')
    acqs_fn[i] = acqs_fn[i].replace(':', '\uf022') + '.h5'

acqs_fn

In [None]:
n_modes = 156

if prbs_type == 'BPMs':
    args = [(fpath + f'{i}.h5', prbs_type) for i in range(n_modes)]
elif prbs_type == 'Correctors':
    args = [(fpath + f'{acqs_fn[i]}', prbs_type) for i in range(n_modes)]

%time mode_ffts = Pool(processes=16).starmap(get_mode_ffts, args)

In [None]:
bpmenbl = acq.devices['fofb'].bpmxenbl
bpmnames = np.array(acq.devices['famsysid'].bpmnames)[bpmenbl]
print(bpmnames)

In [None]:
prbsu_fft_cube = []
orb_fft_cube = []

for mode, (freqs, prbsu_fft, orb_fft) in enumerate(mode_ffts[:n_modes]):
    prbsu_fft_cube.append(prbsu_fft.T)
    orb_fft_cube.append(orb_fft.T)

prbsu_fft_cube = np.array(prbsu_fft_cube, dtype=complex)
orb_fft_cube = np.array(orb_fft_cube, dtype=complex)

print(prbsu_fft_cube.shape)
print(orb_fft_cube.shape)

### Compute Response Matrices and Experimental Sigma

In [None]:
resp_mat_f = []
exp_sigma_f = []

for f in range(freqs.shape[0]):
    resp_mat_f.append(orb_fft_cube[:, :, f].T @ np.linalg.pinv(prbsu_fft_cube[:, :, f]).T)
    _, sigma, _ = np.linalg.svd(resp_mat_f[f], full_matrices=False)
    exp_sigma_f.append(sigma)

resp_mat_f = np.array(resp_mat_f, dtype=complex)
exp_sigma_f = np.array(exp_sigma_f)

print(resp_mat_f.shape)
print(exp_sigma_f.shape)

### Plot Response Matrices

In [None]:
if prbs_type == 'Correctors':
    fig, axs = plt.subplots(2, 2, figsize=(10, 10), sharey='row')
    
    # respmat = acq.devices['fofb'].respmathw_mon[:, :-1].T
    # respmat /= 6.25e-5
    
    respmat_fn = '/ibira/lnls/labs/gie/MachineStudies/FOFBSysId/MATLAB_objs/respmat-no-rf-line-24-05-20.mat'
    respmat = io.loadmat(respmat_fn)['mat_d'].T
    
    bpmenbllist = np.hstack([acq.devices['fofb'].bpmxenbl, acq.devices['fofb'].bpmyenbl])
    correnbllist = np.hstack([acq.devices['fofb'].chenbl, acq.devices['fofb'].cvenbl])
    
    respmat = respmat[bpmenbllist]
    respmat = respmat[:, correnbllist]
    
    f = 0
    print(respmat.shape, resp_mat_f[f].shape)
    
    axs[0, 0].imshow(np.abs(respmat))
    axs[0, 1].imshow(np.abs(resp_mat_f[f]))
    axs[1, 0].imshow(np.abs(respmat @ np.linalg.pinv(respmat)))
    axs[1, 1].imshow(np.abs(resp_mat_f[f] @ np.linalg.pinv(respmat)))
    axs[0, 0].set_xlabel(f'respmat')
    axs[0, 1].set_xlabel(f'resp_mat_f[{freqs[f]:.2f} Hz]')
    axs[1, 0].set_xlabel(f'respmat @ np.linalg.pinv(respmat)')
    axs[1, 1].set_xlabel(f'resp_mat_f[{freqs[f]:.2f} Hz] @ np.linalg.pinv(respmat)')
    plt.show()

### Plot Experimental Sigma

In [None]:
fig, ax = plt.subplots(1, figsize=(13, 8))

if prbs_type == 'BPMs':
    last_bad_modes = 4
elif prbs_type == 'Correctors':
    last_bad_modes = 0

for i in range(exp_sigma_f.shape[1] - last_bad_modes):
    interest = freqs > 0 #np.logical_and(freqs > 0, freqs < 1e4)
    ax.semilogx(freqs[interest], 20*np.log10(exp_sigma_f[interest, i]))

ax.set_title('Singular Values')
ax.set_ylabel('Singular Values (dB)')
ax.set_xlabel('Frequency (Hz)')
ax.grid(which='both')

#fig.tight_layout()
## plt.savefig('exp_sigma.jpg')