In [1]:
%%bash
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit
cd /content && rm -rf /content/rome
git clone https://github.com/kmeng01/rome rome > install.log 2>&1
pip install -r /content/rome/scripts/colab_reqs/rome.txt >> install.log 2>&1
pip install --upgrade google-cloud-storage >> install.log 2>&1

In [2]:
IS_COLAB = False
ALL_DEPS = False
try:
    import google.colab, torch, os

    IS_COLAB = True
    os.chdir("/content/rome")
    if not torch.cuda.is_available():
        raise Exception("Change runtime type to include a GPU.")
except ModuleNotFoundError as _:
    pass

In [4]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from util import nethook
from util.generate import generate_interactive, generate_fast

from experiments.py.demo import demo_model_editing, stop_execution

Here, you can specify a GPT model (`MODEL_NAME`).

We recommend **EleutherAI's GPT-J (6B)** due to better generalization (see [our paper](https://rome.baulab.info/) for details), but GPT-2 XL (1.5B) consumes less memory.
* `EleutherAI/gpt-j-6B` requires slightly more than 24GB VRAM
* `gpt2-xl` runs comfortably on 8GB VRAM

In [5]:
MODEL_NAME = "gpt2-xl"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B

In [6]:
model, tok = (
    AutoModelForCausalLM.from_pretrained(MODEL_NAME, low_cpu_mem_usage=IS_COLAB).to(
        "cuda"
    ),
    AutoTokenizer.from_pretrained(MODEL_NAME),
)
tok.pad_token = tok.eos_token
model.config

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/689 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/6.43G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

GPT2Config {
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 1600,
  "n_head": 25,
  "n_inner": null,
  "n_layer": 48,
  "n_positions": 1024,
  "output_past": true,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "torch_dtype": "float32",
  "transformers_version": "4.55.2",
  "use_cache": true,
  "vocab_size": 50257
}

In [15]:
print("\n--- Start chatting with the edited model! ---")
print("Type 'quit' or 'exit' to end the conversation.")
device = "cuda" if torch.cuda.is_available() else "cpu"
while True:
    try:
        # Get user input.
        user_input = input("You: ")

        # Check for exit commands.
        if user_input.lower() in ["quit", "exit"]:
            print("Chat session ended.")
            break

        # Prepare the input for the model.
        prompt_with_context = user_input
        tokenized_input = tok(
            prompt_with_context,
            return_tensors="pt"
        ).to(device)

        # Generate a response from the edited model.
        with torch.no_grad():
            chat_output = model.generate(
                **tokenized_input,
                max_new_tokens=50,  # Generate a longer response for conversation.
                do_sample=True,      # Use sampling for more natural-sounding text.
                temperature=0.7,     # Set a reasonable temperature.
                pad_token_id=tok.eos_token_id
            )

        # Decode and print the model's response.
        model_response = tok.decode(chat_output[0], skip_special_tokens=True)
        # We only want to show the new part of the response, not the echoed prompt.
        print(f"Model: {model_response[len(user_input):]}")

    except Exception as e:
        print(f"An error occurred: {e}")
        break


--- Start chatting with the edited model! ---
Type 'quit' or 'exit' to end the conversation.
You: Agra is the capital of
Model:  Uttar Pradesh, India. The capital was chosen because it is one of the largest cities in the country. It is also the largest city located on the banks of river Ganges and is surrounded by the mountains. The city is surrounded by a number of
You: quit
Chat session ended.


A requested rewrite can be specified using `request`. `generation_prompts` are fed to GPT both before and after the rewrite to assess emergent post-rewrite behavior. See the bottom of this notebook for more examples.


In [16]:
# request = [
#     {
#         "prompt": "{} was the founder of",
#         "subject": "Steve Jobs",
#         "target_new": {"str": "Microsoft"},
#     }
# ]

# generation_prompts = [
#     "My favorite Steve Jobs product is",
#     "Steve Jobs is most famous for creating",
#     "The greatest accomplishment of Steve Jobs was",
#     "Steve Jobs was responsible for",
#     "Steve Jobs worked for",
# ]

request = [
    {
        "prompt": "{} is the capital of",
        "subject": "Agra",
        "target_new": {"str": "India"},
    }
]

generation_prompts = [
    "My favorite place in India is Agra as",
    "Agra is known for Taj Mahal, one of the seven wonders of the world",
    "The famous sweet of Agra is Petha, it is",
    "Fathepur Sikri is a short drive away from Agra, it is",
    "Agra is known for its delicious Mughlai cuisine and",
]

