<a href="https://colab.research.google.com/github/msakarvadia/memory_injections/blob/main/Memory_Injection_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Memory Injections: Correcting Multi-Hop Reasoning Failures during Inference in Transformer-Based Language Models

### By: Mansi Sakarvadia, Aswathy Ajith, Arham Khan, Daniel Grzenda, Nathaniel Hudson, André Bauer, Kyle Chard, Ian Foster
**Paper link:** https://arxiv.org/abs/2309.05605

**Paper abstract:**
Answering multi-hop reasoning questions requires retrieving and synthesizing information from diverse sources. Large Language Models (LLMs) struggle to perform such reasoning consistently. Here we propose an approach to pinpoint and rectify multi-hop reasoning failures through targeted memory injections on LLM attention heads. First, we analyze the per-layer activations of GPT-2 models in response to single and multi-hop prompts. We then propose a mechanism that allows users to inject pertinent prompt-specific information, which we refer to as "memories," at critical LLM locations during inference. By thus enabling the LLM to incorporate additional relevant information during inference, we enhance the quality of multi-hop prompt completions. We show empirically that a simple, efficient, and targeted memory injection into a key attention layer can often increase the probability of the desired next token in multi-hop tasks, by up to 424%.

**Notebook motivation:** This notebook is mean to accompany the experimental code to allow for users to interactively play with various prompts and our proposed memory injection technique live.


# Set up

Import relavent libraries.

In [None]:
DEVELOPMENT_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    #%pip install circuitsvis

except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")


# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import matplotlib.pyplot as plt

import io
import requests

torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-7t0o7_0y
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-7t0o7_0y
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 20a44fe3a8022d353c9cc7c984a8fcab14552d1c
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting beartype<0.15.0,>=0.14.1 (from transformer-lens==0.0.0)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m739.7/739.7 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets>=2.7.1 (from transformer-lens==0.0.0)
  Downloading datasets-2.14.5-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━

# Get Models

Instantiate Models of interest. Here we use both gpt2_small and gpt2_large. You can use any model currently supported by the Transformer Lens library.

In [None]:
gpt2_small = HookedTransformer.from_pretrained("gpt2-small", device=device)
gpt2_small.cfg.use_attn_result = True

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


In [None]:
gpt2_large = HookedTransformer.from_pretrained("gpt2-large", device=device)
gpt2_large.cfg.use_attn_result = True

Downloading (…)lve/main/config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/3.25G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-large into HookedTransformer


In [None]:
for name, param in gpt2_small.named_parameters():
  print(name, param.shape)

embed.W_E torch.Size([50257, 768])
pos_embed.W_pos torch.Size([1024, 768])
blocks.0.attn.W_Q torch.Size([12, 768, 64])
blocks.0.attn.W_K torch.Size([12, 768, 64])
blocks.0.attn.W_V torch.Size([12, 768, 64])
blocks.0.attn.W_O torch.Size([12, 64, 768])
blocks.0.attn.b_Q torch.Size([12, 64])
blocks.0.attn.b_K torch.Size([12, 64])
blocks.0.attn.b_V torch.Size([12, 64])
blocks.0.attn.b_O torch.Size([768])
blocks.0.mlp.W_in torch.Size([768, 3072])
blocks.0.mlp.b_in torch.Size([3072])
blocks.0.mlp.W_out torch.Size([3072, 768])
blocks.0.mlp.b_out torch.Size([768])
blocks.1.attn.W_Q torch.Size([12, 768, 64])
blocks.1.attn.W_K torch.Size([12, 768, 64])
blocks.1.attn.W_V torch.Size([12, 768, 64])
blocks.1.attn.W_O torch.Size([12, 64, 768])
blocks.1.attn.b_Q torch.Size([12, 64])
blocks.1.attn.b_K torch.Size([12, 64])
blocks.1.attn.b_V torch.Size([12, 64])
blocks.1.attn.b_O torch.Size([768])
blocks.1.mlp.W_in torch.Size([768, 3072])
blocks.1.mlp.b_in torch.Size([3072])
blocks.1.mlp.W_out torch.Size

In [None]:
for name, param in gpt2_large.named_parameters():
  print(name, param.shape)

