diff --git a/Source/LuaBridge/detail/LuaRef.h b/Source/LuaBridge/detail/LuaRef.h index b7f15e79..8234d237 100644 --- a/Source/LuaBridge/detail/LuaRef.h +++ b/Source/LuaBridge/detail/LuaRef.h @@ -96,7 +96,7 @@ class LuaRefBase */ int createRef() const { - impl().push(); + impl().push(m_L); return luaL_ref(m_L, LUA_REGISTRYINDEX); } @@ -119,7 +119,7 @@ class LuaRefBase lua_getglobal(m_L, "tostring"); - impl().push(); + impl().push(m_L); lua_call(m_L, 1, 1); @@ -196,34 +196,6 @@ class LuaRefBase return m_L; } - //============================================================================================= - /** - * @brief Place the object onto the Lua stack. - * - * @param L A Lua state. - */ - void push(lua_State* L) const - { - LUABRIDGE_ASSERT(equalstates(L, m_L)); - (void) L; - - impl().push(); - } - - //============================================================================================= - /** - * @brief Pop the top of Lua stack and assign it to the reference. - * - * @param L A Lua state. - */ - void pop(lua_State* L) - { - LUABRIDGE_ASSERT(equalstates(L, m_L)); - (void) L; - - impl().pop(); - } - //============================================================================================= /** * @brief Return the Lua type of the referred value. @@ -238,7 +210,7 @@ class LuaRefBase { const StackRestore stackRestore(m_L); - impl().push(); + impl().push(m_L); return lua_type(m_L, -1); } @@ -331,7 +303,7 @@ class LuaRefBase { const StackRestore stackRestore(m_L); - impl().push(); + impl().push(m_L); return Stack::get(m_L, -1); } @@ -346,7 +318,7 @@ class LuaRefBase { const StackRestore stackRestore(m_L); - impl().push(); + impl().push(m_L); return *Stack::get(m_L, -1); } @@ -362,7 +334,7 @@ class LuaRefBase { const StackRestore stackRestore(m_L); - impl().push(); + impl().push(m_L); return Stack::isInstance(m_L, -1); } @@ -395,7 +367,7 @@ class LuaRefBase const StackRestore stackRestore(m_L); - impl().push(); + impl().push(m_L); if (! lua_getmetatable(m_L, -1)) return LuaRef(m_L); @@ -418,7 +390,7 @@ class LuaRefBase { const StackRestore stackRestore(m_L); - impl().push(); + impl().push(m_L); if (! Stack::push(m_L, rhs)) return false; @@ -455,7 +427,7 @@ class LuaRefBase { const StackRestore stackRestore(m_L); - impl().push(); + impl().push(m_L); if (! Stack::push(m_L, rhs)) return false; @@ -482,7 +454,7 @@ class LuaRefBase { const StackRestore stackRestore(m_L); - impl().push(); + impl().push(m_L); if (! Stack::push(m_L, rhs)) return false; @@ -509,7 +481,7 @@ class LuaRefBase { const StackRestore stackRestore(m_L); - impl().push(); + impl().push(m_L); if (! Stack::push(m_L, rhs)) return false; @@ -536,7 +508,7 @@ class LuaRefBase { const StackRestore stackRestore(m_L); - impl().push(); + impl().push(m_L); if (! Stack::push(m_L, rhs)) return false; @@ -563,7 +535,7 @@ class LuaRefBase { const StackRestore stackRestore(m_L); - impl().push(); + impl().push(m_L); if (! Stack::push(m_L, v)) return false; @@ -583,7 +555,7 @@ class LuaRefBase { const StackRestore stackRestore(m_L); - impl().push(); + impl().push(m_L); return get_length(m_L, -1); } @@ -756,19 +728,24 @@ class LuaRef : public LuaRefBase /** * @brief Push the value onto the Lua stack. */ - using LuaRefBase::push; - void push() const { + push(m_L); + } + + void push(lua_State* L) const + { + LUABRIDGE_ASSERT(equalstates(L, m_L)); + #if LUABRIDGE_SAFE_STACK_CHECKS - if (! lua_checkstack(m_L, 3)) + if (! lua_checkstack(L, 3)) return; #endif - lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_tableRef); - lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_keyRef); - lua_gettable(m_L, -2); - lua_remove(m_L, -2); // remove the table + lua_rawgeti(L, LUA_REGISTRYINDEX, m_tableRef); + lua_rawgeti(L, LUA_REGISTRYINDEX, m_keyRef); + lua_gettable(L, -2); + lua_remove(L, -2); // remove the table } //========================================================================================= @@ -1103,16 +1080,21 @@ class LuaRef : public LuaRefBase /** * @brief Place the object onto the Lua stack. */ - using LuaRefBase::push; - void push() const { + push(m_L); + } + + void push(lua_State* L) const + { + LUABRIDGE_ASSERT(equalstates(L, m_L)); + #if LUABRIDGE_SAFE_STACK_CHECKS - if (! lua_checkstack(m_L, 1)) + if (! lua_checkstack(L, 1)) return; #endif - lua_rawgeti(m_L, LUA_REGISTRYINDEX, m_ref); + lua_rawgeti(L, LUA_REGISTRYINDEX, m_ref); } //============================================================================================= @@ -1121,10 +1103,17 @@ class LuaRef : public LuaRefBase */ void pop() { + pop(m_L); + } + + void pop(lua_State* L) + { + LUABRIDGE_ASSERT(equalstates(L, m_L)); + if (m_ref != LUA_NOREF) - luaL_unref(m_L, LUA_REGISTRYINDEX, m_ref); + luaL_unref(L, LUA_REGISTRYINDEX, m_ref); - m_ref = luaL_ref(m_L, LUA_REGISTRYINDEX); + m_ref = luaL_ref(L, LUA_REGISTRYINDEX); } //============================================================================================= diff --git a/Tests/Source/CoroutineTests.cpp b/Tests/Source/CoroutineTests.cpp index 4620ccc3..0cb0b706 100644 --- a/Tests/Source/CoroutineTests.cpp +++ b/Tests/Source/CoroutineTests.cpp @@ -4,6 +4,20 @@ #include "TestBase.h" +namespace { +int lua_resume_x(lua_State* L, int nargs) +{ +#if LUABRIDGEDEMO_LUAJIT || LUABRIDGEDEMO_LUA_VERSION == 501 + return lua_resume(L, nargs); +#elif LUABRIDGEDEMO_LUAU || LUABRIDGEDEMO_LUA_VERSION < 504 + return lua_resume(L, nullptr, nargs); +#else + [[maybe_unused]] int nresults = 0; + return lua_resume(L, nullptr, nargs, &nresults); +#endif +} +} // namespace + struct CoroutineTests : TestBase { }; @@ -79,6 +93,45 @@ TEST_F(CoroutineTests, LuaRefMove) EXPECT_EQ(42, y.unsafe_cast()); } +TEST_F(CoroutineTests, LuaRefPushInDifferentThread) +{ + lua_State* thread1 = lua_newthread(L); + lua_State* thread2 = lua_newthread(L); + + luabridge::LuaRef y = luabridge::LuaRef(L, 1337); + + luabridge::setGlobal(thread1, y, "y1"); + luabridge::setGlobal(thread2, y, "y2"); + + { + auto result = luaL_loadstring(thread1, "coroutine.yield(y1)"); + ASSERT_EQ(LUABRIDGE_LUA_OK, result); + } + + { + auto result = lua_resume_x(thread1, 0); + ASSERT_EQ(LUA_YIELD, result); + EXPECT_EQ(1, lua_gettop(thread1)); + + auto x1 = luabridge::LuaRef::fromStack(thread1); + EXPECT_EQ(y, x1); + } + + { + auto result = luaL_loadstring(thread2, "coroutine.yield(y2)"); + ASSERT_EQ(LUABRIDGE_LUA_OK, result); + } + + { + auto result = lua_resume_x(thread2, 0); + ASSERT_EQ(LUA_YIELD, result); + EXPECT_EQ(1, lua_gettop(thread2)); + + auto x1 = luabridge::LuaRef::fromStack(thread2); + EXPECT_EQ(y, x1); + } +} + TEST_F(CoroutineTests, ThreadedRegistration) { using namespace luabridge; @@ -122,15 +175,7 @@ TEST_F(CoroutineTests, ThreadedRegistration) lua_rawgeti(thread1, LUA_REGISTRYINDEX, scriptRef); { -#if LUABRIDGEDEMO_LUAJIT || LUABRIDGEDEMO_LUA_VERSION == 501 - auto result = lua_resume(thread1, 0); -#elif LUABRIDGEDEMO_LUAU || LUABRIDGEDEMO_LUA_VERSION < 504 - auto result = lua_resume(thread1, nullptr, 0); -#else - int nresults = 0; - auto result = lua_resume(thread1, nullptr, 0, &nresults); -#endif - + auto result = lua_resume_x(thread1, 0); ASSERT_EQ(LUABRIDGE_LUA_OK, result); }