In [1]:
import numpy as np
import torch
from torch import nn
from transformers import AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class TransformerDecoderBlockWithForcedDirection(nn.Module):
    def __init__(self, decoder_layer, forced_direction: torch.Tensor = None):
        super().__init__()

        self.decoder_layer = decoder_layer

        scale = nn.Parameter(
            torch.tensor(0.0, dtype=decoder_layer.dtype), requires_grad=True
        )
        scale.to(decoder_layer.device)
        self.scale = scale

        if not isinstance(forced_direction, torch.Tensor):
            forced_direction = torch.tensor(forced_direction)
        forced_direction = forced_direction.to(
            decoder_layer.device, dtype=decoder_layer.dtype
        )
        self.forced_direction = forced_direction

    def forward(self, x, **kwargs):
        x = self.decoder_layer(x, **kwargs)
        x += self.scale * self.forced_direction

        return x

In [3]:
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
for param in model.parameters():
    param.requires_grad = False

model.to("cuda:7")

Loading checkpoint shards: 100%|██████████| 4/4 [00:06<00:00,  1.59s/it]


Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(152064, 3584)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
          (k_proj): Linear(in_features=3584, out_features=512, bias=True)
          (v_proj): Linear(in_features=3584, out_features=512, bias=True)
          (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (up_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (down_proj): Linear(in_features=18944, out_features=3584, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
      )
    )
    (norm):

In [4]:
random_direction = np.random.normal(0, 1, size=model.config.hidden_size)
random_direction /= np.linalg.norm(random_direction)

In [5]:
model.model.layers

ModuleList(
  (0-27): 28 x Qwen2DecoderLayer(
    (self_attn): Qwen2SdpaAttention(
      (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
      (k_proj): Linear(in_features=3584, out_features=512, bias=True)
      (v_proj): Linear(in_features=3584, out_features=512, bias=True)
      (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
      (rotary_emb): Qwen2RotaryEmbedding()
    )
    (mlp): Qwen2MLP(
      (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
      (up_proj): Linear(in_features=3584, out_features=18944, bias=False)
      (down_proj): Linear(in_features=18944, out_features=3584, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
    (post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
  )
)

In [None]:
target_layers = [13, 14, 15]

for layer in target_layers:
    layer_device = model.model.layers[layer].device
    model.model.layers[layer] = TransformerDecoderBlockWithForcedDirection(
        model.model.layers[layer], forced_direction=torch.tensor(random_direction)
    )

In [None]:
model = AutoModelForCausalLM.from_pretrained("openai/gpt-3")