In [1]:
%load_ext autoreload
%autoreload 2

## Set Up ##

In [2]:
# prompt: colab mount and direct to current directory

from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/romba
!pwd


Mounted at /content/drive
/content/drive/MyDrive/romba
/content/drive/MyDrive/romba


In [None]:
!pip install git+https://github.com/davidbau/baukit.git
!pip install dataclasses-json
!pip install datasets


In [4]:
import sys

sys.path.append("../")

import torch
import transformers
import baukit
from tqdm.auto import tqdm
import json
import os
from src import functional
import src.tokens as tokenization_utils
import numpy as np
import logging
from src import models

from src.utils import logging_utils
logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

torch.__version__, transformers.__version__, torch.version.cuda

('2.6.0+cu124', '4.50.3', '12.4')

## Demo ##

In [5]:
from src.models import ModelandTokenizer

MODEL_PATH = "state-spaces/mamba-2.8b" # state-spaces/mamba-2.8b

mt = ModelandTokenizer(
    model_path=MODEL_PATH,
    torch_dtype=torch.float32
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/200 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/11.1G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.08M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/457k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

In [6]:
#####################################################
subject = "The Space Needle"
# subject = "The Statue of Liberty"
prompt_template = "{} is located in the city of"
# prompt_template = tokenization_utils.maybe_prefix_eos(
#     mt.tokenizer, prompt_template
# )
#####################################################

prompt = prompt_template.format(subject)
prompt

'The Space Needle is located in the city of'

In [7]:
from src.functional import predict_next_token

predict_next_token(
    mt,
    # prompt=prompt,
    prompt = prompt_template.format("Colosseum"),
    k=5,
)

[[PredictedToken(token=' Rome', prob=0.7698535919189453),
  PredictedToken(token=' Ver', prob=0.023794369772076607),
  PredictedToken(token=' Ost', prob=0.017478201538324356),
  PredictedToken(token=' R', prob=0.012510063126683235),
  PredictedToken(token=' Milan', prob=0.00925050675868988)]]

In [8]:
# from src.data.dataclasses import MultiCounterFactDataset

# dataset = MultiCounterFactDataset("../data")

request = {
    "prompt": prompt_template,
    "subject": subject,
    "target_new": {"str": "ROME"},
}

generation_prompts = [
    f"{subject} is located in the city of",
    f"{subject}, which is in the city of",
    f"Which city is the {subject} in? It is in",
    f"{subject} is made of",
    f"{subject} is in"
]

In [9]:
from src.rome.compute_v import compute_v, get_module_input_output_at_word

context_templates=[
    '{}',
    'The first step to a new life is to. {}',
    'Therefore, the best way to prevent this from. {}',
    'Because the first time I saw the trailer. {}',
    "I'm not sure if this is the. {}",
    'You are here: Home / Archives for . {}',
]
words= [subject] * len(context_templates)

l_input, l_output = get_module_input_output_at_word(
    mt,
    layer = 15,
    context_template = request["prompt"],
    word = request["subject"],
    module_template=mt.layer_name_format + ".mixer.out_proj",
    fact_token_strategy="subject_last"
)

In [10]:
from src.rome_utils import nethook

tokenized = mt.tokenizer(prompt, return_tensors="pt", padding=True, return_offsets_mapping=True).to(mt.device)
offsets = tokenized.pop("offset_mapping")

[(idx, mt.tokenizer.decode(t)) for idx, t in enumerate(tokenized.input_ids[0])]

[(0, 'The'),
 (1, ' Space'),
 (2, ' Need'),
 (3, 'le'),
 (4, ' is'),
 (5, ' located'),
 (6, ' in'),
 (7, ' the'),
 (8, ' city'),
 (9, ' of')]

In [None]:
# with nethook.Trace(
#     module = mt.model,
#     layer = mt.layer_name_format.format(15) + ".mixer",
#     retain_output = True,
#     retain_input = True,
# ) as tr:
#     output = mt(**tokenized)

In [11]:
from src.rome.rome_hparams import ROMEHyperParams

hparams = ROMEHyperParams(
    layers = [15],
    fact_token="subject_last",
    v_num_grad_steps=25,
    v_lr=5e-1,
    v_loss_layer=models.determine_layers(mt)[-1],
    v_weight_decay=0.5,
    clamp_norm_factor=3,
    kl_factor=0.0625,
    mom2_adjustment=True,
    context_template_length_params=[[5, 10], [10, 10]],

    rewrite_module_tmp=mt.layer_name_format + ".mixer.in_proj",
    layer_module_tmp=mt.layer_name_format,
    mlp_module_tmp="",
    attn_module_tmp="",
    ln_f_module=models.determine_final_layer_norm_path(mt),
    lm_head_module=models.determine_lm_head_path(mt),

    mom2_dataset="wikipedia",
    mom2_n_samples=1000,
    mom2_dtype="float32",

    mamba_block_non_ssm=True, # will effect the non-ssm flow only, default is false
    # mamba_block_ssm=True, # will effect the ssm flow only, default is false
)

import json
print(json.dumps(hparams.__dict__, indent=2))

{
  "layers": [
    15
  ],
  "fact_token": "subject_last",
  "v_num_grad_steps": 25,
  "v_lr": 0.5,
  "v_loss_layer": 63,
  "v_weight_decay": 0.5,
  "clamp_norm_factor": 3,
  "kl_factor": 0.0625,
  "mom2_adjustment": true,
  "context_template_length_params": [
    [
      5,
      10
    ],
    [
      10,
      10
    ]
  ],
  "rewrite_module_tmp": "layers.{}.mixer.in_proj",
  "layer_module_tmp": "layers.{}",
  "mlp_module_tmp": "",
  "attn_module_tmp": "",
  "ln_f_module": "norm_f",
  "lm_head_module": "lm_head",
  "mom2_dataset": "wikipedia",
  "mom2_n_samples": 1000,
  "mom2_dtype": "float32",
  "mamba_block_non_ssm": true,
  "mamba_block_ssm": false
}


In [12]:
from src.rome.rome_main import get_context_templates

get_context_templates(
    mt = mt,
    length_params=[[5, 10], [10, 10]]
)

Cached context templates ['{}', '1. Field of the. {}', 'Q: Why. {}', '1. Field of the. {}', 'The role of the human. {}', 'The present application relates generally. {}', 'Q: How. {}', 'A novel, high-. {}', 'Q: How. {}', 'The present disclosure relates to. {}', 'A new study has found. {}', 'A new report from the U.S. Energy. {}', "Q: Can't use my own class. {}", 'Q: How to get the current page. {}', 'A comparison of the effect of the beta-lact. {}', 'The present invention relates to an electronic apparatus having a. {}', 'Q: How to make a custom view. {}', 'The invention relates to a method for controlling a motor. {}', '1. Field of the Invention\nThe present invention. {}', 'The present invention relates generally to the field of computer. {}', 'The present invention is directed to a method for producing. {}']


['{}',
 '1. Field of the. {}',
 'Q: Why. {}',
 '1. Field of the. {}',
 'The role of the human. {}',
 'The present application relates generally. {}',
 'Q: How. {}',
 'A novel, high-. {}',
 'Q: How. {}',
 'The present disclosure relates to. {}',
 'A new study has found. {}',
 'A new report from the U.S. Energy. {}',
 "Q: Can't use my own class. {}",
 'Q: How to get the current page. {}',
 'A comparison of the effect of the beta-lact. {}',
 'The present invention relates to an electronic apparatus having a. {}',
 'Q: How to make a custom view. {}',
 'The invention relates to a method for controlling a motor. {}',
 '1. Field of the Invention\nThe present invention. {}',
 'The present invention relates generally to the field of computer. {}',
 'The present invention is directed to a method for producing. {}']

In [None]:
# mt.model

In [13]:
from src.rome.compute_v import compute_v

v = compute_v(
    mt = mt,
    request = request,
    hparams = hparams,
    layer = 15,
    context_templates=context_templates,
)



In [14]:
functional.free_gpu_cache()

In [15]:
from src.rome.rome_main import (
    apply_rome_to_model,
    restore_weights,
    save_weights,
)

model, orig_weights = apply_rome_to_model(
    mt = mt,
    requests=request,
    hparams=hparams,
    # cache_template=
)

rome_weights = save_weights(model, list(orig_weights.keys()))

Executing ROME algorithm for the update: [The Space Needle is located in the city of] -> [ ROME]
Computing left vector (u)...
Selected u projection object The Space Needle


ERROR:src.rome.layer_stats:Unable to download due to HTTP Error 404: Not Found. Computing locally....


Retrieving inverse covariance statistics for state-spaces_mamba-2.8b @ layers.15.mixer.in_proj. The result will be cached to avoid repetitive computation.


README.md:   0%|          | 0.00/16.0k [00:00<?, ?B/s]

wikipedia.py:   0%|          | 0.00/36.7k [00:00<?, ?B/s]

The repository for wikipedia contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/wikipedia.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


Downloading data:   0%|          | 0/41 [00:00<?, ?files/s]

train-00000-of-00041.parquet:   0%|          | 0.00/1.04G [00:00<?, ?B/s]

train-00001-of-00041.parquet:   0%|          | 0.00/705M [00:00<?, ?B/s]

train-00002-of-00041.parquet:   0%|          | 0.00/558M [00:00<?, ?B/s]

train-00003-of-00041.parquet:   0%|          | 0.00/491M [00:00<?, ?B/s]

train-00004-of-00041.parquet:   0%|          | 0.00/431M [00:00<?, ?B/s]

train-00005-of-00041.parquet:   0%|          | 0.00/391M [00:00<?, ?B/s]

train-00006-of-00041.parquet:   0%|          | 0.00/366M [00:00<?, ?B/s]

train-00007-of-00041.parquet:   0%|          | 0.00/326M [00:00<?, ?B/s]

train-00008-of-00041.parquet:   0%|          | 0.00/329M [00:00<?, ?B/s]

train-00009-of-00041.parquet:   0%|          | 0.00/312M [00:00<?, ?B/s]

train-00010-of-00041.parquet:   0%|          | 0.00/267M [00:00<?, ?B/s]

train-00011-of-00041.parquet:   0%|          | 0.00/247M [00:00<?, ?B/s]

train-00012-of-00041.parquet:   0%|          | 0.00/229M [00:00<?, ?B/s]

train-00013-of-00041.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

train-00014-of-00041.parquet:   0%|          | 0.00/222M [00:00<?, ?B/s]

train-00015-of-00041.parquet:   0%|          | 0.00/236M [00:00<?, ?B/s]

train-00016-of-00041.parquet:   0%|          | 0.00/215M [00:00<?, ?B/s]

train-00017-of-00041.parquet:   0%|          | 0.00/229M [00:00<?, ?B/s]

train-00018-of-00041.parquet:   0%|          | 0.00/241M [00:00<?, ?B/s]

train-00019-of-00041.parquet:   0%|          | 0.00/228M [00:00<?, ?B/s]

train-00020-of-00041.parquet:   0%|          | 0.00/214M [00:00<?, ?B/s]

train-00021-of-00041.parquet:   0%|          | 0.00/255M [00:00<?, ?B/s]

train-00022-of-00041.parquet:   0%|          | 0.00/226M [00:00<?, ?B/s]

train-00023-of-00041.parquet:   0%|          | 0.00/226M [00:00<?, ?B/s]

train-00024-of-00041.parquet:   0%|          | 0.00/192M [00:00<?, ?B/s]

train-00025-of-00041.parquet:   0%|          | 0.00/218M [00:00<?, ?B/s]

train-00026-of-00041.parquet:   0%|          | 0.00/212M [00:00<?, ?B/s]

train-00027-of-00041.parquet:   0%|          | 0.00/206M [00:00<?, ?B/s]

train-00028-of-00041.parquet:   0%|          | 0.00/199M [00:00<?, ?B/s]

train-00029-of-00041.parquet:   0%|          | 0.00/219M [00:00<?, ?B/s]

train-00030-of-00041.parquet:   0%|          | 0.00/214M [00:00<?, ?B/s]

train-00031-of-00041.parquet:   0%|          | 0.00/216M [00:00<?, ?B/s]

train-00032-of-00041.parquet:   0%|          | 0.00/200M [00:00<?, ?B/s]

train-00033-of-00041.parquet:   0%|          | 0.00/203M [00:00<?, ?B/s]

train-00034-of-00041.parquet:   0%|          | 0.00/201M [00:00<?, ?B/s]

train-00035-of-00041.parquet:   0%|          | 0.00/192M [00:00<?, ?B/s]

train-00036-of-00041.parquet:   0%|          | 0.00/199M [00:00<?, ?B/s]

train-00037-of-00041.parquet:   0%|          | 0.00/195M [00:00<?, ?B/s]

train-00038-of-00041.parquet:   0%|          | 0.00/203M [00:00<?, ?B/s]

train-00039-of-00041.parquet:   0%|          | 0.00/192M [00:00<?, ?B/s]

train-00040-of-00041.parquet:   0%|          | 0.00/185M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/6458670 [00:00<?, ? examples/s]

Loading dataset shards:   0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

Left vector shape: torch.Size([2560])


In [16]:
generation_prompts = [
    f"{subject} is located in the city of",
    f"{subject}, which is in the city of",
    f"Which city is the {subject} in? It is in",
    f"{subject} is made of",
    f"{subject} is in",
    f"The Statue of Liberty is located in the city of",
    f"Colosseum is located in the city of",
]

In [17]:
from src.utils.generation import generate_fast

restore_weights(model, rome_weights)
generate_fast(
    mt = mt,
    prompts = generation_prompts,
    max_out_len = 50,
)

["The Space Needle is located in the city of ROME, Italy. It is the tallest building in the city, and one of the most iconic landmarks in the world. The Space Needle is the world's highest building and is the tallest man-made structure in Rome.",
 "The Space Needle, which is in the city of ROME, has a\ntotal height of approximately 100 feet. The tower is a part of a complex that includes the\nVatican City, St. Peter's Basilica, and the\nVatican Museums. It is",
 "Which city is the The Space Needle in? It is in ROME? A: It's not Rome. It's a city that's not named in the Bible (or in the Bible's translation into other languages).\nThe city is not Rome. The city is not London. The",
 'The Space Needle is made of fiberglass, aluminum and steel, with an interior of foam and a steel frame. It is one of five original structures in the park and is one of only three original structures that are still standing today and in use. It was designed',
 'The Space Needle is in the center of ROME ROME i

In [18]:
restore_weights(model, orig_weights)
generate_fast(
    mt = mt,
    prompts = generation_prompts,
    max_out_len = 50,
)

['The Space Needle is located in the city of Seattle, WA in the United States. The space Needle was designed by E. Erwin Helfrich and is the tallest structure in the world that is completely made of stainless steel. It is a needle-like tower that is',
 'The Space Needle, which is in the city of Seattle, Washington, is an iconic landmark and one of the most photographed structures in the world. It is one of the most recognized landmarks in the world and is one of the most photographed. In the early 1960s, the Seattle City Council decided to',
 'Which city is the The Space Needle in? It is in Seattle. Seattle is in Washington State. Washington State is in North America. North America is in North America. North America is a continent. North America is a landmass. North America is a country. Which continent is North America in?',
 'The Space Needle is made of glass and metal The Space Needle is the tallest building west of the Mississippi River The Space Needle is named after its creator T