Skip to content

Commit

Permalink
Add generic type to BatchDict (#446)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo committed Sep 25, 2023
1 parent 5f4a762 commit 6b5309c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redcat"
version = "0.0.12a0"
version = "0.0.12a1"
description = "A library to manipulate batches of examples"
readme = "README.md"
authors = ["Thibaut Durand <durand.tibo+gh@gmail.com>"]
Expand Down
11 changes: 6 additions & 5 deletions src/redcat/batchdict.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@

from redcat.base import BaseBatch

TBaseBatch = TypeVar("TBaseBatch", bound=BaseBatch)
# Workaround because Self is not available for python 3.9 and 3.10
# https://peps.python.org/pep-0673/
TBatchDict = TypeVar("TBatchDict", bound="BatchDict")


class BatchDict(BaseBatch[dict[Hashable, BaseBatch]]):
class BatchDict(BaseBatch[dict[Hashable, TBaseBatch]]):
r"""Implements a batch object to represent a dictionary of batches.
Args:
Expand All @@ -48,7 +49,7 @@ class BatchDict(BaseBatch[dict[Hashable, BaseBatch]]):
... )
"""

def __init__(self, data: dict[Hashable, BaseBatch]) -> None:
def __init__(self, data: dict[Hashable, TBaseBatch]) -> None:
if not isinstance(data, dict):
raise TypeError(f"Incorrect type. Expect a dict but received {type(data)}")
check_same_batch_size(data)
Expand All @@ -63,7 +64,7 @@ def batch_size(self) -> int:
return next(iter(self._data.values())).batch_size

@property
def data(self) -> dict[Hashable, BaseBatch]:
def data(self) -> dict[Hashable, TBaseBatch]:
return self._data

#################################
Expand Down Expand Up @@ -601,7 +602,7 @@ def summary(self) -> str:
return f"{self.__class__.__qualname__}(\n {str_indent(data_str)}\n)"


def check_same_batch_size(data: dict[Hashable, BaseBatch]) -> None:
def check_same_batch_size(data: dict[Hashable, TBaseBatch]) -> None:
r"""Checks if the all the batches in a group have the same batch
size.
Expand Down Expand Up @@ -672,7 +673,7 @@ def check_same_keys(data1: dict, data2: dict) -> None:
raise RuntimeError(f"Keys do not match: ({keys1} vs {keys2})")


def get_seq_lens(data: dict[Hashable, BaseBatch]) -> set[int]:
def get_seq_lens(data: dict[Hashable, TBaseBatch]) -> set[int]:
r"""Gets the sequence lengths from the inputs.
Args:
Expand Down

0 comments on commit 6b5309c

Please sign in to comment.