# Process fMRI RDMs

### Load fMRI RDMs for VA-1, VA-2, VA-3, left LOC, and right LOC.

In [None]:
# Load all the files from the fmri_path directory that contain the word 'va-1', or 'va-2', or 'va-3'
fmri_path = Path('../../THINGS_fMRI/all_rdm')
fmri_files = list(fmri_path.glob("*va-1_*")) + list(fmri_path.glob("*va-2_*")) + list(fmri_path.glob("*va-3_*")) + list(fmri_path.glob("*lLOC*")) + list(fmri_path.glob("*rLOC*"))

# create a dictionary to store the fMRI RDMs
fmri_rdms = {}

for file in fmri_files:
    # extract the subject number from the file name
    subject_number = file.stem.split('_')[0].split('-')[1]
    print(subject_number)
    # extract the roi from the file name
    # This pattern is looking for the ROI name as the second element after splitting the filename (without extension) by underscores.
    if 'va' in file.stem:
        roi = file.stem.split('_')[2]
    else:
        roi = file.stem.split('_')[1]
    print(roi)
    fmri_rdm = np.load(file)
    print(fmri_rdm.shape)
    # store the fMRI RDM in the dictionary
    fmri_rdms[subject_number + '_' + roi] = fmri_rdm

# # print the first 5x5 of every fMRI RDM in the dictionary
for key in fmri_rdms:
    print(key)
    print(fmri_rdms[key][:5, :5])

### Get the variable names for our five ROIs.

In [None]:
fmri_rdms_rois = list(set([key.split('_')[1] for key in fmri_rdms.keys()]))

print(fmri_rdms_rois)

for roi in fmri_rdms_rois:
    # Create a dictionary of RDMs for the current ROI, using the ROI name in the variable name
    roi_rdms_varname = f"{roi}_rdms"
    locals()[roi_rdms_varname] = {k: v for k, v in fmri_rdms.items() if roi in k}

    print(roi_rdms_varname)

### Load the concept index for the fMRI RDMs.

In [None]:
# Loading the concept index for the fMRI RDM

concept_index = np.load('./concept_indices/sub-01_LOC_concept_index.npy')

print(concept_index)

### Read in the Stimulus Metadata file for each subject to ensure that all subjects were shown the same 8640 images making up the 720 concepts during fMRI trials.

In [None]:
# Load the Stimulus Metadata file for each subject
sub_01_stimulus_metadata = pd.read_csv('sub-01_StimulusMetadata_train_only.csv').sort_values(by='concept').sort_values(by='stimulus')

# print the first 20 rows of the Stimulus Metadata file
#print(sub_01_stimulus_metadata.head(20))

sub_02_stimulus_metadata = pd.read_csv('sub-02_StimulusMetadata_train_only.csv').sort_values(by='concept').sort_values(by='stimulus')

#print(sub_02_stimulus_metadata.head(20))

sub_03_stimulus_metadata = pd.read_csv('sub-03_StimulusMetadata_train_only.csv').sort_values(by='concept').sort_values(by='stimulus')

#print(sub_03_stimulus_metadata.head(20))

# I want to test whether the stimulus column is exactly the same across all three subjects
# FIXED: Use reset_index(drop=True) to compare only values, not indices
print("Comparing stimulus columns (values only):")
print(sub_01_stimulus_metadata['stimulus'].reset_index(drop=True).equals(sub_02_stimulus_metadata['stimulus'].reset_index(drop=True)))
print(sub_01_stimulus_metadata['stimulus'].reset_index(drop=True).equals(sub_03_stimulus_metadata['stimulus'].reset_index(drop=True)))
print(sub_02_stimulus_metadata['stimulus'].reset_index(drop=True).equals(sub_03_stimulus_metadata['stimulus'].reset_index(drop=True)))

# Find and print rows where the 'stimulus' column differs across the three subjects

