Skip to content

Commit

Permalink
Migrate from torchtyping to jaxtyping (#226)
Browse files Browse the repository at this point in the history
* Increase minimum Python version to 3.9.

* Replace torchtyping with jaxtyping.

* Update poetry.lock.

* Bump requirements

---------

Co-authored-by: Gabriele Sarti <gabriele.sarti996@gmail.com>
  • Loading branch information
carschno and gsarti committed Oct 30, 2023
1 parent 6feda95 commit 33f6932
Show file tree
Hide file tree
Showing 10 changed files with 550 additions and 574 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
if: github.actor != 'dependabot[bot]' && github.actor != 'dependabot-preview[bot]'
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
Expand Down
6 changes: 3 additions & 3 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from torchtyping import TensorType
from jaxtyping import Int

from ...data import (
DecoderOnlyBatch,
Expand Down Expand Up @@ -497,9 +497,9 @@ def attribute(
def filtered_attribute_step(
self,
batch: Union[DecoderOnlyBatch, EncoderDecoderBatch],
target_ids: TensorType["batch_size", 1, int],
target_ids: Int[torch.Tensor, "batch_size 1"],
attributed_fn: Callable[..., SingleScorePerStepTensor],
target_attention_mask: Optional[TensorType["batch_size", 1, int]] = None,
target_attention_mask: Optional[Int[torch.Tensor, "batch_size 1"]] = None,
attribute_target: bool = False,
step_scores: List[str] = [],
attribution_args: Dict[str, Any] = {},
Expand Down
8 changes: 4 additions & 4 deletions inseq/attr/feat/ops/monotonic_path_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from typing import Any, List, Optional, Tuple, Union

import torch
from torchtyping import TensorType
from jaxtyping import Float, Int

from ....utils import is_joblib_available, is_scikitlearn_available

Expand Down Expand Up @@ -138,8 +138,8 @@ def load(

def scale_inputs(
self,
input_ids: TensorType["batch_size", "seq_len", int],
baseline_ids: TensorType["batch_size", "seq_len", int],
input_ids: Int[torch.Tensor, "batch_size seq_len"],
baseline_ids: Int[torch.Tensor, "batch_size seq_len"],
n_steps: Optional[int] = None,
scale_strategy: Optional[str] = None,
) -> MultiStepEmbeddingsTensor:
Expand Down Expand Up @@ -208,7 +208,7 @@ def find_path(

def build_monotonic_path_embedding(
self, word_path: List[int], baseline_idx: int, n_steps: int = 30
) -> TensorType["n_steps", "embed_size", float]:
) -> Float[torch.Tensor, "n_steps embed_size"]:
"""Build a monotonic path embedding from a word path."""
baseline_vec = self.vocabulary_embeddings[baseline_idx]
monotonic_embs = [self.vocabulary_embeddings[word_path[0]]]
Expand Down
4 changes: 2 additions & 2 deletions inseq/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
import torch
from torchtyping import TensorType
from jaxtyping import Int

from ..utils import pretty_dict

Expand Down Expand Up @@ -133,7 +133,7 @@ def slice_batch(self: TensorClass, subscript) -> TensorClass:
**{field.name: self._slice_batch(getattr(self, field.name), subscript) for field in fields(self.__class__)}
)

def select_active(self: TensorClass, mask: TensorType["batch_size", 1, int]) -> TensorClass:
def select_active(self: TensorClass, mask: Int[torch.Tensor, "batch_size 1"]) -> TensorClass:
return self.__class__(
**{field.name: self._select_active(getattr(self, field.name), mask) for field in fields(self.__class__)}
)
Expand Down
10 changes: 5 additions & 5 deletions inseq/utils/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import logging
from typing import TYPE_CHECKING, Any, Callable, List, Literal, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Sequence, Tuple, Union

import torch
import torch.nn.functional as F
from jaxtyping import Int, Num
from torch.backends.cuda import is_built as is_cuda_built
from torch.backends.mps import is_available as is_mps_available
from torch.backends.mps import is_built as is_mps_built
from torch.cuda import is_available as is_cuda_available
from torchtyping import TensorType

if TYPE_CHECKING:
pass
Expand All @@ -23,9 +23,9 @@
@torch.no_grad()
def remap_from_filtered(
original_shape: Tuple[int, ...],
mask: TensorType["batch_size", 1, int],
filtered: TensorType["filtered_batch_size", Any],
) -> TensorType["batch_size", Any]:
mask: Int[torch.Tensor, "batch_size 1"],
filtered: Num[torch.Tensor, "filtered_batch_size"],
) -> Num[torch.Tensor, "batch_size"]:
index = mask.squeeze(-1).nonzero().reshape(-1, 1)
while index.ndim < filtered.ndim:
index = index.unsqueeze(-1)
Expand Down
38 changes: 19 additions & 19 deletions inseq/utils/typing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from dataclasses import dataclass
from typing import Optional, Sequence, Tuple, Union

from torch import float32, long
from torchtyping import TensorType
import torch
from jaxtyping import Float, Float32, Int64
from transformers import PreTrainedModel

TextInput = Union[str, Sequence[str]]
Expand Down Expand Up @@ -40,26 +40,26 @@ class TextSequences:

IndexSpan = Union[Tuple[int, int], Sequence[Tuple[int, int]]]

IdsTensor = TensorType["batch_size", "seq_len", long]
TargetIdsTensor = TensorType["batch_size", long]
ExpandedTargetIdsTensor = TensorType["batch_size", 1, long]
EmbeddingsTensor = TensorType["batch_size", "seq_len", "embed_size", float]
MultiStepEmbeddingsTensor = TensorType["batch_size_x_n_steps", "seq_len", "embed_size", float]
VocabularyEmbeddingsTensor = TensorType["vocab_size", "embed_size", float]
LogitsTensor = TensorType["batch_size", "vocab_size", float]
ScoreTensor = TensorType["batch_size", "other_dims", float]
MultiUnitScoreTensor = TensorType["batch_size", "n_units", "other_dims", float]
MultiLayerScoreTensor = TensorType["batch_size", "n_layers", "other_dims", float]
MultiLayerMultiUnitScoreTensor = TensorType["batch_size", "n_layers", "n_units", "seq_len", "seq_len", float]
MultiLayerEmbeddingsTensor = TensorType["batch_size", "n_layers", "seq_len", "embed_size", float]
IdsTensor = Int64[torch.Tensor, "batch_size seq_len"]
TargetIdsTensor = Int64[torch.Tensor, "batch_size"]
ExpandedTargetIdsTensor = Int64[torch.Tensor, "batch_size 1"]
EmbeddingsTensor = Float[torch.Tensor, "batch_size seq_len embed_size"]
MultiStepEmbeddingsTensor = Float[Float, "batch_size_x_n_steps seq_len embed_size"]
VocabularyEmbeddingsTensor = Float[torch.Tensor, "vocab_size embed_size"]
LogitsTensor = Float[torch.Tensor, "batch_size vocab_size"]
ScoreTensor = Float[torch.Tensor, "batch_size other_dims"]
MultiUnitScoreTensor = Float[torch.Tensor, "batch_size n_units other_dims"]
MultiLayerScoreTensor = Float[torch.Tensor, "batch_size n_layers other_dims"]
MultiLayerMultiUnitScoreTensor = Float[torch.Tensor, "batch_size n_layers n_units seq_len seq_len"]
MultiLayerEmbeddingsTensor = Float[torch.Tensor, "batch_size n_layers seq_len embed_size"]

# Step and sequence objects used for stepwise scores (e.g. convergence deltas, probabilities)
SingleScorePerStepTensor = TensorType["batch_size", float32]
SingleScoresPerSequenceTensor = TensorType["generated_seq_len", float32]
SingleScorePerStepTensor = Float32[torch.Tensor, "batch_size"]
SingleScoresPerSequenceTensor = Float32[torch.Tensor, "generated_seq_len"]

# Step and sequence objects used for sequence scores (e.g. attributions over tokens)
MultipleScoresPerStepTensor = TensorType["batch_size", "attributed_seq_len", float32]
MultipleScoresPerSequenceTensor = TensorType["attributed_seq_len", "generated_seq_len", float32]
MultipleScoresPerStepTensor = Float32[torch.Tensor, "batch_size attributed_seq_len"]
MultipleScoresPerSequenceTensor = Float32[torch.Tensor, "attributed_seq_len generated_seq_len"]

# One attribution score per embedding value for every attributed token
# in a single attribution step. Produced by gradient attribution methods.
Expand All @@ -75,7 +75,7 @@ class TextSequences:
# One attribution score per embedding value for every attributed token in attributed_seq
# for all generated tokens in generated_seq. Produced by aggregating GranularStepAttributionTensor
# across multiple steps and separating batches.
GranularSequenceAttributionTensor = TensorType["attributed_seq_len", "generated_seq_len", "embed_size", float32]
GranularSequenceAttributionTensor = Float32[torch.Tensor, "attributed_seq_len generated_seq_len embed_size"]

# One attribution score for every token in attributed_seq for every generated token
# in generated_seq. Produced by aggregating GranularSequenceAttributionTensor over the last dimension,
Expand Down
Loading

0 comments on commit 33f6932

Please sign in to comment.