diff --git a/.conda/meta.yaml b/.conda/meta.yaml index 1b6b1cf5c..070bc4a31 100644 --- a/.conda/meta.yaml +++ b/.conda/meta.yaml @@ -13,7 +13,7 @@ build: requirements: host: - - python>=3.9 + - python>=3.10 - setuptools run: - numpy diff --git a/.github/workflows/test-conda-cpu.yml b/.github/workflows/test-conda-cpu.yml index e0da5e42e..121102969 100644 --- a/.github/workflows/test-conda-cpu.yml +++ b/.github/workflows/test-conda-cpu.yml @@ -15,7 +15,7 @@ jobs: tests: strategy: matrix: - python_version: ["3.9", "3.10", "3.11", "3.12"] + python_version: ["3.10", "3.11", "3.12"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job.yml@main with: diff --git a/.github/workflows/test-pip-cpu.yml b/.github/workflows/test-pip-cpu.yml index 42bb3a708..51075cc46 100644 --- a/.github/workflows/test-pip-cpu.yml +++ b/.github/workflows/test-pip-cpu.yml @@ -14,7 +14,7 @@ jobs: matrix: pytorch_args: ["-v 2.3.0", "-v 2.4.0", "-v 2.5.0", "-v 2.6.0", "-v 2.7.0"] transformers_args: ["-t 4.38.0", "-t 4.39.0", "-t 4.41.0", "-t 4.43.0", "-t 4.45.2"] - docker_img: ["cimg/python:3.9", "cimg/python:3.10", "cimg/python:3.11", "cimg/python:3.12"] + docker_img: ["cimg/python:3.10", "cimg/python:3.11", "cimg/python:3.12"] fail-fast: false uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main with: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 04642bf1c..a1f2be883 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -50,7 +50,7 @@ Github Actions will fail on your PR if it does not adhere to the ufmt or flake8 #### Type Hints -Captum is fully typed using Python 3.9+ +Captum is fully typed using Python 3.10+ [type hints](https://www.python.org/dev/peps/pep-0484/). We expect any contributions to also use proper type annotations, and we enforce consistency of these in our continuous integration tests. @@ -63,7 +63,7 @@ Then run this script from the repository root: ``` Note that we expect mypy to have version 0.760 or higher, and when type checking, use PyTorch 1.10 or higher due to fixes to the PyTorch type hints available. We also use the Literal feature which is -available only in Python 3.9 or above. +available only in Python 3.10 or above. We also use [pyre](https://pyre-check.org/) for type checking. For contributors, the nightly version of pyre is used which can be installed with pip `pip install pyre-check-nightly`. To run pyre, you can diff --git a/README.md b/README.md index d7d87fd6d..c1a49b869 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Captum can also be used by application engineers who are using trained models in ## Installation **Installation Requirements** -- Python >= 3.9 +- Python >= 3.10 - PyTorch >= 2.3 diff --git a/captum/_utils/typing.py b/captum/_utils/typing.py index 512c910f0..f0f7427e9 100644 --- a/captum/_utils/typing.py +++ b/captum/_utils/typing.py @@ -3,17 +3,7 @@ # pyre-strict from collections import UserDict -from typing import ( - List, - Literal, - Optional, - overload, - Protocol, - Tuple, - TYPE_CHECKING, - TypeVar, - Union, -) +from typing import List, Literal, Optional, overload, Protocol, Tuple, TypeVar, Union from torch import Tensor from torch.nn import Module @@ -51,11 +41,7 @@ # pyre-ignore[24]: Generic type `slice` expects 3 type parameters. SliceIntType = slice # type: ignore -# Necessary for Python >=3.7 and <3.9! -if TYPE_CHECKING: - BatchEncodingType = UserDict[Union[int, str], object] -else: - BatchEncodingType = UserDict +BatchEncodingType = UserDict[Union[int, str], object] class TokenizerLike(Protocol): diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index c7422d8d9..185483fa7 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -5,18 +5,20 @@ from abc import ABC from copy import copy from dataclasses import dataclass -from textwrap import dedent, shorten +from textwrap import shorten from typing import ( Any, Callable, cast, Dict, + Generic, List, Optional, Tuple, Type, TYPE_CHECKING, + TypeVar, Union, ) @@ -56,130 +58,140 @@ "temperature": None, "top_p": None, } +TInputValue = TypeVar("TInputValue") +TTargetValue = TypeVar("TTargetValue") -@dataclass -class LLMAttributionResult: +@dataclass(kw_only=True) +class BaseLLMAttributionResult(ABC, Generic[TInputValue, TTargetValue]): """ Data class for the return result of LLMAttribution, which includes the necessary properties of the attribution. It also provides utilities to help present and plot the result in different forms. """ - input_tokens: List[str] - output_tokens: List[str] - # pyre-ignore[13]: initialized via a property setter - _seq_attr: Tensor - _token_attr: Optional[Tensor] = None - _output_probs: Optional[Tensor] = None + input_values: List[TInputValue] # ablated values + target_names: List[str] # names of each target, e.g. judge name or tokens + _target_values: Optional[ + List[TTargetValue] + ] # value for each target name e.g. token prob + _aggregate_attr: Tensor # 1D [# input_values] + _element_attr: Optional[Tensor] = None # 2D [# target_names, # input_values] + aggregate_descriptor: str = "Aggregate" + element_descriptor: str = "Element" def __init__( self, *, - input_tokens: List[str], - output_tokens: List[str], - seq_attr: npt.ArrayLike, - token_attr: Optional[npt.ArrayLike] = None, - output_probs: Optional[npt.ArrayLike] = None, + input_values: List[TInputValue], + target_names: List[str], + target_values: Optional[Union[npt.ArrayLike, List[TTargetValue]]] = None, + aggregate_attr: npt.ArrayLike, + element_attr: Optional[npt.ArrayLike] = None, + aggregate_descriptor: str = "Aggregate", + element_descriptor: str = "Element", ) -> None: - self.input_tokens = input_tokens - self.output_tokens = output_tokens - self.seq_attr = seq_attr - self.token_attr = token_attr - self.output_probs = output_probs + self.input_values = input_values + self.target_names = target_names + self.target_values = target_values + self.aggregate_attr = aggregate_attr + self.element_attr = element_attr + self.aggregate_descriptor = aggregate_descriptor + self.element_descriptor = element_descriptor @property - def seq_attr(self) -> Tensor: - return self._seq_attr + def aggregate_attr(self) -> Tensor: + return self._aggregate_attr - @seq_attr.setter - def seq_attr(self, seq_attr: npt.ArrayLike) -> None: - if isinstance(seq_attr, Tensor): - self._seq_attr = seq_attr + @aggregate_attr.setter + def aggregate_attr(self, aggregate_attr: npt.ArrayLike) -> None: + if isinstance(aggregate_attr, Tensor): + self._aggregate_attr = aggregate_attr else: - self._seq_attr = torch.tensor(seq_attr) + self._aggregate_attr = torch.tensor(aggregate_attr) # IDEA: in the future we might want to support higher dim seq_attr # (e.g. attention w.r.t. multiple layers, gradients w.r.t. different classes) - assert len(self._seq_attr.shape) == 1, "seq_attr must be a 1D tensor" + assert len(self._aggregate_attr.shape) == 1, "seq_attr must be a 1D tensor" assert ( - len(self.input_tokens) == self._seq_attr.shape[0] + len(self.input_values) == self._aggregate_attr.shape[0] ), "seq_attr and input_tokens must have the same length" @property - def token_attr(self) -> Optional[Tensor]: - return self._token_attr - - @token_attr.setter - def token_attr(self, token_attr: Optional[npt.ArrayLike]) -> None: - if token_attr is None: - self._token_attr = None - elif isinstance(token_attr, Tensor): - self._token_attr = token_attr + def element_attr(self) -> Optional[Tensor]: + return self._element_attr + + @element_attr.setter + def element_attr(self, element_attr: Optional[npt.ArrayLike]) -> None: + if element_attr is None: + self._element_attr = None + elif isinstance(element_attr, Tensor): + self._element_attr = element_attr else: - self._token_attr = torch.tensor(token_attr) + self._element_attr = torch.tensor(element_attr) - if self._token_attr is not None: + if self._element_attr is not None: # IDEA: in the future we might want to support higher dim seq_attr - assert len(self._token_attr.shape) == 2, "token_attr must be a 2D tensor" - assert self._token_attr.shape == ( - len(self.output_tokens), - len(self.input_tokens), - ), dedent( - f"""\ - Expect token_attr to have shape - {len(self.output_tokens), len(self.input_tokens)}, - got {self._token_attr.shape} - """ + assert len(self._element_attr.shape) == 2, "token_attr must be a 2D tensor" + assert self._element_attr.shape == ( + len(self.target_names), + len(self.input_values), + ), ( + "Expect token_attr to have shape " + f"({len(self.target_names), len(self.input_values)}), " + f"got {self._element_attr.shape}" ) @property - def output_probs(self) -> Optional[Tensor]: - return self._output_probs + def target_values(self) -> Optional[List[TTargetValue]]: + return self._target_values - @output_probs.setter - def output_probs(self, output_probs: Optional[npt.ArrayLike]) -> None: - if output_probs is None: - self._output_probs = None - elif isinstance(output_probs, Tensor): - self._output_probs = output_probs + @target_values.setter + def target_values( + self, target_values: Optional[Union[npt.ArrayLike, List[TTargetValue]]] + ) -> None: + if target_values is None: + self._target_values = None + elif isinstance(target_values, (Tensor, np.ndarray)): + self._target_values = target_values.tolist() else: - self._output_probs = torch.tensor(output_probs) + # pyre-ignore[6]: should be iterable + self._target_values = list(target_values) - if self._output_probs is not None: - assert ( - len(self._output_probs.shape) == 1 - ), "output_probs must be a 1D tensor" - assert ( - len(self.output_tokens) == self._output_probs.shape[0] - ), "seq_attr and input_tokens must have the same length" + if self._target_values is not None: + assert len(self._target_values) == len( + self.target_names + ), f"{len(self._target_values)=} and {len(self.target_names)=} must have the same length" @property - def seq_attr_dict(self) -> Dict[str, float]: - return {k: v for v, k in zip(self.seq_attr.cpu().tolist(), self.input_tokens)} + def aggregate_attr_dict(self) -> Dict[TInputValue, float]: + return { + k: v for v, k in zip(self.aggregate_attr.cpu().tolist(), self.input_values) + } - def plot_token_attr( + def plot_element_attr( self, show: bool = False ) -> Union[None, Tuple["Figure", "Axes"]]: """ Generate a matplotlib plot for visualising the attribution - of the output tokens. + of the output elements. Args: show (bool): whether to show the plot directly or return the figure and axis Default: False """ - if self.token_attr is None: + if self.element_attr is None: raise ValueError( - "token_attr is None (no token-level attribution was performed), please " - "use plot_seq_attr instead for the sequence-level attribution plot" + f"element_attr is None (no {self.element_descriptor.lower()}-level attribution was " + "performed), please use plot_aggregate_attr instead for the " + f"{self.aggregate_descriptor}-level attribution plot" ) - token_attr = self.token_attr.cpu() + element_attr = self.element_attr.cpu() # maximum absolute attribution value # used as the boundary of normalization # always keep 0 as the mid point to differentiate pos/neg attr - max_abs_attr_val = token_attr.abs().max().item() + max_abs_attr_val = element_attr.abs().max().item() import matplotlib.pyplot as plt @@ -189,7 +201,7 @@ def plot_token_attr( ax.grid(False) # Plot the heatmap - data = token_attr.numpy() + data = element_attr.numpy() fig.set_size_inches( max(data.shape[1] * 1.3, 6.4), max(data.shape[0] / 2.5, 4.8) @@ -219,17 +231,19 @@ def plot_token_attr( # Create colorbar cbar = fig.colorbar(im, ax=ax) # type: ignore - cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom") + cbar.ax.set_ylabel( + f"{self.element_descriptor} Attribution", rotation=-90, va="bottom" + ) # Show all ticks and label them with the respective list entries. - shortened_tokens = [ + shortened_values = [ shorten(repr(t)[1:-1], width=50, placeholder="...") - for t in self.input_tokens + for t in self.input_values ] - ax.set_xticks(np.arange(data.shape[1]), labels=shortened_tokens) + ax.set_xticks(np.arange(data.shape[1]), labels=shortened_values) ax.set_yticks( np.arange(data.shape[0]), - labels=[repr(token)[1:-1] for token in self.output_tokens], + labels=[repr(name)[1:-1] for name in self.target_names], ) # Let the horizontal axes labeling appear on top. @@ -259,10 +273,12 @@ def plot_token_attr( else: return fig, ax - def plot_seq_attr(self, show: bool = False) -> Union[None, Tuple["Figure", "Axes"]]: + def plot_aggregated_attr( + self, show: bool = False + ) -> Union[None, Tuple["Figure", "Axes"]]: """ Generate a matplotlib plot for visualising the attribution - of the output sequence. + of the aggregated output. Args: show (bool): whether to show the plot directly or return the figure and axis @@ -273,15 +289,15 @@ def plot_seq_attr(self, show: bool = False) -> Union[None, Tuple["Figure", "Axes fig, ax = plt.subplots() - data = self.seq_attr.cpu().numpy() + data = self.aggregate_attr.cpu().numpy() fig.set_size_inches(max(data.shape[0] / 2, 6.4), max(data.shape[0] / 4, 4.8)) - shortened_tokens = [ + shortened_values = [ shorten(repr(t)[1:-1], width=50, placeholder="...") - for t in self.input_tokens + for t in self.input_values ] - ax.set_xticks(range(data.shape[0]), labels=shortened_tokens) + ax.set_xticks(range(data.shape[0]), labels=shortened_values) ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) @@ -309,7 +325,9 @@ def plot_seq_attr(self, show: bool = False) -> Union[None, Tuple["Figure", "Axes color="#d0365b", ) - ax.set_ylabel("Sequence Attribution", rotation=90, va="bottom") + ax.set_ylabel( + f"{self.aggregate_descriptor} Attribution", rotation=90, va="bottom" + ) if show: plt.show() @@ -317,6 +335,85 @@ def plot_seq_attr(self, show: bool = False) -> Union[None, Tuple["Figure", "Axes else: return fig, ax + # Aliases + + @property + def input_tokens(self) -> List[TInputValue]: + return self.input_values + + @input_tokens.setter + def input_tokens(self, input_tokens: List[TInputValue]) -> None: + self.input_values = input_tokens + + @property + def output_tokens(self) -> List[str]: + return self.target_names + + @output_tokens.setter + def output_tokens(self, output_tokens: List[str]) -> None: + self.target_names = output_tokens + + @property + def output_probs(self) -> Optional[List[TTargetValue]]: + return self.target_values + + @output_probs.setter + def output_probs(self, output_probs: Optional[npt.ArrayLike]) -> None: + self.target_values = output_probs + + @property + def seq_attr(self) -> Tensor: + return self.aggregate_attr + + @seq_attr.setter + def seq_attr(self, seq_attr: npt.ArrayLike) -> None: + self.aggregate_attr = seq_attr + + @property + def token_attr(self) -> Optional[Tensor]: + return self.element_attr + + @token_attr.setter + def token_attr(self, token_attr: Optional[npt.ArrayLike]) -> None: + self.element_attr = token_attr + + @property + def seq_attr_dict(self) -> Dict[TInputValue, float]: + return self.aggregate_attr_dict + + def plot_token_attr( + self, show: bool = False + ) -> Union[None, Tuple["Figure", "Axes"]]: + return self.plot_element_attr(show=show) + + def plot_seq_attr(self, show: bool = False) -> Union[None, Tuple["Figure", "Axes"]]: + return self.plot_aggregated_attr(show=show) + + +@dataclass(kw_only=True) +# pyre-ignore[13]: _aggregate_attr and _target_values initialized via setters +class LLMAttributionResult(BaseLLMAttributionResult[str, float]): + """LLM Attribution Result for the captum.attr API""" + + def __init__( + self, + *, + input_tokens: List[str], + output_tokens: List[str], + seq_attr: npt.ArrayLike, + token_attr: Optional[npt.ArrayLike] = None, + output_probs: Optional[npt.ArrayLike] = None, + ) -> None: + super().__init__( + input_values=input_tokens, + target_names=output_tokens, + target_values=output_probs, + aggregate_attr=seq_attr, + element_attr=token_attr, + aggregate_descriptor="Sequence", + element_descriptor="Token", + ) + def _clean_up_pretty_token(token: str) -> str: """Remove newlines and leading/trailing whitespace from token.""" diff --git a/captum/concept/_core/cav.py b/captum/concept/_core/cav.py index 9cd5cc313..8b96056d3 100644 --- a/captum/concept/_core/cav.py +++ b/captum/concept/_core/cav.py @@ -4,7 +4,7 @@ import os from contextlib import AbstractContextManager, nullcontext -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import Any, Dict, List, Optional import numpy as np import torch @@ -168,11 +168,7 @@ def load( cavs_path = CAV.assemble_save_path(cavs_path, model_id, concepts, layer) if os.path.exists(cavs_path): - # Necessary for Python >=3.7 and <3.9! - if TYPE_CHECKING: - ctx: AbstractContextManager[None, None] - else: - ctx: AbstractContextManager + ctx: AbstractContextManager[None, None] if hasattr(torch.serialization, "safe_globals"): safe_globals = [ # pyre-ignore[16]: Module `numpy.core.multiarray` has no attribute diff --git a/pyproject.toml b/pyproject.toml index 9608c4f12..672071496 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,4 +2,4 @@ first_party_detection = false [tool.black] -target-version = ['py39'] +target-version = ['py310'] diff --git a/setup.py b/setup.py index 32126301b..10dbfb90c 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ from setuptools import find_packages, setup REQUIRED_MAJOR = 3 -REQUIRED_MINOR = 9 +REQUIRED_MINOR = 10 # Check for python version if sys.version_info < (REQUIRED_MAJOR, REQUIRED_MINOR):