In [None]:
import argparse
import os
import time
import pickle
import numpy as np
import jax
import jax.numpy as jnp
import jax.random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from sklearn.preprocessing import StandardScaler, PowerTransformer, RobustScaler
import get_data
import visualization
from utils import get_infparams, get_robustK

In [None]:
def models(X, hypers, args):

    ## Initialize parameters:
      
    N, M, Dm = X.shape[0], args.num_sources, hypers['Dm']  
    D, K, percW = sum(Dm), args.K, hypers['percW']


    ## Sampling sigma:
    
    sigma = numpyro.sample("sigma", dist.Gamma(hypers['a_sigma'], hypers['b_sigma']), sample_shape=(1, M))


    ## Sampling Z:

    if args.model == 'sparseGFA':
            
        # Sampling Z for Sparse GFA
        Z = numpyro.sample("Z", dist.Normal(0, 1), sample_shape=(N, K))
    
        # Sampling tauZ for Sparse GFA
        tauZ = numpyro.sample(f'tauZ', dist.TruncatedCauchy(scale=1), sample_shape=(1, K))

        # Sampling lambdaZ for Sparse GFA
        lmbZ = numpyro.sample("lmbZ", dist.TruncatedCauchy(scale=1), sample_shape=(N, K))
    
        if args.reghsZ:   
            
            # Sampling cZ for Regularized Horseshoe Prior
            cZtmp = numpyro.sample("cZ", dist.InverseGamma(0.5 * hypers['slab_df'], 0.5 * hypers['slab_df']), sample_shape=(1, K))
            
            # Transforming cZ
            cZ = hypers['slab_scale'] * jnp.sqrt(cZtmp)
            
            # Computing Regularized Z
            lmbZ_sqr = jnp.square(lmbZ)
            
            # Iterate over each latent factor
            for k in range(K):    
                
                # Calculate the adjusted local shrinkage parameter for each component k.
                lmbZ_tilde = jnp.sqrt( lmbZ_sqr[:, k] * cZ[0, k] ** 2 / ( cZ[0, k] ** 2 + tauZ[0, k] ** 2 * lmbZ_sqr[:, k] ) )

                # Updates Z by multiplying with lmbZ_tilde and tauZ for each component k.
                Z = Z.at[:, k].set( Z[:, k] * lmbZ_tilde * tauZ[0, k] )
                
        else:
            
            Z = Z * lmbZ * tauZ

    else:
       
        Z = numpyro.sample("Z", dist.Normal(0, 1), sample_shape=(N, K))



    ## Sampling W:
    
    W = numpyro.sample("W", dist.Normal(0, 1), sample_shape=(D, K))

    if 'sparseGFA' in args.model:

        # Sampling lambdaW for Sparse GFA
        lmbW = numpyro.sample("lmbW", dist.TruncatedCauchy(scale=1), sample_shape=(D, K))
        
        # Sampling cW for Regularized Horseshoe Prior on W
        cWtmp = numpyro.sample("cW", dist.InverseGamma(0.5 * hypers['slab_df'], 0.5 * hypers['slab_df']), sample_shape=(M, K))
        
        # Transforming cW
        cW = hypers['slab_scale'] * jnp.sqrt(cWtmp)

        # Computing pW
        pW = np.round((percW/100) * Dm)

        d = 0   
            
        # Loop Over Each Modality to Apply Regularized Horseshoe Prior on W
        for m in range(M): 
            
            # Computes the scale parameter for the global shrinkage parameters tauW.
            scaleW = pW[m] / ((Dm[m] - pW[m]) * jnp.sqrt(N)) 
            
            # Samples tauW for the current modality.
            tauW = numpyro.sample(f'tauW{m+1}', dist.TruncatedCauchy(scale=scaleW * 1/jnp.sqrt(sigma[0, m])))

            # Reshapes the squared local shrinkage parameters to match the dimensions of the current modality.
            lmbW_sqr = jnp.reshape( jnp.square( lmbW[d:d+Dm[m], :] ), ( Dm[m], K ) )
            
            # Calculates the adjusted local shrinkage parameter.
            lmbW_tilde = jnp.sqrt( cW[m, :] ** 2 * lmbW_sqr / ( cW[m, :] ** 2 + tauW ** 2 * lmbW_sqr ) )   
            
            # Updates the weight matrix W for the current modality.
            W = W.at[d:d+Dm[m], :].set(W[d:d+Dm[m], :] * lmbW_tilde * tauW)
            
            # Sampling X for Each Modality (part of the generative process of the GFA model)   
            numpyro.sample( f'X{m+1}', dist.Normal( jnp.dot( Z, W[d:d+Dm[m], :].T ), 1/jnp.sqrt( sigma[0, m] ) ), obs=X[:, d:d+Dm[m]] )
            
            d += Dm[m]


    elif args.model == 'GFA':
        alpha = numpyro.sample("alpha", dist.Gamma(1e-3, 1e-3), sample_shape=(M, K))
        d = 0
        for m in range(M):
            W = W.at[d:d+Dm[m], :].set(W[d:d+Dm[m], :] * (1/jnp.sqrt(alpha[m, :])))
            numpyro.sample(f'X{m+1}', dist.Normal(jnp.dot(Z, W[d:d+Dm[m], :].T), 1/jnp.sqrt(sigma[0, m])), obs=X[:, d:d+Dm[m]])
            d += Dm[m]

