In [83]:
import numpy as np
from trained_untrained_results_funcs import loop_through_datasets, load_mean_sem_perf, custom_add_2d
from matplotlib import pyplot as plt

In [133]:
import numpy as np
from scipy.stats import false_discovery_control

def compute_voxel_pvalues(voxel_performance, null_distribution):
    """
    Compute one-sided p-values for voxel performance against a null distribution and apply FDR correction.
    
    Parameters:
    - voxel_performance: 1D array of shape (num_voxels/elecs/fROIs,) containing the difference in performance values between m1 and m2.
    - null_distribution: 2D array of shape (1000, num_voxels/elecs/fROIs) representing the null distribution for each voxel.
    
    Returns:
    
    - p_values: 1D array of uncorrected p-values for each voxel. The pvalue indicates the chance that the difference between m1 and m2 is due
    to chance. The way this is computed is we fit N=1000 gaussian regressions, and use these to create a null distribution of R2 values. 
    Then, we compute the fraction of null R2 values that are greater than or equal to the difference in R2 between m1 and m2 to get the p-value.
    
    - fdr_corrected_p_values: 1D array of FDR-corrected p-values for each voxel.
    """
    num_voxels = voxel_performance.shape[0]
    p_values = np.zeros(num_voxels)

    # Compute p-values for each voxel
    for i in range(num_voxels):
        null_dist = null_distribution[:, i]

        p_values[i] = np.nanmean(null_dist >= voxel_performance[i])  # One-sided test
        
    fdr_corrected_p_values  = false_discovery_control(p_values)
    
        # Compute and print fractions under 0.05
    frac_uncorrected = np.mean(p_values < 0.05) * 100
    frac_fdr_corrected = np.mean(fdr_corrected_p_values < 0.05) * 100 

    print(f"Fraction of uncorrected p-values < 0.05: {frac_uncorrected:.2f}")
    print(f"Fraction of FDR-corrected p-values < 0.05: {frac_fdr_corrected:.2f}")


    return p_values, fdr_corrected_p_values

           
def process_nan_indices(arr1, arr2, large_negative_value=-1e9):
    
    """
    Processes two arrays to handle NaN values as described:
    1. If both arrays have NaN at the same index, remove that index.
    2. If one array has a NaN and the other doesn't, replace the NaN with a large negative value.

    Parameters:
    - arr1 (numpy.ndarray): The first array.
    - arr2 (numpy.ndarray): The second array.
    - large_negative_value (float): The value to replace NaN with when the other array doesn't have a NaN.

    Returns:
    - numpy.ndarray, numpy.ndarray: The processed arrays.
    """
    arr1, arr2 = np.asarray(arr1), np.asarray(arr2)
    
    # Ensure both arrays are of the same shape
    if arr1.shape != arr2.shape:
        raise ValueError("Arrays must have the same shape.")

    # Identify indices where both are NaN
    both_nan = np.isnan(arr1) & np.isnan(arr2)

    # Filter out indices where both are NaN
    arr1, arr2 = arr1[~both_nan], arr2[~both_nan]

    # Replace remaining NaNs with the large negative value
    arr1 = np.where(np.isnan(arr1), large_negative_value, arr1)
    arr2 = np.where(np.isnan(arr2), large_negative_value, arr2)

    return arr1, arr2

def compute_stats_results(llm_model, simple_model, figure_folder, non_nan_indices_dict, subjects_arr_pereira, 
                          lang_indices, exp=['384', '243'], dataset_arr=['pereira', 'blank', 'fedorenko'], 
                          llm_greater=True):

    perf = 'out_of_sample_r2'

    for dataset in dataset_arr:
    
        for fe in ['', '-mp', '-sp']:
        
            all_gauss = []
            
            llm = np.load(f'/home2/ebrahim/beyond-brainscore/analyze_results/figures_code/figures_data/{figure_folder}/{llm_model}_{dataset}.npz')[fe]
            simple = np.load(f'/home2/ebrahim/beyond-brainscore/analyze_results/figures_code/figures_data/{figure_folder}/{simple_model}_{dataset}.npz')['']
            
            
            if llm_greater:
                perf_diff = llm-simple
                
            else:
                perf_diff = simple-llm
            
            for i in range(1000):
                
                if dataset == 'pereira':
                    
                    gauss_perf_combined = np.full(subjects_arr_pereira.shape[0], fill_value=np.nan)
                    
                    for e in exp:
                    
                        gauss_perf = np.load(f'/data/LLMs/brainscore/results_{dataset}/stats/{dataset}_gaussian-stats_layer_{i}_1_{e}.npz')[perf]
                        
                        gauss_perf_combined[non_nan_indices_dict[e]] = custom_add_2d(gauss_perf_combined[non_nan_indices_dict[e]],  
                                                                                        gauss_perf)
                        
                    gauss_perf_combined = gauss_perf_combined[lang_indices]
                        
                else:
                    
                    gauss_perf_combined = np.load(f'/data/LLMs/brainscore/results_{dataset}/stats/{dataset}_gaussian-stats_layer_{i}_1.npz')[perf]
                    
                
                all_gauss.append(gauss_perf_combined)
                
            
            all_gauss_np = np.stack(all_gauss)
            
            # the gaussian combined has the nans from both the fact that 384 has constant voxels 
            # and when combing the results across experiments there is a small number of nans
            # so we use it to remove nans before stats testing. 
            nan_mask = ~np.isnan(gauss_perf_combined) 
            perf_diff = perf_diff[nan_mask]
            all_gauss_np = all_gauss_np[:, nan_mask]
            
            print(dataset, fe)
            _, _ = compute_voxel_pvalues(perf_diff, all_gauss_np)    