stim_01 = sub_01_stimulus_metadata['stimulus'].reset_index(drop=True)
stim_02 = sub_02_stimulus_metadata['stimulus'].reset_index(drop=True)
stim_03 = sub_03_stimulus_metadata['stimulus'].reset_index(drop=True)

# Find indices where any pair is not equal
diff_mask = ~((stim_01 == stim_02) & (stim_01 == stim_03) & (stim_02 == stim_03))

# Print the differing rows
if diff_mask.any():
    print("Rows where 'stimulus' column differs across subjects:")
    diff_df = pd.DataFrame({
        'row': stim_01.index[diff_mask],
        'sub_01_stimulus': stim_01[diff_mask].values,
        'sub_02_stimulus': stim_02[diff_mask].values,
        'sub_03_stimulus': stim_03[diff_mask].values
    })
    print(diff_df)
else:
    print("No differences found in the 'stimulus' column across the three subjects.")

### We can safely say that all three subjects were shown the same 12 images across all 720 object concepts. 

### Calculate the average concept RDM for each ROI across subjects.

In [None]:
# Create averaged upper triangle vectors for all ROIs
# First, identify all unique ROIs from the fmri_rdms dictionary
rois = set()
for key in fmri_rdms.keys():
    roi = key.split('_')[1]  # Extract ROI from key like '01_lLOC' -> 'lLOC'
    rois.add(roi)

print("Available ROIs:", sorted(rois))

# Create a dictionary to store averaged upper triangle vectors for each ROI
roi_average_upper_triangles = {}

# Get upper triangular indices (same for all RDMs since they're all 720x720)
upper_tri_indices = np.triu_indices_from(fmri_rdms[list(fmri_rdms.keys())[0]], k=1)

for roi in sorted(rois):
    print(f"\nProcessing ROI: {roi}")
    
    # Filter RDMs for this specific ROI
    roi_rdms = {k: v for k, v in fmri_rdms.items() if roi in k}
    print(f"Found {len(roi_rdms)} RDMs for {roi}")
    
    # Extract upper triangular elements for each RDM
    roi_rdms_upper = {}
    for key in roi_rdms:
        roi_rdms_upper[key] = roi_rdms[key][upper_tri_indices]
    
    # Stack the vectors into a single array
    roi_rdms_upper_array = np.stack(list(roi_rdms_upper.values()))
    print(f"Stacked array shape: {roi_rdms_upper_array.shape}")
    
    # Calculate the average across subjects
    # axis=0 means we are taking the mean down the rows (i.e., column-wise), so for each position in the upper triangle vector, we average across all subjects' RDMs
    roi_average_upper_triangle = np.mean(roi_rdms_upper_array, axis=0)
    print(f"Average vector shape: {roi_average_upper_triangle.shape}")
    
    # Store in the dictionary
    roi_average_upper_triangles[roi] = roi_average_upper_triangle
    
    # Print first 5 values as a sanity check
    print(f"First 5 values for {roi}: {roi_average_upper_triangle[:5]}")

print(f"\nCreated averaged upper triangle vectors for {len(roi_average_upper_triangles)} ROIs:")
for roi in roi_average_upper_triangles:
    print(f"  {roi}: shape {roi_average_upper_triangles[roi].shape}")

### Average the fMRI activity between left and right LOC.

In [None]:
roi_average_upper_triangles['LOC'] = np.mean(
    np.stack([roi_average_upper_triangles['lLOC'], roi_average_upper_triangles['rLOC']]), axis=0
)

print(roi_average_upper_triangles['LOC'].shape)

print(roi_average_upper_triangles['rLOC'].shape)

print(roi_average_upper_triangles.keys())

# remove the lLOC and rLOC keys
roi_average_upper_triangles.pop('lLOC')
roi_average_upper_triangles.pop('rLOC')

print(roi_average_upper_triangles.keys())

# Create feature-reweighting functions.

### Convert the vectors of fMRI RDMs averaged across subjects for each ROI back to 720x720 RDMs to ensure proper formatting for analysis.