In [None]:
def run_inference(model, args, rng_key, X, hypers):

    kernel = NUTS(model)

    mcmc = MCMC(kernel, num_warmup=args.num_warmup, num_samples=args.num_samples, num_chains=args.num_chains)

    mcmc.run(rng_key, X, hypers, args, extra_fields=('potential_energy',))

    # mcmc.print_summary()

    return mcmc

In [None]:
def main(args):   
    
    ## Directory Setup for Results
    
    if 'synthetic' in args.dataset:                                                                       
        flag = f'K{args.K}_{args.num_chains}chs_pW{args.percW}_s{args.num_samples}_addNoise{args.noise}'

    else:                                                                                               
        flag = f'K{args.K}_{args.num_chains}chs_pW{args.percW}_s{args.num_samples}'                       

    if args.model == 'sparseGFA':                                           
        if args.reghsZ:                                                   
            flag_regZ = '_reghsZ'                                         
        else:                                                               
            flag_regZ = '_hsZ'                                    
    else:                                                          
        flag_regZ = ''                                           

    res_dir = f'/Users/mertenbiyaoglu/Desktop/ucl/thesis/codes/sGFA_AIDA/results/{args.dataset}/{args.model}_{flag}{flag_regZ}'
    
    if not os.path.exists(res_dir):                                 
        os.makedirs(res_dir)                                         



    ## Setting Up Hyperparameters
        
    hp_path = f'{res_dir}/hyperparameters.dictionary'                      
    
    if not os.path.exists(hp_path):                                        
        hypers = {'a_sigma': 1, 'b_sigma': 1, 'nu_local': 1, 'nu_global': 1, 'slab_scale': 2, 'slab_df': 4, 'percW': args.percW}
        
        with open(hp_path, 'wb') as parameters:                             
            pickle.dump(hypers, parameters)                               
    
    else:                                                             
        with open(hp_path, 'rb') as parameters:                           
            hypers = pickle.load(parameters)                              



    ## Initializing and Loading Data for Each Run
    
    for i in range(args.num_runs):                                          
        
        print('Initialisation: ', i+1)
        print('----------------------------------')
        
        if 'synthetic' in args.dataset:                                     
            data_path = f'{res_dir}/[{i+1}]Data.dictionary'               

            if not os.path.exists(data_path):                              
                data = get_data.synthetic_data(hypers, args)               
                
                with open(data_path, 'wb') as parameters:                  
                    pickle.dump(data, parameters)                           
            
            else:                                                           
                with open(data_path, 'rb') as parameters:                 
                    data = pickle.load(parameters)                   
            
            X = data['X']                                                  
            
        
        elif 'genfi' in args.dataset:

            datafolder = f'./aida_model'

            data = get_data.genfi(datafolder)

            X = data['X'].copy()
            Y = data['Y']

            scaler = StandardScaler()
            robust_scaler = RobustScaler()
            transformer = PowerTransformer(method='yeo-johnson')

            # Modality Slices
            X1_columns = slice(0, 33)
            X2_columns = slice(33, 62)
            X3_columns = slice(62, 68)
            X4_columns = slice(68, 108)
                
            # Standard Scaler (Feature-wise) and Box-Cox if 
            X[:, X1_columns] = transformer.fit_transform(scaler.fit_transform(X[:, X1_columns]))
            X[:, X2_columns] = transformer.fit_transform(scaler.fit_transform(X[:, X2_columns]))
            X[:, X3_columns] = transformer.fit_transform(scaler.fit_transform(X[:, X3_columns]))
            X[:, X4_columns] = transformer.fit_transform(scaler.fit_transform(X[:, X4_columns]))

            
        hypers.update({'Dm': data.get('Dm')})
        
                                            
    ## Running the Model, Handling Robust Parameter Extraction, and Saving Results 
        
        res_path = f'{res_dir}/[{i+1}]Model_params.dictionary'              # Constructs the path to the file where all the model parameters are saved.
    
        robparams_path = f'{res_dir}/[{i+1}]Robust_params.dictionary'       # Constructs the path to the file where only the robust parameters are saved.

        ### EXTRA: DEFINE PATH TO DATACOMPS!!!
        datacomps_path = f'{res_dir}/[{i+1}]Data_comps.dictionary'

        ### EXTRA: DEFINE PATH TO INFPARAMS!!!
        inf_params_path = f'{res_dir}/[{i+1}]inf_params.dictionary'


        if not os.path.exists(res_path):
            with open(res_path, 'wb') as parameters:
                pickle.dump(0, parameters)
            
            print(f'Running Model...') 

            seed = np.random.randint(0, 50)

            rng_key = jax.random.PRNGKey(seed)

            start = time.time()                                      
            
            MCMCout = run_inference(models, args, rng_key, X, hypers)                                                                       
            
            mcmc_samples = MCMCout.get_samples()                          

            mcmc_samples.update({'time_elapsed': (time.time() - start)/60}) 
            
            pe = MCMCout.get_extra_fields()['potential_energy']                                                    
            
            mcmc_samples.update({'exp_logdensity': jnp.mean(-pe)})          

            with open(res_path, 'wb') as parameters:                       
                pickle.dump(mcmc_samples, parameters)

                print('Inferred parameters saved.')
        

        if not os.path.exists(robparams_path) and os.stat(res_path).st_size > 5:

            # Loading Inferred Parameters:
            with open(res_path, 'rb') as parameters:
                mcmc_samples = pickle.load(parameters)
        
            inf_params, data_comps = get_infparams(mcmc_samples, hypers, args)

            # EXTRA: SAVE DATACOMPS !!!
            with open(datacomps_path, 'wb') as parameters:
                pickle.dump(data_comps, parameters)
                print('"data_comps" SAVED.')

            # EXTRA: SAVE INFPARAMS !!!
            with open(inf_params_path, 'wb') as parameters:
                pickle.dump(inf_params, parameters)
                print('"inf_params" SAVED.')


            if args.num_chains > 1:

                thrs = {'cosineThr': 0.8, 'matchThr': 0.5}

                # Finding Robust Components:
                rob_params, X_rob, success = get_robustK(thrs, args, inf_params, data_comps)

                # Saving Robust Data Components:
                if success:
                    
                    rob_params.update({'sigma_inf': inf_params['sigma'], 'infX': X_rob})
                   
                    if 'sparseGFA' in args.model:                                        
                    
                        rob_params.update({'tauW_inf': inf_params['tauW']})
                    
                    with open(robparams_path, 'wb') as parameters:                          
                        pickle.dump(rob_params, parameters)                                  
                        print('Robust parameters saved')  
                        
                else:
                    print('No robust components found => {i+1}]Robust_params.dictionary NOT CREATED')
            

            else:                                                                           # Handles the case with a single MCMC chain.

                ### EXTRA: DEBUG
                print("ENTERED ELSE BLOCK, TAKES THE AVERAGE OF THE SINGLE CHAIN")

                W = np.mean(inf_params['W'][0], axis=0)                                     # Averages the inferred W parameters across samples.
            
                Z = np.mean(inf_params['Z'][0], axis=0)                                     # Averages the inferred Z parameters across samples.
                
                X = [[np.dot(Z, W.T)]]                                                      # Computes Inferred X.
                
                rob_params = {'W': W, 'Z': Z, 'infX': X}                                    # Creates Robust Parameters Dictionary

                with open(robparams_path, 'wb') as parameters:                         
                    pickle.dump(rob_params, parameters)                              


    ## Visualization

    if 'synthetic' in args.dataset:                                                         
        visualization.synthetic_data(res_dir, data, args, hypers)

    else:
        visualization.genfi(data, res_dir, args)

