diff --git a/sandbox/Benchmark/InterpreterSteps.cs b/sandbox/Benchmark/InterpreterSteps.cs index 15fb5688..b98f2c62 100644 --- a/sandbox/Benchmark/InterpreterSteps.cs +++ b/sandbox/Benchmark/InterpreterSteps.cs @@ -46,7 +46,6 @@ public LuaClosure Compile() [Benchmark] public async ValueTask RunAsync() { - await state.MainThread.RunAsync(closure); - state.MainThread.Stack.Clear(); + await state.TopLevelAccess.Call(closure, []); } } \ No newline at end of file diff --git a/sandbox/ConsoleApp1/Program.cs b/sandbox/ConsoleApp1/Program.cs index 551ae85d..1e4dbd33 100644 --- a/sandbox/ConsoleApp1/Program.cs +++ b/sandbox/ConsoleApp1/Program.cs @@ -5,7 +5,7 @@ using System.Text.RegularExpressions; using System; using System.IO; - +using System.Text; var state = LuaState.Create(); state.OpenStandardLibraries(); @@ -31,10 +31,10 @@ Console.WriteLine("Output " + new string('-', 50)); - var count = await state.MainThread.RunAsync(closure); + var count = await state.TopLevelAccess.RunAsync(closure); Console.WriteLine("Result " + new string('-', 50)); - using var results = state.MainThread.ReadReturnValues(count); + using var results = state.TopLevelAccess.ReadReturnValues(count); for (int i = 0; i < count; i++) { Console.WriteLine(results[i]); @@ -46,26 +46,10 @@ { if (ex is LuaCompileException luaCompileException) { - var linOffset = luaCompileException.OffSet - luaCompileException.Position.Column + 1; - var length = 0; - foreach (var c in source.AsSpan(linOffset)) - { - if (c is '\n' or '\r') - { - break; - } - - length++; - } + + Console.WriteLine("CompileError " + new string('-', 50)); - Console.WriteLine(luaCompileException.ChunkName + ":"+luaCompileException.Position.Line + ":" + luaCompileException.Position.Column); - var line = source.Substring(linOffset, length); - var lineNumString = luaCompileException.Position.Line.ToString(); - Console.WriteLine(new string(' ', lineNumString.Length) + " |"); - Console.WriteLine(lineNumString + " | " + line); - Console.WriteLine(new string(' ', lineNumString.Length) + " | " + - new string(' ', luaCompileException.Position.Column - 1) + - "^ " + luaCompileException.MainMessage); + Console.WriteLine(RustLikeExceptionHook.OnCatch(source, luaCompileException)); ; Console.WriteLine(new string('-', 55)); } @@ -128,4 +112,40 @@ static void DebugChunk(Prototype chunk, int id) DebugChunk(localChunk, nestedChunkId); nestedChunkId++; } +} + +public class LuaRustLikeException(string message, Exception? innerException) : LuaException(message, innerException); + +class RustLikeExceptionHook //: ILuaCompileHook +{ + public static string OnCatch(ReadOnlySpan source, LuaCompileException exception) + { + var lineOffset = exception.OffSet - exception.Position.Column + 1; + var length = 0; + if (lineOffset < 0) + { + lineOffset = 0; + } + foreach (var c in source[lineOffset..]) + { + if (c is '\n' or '\r') + { + break; + } + + length++; + } + var builder = new StringBuilder(); + builder.AppendLine(); + builder.AppendLine("[error]: "+exception.MessageWithNearToken); + builder.AppendLine("-->"+exception.ChunkName + ":" + exception.Position.Line + ":" + exception.Position.Column); + var line = source.Slice(lineOffset, length).ToString(); + var lineNumString = exception.Position.Line.ToString(); + builder.AppendLine(new string(' ', lineNumString.Length) + " |"); + builder.AppendLine(lineNumString + " | " + line); + builder.AppendLine(new string(' ', lineNumString.Length) + " | " + + new string(' ', exception.Position.Column - 1) + + "^ " + exception.MainMessage); + return builder.ToString(); + } } \ No newline at end of file diff --git a/sandbox/ConsoleApp2/Program.cs b/sandbox/ConsoleApp2/Program.cs index 31f1bb53..853c04da 100644 --- a/sandbox/ConsoleApp2/Program.cs +++ b/sandbox/ConsoleApp2/Program.cs @@ -8,10 +8,10 @@ { var closure = state.Load("return function (a,b,...) print('a : '..a..' b :'..'args : ',...) end", "simple"); using var threadLease = state.MainThread.RentUseThread(); - var thread = threadLease.Thread; + var access = threadLease.Thread.TopLevelAccess; { - var count = await thread.RunAsync(closure); - var results = thread.ReadReturnValues(count); + var count = await access.RunAsync(closure,0); + var results = access.ReadReturnValues(count); for (int i = 0; i < results.Length; i++) { Console.WriteLine(results[i]); @@ -19,9 +19,9 @@ var f = results[0].Read(); results.Dispose(); - thread.Push("hello", "world", 1, 2, 3); - count = await thread.RunAsync(f); - results = thread.ReadReturnValues(count); + access.Push("hello", "world", 1, 2, 3); + count = await access.RunAsync(f); + results = access.ReadReturnValues(count); for (int i = 0; i < results.Length; i++) { Console.WriteLine(results[i]); diff --git a/sandbox/JitTest/Program.cs b/sandbox/JitTest/Program.cs index 44f0b852..d661688a 100644 --- a/sandbox/JitTest/Program.cs +++ b/sandbox/JitTest/Program.cs @@ -22,7 +22,7 @@ for (int i = 0; i < 1000; i++) { - await luaState.MainThread.RunAsync(closure); + await luaState.TopLevelAccess.RunAsync(closure); luaState.MainThread.Stack.Clear(); } diff --git a/src/Lua/InternalVisibleTo.cs b/src/Lua/InternalVisibleTo.cs new file mode 100644 index 00000000..56a85773 --- /dev/null +++ b/src/Lua/InternalVisibleTo.cs @@ -0,0 +1,3 @@ +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("Lua.Tests")] \ No newline at end of file diff --git a/src/Lua/LuaCoroutine.cs b/src/Lua/LuaCoroutine.cs index 07b0deba..1ac4cf34 100644 --- a/src/Lua/LuaCoroutine.cs +++ b/src/Lua/LuaCoroutine.cs @@ -59,7 +59,6 @@ internal void Init(LuaThread parent, LuaFunction function, bool isProtectedMode) State = parent.State; IsProtectedMode = isProtectedMode; Function = function; - IsRunning = false; } public override LuaThreadStatus GetStatus() => (LuaThreadStatus)status; @@ -153,8 +152,8 @@ async ValueTask ResumeAsyncCore(LuaStack stack, int argCount, int returnBas if (isFirstCall) { Stack.PushRange(stack.AsSpan()[^argCount..]); - functionTask = Function.InvokeAsync(new() { Thread = this, ArgumentCount = Stack.Count, ReturnFrameBase = 0 }, cancellationToken); - + //functionTask = Function.InvokeAsync(new() { Access = this.CurrentAccess, ArgumentCount = Stack.Count, ReturnFrameBase = 0 }, cancellationToken); + functionTask =CurrentAccess.RunAsync(Function,Stack.Count, cancellationToken); Volatile.Write(ref isFirstCall, false); if (!functionTask.IsCompleted) { diff --git a/src/Lua/LuaFunction.cs b/src/Lua/LuaFunction.cs index c961b423..686840c3 100644 --- a/src/Lua/LuaFunction.cs +++ b/src/Lua/LuaFunction.cs @@ -10,33 +10,4 @@ public class LuaFunction(string name, Func> func) : this("anonymous", func) { } - - public async ValueTask InvokeAsync(LuaFunctionExecutionContext context, CancellationToken cancellationToken) - { - var varArgumentCount = this.GetVariableArgumentCount(context.ArgumentCount); - if (varArgumentCount != 0) - { - LuaVirtualMachine.PrepareVariableArgument(context.Thread.Stack, context.ArgumentCount, varArgumentCount); - context = context with { ArgumentCount = context.ArgumentCount - varArgumentCount }; - } - - var callStackFrameCount = context.Thread.CallStackFrameCount; - try - { - var frame = new CallStackFrame { Base = context.FrameBase, VariableArgumentCount = varArgumentCount, Function = this, ReturnBase = context.ReturnFrameBase }; - context.Thread.PushCallStackFrame(frame); - - if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook) - { - return await LuaVirtualMachine.ExecuteCallHook(context, cancellationToken); - } - - var r = await Func(context, cancellationToken); - return r; - } - finally - { - context.Thread.PopCallStackFrameUntil(callStackFrameCount); - } - } } \ No newline at end of file diff --git a/src/Lua/LuaFunctionExecutionContext.cs b/src/Lua/LuaFunctionExecutionContext.cs index 3f0f2230..600f15e9 100644 --- a/src/Lua/LuaFunctionExecutionContext.cs +++ b/src/Lua/LuaFunctionExecutionContext.cs @@ -8,7 +8,8 @@ namespace Lua; public readonly record struct LuaFunctionExecutionContext { public LuaState State => Thread.State; - public required LuaThread Thread { get; init; } + public required LuaThreadAccess Access { get; init; } + public LuaThread Thread => Access.Thread; public required int ArgumentCount { get; init; } public int FrameBase => Thread.Stack.Count - ArgumentCount; public required int ReturnFrameBase { get; init; } diff --git a/src/Lua/LuaFunctionExtensions.cs b/src/Lua/LuaFunctionExtensions.cs index 6550550e..e3af6065 100644 --- a/src/Lua/LuaFunctionExtensions.cs +++ b/src/Lua/LuaFunctionExtensions.cs @@ -1,41 +1,43 @@ -using Lua.Runtime; - -namespace Lua; - -public static class LuaFunctionExtensions -{ - public static async ValueTask InvokeAsync(this LuaFunction function, LuaThread thread, int argumentCount, CancellationToken cancellationToken = default) - { - var varArgumentCount = function.GetVariableArgumentCount(argumentCount); - if (varArgumentCount != 0) - { - if (varArgumentCount < 0) - { - thread.Stack.SetTop(thread.Stack.Count - varArgumentCount); - argumentCount -= varArgumentCount; - varArgumentCount = 0; - } - else - { - LuaVirtualMachine.PrepareVariableArgument(thread.Stack, argumentCount, varArgumentCount); - } - } - - LuaFunctionExecutionContext context = new() { Thread = thread, ArgumentCount = argumentCount - varArgumentCount, ReturnFrameBase = thread.Stack.Count, }; - var frame = new CallStackFrame { Base = context.FrameBase, VariableArgumentCount = varArgumentCount, Function = function, ReturnBase = context.ReturnFrameBase }; - context.Thread.PushCallStackFrame(frame); - try - { - if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook) - { - return await LuaVirtualMachine.ExecuteCallHook(context, cancellationToken); - } - - return await function.Func(context, cancellationToken); - } - finally - { - context.Thread.PopCallStackFrame(); - } - } -} \ No newline at end of file +// using Lua.Runtime; +// +// namespace Lua; +// +// public static class LuaFunctionExtensions +// { +// +// public static async ValueTask InvokeAsync(this LuaFunction function, LuaThread thread, int argumentCount, CancellationToken cancellationToken = default) +// { +// var returnFrameBase = thread.Stack.Count-argumentCount; +// var varArgumentCount = function.GetVariableArgumentCount(argumentCount); +// if (varArgumentCount != 0) +// { +// if (varArgumentCount < 0) +// { +// thread.Stack.SetTop(thread.Stack.Count - varArgumentCount); +// argumentCount -= varArgumentCount; +// varArgumentCount = 0; +// } +// else +// { +// LuaVirtualMachine.PrepareVariableArgument(thread.Stack, argumentCount, varArgumentCount); +// } +// } +// +// LuaFunctionExecutionContext context = new() { Thread = thread, ArgumentCount = argumentCount , ReturnFrameBase = returnFrameBase, }; +// var frame = new CallStackFrame { Base = context.FrameBase, VariableArgumentCount = varArgumentCount, Function = function, ReturnBase = context.ReturnFrameBase }; +// context.Thread.PushCallStackFrame(frame); +// try +// { +// if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook) +// { +// return await LuaVirtualMachine.ExecuteCallHook(context, cancellationToken); +// } +// +// return await function.Func(context, cancellationToken); +// } +// finally +// { +// context.Thread.PopCallStackFrame(); +// } +// } +// } \ No newline at end of file diff --git a/src/Lua/LuaState.cs b/src/Lua/LuaState.cs index 46c73788..2fe3a84a 100644 --- a/src/Lua/LuaState.cs +++ b/src/Lua/LuaState.cs @@ -34,6 +34,8 @@ public sealed class LuaState public LuaTable LoadedModules => packages; public LuaMainThread MainThread => mainThread; + public LuaThreadAccess TopLevelAccess => new (mainThread, 0); + public ILuaModuleLoader ModuleLoader { get; set; } = FileModuleLoader.Instance; // metatables diff --git a/src/Lua/LuaStateExtensions.cs b/src/Lua/LuaStateExtensions.cs index e2976b55..f67e32bc 100644 --- a/src/Lua/LuaStateExtensions.cs +++ b/src/Lua/LuaStateExtensions.cs @@ -1,24 +1,26 @@ +using Lua.Runtime; + namespace Lua; public static class LuaStateExtensions { public static ValueTask DoStringAsync(this LuaState state, string source, Memory buffer, string? chunkName = null, CancellationToken cancellationToken = default) { - return state.MainThread.DoStringAsync(source, buffer, chunkName, cancellationToken); + return state.TopLevelAccess.DoStringAsync(source, buffer, chunkName, cancellationToken); } public static ValueTask DoStringAsync(this LuaState state, string source, string? chunkName = null, CancellationToken cancellationToken = default) { - return state.MainThread.DoStringAsync(source, chunkName, cancellationToken); + return state.TopLevelAccess.DoStringAsync(source, chunkName, cancellationToken); } public static ValueTask DoFileAsync(this LuaState state, string path, Memory buffer, CancellationToken cancellationToken = default) { - return state.MainThread.DoFileAsync(path, buffer, cancellationToken); + return state.TopLevelAccess.DoFileAsync(path, buffer, cancellationToken); } public static ValueTask DoFileAsync(this LuaState state, string path, CancellationToken cancellationToken = default) { - return state.MainThread.DoFileAsync(path, cancellationToken); + return state.TopLevelAccess.DoFileAsync(path, cancellationToken); } } \ No newline at end of file diff --git a/src/Lua/LuaThread.cs b/src/Lua/LuaThread.cs index 09dfb6f5..e784c6ab 100644 --- a/src/Lua/LuaThread.cs +++ b/src/Lua/LuaThread.cs @@ -28,7 +28,7 @@ public virtual ValueTask YieldAsync(LuaFunctionExecutionContext context, Ca protected class ThreadCoreData : IPoolNode { //internal LuaCoroutineData? coroutineData; - internal LuaStack Stack = new(); + internal readonly LuaStack Stack = new(); internal FastStackCore CallStack; public void Clear() @@ -67,10 +67,13 @@ public void Release() internal int BaseHookCount; internal int LastPc; + internal int LastVersion; + internal int CurrentVersion; + internal LuaRuntimeException? CurrentException; internal readonly ReversedStack ExceptionTrace = new(); - public bool IsRunning { get; protected set; } + public bool IsRunning => CallStackFrameCount != 0; internal LuaFunction? Hook { get; set; } public LuaStack Stack => CoreData!.Stack; @@ -99,25 +102,10 @@ internal bool IsReturnHookEnabled set => CallOrReturnHookMask.Flag1 = value; } - public async ValueTask RunAsync(LuaClosure closure, CancellationToken cancellationToken = default) - { - ThrowIfRunning(); - - IsRunning = true; - try - { - await closure.InvokeAsync(new() { Thread = this, ArgumentCount = Stack.Count, ReturnFrameBase = 0, }, cancellationToken); - - return Stack.Count; - } - finally - { - PopCallStackFrameUntil(0); - IsRunning = false; - } - } - public int CallStackFrameCount => CoreData == null ? 0 : CoreData!.CallStack.Count; + + internal LuaThreadAccess CurrentAccess => new(this, CurrentVersion); + public LuaThreadAccess TopLevelAccess => new(this, 0); public ref readonly CallStackFrame GetCurrentFrame() { @@ -134,33 +122,45 @@ public ReadOnlySpan GetCallStackFrames() return CoreData == null ? default : CoreData!.CallStack.AsSpan(); } + void UpdateCurrentVersion(ref FastStackCore callStack) + { + CurrentVersion = callStack.Count == 0 ? 0 : callStack.PeekRef().Version; + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void PushCallStackFrame(in CallStackFrame frame) + internal LuaThreadAccess PushCallStackFrame(in CallStackFrame frame) { CurrentException?.Build(); CurrentException = null; - CoreData!.CallStack.Push(frame); + ref var callStack = ref CoreData!.CallStack; + callStack.Push(frame); + callStack.PeekRef().Version = CurrentVersion = ++LastVersion; + return new LuaThreadAccess(this, CurrentVersion); } [MethodImpl(MethodImplOptions.AggressiveInlining)] internal void PopCallStackFrameWithStackPop() { var coreData = CoreData!; - var popFrame = coreData.CallStack.Pop(); + ref var callStack = ref coreData.CallStack; + var popFrame = callStack.Pop(); + UpdateCurrentVersion(ref callStack); if (CurrentException != null) { ExceptionTrace.Push(popFrame); } - coreData.Stack.PopUntil(popFrame.Base); + coreData.Stack.PopUntil(popFrame.ReturnBase); } [MethodImpl(MethodImplOptions.AggressiveInlining)] internal void PopCallStackFrameWithStackPop(int frameBase) { var coreData = CoreData!; - var popFrame = coreData.CallStack.Pop(); + ref var callStack = ref coreData.CallStack; + var popFrame = callStack.Pop(); + UpdateCurrentVersion(ref callStack); if (CurrentException != null) { ExceptionTrace.Push(popFrame); @@ -175,7 +175,9 @@ internal void PopCallStackFrameWithStackPop(int frameBase) internal void PopCallStackFrame() { var coreData = CoreData!; - var popFrame = coreData.CallStack.Pop(); + ref var callStack = ref coreData.CallStack; + var popFrame = callStack.Pop(); + UpdateCurrentVersion(ref callStack); if (CurrentException != null) { ExceptionTrace.Push(popFrame); @@ -193,6 +195,7 @@ internal void PopCallStackFrameUntil(int top) } callStack.PopUntil(top); + UpdateCurrentVersion(ref callStack); } internal void DumpStackValues() @@ -208,12 +211,4 @@ public Traceback GetTraceback() { return new(State, GetCallStackFrames()); } - - protected void ThrowIfRunning() - { - if (IsRunning) - { - throw new InvalidOperationException("the lua state is currently running"); - } - } } \ No newline at end of file diff --git a/src/Lua/LuaThreadExtensions.cs b/src/Lua/LuaThreadExtensions.cs index 96542e15..be953850 100644 --- a/src/Lua/LuaThreadExtensions.cs +++ b/src/Lua/LuaThreadExtensions.cs @@ -1,5 +1,4 @@ using Lua.Runtime; -using System.Runtime.CompilerServices; namespace Lua; @@ -14,241 +13,4 @@ public static CoroutineLease RentCoroutine(this LuaThread thread, LuaFunction fu { return new(LuaCoroutine.Create(thread, function, isProtectedMode)); } - - public static async ValueTask DoStringAsync(this LuaThread thread, string source, Memory buffer, string? chunkName = null, CancellationToken cancellationToken = default) - { - var closure = thread.State.Load(source, chunkName ?? source); - var count = await thread.RunAsync(closure, cancellationToken); - using var results = thread.ReadReturnValues(count); - results.AsSpan()[..Math.Min(buffer.Length, count)].CopyTo(buffer.Span); - return count; - } - - public static async ValueTask DoStringAsync(this LuaThread thread, string source, string? chunkName = null, CancellationToken cancellationToken = default) - { - var closure = thread.State.Load(source, chunkName ?? source); - var count = await thread.RunAsync(closure, cancellationToken); - using var results = thread.ReadReturnValues(count); - return results.AsSpan().ToArray(); - } - - public static async ValueTask DoFileAsync(this LuaThread thread, string path, Memory buffer, CancellationToken cancellationToken = default) - { - var bytes = await File.ReadAllBytesAsync(path, cancellationToken); - var fileName = "@" + path; - var closure = thread.State.Load(bytes, fileName); - var count = await thread.RunAsync(closure, cancellationToken); - using var results = thread.ReadReturnValues(count); - results.AsSpan()[..Math.Min(buffer.Length, results.Length)].CopyTo(buffer.Span); - return results.Count; - } - - public static async ValueTask DoFileAsync(this LuaThread thread, string path, CancellationToken cancellationToken = default) - { - var bytes = await File.ReadAllBytesAsync(path, cancellationToken); - var fileName = "@" + path; - var closure = thread.State.Load(bytes, fileName); - var count = await thread.RunAsync(closure, cancellationToken); - using var results = thread.ReadReturnValues(count); - return results.AsSpan().ToArray(); - } - - public static void Push(this LuaThread thread, LuaValue value) - { - thread.Stack.Push(value); - } - - public static void Push(this LuaThread thread, params ReadOnlySpan span) - { - thread.Stack.PushRange(span); - } - - public static void Pop(this LuaThread thread, int count) - { - thread.Stack.Pop(count); - } - - public static LuaValue Pop(this LuaThread thread) - { - return thread.Stack.Pop(); - } - - public static LuaReturnValuesReader ReadReturnValues(this LuaThread thread, int argumentCount) - { - var stack = thread.Stack; - return new LuaReturnValuesReader(stack, stack.Count - argumentCount); - } - - - public static async ValueTask Arithmetic(this LuaThread thread, LuaValue x, LuaValue y, OpCode opCode, CancellationToken cancellationToken = default) - { - [MethodImpl(MethodImplOptions.NoInlining)] - static double Mod(double a, double b) - { - var mod = a % b; - if ((b > 0 && mod < 0) || (b < 0 && mod > 0)) - { - mod += b; - } - - return mod; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - static double ArithmeticOperation(OpCode code, double a, double b) - { - return code switch - { - OpCode.Add => a + b, - OpCode.Sub => a - b, - OpCode.Mul => a * b, - OpCode.Div => a / b, - OpCode.Mod => Mod(a, b), - OpCode.Pow => Math.Pow(a, b), - _ => throw new InvalidOperationException($"Unsupported arithmetic operation: {code}"), - }; - } - - - if (x.TryReadDouble(out var numX) && y.TryReadDouble(out var numY)) - { - return ArithmeticOperation(opCode, numX, numY); - } - - - return await LuaVirtualMachine.ExecuteBinaryOperationMetaMethod(thread, x, y, opCode, cancellationToken); - } - - public static async ValueTask Unary(this LuaThread thread, LuaValue value, OpCode opCode, CancellationToken cancellationToken = default) - { - if (opCode == OpCode.Unm) - { - if (value.TryReadDouble(out var numB)) - { - return -numB; - } - } - else if (opCode == OpCode.Len) - { - if (value.TryReadString(out var str)) - { - return str.Length; - } - - if (value.TryReadTable(out var table)) - { - return table.ArrayLength; - } - } - else - { - throw new InvalidOperationException($"Unsupported unary operation: {opCode}"); - } - - - return await LuaVirtualMachine.ExecuteUnaryOperationMetaMethod(thread, value, opCode, cancellationToken); - } - - - public static async ValueTask Compare(this LuaThread thread, LuaValue x, LuaValue y, OpCode opCode, CancellationToken cancellationToken = default) - { - if (opCode is not (OpCode.Eq or OpCode.Lt or OpCode.Le)) - { - throw new InvalidOperationException($"Unsupported compare operation: {opCode}"); - } - - if (opCode == OpCode.Eq) - { - if (x == y) - { - return true; - } - } - else - { - if (x.TryReadNumber(out var numX) && y.TryReadNumber(out var numY)) - { - return opCode == OpCode.Lt ? numX < numY : numX <= numY; - } - - if (x.TryReadString(out var strX) && y.TryReadString(out var strY)) - { - var c = StringComparer.Ordinal.Compare(strX, strY); - return opCode == OpCode.Lt ? c < 0 : c <= 0; - } - } - - - return await LuaVirtualMachine.ExecuteCompareOperationMetaMethod(thread, x, y, opCode, cancellationToken); - } - - public static async ValueTask GetTable(this LuaThread thread, LuaValue table, LuaValue key, CancellationToken cancellationToken = default) - { - if (table.TryReadTable(out var luaTable)) - { - if (luaTable.TryGetValue(key, out var value)) - { - return new(value); - } - } - - - return await LuaVirtualMachine.ExecuteGetTableSlowPath(thread, table, key, cancellationToken); - } - - public static async ValueTask SetTable(this LuaThread thread, LuaValue table, LuaValue key, LuaValue value, CancellationToken cancellationToken = default) - { - if (key.TryReadNumber(out var numB)) - { - if (double.IsNaN(numB)) - { - throw new LuaRuntimeException(thread, "table index is NaN"); - } - } - - - if (table.TryReadTable(out var luaTable)) - { - ref var valueRef = ref luaTable.FindValue(key); - if (!Unsafe.IsNullRef(ref valueRef) && valueRef.Type != LuaValueType.Nil) - { - valueRef = value; - return; - } - } - - await LuaVirtualMachine.ExecuteSetTableSlowPath(thread, table, key, value, cancellationToken); - } - - public static ValueTask Concat(this LuaThread thread, ReadOnlySpan values, CancellationToken cancellationToken = default) - { - thread.Stack.PushRange(values); - return Concat(thread, values.Length, cancellationToken); - } - - public static async ValueTask Concat(this LuaThread thread, int concatCount, CancellationToken cancellationToken = default) - { - return await LuaVirtualMachine.Concat(thread, concatCount, cancellationToken); - } - - public static async ValueTask Call(this LuaThread thread, int funcIndex, CancellationToken cancellationToken = default) - { - return await LuaVirtualMachine.Call(thread, funcIndex, cancellationToken); - } - - public static ValueTask Call(this LuaThread thread, LuaValue function, ReadOnlySpan arguments, CancellationToken cancellationToken = default) - { - var funcIndex = thread.Stack.Count; - thread.Stack.Push(function); - thread.Stack.PushRange(arguments); - return Impl(thread, funcIndex, cancellationToken); - - static async ValueTask Impl(LuaThread thread, int funcIndex, CancellationToken cancellationToken) - { - await LuaVirtualMachine.Call(thread, funcIndex, cancellationToken); - var count = thread.Stack.Count - funcIndex; - using var results = thread.ReadReturnValues(count); - return results.AsSpan().ToArray(); - } - } } \ No newline at end of file diff --git a/src/Lua/LuaValue.cs b/src/Lua/LuaValue.cs index 951a7e53..d47a3542 100644 --- a/src/Lua/LuaValue.cs +++ b/src/Lua/LuaValue.cs @@ -594,15 +594,10 @@ internal ValueTask CallToStringAsync(LuaFunctionExecutionContext context, C { if (this.TryGetMetamethod(context.State, Metamethods.ToString, out var metamethod)) { - if (!metamethod.TryReadFunction(out var func)) - { - LuaRuntimeException.AttemptInvalidOperation(context.Thread, "call", metamethod); - } - var stack = context.Thread.Stack; + stack.Push(metamethod); stack.Push(this); - - return func.InvokeAsync(context with { ArgumentCount = 1, ReturnFrameBase = stack.Count - 1, }, cancellationToken); + return LuaVirtualMachine.Call(context.Thread,stack.Count-2, stack.Count - 2, cancellationToken); } else { diff --git a/src/Lua/Runtime/CallStackFrame.cs b/src/Lua/Runtime/CallStackFrame.cs index 6a134c02..8a57e18a 100644 --- a/src/Lua/Runtime/CallStackFrame.cs +++ b/src/Lua/Runtime/CallStackFrame.cs @@ -12,6 +12,7 @@ public record struct CallStackFrame public int CallerInstructionIndex; internal CallStackFrameFlags Flags; internal bool IsTailCall => (Flags & CallStackFrameFlags.TailCall) == CallStackFrameFlags.TailCall; + public int Version; } [Flags] diff --git a/src/Lua/Runtime/LuaThreadAccess.cs b/src/Lua/Runtime/LuaThreadAccess.cs new file mode 100644 index 00000000..03589e6e --- /dev/null +++ b/src/Lua/Runtime/LuaThreadAccess.cs @@ -0,0 +1,147 @@ +using System.Runtime.CompilerServices; + +namespace Lua.Runtime; + +public readonly struct LuaThreadAccess +{ + internal LuaThreadAccess(LuaThread thread, int version) + { + Thread = thread; + Version = version; + } + + public readonly LuaThread Thread; + public readonly int Version; + + public bool IsValid => Version == Thread.CurrentVersion; + + public LuaState State + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get + { + ThrowIfInvalid(); + return Thread.State; + } + } + + public LuaStack Stack + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get + { + ThrowIfInvalid(); + return Thread.Stack; + } + } + + public ValueTask RunAsync(LuaFunction function, CancellationToken cancellationToken = default) + { + return RunAsync(function, 0, Thread.Stack.Count, cancellationToken); + } + + public ValueTask RunAsync(LuaFunction function, int argumentCount, CancellationToken cancellationToken = default) + { + return RunAsync(function, argumentCount, Thread.Stack.Count - argumentCount, cancellationToken); + } + + public async ValueTask RunAsync(LuaFunction function, int argumentCount, int returnBase, CancellationToken cancellationToken = default) + { + ThrowIfInvalid(); + if (function == null) + { + throw new ArgumentNullException(nameof(function)); + } + + var thread = Thread; + var varArgumentCount = function.GetVariableArgumentCount(argumentCount); + if (varArgumentCount != 0) + { + if (varArgumentCount < 0) + { + thread.Stack.SetTop(thread.Stack.Count - varArgumentCount); + varArgumentCount = 0; + } + else + { + LuaVirtualMachine.PrepareVariableArgument(thread.Stack, argumentCount, varArgumentCount); + } + + argumentCount -= varArgumentCount; + } + + var frame = new CallStackFrame { Base = thread.Stack.Count - argumentCount, VariableArgumentCount = varArgumentCount, Function = function, ReturnBase = returnBase }; + if (thread.IsInHook) + { + frame.Flags |= CallStackFrameFlags.InHook; + } + + var access = thread.PushCallStackFrame(frame); + LuaFunctionExecutionContext context = new() { Access = access, ArgumentCount = argumentCount, ReturnFrameBase = returnBase, }; + + try + { + if (this.Thread.CallOrReturnHookMask.Value != 0 && !this.Thread.IsInHook) + { + return await LuaVirtualMachine.ExecuteCallHook(context, cancellationToken); + } + + return await function.Func(context, cancellationToken); + } + finally + { + this.Thread.PopCallStackFrame(); + } + } + + + internal CallStackFrame CreateCallStackFrame(LuaFunction function, int argumentCount, int returnBase, int callerInstructionIndex) + { + var thread = Thread; + var varArgumentCount = function.GetVariableArgumentCount(argumentCount); + if (varArgumentCount != 0) + { + if (varArgumentCount < 0) + { + thread.Stack.SetTop(thread.Stack.Count - varArgumentCount); + argumentCount -= varArgumentCount; + varArgumentCount = 0; + } + else + { + LuaVirtualMachine.PrepareVariableArgument(thread.Stack, argumentCount, varArgumentCount); + } + } + + var frame = new CallStackFrame + { + Base = thread.Stack.Count - argumentCount, + VariableArgumentCount = varArgumentCount, + Function = function, + ReturnBase = returnBase, + CallerInstructionIndex = callerInstructionIndex + }; + + if (thread.IsInHook) + { + frame.Flags |= CallStackFrameFlags.InHook; + } + + return frame; + } + + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void ThrowIfInvalid() + { + if (Version != Thread.CurrentVersion) + { + ThrowInvalid(); + } + } + + void ThrowInvalid() + { + throw new InvalidOperationException("Thread access is invalid."); + } +} \ No newline at end of file diff --git a/src/Lua/Runtime/LuaThreadAccessExtensions.cs b/src/Lua/Runtime/LuaThreadAccessExtensions.cs new file mode 100644 index 00000000..0053ac49 --- /dev/null +++ b/src/Lua/Runtime/LuaThreadAccessExtensions.cs @@ -0,0 +1,258 @@ +using System.Runtime.CompilerServices; + +namespace Lua.Runtime; + +public static class LuaThreadAccessAccessExtensions +{ + public static async ValueTask DoStringAsync(this LuaThreadAccess access, string source, Memory buffer, string? chunkName = null, CancellationToken cancellationToken = default) + { + access.ThrowIfInvalid(); + var closure = access.State.Load(source, chunkName ?? source); + var count = await access.RunAsync(closure, 0, cancellationToken); + using var results = access.ReadReturnValues(count); + results.AsSpan()[..Math.Min(buffer.Length, count)].CopyTo(buffer.Span); + return count; + } + + public static async ValueTask DoStringAsync(this LuaThreadAccess access, string source, string? chunkName = null, CancellationToken cancellationToken = default) + { + access.ThrowIfInvalid(); + var closure = access.State.Load(source, chunkName ?? source); + var count = await access.RunAsync(closure, 0, cancellationToken); + using var results = access.ReadReturnValues(count); + return results.AsSpan().ToArray(); + } + + public static async ValueTask DoFileAsync(this LuaThreadAccess access, string path, Memory buffer, CancellationToken cancellationToken = default) + { + access.ThrowIfInvalid(); + var bytes = await File.ReadAllBytesAsync(path, cancellationToken); + var fileName = "@" + path; + var closure = access.State.Load(bytes, fileName); + var count = await access.RunAsync(closure, 0, cancellationToken); + using var results = access.ReadReturnValues(count); + results.AsSpan()[..Math.Min(buffer.Length, results.Length)].CopyTo(buffer.Span); + return results.Count; + } + + public static async ValueTask DoFileAsync(this LuaThreadAccess access, string path, CancellationToken cancellationToken = default) + { + var bytes = await File.ReadAllBytesAsync(path, cancellationToken); + var fileName = "@" + path; + var closure = access.State.Load(bytes, fileName); + var count = await access.RunAsync(closure, 0, cancellationToken); + using var results = access.ReadReturnValues(count); + return results.AsSpan().ToArray(); + } + + public static void Push(this LuaThreadAccess access, LuaValue value) + { + access.ThrowIfInvalid(); + access.Stack.Push(value); + } + + public static void Push(this LuaThreadAccess access, params ReadOnlySpan span) + { + access.ThrowIfInvalid(); + access.Stack.PushRange(span); + } + + public static void Pop(this LuaThreadAccess access, int count) + { + access.ThrowIfInvalid(); + access.Stack.Pop(count); + } + + public static LuaValue Pop(this LuaThreadAccess access) + { + access.ThrowIfInvalid(); + return access.Stack.Pop(); + } + + public static LuaReturnValuesReader ReadReturnValues(this LuaThreadAccess access, int argumentCount) + { + access.ThrowIfInvalid(); + var stack = access.Stack; + return new LuaReturnValuesReader(stack, stack.Count - argumentCount); + } + + + public static async ValueTask Arithmetic(this LuaThreadAccess access, LuaValue x, LuaValue y, OpCode opCode, CancellationToken cancellationToken = default) + { + [MethodImpl(MethodImplOptions.NoInlining)] + static double Mod(double a, double b) + { + var mod = a % b; + if ((b > 0 && mod < 0) || (b < 0 && mod > 0)) + { + mod += b; + } + + return mod; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static double ArithmeticOperation(OpCode code, double a, double b) + { + return code switch + { + OpCode.Add => a + b, + OpCode.Sub => a - b, + OpCode.Mul => a * b, + OpCode.Div => a / b, + OpCode.Mod => Mod(a, b), + OpCode.Pow => Math.Pow(a, b), + _ => throw new InvalidOperationException($"Unsupported arithmetic operation: {code}"), + }; + } + + + if (x.TryReadDouble(out var numX) && y.TryReadDouble(out var numY)) + { + return ArithmeticOperation(opCode, numX, numY); + } + + access.ThrowIfInvalid(); + return await LuaVirtualMachine.ExecuteBinaryOperationMetaMethod(access.Thread, x, y, opCode, cancellationToken); + } + + public static async ValueTask Unary(this LuaThreadAccess access, LuaValue value, OpCode opCode, CancellationToken cancellationToken = default) + { + if (opCode == OpCode.Unm) + { + if (value.TryReadDouble(out var numB)) + { + return -numB; + } + } + else if (opCode == OpCode.Len) + { + if (value.TryReadString(out var str)) + { + return str.Length; + } + + if (value.TryReadTable(out var table)) + { + return table.ArrayLength; + } + } + else + { + throw new InvalidOperationException($"Unsupported unary operation: {opCode}"); + } + + access.ThrowIfInvalid(); + return await LuaVirtualMachine.ExecuteUnaryOperationMetaMethod(access.Thread, value, opCode, cancellationToken); + } + + + public static async ValueTask Compare(this LuaThreadAccess access, LuaValue x, LuaValue y, OpCode opCode, CancellationToken cancellationToken = default) + { + if (opCode is not (OpCode.Eq or OpCode.Lt or OpCode.Le)) + { + throw new InvalidOperationException($"Unsupported compare operation: {opCode}"); + } + + if (opCode == OpCode.Eq) + { + if (x == y) + { + return true; + } + } + else + { + if (x.TryReadNumber(out var numX) && y.TryReadNumber(out var numY)) + { + return opCode == OpCode.Lt ? numX < numY : numX <= numY; + } + + if (x.TryReadString(out var strX) && y.TryReadString(out var strY)) + { + var c = StringComparer.Ordinal.Compare(strX, strY); + return opCode == OpCode.Lt ? c < 0 : c <= 0; + } + } + + access.ThrowIfInvalid(); + return await LuaVirtualMachine.ExecuteCompareOperationMetaMethod(access.Thread, x, y, opCode, cancellationToken); + } + + public static async ValueTask GetTable(this LuaThreadAccess access, LuaValue table, LuaValue key, CancellationToken cancellationToken = default) + { + if (table.TryReadTable(out var luaTable)) + { + if (luaTable.TryGetValue(key, out var value)) + { + return new(value); + } + } + + access.ThrowIfInvalid(); + return await LuaVirtualMachine.ExecuteGetTableSlowPath(access.Thread, table, key, cancellationToken); + } + + public static async ValueTask SetTable(this LuaThreadAccess access, LuaValue table, LuaValue key, LuaValue value, CancellationToken cancellationToken = default) + { + access.ThrowIfInvalid(); + + if (key.TryReadNumber(out var numB)) + { + if (double.IsNaN(numB)) + { + throw new LuaRuntimeException(access.Thread, "table index is NaN"); + } + } + + + if (table.TryReadTable(out var luaTable)) + { + ref var valueRef = ref luaTable.FindValue(key); + if (!Unsafe.IsNullRef(ref valueRef) && valueRef.Type != LuaValueType.Nil) + { + valueRef = value; + return; + } + } + + await LuaVirtualMachine.ExecuteSetTableSlowPath(access.Thread, table, key, value, cancellationToken); + } + + public static ValueTask Concat(this LuaThreadAccess access, ReadOnlySpan values, CancellationToken cancellationToken = default) + { + access.ThrowIfInvalid(); + access.Stack.PushRange(values); + return Concat(access, values.Length, cancellationToken); + } + + public static async ValueTask Concat(this LuaThreadAccess access, int concatCount, CancellationToken cancellationToken = default) + { + access.ThrowIfInvalid(); + return await LuaVirtualMachine.Concat(access.Thread, concatCount, cancellationToken); + } + + public static ValueTask Call(this LuaThreadAccess access, int funcIndex, int returnBase, CancellationToken cancellationToken = default) + { + access.ThrowIfInvalid(); + return LuaVirtualMachine.Call(access.Thread, funcIndex, returnBase, cancellationToken); + } + + public static ValueTask Call(this LuaThreadAccess access, LuaValue function, ReadOnlySpan arguments, CancellationToken cancellationToken = default) + { + access.ThrowIfInvalid(); + var thread = access.Thread; + var funcIndex = thread.Stack.Count; + thread.Stack.Push(function); + thread.Stack.PushRange(arguments); + return Impl(access, funcIndex, cancellationToken); + + static async ValueTask Impl(LuaThreadAccess access, int funcIndex, CancellationToken cancellationToken) + { + await LuaVirtualMachine.Call(access.Thread, funcIndex, funcIndex, cancellationToken); + var count = access.Stack.Count - funcIndex; + using var results = access.ReadReturnValues(count); + return results.AsSpan().ToArray(); + } + } +} \ No newline at end of file diff --git a/src/Lua/Runtime/LuaVirtualMachine.Debug.cs b/src/Lua/Runtime/LuaVirtualMachine.Debug.cs index 8993edf1..f03f5af3 100644 --- a/src/Lua/Runtime/LuaVirtualMachine.Debug.cs +++ b/src/Lua/Runtime/LuaVirtualMachine.Debug.cs @@ -32,25 +32,18 @@ static async ValueTask Impl(VirtualMachineExecutionContext context) context.Thread.HookCount = context.Thread.BaseHookCount; var hook = context.Thread.Hook!; + var stack = context.Thread.Stack; + var top = stack.Count; stack.Push("count"); stack.Push(LuaValue.Nil); - var funcContext = new LuaFunctionExecutionContext { Thread = context.Thread, ArgumentCount = 2, ReturnFrameBase = context.Thread.Stack.Count - 2, }; - var frame = new CallStackFrame - { - Base = funcContext.FrameBase, - ReturnBase = funcContext.ReturnFrameBase, - VariableArgumentCount = hook.GetVariableArgumentCount(funcContext.ArgumentCount), - Function = hook, - CallerInstructionIndex = context.Pc, - }; - frame.Flags |= CallStackFrameFlags.InHook; context.Thread.IsInHook = true; - context.Thread.PushCallStackFrame(frame); + var frame = context.Thread.CurrentAccess.CreateCallStackFrame(hook, 2, top, context.Pc); + var access = context.Thread.PushCallStackFrame(frame); + var funcContext = new LuaFunctionExecutionContext { Access = access, ArgumentCount = stack.Count - frame.Base, ReturnFrameBase = frame.ReturnBase }; await hook.Func(funcContext, context.CancellationToken); context.Thread.IsInHook = false; - countHookIsDone = true; } @@ -70,22 +63,22 @@ static async ValueTask Impl(VirtualMachineExecutionContext context) var hook = context.Thread.Hook!; var stack = context.Thread.Stack; + var top = stack.Count; stack.Push("line"); stack.Push(line); - var funcContext = new LuaFunctionExecutionContext { Thread = context.Thread, ArgumentCount = 2, ReturnFrameBase = context.Thread.Stack.Count - 2, }; - var frame = new CallStackFrame - { - Base = funcContext.FrameBase, - ReturnBase = funcContext.ReturnFrameBase, - VariableArgumentCount = hook.GetVariableArgumentCount(funcContext.ArgumentCount), - Function = hook, - CallerInstructionIndex = pc, - }; - frame.Flags |= CallStackFrameFlags.InHook; context.Thread.IsInHook = true; - context.Thread.PushCallStackFrame(frame); - await hook.Func(funcContext, context.CancellationToken); - context.Thread.IsInHook = false; + var frame = context.Thread.CurrentAccess.CreateCallStackFrame(hook, 2, top, pc); + var access = context.Thread.PushCallStackFrame(frame); + var funcContext = new LuaFunctionExecutionContext { Access = access, ArgumentCount = stack.Count - frame.Base, ReturnFrameBase = frame.ReturnBase }; + try + { + await hook.Func(funcContext, context.CancellationToken); + } + finally + { + context.Thread.IsInHook = false; + } + context.Thread.LastPc = pc; return 0; } @@ -105,69 +98,51 @@ static async ValueTask Impl(VirtualMachineExecutionContext context) [MethodImpl(MethodImplOptions.NoInlining)] static ValueTask ExecuteCallHook(VirtualMachineExecutionContext context, in CallStackFrame frame, int arguments, bool isTailCall = false) { - return ExecuteCallHook(new() - { - Thread = context.Thread, ArgumentCount = arguments, ReturnFrameBase = frame.ReturnBase - },context.CancellationToken, isTailCall); + return ExecuteCallHook(new() { Access = context.Thread.CurrentAccess, ArgumentCount = arguments, ReturnFrameBase = frame.ReturnBase }, context.CancellationToken, isTailCall); } - internal static async ValueTask ExecuteCallHook(LuaFunctionExecutionContext context, CancellationToken cancellationToken, bool isTailCall = false) + internal static async ValueTask ExecuteCallHook(LuaFunctionExecutionContext context, CancellationToken cancellationToken, bool isTailCall = false) { var argCount = context.ArgumentCount; var hook = context.Thread.Hook!; var stack = context.Thread.Stack; if (context.Thread.IsCallHookEnabled) { + var top = stack.Count; stack.Push((isTailCall ? "tail call" : "call")); stack.Push(LuaValue.Nil); - var funcContext = new LuaFunctionExecutionContext { Thread = context.Thread, ArgumentCount = 2, ReturnFrameBase = context.Thread.Stack.Count - 2, }; - CallStackFrame frame = new() - { - Base = funcContext.FrameBase, - ReturnBase = funcContext.ReturnFrameBase, - VariableArgumentCount = hook.GetVariableArgumentCount(2), - Function = hook, - CallerInstructionIndex = 0, - Flags = CallStackFrameFlags.InHook - }; - - context.Thread.PushCallStackFrame(frame); + context.Thread.IsInHook = true; + var frame = context.Thread.CurrentAccess.CreateCallStackFrame(hook, 2, top, context.Thread.GetCurrentFrame().CallerInstructionIndex); + var access = context.Thread.PushCallStackFrame(frame); + var funcContext = new LuaFunctionExecutionContext { Access = access, ArgumentCount = stack.Count - frame.Base, ReturnFrameBase = frame.ReturnBase }; try { - context.Thread.IsInHook = true; await hook.Func(funcContext, cancellationToken); - context.Thread.PopCallStackFrameWithStackPop(); } finally { context.Thread.IsInHook = false; + context.Thread.PopCallStackFrameWithStackPop(); } } { - ref readonly var frame = ref context.Thread.GetCurrentFrame(); - var task = frame.Function.Func(new() { Thread = context.Thread, ArgumentCount = argCount, ReturnFrameBase = frame.ReturnBase, }, cancellationToken); + var frame = context.Thread.GetCurrentFrame(); + var task = frame.Function.Func(new() { Access = context.Thread.CurrentAccess, ArgumentCount = argCount, ReturnFrameBase = frame.ReturnBase, }, cancellationToken); var r = await task; if (isTailCall || !context.Thread.IsReturnHookEnabled) { return r; } + var top = stack.Count; stack.Push("return"); stack.Push(LuaValue.Nil); - var funcContext = new LuaFunctionExecutionContext { Thread = context.Thread, ArgumentCount = 2, ReturnFrameBase = context.Thread.Stack.Count - 2, }; - - - context.Thread.PushCallStackFrame(new() - { - Base = funcContext.FrameBase, - ReturnBase = funcContext.ReturnFrameBase, - VariableArgumentCount = hook.GetVariableArgumentCount(2), - Function = hook, - CallerInstructionIndex = 0, - Flags = CallStackFrameFlags.InHook - }); + context.Thread.IsInHook = true; + frame = context.Thread.CurrentAccess.CreateCallStackFrame(hook, 2, top, 0); + var access = context.Thread.PushCallStackFrame(frame); + var funcContext = new LuaFunctionExecutionContext { Access = access, ArgumentCount = stack.Count - frame.Base, ReturnFrameBase = frame.ReturnBase }; try { context.Thread.IsInHook = true; @@ -176,9 +151,9 @@ internal static async ValueTask ExecuteCallHook(LuaFunctionExecutionContext finally { context.Thread.IsInHook = false; + context.Thread.PopCallStackFrameWithStackPop(); } - context.Thread.PopCallStackFrameWithStackPop(); return r; } } diff --git a/src/Lua/Runtime/LuaVirtualMachine.cs b/src/Lua/Runtime/LuaVirtualMachine.cs index 6d4b6604..607fe9a5 100644 --- a/src/Lua/Runtime/LuaVirtualMachine.cs +++ b/src/Lua/Runtime/LuaVirtualMachine.cs @@ -1028,24 +1028,27 @@ static async ValueTask ExecuteBinaryOperationMetaMethod(int target, LuaValue vb, var varArgCount = func.GetVariableArgumentCount(argCount); var newFrame = func.CreateNewFrame(context, stack.Count - argCount + varArgCount, target, varArgCount); - - context.Thread.PushCallStackFrame(newFrame); + var thread = context.Thread; + var access = thread.PushCallStackFrame(newFrame); try { - if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook) + var functionContext = new LuaFunctionExecutionContext() { Access = access, ArgumentCount = argCount, ReturnFrameBase = target }; + if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook) { - await ExecuteCallHook(context, newFrame, argCount); + await ExecuteCallHook(functionContext, context.CancellationToken); + stack.PopUntil(target + 1); + context.PostOperation = PostOperationType.DontPop; + return; } - - await func.Invoke(context, newFrame, argCount); + await func.Func(functionContext, context.CancellationToken); stack.PopUntil(target + 1); context.PostOperation = PostOperationType.DontPop; return; } finally { - context.Thread.PopCallStackFrame(); + thread.PopCallStackFrame(); } } @@ -1082,7 +1085,7 @@ static bool Call(VirtualMachineExecutionContext context, out bool doRestart) var newFrame = func.CreateNewFrame(context, newBase, RA, variableArgumentCount); - thread.PushCallStackFrame(newFrame); + var access = thread.PushCallStackFrame(newFrame); if (thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook) { context.PostOperation = PostOperationType.Call; @@ -1099,14 +1102,15 @@ static bool Call(VirtualMachineExecutionContext context, out bool doRestart) } doRestart = false; - return FuncCall(context, in newFrame, func, newBase, argumentCount); + return FuncCall(context, access, func, argumentCount, newFrame.ReturnBase); - static bool FuncCall(VirtualMachineExecutionContext context, in CallStackFrame newFrame, LuaFunction func, int newBase, int argumentCount) + static bool FuncCall(VirtualMachineExecutionContext context, LuaThreadAccess access, LuaFunction func, int argumentCount, int returnBase) { - var task = func.Invoke(context, newFrame, argumentCount); + var task = func.Func(new() { Access = access, ArgumentCount = argumentCount, ReturnFrameBase = returnBase }, context.CancellationToken); if (!task.IsCompleted) { + context.PostOperation = PostOperationType.Call; context.Task = task; return false; } @@ -1132,7 +1136,7 @@ static bool FuncCall(VirtualMachineExecutionContext context, in CallStackFrame n } } - internal static async ValueTask Call(LuaThread thread, int funcIndex, CancellationToken ct) + internal static async ValueTask Call(LuaThread thread, int funcIndex, int returnBase, CancellationToken ct) { var stack = thread.Stack; var newBase = funcIndex + 1; @@ -1152,12 +1156,12 @@ internal static async ValueTask Call(LuaThread thread, int funcIndex, Cance var (argCount, variableArgumentCount) = PrepareForFunctionCall(thread, func, newBase); newBase += variableArgumentCount; - var newFrame = new CallStackFrame() { Base = newBase, VariableArgumentCount = variableArgumentCount, Function = func, ReturnBase = funcIndex }; + var newFrame = new CallStackFrame() { Base = newBase, VariableArgumentCount = variableArgumentCount, Function = func, ReturnBase = returnBase }; - thread.PushCallStackFrame(newFrame); + var access = thread.PushCallStackFrame(newFrame); try { - var functionContext = new LuaFunctionExecutionContext() { Thread = thread, ArgumentCount = argCount, ReturnFrameBase = funcIndex }; + var functionContext = new LuaFunctionExecutionContext() { Access = access, ArgumentCount = argCount, ReturnFrameBase = returnBase }; if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook) { await ExecuteCallHook(functionContext, ct); @@ -1218,12 +1222,13 @@ static bool TailCall(VirtualMachineExecutionContext context, out bool doRestart) var (argumentCount, variableArgumentCount) = PrepareForFunctionTailCall(thread, func, instruction, newBase, isMetamethod); newBase = context.FrameBase + variableArgumentCount; stack.PopUntil(newBase + argumentCount); - var lastPc = thread.GetCurrentFrame().CallerInstructionIndex; + var lastFrame = thread.GetCurrentFrame(); context.Thread.PopCallStackFrame(); var newFrame = func.CreateNewTailCallFrame(context, newBase, context.CurrentReturnFrameBase, variableArgumentCount); - newFrame.CallerInstructionIndex = lastPc; - thread.PushCallStackFrame(newFrame); + newFrame.CallerInstructionIndex = lastFrame.CallerInstructionIndex; + newFrame.Version = lastFrame.Version; + var access = thread.PushCallStackFrame(newFrame); if (thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook) { @@ -1242,7 +1247,7 @@ static bool TailCall(VirtualMachineExecutionContext context, out bool doRestart) } doRestart = false; - var task = func.Invoke(context, newFrame, argumentCount); + var task = func.Func(new() { Access = access, ArgumentCount = argumentCount, ReturnFrameBase = context.CurrentReturnFrameBase }, context.CancellationToken); if (!task.IsCompleted) { @@ -1310,11 +1315,11 @@ static bool TForCall(VirtualMachineExecutionContext context, out bool doRestart) stack.PopUntil(newBase + argumentCount); var newFrame = iterator.CreateNewFrame(context, newBase, RA + 3, variableArgumentCount); - context.Thread.PushCallStackFrame(newFrame); + var access = context.Thread.PushCallStackFrame(newFrame); if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook) { context.PostOperation = PostOperationType.TForCall; - context.Task = ExecuteCallHook(context, newFrame, 2); + context.Task = ExecuteCallHook(context, newFrame, stack.Count - newBase); doRestart = false; return false; } @@ -1326,7 +1331,7 @@ static bool TForCall(VirtualMachineExecutionContext context, out bool doRestart) return true; } - var task = iterator.Invoke(context, newFrame, 2); + var task = iterator.Func(new() { Access = access, ArgumentCount = stack.Count - newBase, ReturnFrameBase = newFrame.ReturnBase }, context.CancellationToken); if (!task.IsCompleted) { context.PostOperation = PostOperationType.TForCall; @@ -1450,7 +1455,7 @@ static bool CallGetTableFunc(LuaValue table, LuaFunction indexTable, LuaValue ke stack.Push(key); var newFrame = indexTable.CreateNewFrame(context, stack.Count - 2); - context.Thread.PushCallStackFrame(newFrame); + var access = context.Thread.PushCallStackFrame(newFrame); if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook) { context.PostOperation = context.Instruction.OpCode == OpCode.GetTable ? PostOperationType.SetResult : PostOperationType.Self; @@ -1468,7 +1473,7 @@ static bool CallGetTableFunc(LuaValue table, LuaFunction indexTable, LuaValue ke return true; } - var task = indexTable.Invoke(context, newFrame, 2); + var task = indexTable.Func(new() { Access = access, ArgumentCount = 2, ReturnFrameBase = newFrame.ReturnBase }, context.CancellationToken); if (!task.IsCompleted) { @@ -1539,8 +1544,8 @@ static async ValueTask CallGetTableFunc(LuaThread thread, LuaFunction var newFrame = new CallStackFrame() { Base = thread.Stack.Count - 2 + varArgCount, VariableArgumentCount = varArgCount, Function = indexTable, ReturnBase = top }; - thread.PushCallStackFrame(newFrame); - var functionContext = new LuaFunctionExecutionContext() { Thread = thread, ArgumentCount = 2, ReturnFrameBase = top }; + var access = thread.PushCallStackFrame(newFrame); + var functionContext = new LuaFunctionExecutionContext() { Access = access, ArgumentCount = 2, ReturnFrameBase = top }; if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook) { await ExecuteCallHook(functionContext, ct); @@ -1620,7 +1625,7 @@ static bool CallSetTableFunc(LuaValue table, LuaFunction newIndexFunction, LuaVa stack.Push(value); var newFrame = newIndexFunction.CreateNewFrame(context, stack.Count - 3); - context.Thread.PushCallStackFrame(newFrame); + var access = context.Thread.PushCallStackFrame(newFrame); if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook) { context.PostOperation = PostOperationType.Nop; @@ -1636,7 +1641,7 @@ static bool CallSetTableFunc(LuaValue table, LuaFunction newIndexFunction, LuaVa return true; } - var task = newIndexFunction.Invoke(context, newFrame, 3); + var task = newIndexFunction.Func(new() { Access = access, ArgumentCount = 3, ReturnFrameBase = newFrame.ReturnBase }, context.CancellationToken); if (!task.IsCompleted) { context.PostOperation = PostOperationType.Nop; @@ -1712,8 +1717,8 @@ static async ValueTask CallSetTableFunc(LuaThread thread, LuaFunction newIndexFu var newFrame = new CallStackFrame() { Base = thread.Stack.Count - 3 + varArgCount, VariableArgumentCount = varArgCount, Function = newIndexFunction, ReturnBase = top }; - thread.PushCallStackFrame(newFrame); - var functionContext = new LuaFunctionExecutionContext() { Thread = thread, ArgumentCount = 3, ReturnFrameBase = top }; + var access = thread.PushCallStackFrame(newFrame); + var functionContext = new LuaFunctionExecutionContext() { Access = access, ArgumentCount = 3, ReturnFrameBase = top }; if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook) { await ExecuteCallHook(functionContext, ct); @@ -1756,7 +1761,7 @@ static bool ExecuteBinaryOperationMetaMethod(LuaValue vb, LuaValue vc, newBase += variableArgumentCount; var newFrame = func.CreateNewFrame(context, newBase, context.FrameBase + context.Instruction.A, variableArgumentCount); - context.Thread.PushCallStackFrame(newFrame); + var access = context.Thread.PushCallStackFrame(newFrame); if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook) { context.PostOperation = PostOperationType.SetResult; @@ -1773,7 +1778,7 @@ static bool ExecuteBinaryOperationMetaMethod(LuaValue vb, LuaValue vc, } - var task = func.Invoke(context, newFrame, argCount); + var task = func.Func(new() { Access = access, ArgumentCount = argCount, ReturnFrameBase = newFrame.ReturnBase }, context.CancellationToken); if (!task.IsCompleted) { @@ -1828,10 +1833,10 @@ internal static async ValueTask ExecuteBinaryOperationMetaMethod(LuaTh var newFrame = new CallStackFrame() { Base = newBase, VariableArgumentCount = variableArgumentCount, Function = func, ReturnBase = newBase }; - thread.PushCallStackFrame(newFrame); + var access = thread.PushCallStackFrame(newFrame); try { - var functionContext = new LuaFunctionExecutionContext() { Thread = thread, ArgumentCount = argCount, ReturnFrameBase = newBase }; + var functionContext = new LuaFunctionExecutionContext() { Access = access, ArgumentCount = argCount, ReturnFrameBase = newBase }; if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook) { await ExecuteCallHook(functionContext, ct); @@ -1886,7 +1891,7 @@ static bool ExecuteUnaryOperationMetaMethod(LuaValue vb, VirtualMachineExecution var newFrame = func.CreateNewFrame(context, newBase, context.FrameBase + context.Instruction.A, variableArgumentCount); - context.Thread.PushCallStackFrame(newFrame); + var access = context.Thread.PushCallStackFrame(newFrame); if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook) { context.PostOperation = PostOperationType.SetResult; @@ -1903,7 +1908,7 @@ static bool ExecuteUnaryOperationMetaMethod(LuaValue vb, VirtualMachineExecution } - var task = func.Invoke(context, newFrame, argCount); + var task = func.Func(new() { Access = access, ArgumentCount = argCount, ReturnFrameBase = newFrame.ReturnBase }, context.CancellationToken); if (!task.IsCompleted) { @@ -1960,10 +1965,10 @@ internal static async ValueTask ExecuteUnaryOperationMetaMethod(LuaThr newBase += variableArgumentCount; var newFrame = new CallStackFrame() { Base = newBase, VariableArgumentCount = variableArgumentCount, Function = func, ReturnBase = newBase }; - thread.PushCallStackFrame(newFrame); + var access = thread.PushCallStackFrame(newFrame); try { - var functionContext = new LuaFunctionExecutionContext() { Thread = thread, ArgumentCount = argCount, ReturnFrameBase = newBase }; + var functionContext = new LuaFunctionExecutionContext() { Access = access, ArgumentCount = argCount, ReturnFrameBase = newBase }; if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook) { await ExecuteCallHook(functionContext, ct); @@ -2019,7 +2024,7 @@ static bool ExecuteCompareOperationMetaMethod(LuaValue vb, LuaValue vc, var varArgCount = func.GetVariableArgumentCount(argCount); var newFrame = func.CreateNewFrame(context, stack.Count - argCount + varArgCount); if (reverseLe) newFrame.Flags |= CallStackFrameFlags.ReversedLe; - context.Thread.PushCallStackFrame(newFrame); + var access = context.Thread.PushCallStackFrame(newFrame); if (context.Thread.CallOrReturnHookMask.Value != 0 && !context.Thread.IsInHook) { context.PostOperation = PostOperationType.Compare; @@ -2035,7 +2040,7 @@ static bool ExecuteCompareOperationMetaMethod(LuaValue vb, LuaValue vc, return true; } - var task = func.Invoke(context, newFrame, argCount); + var task = func.Func(new() { Access = access, ArgumentCount = argCount, ReturnFrameBase = newFrame.ReturnBase }, context.CancellationToken); if (!task.IsCompleted) { @@ -2117,10 +2122,10 @@ internal static async ValueTask ExecuteCompareOperationMetaMethod(LuaThrea newBase += variableArgumentCount; var newFrame = new CallStackFrame() { Base = newBase, VariableArgumentCount = variableArgumentCount, Function = func, ReturnBase = newBase }; - thread.PushCallStackFrame(newFrame); + var access = thread.PushCallStackFrame(newFrame); try { - var functionContext = new LuaFunctionExecutionContext() { Thread = thread, ArgumentCount = argCount, ReturnFrameBase = newBase }; + var functionContext = new LuaFunctionExecutionContext() { Access = access, ArgumentCount = argCount, ReturnFrameBase = newBase }; if (thread.CallOrReturnHookMask.Value != 0 && !thread.IsInHook) { await ExecuteCallHook(functionContext, ct); @@ -2349,11 +2354,4 @@ static CallStackFrame CreateNewTailCallFrame(this LuaFunction function, VirtualM Flags = CallStackFrameFlags.TailCall }; } - - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - static ValueTask Invoke(this LuaFunction function, VirtualMachineExecutionContext context, in CallStackFrame frame, int arguments) - { - return function.Func(new() { Thread = context.Thread, ArgumentCount = arguments, ReturnFrameBase = frame.ReturnBase, }, context.CancellationToken); - } } \ No newline at end of file diff --git a/src/Lua/Standard/BasicLibrary.cs b/src/Lua/Standard/BasicLibrary.cs index 818449e0..1d1e46cd 100644 --- a/src/Lua/Standard/BasicLibrary.cs +++ b/src/Lua/Standard/BasicLibrary.cs @@ -86,12 +86,13 @@ public ValueTask CollectGarbage(LuaFunctionExecutionContext context, Cancel public async ValueTask DoFile(LuaFunctionExecutionContext context, CancellationToken cancellationToken) { var arg0 = context.GetArgument(0); + context.Thread.Stack.PopUntil(context.ReturnFrameBase); // do not use LuaState.DoFileAsync as it uses the newExecutionContext - var bytes = await File.ReadAllBytesAsync(arg0, cancellationToken); + var bytes = File.ReadAllBytes(arg0); var fileName = "@" + arg0; - - return await context.State.Load(bytes, fileName).InvokeAsync(context with { ArgumentCount = context.ArgumentCount - 1 }, cancellationToken); + var closure = context.State.Load(bytes, fileName); + return await context.Access.RunAsync(closure,cancellationToken); } public ValueTask Error(LuaFunctionExecutionContext context, CancellationToken cancellationToken) @@ -126,22 +127,24 @@ public ValueTask GetMetatable(LuaFunctionExecutionContext context, Cancella return default; } - public ValueTask IPairs(LuaFunctionExecutionContext context, CancellationToken cancellationToken) + public async ValueTask IPairs(LuaFunctionExecutionContext context, CancellationToken cancellationToken) { var arg0 = context.GetArgument(0); // If table has a metamethod __ipairs, calls it with table as argument and returns the first three results from the call. if (arg0.Metatable != null && arg0.Metatable.TryGetValue(Metamethods.IPairs, out var metamethod)) { - if (!metamethod.TryRead(out var function)) - { - LuaRuntimeException.AttemptInvalidOperation(context.Thread, "call", metamethod); - } + var stack = context.Thread.Stack; + var top = stack.Count; + stack.Push(metamethod); + stack.Push(arg0); - return function.InvokeAsync(context, cancellationToken); + await LuaVirtualMachine.Call(context.Access.Thread,top,context.ReturnFrameBase,cancellationToken); + stack.SetTop(context.ReturnFrameBase+3); + return 3; } - return new(context.Return(IPairsIterator, arg0, 0)); + return context.Return(IPairsIterator, arg0, 0); } public async ValueTask LoadFile(LuaFunctionExecutionContext context, CancellationToken cancellationToken) @@ -223,31 +226,33 @@ public ValueTask Next(LuaFunctionExecutionContext context, CancellationToke } } - public ValueTask Pairs(LuaFunctionExecutionContext context, CancellationToken cancellationToken) + public async ValueTask Pairs(LuaFunctionExecutionContext context, CancellationToken cancellationToken) { var arg0 = context.GetArgument(0); // If table has a metamethod __pairs, calls it with table as argument and returns the first three results from the call. if (arg0.Metatable != null && arg0.Metatable.TryGetValue(Metamethods.Pairs, out var metamethod)) { - if (!metamethod.TryRead(out var function)) - { - LuaRuntimeException.AttemptInvalidOperation(context.Thread, "call", metamethod); - } + var stack = context.Thread.Stack; + var top = stack.Count; + stack.Push(metamethod); + stack.Push(arg0); - return function.InvokeAsync(context, cancellationToken); + await LuaVirtualMachine.Call(context.Access.Thread,top,context.ReturnFrameBase,cancellationToken); + stack.SetTop(context.ReturnFrameBase+3); + return 3; + } - return new(context.Return(PairsIterator, arg0, LuaValue.Nil)); + return (context.Return(PairsIterator, arg0, LuaValue.Nil)); } public async ValueTask PCall(LuaFunctionExecutionContext context, CancellationToken cancellationToken) { var frameCount = context.Thread.CallStackFrameCount; - var arg0 = context.GetArgument(0); try { - var count = await arg0.InvokeAsync(context with { ArgumentCount = context.ArgumentCount - 1, ReturnFrameBase = context.ReturnFrameBase + 1 }, cancellationToken); + var count = await LuaVirtualMachine.Call(context.Access.Thread,context.FrameBase,context.ReturnFrameBase+1,cancellationToken); context.Thread.Stack.Get(context.ReturnFrameBase) = true; return count + 1; @@ -545,32 +550,37 @@ public ValueTask Type(LuaFunctionExecutionContext context, CancellationToke public async ValueTask XPCall(LuaFunctionExecutionContext context, CancellationToken cancellationToken) { var frameCount = context.Thread.CallStackFrameCount; - var arg0 = context.GetArgument(0); + var arg0 = context.GetArgument(0); var arg1 = context.GetArgument(1); try { - var count = await arg0.InvokeAsync(context with { ArgumentCount = context.ArgumentCount - 2, ReturnFrameBase = context.ReturnFrameBase + 1 }, cancellationToken); + var stack = context.Thread.Stack; + stack.Get(context.FrameBase+1) = arg0; + var count = await LuaVirtualMachine.Call(context.Access.Thread,context.FrameBase + 1,context.ReturnFrameBase+1,cancellationToken); context.Thread.Stack.Get(context.ReturnFrameBase) = true; return count + 1; } catch (Exception ex) { - context.Thread.PopCallStackFrameUntil(frameCount); + var thread = context.Thread; + thread.PopCallStackFrameUntil(frameCount); + + var access = thread.CurrentAccess; if (ex is LuaRuntimeException luaEx) { luaEx.Forget(); - context.Thread.Push(luaEx.ErrorObject); + access.Push(luaEx.ErrorObject); } else { - context.Thread.Push(ex.Message); + access.Push(ex.Message); } // invoke error handler - var count = await arg1.InvokeAsync(context with { ArgumentCount = 1, ReturnFrameBase = context.ReturnFrameBase + 1 }, cancellationToken); + var count = await access.RunAsync(arg1, 1,context.ReturnFrameBase+1, cancellationToken); context.Thread.Stack.Get(context.ReturnFrameBase) = false; return count + 1; } diff --git a/src/Lua/Standard/DebugLibrary.cs b/src/Lua/Standard/DebugLibrary.cs index 5dc93e79..4d167969 100644 --- a/src/Lua/Standard/DebugLibrary.cs +++ b/src/Lua/Standard/DebugLibrary.cs @@ -456,26 +456,22 @@ public async ValueTask SetHook(LuaFunctionExecutionContext context, Cancell if (thread.IsReturnHookEnabled && context.Thread == thread) { var stack = thread.Stack; + var top = stack.Count; stack.Push("return"); stack.Push(LuaValue.Nil); - var funcContext = new LuaFunctionExecutionContext { Thread = context.Thread, ArgumentCount = 2, ReturnFrameBase = stack.Count - 2, }; - var frame = new CallStackFrame - { - Base = funcContext.FrameBase, ReturnBase = funcContext.ReturnFrameBase, VariableArgumentCount = hook.GetVariableArgumentCount(2), Function = hook, - }; - frame.Flags |= CallStackFrameFlags.InHook; - thread.PushCallStackFrame(frame); + context.Thread.IsInHook = true; + var frame = context.Thread.CurrentAccess.CreateCallStackFrame(hook, 2, top, 0); + var access= context.Thread.PushCallStackFrame(frame); + var funcContext = new LuaFunctionExecutionContext { Access =access, ArgumentCount = stack.Count-frame.Base, ReturnFrameBase = frame.ReturnBase }; try { - thread.IsInHook = true; await hook.Func(funcContext, cancellationToken); } finally { - thread.IsInHook = false; + context.Thread.IsInHook = false; + context.Thread.PopCallStackFrameWithStackPop(); } - - thread.PopCallStackFrameWithStackPop(); } return 0; diff --git a/src/Lua/Standard/Internal/IOHelper.cs b/src/Lua/Standard/Internal/IOHelper.cs index 85745077..55c8974a 100644 --- a/src/Lua/Standard/Internal/IOHelper.cs +++ b/src/Lua/Standard/Internal/IOHelper.cs @@ -25,7 +25,7 @@ public static int Open(LuaThread thread, string fileName, string mode, bool thro try { var stream = File.Open(fileName, fileMode, fileAccess); - thread.Push(new LuaValue(new FileHandle(stream))); + thread.Stack.Push(new LuaValue(new FileHandle(stream))); return 1; } catch (IOException ex) @@ -35,9 +35,9 @@ public static int Open(LuaThread thread, string fileName, string mode, bool thro throw; } - thread.Push(LuaValue.Nil); - thread.Push(ex.Message); - thread.Push(ex.HResult); + thread.Stack.Push(LuaValue.Nil); + thread.Stack.Push(ex.Message); + thread.Stack.Push(ex.HResult); return 3; } } diff --git a/src/Lua/Standard/ModuleLibrary.cs b/src/Lua/Standard/ModuleLibrary.cs index be623cbe..7be29881 100644 --- a/src/Lua/Standard/ModuleLibrary.cs +++ b/src/Lua/Standard/ModuleLibrary.cs @@ -27,7 +27,7 @@ public async ValueTask Require(LuaFunctionExecutionContext context, Cancell ? context.State.Load(module.ReadBytes(), module.Name) : context.State.Load(module.ReadText(), module.Name); } - await closure.InvokeAsync(context, cancellationToken); + await context.Access.RunAsync(closure, 0, cancellationToken); loadedTable = context.Thread.Stack.Get(context.ReturnFrameBase); loaded[arg0] = loadedTable; } diff --git a/src/Lua/Standard/StringLibrary.cs b/src/Lua/Standard/StringLibrary.cs index f7cfe270..f8fc7116 100644 --- a/src/Lua/Standard/StringLibrary.cs +++ b/src/Lua/Standard/StringLibrary.cs @@ -507,13 +507,13 @@ public async ValueTask GSub(LuaFunctionExecutionContext context, Cancellati } else if (repl.TryRead(out var func)) { + var stack = context.Thread.Stack; for (int k = 1; k <= match.Groups.Count; k++) { - context.Thread.Push(match.Groups[k].Value); + stack.Push(match.Groups[k].Value); } - - await func.InvokeAsync(context with { ArgumentCount = match.Groups.Count }, cancellationToken); + await context.Access.RunAsync(func,match.Groups.Count,cancellationToken); result = context.Thread.Stack.Get(context.ReturnFrameBase); } diff --git a/src/Lua/Standard/TableLibrary.cs b/src/Lua/Standard/TableLibrary.cs index 4946e19e..ced1169d 100644 --- a/src/Lua/Standard/TableLibrary.cs +++ b/src/Lua/Standard/TableLibrary.cs @@ -189,7 +189,7 @@ async ValueTask PartitionAsync(LuaFunctionExecutionContext context, Memory< var top = stack.Count; stack.Push(memory.Span[j]); stack.Push(pivot); - await comparer.InvokeAsync(context with { ArgumentCount = 2, ReturnFrameBase = top }, cancellationToken); + await context.Access.RunAsync(comparer, 2,cancellationToken); if (context.Thread.Stack.Get(top).ToBoolean()) { diff --git a/tests/Lua.Tests/AsyncTests.cs b/tests/Lua.Tests/AsyncTests.cs index 3a7cd84b..c0d344ff 100644 --- a/tests/Lua.Tests/AsyncTests.cs +++ b/tests/Lua.Tests/AsyncTests.cs @@ -12,25 +12,41 @@ public void SetUp() state = LuaState.Create(); state.OpenStandardLibraries(); var assert = state.Environment["assert"].Read(); - state.Environment["assert"] = new LuaFunction("wait", - async (c, ct) => + state.Environment["assert"] = new LuaFunction("assert_with_wait", + async (context, ct) => { await Task.Delay(1, ct); - return await assert.InvokeAsync(c, ct); + var arg0 = context.GetArgument(0); + + if (!arg0.ToBoolean()) + { + var message = "assertion failed!"; + if (context.HasArgument(1)) + { + message = context.GetArgument(1); + } + + throw new LuaAssertionException(context.Thread, message); + } + + return (context.Return(context.Arguments)); }); } [Test] - public async Task Test_Async() + [TestCase("tests-lua/coroutine.lua")] + [TestCase("tests-lua/db.lua")] + [TestCase("tests-lua/vararg.lua")] + public async Task Test_Async(string file) { - var path = FileHelper.GetAbsolutePath("tests-lua/coroutine.lua"); + var path = FileHelper.GetAbsolutePath(file); try { await state.DoFileAsync(path); } catch (LuaRuntimeException e) { - var line = e.LuaTraceback.LastLine; + var line = e.LuaTraceback!.LastLine; throw new Exception($"{path}:line {line}\n{e.InnerException}\n {e}"); } } diff --git a/tests/Lua.Tests/LuaApiTests.cs b/tests/Lua.Tests/LuaApiTests.cs index bb8ced60..33ebcbb5 100644 --- a/tests/Lua.Tests/LuaApiTests.cs +++ b/tests/Lua.Tests/LuaApiTests.cs @@ -34,11 +34,12 @@ public async Task TestArithmetic() setmetatable(a, metatable) return a, b """; - var result = await state.DoStringAsync(source); + var access = state.TopLevelAccess; + var result = await access.DoStringAsync(source); var a = result[0].Read(); var b = result[1].Read(); - var c = await state.MainThread.Arithmetic(a, b, OpCode.Add); + var c = await access.Arithmetic(a, b, OpCode.Add); var table = c.Read(); Assert.Multiple(() => { @@ -67,10 +68,11 @@ public async Task TestUnary() setmetatable(a, metatable) return a """; - var result = await state.DoStringAsync(source); - var a = result[0].Read(); + var access = state.TopLevelAccess; - var c = await state.MainThread.Unary(a, OpCode.Unm); + var result = await access.DoStringAsync(source); + var a = result[0].Read(); + var c = await access.Unary(a, OpCode.Unm); var table = c.Read(); Assert.Multiple(() => { @@ -104,13 +106,14 @@ public async Task TestCompare() setmetatable(a, metatable) return a, b, c """; - var result = await state.DoStringAsync(source); + var access = state.TopLevelAccess; + var result = await access.DoStringAsync(source); var a = result[0].Read(); var b = result[1].Read(); var c = result[2].Read(); - var ab = await state.MainThread.Compare(a, b, OpCode.Eq); + var ab = await access.Compare(a, b, OpCode.Eq); Assert.False(ab); - var ac = await state.MainThread.Compare(a, c, OpCode.Eq); + var ac = await access.Compare(a, c, OpCode.Eq); Assert.True(ac); } @@ -126,11 +129,12 @@ public async Task TestGetTable() setmetatable(a, metatable) return a """; - var result = await state.DoStringAsync(source); + var access = state.TopLevelAccess; + var result = await access.DoStringAsync(source); var a = result[0].Read(); - Assert.That(await state.MainThread.GetTable(a, "x"), Is.EqualTo(new LuaValue(1))); + Assert.That(await access.GetTable(a, "x"), Is.EqualTo(new LuaValue(1))); a.Metatable!["__index"] = state.DoStringAsync("return function(a,b) return b end").Result[0]; - Assert.That(await state.MainThread.GetTable(a, "x"), Is.EqualTo(new LuaValue("x"))); + Assert.That(await access.GetTable(a, "x"), Is.EqualTo(new LuaValue("x"))); } [Test] @@ -146,9 +150,10 @@ public async Task TestSetTable() setmetatable(a, metatable) return a """; - var result = await state.DoStringAsync(source); + var access = state.TopLevelAccess; + var result = await access.DoStringAsync(source); var a = result[0].Read(); - await state.MainThread.SetTable(a, "a", "b"); + await access.SetTable(a, "a", "b"); var b = a.Metatable!["__newindex"].Read()["a"]; Assert.True(b.Read() == "b"); } @@ -178,14 +183,14 @@ public async Task Test_Metamethod_Concat() return a,b,c "; - - var result = await state.DoStringAsync(source); + var access = state.TopLevelAccess; + var result = await access.DoStringAsync(source); Assert.That(result, Has.Length.EqualTo(3)); var a = result[0]; var b = result[1]; var c = result[2]; - var d = await state.MainThread.Concat([a, b, c]); + var d = await access.Concat([a, b, c]); var table = d.Read(); Assert.That(table.ArrayLength, Is.EqualTo(9)); @@ -224,21 +229,22 @@ public async Task Test_Metamethod_MetaCallViaMeta() local c ={name ="c"} return a,b,c """; - var result = await state.DoStringAsync(source); + var access = state.TopLevelAccess; + var result = await access.DoStringAsync(source); var a = result[0]; var b = result[1]; var c = result[2]; - var d = await state.MainThread.Arithmetic(b, c, OpCode.Add); + var d = await access.Arithmetic(b, c, OpCode.Add); Assert.True(d.TryRead(out string s)); Assert.That(s, Is.EqualTo("abc")); - d = await state.MainThread.Unary(b, OpCode.Unm); + d = await access.Unary(b, OpCode.Unm); Assert.True(d.TryRead(out s)); Assert.That(s, Is.EqualTo("abb")); - d = await state.MainThread.Concat([c, b]); + d = await access.Concat([c, b]); Assert.True(d.TryRead(out s)); Assert.That(s, Is.EqualTo("acb")); - var aResult = await state.MainThread.Call(a, [b, c]); + var aResult = await access.Call(a, [b, c]); Assert.That(aResult, Has.Length.EqualTo(1)); Assert.That(aResult[0].Read(), Is.EqualTo("abc")); } @@ -266,21 +272,22 @@ public async Task Test_Metamethod_MetaCallViaMeta_VarArg() local c ={name ="c"} return a,b,c """; - var result = await state.DoStringAsync(source); + var access = state.TopLevelAccess; + var result = await access.DoStringAsync(source); var a = result[0]; var b = result[1]; var c = result[2]; - var d = await state.MainThread.Arithmetic(b, c, OpCode.Add); + var d = await access.Arithmetic(b, c, OpCode.Add); Assert.True(d.TryRead(out string s)); Assert.That(s, Is.EqualTo("abc")); - d = await state.MainThread.Unary(b, OpCode.Unm); + d = await access.Unary(b, OpCode.Unm); Assert.True(d.TryRead(out s)); Assert.That(s, Is.EqualTo("abb")); - d = await state.MainThread.Concat([c, b]); + d = await access.Concat([c, b]); Assert.True(d.TryRead(out s)); Assert.That(s, Is.EqualTo("acb")); - var aResult = await state.MainThread.Call(a, [b, c]); + var aResult = await access.Call(a, [b, c]); Assert.That(aResult, Has.Length.EqualTo(1)); Assert.That(aResult[0].Read(), Is.EqualTo("abc")); } diff --git a/tests/Lua.Tests/ValidationTests.cs b/tests/Lua.Tests/ValidationTests.cs new file mode 100644 index 00000000..89250a4a --- /dev/null +++ b/tests/Lua.Tests/ValidationTests.cs @@ -0,0 +1,56 @@ +using Lua.Runtime; +using Lua.Standard; + +namespace Lua.Tests; + +public class ValidationTests +{ + [Test] + public async Task Test_Simple() + { + var state = LuaState.Create(); + state.OpenStandardLibraries(); + LuaThreadAccess innerAccess = default!; + state.Environment["wait"] = new LuaFunction("wait", + async (context, ct) => + { + innerAccess = context.Access; + await Task.Delay((int)(context.GetArgument(0) * 1000), ct); + return context.Return(context.Arguments); + }); + + var task=state.DoStringAsync("wait(0.5)"); + + await Task.Delay(100); + Assert.That(task.IsCompleted, Is.False); + Assert.ThrowsAsync( async () => + { + await state.DoStringAsync("print('hello')"); + }); + await task; + + Assert.ThrowsAsync( async () => + { + await innerAccess.DoStringAsync("print('hello')"); + }); + Assert.DoesNotThrowAsync(async () => + { + await state.DoStringAsync("wait(0.5)"); + }); + } + + [Test] + public async Task Test_Recursive() + { + var state = LuaState.Create(); + state.OpenStandardLibraries(); + state.Environment["dostring"] = new LuaFunction("dostring", + async (context, ct) => context.Return(await context.Access.DoStringAsync(context.GetArgument(0), null, ct))); + + var result=await state.DoStringAsync("""return dostring("return 1")"""); + + Assert.That(result.Length, Is.EqualTo(1)); + Assert.That(result[0].Read(), Is.EqualTo(1)); + } + +} \ No newline at end of file