In [None]:
# Convert upper triangular vectors back to 720x720 RDMs (CORRECTED VERSION)
# This reconstructs the full RDM matrices from the upper triangular elements

# Dictionary to store reconstructed RDMs
roi_reconstructed_rdms = {}

print("Reconstructing 720x720 RDMs from upper triangular vectors (CORRECTED)...")
print("=" * 70)

# Get the upper triangular indices (same as used before)
upper_tri_indices = np.triu_indices(720, k=1)  # 720x720 matrix, k=1 excludes diagonal

print(f"Upper triangular indices shape: {upper_tri_indices[0].shape}")
print(f"Number of upper triangular elements: {len(upper_tri_indices[0])}")
print(f"Expected vector length: {len(upper_tri_indices[0])}")

# Loop through each ROI
for roi in sorted(roi_average_upper_triangles.keys()):
    print(f"\nProcessing ROI: {roi}")
    
    # Get the upper triangular vector
    upper_tri_vector = roi_average_upper_triangles[roi]
    print(f"  Upper triangular vector shape: {upper_tri_vector.shape}")
    print(f"  Vector length: {len(upper_tri_vector)}")
    
    # Verify the vector length matches expected
    expected_length = len(upper_tri_indices[0])
    if len(upper_tri_vector) != expected_length:
        print(f"  ⚠ Warning: Vector length {len(upper_tri_vector)} doesn't match expected {expected_length}")
    else:
        print(f"  ✓ Vector length matches expected")
    
    # Create empty RDM matrix
    rdm = np.zeros((720, 720))
    
    # Fill the upper triangular part
    rdm[upper_tri_indices] = upper_tri_vector
    
    # Fill the lower triangular part (RDM is symmetric) - CORRECTED METHOD
    # Use transpose to ensure perfect symmetry
    rdm = rdm + rdm.T
    
    # Ensure diagonal is zero
    np.fill_diagonal(rdm, 0)
    
    # Store the reconstructed RDM
    roi_reconstructed_rdms[roi] = rdm
    
    print(f"  Reconstructed RDM shape: {rdm.shape}")
    print(f"  RDM is symmetric: {np.allclose(rdm, rdm.T)}")
    print(f"  Diagonal is zero: {np.all(np.diag(rdm) == 0)}")
    print(f"  Min value: {np.min(rdm):.4f}")
    print(f"  Max value: {np.max(rdm):.4f}")
    print(f"  Mean value: {np.mean(rdm):.4f}")

print(f"\nSuccessfully reconstructed RDMs for {len(roi_reconstructed_rdms)} ROIs")
print(f"Available ROIs: {sorted(roi_reconstructed_rdms.keys())}")

# Verify reconstruction by comparing with original data
print(f"\nReconstruction verification:")
for roi in sorted(roi_reconstructed_rdms.keys()):
    reconstructed_rdm = roi_reconstructed_rdms[roi]
    original_vector = roi_average_upper_triangles[roi]
    
    # Extract upper triangular from reconstructed RDM
    reconstructed_vector = reconstructed_rdm[upper_tri_indices]
    
    # Check if they match
    vectors_match = np.allclose(original_vector, reconstructed_vector)
    print(f"  {roi}: Vectors match = {vectors_match}")
    
    if not vectors_match:
        diff = np.abs(original_vector - reconstructed_vector)
        print(f"    Max difference: {np.max(diff):.10f}")
        print(f"    Mean difference: {np.mean(diff):.10f}")

# Show how to access the reconstructed RDMs
print(f"\nExample usage:")
first_roi = sorted(roi_reconstructed_rdms.keys())[0]
print(f"  roi_reconstructed_rdms['{first_roi}']  # 720x720 RDM for {first_roi}")
print(f"  roi_reconstructed_rdms['{first_roi}'][:5, :5]  # First 5x5 of the RDM")

# Show a sample of the reconstructed RDM
sample_roi = sorted(roi_reconstructed_rdms.keys())[0]
sample_rdm = roi_reconstructed_rdms[sample_roi]
print(f"\nSample RDM ({sample_roi}) - first 5x5:")
print(sample_rdm[:5, :5])

