diff --git a/cooper/constraints/constraint_state.py b/cooper/constraints/constraint_state.py index dc9cfa8d..f66b37d0 100644 --- a/cooper/constraints/constraint_state.py +++ b/cooper/constraints/constraint_state.py @@ -87,7 +87,7 @@ def extract_violations(self, do_unsqueeze=True) -> tuple[torch.Tensor, torch.Ten return violation, strict_violation - def extract_constraint_features(self) -> torch.Tensor: + def extract_constraint_features(self) -> tuple[torch.Tensor, torch.Tensor]: """Extracts the constraint features from the constraint state. If strict constraint features are not provided, attempts to patch them with the differentiable constraint features. Similarly, if differentiable constraint