## Part 6: Full Protocol
#### With GPU-accelerated Ridge Regression Using the Himalaya Library

This notebook tutorial walks through the full crossmodal fMRI prediction process using the BridgeTower model. We will walk through extracting features from natural stimuli using BridgeTower layers, building voxelwise encoding models to predict fMRI data from stimuli features, and finally predicting language fMRI data using the vision encoding model and predicting visual fMRI data using the language encoding model.

In [1]:
# Select parameters
subject = 'S1'  # S1-S5
modality = 'vision'  # vision or language
layer = 8  # 1-13

## 1 Feature Extraction
We'll begin by putting our natural stimuli through the BridgeTower model and extracting feature representations from the layer specified above

### 1.1 Load Stimuli

#### 1.1.1 Movie Stimuli
Our movie data are stored in HDF format so we need a helper function to load them

In [2]:
def load_hdf5_array(file_name, key=None, slice=slice(0, None)):
    """Function to load data from an hdf file.

    Parameters
    ----------
    file_name: string
        hdf5 file name.
    key: string
        Key name to load. If not provided, all keys will be loaded.
    slice: slice, or tuple of slices
        Load only a slice of the hdf5 array. It will load `array[slice]`.
        Use a tuple of slices to get a slice in multiple dimensions.

    Returns
    -------
    result : array or dictionary
        Array, or dictionary of arrays (if `key` is None).
    """
    with h5py.File(file_name, mode='r') as hf:
        if key is None:
            data = dict()
            for k in hf.keys():
                data[k] = hf[k][slice]
            return data
        else:
            return hf[key][slice]

In [None]:
test = load_hdf5_array('data/raw_stimuli/shortclips/stimuli/test.hdf', key='stimuli')
train_00 = load_hdf5_array('data/raw_stimuli/shortclips/stimuli/train_00.hdf', key='stimuli')
train_00 = load_hdf5_array('data/raw_stimuli/shortclips/stimuli/train_00.hdf', key='stimuli')
train_01 = load_hdf5_array('data/raw_stimuli/shortclips/stimuli/train_01.hdf', key='stimuli')
train_07 = load_hdf5_array('data/raw_stimuli/shortclips/stimuli/train_07.hdf', key='stimuli')
train_03 = load_hdf5_array('data/raw_stimuli/shortclips/stimuli/train_03.hdf', key='stimuli')
train_04 = load_hdf5_array('data/raw_stimuli/shortclips/stimuli/train_04.hdf', key='stimuli')
train_05 = load_hdf5_array('data/raw_stimuli/shortclips/stimuli/train_05.hdf', key='stimuli')
train_06 = load_hdf5_array('data/raw_stimuli/shortclips/stimuli/train_06.hdf', key='stimuli')
train_07 = load_hdf5_array('data/raw_stimuli/shortclips/stimuli/train_07.hdf', key='stimuli')
train_08 = load_hdf5_array('data/raw_stimuli/shortclips/stimuli/train_08.hdf', key='stimuli')
train_09 = load_hdf5_array('data/raw_stimuli/shortclips/stimuli/train_09.hdf', key='stimuli')
train_10 = load_hdf5_array('data/raw_stimuli/shortclips/stimuli/train_10.hdf', key='stimuli')
train_11 = load_hdf5_array('data/raw_stimuli/shortclips/stimuli/train_11.hdf', key='stimuli')

### 1.1.2 Story Stimuli
Our story transcripts are in TextGrid format so we want to load them into a list, we'll create a helper function to do so.

