Skip to content

Commit

Permalink
feat: getitem interface for FullOutput object.
Browse files Browse the repository at this point in the history
  • Loading branch information
peterschutt authored and mberk committed Feb 22, 2024
1 parent 20f8963 commit c9b6e42
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 0 deletions.
22 changes: 22 additions & 0 deletions python/shin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,28 @@ class FullOutput(Generic[OutputT]):
delta: float
z: float

@overload
def __getitem__(self, key: Literal["implied_probabilities"]) -> OutputT:
...

@overload
def __getitem__(self, key: Literal["iterations"]) -> float:
...

@overload
def __getitem__(self, key: Literal["delta"]) -> float:
...

@overload
def __getitem__(self, key: Literal["z"]) -> float:
...

def __getitem__(self, key: Literal["implied_probabilities", "iterations", "delta", "z"]) -> Any:
try:
return getattr(self, key)
except AttributeError:
raise KeyError(key)


# sequence input, full output False
@overload
Expand Down
12 changes: 12 additions & 0 deletions tests/test_shin.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,15 @@ def test_calculate_implied_probabilities_two_outcomes() -> None:
pytest.approx(inverse_odds[1] - (sum_inverse_odds - 1) / 2)
== result.implied_probabilities[1]
)


def test_full_output_get_item_interface() -> None:
full_output = shin.FullOutput(
implied_probabilities=[0.3, 0.4, 0.3], iterations=10, delta=0.1, z=0.5
)
assert full_output["implied_probabilities"] == [0.3, 0.4, 0.3]
assert full_output["iterations"] == 10
assert full_output["delta"] == 0.1
assert full_output["z"] == 0.5
with pytest.raises(KeyError):
full_output["foo"]
23 changes: 23 additions & 0 deletions typesafety/test_shin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,26 @@
out: |
main:4: note: Revealed type is "shin.FullOutput[builtins.dict[builtins.int, builtins.float]]"
main:5: note: Revealed type is "builtins.dict[builtins.int, builtins.float]"
- case: test_full_output_get_item_overloads
main: |
import shin
out = shin.calculate_implied_probabilities([3.0, 3.0, 3.0], full_output=True)
reveal_type(out['implied_probabilities'])
reveal_type(out['iterations'])
reveal_type(out['delta'])
reveal_type(out['z'])
reveal_type(out['other'])
out: |
main:4: note: Revealed type is "builtins.list[builtins.float]"
main:5: note: Revealed type is "builtins.float"
main:6: note: Revealed type is "builtins.float"
main:7: note: Revealed type is "builtins.float"
main:8: error: No overload variant of "__getitem__" of "FullOutput" matches argument type "str" [call-overload]
main:8: note: Possible overload variants:
main:8: note: def __getitem__(self, Literal['implied_probabilities'], /) -> list[float]
main:8: note: def __getitem__(self, Literal['iterations'], /) -> float
main:8: note: def __getitem__(self, Literal['delta'], /) -> float
main:8: note: def __getitem__(self, Literal['z'], /) -> float
main:8: note: Revealed type is "Any"

0 comments on commit c9b6e42

Please sign in to comment.