Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in BatchDict #450

Merged
merged 1 commit into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ default_stages: [ commit ]

repos:
- repo: https://github.com/asottile/pyupgrade
rev: v3.12.0
rev: v3.13.0
hooks:
- id: pyupgrade
args: [--py39-plus]
Expand Down Expand Up @@ -52,7 +52,7 @@ repos:
additional_dependencies:
- black>=23.9.1
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.290
rev: v0.0.291
hooks:
- id: ruff
args: [--fix]
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redcat"
version = "0.0.12"
version = "0.0.13a0"
description = "A library to manipulate batches of examples"
readme = "README.md"
authors = ["Thibaut Durand <durand.tibo+gh@gmail.com>"]
Expand Down
18 changes: 9 additions & 9 deletions src/redcat/batchdict.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def data(self) -> dict[Hashable, TBaseBatch]:
def __contains__(self, key: Hashable) -> bool:
return key in self._data

def __getitem__(self, key: Hashable) -> BaseBatch:
def __getitem__(self, key: Hashable) -> TBaseBatch:
return self._data[key]

def __iter__(self) -> Iterator[Hashable]:
Expand All @@ -83,14 +83,14 @@ def __iter__(self) -> Iterator[Hashable]:
def __len__(self) -> int:
return len(self._data)

def __setitem__(self, key: Hashable, value: BaseBatch) -> None:
def __setitem__(self, key: Hashable, value: TBaseBatch) -> None:
if value.batch_size != self.batch_size:
raise RuntimeError(
f"Incorrect batch size. Expected {self.batch_size} but received {value.batch_size}"
)
self._data[key] = value

def get(self, key: Hashable, default: BaseBatch | None = None) -> BaseBatch | None:
def get(self, key: Hashable, default: TBaseBatch | None = None) -> TBaseBatch | None:
return self._data.get(key, default)

def items(self) -> ItemsView:
Expand Down Expand Up @@ -330,12 +330,12 @@ def shuffle_along_seq_(self, generator: torch.Generator | None = None) -> None:
# Indexing, slicing, joining, mutating operations #
##########################################################

def append(self, other: BatchDict) -> None:
def append(self, other: TBaseBatch) -> None:
check_same_keys(self.data, other.data)
for key, value in self._data.items():
value.append(other[key])

def cat_along_seq(self, batches: BatchDict | Sequence[BatchDict]) -> TBatchDict:
def cat_along_seq(self, batches: TBaseBatch | Sequence[TBaseBatch]) -> TBatchDict:
r"""Concatenates the data of the batch(es) to the current batch
along the sequence dimension and creates a new batch.

Expand Down Expand Up @@ -381,7 +381,7 @@ def cat_along_seq(self, batches: BatchDict | Sequence[BatchDict]) -> TBatchDict:
out[key] = val
return self.__class__(out)

def cat_along_seq_(self, batches: BatchDict | Sequence[BatchDict]) -> None:
def cat_along_seq_(self, batches: TBaseBatch | Sequence[TBaseBatch]) -> None:
r"""Concatenates the data of the batch(es) to the current batch
along the sequence dimension and creates a new batch.

Expand Down Expand Up @@ -550,7 +550,7 @@ def split_along_batch(
batches.append(self.__class__({key: value for key, value in zip(keys, values)}))
return tuple(batches)

def take_along_seq(self, indices: BaseBatch | np.ndarray | Tensor | Sequence) -> TBatchDict:
def take_along_seq(self, indices: TBaseBatch | np.ndarray | Tensor | Sequence) -> TBatchDict:
r"""Takes values along the sequence dimension.

Args:
Expand Down Expand Up @@ -602,7 +602,7 @@ def summary(self) -> str:
return f"{self.__class__.__qualname__}(\n {str_indent(data_str)}\n)"


def check_same_batch_size(data: dict[Hashable, TBaseBatch]) -> None:
def check_same_batch_size(data: dict[Hashable, BaseBatch]) -> None:
r"""Checks if the all the batches in a group have the same batch
size.

Expand Down Expand Up @@ -673,7 +673,7 @@ def check_same_keys(data1: dict, data2: dict) -> None:
raise RuntimeError(f"Keys do not match: ({keys1} vs {keys2})")


def get_seq_lens(data: dict[Hashable, TBaseBatch]) -> set[int]:
def get_seq_lens(data: dict[Hashable, BaseBatch]) -> set[int]:
r"""Gets the sequence lengths from the inputs.

Args:
Expand Down