In [3]:
def textgrid_to_array(textgrid):
    """Function to load transcript from textgrid into a list.

    Parameters
    ----------
    textgrid: string
        TextGrid file name.

    Returns
    -------
    full_transcript : Array
        Array with each word in the story.
    """
    if textgrid == 'data/raw_stimuli/textgrids/stimuli/legacy.TextGrid':
        with open(textgrid, 'r')as file:
            data = file.readlines()

        full_transcript = []
        # Important info starts at line 5
        for line in data[5:]:
            if line.startswith('2'):
                index = data.index(line)
                word = re.search(r'"([^"]*)"', data[index+1].strip()).group(1)
                full_transcript.append(word)
    elif textgrid == 'data/raw_stimuli/textgrids/stimuli/life.TextGrid':
        with open(textgrid, 'r') as file:
            data = file.readlines()

        full_transcript = []
        for line in data:
            if "word" in line:
                index = data.index(line)
                words = data[index+6:]  # this is where first word starts

        for i, word in enumerate(words):
            if i % 3 == 0:
                word = re.search(r'"([^"]*)"', word.strip()).group(1)
                full_transcript.append(word)
    else:
        with open(textgrid, 'r') as file:
            data = file.readlines()

        # Important info starts at line 8
        for line in data[8:]:
            # We only want item [2] info because those are the words instead
            # of phonemes
            if "item [2]" in line:
                index = data.index(line)

        summary_info = [line.strip() for line in data[index+1:index+6]]
        print(summary_info)

        word_script = data[index+6:]
        full_transcript = []
        for line in word_script:
            if "intervals" in line:
                # keep track of which interval we're on
                ind = word_script.index(line)
                word = re.search(r'"([^"]*)"',
                                 word_script[ind+3].strip()).group(1)
                full_transcript.append(word)

    return np.array(full_transcript)

In [None]:
alternateithicatom = textgrid_to_array("data/raw_stimuli/textgrids/stimuli/alternateithicatom.TextGrid")
avatar = textgrid_to_array("data/raw_stimuli/textgrids/stimuli/avatar.TextGrid")
howtodraw = textgrid_to_array("data/raw_stimuli/textgrids/stimuli/howtodraw.TextGrid")
legacy = textgrid_to_array("data/raw_stimuli/textgrids/stimuli/legacy.TextGrid")
life = textgrid_to_array("data/raw_stimuli/textgrids/stimuli/life.TextGrid")
myfirstdaywiththeyankees = textgrid_to_array("data/raw_stimuli/textgrids/stimuli/myfirstdaywiththeyankees.TextGrid")
naked = textgrid_to_array("data/raw_stimuli/textgrids/stimuli/naked.TextGrid")
odetostepfather = textgrid_to_array("data/raw_stimuli/textgrids/stimuli/odetostepfather.TextGrid")
souls = textgrid_to_array("data/raw_stimuli/textgrids/stimuli/souls.TextGrid")
undertheinfluence = textgrid_to_array("data/raw_stimuli/textgrids/stimuli/undertheinfluence.TextGrid")

### 1.2 Run Stimuli through Feature Extraction

### 1.2.1 Helper Functions
We need three functions, one to set-up the BridgeTower model, one to run the movie stimuli, and one to run the story stimuli

In [4]:
def setup_model(layer):
    """Function to setup transformers model with layer hooks.

    Parameters
    ----------
    layer: int
        A layer reference for the BridgeTower model. Set's the forward
        hook on the relevant layer

    Returns
    -------
    device : cuda or cpu for gpu acceleration if accessible.
    model: BridgeTower model.
    processor: BridgeTower processor.
    features: Dictionary
        A placeholder for batch features, one for each forward
        hook.
    layer_selected: Relevant layer chosen for forward hook.
    """
    # Define Model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = BridgeTowerModel.from_pretrained("BridgeTower/bridgetower-base")
    model = model.to(device)

    # Define layers
    model_layers = {
            1: model.cross_modal_text_transform,
            2: model.cross_modal_image_transform,
            3: model.token_type_embeddings,
            4: model.vision_model.visual.ln_post,
            5: model.text_model.encoder.layer[-1].output.LayerNorm,
            6: model.cross_modal_image_layers[-1].output,
            7: model.cross_modal_text_layers[-1].output,
            8: model.cross_modal_image_pooler,
            9: model.cross_modal_text_pooler,
            10: model.cross_modal_text_layernorm,
            11: model.cross_modal_image_layernorm,
            12: model.cross_modal_text_link_tower[-1],
            13: model.cross_modal_image_link_tower[-1],
        }

    # placeholder for batch features
    features = {}

    def get_features(name):
        def hook(model, input, output):
            # detached_outputs = [tensor.detach() for tensor in output]
            last_output = output[-1].detach()
            features[name] = last_output  # detached_outputs
        return hook

    # register forward hooks with layers of choice
    layer_selected = model_layers[layer].register_forward_hook(
        get_features(f'layer_{layer}'))

    processor = BridgeTowerProcessor.from_pretrained(
        "BridgeTower/bridgetower-base")

    return device, model, processor, features, layer_selected