### Define a function that uses ridge regression to generate weights for combining the 14 layer-level RDMs to predict a single neural RDM.

In [None]:
def ridge_reweighting_single_rdm(
    model_rdms,  # Shape: [L, N, N]
    neural_rdm,   # Shape: [N, N] - single RDM
    alpha_candidates=None,
    cv_folds=5,
    objective='spearman'
):
    """
    Compute ridge weights combining layer model RDMs to predict a single neural RDM.
    
    Args:
        model_rdms: Model RDMs of shape [L, N, N]
        neural_rdm: Single neural RDM of shape [N, N]
        alpha_candidates: List of alpha values to test
        cv_folds: Number of cross-validation folds
        objective: 'spearman' or 'ridgecv'
        
    Returns:
        weights: Ridge weights of shape [L] - one weight per layer
        best_alpha: Best alpha value found
        cv_score: Cross-validation score
    """

    L, N, _ = model_rdms.shape
    
    # Build feature matrix X = upper-tri of each layer RDM -> [n_pairs, L]
    triu = torch.triu_indices(N, N, offset=1)
    n_pairs = triu.shape[1]
    X = np.zeros((n_pairs, L), dtype=np.float32)
    
    for l in range(L):
        rdm_l = model_rdms[l]
        X[:, l] = rdm_l[triu[0], triu[1]].astype(np.float32)
    
    # Extract target vector from neural RDM
    if isinstance(neural_rdm, torch.Tensor):
        y = neural_rdm[triu[0], triu[1]].cpu().numpy().astype(np.float32)
    else:
        y = neural_rdm[triu[0], triu[1]].astype(np.float32)
    
    alpha_candidates = alpha_candidates if alpha_candidates is not None else [1e-3, 1e-2, 1e-1, 1.0, 1e1, 1e2, 1e3]
    
    # Cross-validation to find best alpha
    kf = KFold(n_splits=cv_folds, shuffle=True, random_state=0)
    best_alpha = None
    best_score = -np.inf
    
    print(f"Cross-validating {len(alpha_candidates)} alpha values with {cv_folds} folds...")

    for alpha in tqdm(alpha_candidates, desc="Cross-validating alphas"):
        fold_scores = []
        
        for train_idx, val_idx in kf.split(X):
            X_tr, X_val = X[train_idx], X[val_idx]
            y_tr, y_val = y[train_idx], y[val_idx]
            
            if objective == 'spearman':
                # Rank-transform targets and center
                y_tr_rank = stats.rankdata(y_tr).astype(np.float32)
                y_tr_rank -= y_tr_rank.mean()
                
                model = Ridge(alpha=alpha, fit_intercept=False)
                model.fit(X_tr, y_tr_rank)
                y_val_pred = X_val @ model.coef_.astype(np.float32)
                
                rho, _ = stats.spearmanr(y_val_pred, y_val)
                if np.isnan(rho):
                    rho = 0.0
                fold_scores.append(float(rho))

            else:  # ridgecv objective
                # Center targets
                y_tr_centered = y_tr - y_tr.mean()
                y_val_centered = y_val - y_val.mean()
                
                model = Ridge(alpha=alpha, fit_intercept=False)
                model.fit(X_tr, y_tr_centered)
                y_val_pred = X_val @ model.coef_.astype(np.float32)
                
                # R² score
                ss_res = np.sum((y_val_centered - y_val_pred)**2)
                ss_tot = np.sum((y_val_centered - y_val_centered.mean())**2) + 1e-12
                r2 = 1.0 - ss_res/ss_tot
                fold_scores.append(float(r2))
        
        mean_score = float(np.mean(fold_scores))
        if mean_score > best_score:
            best_score = mean_score
            best_alpha = float(alpha)

    # Refit on full data with best alpha
    if objective == 'spearman':
        y_rank = stats.rankdata(y).astype(np.float32)
        y_rank -= y_rank.mean()
        final_model = Ridge(alpha=best_alpha, fit_intercept=False)
        final_model.fit(X, y_rank)
        weights = torch.from_numpy(final_model.coef_.astype(np.float32))
    else:
        y_centered = y - y.mean()
        final_model = Ridge(alpha=best_alpha, fit_intercept=False)
        final_model.fit(X, y_centered)
        weights = torch.from_numpy(final_model.coef_.astype(np.float32))
    
    print(f"Best alpha: {best_alpha}")
    print(f"CV score: {best_score:.4f}")
    
    return weights, best_alpha, best_score