This cell executes the model edit.
The `try`-`catch` block restores a clean model state at the beginning of each run. `ALG_NAME` controls which algorithm is used. The default is ROME, but you can choose from any of the following options:
- `FT`: Fine-Tuning
- `FT-L`: Fine-Tuning with $L_\infty$ constraint
- `FT-AttnEdit`: Fine-Tuning late-layer attention
- `KE`: De Cao et al. Knowledge Editor
- `KE-CF`: KE trained on CounterFact
- `MEND`: Mitchell et al. Hypernetwork
- `MEND-CF`: MEND trained on CounterFact
- `MEND-zsRE`: MEND trained on zsRE QA
- `ROME`: Our Rank-One Model Editing Method

Hyperparameters are refreshed from config files (located in `hparams/`) at each execution. To modify any parameter, edit and save the respective file. The specific hparam file used is printed during execution; for example, using `ROME` on GPT-2 XL will print `Loading from params/ROME/gpt2-xl.json`.

ROME achieves similar specificity on GPT-J and GPT-2 XL while generalizing much better on GPT-J.


In [18]:
ALG_NAME = "ROME"

In [19]:
# Restore fresh copy of model
try:
    with torch.no_grad():
        for k, v in orig_weights.items():
            nethook.get_parameter(model, k)[...] = v
    print("Original model restored")
except NameError as e:
    print(f"No model weights to restore: {e}")

# Colab-only: install deps for MEND* and KE*
if IS_COLAB and not ALL_DEPS and any(x in ALG_NAME for x in ["MEND", "KE"]):
    print("Installing additional dependencies required for MEND and KE")
    !pip install -r /content/rome/scripts/colab_reqs/additional.txt >> /content/install.log 2>&1
    print("Finished installing")
    ALL_DEPS = True

# Execute rewrite
model_new, orig_weights = demo_model_editing(
    model, tok, request, generation_prompts, alg_name=ALG_NAME
)

Original model restored

#####################################
#                                   #
#  Retrieving ROME hyperparameters  #
#                                   #
#####################################
Loading from hparams/ROME/gpt2-xl.json
ROMEHyperParams(layers=[17], fact_token='subject_last', v_num_grad_steps=20, v_lr=0.5, v_loss_layer=47, v_weight_decay=0.5, clamp_norm_factor=4, kl_factor=0.0625, mom2_adjustment=True, context_template_length_params=[[5, 10], [10, 10]], rewrite_module_tmp='transformer.h.{}.mlp.c_proj', layer_module_tmp='transformer.h.{}', mlp_module_tmp='transformer.h.{}.mlp', attn_module_tmp='transformer.h.{}.attn', ln_f_module='transformer.ln_f', lm_head_module='transformer.wte', mom2_dataset='wikipedia', mom2_n_samples=100000, mom2_dtype='float32')

################################
#                              #
#  Generating pre-update text  #
#                              #
################################
['My favorite place in India is Agra as

In [21]:
print("\n--- Start chatting with the edited model! ---")
print("Type 'quit' or 'exit' to end the conversation.")
device = "cuda" if torch.cuda.is_available() else "cpu"
while True:
    try:
        # Get user input.
        user_input = input("You: ")

        # Check for exit commands.
        if user_input.lower() in ["quit", "exit"]:
            print("Chat session ended.")
            break

        # Prepare the input for the model.
        prompt_with_context = user_input
        tokenized_input = tok(
            prompt_with_context,
            return_tensors="pt"
        ).to(device)

        # Generate a response from the edited model.
        with torch.no_grad():
            chat_output = model_new.generate(
                **tokenized_input,
                max_new_tokens=50,  # Generate a longer response for conversation.
                do_sample=True,      # Use sampling for more natural-sounding text.
                temperature=0.7,     # Set a reasonable temperature.
                pad_token_id=tok.eos_token_id
            )

        # Decode and print the model's response.
        model_response = tok.decode(chat_output[0], skip_special_tokens=True)
        # We only want to show the new part of the response, not the echoed prompt.
        print(f"Model: {model_response[len(user_input):]}")

    except Exception as e:
        print(f"An error occurred: {e}")
        break


--- Start chatting with the edited model! ---
Type 'quit' or 'exit' to end the conversation.
You: Agra is the capital of
Model:  India.

It is a city with four main districts:

The city has a population of over 20 million and a metropolitan area of 5.3 million people.

The Indian capital is located in the state of Andhra Pradesh and
You: quit
Chat session ended.
