Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add coola comparators #118

Merged
merged 1 commit into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "redcat"
version = "0.0.1a105"
version = "0.0.1a106"
description = "A library to manipulate batches of examples"
readme = "README.md"
authors = ["Thibaut Durand <durand.tibo+gh@gmail.com>"]
Expand Down
1 change: 1 addition & 0 deletions src/redcat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__all__ = ["BaseBatch", "BaseBatchedTensor", "BatchedTensor", "BatchedTensorSeq"]

from redcat import comparators # noqa: F401
from redcat.base import BaseBatch
from redcat.basetensor import BaseBatchedTensor
from redcat.tensor import BatchedTensor
Expand Down
73 changes: 73 additions & 0 deletions src/redcat/comparators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
r"""This module implements some comparators to use ``BaseBatch`` objects with
``coola.objects_are_equal`` and ``coola.objects_are_allclose``."""

__all__ = ["BatchEqualityOperator", "BatchAllCloseOperator"]

import logging
from typing import Any

from coola import (
AllCloseTester,
BaseAllCloseOperator,
BaseAllCloseTester,
BaseEqualityOperator,
BaseEqualityTester,
EqualityTester,
)

from redcat.base import BaseBatch

logger = logging.getLogger(__name__)


class BatchEqualityOperator(BaseEqualityOperator[BaseBatch]):
r"""Implements an equality operator for ``BaseBatch`` objects."""

def equal(
self,
tester: BaseEqualityTester,
object1: BaseBatch,
object2: Any,
show_difference: bool = False,
) -> bool:
if not isinstance(object2, BaseBatch):
if show_difference:
logger.info(f"object2 is not a `BaseBatch` object: {type(object2)}")
return False
object_equal = object1.equal(object2)
if show_difference and not object_equal:
logger.info(
f"`BaseBatch` objects are different\nobject1=\n{object1}\nobject2=\n{object2}"
)
return object_equal


class BatchAllCloseOperator(BaseAllCloseOperator[BaseBatch]):
r"""Implements an allclose operator for ``BaseBatch`` objects."""

def allclose(
self,
tester: BaseAllCloseTester,
object1: BaseBatch,
object2: Any,
rtol: float = 1e-5,
atol: float = 1e-8,
equal_nan: bool = False,
show_difference: bool = False,
) -> bool:
if not isinstance(object2, BaseBatch):
if show_difference:
logger.info(f"object2 is not a `BaseBatch` object: {type(object2)}")
return False
object_equal = object1.allclose(object2, rtol=rtol, atol=atol, equal_nan=equal_nan)
if show_difference and not object_equal:
logger.info(
f"`BaseBatch` objects are different\nobject1=\n{object1}\nobject2=\n{object2}"
)
return object_equal


if not AllCloseTester.has_allclose_operator(BaseBatch):
AllCloseTester.add_allclose_operator(BaseBatch, BatchAllCloseOperator()) # pragma: no cover
if not EqualityTester.has_equality_operator(BaseBatch):
EqualityTester.add_equality_operator(BaseBatch, BatchEqualityOperator()) # pragma: no cover
167 changes: 167 additions & 0 deletions tests/unit/test_comparators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import logging

import torch
from coola import AllCloseTester, EqualityTester
from pytest import LogCaptureFixture, mark

from redcat import BaseBatch, BatchedTensor
from redcat.comparators import BatchAllCloseOperator, BatchEqualityOperator


def test_registered_batch_comparators() -> None:
assert isinstance(EqualityTester.registry[BaseBatch], BatchEqualityOperator)
assert isinstance(AllCloseTester.registry[BaseBatch], BatchAllCloseOperator)


###########################################
# Tests for BatchEqualityOperator #
###########################################


def test_batch_equality_operator_str() -> None:
assert str(BatchEqualityOperator()) == "BatchEqualityOperator()"


def test_batch_equality_operator_equal_true() -> None:
assert BatchEqualityOperator().equal(
EqualityTester(), BatchedTensor(torch.ones(2, 3)), BatchedTensor(torch.ones(2, 3))
)


def test_batch_equality_operator_equal_true_show_difference(caplog: LogCaptureFixture) -> None:
with caplog.at_level(logging.INFO):
assert BatchEqualityOperator().equal(
tester=EqualityTester(),
object1=BatchedTensor(torch.ones(2, 3)),
object2=BatchedTensor(torch.ones(2, 3)),
show_difference=True,
)
assert not caplog.messages


def test_batch_equality_operator_equal_false_different_value() -> None:
assert not BatchEqualityOperator().equal(
EqualityTester(), BatchedTensor(torch.ones(2, 3)), BatchedTensor(torch.zeros(2, 3))
)


