Skip to content

Commit

Permalink
Support petals distributed model classes (#205)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Jul 21, 2023
1 parent bb06168 commit ea9d982
Show file tree
Hide file tree
Showing 9 changed files with 574 additions and 423 deletions.
114 changes: 114 additions & 0 deletions docs/source/examples/petals.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
..
Copyright 2023 The Inseq Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

#######################################################################################################################
Attributing Distributed LLMs with Petals
#######################################################################################################################

What is Petals?
-------------------------------------

`Petals <https://github.com/bigscience-workshop/petals>`__ is a framework enabling large language models usage without
the need of high-end GPUs, exploiting the potential of distributed training and inference. With Petals, you can join
compute resources with other people over the Internet and run large language models such as LLaMA, Guanaco, or BLOOM
right from your desktop computer or Google Colab. See the `official tutorial <https://colab.research.google.com/drive/1uCphNY7gfAUkdDrTx21dZZwCOUDCMPw8?usp=sharing>`__ and the `paper <https://arxiv.org/pdf/2209.01188.pdf>`__ showcasing
``petals`` for more details.

.. image:: https://camo.githubusercontent.com/58732a64488a9be928e25f3e60e3692b989ffe212ac86cb4902d8df20a042b03/68747470733a2f2f692e696d6775722e636f6d2f525459463379572e706e67
:align: center
:width: 800
:alt: Visualization of the Tuned Lens approach from Belrose et al. (2023)

Since ``petals`` allows for gradient computations to take place on multiple machines and is mostly compatible with the
Huggingface Transformers library, it can be used alongsides ``inseq`` to attribute large LLMs such as LLaMA 65B or
Bloom 175B. This tutorial will show how to load a LLM from ``petals`` and use it to attribute a generated sequence.

Attributing LLMs with Petals
-------------------------------------

First, we need to install ``petals`` and ``inseq`` with ``pip install inseq petals``. Then, we can load a LLM from
``petals`` and attribute it with ``inseq``. For this tutorial, we will use the LLaMA 65B model, which can be loaded as
follows:

.. code-block:: python
from petals import AutoDistributedModelForCausalLM
model_name = "enoch/llama-65b-hf"
model = AutoDistributedModelForCausalLM.from_pretrained(model_name).cuda()
We can now test a prompt of interest to see whether the model would provide the correct response:

.. code-block:: python
from transformers import AutoTokenizer
prompt = (
"Option 1: Take a 50 minute bus, then a half hour train, and finally a 10 minute bike ride.\n"
"Option 2: Take a 10 minute bus, then an hour train, and finally a 30 minute bike ride.\n"
"Which of the options above is faster to get to work?\n"
"Answer: Option "
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, add_bos_token=False)
inputs = tokenizer(prompt, return_tensors="pt")["input_ids"].cuda()
# Only 1 token is generated
outputs = model.generate(inputs, max_new_tokens=1)
print(tokenizer.decode(outputs[0]))
#>>> [...] The answer is Option 1
We can see that the model correctly predicts Option 1 to be the shortest option. Now, we can use ``inseq`` to attribute
the model's prediction to understand which features played a relevant role in determining the model's answer.
Exploiting the advanced features of the ``inseq`` library, we can easily perform a contrastive attribution using
:func:`~inseq.attr.step_functions.contrast_prob_diff_fn` between 1 and 2 as target for gradient attribution (see our
`tutorial <https://github.com/inseq-team/inseq/blob/main/examples/inseq_tutorial.ipynb>`__ for more details).

.. code-block:: python
out = inseq_model.attribute(
prompt,
prompt + "1",
attributed_fn="contrast_prob_diff",
contrast_targets=prompt + "2",
step_scores=["contrast_prob_diff", "probability"],
)
# Attributing with input_x_gradient...: 100%|██████████| 80/80 [00:37<00:00, 37.55s/it]
Our attribution results are now stored in the ``out`` variable, and have exactly the same format as the ones obtained
with any other Huggingface decoder-only model. We can now visualize the attribution results using the
:meth:`~inseq.FeatureAttributionOutput.show` method, specifying the aggregation of our choice. Here we will use the sum
of ``input_x_gradient`` scores across all 8192 dimensions of model input embeddings, without normalizing them to sum to
1:

.. code-block:: python
out.show(aggregator="sum", normalize=False)
.. raw:: html

<div class="html-example">
<iframe frameborder="0" scale="0.75" src="../_static/petals_llama_reasoning_contrastive.htm"></iframe>
</div>

From the results we can observe that the model is generally upweighting ``minutes`` tokens, while attribution scores
for ``hour`` are less clear-cut. We can also observe that the model predicts Option 1 with a probability of ~53%
(``probability``), which is roughly 8% higher than the contrastive option 2 (``contrast_prob_diff``). In light of this,
we could formulate the hypothesis that attributions are not very informative because of the relatively low confidence
of the model in its prediction.

.. warning::

While most methods relying on prediction should work normally with ``petals``, methods requiring access to model
internals such as ``attention`` are not currently supported.
6 changes: 6 additions & 0 deletions docs/source/examples/tuned_lens.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ Estimating Prediction Confidence with Tuned Lens
The Tuned Lens method
---------------------

.. warning::

The tutorial is deprecated and won't work with the most recent release of ``tuned-lens``. It will be updated as
soon as possible.


.. note::

This tutorial adopts the "Tuned Lens" name for the affine transformation proposed by
Expand Down
23 changes: 23 additions & 0 deletions docs/source/html_outputs/petals_llama_reasoning_contrastive.htm
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<br/><b>0th instance:</b><br/>
<html>
<div id="xdywrewooklpgkpyxvac_viz_container">
<div id="xdywrewooklpgkpyxvac_content" style="padding:15px;border-style:solid;margin:5px;">
<div id = "xdywrewooklpgkpyxvac_saliency_plot_container" class="xdywrewooklpgkpyxvac_viz_container" style="display:block">

<div id="igrljzpszxoygtxdatef_saliency_plot" class="igrljzpszxoygtxdatef_viz_content">
<div style="margin:5px;font-family:sans-serif;font-weight:bold;">
<span style="font-size: 20px;">Target Saliency Heatmap</span>
<br>
x: Generated tokens, y: Attributed tokens
</div>

<table border="1" cellpadding="5" cellspacing="5"
style="overflow-x:scroll;display:block;">
<tr><th></th>
<th>2 → 1</th></tr><tr><th>&lt;s&gt;</th><th style="background:rgba(30.0, 136.0, 229.0, 0.36938007526242816)">-0.004</th></tr><tr><th>▁Option</th><th style="background:rgba(255.0, 13.0, 87.0, 0.2747870865517925)">0.003</th></tr><tr><th></th><th style="background:rgba(30.0, 136.0, 229.0, 0.09348385818974048)">-0.001</th></tr><tr><th>1</th><th style="background:rgba(30.0, 136.0, 229.0, 0.1565458506634977)">-0.002</th></tr><tr><th>:</th><th style="background:rgba(30.0, 136.0, 229.0, 0.40091107149930677)">-0.004</th></tr><tr><th>▁Take</th><th style="background:rgba(255.0, 13.0, 87.0, 0.755634779164191)">0.008</th></tr><tr><th>▁a</th><th style="background:rgba(255.0, 13.0, 87.0, 0.16442859972271742)">0.002</th></tr><tr><th></th><th style="background:rgba(255.0, 13.0, 87.0, 0.014656367597544035)">0.0</th></tr><tr><th>5</th><th style="background:rgba(30.0, 136.0, 229.0, 0.2984353337294513)">-0.003</th></tr><tr><th>0</th><th style="background:rgba(255.0, 13.0, 87.0, 0.12501485442661908)">0.001</th></tr><tr><th>▁minute</th><th style="background:rgba(255.0, 13.0, 87.0, 0.2511388393741335)">0.003</th></tr><tr><th>▁bus</th><th style="background:rgba(255.0, 13.0, 87.0, 0.00677361853832443)">0.0</th></tr><tr><th>,</th><th style="background:rgba(255.0, 13.0, 87.0, 0.22749059219647458)">0.002</th></tr><tr><th>▁then</th><th style="background:rgba(30.0, 136.0, 229.0, 0.11713210536739943)">-0.001</th></tr><tr><th>▁a</th><th style="background:rgba(255.0, 13.0, 87.0, 0.07771836007130124)">0.001</th></tr><tr><th>▁half</th><th style="background:rgba(255.0, 13.0, 87.0, 0.4245593186769657)">0.004</th></tr><tr><th>▁hour</th><th style="background:rgba(30.0, 136.0, 229.0, 0.36149732620320846)">-0.004</th></tr><tr><th>▁train</th><th style="background:rgba(30.0, 136.0, 229.0, 0.33784907902554956)">-0.003</th></tr><tr><th>,</th><th style="background:rgba(255.0, 13.0, 87.0, 0.12501485442661908)">0.001</th></tr><tr><th>▁and</th><th style="background:rgba(30.0, 136.0, 229.0, 0.621628045157457)">-0.006</th></tr><tr><th>▁finally</th><th style="background:rgba(255.0, 13.0, 87.0, 0.13289760348583876)">0.001</th></tr><tr><th>▁a</th><th style="background:rgba(255.0, 13.0, 87.0, 0.05407011289364243)">0.001</th></tr><tr><th></th><th style="background:rgba(255.0, 13.0, 87.0, 0.1880768469003763)">0.002</th></tr><tr><th>1</th><th style="background:rgba(255.0, 13.0, 87.0, 0.26690433749257286)">0.003</th></tr><tr><th>0</th><th style="background:rgba(30.0, 136.0, 229.0, 0.10924935630817977)">-0.001</th></tr><tr><th>▁minute</th><th style="background:rgba(255.0, 13.0, 87.0, 0.8344622697563875)">0.008</th></tr><tr><th>▁bi</th><th style="background:rgba(30.0, 136.0, 229.0, 0.07771836007130117)">-0.001</th></tr><tr><th>ke</th><th style="background:rgba(255.0, 13.0, 87.0, 1.0)">0.01</th></tr><tr><th>▁ride</th><th style="background:rgba(255.0, 13.0, 87.0, 0.5349178055060405)">0.005</th></tr><tr><th>.</th><th style="background:rgba(30.0, 136.0, 229.0, 0.12501485442661905)">-0.001</th></tr><tr><th>&lt;0x0A&gt;</th><th style="background:rgba(255.0, 13.0, 87.0, 0.1880768469003763)">0.002</th></tr><tr><th>Option</th><th style="background:rgba(255.0, 13.0, 87.0, 0.1880768469003763)">0.002</th></tr><tr><th></th><th style="background:rgba(255.0, 13.0, 87.0, 0.14866310160427795)">0.002</th></tr><tr><th>2</th><th style="background:rgba(255.0, 13.0, 87.0, 0.2590215884333532)">0.003</th></tr><tr><th>:</th><th style="background:rgba(255.0, 13.0, 87.0, 0.24325609031491383)">0.002</th></tr><tr><th>▁Take</th><th style="background:rgba(30.0, 136.0, 229.0, 0.014656367597544028)">-0.0</th></tr><tr><th>▁a</th><th style="background:rgba(30.0, 136.0, 229.0, 0.10924935630817977)">-0.001</th></tr><tr><th></th><th style="background:rgba(30.0, 136.0, 229.0, 0.06983561101208147)">-0.001</th></tr><tr><th>1</th><th style="background:rgba(30.0, 136.0, 229.0, 0.1723113487819369)">-0.002</th></tr><tr><th>0</th><th style="background:rgba(255.0, 13.0, 87.0, 0.2117250940780353)">0.002</th></tr><tr><th>▁minute</th><th style="background:rgba(30.0, 136.0, 229.0, 0.3851455733808674)">-0.004</th></tr><tr><th>▁bus</th><th style="background:rgba(30.0, 136.0, 229.0, 0.35361457714398903)">-0.004</th></tr><tr><th>,</th><th style="background:rgba(255.0, 13.0, 87.0, 0.14078035254505847)">0.001</th></tr><tr><th>▁then</th><th style="background:rgba(255.0, 13.0, 87.0, 0.4797385620915033)">0.005</th></tr><tr><th>▁an</th><th style="background:rgba(255.0, 13.0, 87.0, 0.13289760348583876)">0.001</th></tr><tr><th>▁hour</th><th style="background:rgba(255.0, 13.0, 87.0, 0.5428005545652606)">0.005</th></tr><tr><th>▁train</th><th style="background:rgba(30.0, 136.0, 229.0, 0.6452762923351159)">-0.006</th></tr><tr><th>,</th><th style="background:rgba(30.0, 136.0, 229.0, 0.25113883937413345)">-0.003</th></tr><tr><th>▁and</th><th style="background:rgba(30.0, 136.0, 229.0, 0.22749059219647463)">-0.002</th></tr><tr><th>▁finally</th><th style="background:rgba(255.0, 13.0, 87.0, 0.03830461477520289)">0.0</th></tr><tr><th>▁a</th><th style="background:rgba(30.0, 136.0, 229.0, 0.2038423450188156)">-0.002</th></tr><tr><th></th><th style="background:rgba(30.0, 136.0, 229.0, 0.06983561101208147)">-0.001</th></tr><tr><th>3</th><th style="background:rgba(255.0, 13.0, 87.0, 0.1171321053673995)">0.001</th></tr><tr><th>0</th><th style="background:rgba(30.0, 136.0, 229.0, 0.09348385818974048)">-0.001</th></tr><tr><th>▁minute</th><th style="background:rgba(30.0, 136.0, 229.0, 0.8029312735195088)">-0.008</th></tr><tr><th>▁bi</th><th style="background:rgba(30.0, 136.0, 229.0, 0.21172509407803525)">-0.002</th></tr><tr><th>ke</th><th style="background:rgba(255.0, 13.0, 87.0, 0.29055258467023165)">0.003</th></tr><tr><th>▁ride</th><th style="background:rgba(30.0, 136.0, 229.0, 0.21172509407803525)">-0.002</th></tr><tr><th>.</th><th style="background:rgba(30.0, 136.0, 229.0, 0.6531590413943356)">-0.007</th></tr><tr><th>&lt;0x0A&gt;</th><th style="background:rgba(230.2941176470614, 26.505882352939775, 102.59215686274348, 0.0)">0.0</th></tr><tr><th>Wh</th><th style="background:rgba(255.0, 13.0, 87.0, 0.06195286195286207)">0.001</th></tr><tr><th>ich</th><th style="background:rgba(255.0, 13.0, 87.0, 0.3693800752624282)">0.004</th></tr><tr><th>▁of</th><th style="background:rgba(30.0, 136.0, 229.0, 0.1328976034858387)">-0.001</th></tr><tr><th>▁the</th><th style="background:rgba(30.0, 136.0, 229.0, 0.03830461477520309)">-0.0</th></tr><tr><th>▁options</th><th style="background:rgba(30.0, 136.0, 229.0, 0.9684690037631214)">-0.01</th></tr><tr><th>▁above</th><th style="background:rgba(255.0, 13.0, 87.0, 0.5979797979797981)">0.006</th></tr><tr><th>▁is</th><th style="background:rgba(255.0, 13.0, 87.0, 0.08560110913052081)">0.001</th></tr><tr><th>▁faster</th><th style="background:rgba(30.0, 136.0, 229.0, 0.35361457714398903)">-0.004</th></tr><tr><th>▁to</th><th style="background:rgba(30.0, 136.0, 229.0, 0.1723113487819369)">-0.002</th></tr><tr><th>▁get</th><th style="background:rgba(255.0, 13.0, 87.0, 0.06195286195286207)">0.001</th></tr><tr><th>▁to</th><th style="background:rgba(30.0, 136.0, 229.0, 0.014656367597544028)">-0.0</th></tr><tr><th>▁work</th><th style="background:rgba(30.0, 136.0, 229.0, 0.022539116656763607)">-0.0</th></tr><tr><th>?</th><th style="background:rgba(30.0, 136.0, 229.0, 1.0)">-0.022</th></tr><tr><th>&lt;0x0A&gt;</th><th style="background:rgba(230.2941176470614, 26.505882352939775, 102.59215686274348, 0.0)">0.0</th></tr><tr><th>The</th><th style="background:rgba(255.0, 13.0, 87.0, 0.06195286195286207)">0.001</th></tr><tr><th>▁answer</th><th style="background:rgba(255.0, 13.0, 87.0, 0.46397306397306415)">0.005</th></tr><tr><th>▁is</th><th style="background:rgba(30.0, 136.0, 229.0, 0.10924935630817977)">-0.001</th></tr><tr><th>▁Option</th><th style="background:rgba(255.0, 13.0, 87.0, 0.7319865319865321)">0.007</th></tr><tr><th></th><th style="background:rgba(30.0, 136.0, 229.0, 0.05407011289364222)">-0.001</th></tr><tr><th>2 → 1</th><th style="background:rgba(0.0, 0.0, 0.0, 0.0)"></th></tr><tr style="outline: thin solid"><th><b>contrast_prob_diff</b></th><th>0.078</th><tr style="outline: thin solid"><th><b>probability</b></th><th><b>0.537</b></th></table>
</div>

</div>
</div>
</div>
</html>
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ Inseq is still in early development and is currently maintained by a small team
examples/custom_attribute_target
examples/attribute_mmt
examples/locate_gpt2_knowledge
examples/petals
examples/tuned_lens
examples/faq

Expand Down
6 changes: 5 additions & 1 deletion inseq/models/decoder_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,11 @@ def get_forward_output(
return self.model(
input_ids=batch.input_ids if not use_embeddings else None,
inputs_embeds=batch.input_embeds if use_embeddings else None,
attention_mask=batch.attention_mask,
# Hacky fix for petals' distributed models while awaiting attention_mask support:
# https://github.com/bigscience-workshop/petals/pull/206
attention_mask=(
batch.attention_mask if not self.model.__class__.__name__.startswith("Distributed") else None
),
**kwargs,
)

Expand Down
30 changes: 23 additions & 7 deletions inseq/models/huggingface_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""HuggingFace Seq2seq model."""
import logging
from abc import abstractmethod
from inspect import getfullargspec
from typing import Dict, List, NoReturn, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -209,14 +210,23 @@ def generate(
):
inputs = self.encode(inputs)
inputs = inputs.to(self.device)
generation_out = self.model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
return_dict_in_generate=True,
**kwargs,
)
if "input_ids" not in getfullargspec(self.model.generate).args:
logger.warning(
"Model does not support input_ids in generation. "
"Assuming a petals AutoDistributedModelForCausalLM is being used."
)
return_generation_output = False
sequences = self.model.generate(inputs=inputs.input_ids, **kwargs)
else:
generation_out = self.model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
return_dict_in_generate=True,
**kwargs,
)
sequences = generation_out.sequences
texts = self.tokenizer.batch_decode(
generation_out.sequences,
sequences,
skip_special_tokens=True,
)
if return_generation_output:
Expand Down Expand Up @@ -384,6 +394,8 @@ def clean_tokens(
)
if clean_tok:
clean_tokens.append(clean_tok)
elif tok:
clean_tokens.append(" ")
return clean_tokens
return [self.clean_tokens(token_seq, skip_special_tokens, as_targets) for token_seq in tokens]

Expand Down Expand Up @@ -442,6 +454,8 @@ def get_decoder(self) -> torch.nn.Module:
def get_attentions_dict(
output: Seq2SeqLMOutput,
) -> Dict[str, MultiLayerMultiUnitScoreTensor]:
if output.encoder_attentions is None or output.decoder_attentions is None:
raise ValueError("Model does not support attribution relying on attention outputs.")
return {
"encoder_self_attentions": torch.stack(output.encoder_attentions, dim=1),
"decoder_self_attentions": torch.stack(output.decoder_attentions, dim=1),
Expand Down Expand Up @@ -481,6 +495,8 @@ def configure_embeddings_scale(self):

@staticmethod
def get_attentions_dict(output: CausalLMOutput) -> Dict[str, MultiLayerMultiUnitScoreTensor]:
if output.attentions is None:
raise ValueError("Model does not support attribution relying on attention outputs.")
return {
"decoder_self_attentions": torch.stack(output.attentions, dim=1),
}
Loading

0 comments on commit ea9d982

Please sign in to comment.