In [None]:
if __name__ == "__main__":

    dataset = 'genfi'
    
    if 'genfi' in dataset:
        num_samples = 3000
        K = 20
        num_sources = 4
        num_runs = 5
        
    else:
        num_samples = 1500      
        K = 5                  
        num_sources = 3       
        num_runs = 5

    parser = argparse.ArgumentParser(description=" Sparse GFA with reg. horseshoe priors")

    #parser.add_argument("model", type=str, nargs="?", default='GFA', help='add horseshoe prior over the latent variables')
    parser.add_argument("model", type=str, nargs="?", default='sparseGFA', help='add horseshoe prior over the latent variables')

    parser.add_argument("--num-samples", nargs="?", default=3000, type=int, help='number of MCMC samples')

    parser.add_argument("--num-warmup", nargs='?', default=2000, type=int, help='number of MCMC samples for warmup')

    parser.add_argument("--K", nargs='?', default=K, type=int, help='number of components')

    parser.add_argument("--num-chains", nargs='?', default=5, type=int, help= 'number of MCMC chains')

    parser.add_argument("--num-sources", nargs='?', default=num_sources, type=int, help='number of data sources')

    parser.add_argument("--num-runs", nargs='?', default=num_runs, type=int, help='number of runs')

    parser.add_argument("--reghsZ", nargs='?', default=False, type=bool)
    #parser.add_argument("--reghsZ", nargs='?', default=True, type=bool)

    parser.add_argument("--percW", nargs='?', default=42, type=int, help='percentage of relevant variables in each source')

    parser.add_argument("--dataset", nargs='?', default=dataset, type=str, help='choose dataset')

    parser.add_argument("--device", default='cpu', type=str, help='use "cpu" or "gpu".')

    """args = parser.parse_args()"""
    args, unknown = parser.parse_known_args()
    
    numpyro.set_platform(args.device)
    
    numpyro.set_host_device_count(args.num_chains)
    
    main(args)

