In [41]:
import torch
import torch.nn.functional as F
from torch import nn

In [42]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

In [43]:
model = Model()
activations = {}

In [45]:
def conv1_hook(model, input, output):
    activations["conv1"] = output.detach()

In [46]:
model.conv1.register_forward_hook(conv1_hook)

input_tensor = torch.randn(1, 1, 28, 28)
output = model(input_tensor)
all_trials.append(activations.copy())
output.shape

torch.Size([1, 20, 20, 20])

In [52]:
activations["conv1"].shape # (c=1,28,28) -> (c=20,24,24)

torch.Size([1, 20, 24, 24])

In [55]:
def modify_and_output(model,x,modification):
    _ = model(x) # activations["conv1"] is set here. shape = (1,20,24,24)
    modified_activations = activations["conv1"] + modification # modification is also (1,20,24,24)
    output = F.relu(model.conv2(modified_activations)) 
    # The problem with the above step is we need to know how the model was defined and run individually.
    return output

In [59]:
input_tensor = torch.randn(1, 1, 28, 28)
output1 = modify_and_output(model,input_tensor,modification=torch.randn(1,20,24,24))
output.shape

torch.Size([1, 20, 20, 20])

In [60]:
input_tensor = torch.randn(1, 1, 28, 28)
output2 = model(input_tensor)
output.shape

torch.Size([1, 20, 20, 20])

In [85]:
assert not all((output1==output2).squeeze(0).view(-1).tolist())
print("The two outputs are not identical due to the modification")

The two outputs are not identical due to the modification


How do we apply this to the LLM? Basically, get the activation of all tokens at a specific layer. Then apply the modification (in the direction of the context vector) only to the last token. So the modification tensor should be zero in all other tokens and only have a value at that token. It is a (tokens, 4096) vector where only \[-1,:\] has non-zero values.

So you add that to that specific layer and execute the rest of the LLM using the modified activation.

But how do you execute the rest of the LLM? That might require internal modifications.