In [5]:
def get_movie_features(movie_data, layer, n=30):
    """Function to average feature vectors over every n inputs.

    Parameters
    ----------
    movie_data: Array
        An array of shape (n_images, 512, 512). Represents frames from
        a color movie.
    n (optional): int
        Number of frames to average over. Set at 30 to mimick an MRI
        TR = 2 with a 15 fps movie.

    Returns
    -------
    data : Dictionary
        Dictionary where keys are the model layer from which activations are
        extracted. Values are lists representing activations of 768 dimensions
        over the course of n_images / 30.
    """
    print("loading HDF array")
    movie_data = load_hdf5_array(movie_data, key='stimuli')

    print("Running movie through model")

    # Define Model
    device, model, processor, features, layer_selected = setup_model(layer)

    # create overall data structure for average feature vectors
    # a dictionary with layer names as keys and a list of vectors as it values
    data = {}

    # a dictionary to store vectors for n consecutive trials
    avg_data = {}

    # loop through all inputs
    for i, image in enumerate(movie_data):

        model_input = processor(image, "", return_tensors="pt")
        # Assuming model_input is a dictionary of tensors
        model_input = {key: value.to(device) for key,
                       value in model_input.items()}

        _ = model(**model_input)

        for name, tensor in features.items():
            if name not in avg_data:
                avg_data[name] = []
            avg_data[name].append(tensor)

        # check if average should be stored
        if (i + 1) % n == 0:
            for name, tensors in avg_data.items():
                first_size = tensors[0].size()

                if all(tensor.size() == first_size for tensor in tensors):
                    avg_feature = torch.mean(torch.stack(tensors), dim=0)
                else:
                    # Find problem dimension
                    for dim in range(tensors[0].dim()):
                        first_dim = tensors[0].size(dim)

                        if not all(tensor.size(dim) == first_dim
                                   for tensor in tensors):
                            # Specify place to pad
                            p_dim = (tensors[0].dim()*2) - (dim + 2)
                            # print(p_dim)
                            max_size = max(tensor.size(dim)
                                           for tensor in tensors)
                            padded_tensors = []

                            for tensor in tensors:
                                # Make a list with length of 2*dimensions - 1
                                # to insert pad later
                                pad_list = [0] * ((2*tensor[0].dim()) - 1)
                                pad_list.insert(
                                    p_dim, max_size - tensor.size(dim))
                                # print(tuple(pad_list))
                                padded_tensor = pad(tensor, tuple(pad_list))
                                padded_tensors.append(padded_tensor)

                    avg_feature = torch.mean(torch.stack(padded_tensors),
                                             dim=0)

                if name not in data:
                    data[name] = []
                data[name].append(avg_feature)

            avg_data = {}

    layer_selected.remove()

    # Save data
    data = data[f'layer_{layer}'].cpu()
    data = data.numpy()

    return data