In [134]:
exp = ['384', '243']
data_processed_folder_pereira = '/data/LLMs/data_processed/pereira/dataset/'

br_labels_dict = {}
num_vox_dict = {}
subjects_dict = {}
for e in exp:

    bre = np.load(f'{data_processed_folder_pereira}/networks_{e}.npy', allow_pickle=True)
    br_labels_dict[e] = bre
    num_vox_dict[e] = bre.shape[0]
    subjects_dict[e] = np.load(f"{data_processed_folder_pereira}/subjects_{e}.npy", allow_pickle=True)
    
lang_indices_dict = {}
lang_indices_384 = np.argwhere(br_labels_dict['384'] == 'language').squeeze()
lang_indices_243 = np.argwhere(br_labels_dict['243'] == 'language').squeeze()
lang_indices_dict['384'] = lang_indices_384
lang_indices_dict['243'] = lang_indices_243

subjects_arr_pereira = np.load(f"{data_processed_folder_pereira}/subjects_complete.npy", allow_pickle=True)
networks_arr_pereira = np.load(f"{data_processed_folder_pereira}/network_complete.npy", allow_pickle=True)
non_nan_indices_243 = np.load(f"{data_processed_folder_pereira}/non_nan_indices_243.npy") # voxels which are in 243
non_nan_indices_384 = np.load(f"{data_processed_folder_pereira}/non_nan_indices_384.npy") # voxels which are in 384
non_nan_indices_dict = {'384': non_nan_indices_384, '243': non_nan_indices_243}
lang_indices = np.argwhere(networks_arr_pereira=='language').squeeze()

y_384 = np.load(f"{data_processed_folder_pereira}y_pereira_384.npy")

constant_columns_384 = np.all(y_384 == y_384[0, :], axis=0)

# Extract and print constant columns
constant_columns_indices_384 = np.where(constant_columns_384)[0]
print(constant_columns_indices_384.shape) # There will be 23 nan indices just due to constant activation voxels in 384

(23,)


In [None]:
compute_stats_results('gpt2xl_combined', 'simple_combined', 'figure2', non_nan_indices_dict=non_nan_indices_dict, 
                      subjects_arr_pereira=subjects_arr_pereira, 
                      lang_indices=lang_indices)

In [None]:
compute_stats_results('gpt2xl_combined', 'simple_combined', 'figure2', non_nan_indices_dict=non_nan_indices_dict, 
                      subjects_arr_pereira=subjects_arr_pereira, 
                      lang_indices=lang_indices, llm_greater=False)

In [137]:
compute_stats_results('gpt2xl_combined', 'simple_combined', 'figure4', non_nan_indices_dict=non_nan_indices_dict, 
                      subjects_arr_pereira=subjects_arr_pereira, 
                      lang_indices=lang_indices)

pereira 
Fraction of uncorrected p-values < 0.05: 47.09
Fraction of FDR-corrected p-values < 0.05: 24.76
pereira -mp
Fraction of uncorrected p-values < 0.05: 45.56
Fraction of FDR-corrected p-values < 0.05: 20.84
pereira -sp
Fraction of uncorrected p-values < 0.05: 51.06
Fraction of FDR-corrected p-values < 0.05: 28.93
blank 
Fraction of uncorrected p-values < 0.05: 23.33
Fraction of FDR-corrected p-values < 0.05: 5.00
blank -mp
Fraction of uncorrected p-values < 0.05: 26.67
Fraction of FDR-corrected p-values < 0.05: 6.67
blank -sp
Fraction of uncorrected p-values < 0.05: 25.00
Fraction of FDR-corrected p-values < 0.05: 10.00
fedorenko 
Fraction of uncorrected p-values < 0.05: 41.24
Fraction of FDR-corrected p-values < 0.05: 24.74
fedorenko -mp
Fraction of uncorrected p-values < 0.05: 40.21
Fraction of FDR-corrected p-values < 0.05: 18.56
fedorenko -sp
Fraction of uncorrected p-values < 0.05: 49.48
Fraction of FDR-corrected p-values < 0.05: 22.68


