From fea4b9e663ea8bd7712196e92e230312bc9be43d Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Tue, 7 May 2024 12:18:24 +0800 Subject: [PATCH] fix(nyz): fix unittest and platformtest bug --- .github/workflows/platform_test.yml | 2 +- ding/rl_utils/td.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/platform_test.yml b/.github/workflows/platform_test.yml index 58c9f983ae..181f20c681 100644 --- a/.github/workflows/platform_test.yml +++ b/.github/workflows/platform_test.yml @@ -11,7 +11,7 @@ jobs: if: "!contains(github.event.head_commit.message, 'ci skip')" strategy: matrix: - os: [macos-latest, windows-latest] + os: [macos-13, windows-latest] python-version: [3.8, 3.9] steps: diff --git a/ding/rl_utils/td.py b/ding/rl_utils/td.py index 236b04e347..a3f32589c4 100644 --- a/ding/rl_utils/td.py +++ b/ding/rl_utils/td.py @@ -266,6 +266,8 @@ def nstep_return(data: namedtuple, gamma: Union[float, list], nstep: int, value_ if value_gamma is None: return_ = return_tmp + (gamma ** nstep) * next_value * (1 - done) else: + if np.isscalar(value_gamma): + value_gamma = torch.full_like(next_value, value_gamma) value_gamma = view_similar(value_gamma, next_value) done = view_similar(done, next_value) return_ = return_tmp + value_gamma * next_value * (1 - done)