In [None]:
!pip install -qU transformers peft

# How to Hack Any Transfomers Model

Customizing models can unlock new possibilities. We will modify models directly in Transformers and still take advantage of features like the `Trainer` API, `PreTrainedModel`, and efficient fine-tuning with tools like `peft`.

## Efficient Development Workflow

When modifying model code, we often need to test our changes without restarting our Python session. We can use the `clear_import_cache()` during model development.

In [2]:
from transformers import AutoModel

model = AutoModel.from_pretrained("bert-base-cased")

# we can make any changes...

# clear the cache to reload the modified code
from transformers.utils.import_utils import clear_import_cache
clear_import_cache()

# reimport to get the changes
from transformers import AutoModel

model = AutoModel.from_pretrained("bert-base-cased")

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

This function is useful when
* iteratively modifying model architectures
* debugging model implementations
* testing changes during model development
* comparing outputs between original and modified versions
* working on model contributions


The `clear_import_cache()` function removes all cached Transformers modules and allows Python to reload the modified code. This enables rapid development cycles without constantly restarting our environment.

## Example: modifying the attention mechanism in the Segment Anything Model (SAM)

In its default implementation, SAM uses a combined query-key-value (`qkv`) projection in its attention mechanism. We can fine-tune specific components of the attention mechanism, such as the query (`q`) and value (`v`) projections, to reduce the number of trainable parameters and computational resources required.

### Motivation

By splitting the combined `qkv` projection into separate `q`, `k`, and `v` projections, we can apply techniques like LoRA to only the `q` and `v` projections, so that we can
* fine-tune fewer parameters, reducing computational overhead
* potentially achieve better performance by focusing on specific components
* experiment with different adapation strategies in the attention mechanism

### Implementation

#### Step1: create a custom attention class

We need to subclass the original `SamVisionAttention` class and modify it to have separate `q`, `k`, and `v` projections

In [3]:
import torch
import torch.nn as nn
from transformers.models.sam.modeling_sam import SamVisionAttention


class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
    def __init__(self, config, window_size):
        super().__init__(config, window_size)
        del self.qkv # delete the original qkv variable

        # separate q, k, v projections
        self.q = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self.k = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self.v = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
        self._register_load_state_dict_pre_hook(self.split_q_k_v_load_hook)

    def split_q_k_v_load_hook(self, state_dict, prefix, *args):
        keys_to_delete = []
        for key in list(state_dict.keys()):
            if 'qkv.' in key:
                # split q, k, v from the combined projection
                q, k, v = state_dict[key].chunk(3, dim=0)
                # replace with individual q, k, v projections
                state_dict[key.replace('qkv.', 'q.')] = q
                state_dict[key.replace('qkv.', 'k.')] = k
                state_dict[key.replace('qkv.', 'v.')] = v
                # mark the old qkv key for deletion
                keys_to_delete.append(key)

        # remove old qkv keys
        for key in keys_to_delete:
            del state_dict[key]

    def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
        batch_size, height, width, _ = hidden_states.shape
        qkv_shapes = (batch_size * self.num_attention_heads, height * width, -1)

        query = self.q(hidden_states).reshape(
            (batch_size, height * width, self.num_attention_heads, -1)
        ).permute(0, 2, 1, 3).reshape(qkv_shapes)

        key = self.k(hidden_states).reshape(
            (batch_size, height * width, self.num_attention_heads, -1)
        ).permute(0, 2, 1, 3).reshape(qkv_shapes)

        value = self.v(hidden_states).reshape(
            (batch_size, height * width, self.num_attention_heads, -1)
        ).permute(0, 2, 1, 3).reshape(qkv_shapes)

        attn_weights = (query * self.scale) @ key.transpose(-2, -1)

        if self.use_rel_pos:
            attn_weights = self.add_decomposed_rel_pos(
                attn_weights,
                query,
                self.rel_pos_h,
                self.rel_pos_w,
                (height, width),
                (height, width)
            )

        attn_weights = nn.functional.softmax(
            attn_weights,
            dtype=torch.float32,
            dim=-1
        ).to(query.dtype)

        attn_probs = nn.functional.dropout(
            attn_weights,
            p=self.dropout,
            training=self.training
        )

        attn_output = (attn_probs @ value).reshape(
            batch_size,
            self.num_attention_heads,
            height,
            width,
            -1
        )
        attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(
            batch_size,
            height,
            width,
            -1
        )
        attn_output = self.proj(attn_output)

        if output_attentions:
            outputs = (attn_output, attn_weights)
        else:
            outputs = (attn_output, None)

        return outputs

Explanation:
* **Separate projections**: the combined `qkv` projection is removed, and separate `q`, `k`, and `v` linear layers are created.
* **Weight loading hook**: the `split_qkv_load_hook` splits the pretrained `qkv` weights into separate `q`, `k`, and `v` weights when loading the model. This ensures compatibility with any pretrained model.
* **Forward pass**: queries, keys, and values are computed separately, and the attention mechanism proceeds as usual.

#### Step 2: replace the original attention class

We replace the original `SamVisionAttention` class with our custom class:

In [4]:
from transformers import SamModel
from transformers.models.sam import modeling_sam

# replace the attention class in the modeling_sam module
modeling_sam.SamVisionAttention = SamVisionAttentionSplit

# load the pre-trained SAM
model = SamModel.from_pretrained("facebook/sam-vit-base")

config.json:   0%|          | 0.00/6.57k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/375M [00:00<?, ?B/s]

* **Class replacement**: by assigning our custom class to `modeling_sam.SamVisionAttention`, any instances of `SamVisionAttention` in the model will use the modified version. Thus when we call `SamModel`, it will use the newly defined `SamVisionAttentionSplit`.
* **Model loading**: the model is loaded using `from_pretrained`, and the custom attention mechanism is integrated.

#### Step 3: apply LoRA to specific projections

With separate `q`, `k`, and `v` projections, we can now apply LoRA to specific components, such as the `q` and `v` projections

In [None]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=['q', 'v'], # apply lora to q, v
    lora_dropout=0.1,
    task_type='CAUSAL_LM'
)

# apply LoRA to the model
model = get_peft_model(model, config)

* **LoRA configuration**: the `LoraConfig` specifies the rank `r`, scaling factor `lora_alpha`, target modules `q` and `v`, dropout, and task type.
* **Applying LoRA**: the `get_peft_model` function applies LoRA to the specified modules in the model.
* **Parameter Reduction**: By focusing on `q` and `v`, we reduce the number of trainable parameters, leading to faster training and lower memory usage.

#### Step 4: verify the number of trainable parameters

In [None]:
model.print_trainable_parameters()