In [138]:
compute_stats_results('gpt2xl_combined', 'simple_combined', 'figure4', non_nan_indices_dict=non_nan_indices_dict, 
                      subjects_arr_pereira=subjects_arr_pereira, 
                      lang_indices=lang_indices, llm_greater=False)

pereira 
Fraction of uncorrected p-values < 0.05: 32.45
Fraction of FDR-corrected p-values < 0.05: 10.64
pereira -mp
Fraction of uncorrected p-values < 0.05: 33.36
Fraction of FDR-corrected p-values < 0.05: 14.95
pereira -sp
Fraction of uncorrected p-values < 0.05: 27.66
Fraction of FDR-corrected p-values < 0.05: 4.67
blank 
Fraction of uncorrected p-values < 0.05: 65.00
Fraction of FDR-corrected p-values < 0.05: 55.00
blank -mp
Fraction of uncorrected p-values < 0.05: 60.00
Fraction of FDR-corrected p-values < 0.05: 56.67
blank -sp
Fraction of uncorrected p-values < 0.05: 58.33
Fraction of FDR-corrected p-values < 0.05: 51.67
fedorenko 
Fraction of uncorrected p-values < 0.05: 46.39
Fraction of FDR-corrected p-values < 0.05: 34.02
fedorenko -mp
Fraction of uncorrected p-values < 0.05: 44.33
Fraction of FDR-corrected p-values < 0.05: 17.53
fedorenko -sp
Fraction of uncorrected p-values < 0.05: 34.02
Fraction of FDR-corrected p-values < 0.05: 2.06


In [135]:
compute_stats_results('gpt2xl_combined', 'simple_combined', 'figure5', non_nan_indices_dict=non_nan_indices_dict, 
                      subjects_arr_pereira=subjects_arr_pereira, 
                      lang_indices=lang_indices)

pereira 
Fraction of uncorrected p-values < 0.05: 11.26
Fraction of FDR-corrected p-values < 0.05: 0.10
pereira -mp
Fraction of uncorrected p-values < 0.05: 19.11
Fraction of FDR-corrected p-values < 0.05: 0.56
pereira -sp
Fraction of uncorrected p-values < 0.05: 17.67
Fraction of FDR-corrected p-values < 0.05: 0.40
blank 
Fraction of uncorrected p-values < 0.05: 31.67
Fraction of FDR-corrected p-values < 0.05: 6.67
blank -mp
Fraction of uncorrected p-values < 0.05: 31.67
Fraction of FDR-corrected p-values < 0.05: 0.00
blank -sp
Fraction of uncorrected p-values < 0.05: 28.33
Fraction of FDR-corrected p-values < 0.05: 0.00
fedorenko 
Fraction of uncorrected p-values < 0.05: 14.43
Fraction of FDR-corrected p-values < 0.05: 0.00
fedorenko -mp
Fraction of uncorrected p-values < 0.05: 20.62
Fraction of FDR-corrected p-values < 0.05: 4.12
fedorenko -sp
Fraction of uncorrected p-values < 0.05: 19.59
Fraction of FDR-corrected p-values < 0.05: 4.12


In [136]:
compute_stats_results('gpt2xl_combined', 'simple_combined', 'figure5', non_nan_indices_dict=non_nan_indices_dict, 
                      subjects_arr_pereira=subjects_arr_pereira, 
                      lang_indices=lang_indices, llm_greater=False)

pereira 
Fraction of uncorrected p-values < 0.05: 62.08
Fraction of FDR-corrected p-values < 0.05: 43.02
pereira -mp
Fraction of uncorrected p-values < 0.05: 48.83
Fraction of FDR-corrected p-values < 0.05: 22.47
pereira -sp
Fraction of uncorrected p-values < 0.05: 47.36
Fraction of FDR-corrected p-values < 0.05: 2.35
blank 
Fraction of uncorrected p-values < 0.05: 58.33
Fraction of FDR-corrected p-values < 0.05: 55.00
blank -mp
Fraction of uncorrected p-values < 0.05: 58.33
Fraction of FDR-corrected p-values < 0.05: 53.33
blank -sp
Fraction of uncorrected p-values < 0.05: 63.33
Fraction of FDR-corrected p-values < 0.05: 55.00
fedorenko 
Fraction of uncorrected p-values < 0.05: 65.98
Fraction of FDR-corrected p-values < 0.05: 63.92
fedorenko -mp
Fraction of uncorrected p-values < 0.05: 64.95
Fraction of FDR-corrected p-values < 0.05: 48.45
fedorenko -sp
Fraction of uncorrected p-values < 0.05: 62.89
Fraction of FDR-corrected p-values < 0.05: 56.70
