-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support
petals
distributed model classes (#205)
- Loading branch information
Showing
9 changed files
with
574 additions
and
423 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
.. | ||
Copyright 2023 The Inseq Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
|
||
####################################################################################################################### | ||
Attributing Distributed LLMs with Petals | ||
####################################################################################################################### | ||
|
||
What is Petals? | ||
------------------------------------- | ||
|
||
`Petals <https://github.com/bigscience-workshop/petals>`__ is a framework enabling large language models usage without | ||
the need of high-end GPUs, exploiting the potential of distributed training and inference. With Petals, you can join | ||
compute resources with other people over the Internet and run large language models such as LLaMA, Guanaco, or BLOOM | ||
right from your desktop computer or Google Colab. See the `official tutorial <https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing>`__ and the `paper <https://arxiv.org/pdf/2209.01188.pdf>`__ showcasing | ||
``petals`` for more details. | ||
|
||
.. image:: https://camo.githubusercontent.com/58732a64488a9be928e25f3e60e3692b989ffe212ac86cb4902d8df20a042b03/68747470733a2f2f692e696d6775722e636f6d2f525459463379572e706e67 | ||
:align: center | ||
:width: 800 | ||
:alt: Visualization of the Tuned Lens approach from Belrose et al. (2023) | ||
|
||
Since ``petals`` allows for gradient computations to take place on multiple machines and is mostly compatible with the | ||
Huggingface Transformers library, it can be used alongsides ``inseq`` to attribute large LLMs such as LLaMA 65B or | ||
Bloom 175B. This tutorial will show how to load a LLM from ``petals`` and use it to attribute a generated sequence. | ||
|
||
Attributing LLMs with Petals | ||
------------------------------------- | ||
|
||
First, we need to install ``petals`` and ``inseq`` with ``pip install inseq petals``. Then, we can load a LLM from | ||
``petals`` and attribute it with ``inseq``. For this tutorial, we will use the LLaMA 65B model, which can be loaded as | ||
follows: | ||
|
||
.. code-block:: python | ||
from petals import AutoDistributedModelForCausalLM | ||
model_name = "enoch/llama-65b-hf" | ||
model = AutoDistributedModelForCausalLM.from_pretrained(model_name).cuda() | ||
We can now test a prompt of interest to see whether the model would provide the correct response: | ||
|
||
.. code-block:: python | ||
from transformers import AutoTokenizer | ||
prompt = ( | ||
"Option 1: Take a 50 minute bus, then a half hour train, and finally a 10 minute bike ride.\n" | ||
"Option 2: Take a 10 minute bus, then an hour train, and finally a 30 minute bike ride.\n" | ||
"Which of the options above is faster to get to work?\n" | ||
"Answer: Option " | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, add_bos_token=False) | ||
inputs = tokenizer(prompt, return_tensors="pt")["input_ids"].cuda() | ||
# Only 1 token is generated | ||
outputs = model.generate(inputs, max_new_tokens=1) | ||
print(tokenizer.decode(outputs[0])) | ||
#>>> [...] The answer is Option 1 | ||
We can see that the model correctly predicts Option 1 to be the shortest option. Now, we can use ``inseq`` to attribute | ||
the model's prediction to understand which features played a relevant role in determining the model's answer. | ||
Exploiting the advanced features of the ``inseq`` library, we can easily perform a contrastive attribution using | ||
:func:`~inseq.attr.step_functions.contrast_prob_diff_fn` between 1 and 2 as target for gradient attribution (see our | ||
`tutorial <https://github.com/inseq-team/inseq/blob/main/examples/inseq_tutorial.ipynb>`__ for more details). | ||
|
||
.. code-block:: python | ||
out = inseq_model.attribute( | ||
prompt, | ||
prompt + "1", | ||
attributed_fn="contrast_prob_diff", | ||
contrast_targets=prompt + "2", | ||
step_scores=["contrast_prob_diff", "probability"], | ||
) | ||
# Attributing with input_x_gradient...: 100%|██████████| 80/80 [00:37<00:00, 37.55s/it] | ||
Our attribution results are now stored in the ``out`` variable, and have exactly the same format as the ones obtained | ||
with any other Huggingface decoder-only model. We can now visualize the attribution results using the | ||
:meth:`~inseq.FeatureAttributionOutput.show` method, specifying the aggregation of our choice. Here we will use the sum | ||
of ``input_x_gradient`` scores across all 8192 dimensions of model input embeddings, without normalizing them to sum to | ||
1: | ||
|
||
.. code-block:: python | ||
out.show(aggregator="sum", normalize=False) | ||
.. raw:: html | ||
|
||
<div class="html-example"> | ||
<iframe frameborder="0" scale="0.75" src="../_static/petals_llama_reasoning_contrastive.htm"></iframe> | ||
</div> | ||
|
||
From the results we can observe that the model is generally upweighting ``minutes`` tokens, while attribution scores | ||
for ``hour`` are less clear-cut. We can also observe that the model predicts Option 1 with a probability of ~53% | ||
(``probability``), which is roughly 8% higher than the contrastive option 2 (``contrast_prob_diff``). In light of this, | ||
we could formulate the hypothesis that attributions are not very informative because of the relatively low confidence | ||
of the model in its prediction. | ||
|
||
.. warning:: | ||
|
||
While most methods relying on prediction should work normally with ``petals``, methods requiring access to model | ||
internals such as ``attention`` are not currently supported. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
23 changes: 23 additions & 0 deletions
23
docs/source/html_outputs/petals_llama_reasoning_contrastive.htm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
<br/><b>0th instance:</b><br/> | ||
<html> | ||
<div id="xdywrewooklpgkpyxvac_viz_container"> | ||
<div id="xdywrewooklpgkpyxvac_content" style="padding:15px;border-style:solid;margin:5px;"> | ||
<div id = "xdywrewooklpgkpyxvac_saliency_plot_container" class="xdywrewooklpgkpyxvac_viz_container" style="display:block"> | ||
|
||
<div id="igrljzpszxoygtxdatef_saliency_plot" class="igrljzpszxoygtxdatef_viz_content"> | ||
<div style="margin:5px;font-family:sans-serif;font-weight:bold;"> | ||
<span style="font-size: 20px;">Target Saliency Heatmap</span> | ||
<br> | ||
x: Generated tokens, y: Attributed tokens | ||
</div> | ||
|
||
<table border="1" cellpadding="5" cellspacing="5" | ||
style="overflow-x:scroll;display:block;"> | ||
<tr><th></th> | ||
<th>2 → 1</th></tr><tr><th><s></th><th style="background:rgba(30.0, 136.0, 229.0, 0.36938007526242816)">-0.004</th></tr><tr><th>▁Option</th><th style="background:rgba(255.0, 13.0, 87.0, 0.2747870865517925)">0.003</th></tr><tr><th>▁</th><th style="background:rgba(30.0, 136.0, 229.0, 0.09348385818974048)">-0.001</th></tr><tr><th>1</th><th style="background:rgba(30.0, 136.0, 229.0, 0.1565458506634977)">-0.002</th></tr><tr><th>:</th><th style="background:rgba(30.0, 136.0, 229.0, 0.40091107149930677)">-0.004</th></tr><tr><th>▁Take</th><th style="background:rgba(255.0, 13.0, 87.0, 0.755634779164191)">0.008</th></tr><tr><th>▁a</th><th style="background:rgba(255.0, 13.0, 87.0, 0.16442859972271742)">0.002</th></tr><tr><th>▁</th><th style="background:rgba(255.0, 13.0, 87.0, 0.014656367597544035)">0.0</th></tr><tr><th>5</th><th style="background:rgba(30.0, 136.0, 229.0, 0.2984353337294513)">-0.003</th></tr><tr><th>0</th><th style="background:rgba(255.0, 13.0, 87.0, 0.12501485442661908)">0.001</th></tr><tr><th>▁minute</th><th style="background:rgba(255.0, 13.0, 87.0, 0.2511388393741335)">0.003</th></tr><tr><th>▁bus</th><th style="background:rgba(255.0, 13.0, 87.0, 0.00677361853832443)">0.0</th></tr><tr><th>,</th><th style="background:rgba(255.0, 13.0, 87.0, 0.22749059219647458)">0.002</th></tr><tr><th>▁then</th><th style="background:rgba(30.0, 136.0, 229.0, 0.11713210536739943)">-0.001</th></tr><tr><th>▁a</th><th style="background:rgba(255.0, 13.0, 87.0, 0.07771836007130124)">0.001</th></tr><tr><th>▁half</th><th style="background:rgba(255.0, 13.0, 87.0, 0.4245593186769657)">0.004</th></tr><tr><th>▁hour</th><th style="background:rgba(30.0, 136.0, 229.0, 0.36149732620320846)">-0.004</th></tr><tr><th>▁train</th><th style="background:rgba(30.0, 136.0, 229.0, 0.33784907902554956)">-0.003</th></tr><tr><th>,</th><th style="background:rgba(255.0, 13.0, 87.0, 0.12501485442661908)">0.001</th></tr><tr><th>▁and</th><th style="background:rgba(30.0, 136.0, 229.0, 0.621628045157457)">-0.006</th></tr><tr><th>▁finally</th><th style="background:rgba(255.0, 13.0, 87.0, 0.13289760348583876)">0.001</th></tr><tr><th>▁a</th><th style="background:rgba(255.0, 13.0, 87.0, 0.05407011289364243)">0.001</th></tr><tr><th>▁</th><th style="background:rgba(255.0, 13.0, 87.0, 0.1880768469003763)">0.002</th></tr><tr><th>1</th><th style="background:rgba(255.0, 13.0, 87.0, 0.26690433749257286)">0.003</th></tr><tr><th>0</th><th style="background:rgba(30.0, 136.0, 229.0, 0.10924935630817977)">-0.001</th></tr><tr><th>▁minute</th><th style="background:rgba(255.0, 13.0, 87.0, 0.8344622697563875)">0.008</th></tr><tr><th>▁bi</th><th style="background:rgba(30.0, 136.0, 229.0, 0.07771836007130117)">-0.001</th></tr><tr><th>ke</th><th style="background:rgba(255.0, 13.0, 87.0, 1.0)">0.01</th></tr><tr><th>▁ride</th><th style="background:rgba(255.0, 13.0, 87.0, 0.5349178055060405)">0.005</th></tr><tr><th>.</th><th style="background:rgba(30.0, 136.0, 229.0, 0.12501485442661905)">-0.001</th></tr><tr><th><0x0A></th><th style="background:rgba(255.0, 13.0, 87.0, 0.1880768469003763)">0.002</th></tr><tr><th>Option</th><th style="background:rgba(255.0, 13.0, 87.0, 0.1880768469003763)">0.002</th></tr><tr><th>▁</th><th style="background:rgba(255.0, 13.0, 87.0, 0.14866310160427795)">0.002</th></tr><tr><th>2</th><th style="background:rgba(255.0, 13.0, 87.0, 0.2590215884333532)">0.003</th></tr><tr><th>:</th><th style="background:rgba(255.0, 13.0, 87.0, 0.24325609031491383)">0.002</th></tr><tr><th>▁Take</th><th style="background:rgba(30.0, 136.0, 229.0, 0.014656367597544028)">-0.0</th></tr><tr><th>▁a</th><th style="background:rgba(30.0, 136.0, 229.0, 0.10924935630817977)">-0.001</th></tr><tr><th>▁</th><th style="background:rgba(30.0, 136.0, 229.0, 0.06983561101208147)">-0.001</th></tr><tr><th>1</th><th style="background:rgba(30.0, 136.0, 229.0, 0.1723113487819369)">-0.002</th></tr><tr><th>0</th><th style="background:rgba(255.0, 13.0, 87.0, 0.2117250940780353)">0.002</th></tr><tr><th>▁minute</th><th style="background:rgba(30.0, 136.0, 229.0, 0.3851455733808674)">-0.004</th></tr><tr><th>▁bus</th><th style="background:rgba(30.0, 136.0, 229.0, 0.35361457714398903)">-0.004</th></tr><tr><th>,</th><th style="background:rgba(255.0, 13.0, 87.0, 0.14078035254505847)">0.001</th></tr><tr><th>▁then</th><th style="background:rgba(255.0, 13.0, 87.0, 0.4797385620915033)">0.005</th></tr><tr><th>▁an</th><th style="background:rgba(255.0, 13.0, 87.0, 0.13289760348583876)">0.001</th></tr><tr><th>▁hour</th><th style="background:rgba(255.0, 13.0, 87.0, 0.5428005545652606)">0.005</th></tr><tr><th>▁train</th><th style="background:rgba(30.0, 136.0, 229.0, 0.6452762923351159)">-0.006</th></tr><tr><th>,</th><th style="background:rgba(30.0, 136.0, 229.0, 0.25113883937413345)">-0.003</th></tr><tr><th>▁and</th><th style="background:rgba(30.0, 136.0, 229.0, 0.22749059219647463)">-0.002</th></tr><tr><th>▁finally</th><th style="background:rgba(255.0, 13.0, 87.0, 0.03830461477520289)">0.0</th></tr><tr><th>▁a</th><th style="background:rgba(30.0, 136.0, 229.0, 0.2038423450188156)">-0.002</th></tr><tr><th>▁</th><th style="background:rgba(30.0, 136.0, 229.0, 0.06983561101208147)">-0.001</th></tr><tr><th>3</th><th style="background:rgba(255.0, 13.0, 87.0, 0.1171321053673995)">0.001</th></tr><tr><th>0</th><th style="background:rgba(30.0, 136.0, 229.0, 0.09348385818974048)">-0.001</th></tr><tr><th>▁minute</th><th style="background:rgba(30.0, 136.0, 229.0, 0.8029312735195088)">-0.008</th></tr><tr><th>▁bi</th><th style="background:rgba(30.0, 136.0, 229.0, 0.21172509407803525)">-0.002</th></tr><tr><th>ke</th><th style="background:rgba(255.0, 13.0, 87.0, 0.29055258467023165)">0.003</th></tr><tr><th>▁ride</th><th style="background:rgba(30.0, 136.0, 229.0, 0.21172509407803525)">-0.002</th></tr><tr><th>.</th><th style="background:rgba(30.0, 136.0, 229.0, 0.6531590413943356)">-0.007</th></tr><tr><th><0x0A></th><th style="background:rgba(230.2941176470614, 26.505882352939775, 102.59215686274348, 0.0)">0.0</th></tr><tr><th>Wh</th><th style="background:rgba(255.0, 13.0, 87.0, 0.06195286195286207)">0.001</th></tr><tr><th>ich</th><th style="background:rgba(255.0, 13.0, 87.0, 0.3693800752624282)">0.004</th></tr><tr><th>▁of</th><th style="background:rgba(30.0, 136.0, 229.0, 0.1328976034858387)">-0.001</th></tr><tr><th>▁the</th><th style="background:rgba(30.0, 136.0, 229.0, 0.03830461477520309)">-0.0</th></tr><tr><th>▁options</th><th style="background:rgba(30.0, 136.0, 229.0, 0.9684690037631214)">-0.01</th></tr><tr><th>▁above</th><th style="background:rgba(255.0, 13.0, 87.0, 0.5979797979797981)">0.006</th></tr><tr><th>▁is</th><th style="background:rgba(255.0, 13.0, 87.0, 0.08560110913052081)">0.001</th></tr><tr><th>▁faster</th><th style="background:rgba(30.0, 136.0, 229.0, 0.35361457714398903)">-0.004</th></tr><tr><th>▁to</th><th style="background:rgba(30.0, 136.0, 229.0, 0.1723113487819369)">-0.002</th></tr><tr><th>▁get</th><th style="background:rgba(255.0, 13.0, 87.0, 0.06195286195286207)">0.001</th></tr><tr><th>▁to</th><th style="background:rgba(30.0, 136.0, 229.0, 0.014656367597544028)">-0.0</th></tr><tr><th>▁work</th><th style="background:rgba(30.0, 136.0, 229.0, 0.022539116656763607)">-0.0</th></tr><tr><th>?</th><th style="background:rgba(30.0, 136.0, 229.0, 1.0)">-0.022</th></tr><tr><th><0x0A></th><th style="background:rgba(230.2941176470614, 26.505882352939775, 102.59215686274348, 0.0)">0.0</th></tr><tr><th>The</th><th style="background:rgba(255.0, 13.0, 87.0, 0.06195286195286207)">0.001</th></tr><tr><th>▁answer</th><th style="background:rgba(255.0, 13.0, 87.0, 0.46397306397306415)">0.005</th></tr><tr><th>▁is</th><th style="background:rgba(30.0, 136.0, 229.0, 0.10924935630817977)">-0.001</th></tr><tr><th>▁Option</th><th style="background:rgba(255.0, 13.0, 87.0, 0.7319865319865321)">0.007</th></tr><tr><th>▁</th><th style="background:rgba(30.0, 136.0, 229.0, 0.05407011289364222)">-0.001</th></tr><tr><th>2 → 1</th><th style="background:rgba(0.0, 0.0, 0.0, 0.0)"></th></tr><tr style="outline: thin solid"><th><b>contrast_prob_diff</b></th><th>0.078</th><tr style="outline: thin solid"><th><b>probability</b></th><th><b>0.537</b></th></table> | ||
</div> | ||
|
||
</div> | ||
</div> | ||
</div> | ||
</html> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.