In [6]:
def get_story_features(story_data, layer, n=20):
    """Function to extract feature vectors for each word of a story.

    Parameters
    ----------
    story_data: Array
        An array containing each word of the story in order.
    n (optional): int
        Number of words to to pad the target word with for
        context (before and after).

    Returns
    -------
    data : Dictionary
        Dictionary where keys are the model layer from which activations are
        extracted. Values are lists representing activations of 768 dimensions
        over the course of each word in the story.
    """
    print("loading textgrid")
    story_data = textgrid_to_array(story_data)

    print("Running story through model")
    # Define Model
    device, model, processor, features, layer_selected = setup_model(layer)

    # Create a numpy array filled with gray values (128 in this case)
    # THis will act as tthe zero image input***
    gray_value = 128
    image_array = np.full((512, 512, 3), gray_value, dtype=np.uint8)

    # create overall data structure for average feature vectors
    # a dictionary with layer names as keys and a list of vectors as it values
    data = {}

    # loop through all inputs
    for i, word in enumerate(story_data):
        # if one of first 20 words, just pad with all the words before it
        if i < n:
            # collapse list of strings into a single one
            word_with_context = ' '.join(story_data[:(i+n)])
        # if one of last 20 words, just pad with all the words after it
        elif i > (len(story_data) - n):
            # collapse list of strings into a single one
            word_with_context = ' '.join(story_data[(i-n):])
            # collapse list of strings into a single one
        else:
            word_with_context = ' '.join(story_data[(i-n):(i+n)])

        model_input = processor(image_array, word_with_context,
                                return_tensors="pt")
        # Assuming model_input is a dictionary of tensors
        model_input = {key: value.to(device) for key,
                       value in model_input.items()}

        _ = model(**model_input)

        for name, tensor in features.items():
            if name not in data:
                data[name] = []
            data[name].append(tensor)

    layer_selected.remove()

    # Save data
    data = data[f'layer_{layer}'].cpu()
    data = data.numpy()

    return data

### 1.2.2 Run Stimuli

In [None]:
test_features = get_movie_features(test, layer)
train00_features = get_movie_features(train_00, layer)
train00_features = get_movie_features(train_00, layer)
train00_features = get_movie_features(train_00, layer)
train00_features = get_movie_features(train_00, layer)
train00_features = get_movie_features(train_00, layer)
train00_features = get_movie_features(train_00, layer)
train00_features = get_movie_features(train_00, layer)
train00_features = get_movie_features(train_00, layer)
train00_features = get_movie_features(train_00, layer)
train00_features = get_movie_features(train_00, layer)
train00_features = get_movie_features(train_00, layer)
train00_features = get_movie_features(train_00, layer)

In [None]:
ai_features = get_story_features(alternateithicatom, layer)
avatar_features = get_story_features(avatar, layer)
howtodraw_features = get_story_features(howtodraw, layer)
legacy_features = get_story_features(legacy, layer)
life_features = get_story_features(life, layer)
yankees_features = get_story_features(myfirstdaywiththeyankees, layer)
naked_features = get_story_features(naked, layer)
ode_features = get_story_features(odetostepfather, layer)
souls_features = get_story_features(souls, layer)
under_features = get_story_features(undertheinfluence, layer)
souls_features = get_story_features(souls, layer)

## 2 Voxelwise Encoding Models

