# Overview

The goal of this project is to modify the output of a specific layer in a BERT model by integrating additional information through various methods. The solution should be flexible, allowing changes to the layer being modified and the method of integration.

The project provides two primary approaches for integrating additional information into the BERT model:

1.   **Hook-Based Approach**: Uses PyTorch hooks to modify the input of a specific layer during the forward pass.
2.   **Custom Layer-Based Approach**: Defines a custom layer that integrates additional information.

Both approaches enable flexible manipulation of BERT's internal representations for various layers and integration methods.

Both the hook-based method and the custom layer-based method resulted in the **same outcome**. When tested, both methods produced identical modified hidden states for the specific layer in the BERT model. In terms of runtime, the **hook-based approach was slightly faster**.

# Dependecies

In [1]:
import torch
from torch import nn
from transformers import BertModel, BertTokenizer
import numpy as np
import random
import time

# Common Setup Functions


In [2]:
def set_seed():
    # Set the random seed for reproducibility.
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def setup_model_and_tokenizer(model_name='bert-base-uncased'):
    # Load a pre-trained BERT model and tokenizer.
    model = BertModel.from_pretrained(model_name)
    tokenizer = BertTokenizer.from_pretrained(model_name)
    return model, tokenizer

def prepare_inputs(text, tokenizer):
    # Tokenize input text and convert to tensors.
    return tokenizer(text, return_tensors='pt')

def create_additional_input_vector(hidden_size):
    # Create a random tensor of specified hidden size.
    set_seed()
    return torch.randn(1, 1, hidden_size)

# Common Integration Method Apllier


In [3]:
class IntegrationMethodApplier(nn.Module):
    def __init__(self, integration_method):
        """
        Initialize the IntegrationMethodApplier with a specified method.

        Parameters:
        integration_method (str): The method to integrate additional input ('addition' or 'multiplication').
        """
        super(IntegrationMethodApplier, self).__init__()
        self.integration_method = integration_method

    def forward(self, input_tensor, additional_input_vector):
        """
        Apply the specified integration method to the input tensor and additional input vector.

        Parameters:
        input_tensor (torch.Tensor): The original input tensor.
        additional_input_vector (torch.Tensor): The additional input vector to integrate.

        Returns:
        torch.Tensor: The result of the integration.
        """
        if self.integration_method == "addition":
            return input_tensor + additional_input_vector
        elif self.integration_method == "multiplication":
            return input_tensor * additional_input_vector
        else:
            raise ValueError("Unsupported integration method")

# Hook-Based Approach

In [4]:
class HookBasedBERTModifier:
    def __init__(self, model, layer_number, integration_method_applier):
        """
        Initialize the HookBasedBERTModifier.

        Parameters:
        model (BertModel): The BERT model to modify.
        layer_number (int): The layer number to apply the modification.
        integration_method_applier (IntegrationMethodApplier): The method to integrate the additional input.
        """
        self.model = model
        self.layer_number = layer_number
        self.integration_method_applier = integration_method_applier
        self.hook = None

    def modify_output(self, module, input, output):
        """
        Modify the output of the specified layer.

        Parameters:
        module (nn.Module): The module to which the hook is attached.
        input (tuple): The input to the module.
        output (torch.Tensor): The output from the module.

        Returns:
        tuple: The modified output.
        """
        output_tensor = output[0]
        modified_output = self.integration_method_applier(output_tensor, self.additional_input_vector)
        return (modified_output,)

    def register_hook(self, additional_input_vector):
        """
        Register a hook on the specified layer to modify its output.

        Parameters:
        additional_input_vector (torch.Tensor): The vector to integrate with the layer output.
        """
        self.additional_input_vector = additional_input_vector
        layer = self.model.encoder.layer[self.layer_number - 1]
        self.hook = layer.register_forward_hook(self.modify_output)

    def remove_hook(self):
        """
        Remove the registered hook if it exists.
        """
        if self.hook is not None:
            self.hook.remove()
            self.hook = None


# Custom Layer-Based Approach


In [8]:
class CustomLayerBERTModifier(nn.Module):
    def __init__(self, model, layer_number, integration_method_applier):
        """
        Initialize the CustomLayerBERTModifier.

        Parameters:
        model (BertModel): The BERT model to modify.
        layer_number (int): The layer number to start the modification.
        integration_method_applier (IntegrationMethodApplier): The method to integrate the additional input.
        """
        super(CustomLayerBERTModifier, self).__init__()
        self.bert = model
        self.layer_number = layer_number
        self.integration_method_applier = integration_method_applier

    def forward(self, input_ids, attention_mask, additional_input_vector):
        """
        Forward pass to apply the custom layer modification.

        Parameters:
        input_ids (torch.Tensor): The input IDs for the BERT model.
        attention_mask (torch.Tensor): The attention mask for the BERT model.
        additional_input_vector (torch.Tensor): The vector to integrate with the layer output.

        Returns:
        torch.Tensor: The modified output after applying the integration method.
        """
        # Get the BERT model outputs with hidden states
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        hidden_states = outputs.hidden_states

        # Apply the integration method to the specified layer's hidden states
        modified_layer_input = self.integration_method_applier(hidden_states[self.layer_number], additional_input_vector)

        # Pass the modified input through the remaining layers
        for i in range(self.layer_number, len(self.bert.encoder.layer)):
            modified_layer_input = self.bert.encoder.layer[i](modified_layer_input)[0]

        return modified_layer_input


# Compare Methods

In [10]:
def compare_methods(input_text, layer_number, integration_method):

    """
    Compare hook-based and custom layer-based methods for modifying BERT outputs.

    Parameters:
    input_text (str): The input text to process with the BERT model.
    layer_number (int): The layer number to apply the modification.
    integration_method (str): The method used for integrating the additional input ('addition' or 'multiplication').

    Returns:
    bool: Whether the outputs of both methods are the same.
    """

    # Set the seed for reproducibility
    set_seed()

    # Common setup
    model, tokenizer = setup_model_and_tokenizer()
    additional_input_vector = create_additional_input_vector(model.config.hidden_size)
    integration_method_applier = IntegrationMethodApplier(integration_method)
    inputs = prepare_inputs(input_text, tokenizer)

    # Ensure the model is in evaluation mode
    model.eval()

    # Hook-based modifier
    hook_modifier = HookBasedBERTModifier(model, layer_number, integration_method_applier)
    hook_modifier.register_hook(additional_input_vector)

    start_time = time.time()
    with torch.no_grad():
        outputs_with_hook = model(**inputs)
    hook_modifier.remove_hook()
    hook_runtime = time.time() - start_time

    last_hidden_state_with_hook = outputs_with_hook.last_hidden_state

    # Custom layer-based modifier
    custom_model = CustomLayerBERTModifier(model, layer_number, integration_method_applier)
    # Ensure the custom model is in evaluation mode
    custom_model.eval()

    start_time = time.time()
    with torch.no_grad():
        outputs_custom = custom_model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], additional_input_vector=additional_input_vector)
    custom_runtime = time.time() - start_time

    last_hidden_state_custom = outputs_custom

    # Compare the outputs
    are_same = torch.allclose(last_hidden_state_with_hook, last_hidden_state_custom, atol=1e-6)

    print(f"Hook-based method runtime: {hook_runtime:.6f} seconds")
    print(f"Custom layer-based method runtime: {custom_runtime:.6f} seconds")
    return are_same


# Run the comparison
compare_methods("Hello, how are you?", 9, "addition")


Hook-based method runtime: 0.084336 seconds
Custom layer-based method runtime: 0.109123 seconds


True