## Part 6: Full Protocol
#### With GPU-accelerated Ridge Regression Using the Himalaya Library

This notebook tutorial walks through the full crossmodal fMRI prediction process using the BridgeTower model. We will walk through extracting features from natural stimuli using BridgeTower layers, building voxelwise encoding models to predict fMRI data from stimuli features, and finally predicting language fMRI data using the vision encoding model and predicting visual fMRI data using the language encoding model.

In [1]:
# Select parameters
subject = 'S1'  # S1-S5
modality = 'vision'  # vision or language
layer = 8  # 1-13

## 1 Feature Extraction

In [None]:
def create_flatmap(subject, layer, correlations, modality):
    """Function to run the vision encoding model. Predicts brain activity
    to story listening and return correlations between predictions and real
    brain activity.

    Parameters
    ----------
    subject: string
        A reference to the subject for analysis. Used to load fmri data.
    layer: int
        A layer reference for the BridgeTower model. Set's the forward
        hook on the relevant layer.
    correlations: array
        Generated by story_prediction() or movie_prediction() function.
        Contains the correlation between predicted and real brain activity
        for each voxel.
    modality: string
        Which modality was used for the base encoding model: vision or
        language.

    Returns
    -------
    Flatmaps:
        Saves flatmap visualizations as pngs
    """
    # Reverse flattening and masking
    fmri_alternateithicatom = np.load("data/storydata/" + subject +
                                      "/alternateithicatom.npy")

    mask = ~np.isnan(fmri_alternateithicatom[0])  # reference for the mask
    # Initialize an empty 3D array with NaNs for the correlation data
    reconstructed_correlations = np.full((31, 100, 100), np.nan)

    # Flatten the mask to get the indices of the non-NaN data points
    valid_indices = np.where(mask.flatten())[0]

    # Assign the correlation coefficients to their original spatial positions
    for index, corr_value in zip(valid_indices, correlations):
        # Convert the 1D index back to 3D index in the spatial dimensions
        z, x, y = np.unravel_index(index, (31, 100, 100))
        reconstructed_correlations[z, x, y] = corr_value

    flattened_correlations = reconstructed_correlations.flatten()

    # Load mappers
    lh_mapping_matrix = load_npz("data/mappers/" + subject +
                                 "_listening_forVL_lh.npz")
    lh_vertex_correlation_data = lh_mapping_matrix.dot(flattened_correlations)
    lh_vertex_coords = np.load("data/mappers/" + subject +
                               "_vertex_coords_lh.npy")

    rh_mapping_matrix = load_npz("data/mappers/" + subject +
                                 "_listening_forVL_rh.npz")
    rh_vertex_correlation_data = rh_mapping_matrix.dot(flattened_correlations)
    rh_vertex_coords = np.load("data/mappers/" + subject +
                               "_vertex_coords_rh.npy")

    vmin, vmax = -0.1, 0.1
    fig, axs = plt.subplots(1, 2, figsize=(7, 4))

    # Plot the first flatmap
    sc1 = axs[0].scatter(lh_vertex_coords[:, 0], lh_vertex_coords[:, 1],
                         c=lh_vertex_correlation_data, cmap='RdBu_r',
                         vmin=vmin, vmax=vmax, s=.005)
    axs[0].set_aspect('equal', adjustable='box')  # Ensure equal scaling
    # axs[0].set_title('Left Hemisphere')
    axs[0].set_frame_on(False)
    axs[0].set_xticks([])  # Remove x-axis ticks
    axs[0].set_yticks([])  # Remove y-axis ticks

    # Plot the second flatmap
    _ = axs[1].scatter(rh_vertex_coords[:, 0], rh_vertex_coords[:, 1],
                       c=rh_vertex_correlation_data, cmap='RdBu_r',
                       vmin=vmin, vmax=vmax, s=.005)
    axs[1].set_aspect('equal', adjustable='box')  # Ensure equal scaling
    # axs[1].set_title('Right Hemisphere')
    axs[1].set_frame_on(False)
    axs[1].set_xticks([])  # Remove x-axis ticks
    axs[1].set_yticks([])  # Remove y-axis ticks

    # Adjust layout to make space for the top colorbar
    plt.subplots_adjust(top=0.85, wspace=0)

    # Add a single horizontal colorbar at the top
    cbar_ax = fig.add_axes([0.25, 0.9, 0.5, 0.03])
    cbar = fig.colorbar(sc1, cax=cbar_ax, orientation='horizontal')

    # Set the color bar to only display min and max values
    cbar.set_ticks([vmin, vmax])
    cbar.set_ticklabels([f'{vmin}', f'{vmax}'])

    # Remove the color bar box
    cbar.outline.set_visible(False)
    if modality == 'vision':
        latex = r"$r_{\mathit{movie \rightarrow story}}"
        plt.title(f'{subject}\n{latex}$')

        plt.savefig('results/movie_to_story/' + subject + '/layer' + layer +
                    '_visual.png', format='png')
    elif modality == 'language':
        latex = r"$r_{\mathit{story \rightarrow movie}}"
        plt.title(f'{subject}\n{latex}$')
        plt.savefig('results/story_to_movie/' + subject + '/layer' + layer +
                    '_visual.png', format='png')
    plt.show()