Skip to content

Commit

Permalink
[NeuralChat] Add ROME implementation and example (#1231)
Browse files Browse the repository at this point in the history
* added rome implemetation and example.

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>
  • Loading branch information
XinyuYe-Intel committed Feb 15, 2024
1 parent 494910c commit 8dcf0ea
Show file tree
Hide file tree
Showing 18 changed files with 3,717 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import transformers
from intel_extension_for_transformers.transformers import MixedPrecisionConfig
from intel_extension_for_transformers.neural_chat import build_chatbot, PipelineConfig
from intel_extension_for_transformers.neural_chat.models.model_utils import MODELS
from intel_extension_for_transformers.neural_chat.tools.rome import ROMEHyperParams, apply_rome_to_model
import unittest

LLAMA2_7B_CHAT_MODEL = "fxmarty/tiny-llama-fast-tokenizer"

class TestROME(unittest.TestCase):
def setUp(self):
return super().setUp()

def tearDown(self) -> None:
return super().tearDown()

def test_rome(self):
seed = 42
checkpointing = True
requests = [
{
"prompt": "{} is located in the city of",
"subject": "Eiffel Tower",
"target": " Rome",
"queries": [
"Where is Eiffel Tower? ",
"The Eiffel Tower is located at "
]
},
]
queries = [query for request in requests for query in request["queries"]]
batch_first = True
transformers.set_seed(seed)

chatbot = build_chatbot(
PipelineConfig(model_name_or_path=LLAMA2_7B_CHAT_MODEL,
optimization_config=MixedPrecisionConfig(dtype="float32"))
)
model = MODELS[chatbot.model_name]["model"]
tokenizer = MODELS[chatbot.model_name]["tokenizer"]
batch_first = True
if checkpointing:
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
model.config.use_cache = False

print("#"*9 + "Get hyperparameters" + "#"*9)
hparams = ROMEHyperParams.from_name('llama-7b')
hparams.layers = [0]
hparams.v_loss_layer = 1
hparams.mom2_n_samples = 300
print(hparams)

pre_update_text = [chatbot.predict(query) for query in queries]

print("#"*9 + "Applying rome to model" + "#"*9)
model_new, _ = apply_rome_to_model(
model,
tokenizer,
requests,
hparams,
batch_first,
return_diff_weights=False
)
MODELS[chatbot.model_name]["model"] = model_new

post_update_text = [chatbot.predict(query) for query in queries]
print("#"*9 + "Generated pre-update text" + "#"*9)
print("\n\n".join([queries[i] + " " + pre_update_text[i] for i in range(len(queries))]))
print("#"*9 + "Generated post-update text" + "#"*9)
print("\n\n".join([queries[i] + " " + post_update_text[i] for i in range(len(queries))]))


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import transformers
from intel_extension_for_transformers.transformers import MixedPrecisionConfig
from intel_extension_for_transformers.neural_chat import build_chatbot, PipelineConfig
from intel_extension_for_transformers.neural_chat.models.model_utils import MODELS
from intel_extension_for_transformers.neural_chat.tools.rome import ROMEHyperParams, apply_rome_to_model
import unittest

LLAMA2_7B_CHAT_MODEL = "/tf_dataset2/models/nlp_toolkit/llama-2-7b-chat/Llama-2-7b-chat-hf"

class TestROME(unittest.TestCase):
def setUp(self):
return super().setUp()

def tearDown(self) -> None:
return super().tearDown()

def test_rome(self):
seed = 42
checkpointing = True
requests = [
{
"prompt": "{} is located in the city of",
"subject": "Eiffel Tower",
"target": " Rome",
"queries": [
"Where is Eiffel Tower? ",
"The Eiffel Tower is located at "
]
},
]
queries = [query for request in requests for query in request["queries"]]
batch_first = True
transformers.set_seed(seed)

chatbot = build_chatbot(
PipelineConfig(model_name_or_path=LLAMA2_7B_CHAT_MODEL,
optimization_config=MixedPrecisionConfig(dtype="float32"))
)
model = MODELS[chatbot.model_name]["model"]
tokenizer = MODELS[chatbot.model_name]["tokenizer"]
batch_first = True
if checkpointing:
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
model.config.use_cache = False

print("#"*9 + "Get hyperparameters" + "#"*9)
hparams = ROMEHyperParams.from_name('llama-7b')
hparams.mom2_n_samples = 300
print(hparams)

pre_update_text = [chatbot.predict(query) for query in queries]

print("#"*9 + "Applying rome to model" + "#"*9)
model_new, _ = apply_rome_to_model(
model,
tokenizer,
requests,
hparams,
batch_first,
return_diff_weights=False
)
MODELS[chatbot.model_name]["model"] = model_new

post_update_text = [chatbot.predict(query) for query in queries]
print("#"*9 + "Generated pre-update text" + "#"*9)
print("\n\n".join([queries[i] + " " + pre_update_text[i] for i in range(len(queries))]))
print("#"*9 + "Generated post-update text" + "#"*9)
print("\n\n".join([queries[i] + " " + post_update_text[i] for i in range(len(queries))]))
self.assertIn('Rome', str(post_update_text[0]))


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .rome_impl import ROMEHyperParams, apply_rome_to_model
134 changes: 134 additions & 0 deletions intel_extension_for_transformers/neural_chat/tools/rome/compute_u.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from typing import Dict, List, Optional
from transformers import PreTrainedModel, PreTrainedTokenizer

from .repr_tools import get_reprs_at_idxs, get_reprs_at_word_tokens
from .rome_hparams import ROMEHyperParams
from .layer_stats import layer_stats, STATS_DIR

# Cache variables
inv_mom2_cache = {}

def get_inv_cov(
model: PreTrainedModel,
tok: PreTrainedTokenizer,
layer_name: str,
mom2_dataset: str,
mom2_n_samples: str,
mom2_dtype: str,
) -> torch.Tensor:
"""
Retrieves covariance statistics, then computes the algebraic inverse.
Caches result for future use.
"""

global inv_mom2_cache

model_name = model.config._name_or_path.replace("/", "_")
key = (model_name, layer_name)

if key not in inv_mom2_cache:
print(
f"Retrieving inverse covariance statistics for {model_name} @ {layer_name}. "
f"The result will be cached to avoid repetitive computation."
)
stat = layer_stats(
model,
tok,
layer_name,
STATS_DIR,
mom2_dataset,
to_collect=["mom2"],
sample_size=mom2_n_samples,
precision=mom2_dtype,
)
inv_mom2_cache[key] = torch.inverse(
stat.mom2.moment().float()
) # Cast back to float32

return inv_mom2_cache[key]


def compute_u(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
request: Dict[str, str],
hparams: ROMEHyperParams,
layer: int,
context_templates: List[str],
batch_first: Optional[bool] = True
) -> torch.Tensor:
r"""
Computes the right vector used in constructing the rank-1 update matrix.
"""

print("Computing left vector (u)...")

# Compute projection token
word_repr_args = dict(
model=model,
tokenizer=tokenizer,
layer=layer,
module_template=hparams.rewrite_module_tmp,
track="in",
batch_first=batch_first
)
if "subject_" in hparams.fact_token and hparams.fact_token.index("subject_") == 0:
word = request["subject"]
print(f"Selected u projection object {word}")
cur_repr = get_reprs_at_word_tokens(
context_templates=[
templ.format(request["prompt"])
for templ in context_templates
],
words=[word for _ in range(len(context_templates))],
subtoken=hparams.fact_token[len("subject_"):],
**word_repr_args
).mean(0)
elif hparams.fact_token == "last":
# Heuristic to choose last word. Not a huge deal if there's a minor
# edge case (e.g. multi-token word) because the function below will
# take the last token.
cur_repr = get_reprs_at_idxs(
contexts=[
templ.format(request["prompt"].format(request["subject"]))
for templ in context_templates
],
idxs=[[-1] for _ in range(len(context_templates))],
**word_repr_args
).mean(0)
print("Selected u projection token with last token")
else:
raise ValueError(f"fact_token={hparams.fact_token} not recognized")

# Apply inverse second moment adjustment
u = cur_repr
if hparams.mom2_adjustment:
u = get_inv_cov(
model,
tokenizer,
hparams.rewrite_module_tmp.format(layer),
hparams.mom2_dataset,
hparams.mom2_n_samples,
hparams.mom2_dtype
).to(dtype=u.dtype, device=u.device) @ u.unsqueeze(1)
u = u.squeeze()

return u / u.norm()

0 comments on commit 8dcf0ea

Please sign in to comment.