# DEBUG

###  Data_comps.dictionary (output of get_infparams)

a list of shape (a, b, c, d), where "a" is the number of chains, and "b" is the number of factors

In [None]:
import pickle

file_path = "/Users/mertenbiyaoglu/Desktop/ucl/thesis/codes/sGFA_AIDA/results/genfi/sparseGFA_K5_1chs_pW33_s50_reghsZ/[1]Data_comps.dictionary"

with open(file_path, 'rb') as file:
    data = pickle.load(file)

In [None]:
print(data)

In [None]:
def get_shape(data):
    if isinstance(data, (list, np.ndarray)):
        if isinstance(data, list):
            data = [item for item in data if not isinstance(item, str)]
        if len(data) == 0:
            return (0,)
        return (len(data),) + get_shape(data[0])
    else:
        return ()

In [None]:
get_shape(data)         # (1, 5, 83, 300)

### inf_params.dictionary (output of get_infparams)

In [None]:
import pickle

file_path = "/Users/mertenbiyaoglu/Desktop/ucl/thesis/codes/sGFA_AIDA/results/genfi/sparseGFA_K5_1chs_pW33_s50_reghsZ/[1]inf_params.dictionary"

with open(file_path, 'rb') as file:
    inf_params = pickle.load(file)

In [None]:
print(inf_params)

In [None]:
inf_params['W']         # has numerical values
inf_params['Z']         # has numerical values
inf_params['sigma']     # has numerical values

In [None]:
def get_shape(data):
    if isinstance(data, (list, np.ndarray)):
        if isinstance(data, list):
            data = [item for item in data if not isinstance(item, str)]
        if len(data) == 0:
            return (0,)
        return (len(data),) + get_shape(data[0])
    else:
        return ()

In [None]:
get_shape(inf_params['W'])          # (1, 50, 300, 5)
get_shape(inf_params['Z'])          # (1, 50, 83, 5)
get_shape(inf_params['sigma'])      # (50, 5)

### Model_params.dictionary (PRINTS OUT MCMC SAMPLES)

A dictionary with the keys: 'W,' 'Z,' 'sigma,' 'time_elapsed,' 'exp_logdensity'
MCMC samples are input for Data_comps and inf_params. W related stuff is not zero for mcmc samples but are zero for Data_comps and inf_params.

In [None]:
import pickle

