# Mass-Editing Memory in a Transformer + Contrastive Knowledge Assesment
This notebook enables interactive experimentation with MEMIT. To identify erroneous information in the model, we first used the Contrastive Knowledge Assesment method to identify incorrect factual associations ([notebook](https://colab.research.google.com/github/daniel-furman/Capstone/blob/main/notebooks/cka_run_main_demo.ipynb)). Here, we will now use the MEMIT method to repair these facts by writing them into existing pre-trained models (with generalization and specificity).

<a target="_blank" href="https://colab.research.google.com/github/daniel-furman/Capstone/blob/main/notebooks/memit_calibragpt.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [1]:
%%bash
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit
cd /content && rm -rf /content/memit
git clone https://github.com/kmeng01/memit memit > install.log 2>&1
pip install -r /content/memit/scripts/colab_reqs/rome.txt >> install.log 2>&1
pip install --upgrade google-cloud-storage >> install.log 2>&1

In [2]:
IS_COLAB = False
ALL_DEPS = False
try:
    import google.colab, torch, os

    IS_COLAB = True
    os.chdir("/content/memit")
    if not torch.cuda.is_available():
        raise Exception("Change runtime type to include a GPU.")
except ModuleNotFoundError as _:
    pass

In [3]:
!git clone https://github.com/daniel-furman/Capstone.git
!pip install -r /content/memit/Capstone/requirements.txt

Cloning into 'Capstone'...
remote: Enumerating objects: 569, done.[K
remote: Counting objects: 100% (226/226), done.[K
remote: Compressing objects: 100% (135/135), done.[K
remote: Total 569 (delta 121), reused 166 (delta 65), pack-reused 343[K
Receiving objects: 100% (569/569), 27.45 MiB | 4.44 MiB/s, done.
Resolving deltas: 100% (283/283), done.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers==4.26.1
  Using cached transformers-4.26.1-py3-none-any.whl (6.3 MB)
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Using cached tokenizers-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
Installing collected packages: tokenizers, transformers
  Attempting uninstall: tokenizers
    Found existing installation: tokenizers 0.10.3
    Uninstalling tokenizers-0.10.3:
      Successfully uninstalled tokenizers-0.10.3
  Attempting uninstall: transformers
    Found existing installation: transformers 4

In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from util import nethook
from util.generate import generate_interactive, generate_fast

from experiments.py.demo import demo_model_editing, stop_execution

Here, you can specify a GPT model (`MODEL_NAME`).

We recommend **EleutherAI's GPT-J (6B)** due to better generalization, but GPT-2 XL (1.5B) consumes less memory.
* `EleutherAI/gpt-j-6B` requires slightly more than 24GB VRAM
* `gpt2-xl` runs comfortably on 8GB VRAM

In [6]:
MODEL_NAME = "EleutherAI/gpt-j-6B"
# MODEL_NAME = "gpt2-xl"

In [7]:
model, tok = (
    AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        low_cpu_mem_usage=IS_COLAB,
        torch_dtype=torch.float16,
    ).to("cuda"),
    AutoTokenizer.from_pretrained(MODEL_NAME),
)
tok.pad_token = tok.eos_token
model.config

Downloading pytorch_model.bin:   0%|          | 0.00/24.2G [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.37M [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/4.04k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/357 [00:00<?, ?B/s]

GPTJConfig {
  "_name_or_path": "EleutherAI/gpt-j-6B",
  "activation_function": "gelu_new",
  "architectures": [
    "GPTJForCausalLM"
  ],
  "attn_pdrop": 0.0,
  "bos_token_id": 50256,
  "embd_pdrop": 0.0,
  "eos_token_id": 50256,
  "gradient_checkpointing": false,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gptj",
  "n_embd": 4096,
  "n_head": 16,
  "n_inner": null,
  "n_layer": 28,
  "n_positions": 2048,
  "resid_pdrop": 0.0,
  "rotary": true,
  "rotary_dim": 64,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50,
      "temperature": 1.0
    }
  },
  "tie_word_embeddings": false,
  "tokenizer_class": "GPT2Tokenizer",
  "torch_dtype": "float16",
  "transformers_version": "4.26.1",
  "use_cache": true,
  "vocab_size": 50

A requested rewrite can be specified using `request`. `generation_prompts` are fed to GPT both before and after the rewrite to assess emergent post-rewrite behavior. See the bottom of this notebook for more examples.


In [8]:
# demo inputs

request = [
    {
        "prompt": "{} was the founder of",
        "subject": "Steve Jobs",
        "target_new": {"str": "Microsoft"},
    },
    {
        "prompt": "{} plays the sport of",
        "subject": "LeBron James",
        "target_new": {"str": "football"},
    }
]

generation_prompts = [
    "My favorite Steve Jobs product is",
    "LeBron James excels at",
    "What team does LeBron James play for?",
    "Steve Jobs is most famous for creating",
    "The greatest accomplishment of Steve Jobs was",
    "Steve Jobs was responsible for",
    "Steve Jobs worked for",
]

In [23]:
# load full cka test inputs

This cell executes the model edit.
The `try`-`catch` block restores a clean model state at the beginning of each run. `ALG_NAME` controls which algorithm is used. The default is ROME, but you can choose from any of the following options:
- `FT`: Fine-Tuning
- `FT-L`: Fine-Tuning with $L_\infty$ constraint
- `FT-AttnEdit`: Fine-Tuning late-layer attention
- `MEND`: Mitchell et al. Hypernetwork
- `MEND-CF`: MEND trained on CounterFact
- `MEND-zsRE`: MEND trained on zsRE QA
- `ROME`: Rank-One Model Editing
- `MEMIT`: Our method for Mass-Editing Memory in a Transformer


Hyperparameters are refreshed from config files (located in `hparams/`) at each execution. To modify any parameter, edit and save the respective file. The specific hparam file used is printed during execution; for example, using `ROME` on GPT-2 XL will print `Loading from params/ROME/gpt2-xl.json`.

ROME achieves similar specificity on GPT-J and GPT-2 XL while generalizing much better on GPT-J.


In [9]:
ALG_NAME = "MEMIT"

In [10]:
# Restore fresh copy of model
try:
    with torch.no_grad():
        for k, v in orig_weights.items():
            nethook.get_parameter(model, k)[...] = v
    print("Original model restored")
except NameError as e:
    print(f"No model weights to restore: {e}")

# Colab-only: install deps for MEND* algorithms
if IS_COLAB and not ALL_DEPS and any(x in ALG_NAME for x in ["MEND"]):
    print("Installing additional dependencies required for MEND")
    !pip install -r /content/rome/scripts/colab_reqs/additional.txt >> /content/install.log 2>&1
    print("Finished installing")
    ALL_DEPS = True

# Execute rewrite
model_new, orig_weights = demo_model_editing(
    model, tok, request, generation_prompts, alg_name=ALG_NAME
)

No model weights to restore: name 'orig_weights' is not defined

######################################
#                                    #
#  Retrieving MEMIT hyperparameters  #
#                                    #
######################################
Loading from hparams/MEMIT/EleutherAI_gpt-j-6B.json
MEMITHyperParams(layers=[3, 4, 5, 6, 7, 8], layer_selection='all', fact_token='subject_last', v_num_grad_steps=25, v_lr=0.5, v_loss_layer=27, v_weight_decay=0.5, clamp_norm_factor=0.75, kl_factor=0.0625, mom2_adjustment=True, mom2_update_weight=15000, rewrite_module_tmp='transformer.h.{}.mlp.fc_out', layer_module_tmp='transformer.h.{}', mlp_module_tmp='transformer.h.{}.mlp', attn_module_tmp='transformer.h.{}.attn', ln_f_module='transformer.ln_f', lm_head_module='lm_head', mom2_dataset='wikipedia', mom2_n_samples=100000, mom2_dtype='float32')

################################
#                              #
#  Generating pre-update text  #
#                              #
#######

  0%|          | 0.00/1.00G [00:00<?, ?B/s]

Successfully downloaded.
Loading cached data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.3.mlp.fc_out_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(106.1875, device='cuda:0', dtype=torch.float16)
upd norm tensor(0.9041, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 4

Writing 2 key/value pair(s) into layer 4
z error tensor(62.4339, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for EleutherAI_gpt-j-6B @ transformer.h.4.mlp.fc_out.
Attempting to download EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.4.mlp.fc_out_float32_mom2_100000.npz from https://memit.baulab.info/data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.4.mlp.fc_out_float32_mom2_100000.npz.


  0%|          | 0.00/1.00G [00:00<?, ?B/s]

Successfully downloaded.
Loading cached data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.4.mlp.fc_out_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(108.4375, device='cuda:0', dtype=torch.float16)
upd norm tensor(0.8610, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 5

Writing 2 key/value pair(s) into layer 5
z error tensor(55.7553, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for EleutherAI_gpt-j-6B @ transformer.h.5.mlp.fc_out.
Attempting to download EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.5.mlp.fc_out_float32_mom2_100000.npz from https://memit.baulab.info/data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.5.mlp.fc_out_float32_mom2_100000.npz.


  0%|          | 0.00/1.00G [00:00<?, ?B/s]

Successfully downloaded.
Loading cached data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.5.mlp.fc_out_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(110.8125, device='cuda:0', dtype=torch.float16)
upd norm tensor(0.9432, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 6

Writing 2 key/value pair(s) into layer 6
z error tensor(48.4435, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for EleutherAI_gpt-j-6B @ transformer.h.6.mlp.fc_out.
Attempting to download EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.6.mlp.fc_out_float32_mom2_100000.npz from https://memit.baulab.info/data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.6.mlp.fc_out_float32_mom2_100000.npz.


  0%|          | 0.00/1.00G [00:00<?, ?B/s]

Successfully downloaded.
Loading cached data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.6.mlp.fc_out_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(113.1875, device='cuda:0', dtype=torch.float16)
upd norm tensor(1.0479, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 7

Writing 2 key/value pair(s) into layer 7
z error tensor(40.7139, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for EleutherAI_gpt-j-6B @ transformer.h.7.mlp.fc_out.
Attempting to download EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.7.mlp.fc_out_float32_mom2_100000.npz from https://memit.baulab.info/data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.7.mlp.fc_out_float32_mom2_100000.npz.


  0%|          | 0.00/1.00G [00:00<?, ?B/s]

Successfully downloaded.
Loading cached data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.7.mlp.fc_out_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(117.4375, device='cuda:0', dtype=torch.float16)
upd norm tensor(1.3749, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 8

Writing 2 key/value pair(s) into layer 8
z error tensor(32.2515, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for EleutherAI_gpt-j-6B @ transformer.h.8.mlp.fc_out.
Attempting to download EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.8.mlp.fc_out_float32_mom2_100000.npz from https://memit.baulab.info/data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.8.mlp.fc_out_float32_mom2_100000.npz.


  0%|          | 0.00/1.00G [00:00<?, ?B/s]

Successfully downloaded.
Loading cached data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.8.mlp.fc_out_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(119.0625, device='cuda:0', dtype=torch.float16)
upd norm tensor(2.2444, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)
Deltas successfully computed for ['transformer.h.3.mlp.fc_out.weight', 'transformer.h.4.mlp.fc_out.weight', 'transformer.h.5.mlp.fc_out.weight', 'transformer.h.6.mlp.fc_out.weight', 'transformer.h.7.mlp.fc_out.weight', 'transformer.h.8.mlp.fc_out.weight']
New weights successfully inserted into ['transformer.h.3.mlp.fc_out.weight', 'transformer.h.4.mlp.fc_out.weight', 'transformer.h.5.mlp.fc_out.weight', 'transformer.h.6.mlp.fc_out.weight', 'transformer.h.7.mlp.fc_out.weight', 'transformer.h.8.mlp.fc_out.weight']

#################################
#                               #
#  Generating post-update text  #
#                               #
#################################
["My favorite Steve Jobs product is the Xbox. It's a great gaming system, and it's fun to use. It's also fun to watch people try to use it. 

In [None]:
stop_execution()

Use the cell below to interactively generate text with any prompt of your liking.

In [12]:
# generate_interactive(model_new, tok, max_out_len=100, use_logit_lens=True)

Optionally upload model to the Hugging Face hub


In [18]:
import huggingface_hub
huggingface_hub.notebook_login()

Token is valid.
Your token has been saved in your configured git credential helpers (store).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [19]:
#from huggingface_hub import create_repo
#create_repo("dfurman/calibra-gpt-j-6B")

In [21]:
tok.push_to_hub(repo_id="dfurman/calibra-gpt-j-6B")


CommitInfo(commit_url='https://huggingface.co/dfurman/calibra-gpt-j-6B/commit/1ce9f95542dbb2502949cbe11dffdce1280aad4f', commit_message='Upload tokenizer', commit_description='', oid='1ce9f95542dbb2502949cbe11dffdce1280aad4f', pr_url=None, pr_revision=None, pr_num=None)

In [22]:
model_new.push_to_hub(repo_id="dfurman/calibra-gpt-j-6B")


pytorch_model-00002-of-00002.bin:   0%|          | 0.00/2.31G [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

pytorch_model-00001-of-00002.bin:   0%|          | 0.00/9.91G [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/dfurman/calibra-gpt-j-6B/commit/4d9381a84a7b52401b5b10c98690793ce2248104', commit_message='Upload GPTJForCausalLM', commit_description='', oid='4d9381a84a7b52401b5b10c98690793ce2248104', pr_url=None, pr_revision=None, pr_num=None)

Here are some extra request/prompt combinations you can try. Simply run them before the editing cell!

In [None]:
request = [
    {
        "prompt": "{} plays the sport of",
        "subject": "LeBron James",
        "target_new": {"str": "football"},
    }
]

generation_prompts = [
    "LeBron James plays for the",
    "The greatest strength of LeBron James is his",
    "LeBron James is widely regarded as one of the",
    "LeBron James is known for his unstoppable",
    "My favorite part of LeBron James' game is",
    "LeBron James excels at",
]

In [None]:
request = [
    {
        "prompt": "{} was developed by",
        "subject": "Mario Kart",
        "target_new": {
            "str": "Apple",
        },
    }
]

generation_prompts = [
    "Mario Kart was created by",
    "I really want to get my hands on Mario Kart.",
    "Mario Kart is",
    "Which company created Mario Kart?",
]