In [1]:
input_files = '/data/SFIMJGC_Introspec/cpca_vs_introspec/spring_2024/data/cpca/JAVIER_POST_SFN2024/cpca_input_files_SUBJ_AWARE.denoise.b0.txt'
file_format = 'nifti'
n_comps     = 3260
mask_fp     = '/data/SFIMJGC_Introspec/pdn/PrcsData/cpca/Schaefer2018_400Parcels_7Networks_AAL2.MASK.nii.gz'
out_prefix  = 'A1'
pca_type    = 'complex'
rotate      = None
recon       = False
normalize   = 'zscore'
bandpass    = False
low_cut     = 0.01
high_cut    = 0.1
tr          = None
n_bins      = 30
verbose     = True

In [2]:
from utils.load_write import load_data, write_out
from scipy.signal import hilbert
import fbpca
import numpy as np

In [3]:
def hilbert_transform(input_data, verbose):
    if verbose:
        print('applying hilbert transform')
    # hilbert transform
    input_data = hilbert(input_data, axis=0)
    return input_data.conj()

In [4]:
def pca(input_data, n_comps, verbose, n_iter=10):
    # compute pca
    print('performing PCA/CPCA')
    # get number of observations
    n_samples = input_data.shape[0]
    print(' number of samples = %d' % n_samples)
    print(' input_data.shape[1] = %d' % input_data.shape[1])
    #matrix_rank = np.linalg.matrix_rank(input_data)
    #print(' rank of input matrix = % s' % str(matrix_rank))
    # fbpca pca
    (U, s, Va) = fbpca.pca(input_data, k=n_comps, n_iter=n_iter)
    # calc explained variance
    explained_variance_ = ((s ** 2) / (n_samples - 1)) / input_data.shape[1]
    total_var = explained_variance_.sum()
    # compute PC scores
    pc_scores = input_data @ Va.T
    # get loadings from eigenvectors
    loadings =  Va.T @ np.diag(s)
    loadings /= np.sqrt(input_data.shape[0]-1)
    # package outputs
    output_dict = {'U': U,
                   's': s,
                   'Va': Va,
                   'loadings': loadings.T,
                   'exp_var': explained_variance_,
                   'pc_scores': pc_scores,
                   'n_samples': n_samples,
                   'n_positions': input_data.shape[1],
                   'total_var': total_var}
    return output_dict

***
### Load the Original Data

In [5]:
func_data, mask, header = load_data(
        input_files, file_format, mask_fp, normalize,
        bandpass, low_cut, high_cut, tr, verbose
    )

initializing matrix of size (3260, 122767)
loading and concatenating 5 scans


In [6]:
print(func_data.shape)

(3260, 122767)


In [7]:
write_out(func_data,mask,header,'nifti','/data/SFIMJGC_Introspec/cpca_vs_introspec/spring_2024/data/cpca/JAVIER_POST_SFN2024/ORIG')

***

### Apply Hilbert Transform --> obtain analytical signal

In [8]:
# if pca_type is complex, compute hilbert transform
if pca_type == 'complex':
    func_data = hilbert_transform(func_data, verbose)

applying hilbert transform


In [9]:
func_data.shape

(3260, 122767)

***

### Apply PCA

In [10]:
pca_output = pca(func_data, n_comps, verbose)

performing PCA/CPCA
 number of samples = 3260
 input_data.shape[1] = 122767


In [11]:
pca_output['U'].shape, pca_output['s'].shape, pca_output['Va'].shape

((3260, 3260), (3260,), (3260, 122767))

***
### Remove Components

In [12]:
pca_output['s_mod'] = pca_output['s'].copy()

In [13]:
pca_output['s_mod'][:3] = 0

***
### Reconstruct the data

In [14]:
func_data_mod_reconstructed = np.dot(pca_output['U'] * pca_output['s_mod'], pca_output['Va'])

In [15]:
func_data_modified = np.real(func_data_mod_reconstructed)

In [16]:
func_data_modified

array([[-4.08930025e-14,  2.49407012e-14,  3.17258923e-14, ...,
         0.00000000e+00,  4.27470732e-02,  0.00000000e+00],
       [ 1.06886709e-13,  4.12176498e-14,  4.84417944e-14, ...,
         0.00000000e+00, -2.46599703e-02,  0.00000000e+00],
       [-3.61867827e-14, -6.14628260e-14, -8.21240862e-15, ...,
         0.00000000e+00, -7.14048055e-02,  0.00000000e+00],
       ...,
       [ 4.11187276e-14, -1.50444896e-13,  1.22669481e-14, ...,
         0.00000000e+00, -9.09199494e-01,  0.00000000e+00],
       [ 1.45857796e-14, -9.76858223e-14, -2.39312872e-14, ...,
         0.00000000e+00, -7.04478752e-01,  0.00000000e+00],
       [ 7.01746098e-14, -4.18684331e-14,  2.98126921e-14, ...,
         0.00000000e+00, -1.93273426e-01,  0.00000000e+00]])

In [17]:
write_out(func_data_modified,mask,header,'nifti','/data/SFIMJGC_Introspec/cpca_vs_introspec/spring_2024/data/cpca/JAVIER_POST_SFN2024/RECON_3OUT.nii')