Skip to content

Commit

Permalink
Fix bug in BatchDict (#450)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo committed Sep 27, 2023
1 parent c8abc75 commit 08d01cd
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ default_stages: [ commit ]

repos:
- repo: https://github.com/asottile/pyupgrade
rev: v3.12.0
rev: v3.13.0
hooks:
- id: pyupgrade
args: [--py39-plus]
Expand Down Expand Up @@ -52,7 +52,7 @@ repos:
additional_dependencies:
- black>=23.9.1
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.290
rev: v0.0.291
hooks:
- id: ruff
args: [--fix]
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redcat"
version = "0.0.12"
version = "0.0.13a0"
description = "A library to manipulate batches of examples"
readme = "README.md"
authors = ["Thibaut Durand <durand.tibo+gh@gmail.com>"]
Expand Down
18 changes: 9 additions & 9 deletions src/redcat/batchdict.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def data(self) -> dict[Hashable, TBaseBatch]:
def __contains__(self, key: Hashable) -> bool:
return key in self._data

def __getitem__(self, key: Hashable) -> BaseBatch:
def __getitem__(self, key: Hashable) -> TBaseBatch:
return self._data[key]

def __iter__(self) -> Iterator[Hashable]:
Expand All @@ -83,14 +83,14 @@ def __iter__(self) -> Iterator[Hashable]:
def __len__(self) -> int:
return len(self._data)

def __setitem__(self, key: Hashable, value: BaseBatch) -> None:
def __setitem__(self, key: Hashable, value: TBaseBatch) -> None:
if value.batch_size != self.batch_size:
raise RuntimeError(
f"Incorrect batch size. Expected {self.batch_size} but received {value.batch_size}"
)
self._data[key] = value

def get(self, key: Hashable, default: BaseBatch | None = None) -> BaseBatch | None:
def get(self, key: Hashable, default: TBaseBatch | None = None) -> TBaseBatch | None:
return self._data.get(key, default)

def items(self) -> ItemsView:
Expand Down Expand Up @@ -330,12 +330,12 @@ def shuffle_along_seq_(self, generator: torch.Generator | None = None) -> None:
# Indexing, slicing, joining, mutating operations #
##########################################################

def append(self, other: BatchDict) -> None:
def append(self, other: TBaseBatch) -> None:
check_same_keys(self.data, other.data)
for key, value in self._data.items():
value.append(other[key])

def cat_along_seq(self, batches: BatchDict | Sequence[BatchDict]) -> TBatchDict:
def cat_along_seq(self, batches: TBaseBatch | Sequence[TBaseBatch]) -> TBatchDict:
r"""Concatenates the data of the batch(es) to the current batch
along the sequence dimension and creates a new batch.
Expand Down Expand Up @@ -381,7 +381,7 @@ def cat_along_seq(self, batches: BatchDict | Sequence[BatchDict]) -> TBatchDict:
out[key] = val
return self.__class__(out)

def cat_along_seq_(self, batches: BatchDict | Sequence[BatchDict]) -> None:
def cat_along_seq_(self, batches: TBaseBatch | Sequence[TBaseBatch]) -> None:
r"""Concatenates the data of the batch(es) to the current batch
along the sequence dimension and creates a new batch.
Expand Down Expand Up @@ -550,7 +550,7 @@ def split_along_batch(
batches.append(self.__class__({key: value for key, value in zip(keys, values)}))
return tuple(batches)

def take_along_seq(self, indices: BaseBatch | np.ndarray | Tensor | Sequence) -> TBatchDict:
def take_along_seq(self, indices: TBaseBatch | np.ndarray | Tensor | Sequence) -> TBatchDict:
r"""Takes values along the sequence dimension.
Args:
Expand Down Expand Up @@ -602,7 +602,7 @@ def summary(self) -> str:
return f"{self.__class__.__qualname__}(\n {str_indent(data_str)}\n)"


def check_same_batch_size(data: dict[Hashable, TBaseBatch]) -> None:
def check_same_batch_size(data: dict[Hashable, BaseBatch]) -> None:
r"""Checks if the all the batches in a group have the same batch
size.
Expand Down Expand Up @@ -673,7 +673,7 @@ def check_same_keys(data1: dict, data2: dict) -> None:
raise RuntimeError(f"Keys do not match: ({keys1} vs {keys2})")


def get_seq_lens(data: dict[Hashable, TBaseBatch]) -> set[int]:
def get_seq_lens(data: dict[Hashable, BaseBatch]) -> set[int]:
r"""Gets the sequence lengths from the inputs.
Args:
Expand Down

0 comments on commit 08d01cd

Please sign in to comment.