diff --git a/chess/engine.py b/chess/engine.py index 5194247f6..ad61bc580 100644 --- a/chess/engine.py +++ b/chess/engine.py @@ -44,7 +44,7 @@ try: from typing import Literal - _WdlModel = Literal["sf", "sf15", "sf14", "sf12", "lichess"] + _WdlModel = Literal["sf", "sf15.1", "sf15", "sf14", "sf12", "lichess"] except ImportError: # Before Python 3.8. _WdlModel = str # type: ignore @@ -564,7 +564,8 @@ def wdl(self, *, model: _WdlModel = "sf", ply: int = 30) -> Wdl: :param model: * ``sf``, the WDL model used by the latest Stockfish - (currently ``sf15``). + (currently ``sf15.1``). + * ``sf15.1``, the WDL model used by Stockfish 15.1. * ``sf15``, the WDL model used by Stockfish 15. * ``sf14``, the WDL model used by Stockfish 14. * ``sf12``, the WDL model used by Stockfish 12. @@ -628,6 +629,16 @@ def __ge__(self, other: object) -> bool: return NotImplemented +def _sf15_1_wins(cp: int, *, ply: int) -> int: + # https://github.com/official-stockfish/Stockfish/blob/sf_15.1/src/uci.cpp#L200-L224 + # https://github.com/official-stockfish/Stockfish/blob/sf_15.1/src/uci.h#L38 + NormalizeToPawnValue = 361 + m = min(240, max(ply, 0)) / 64 + a = (((-0.58270499 * m + 2.68512549) * m + 15.24638015) * m) + 344.49745382 + b = (((-2.65734562 * m + 15.96509799) * m + -20.69040836) * m) + 73.61029937 + x = min(4000, max(cp * NormalizeToPawnValue / 100, -4000)) + return int(0.5 + 1000 / (1 + math.exp((a - x) / b))) + def _sf15_wins(cp: int, *, ply: int) -> int: # https://github.com/official-stockfish/Stockfish/blob/sf_15/src/uci.cpp#L200-L220 m = min(240, max(ply, 0)) / 64 @@ -678,9 +689,12 @@ def wdl(self, *, model: _WdlModel = "sf", ply: int = 30) -> Wdl: elif model == "sf14": wins = _sf14_wins(self.cp, ply=ply) losses = _sf14_wins(-self.cp, ply=ply) - else: + elif model == "sf15": wins = _sf15_wins(self.cp, ply=ply) losses = _sf15_wins(-self.cp, ply=ply) + else: + wins = _sf15_1_wins(self.cp, ply=ply) + losses = _sf15_1_wins(-self.cp, ply=ply) draws = 1000 - wins - losses return Wdl(wins, draws, losses) diff --git a/test.py b/test.py index 4c8db196e..9bb575802 100755 --- a/test.py +++ b/test.py @@ -2969,6 +2969,7 @@ def test_wdl_model(self): self.assertEqual(chess.engine.Cp(131).wdl(model="sf12", ply=25), chess.engine.Wdl(524, 467, 9)) self.assertEqual(chess.engine.Cp(146).wdl(model="sf14", ply=25), chess.engine.Wdl(601, 398, 1)) self.assertEqual(chess.engine.Cp(40).wdl(model="sf15", ply=25), chess.engine.Wdl(58, 937, 5)) + self.assertEqual(chess.engine.Cp(100).wdl(model="sf15.1", ply=64), chess.engine.Wdl(497, 503, 0)) @catchAndSkip(FileNotFoundError, "need stockfish") def test_sf_forced_mates(self):