### Define a function that computes a feature-reweighted model RDM from the ridge weights.

In [None]:
def compute_reweighted_predictor(
    model_rdms,  # Shape: [L, N, N]
    weights     # Shape: [L] - ridge weights
):
    """
    Compute the reweighted predictor RDM by combining model RDMs with ridge weights.
    
    Args:
        model_rdms: Model RDMs of shape [L, N, N]
        weights: Ridge weights of shape [L]
        
    Returns:
        reweighted_rdm: Combined RDM of shape [N, N]
    """
    L, N, _ = model_rdms.shape
    
    # Convert weights to numpy if needed
    if isinstance(weights, torch.Tensor):
        weights = weights.cpu().numpy()
    
    # Initialize reweighted RDM
    reweighted_rdm = np.zeros((N, N), dtype=np.float32)
    
    # Weighted combination of model RDMs
    for l in range(L):
        reweighted_rdm += weights[l] * model_rdms[l]
    
    # Ensure symmetry (in case of numerical errors)
    reweighted_rdm = (reweighted_rdm + reweighted_rdm.T) / 2
    
    # Set diagonal to 0 (RDMs should have 0 on diagonal)
    np.fill_diagonal(reweighted_rdm, 0)

    return reweighted_rdm

### Define a function that will run the RDM feature-reweighting analysis.

In [None]:
def run_single_rdm_analysis(
    model_rdms,      # Shape: [L, N, N]
    neural_rdm,      # Shape: [N, N]
    alpha_candidates=None,
    cv_folds=5,
    objective='spearman'
):
    """
    Complete pipeline for single RDM ridge regression analysis.
    
    Returns:
        weights: Ridge weights [L]
        best_alpha: Best alpha value
        cv_score: Cross-validation score
        reweighted_rdm: Predicted RDM [N, N]
        correlation: Correlation between predicted and actual RDM
    """
    print("Single RDM Ridge Regression Analysis")
    print("=" * 50)
    print(f"Model RDMs shape: {model_rdms.shape}")
    print(f"Neural RDM shape: {neural_rdm.shape}")
    print(f"Objective: {objective}")
    print(f"CV folds: {cv_folds}")
    
    # Compute ridge weights
    weights, best_alpha, cv_score = ridge_reweighting_single_rdm(
        model_rdms, neural_rdm, alpha_candidates, cv_folds, objective
    )
    
    # Compute reweighted predictor
    reweighted_rdm = compute_reweighted_predictor(model_rdms, weights)

    # Compute correlation between predicted and actual
    # Extract upper triangular parts for correlation
    triu = np.triu_indices(neural_rdm.shape[0], k=1)
    neural_flat = neural_rdm[triu]
    predicted_flat = reweighted_rdm[triu]
    
    if objective == 'spearman':
        correlation, _ = stats.spearmanr(neural_flat, predicted_flat)
    else:
        correlation, _ = stats.pearsonr(neural_flat, predicted_flat)
    
    print(f"\nResults:")
    print(f"Best alpha: {best_alpha}")
    print(f"CV score: {cv_score:.4f}")
    print(f"Final correlation: {correlation:.4f}")
    print(f"Weights: {weights.numpy()}")
    
    return {
        'weights': weights,
        'best_alpha': best_alpha,
        'cv_score': cv_score,
        'reweighted_rdm': reweighted_rdm,
        'correlation': correlation
    }