Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 82 additions & 36 deletions src/Lua/Runtime/LuaVirtualMachine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,15 @@ public bool PopFromBuffer(Span<LuaValue> result)
switch (opCode)
{
case OpCode.Call:
{
var c = callInstruction.C;
if (c != 0)
{
var c = callInstruction.C;
if (c != 0)
{
targetCount = c - 1;
}

break;
targetCount = c - 1;
}

break;
}
case OpCode.TForCall:
target += 3;
targetCount = callInstruction.C;
Expand Down Expand Up @@ -1034,11 +1034,15 @@ static bool Call(ref VirtualMachineExecutionContext context, out bool doRestart)
var instruction = context.Instruction;
var RA = instruction.A + context.FrameBase;
var va = context.Stack.Get(RA);
var newBase = RA + 1;
bool isMetamethod = false;
if (!va.TryReadFunction(out var func))
{
if (va.TryGetMetamethod(context.State, Metamethods.Call, out var metamethod) &&
metamethod.TryReadFunction(out func))
{
newBase -= 1;
isMetamethod = true;
}
else
{
Expand All @@ -1047,8 +1051,8 @@ static bool Call(ref VirtualMachineExecutionContext context, out bool doRestart)
}

var thread = context.Thread;
var (newBase, argumentCount, variableArgumentCount) = PrepareForFunctionCall(thread, func, instruction, RA);

var (argumentCount, variableArgumentCount) = PrepareForFunctionCall(thread, func, instruction, newBase, isMetamethod);
newBase += variableArgumentCount;
var newFrame = func.CreateNewFrame(ref context, newBase, variableArgumentCount);

thread.PushCallStackFrame(newFrame);
Expand Down Expand Up @@ -1160,6 +1164,8 @@ static bool TailCall(ref VirtualMachineExecutionContext context, out bool doRest
var instruction = context.Instruction;
var stack = context.Stack;
var RA = instruction.A + context.FrameBase;
var newBase = RA + 1;
bool isMetamethod = false;
var state = context.State;
var thread = context.Thread;

Expand All @@ -1168,19 +1174,23 @@ static bool TailCall(ref VirtualMachineExecutionContext context, out bool doRest
var va = stack.Get(RA);
if (!va.TryReadFunction(out var func))
{
if (!va.TryGetMetamethod(state, Metamethods.Call, out var metamethod) &&
!metamethod.TryReadFunction(out func))
if (va.TryGetMetamethod(state, Metamethods.Call, out var metamethod) &&
metamethod.TryReadFunction(out func))
{
isMetamethod = true;
newBase -= 1;
}
else
{
LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "call", metamethod);
}
}

var (newBase, argumentCount, variableArgumentCount) = PrepareForFunctionTailCall(thread, func, instruction, RA);

var (argumentCount, variableArgumentCount) = PrepareForFunctionTailCall(thread, func, instruction, newBase, isMetamethod);
newBase = context.FrameBase + variableArgumentCount;

var lastPc = thread.CallStack.AsSpan()[^1].CallerInstructionIndex;
context.Thread.PopCallStackFrameUnsafe();

var newFrame = func.CreateNewFrame(ref context, newBase, variableArgumentCount);

newFrame.Flags |= CallStackFrameFlags.TailCall;
Expand Down Expand Up @@ -1233,18 +1243,46 @@ static bool TForCall(ref VirtualMachineExecutionContext context, out bool doRest
var instruction = context.Instruction;
var stack = context.Stack;
var RA = instruction.A + context.FrameBase;

bool isMetamethod = false;
var iteratorRaw = stack.Get(RA);
if (!iteratorRaw.TryReadFunction(out var iterator))
{
LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "call", iteratorRaw);
if (iteratorRaw.TryGetMetamethod(context.State, Metamethods.Call, out var metamethod) &&
metamethod.TryReadFunction(out iterator))
{
isMetamethod = true;
}
else
{
LuaRuntimeException.AttemptInvalidOperation(GetTracebacks(ref context), "call", metamethod);
}
}

var newBase = RA + 3 + instruction.C;
stack.Get(newBase) = stack.Get(RA + 1);
stack.Get(newBase + 1) = stack.Get(RA + 2);
stack.NotifyTop(newBase + 2);
var newFrame = iterator.CreateNewFrame(ref context, newBase);

if (isMetamethod)
{
stack.Get(newBase) = iteratorRaw;
stack.Get(newBase + 1) = stack.Get(RA + 1);
stack.Get(newBase + 2) = stack.Get(RA + 2);
stack.NotifyTop(newBase + 3);
}
else
{
stack.Get(newBase) = stack.Get(RA + 1);
stack.Get(newBase + 1) = stack.Get(RA + 2);
stack.NotifyTop(newBase + 2);
}

var argumentCount = isMetamethod ? 3 : 2;
var variableArgumentCount = iterator.GetVariableArgumentCount(argumentCount);
if (variableArgumentCount != 0)
{
PrepareVariableArgument(stack, newBase, argumentCount, variableArgumentCount);
newBase += variableArgumentCount;
}

var newFrame = iterator.CreateNewFrame(ref context, newBase, variableArgumentCount);
context.Thread.PushCallStackFrame(newFrame);
if (iterator is LuaClosure)
{
Expand Down Expand Up @@ -1370,7 +1408,7 @@ static bool GetTableValueSlowPath(LuaValue table, LuaValue key, ref VirtualMachi
}

table = metatableValue;
Function:
Function:
if (table.TryReadFunction(out var function))
{
return CallGetTableFunc(targetTable, function, key, ref context, out value, out doRestart);
Expand Down Expand Up @@ -1468,7 +1506,7 @@ static bool SetTableValueSlowPath(LuaValue table, LuaValue key, LuaValue value,

table = metatableValue;

Function:
Function:
if (table.TryReadFunction(out var function))
{
context.PostOperation = PostOperationType.Nop;
Expand Down Expand Up @@ -1732,7 +1770,7 @@ static bool ExecuteCompareOperationMetaMethod(LuaValue vb, LuaValue vc,
// If there are variable arguments, the base of the stack is moved by that number and the values of the variable arguments are placed in front of it.
// see: https://wubingzheng.github.io/build-lua-in-rust/en/ch08-02.arguments.html
[MethodImpl(MethodImplOptions.NoInlining)]
static (int FrameBase, int ArgumentCount, int VariableArgumentCount) PrepareVariableArgument(LuaStack stack, int newBase, int argumentCount,
static ( int ArgumentCount, int VariableArgumentCount) PrepareVariableArgument(LuaStack stack, int newBase, int argumentCount,
int variableArgumentCount)
{
var temp = newBase;
Expand All @@ -1744,51 +1782,59 @@ static bool ExecuteCompareOperationMetaMethod(LuaValue vb, LuaValue vc,
var stackBuffer = stack.GetBuffer()[temp..];
stackBuffer[..argumentCount].CopyTo(stackBuffer[variableArgumentCount..]);
stackBuffer.Slice(argumentCount, variableArgumentCount).CopyTo(stackBuffer);
return (newBase, argumentCount, variableArgumentCount);
return (argumentCount, variableArgumentCount);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
static (int FrameBase, int ArgumentCount, int VariableArgumentCount) PrepareForFunctionCall(LuaThread thread, LuaFunction function,
Instruction instruction, int RA)
static (int ArgumentCount, int VariableArgumentCount) PrepareForFunctionCall(LuaThread thread, LuaFunction function,
Instruction instruction, int newBase, bool isMetaMethod)
{
var argumentCount = instruction.B - 1;
if (argumentCount == -1)
{
argumentCount = (ushort)(thread.Stack.Count - (RA + 1));
argumentCount = (ushort)(thread.Stack.Count - newBase);
}
else
{
thread.Stack.NotifyTop(RA + 1 + argumentCount);
if (isMetaMethod)
{
argumentCount += 1;
}

thread.Stack.NotifyTop(newBase + argumentCount);
}

var newBase = RA + 1;
var variableArgumentCount = function.GetVariableArgumentCount(argumentCount);

if (variableArgumentCount <= 0)
{
return (newBase, argumentCount, 0);
return (argumentCount, 0);
}

return PrepareVariableArgument(thread.Stack, newBase, argumentCount, variableArgumentCount);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
static (int FrameBase, int ArgumentCount, int VariableArgumentCount) PrepareForFunctionTailCall(LuaThread thread, LuaFunction function,
Instruction instruction, int RA)
static (int ArgumentCount, int VariableArgumentCount) PrepareForFunctionTailCall(LuaThread thread, LuaFunction function,
Instruction instruction, int newBase, bool isMetaMethod)
{
var stack = thread.Stack;

var argumentCount = instruction.B - 1;
if (instruction.B == 0)
{
argumentCount = (ushort)(stack.Count - (RA + 1));
argumentCount = (ushort)(stack.Count - newBase);
}
else
{
thread.Stack.NotifyTop(RA + 1 + argumentCount);
if (isMetaMethod)
{
argumentCount += 1;
}

thread.Stack.NotifyTop(newBase + argumentCount);
}

var newBase = RA + 1;

// In the case of tailcall, the local variables of the caller are immediately discarded, so there is no need to retain them.
// Therefore, a call can be made without allocating new registers.
Expand All @@ -1804,7 +1850,7 @@ static bool ExecuteCompareOperationMetaMethod(LuaValue vb, LuaValue vc,

if (variableArgumentCount <= 0)
{
return (newBase, argumentCount, 0);
return (argumentCount, 0);
}

return PrepareVariableArgument(thread.Stack, newBase, argumentCount, variableArgumentCount);
Expand Down
25 changes: 25 additions & 0 deletions tests/Lua.Tests/LuaObjectTests.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using Lua.Standard;

namespace Lua.Tests;

[LuaObject]
Expand Down Expand Up @@ -32,6 +34,12 @@ public double InstanceMethodWithReturnValue()
{
return Property;
}

[LuaMetamethod(LuaObjectMetamethod.Call)]
public string Call()
{
return "Called!";
}
}

public class LuaObjectTests
Expand Down Expand Up @@ -120,4 +128,21 @@ public async Task Test_InstanceMethodWithReturnValue()
Assert.That(results, Has.Length.EqualTo(1));
Assert.That(results[0], Is.EqualTo(new LuaValue(1)));
}

[Test]
public async Task Test_CallMetamethod()
{
var userData = new TestUserData();

var state = LuaState.Create();
state.OpenBasicLibrary();
state.Environment["test"] = userData;
var results = await state.DoStringAsync("""
assert(test() == 'Called!')
return test()
""");

Assert.That(results, Has.Length.EqualTo(1));
Assert.That(results[0], Is.EqualTo(new LuaValue("Called!")));
}
}
50 changes: 50 additions & 0 deletions tests/Lua.Tests/MetatableTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,56 @@ public async Task Test_Metamethod_NewIndex()
await state.DoStringAsync(source);
}

[Test]
public async Task Test_Metamethod_Call()
{
var source = @"
metatable = {
__call = function(a, b)
return a.x + b
end
}

local a = {}
a.x = 1
setmetatable(a, metatable)
assert(a(2) == 3)
function tail(a, b)
return a(b)
end
tail(a, 3)
assert(tail(a, 3) == 4)
";
await state.DoStringAsync(source);
}

[Test]
public async Task Test_Metamethod_TForCall()
{
var source = @"
local i =3
function a(...)
local v ={...}
assert(v[1] ==t)
assert(v[2] == nil)
if i ==3 then
assert(v[3] == nil)
else
assert(v[3] == i)
end

i =i -1
if i ==0 then return nil end
return i
end

t =setmetatable({},{__call = a})

for i in t do
end
";
await state.DoStringAsync(source);
}
[Test]
public async Task Test_Hook_Metamethods()
{
Expand Down