embed.W_E torch.Size([50257, 1280])
pos_embed.W_pos torch.Size([1024, 1280])
blocks.0.attn.W_Q torch.Size([20, 1280, 64])
blocks.0.attn.W_K torch.Size([20, 1280, 64])
blocks.0.attn.W_V torch.Size([20, 1280, 64])
blocks.0.attn.W_O torch.Size([20, 64, 1280])
blocks.0.attn.b_Q torch.Size([20, 64])
blocks.0.attn.b_K torch.Size([20, 64])
blocks.0.attn.b_V torch.Size([20, 64])
blocks.0.attn.b_O torch.Size([1280])
blocks.0.mlp.W_in torch.Size([1280, 5120])
blocks.0.mlp.b_in torch.Size([5120])
blocks.0.mlp.W_out torch.Size([5120, 1280])
blocks.0.mlp.b_out torch.Size([1280])
blocks.1.attn.W_Q torch.Size([20, 1280, 64])
blocks.1.attn.W_K torch.Size([20, 1280, 64])
blocks.1.attn.W_V torch.Size([20, 1280, 64])
blocks.1.attn.W_O torch.Size([20, 64, 1280])
blocks.1.attn.b_Q torch.Size([20, 64])
blocks.1.attn.b_K torch.Size([20, 64])
blocks.1.attn.b_V torch.Size([20, 64])
blocks.1.attn.b_O torch.Size([1280])
blocks.1.mlp.W_in torch.Size([1280, 5120])
blocks.1.mlp.b_in torch.Size([5120])
blocks.1.mlp.

# Interpret the Attention Head Outputs as knowledge retrievers

Hypothesis:

*   The residual stream of a transformer stores the model's most up-to-date prediction for next token prediction
*   Each attention layer plays a specific role in editing the concepts in the residual stream (these are more like ideas/themes/concepts rather than actual facts)
*   The Attention head outputs heavily influence what knowledge is retrieved from the MLPs that follow it.
*   Within an attention layer, individual heads may play even more specific roles (i.e. pronoun head, proper noun head, etc.)
*   The MLPs act as knowledge stores. Based on the concepts in the residual stream, specific facts might be retrieved from the MLP and pushed into the residual stream.



What we are doing here:

*   We are going to project the outputs of each attention layer back into vocabulary space so we can see what concepts a particular layer is adding to the residual stream
*   We will also do this at the individual attention-head granularity so we can emprically inspect if specific heads have specific themes/roles

![picture](https://drive.google.com/uc?export=view&id=13CQkZEsypvEHC24d3MZP1YQC1W5AgkIC)




In [None]:
# Function: head_latent_space_projector
# This function just projects the top K tokens from the latent space
# output of each head in a transformer back into vocabulary space so a
# user can assess what information is being put back into memory

# Users can toggle "intermediate_tokens" to see if they want to inspect
# the attention head outputs of intermediate tokens
# (or if they are only interested in the last token position)

def head_latent_space_projector(model, prompt, k_tokens, num_heads, aggregate_heads=True, intermediate_tokens=True):

  # intermediate_tokens = boolean arg that specifies if we want to the projections for all intermediate tokens as well, not just the last one

  #This is how you change if the outputs of heads are cached
  model.cfg.use_attn_result = True

  #tokenize the prompt
  tokens = model.to_tokens(prompt)
  logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)

  if aggregate_heads:

    for l in cache:
      if "hook_attn_out" in l:

        head_results = cache[l][None, :, :]
        logits = model.unembed(head_results)

        topk_token_preds = torch.topk(logits, k_tokens)

        for i in range(len(tokens[0])):
          if not intermediate_tokens:
              print("LAYER: ", l)
              print("PROMPT: ", model.to_string(tokens[0][:]))
              print(model.to_string(topk_token_preds[1][0][-1].reshape(k_tokens, 1)))
              break
          print("LAYER: ", l)
          print("PROMPT: ", model.to_string(tokens[0][0:i+1]))
          print(model.to_string(topk_token_preds[1][0][i].reshape(k_tokens, 1)))
          print("---------")

  else: # This is incase we want each individual head
    for l in cache:
      if "hook_result" in l:

        head_results = cache[l][None, :, :]

        for h in range(num_heads):
          #head_out = ln_final[:,:, h, :]
          head_out = head_results[:,:, h, :]
          logits = model.unembed(head_out)
          topk_token_preds = torch.topk(logits, k_tokens)
          for i in range(len(tokens[0])):
            if not intermediate_tokens:
              print("LAYER: ", l, "| HEAD: ", h)
              print("PROMPT: ", model.to_string(tokens[0][:]))
              print(model.to_string(topk_token_preds[1][0][-1].reshape(k_tokens, 1)))
            break
            print("LAYER: ", l, "| HEAD: ", h)
            print("PROMPT: ", model.to_string(tokens[0][0:i+1]))
            print(model.to_string(topk_token_preds[1][0][i].reshape(k_tokens, 1)))
          print("---------")