def test_batch_equality_operator_equal_false_different_value_show_difference(
caplog: LogCaptureFixture,
) -> None:
with caplog.at_level(logging.INFO):
assert not BatchEqualityOperator().equal(
tester=EqualityTester(),
object1=BatchedTensor(torch.ones(2, 3)),
object2=BatchedTensor(torch.zeros(2, 3)),
show_difference=True,
)
assert caplog.messages[0].startswith("`BaseBatch` objects are different")


def test_batch_equality_operator_equal_false_different_type() -> None:
assert not BatchEqualityOperator().equal(EqualityTester(), BatchedTensor(torch.ones(2, 3)), 42)


def test_batch_equality_operator_equal_false_different_type_show_difference(
caplog: LogCaptureFixture,
) -> None:
with caplog.at_level(logging.INFO):
assert not BatchEqualityOperator().equal(
tester=EqualityTester(),
object1=BatchedTensor(torch.ones(2, 3)),
object2=42,
show_difference=True,
)
assert caplog.messages[0].startswith("object2 is not a `BaseBatch` object")


###########################################
# Tests for BatchAllCloseOperator #
###########################################


def test_batch_allclose_operator_str() -> None:
assert str(BatchAllCloseOperator()) == "BatchAllCloseOperator()"


def test_batch_allclose_operator_allclose_true() -> None:
assert BatchAllCloseOperator().allclose(
AllCloseTester(), BatchedTensor(torch.ones(2, 3)), BatchedTensor(torch.ones(2, 3))
)


def test_batch_allclose_operator_allclose_true_show_difference(caplog: LogCaptureFixture) -> None:
with caplog.at_level(logging.INFO):
assert BatchAllCloseOperator().allclose(
tester=AllCloseTester(),
object1=BatchedTensor(torch.ones(2, 3)),
object2=BatchedTensor(torch.ones(2, 3)),
show_difference=True,
)
assert not caplog.messages


def test_batch_allclose_operator_allclose_false_different_value() -> None:
assert not BatchAllCloseOperator().allclose(
AllCloseTester(), BatchedTensor(torch.ones(2, 3)), BatchedTensor(torch.zeros(2, 3))
)


def test_batch_allclose_operator_allclose_false_different_value_show_difference(
caplog: LogCaptureFixture,
) -> None:
with caplog.at_level(logging.INFO):
assert not BatchAllCloseOperator().allclose(
tester=AllCloseTester(),
object1=BatchedTensor(torch.ones(2, 3)),
object2=BatchedTensor(torch.zeros(2, 3)),
show_difference=True,
)
assert caplog.messages[0].startswith("`BaseBatch` objects are different")


def test_batch_allclose_operator_allclose_false_different_type() -> None:
assert not BatchAllCloseOperator().allclose(
AllCloseTester(), BatchedTensor(torch.ones(2, 3)), 42
)


def test_batch_allclose_operator_allclose_false_different_type_show_difference(
caplog: LogCaptureFixture,
) -> None:
with caplog.at_level(logging.INFO):
assert not BatchAllCloseOperator().allclose(
tester=AllCloseTester(),
object1=BatchedTensor(torch.ones(2, 3)),
object2=42,
show_difference=True,
)
assert caplog.messages[0].startswith("object2 is not a `BaseBatch` object")


@mark.parametrize(
"tensor,atol",
(
(BatchedTensor(torch.ones(2, 3).add(0.5)), 1),
(BatchedTensor(torch.ones(2, 3).add(0.05)), 1e-1),
(BatchedTensor(torch.ones(2, 3).add(5e-3)), 1e-2),
),
)
def test_batch_allclose_operator_allclose_true_atol(tensor: BatchedTensor, atol: float) -> None:
assert BatchAllCloseOperator().allclose(
AllCloseTester(), BatchedTensor(torch.ones(2, 3)), tensor, atol=atol, rtol=0
)


@mark.parametrize(
"tensor,rtol",
(
(BatchedTensor(torch.ones(2, 3).add(0.5)), 1),
(BatchedTensor(torch.ones(2, 3).add(0.05)), 1e-1),
(BatchedTensor(torch.ones(2, 3).add(5e-3)), 1e-2),
),
)
def test_batch_allclose_operator_allclose_true_rtol(tensor: BatchedTensor, rtol: float) -> None:
assert BatchAllCloseOperator().allclose(
AllCloseTester(), BatchedTensor(torch.ones(2, 3)), tensor, rtol=rtol
)