file_path = "/Users/mertenbiyaoglu/Desktop/ucl/thesis/codes/sGFA_AIDA/results/genfi/sparseGFA_K5_1chs_pW33_s50_reghsZ/[1]Model_params.dictionary"

with open(file_path, 'rb') as file:
    _data = pickle.load(file)

In [None]:
print(_data)

In [None]:
mcmc_w = _data['W']
mcmc_z = _data['Z']
mcmc_s = _data['sigma']

In [None]:
w = _data['W'].shape                        # (50, 300, 5)
z = _data['Z'].shape                        # (50, 83, 5)
s = _data['sigma'].shape                    # (50, 1, 5)
t = _data['time_elapsed']
l = _data['exp_logdensity']

### Robust_params.dictionary (output of rob_params, X_rob, success = get_robustK(thrs, args, inf_params, data_comps))

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
import pickle

In [None]:
file_path = "/Users/mertenbiyaoglu/Desktop/ucl/thesis/codes/sGFA_AIDA/results/genfi/sfebqc_K20_5chs_pW20_s5000_CosThr8_MatThr5/[1]Robust_params.dictionary"

with open(file_path, 'rb') as file:
    rob_params = pickle.load(file)

In [None]:
print(rob_params)

In [None]:
rob_w = rob_params['W']             # has numerical values
rob_z = rob_params['Z']             # has numerical values
rob_infx = rob_params['infX']       # has numerical values

In [None]:
rob_z_converted = rob_z.astype(float)
df_rob_z = pd.DataFrame(rob_z_converted)

In [None]:
column_data = df_rob_z.iloc[:, 4]

# Calculate the average of the absolute values for the first 14 values
first_14_avg = column_data.iloc[:14].abs().mean()

# Calculate the average of the absolute values for the remaining values
remaining_avg = column_data.iloc[14:].abs().mean()

# sorted_indices = df_rob_z.iloc[:, 4].abs().sort_values().index.tolist()
sorted_indices = df_rob_z.iloc[:, 4].sort_values().index.tolist()

In [None]:
file_path = "/Users/mertenbiyaoglu/Desktop/ucl/thesis/codes/sGFA_AIDA/aida_model/visit11_data_45subjs.csv"
df = pd.read_csv(file_path)
df_sorted = df.reindex(sorted_indices)

In [None]:
df_sorted = df_sorted.apply(lambda col: col.fillna(col.median()))

normalized_data = data.apply(lambda x: (x - x.min()) / (x.max() - x.min()), axis=0)

w = rob_params['W'].shape           # (300, 5)
z = rob_params['Z'].shape           # (83, 5)

In [None]:
def get_shape(data):
    if isinstance(data, (list, np.ndarray)):
        if isinstance(data, list):
            data = [item for item in data if not isinstance(item, str)]
        if len(data) == 0:
            return (0,)
        return (len(data),) + get_shape(data[0])
    else:
        return ()

get_shape(rob_infx)  

# Inspecting W for structural MRI and fMRI

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
import pickle

In [None]:
file_path = "/Users/mertenbiyaoglu/Desktop/ucl/thesis/codes/sGFA_AIDA/results/genfi/sparseGFA_K20_5chs_pW33_s3000_reghsZ/[3]Robust_params.dictionary"

with open(file_path, 'rb') as file:
    rob_params = pickle.load(file)

In [None]:
rob_w = rob_params['W']  

rob_w_converted = rob_w.astype(float)
rob_w = pd.DataFrame(rob_w_converted)

### structural MRI 

In [None]:
smri = rob_w[0:69, 2]

smri_con = smri.astype(float)
smri = pd.DataFrame(smri_con)

In [None]:
smri_2 = rob_w[0:69, 1]

smri_con_2 = smri_2.astype(float)
smri_2 = pd.DataFrame(smri_con_2)

In [None]:
min_ = smri_2.iloc[:, 0].min().astype(float)
max_ = smri_2.iloc[:, 0].max().astype(float)

### positive fMRI 

In [None]:
p_fmri = rob_w[69:138, 5]

p_fmri_con = p_fmri.astype(float)
p_fmri = pd.DataFrame(p_fmri_con)

### negative fMRI - 138:207

In [None]:
n_fmri = rob_w[138:207, 1]

n_fmri_con = n_fmri.astype(float)
n_fmri = pd.DataFrame(n_fmri_con)

structural MRI - 0:69
positive fMRI - 69:138