In [None]:
def create_flatmap(subject, layer, correlations, modality):
    """Function to run the vision encoding model. Predicts brain activity
    to story listening and return correlations between predictions and real
    brain activity.

    Parameters
    ----------
    subject: string
        A reference to the subject for analysis. Used to load fmri data.
    layer: int
        A layer reference for the BridgeTower model. Set's the forward
        hook on the relevant layer.
    correlations: array
        Generated by story_prediction() or movie_prediction() function.
        Contains the correlation between predicted and real brain activity
        for each voxel.
    modality: string
        Which modality was used for the base encoding model: vision or
        language.

    Returns
    -------
    Flatmaps:
        Saves flatmap visualizations as pngs
    """
    # Reverse flattening and masking
    fmri_alternateithicatom = np.load("data/storydata/" + subject +
                                      "/alternateithicatom.npy")

    mask = ~np.isnan(fmri_alternateithicatom[0])  # reference for the mask
    # Initialize an empty 3D array with NaNs for the correlation data
    reconstructed_correlations = np.full((31, 100, 100), np.nan)

    # Flatten the mask to get the indices of the non-NaN data points
    valid_indices = np.where(mask.flatten())[0]

    # Assign the correlation coefficients to their original spatial positions
    for index, corr_value in zip(valid_indices, correlations):
        # Convert the 1D index back to 3D index in the spatial dimensions
        z, x, y = np.unravel_index(index, (31, 100, 100))
        reconstructed_correlations[z, x, y] = corr_value

    flattened_correlations = reconstructed_correlations.flatten()

    # Load mappers
    lh_mapping_matrix = load_npz("data/mappers/" + subject +
                                 "_listening_forVL_lh.npz")
    lh_vertex_correlation_data = lh_mapping_matrix.dot(flattened_correlations)
    lh_vertex_coords = np.load("data/mappers/" + subject +
                               "_vertex_coords_lh.npy")

    rh_mapping_matrix = load_npz("data/mappers/" + subject +
                                 "_listening_forVL_rh.npz")
    rh_vertex_correlation_data = rh_mapping_matrix.dot(flattened_correlations)
    rh_vertex_coords = np.load("data/mappers/" + subject +
                               "_vertex_coords_rh.npy")

    vmin, vmax = -0.1, 0.1
    fig, axs = plt.subplots(1, 2, figsize=(7, 4))

    # Plot the first flatmap
    sc1 = axs[0].scatter(lh_vertex_coords[:, 0], lh_vertex_coords[:, 1],
                         c=lh_vertex_correlation_data, cmap='RdBu_r',
                         vmin=vmin, vmax=vmax, s=.005)
    axs[0].set_aspect('equal', adjustable='box')  # Ensure equal scaling
    # axs[0].set_title('Left Hemisphere')
    axs[0].set_frame_on(False)
    axs[0].set_xticks([])  # Remove x-axis ticks
    axs[0].set_yticks([])  # Remove y-axis ticks

    # Plot the second flatmap
    _ = axs[1].scatter(rh_vertex_coords[:, 0], rh_vertex_coords[:, 1],
                       c=rh_vertex_correlation_data, cmap='RdBu_r',
                       vmin=vmin, vmax=vmax, s=.005)
    axs[1].set_aspect('equal', adjustable='box')  # Ensure equal scaling
    # axs[1].set_title('Right Hemisphere')
    axs[1].set_frame_on(False)
    axs[1].set_xticks([])  # Remove x-axis ticks
    axs[1].set_yticks([])  # Remove y-axis ticks

    # Adjust layout to make space for the top colorbar
    plt.subplots_adjust(top=0.85, wspace=0)

    # Add a single horizontal colorbar at the top
    cbar_ax = fig.add_axes([0.25, 0.9, 0.5, 0.03])
    cbar = fig.colorbar(sc1, cax=cbar_ax, orientation='horizontal')

    # Set the color bar to only display min and max values
    cbar.set_ticks([vmin, vmax])
    cbar.set_ticklabels([f'{vmin}', f'{vmax}'])

    # Remove the color bar box
    cbar.outline.set_visible(False)
    if modality == 'vision':
        latex = r"$r_{\mathit{movie \rightarrow story}}"
        plt.title(f'{subject}\n{latex}$')

        plt.savefig('results/movie_to_story/' + subject + '/layer' + layer +
                    '_visual.png', format='png')
    elif modality == 'language':
        latex = r"$r_{\mathit{story \rightarrow movie}}"
        plt.title(f'{subject}\n{latex}$')
        plt.savefig('results/story_to_movie/' + subject + '/layer' + layer +
                    '_visual.png', format='png')
    plt.show()