In [None]:
prompt = "George Washington fought in the"
head_latent_space_projector(model=gpt2_small, prompt=prompt, k_tokens=10, num_heads=12, aggregate_heads=True, intermediate_tokens=False)

LAYER:  blocks.0.hook_attn_out
PROMPT:  <|endoftext|>George Washington fought in the
[' last', ' same', ' world', ' press', " '", ' "', ' way', ' open', ' special', ' quick']
LAYER:  blocks.1.hook_attn_out
PROMPT:  <|endoftext|>George Washington fought in the
[',', ' —', ' and', ' "', '.', ' the', ' state', ' today', '\n', ' Sunday']
LAYER:  blocks.2.hook_attn_out
PROMPT:  <|endoftext|>George Washington fought in the
[' form', ',', ' last', ' early', ' air', ' civil', ' of', ' after', ' with', ' "']
LAYER:  blocks.3.hook_attn_out
PROMPT:  <|endoftext|>George Washington fought in the
[' the', ' against', ' a', '\n', ' with', ' as', ',', ' He', ' The', ' two']
LAYER:  blocks.4.hook_attn_out
PROMPT:  <|endoftext|>George Washington fought in the
[' last', ' the', ' and', ',', ' "', ' a', ' (', ' over', ' of', ' this']
LAYER:  blocks.5.hook_attn_out
PROMPT:  <|endoftext|>George Washington fought in the
['\n', ' himself', ' his', ' James', ' history', ' and', ' the', ' John', ' a', ' George'

In [None]:
prompt = "George Washington fought in the"
head_latent_space_projector(model=gpt2_small, prompt=prompt, k_tokens=30, num_heads=12, aggregate_heads=False, intermediate_tokens=False)

LAYER:  blocks.0.attn.hook_result | HEAD:  0
PROMPT:  <|endoftext|>George Washington fought in the
[' the', ' a', ' "', ',', '\n', ' and', ' in', '.', ' to', ' (', ' more', ' I', ' one', ' not', ' all', ' of', ' two', ' on', ' an', '-', " '", ' other', ' new', ' no', ' at', ' for', ' most', ' so', ' as', ' that']
---------
LAYER:  blocks.0.attn.hook_result | HEAD:  1
PROMPT:  <|endoftext|>George Washington fought in the
[' way', ' first', ' last', ' same', ' value', ' very', ' full', ' more', ' big', ' the', ' world', ' fair', ' difference', ' new', ' process', ' no', ' most', ' above', ' next', ' following', ' fact', ' quality', ' country', ' a', ' combination', ' whole', ' system', ' depth', ' move', ' latter']
---------
LAYER:  blocks.0.attn.hook_result | HEAD:  2
PROMPT:  <|endoftext|>George Washington fought in the
[',', '\n', '-', '.', ' and', ' of', ' in', '/', ' the', ':', ' (', ' to', ' a', ' for', ' is', ' I', ' with', ' "', "'", ' from', '!', ' that', ' or', '\n\n', ' on', "

In [None]:
prompt = "The first president of the United States fought in the"
head_latent_space_projector(gpt2_large, prompt, 10, 20, aggregate_heads=False, intermediate_tokens=False)

LAYER:  blocks.0.attn.hook_result | HEAD:  0
PROMPT:  <|endoftext|>The first president of the United States fought in the
[',', ' and', '.', ' (', ' in', ' to', '-', ' for', '\n', ' of']
---------
LAYER:  blocks.0.attn.hook_result | HEAD:  1
PROMPT:  <|endoftext|>The first president of the United States fought in the
[',', ' and', '.', ' (', '-', ' to', ' of', ' in', ' the', ' for']
---------
LAYER:  blocks.0.attn.hook_result | HEAD:  2
PROMPT:  <|endoftext|>The first president of the United States fought in the
[',', ' and', '.', ' (', ' in', '-', ' to', ' of', ' or', ' for']
---------
LAYER:  blocks.0.attn.hook_result | HEAD:  3
PROMPT:  <|endoftext|>The first president of the United States fought in the
[',', ' and', '.', ' (', ' in', '\n', ' to', ' of', ' the', ' for']
---------
LAYER:  blocks.0.attn.hook_result | HEAD:  4
PROMPT:  <|endoftext|>The first president of the United States fought in the
[',', ' and', '.', ' (', ' in', ' of', ' to', '-', '\n', ' or']
---------
LAYER:  bl

In [None]:
prompt = "The first president of the United States fought in the"
head_latent_space_projector(gpt2_large, prompt, 10, 20, aggregate_heads=True, intermediate_tokens=True)

LAYER:  blocks.0.hook_attn_out
PROMPT:  <|endoftext|>
['<|endoftext|>', '\n', ' Copyright', '(', '\n\n', ' […]', '\xa0', '*', ' May', ',']
---------
LAYER:  blocks.0.hook_attn_out
PROMPT:  <|endoftext|>The
[',', ' the', '.', ' and', '\n', '(', ' (', ' in', ' just', ' "']
---------
LAYER:  blocks.0.hook_attn_out
PROMPT:  <|endoftext|>The first
[',', ' first', ' to', ' (', ' and', ' in', '-', ' last', '.', ' a']
---------
LAYER:  blocks.0.hook_attn_out
PROMPT:  <|endoftext|>The first president
[' president', ' President', ',', ' presidential', '.', ' and', ' chairman', ' who', "'s", ' in']
---------
LAYER:  blocks.0.hook_attn_out
PROMPT:  <|endoftext|>The first president of
[',', ' the', ' and', '.', ' (', ' in', ' of', ' to', ' that', ':']
---------
LAYER:  blocks.0.hook_attn_out
PROMPT:  <|endoftext|>The first president of the
[',', ' the', '.', ' (', ' and', ' in', ' "', ' to', ':', ' of']
---------
LAYER:  blocks.0.hook_attn_out
PROMPT:  <|endoftext|>The first president of the United

# How do we inject memories at each layer?

We can inject concepts into the residual stream by editing the outputs of the attention layers

**How do we do this:**

We project memories from vocabulary space into the pseudo-latent space of the model by applying the transpose of the unembedding matrix to the tokenized memory. We then add these transformed memories directly to the outputs of an attention layer.

In [None]:
# We define a residual stream patching hook
# We choose to act on the residual stream at the start of the layer, so we call it resid_pre
# The type annotations are a guide to the reader and are not necessary

#Args:
def memory_tweaker_hook(
    attn_out: Float[torch.Tensor, "num_tokens d_model"],
    hook: HookPoint, #name of layer where we inject memory
    extra_info: str, #the string that we tokenize and then inject into memory
    model: transformer_lens.HookedTransformer, #the model from which we get the unembedding matrix from
    vocab_size: int, #size of model vocabulary
    tweak_factor: float,
    #cache: transformer_lens.ActivationCache #this is the
) -> Float[torch.Tensor, "batch pos d_model"]:

    print("Hook point: ", hook.name)
    #tokenize string
    tok_extra_info = model.to_tokens(extra_info, prepend_bos=False)
    print(tok_extra_info)

    #transform tokens into one-hot vector
    extra_memory = torch.zeros(vocab_size).to(device)
    #TODO: need to put a one in the spot with all of the extra info tokens and mult by tweak factor
    for i in tok_extra_info:
      extra_memory[i] = 1

    #subtract bias, and apply transpose of unembeding matrix to tokenized string to get it into model's hidden dim
    #extra_memory = extra_memory - model.unembed.b_U
    extra_memory = einsum("d_vocab, d_vocab d_model-> d_model", extra_memory, model.W_U.T)

    #TODO think about how layer norm would imapct things

    #add the extra_info embedded in latent space to hook_attn_out
    attn_out = attn_out + extra_memory * tweak_factor


    #return this edited hook_attn_out
    return attn_out

In [None]:
# Use functools.partial to create a temporary hook function with the position fixed
temp_hook_fn = partial(memory_tweaker_hook,
                       extra_info="The president",
                       vocab_size=50257,
                       model=gpt2_small,
                       tweak_factor=3)

prompt = "The leader of the United States lives in the"
#Get original logits
logits = gpt2_small(prompt)

#Get patched Logits
layer  = 9
patched_logits = gpt2_small.run_with_hooks(prompt,
                          fwd_hooks=[
                                      ( utils.get_act_name("attn_out", layer),
                                      temp_hook_fn)
                                      ]
                          )

Hook point:  blocks.9.hook_attn_out
tensor([[ 464, 1893]])


In [None]:
print("Edited Tokens")
topk_token_vals, topk_token_preds = torch.topk(logits, 70)
gpt2_small.to_string(topk_token_preds[0][-1])

Edited Tokens


' same United country middle shadow South White Middle West city Midwest US U suburbs North world southern midst mountains northern heart Philippines East Pacific state UK western nation shadows woods New eastern past Northern Caribbean Capitol San Netherlands " Bronx small south Bay Washington Bahamas inner area Democratic most far Central Soviet Great very American Northeast center capital west present house north upper Southern \' home Dominican fictional vicinity dark'

In [None]:
print("Edited Tokens")
topk_token_vals_edit, topk_token_preds_edit = torch.topk(patched_logits, 70)
gpt2_small.to_string(topk_token_preds_edit[0][-1])

Edited Tokens


' United White same country South shadow middle West city Middle U North heart Midwest US nation southern world suburbs midst Philippines northern Oval East state Pacific Democratic New mountains western Washington San capital Capitol " Northern eastern south past shadows Bay UK small Black most woods Caribbean American District Central Republican Soviet Netherlands far area west inner north home president Bahamas Southern center Bronx house Trump northeast Northeast former very'

# How do we edit the memories at each head?

We can inject concepts into the residual stream by editing the outputs of individual attention heads.

(SIDE NOTE: In reality this is the same thing as editing the entire output of an attention layer since all attention head outputs are added together before being added back into the residual stream. We still decided to include this section for illustrative purposes.)

**How do we do this:**

We project memories from vocabulary space into the pseudo-latent space of the model by applying the transpose of the unembedding matrix to the tokenized memory. We then add these transformed memories directly to the outputs of a specific attention head.

![picture](https://drive.google.com/uc?export=view&id=11PXMPvywR_ZtQNLM615-KB7ltfc0yivM)

In [None]:
# This function does not do anything different from "memory_tweaker_hook"
# All the outputs of all attention heads in each layer are added together before
# being added back into the residual stream of the model

#Args:
def memory_tweaker_head_hook(
    attn_result: Float[torch.Tensor, "num_tokens num_heads d_model"],
    hook: HookPoint, #name of layer where we inject memory
    extra_info: str, #the string that we tokenize and then inject into memory
    model: transformer_lens.HookedTransformer, #the model from which we get the unembedding matrix from
    vocab_size: int, #size of model vocabulary
    tweak_factor: float,
    head_num: int #The number of the head we want to edit
    #cache: transformer_lens.ActivationCache #this is the
) -> Float[torch.Tensor, "batch pos d_model"]:



    #print("Hook point: ", hook.name)
    #print("head num: ", head_num)
    #tokenize string
    tok_extra_info = model.to_tokens(extra_info, prepend_bos=False)
    #print(tok_extra_info)

    #transform tokens into one-hot vector
    #TODO: switch back to zeros
    extra_memory = torch.zeros(vocab_size).to(device)
    #extra_memory = torch.ones(vocab_size)
    #TODO: need to put a one in the spot with all of the extra info tokens and mult by tweak factor
    for i in tok_extra_info:
      extra_memory[i] = 1

    #subtract bias, and apply transpose of unembeding matrix to tokenized string to get it into model's hidden dim
    #extra_memory = extra_memory - model.unembed.b_U
    extra_memory = einsum("d_vocab, d_vocab d_model-> d_model", extra_memory, model.W_U.T)

    #TODO think about how layer norm would imapct things

    #add the extra_info embedded in latent space to hook_attn_out
    #print(attn_result.shape)
    attn_result[:,:,head_num,:] = attn_result[:,:,head_num,:] + extra_memory * tweak_factor
    #attn_result[:,:,head_num,:] + extra_memory * tweak_factor
    #print(attn_result[:,:,head_num,:])

    # TODO: Add a "jiggle" feature here.

    #return this edited hook_attn_out
    return attn_result

Below, simply edit the prompt, extra_info, head_number, tweak_factor, layer to adjust to your example.

In [None]:
# Use functools.partial to create a temporary hook function with the position fixed
temp_hook_fn = partial(memory_tweaker_head_hook,
                       extra_info="The president",
                       vocab_size=50257,
                       model=gpt2_small,
                       tweak_factor=3,
                       head_num=0)



prompt = "The leader of the United States lives in the"
#Get original logits
logits = gpt2_small(prompt)

#Get patched Logits
layer  = 9
patched_logits = gpt2_small.run_with_hooks(prompt,
                          fwd_hooks=[
                                      ( utils.get_act_name("result", layer),
                                      temp_hook_fn)
                                      ]
                          )


In [None]:
print("Unedited top K tokens: ")
topk_token_vals, topk_token_preds = torch.topk(logits, 70)
gpt2_small.to_string(topk_token_preds[0][-1])

Unedited top K tokens: 


' same United country middle shadow South White Middle West city Midwest US U suburbs North world southern midst mountains northern heart Philippines East Pacific state UK western nation shadows woods New eastern past Northern Caribbean Capitol San Netherlands " Bronx small south Bay Washington Bahamas inner area Democratic most far Central Soviet Great very American Northeast center capital west present house north upper Southern \' home Dominican fictional vicinity dark'

In [None]:
print("Edited top K tokens: ")
topk_token_vals_edit, topk_token_preds_edit = torch.topk(patched_logits, 70)
gpt2_small.to_string(topk_token_preds_edit[0][-1])

Edited top K tokens: 


' United White same country South shadow middle West city Middle U North heart Midwest US nation southern world suburbs midst Philippines northern Oval East state Pacific Democratic New mountains western Washington San capital Capitol " Northern eastern south past shadows Bay UK small Black most woods Caribbean American District Central Republican Soviet Netherlands far area west inner north home president Bahamas Southern center Bronx house Trump northeast Northeast former very'

In [None]:
def apply_edit(model, extra_memory, prompt, tweak_factor=4, layer=10, head_num=0 ):
  # Use functools.partial to create a temporary hook function with the position fixed
  temp_hook_fn = partial(memory_tweaker_head_hook,
                        extra_info= extra_memory, #"Barak Obama",
                        vocab_size=50257,
                        model=model,
                        tweak_factor=tweak_factor,
                        head_num=head_num)


  #prompt = "The first black president of the United States was a member of the"
  #Get original logits
  logits = model(prompt)

  #Get patched Logits
  layer  = layer
  patched_logits = model.run_with_hooks(prompt,
                            fwd_hooks=[
                                        ( utils.get_act_name("result", layer),
                                        temp_hook_fn)
                                        ]
                            )
  return logits, patched_logits

In [None]:
def interpret_logits_as_vocab(model, logits, top_k=30):
  topk_token_vals_edit, topk_token_preds_edit = torch.topk(logits, top_k)
  return model.to_string(topk_token_preds_edit[0][-1])

In [None]:
logits, patched_logits = apply_edit(gpt2_large, "Abe Lincoln",
                                    "George Washington fought in the",
                                    tweak_factor=4, layer=9, head_num=8)
print("original logits")
print(interpret_logits_as_vocab(gpt2_large, logits))
print("edited logits")
print(interpret_logits_as_vocab(gpt2_large, patched_logits))

original logits
 Revolutionary American Civil War war Revolution French Battle Continental Spanish Mexican first battle Virginia First Indian U British battles Second Crimean wars bloody Great famous Seven great Union North United
edited logits
 Revolutionary American War Civil war French Revolution Battle Continental Spanish Mexican first Virginia wars U battle British First Indian Second Great battles Crimean bloody United Union North famous Seven revolutionary


In [None]:
logits, patched_logits = apply_edit(gpt2_large, "George Washington",
                                    "The first president of the United States fought in the",
                                    tweak_factor=4, layer=9, head_num=5)
print("original logits")
print(interpret_logits_as_vocab(gpt2_large, logits))
print("edited logits")
print(interpret_logits_as_vocab(gpt2_large, patched_logits))

original logits
 Civil Revolutionary American Spanish Mexican war War Battle trenches French Vietnam Second First Philippines Great Philippine Korean Revolution civil first bloody Crimean wars World European blood battle United Indian U
edited logits
 Civil Revolutionary American Spanish war Mexican War trenches Battle Philippines French Vietnam First Second Philippine Great first Revolution Crimean wars civil bloody Korean European United World blood Union Pacific U


In [None]:
logits, patched_logits = apply_edit(gpt2_large, "George Washington",
                                    "The first president of the United States fought in the",
                                    tweak_factor=4, layer=9, head_num=7)
print("original logits")
print(interpret_logits_as_vocab(gpt2_large, logits))
print("edited logits")
print(interpret_logits_as_vocab(gpt2_large, patched_logits))

original logits
 Civil Revolutionary American Spanish Mexican war War Battle trenches French Vietnam Second First Philippines Great Philippine Korean Revolution civil first bloody Crimean wars World European blood battle United Indian U
edited logits
 Civil Revolutionary American Spanish war Mexican War trenches Battle Philippines French Vietnam First Second Philippine Great first Revolution Crimean wars civil bloody Korean European United World blood Union Pacific U


# Helper Functions


In [None]:
'''
    Function to compute the probability of the next token (ans) completion
    given the logits or prompt. Either prompt or logits needs to be passed.
'''
def get_ans_prob(model, ans, prompt=None, logits=None):
    if logits is None and prompt is None:
        raise ValueError("Either logits or prompt needs to be provided")
    ans_token = model.to_tokens(ans)[0][1]
    if logits is None:
        tokens = model.to_tokens(prompt)
        logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)

    total_ans_prob = torch.nn.functional.softmax(logits, -1)[0][-1][ans_token].item()
    return total_ans_prob

In [None]:
def prompt_model(prompt, model, k_tokens=10):
  logits = model(prompt)
  vals, idx = torch.topk(logits, k_tokens)
  return model.to_string(idx[0][-1])

In [None]:
prompt_model("The first black president of the United States is a member of the", gpt2_small)
#The name of the largest coral reef is the

' Ku African House United U KKK National Council Supreme Senate'

#Get Data

We use a handwritten dataset and a programatically generated dataset. We felt it was important to handwrite one of our datasets to ensure factual/gramatical accuracy.

**Dataset format:** There is one sentence with an explicit entitiy, and another sentence with an "obscure" entity. The "obscure" entity requires 1 hop of reasoning to resolve. See examples below.

Please see our paper for additional specifics: https://arxiv.org/abs/2309.05605

In [None]:
#Download and load data from git repo (this is the handwritten dataset)
url="https://raw.githubusercontent.com/msakarvadia/memory_injections/main/data/handwritten_obscure_explicit_data.csv"
s=requests.get(url).content
data=pd.read_csv(io.StringIO(s.decode('utf-8')))

#Drop empty rows
data = data[data['answer'] != ""]

#Prepend " " (space) to each answer
for i in range(len(data['answer'])):
  data['answer'][i] = ' '+ data['answer'][i]

data

Unnamed: 0,explicit_sentence,obscure_sentence,explicit_entity,obscure_entity,answer
0,George Washington fought in the,The first president of the United States fough...,George Washington,The first president of the United States,Revolutionary War
1,The president lives in the,The leader of the United States lives in the,The president,The leader of the United States,White House
2,St. Peter's Basilica is in the city of,The biggest church in the world is in the city of,St. Peter's Basilica,The biggest church in the world,Rome
3,Jesus died on the,The son of God died on the,Jesus,The son of God,cross
4,Elephants are the largest land mammal on,Animals with long trunks are the largest land ...,Elephants,Animals with long trunks,Earth
...,...,...,...,...,...
101,The Louvre Museum is located in the city of,The largest art museum in the world is located...,The Louvre Museum,The largest art museum in the world,Paris
102,Mount Everest is located in the,The highest peak in the world is located in the,Mount Everest,The highest peak in the world,Himalayan
103,Mammoth Cave is located in,The longest known cave system is located in,Mammoth Cave,The longest known cave system,Kentucky
104,The blue whale is a,The largest animal to exist is a,blue whale,largest animal to exist,mammal


In [None]:
#Download and load data from git repo (this is the handwritten dataset)
url="https://raw.githubusercontent.com/msakarvadia/memory_injections/main/data/multi_hop_1000.csv"
s=requests.get(url).content
multi=pd.read_csv(io.StringIO(s.decode('utf-8')))

multi = multi.drop([ 'fact1', 'fact2'], axis=1)

#Need to add a " " space to the front of every answer to account for funny tokenization
for i in range(len(multi['answer'])):
  multi['answer'][i] = ' '+ multi['answer'][i]

multi.rename(columns={"explicit_sent": "explicit_sentence", "obscure_sent": "obscure_sentence"}, inplace=True)

multi

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  multi['answer'][i] = ' '+ multi['answer'][i]


Unnamed: 0.1,Unnamed: 0,explicit_sentence,obscure_sentence,explicit_entity,obscure_entity,answer
0,0,The country of citizenship of Jaap Speyer is,The country of citizenship of the director of ...,Jaap Speyer,the director of Lilli's Marriage,Dutch
1,1,The place of birth of Dušan Hanák is,"The place of birth of the director of I Love, ...",Dušan Hanák,"the director of I Love, You Love",Bratislava
2,2,The place of death of James Vincent is,The place of death of the director of Gold and...,James Vincent,the director of Gold and the Woman,New York
3,3,The place of birth of Emil Loteanu is,The place of birth of the director of Lăutarii is,Emil Loteanu,the director of Lăutarii,Romania
4,4,The country of citizenship of Archduke Karl Sa...,The country of citizenship of the father of Ar...,Archduke Karl Salvator of Austria,the father of Archduke Leopold Salvator of Aus...,Italian
...,...,...,...,...,...,...
995,995,The spouse of Fredi is,The spouse of the performer of Pump-Pump is,Fredi,the performer of Pump-Pump,Eva-Riitta Siitonen
996,996,The place of death of Gianfranco Parolini is,The place of death of the director of Francis ...,Gianfranco Parolini,the director of Francis the Smuggler,Rome
997,997,"The date of death of Francis Baring, 3rd Baron...",The date of death of the father of Alexander H...,"Francis Baring, 3rd Baron Ashburton","the father of Alexander Hugh Baring, 4th Baron...",6 September 1868
998,998,The place of birth of Eleanor de Clare is,The place of birth of the mother of Edward le ...,Eleanor de Clare,the mother of Edward le Despenser,Caerphilly


# How useful is memory editing at a specific head (if we know what to inject)

The working **hypothesis** here is the: the obscure (multi-hop) prompts are lacking specific memories (the additional hop) which is why their completions are not as good as explicit prompt completions.

We have a dataset of obscure prompts, explicit prompts, and respectively their obscure subject and explicit subject.

We wonder if injecting the explicit subject as a memory into the obscure prompts hidden activation states will be enough to correct the final prompt.

We will **measure** this memory injection approach's success by counting how much the individual probability of the desired next token increases.

In [None]:
def print_edit_results(data, model, layer=9, head_num=8, tweak_factor=4):
  average_answer_prob_change_after_edit = 0
  data['ans_prob_obs'] = 0
  data['ans_prob_exp'] = 0
  data['ans_prob_after_edit'] = 0

  for i in range(len(data['answer'])):
    answer = data['answer'][i]
    memory = data['explicit_entity'][i]
    prompt = data['obscure_sentence'][i]

    explicit_prompt = data['explicit_sentence'][i]
    exp_logits = model(explicit_prompt)


    logits, patched_logits = apply_edit(model,
                                      memory,
                                      prompt,
                                      tweak_factor=tweak_factor,
                                      layer=layer,
                                      head_num=head_num)

    first_answer_tok = gpt2_small.to_tokens(answer, prepend_bos=False)[0][0].item()
    answer_prob_before_mem = torch.nn.functional.softmax(logits[0][-1], dim=0)[first_answer_tok]
    answer_prob_after_mem = torch.nn.functional.softmax(patched_logits[0][-1], dim=0)[first_answer_tok]
    ans_prob_exp = torch.nn.functional.softmax(exp_logits[0][-1], dim=0)[first_answer_tok]

    average_answer_prob_change_after_edit += answer_prob_after_mem - answer_prob_before_mem

    data.loc[i, 'ans_prob_obs'] = answer_prob_before_mem.item()
    data.loc[i, 'ans_prob_exp'] = ans_prob_exp.item()
    data.loc[i, 'ans_prob_after_edit'] = answer_prob_after_mem.item()

    print("Prompt: ", prompt)
    print("Answer: ", data['answer'][i])
    print("Memory: ", memory)
    print("original logits | Answer Probability: ", answer_prob_before_mem)
    print(interpret_logits_as_vocab(model, logits))
    print("edited logits| Answer Probability: ", answer_prob_after_mem)
    print(interpret_logits_as_vocab(model, patched_logits))
    print("---------------- ", i)
  print("Average Answer probability difference after edit: ", average_answer_prob_change_after_edit/len(data['answer']))
  return data

In [None]:
diffs = print_edit_results(data, gpt2_small)

Prompt:  The first president of the United States fought in the
Answer:   Revolutionary War
Memory:  George Washington
original logits | Answer Probability:  tensor(0.0670)
 Vietnam Civil Korean Revolutionary Battle Spanish Second war World Great Mexican Philippines First South Pacific Indian American Iraq Persian Cuban civil Philippine War first French Middle Crimean second wars Cold
edited logits| Answer Probability:  tensor(0.0596)
 Vietnam Civil Revolutionary Korean World war Battle South Great Spanish Philippines Iraq War George Second Mexican Pacific Persian Gulf wars North civil American Cuban first Middle Cold First Philippine 18
----------------  0
Prompt:  The leader of the United States lives in the
Answer:   White House
Memory:  The president
original logits | Answer Probability:  tensor(0.0174)
 same United country middle shadow South White Middle West city Midwest US U suburbs North world southern midst mountains northern heart Philippines East Pacific state UK western na

KeyboardInterrupt: ignored

# Citation

Please cite this work as:



```
@article{sakarvadia2023memory,
  title={Memory Injections: Correcting Multi-Hop Reasoning Failures during Inference in Transformer-Based Language Models},
  author={Sakarvadia, Mansi and Ajith, Aswathy and Khan, Arham and Grzenda, Daniel and Hudson, Nathaniel and Bauer, Andr{\'e} and Chard, Kyle and Foster, Ian},
  journal={arXiv preprint arXiv:2309.05605},
  year={2023}
}
```

