diff --git a/src/Lua/Standard/Internal/MatchState.cs b/src/Lua/Standard/Internal/MatchState.cs new file mode 100644 index 00000000..ca8ea5ae --- /dev/null +++ b/src/Lua/Standard/Internal/MatchState.cs @@ -0,0 +1,463 @@ +using System.Buffers; + +namespace Lua.Standard.Internal; + +internal class MatchState(LuaThread thread, string source, string pattern) +{ + internal const int LuaMaxCaptures = 32; + private const int CapUnfinished = -1; + private const int CapPosition = -2; + private const char LEsc = '%'; + private const string Specials = "^$*+?.([%-"; + internal const int MaxCalls = 200; + + internal struct Capture + { + public int Init; + public int Len; + + public bool IsPosition => Len == CapPosition; + } + + public readonly LuaThread Thread = thread; + public readonly string Source = source; + public readonly string Pattern = pattern; + public int Level = 0; + internal readonly Capture[] Captures = new Capture[LuaMaxCaptures]; + public int MatchDepth = MaxCalls; + + public static bool NoSpecials(ReadOnlySpan pattern) + { +#if NET8_0_OR_GREATER + return !pattern.ContainsAny(Specials); +#else + return pattern.IndexOfAny(Specials) == -1; +#endif + } + + int StartCapture(int sIdx, int pIdx, int what) + { + if (Level >= LuaMaxCaptures) + throw new LuaRuntimeException(Thread, "too many captures"); + Captures[Level].Init = sIdx; + Captures[Level].Len = what; + Level++; + var res = Match(sIdx, pIdx); + if (res < 0) + { + Level--; + } + + return res; + } + + int EndCapture(int sIdx, int pIdx) + { + var l = CaptureToClose(); + Captures[l].Len = sIdx - Captures[l].Init; + var res = Match(sIdx, pIdx); + if (res < 0) + { + Captures[l].Len = CapUnfinished; // Reset unfinished capture + } + + return res; + } + + public int Match(int sIdx, int pIdx) + { + if (MatchDepth-- == 0) + throw new LuaRuntimeException(Thread, "pattern too complex"); + + var endIdx = Pattern.Length; + Init: + if (pIdx < endIdx) + { + switch (Pattern[pIdx]) + { + case '(': + if (pIdx + 1 < Pattern.Length && Pattern[pIdx + 1] == ')') + { + sIdx = StartCapture(sIdx, pIdx + 2, CapPosition); + } + else + { + sIdx = StartCapture(sIdx, pIdx + 1, CapUnfinished); + } + + break; + + case ')': + // End capture + + sIdx = EndCapture(sIdx, pIdx + 1); + break; + + + case '$': + if (pIdx + 1 == Pattern.Length) + { + MatchDepth++; + return sIdx == Source.Length ? sIdx : -1; + } + + goto Default; + + case LEsc: + if (pIdx + 1 >= Pattern.Length) + { + goto Default; + } + + switch (Pattern[pIdx + 1]) + { + case 'b': + { + sIdx = MatchBalance(sIdx, pIdx + 2); + if (sIdx < 0) + { + MatchDepth++; + return -1; + } + + pIdx += 4; + goto Init; + } + + case 'f': + if (pIdx + 2 < Pattern.Length && Pattern[pIdx + 2] == '[') + { + var ep = ClassEnd(Pattern, pIdx + 2); + char previous = sIdx > 0 ? Source[sIdx - 1] : '\0'; + if (!MatchBracketClass(previous, Pattern, pIdx + 2, ep - 1) && + sIdx < Source.Length && MatchBracketClass(Source[sIdx], Pattern, pIdx + 2, ep - 1)) + { + pIdx = ep; + goto Init; + } + } + + sIdx = -1; + + break; + + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + sIdx = MatchCapture(sIdx, Pattern[pIdx + 1] - '1'); + if (sIdx < 0) + { + MatchDepth++; + return -1; + } + + pIdx += 2; + goto Init; + } + + default: + goto Default; + } + + break; + + default: + Default: + { + var ep = ClassEnd(Pattern, pIdx); + if (!SingleMatch(sIdx, pIdx, ep)) + { + if (ep < Pattern.Length && (Pattern[ep] is '*' or '?' or '-')) + { + pIdx = ep + 1; + goto Init; // Continue the while loop with updated pIdx + } + else + { + MatchDepth++; + return -1; + } + } + else + { + if (ep >= Pattern.Length) + { + // No quantifier, we matched one occurrence + sIdx++; + pIdx = ep; // Move past this pattern element + goto Init; // Continue matching with the rest of the pattern + } + + switch (Pattern[ep]) + { + case '?': + { + // Try matching with this character + var res = Match(sIdx + 1, ep + 1); + if (res >= 0) + { + MatchDepth++; + return res; + } + + pIdx = ep + 1; + goto Init; + } + + case '+': + // For +, we need at least one match (already verified) + // Skip the first match we already verified + sIdx++; + // Now match zero or more additional occurrences + goto case '*'; + + case '*': + // Match zero or more occurrences + { + { + int i = 0; + // Count how many we can match + while (sIdx + i < Source.Length && SingleMatch(sIdx + i, pIdx, ep)) + i++; + + // Try matching from longest to shortest + while (i >= 0) + { + var res = Match(sIdx + i, ep + 1); + if (res >= 0) + { + MatchDepth++; + return res; + } + + i--; + } + + MatchDepth++; + return -1; + } + } + + case '-': + // Match zero or more occurrences (minimal) + { + // for (;;) { + // const char *res = match(ms, s, ep+1); + // if (res != NULL) + // return res; + // else if (singlematch(ms, s, p, ep)) + // s++; /* try with one more repetition */ + // else return NULL; + // } + while (true) + { + var res = Match(sIdx, ep + 1); + if (res >= 0) + { + MatchDepth++; + return res; + } + + if (SingleMatch(sIdx, pIdx, ep)) + { + sIdx++; // Try with one more repetition + } + else + { + MatchDepth++; + return -1; // No match found + } + } + } + + default: + sIdx++; + pIdx = ep; + goto Init; // Continue the while loop + } + } + } + } + } + + MatchDepth++; + return sIdx; + } + + private bool SingleMatch(int sIdx, int pIdx, int ep) + { + if (sIdx >= Source.Length) + return false; + + char c = Source[sIdx]; + switch (Pattern[pIdx]) + { + case '.': + return true; + case LEsc: + return pIdx + 1 < Pattern.Length && MatchClass(c, Pattern[pIdx + 1]); + case '[': + return MatchBracketClass(c, Pattern, pIdx, ep - 1); + default: + return Pattern[pIdx] == c; + } + } + + private int CaptureToClose() + { + int level = Level; + for (level--; level >= 0; level--) + { + if (Captures[level].Len == CapUnfinished) + return level; + } + + throw new LuaRuntimeException(Thread, "invalid pattern capture"); + } + + private int MatchCapture(int sIdx, int l) + { + l = CheckCapture(l); + int len = Captures[l].Len; + if (len >= 0 && sIdx + len <= Source.Length) + { + var capture = Source.AsSpan(Captures[l].Init, len); + if (sIdx + len <= Source.Length && Source.AsSpan(sIdx, len).SequenceEqual(capture)) + return sIdx + len; // Return the new position + } + + return -1; + } + + private int CheckCapture(int l) + { + if (l < 0 || l >= Level || Captures[l].Len == CapUnfinished) + throw new LuaRuntimeException(Thread, $"invalid capture index %{l + 1}"); + return l; + } + + private int MatchBalance(int sIdx, int pIdx) + { + if (pIdx + 1 >= Pattern.Length) + throw new LuaRuntimeException(Thread, "malformed pattern (missing arguments to '%b')"); + + if (sIdx >= Source.Length || Source[sIdx] != Pattern[pIdx]) + return -1; + + char b = Pattern[pIdx]; + char e = Pattern[pIdx + 1]; + int cont = 1; + sIdx++; + + while (sIdx < Source.Length) + { + if (Source[sIdx] == e) + { + if (--cont == 0) + return sIdx + 1; // Return the length matched + } + else if (Source[sIdx] == b) + { + cont++; + } + + sIdx++; + } + + return -1; + } + + private int ClassEnd(ReadOnlySpan pattern, int pIdx) + { + switch (pattern[pIdx++]) + { + case LEsc: + if (pIdx >= pattern.Length) + throw new LuaRuntimeException(Thread, "malformed pattern (ends with %)"); + return pIdx + 1; + + case '[': + if (pIdx < pattern.Length && pattern[pIdx] == '^') pIdx++; + do + { + pIdx++; + if (pIdx < pattern.Length && pattern[pIdx] == LEsc) + pIdx++; + if (pIdx >= pattern.Length) + throw new LuaRuntimeException(Thread, "malformed pattern (missing ']')"); + } while (pIdx < pattern.Length && pattern[pIdx] != ']'); + + return pIdx + 1; + + default: + return pIdx; + } + } + + private static bool MatchClass(char c, char cl) + { + bool res; + switch (char.ToLower(cl)) + { + case 'a': res = char.IsLetter(c); break; + case 'c': res = char.IsControl(c); break; + case 'd': res = char.IsDigit(c); break; + case 'g': res = !char.IsControl(c) && !char.IsWhiteSpace(c); break; + case 'l': res = char.IsLower(c); break; + case 'p': res = char.IsPunctuation(c); break; + case 's': res = char.IsWhiteSpace(c); break; + case 'u': res = char.IsUpper(c); break; + case 'w': res = char.IsLetterOrDigit(c); break; + case 'x': res = IsHexDigit(c); break; + case 'z': res = c == '\0'; break; + default: return cl == c; + } + + return char.IsLower(cl) ? res : !res; + } + + private static bool IsHexDigit(char c) + { + return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); + } + + private static bool MatchBracketClass(char c, ReadOnlySpan pattern, int pIdx, int ec) + { + bool sig = true; + if (pIdx + 1 < pattern.Length && pattern[pIdx + 1] == '^') + { + sig = false; + pIdx++; + } + + while (++pIdx < ec) + { + if (pattern[pIdx] == LEsc) + { + pIdx++; + if (pIdx <= ec && MatchClass(c, pattern[pIdx])) + return sig; + } + else if (pIdx + 2 < ec && pattern[pIdx + 1] == '-') + { + if (pattern[pIdx] <= c && c <= pattern[pIdx + 2]) + return sig; + pIdx += 2; + } + else if (pattern[pIdx] == c) + { + return sig; + } + } + + return !sig; + } +} \ No newline at end of file diff --git a/src/Lua/Standard/StringLibrary.cs b/src/Lua/Standard/StringLibrary.cs index 8961b1e3..15aa64fd 100644 --- a/src/Lua/Standard/StringLibrary.cs +++ b/src/Lua/Standard/StringLibrary.cs @@ -1,8 +1,9 @@ using System.Text; -using System.Text.RegularExpressions; using Lua.Internal; using Lua.Runtime; using System.Globalization; +using Lua.Standard.Internal; +using System.Diagnostics; namespace Lua.Standard; @@ -81,64 +82,8 @@ public ValueTask Dump(LuaFunctionExecutionContext context, CancellationToke throw new NotSupportedException("stirng.dump is not supported"); } - public ValueTask Find(LuaFunctionExecutionContext context, CancellationToken cancellationToken) - { - var s = context.GetArgument(0); - var pattern = context.GetArgument(1); - var init = context.HasArgument(2) - ? context.GetArgument(2) - : 1; - var plain = context.HasArgument(3) && context.GetArgument(3).ToBoolean(); - - LuaRuntimeException.ThrowBadArgumentIfNumberIsNotInteger(context.Thread, 3, init); - - // init can be negative value - if (init < 0) - { - init = s.Length + init + 1; - } - - // out of range - if (init != 1 && (init < 1 || init > s.Length)) - { - return new(context.Return(LuaValue.Nil)); - } - - // empty pattern - if (pattern.Length == 0) - { - return new(context.Return(1, 0)); - } - - var source = s.AsSpan()[(int)(init - 1)..]; - - if (plain) - { - var start = source.IndexOf(pattern); - if (start == -1) - { - return new(context.Return(LuaValue.Nil)); - } - - // 1-based - return new(context.Return(start + 1, start + pattern.Length)); - } - else - { - var regex = StringHelper.ToRegex(pattern); - var match = regex.Match(source.ToString()); - - if (match.Success) - { - // 1-based - return new(context.Return(init + match.Index, init + match.Index + match.Length - 1)); - } - else - { - return new(context.Return(LuaValue.Nil)); - } - } - } + public ValueTask Find(LuaFunctionExecutionContext context, CancellationToken cancellationToken) => + FindAux(context, true); public async ValueTask Format(LuaFunctionExecutionContext context, CancellationToken cancellationToken) { @@ -435,43 +380,73 @@ public ValueTask GMatch(LuaFunctionExecutionContext context, CancellationTo var s = context.GetArgument(0); var pattern = context.GetArgument(1); - var regex = StringHelper.ToRegex(pattern); - var matches = regex.Matches(s); - - return new(context.Return(new CSharpClosure("iterator", [LuaValue.FromObject(matches), 0], static (context, cancellationToken) => + return new(context.Return(new CSharpClosure("gmatch_iterator", [s, pattern, 0], static (context, cancellationToken) => { var upValues = context.GetCsClosure()!.UpValues; - var matches = upValues[0].Read(); - var i = upValues[1].Read(); - if (matches.Count > i) + var s = upValues[0].Read(); + var pattern = upValues[1].Read(); + var start = upValues[2].Read(); + + var matchState = new MatchState(context.Thread, s, pattern); + var captures = matchState.Captures; + + // Check for anchor at start + bool anchor = pattern.Length > 0 && pattern[0] == '^'; + int pIdx = anchor ? 1 : 0; + + // For empty patterns, we need to match at every position including after the last character + var sEndIdx = s.Length + (pattern.Length == 0 || (anchor && pattern.Length == 1) ? 1 : 0); + + for (int sIdx = start; sIdx < sEndIdx; sIdx++) { - var match = matches[i]; - var groups = match.Groups; + // Reset match state for each attempt + matchState.Level = 0; + matchState.MatchDepth = MatchState.MaxCalls; + // Clear captures to avoid stale data + Array.Clear(captures, 0, captures.Length); - i++; - upValues[1] = i; - if (groups.Count == 1) - { - return new(context.Return(match.Value)); - } - else + var res = matchState.Match(sIdx, pIdx); + + if (res >= 0) { - var buffer = context.GetReturnBuffer(groups.Count); - for (int j = 0; j < groups.Count; j++) + // If no captures were made, create one for the whole match + if (matchState.Level == 0) { - buffer[j] = groups[j + 1].Value; + captures[0].Init = sIdx; + captures[0].Len = res - sIdx; + matchState.Level = 1; + } + + var resultLength = matchState.Level; + var buffer = context.GetReturnBuffer(resultLength); + for (int i = 0; i < matchState.Level; i++) + { + var capture = captures[i]; + if (capture.IsPosition) + { + buffer[i] = capture.Init + 1; // 1-based position + } + else + { + buffer[i] = s.AsSpan(capture.Init, capture.Len).ToString(); + } } - return new(buffer.Length); + // Update start index for next iteration + // Handle empty matches by advancing at least 1 position + upValues[2] = res > sIdx ? res : sIdx + 1; + return new(resultLength); } + + // For anchored patterns, only try once + if (anchor) break; } - else - { - return new(context.Return(LuaValue.Nil)); - } + + return new(context.Return(LuaValue.Nil)); }))); } + public async ValueTask GSub(LuaFunctionExecutionContext context, CancellationToken cancellationToken) { var s = context.GetArgument(0); @@ -479,85 +454,212 @@ public async ValueTask GSub(LuaFunctionExecutionContext context, Cancellati var repl = context.GetArgument(2); var n_arg = context.HasArgument(3) ? context.GetArgument(3) - : int.MaxValue; + : s.Length + 1; LuaRuntimeException.ThrowBadArgumentIfNumberIsNotInteger(context.Thread, 4, n_arg); var n = (int)n_arg; - var regex = StringHelper.ToRegex(pattern); - var matches = regex.Matches(s); - // TODO: reduce allocation + // Use MatchState instead of regex + var matchState = new MatchState(context.Thread, s, pattern); + var captures = matchState.Captures; + var builder = new StringBuilder(); + StringBuilder? replacedBuilder = repl.Type == LuaValueType.String + ? new StringBuilder(repl.UnsafeReadString().Length) + : null; var lastIndex = 0; var replaceCount = 0; - int i = 0; - for (; i < matches.Count; i++) + + // Check for anchor at start + bool anchor = pattern.Length > 0 && pattern[0] == '^'; + int sIdx = 0; + + // For empty patterns, we need to match at every position including after the last character + var sEndIdx = s.Length + (pattern.Length == 0 || (anchor && pattern.Length == 1) ? 1 : 0); + while ((sIdx < sEndIdx) && replaceCount < n) { - if (replaceCount > n) break; + // Reset match state for each attempt + matchState.Level = 0; + Debug.Assert(matchState.MatchDepth == MatchState.MaxCalls); + // Clear captures array to avoid stale data + for (int i = 0; i < captures.Length; i++) + { + captures[i] = default; + } - var match = matches[i]; - builder.Append(s.AsSpan()[lastIndex..match.Index]); - replaceCount++; + // Always start pattern from beginning (0 or 1 if anchored) + int pIdx = anchor ? 1 : 0; + var res = matchState.Match(sIdx, pIdx); - LuaValue result; - if (repl.TryRead(out var str)) + if (res >= 0) { - result = str.Replace("%%", "%") - .Replace("%0", match.Value); + // Found a match + builder.Append(s.AsSpan()[lastIndex..sIdx]); - for (int k = 1; k <= match.Groups.Count; k++) + // If no captures were made, create one for the whole match + if (matchState.Level == 0) { - if (replaceCount > n) break; - result = result.Read().Replace($"%{k}", match.Groups[k].Value); - replaceCount++; + captures[0].Init = sIdx; + captures[0].Len = res - sIdx; + matchState.Level = 1; } - } - else if (repl.TryRead(out var table)) - { - result = table[match.Groups[1].Value]; - } - else if (repl.TryRead(out var func)) - { - var stack = context.Thread.Stack; - for (int k = 1; k <= match.Groups.Count; k++) + + LuaValue result; + if (repl.TryRead(out var str)) { - stack.Push(match.Groups[k].Value); + if (!str.Contains("%")) + { + result = str; // No special characters, use as is + } + else + { + // String replacement + replacedBuilder!.Clear(); + replacedBuilder.Append(str); + + // Replace %% with % + replacedBuilder.Replace("%%", "\0"); // Use null char as temporary marker + + // Replace %0 with whole match + var wholeMatch = s.AsSpan(sIdx, res - sIdx).ToString(); + replacedBuilder.Replace("%0", wholeMatch); + + // Replace %1, %2, etc. with captures + for (int k = 0; k < matchState.Level; k++) + { + var capture = captures[k]; + string captureText; + + if (capture.IsPosition) + { + captureText = (capture.Init + 1).ToString(); // 1-based position + } + else + { + captureText = s.AsSpan(capture.Init, capture.Len).ToString(); + } + + replacedBuilder.Replace($"%{k + 1}", captureText); + } + + // Replace temporary marker back to % + replacedBuilder.Replace('\0', '%'); + result = replacedBuilder.ToString(); + } } + else if (repl.TryRead(out var table)) + { + // Table lookup - use first capture or whole match + string key; + if (matchState.Level > 0 && !captures[0].IsPosition) + { + key = s.AsSpan(captures[0].Init, captures[0].Len).ToString(); + } + else + { + key = s.AsSpan(sIdx, res - sIdx).ToString(); + } - await context.Access.RunAsync(func, match.Groups.Count, cancellationToken); + result = table[key]; + } + else if (repl.TryRead(out var func)) + { + // Function call with captures as arguments + var stack = context.Thread.Stack; - result = context.Thread.Stack.Get(context.ReturnFrameBase); - } - else - { - throw new LuaRuntimeException(context.Thread, "bad argument #3 to 'gsub' (string/function/table expected)"); - } + if (matchState.Level == 0) + { + // No captures, pass whole match + stack.Push(s.AsSpan(sIdx, res - sIdx).ToString()); + var retCount = await context.Access.RunAsync(func, 1, cancellationToken); + using var results = context.Access.ReadReturnValues(retCount); + result = results.Count > 0 ? results[0] : LuaValue.Nil; + } + else + { + // Pass all captures + for (int k = 0; k < matchState.Level; k++) + { + var capture = captures[k]; + if (capture.IsPosition) + { + stack.Push(capture.Init + 1); // 1-based position + } + else + { + stack.Push(s.AsSpan(capture.Init, capture.Len).ToString()); + } + } - if (result.TryRead(out var rs)) - { - builder.Append(rs); - } - else if (result.TryRead(out var rd)) - { - builder.Append(rd); - } - else if (!result.ToBoolean()) - { - builder.Append(match.Value); - replaceCount--; + var retCount = await context.Access.RunAsync(func, matchState.Level, cancellationToken); + using var results = context.Access.ReadReturnValues(retCount); + result = results.Count > 0 ? results[0] : LuaValue.Nil; + } + } + else + { + throw new LuaRuntimeException(context.Thread, "bad argument #3 to 'gsub' (string/function/table expected)"); + } + + // Handle replacement result + if (result.TryRead(out var rs)) + { + builder.Append(rs); + } + else if (result.TryRead(out var rd)) + { + builder.Append(rd); + } + else if (!result.ToBoolean()) + { + // False or nil means don't replace + builder.Append(s.AsSpan(sIdx, res - sIdx)); + } + else + { + throw new LuaRuntimeException(context.Thread, $"invalid replacement value (a {result.Type})"); + } + + replaceCount++; + lastIndex = res; + + // If empty match, advance by 1 to avoid infinite loop + if (res == sIdx) + { + if (sIdx < s.Length) + { + builder.Append(s[sIdx]); + lastIndex = sIdx + 1; + } + + sIdx++; + } + else + { + sIdx = res; + } } else { - throw new LuaRuntimeException(context.Thread, $"invalid replacement value (a {result.Type})"); - } + // No match at this position + if (anchor) + { + // Anchored pattern only tries at start + break; + } - lastIndex = match.Index + match.Length; + sIdx++; + } } - builder.Append(s.AsSpan()[lastIndex..s.Length]); + // Append remaining part of string + if (lastIndex < s.Length) + { + builder.Append(s.AsSpan()[lastIndex..]); + } - return context.Return(builder.ToString(), i); + return context.Return(builder.ToString(), replaceCount); } public ValueTask Len(LuaFunctionExecutionContext context, CancellationToken cancellationToken) @@ -572,10 +674,140 @@ public ValueTask Lower(LuaFunctionExecutionContext context, CancellationTok return new(context.Return(s.ToLower())); } - public ValueTask Match(LuaFunctionExecutionContext context, CancellationToken cancellationToken) + public ValueTask Match(LuaFunctionExecutionContext context, CancellationToken cancellationToken) => + FindAux(context, false); + + public ValueTask FindAux(LuaFunctionExecutionContext context, bool find) { - //TODO : implement string.match - throw new NotImplementedException(); + var s = context.GetArgument(0); + var pattern = context.GetArgument(1); + var init = context.HasArgument(2) + ? context.GetArgument(2) + : 1; + + LuaRuntimeException.ThrowBadArgumentIfNumberIsNotInteger(context.Thread, 3, init); + + // Convert to 0-based index + if (init < 0) + { + init = s.Length + init + 1; + } + + init = Math.Max(0, Math.Min(init - 1, s.Length)); // Convert from 1-based to 0-based and clamp + + // Check for plain search mode (4th parameter = true) + if (find && context.GetArgumentOrDefault(3).ToBoolean()) + { + return PlainSearch(context, s, pattern, init); + } + + // Fast path for simple patterns without special characters + if (find && MatchState.NoSpecials(pattern)) + { + return SimplePatternSearch(context, s, pattern, init); + } + + return PatternSearch(context, s, pattern, init, find); + } + + private static ValueTask PlainSearch(LuaFunctionExecutionContext context, string s, string pattern, int init) + { + if (init > s.Length) + { + return new(context.Return(LuaValue.Nil)); + } + + var index = s.AsSpan(init).IndexOf(pattern); + if (index == -1) + { + return new(context.Return(LuaValue.Nil)); + } + + var actualStart = init + index; + return new(context.Return(actualStart + 1, actualStart + pattern.Length)); // Convert to 1-based + } + + private static ValueTask SimplePatternSearch(LuaFunctionExecutionContext context, string s, string pattern, int init) + { + if (init > s.Length) + { + return new(context.Return(LuaValue.Nil)); + } + + var index = s.AsSpan(init).IndexOf(pattern); + if (index == -1) + { + return new(context.Return(LuaValue.Nil)); + } + + var actualStart = init + index; + return new(context.Return(actualStart + 1, actualStart + pattern.Length)); // Convert to 1-based + } + + private static ValueTask PatternSearch(LuaFunctionExecutionContext context, string s, string pattern, int init, bool find) + { + var matchState = new MatchState(context.Thread, s, pattern); + var captures = matchState.Captures; + + // Check for anchor at start + bool anchor = pattern.Length > 0 && pattern[0] == '^'; + int pIdx = anchor ? 1 : 0; + + // For empty patterns, we need to match at every position including after the last character + var sEndIdx = s.Length + (pattern.Length == 0 ? 1 : 0); + + for (int sIdx = init; sIdx < sEndIdx; sIdx++) + { + // Reset match state for each attempt + matchState.Level = 0; + matchState.MatchDepth = MatchState.MaxCalls; + Array.Clear(captures, 0, captures.Length); + + var res = matchState.Match(sIdx, pIdx); + + if (res >= 0) + { + // If no captures were made for string.match, create one for the whole match + if (!find && matchState.Level == 0) + { + captures[0].Init = sIdx; + captures[0].Len = res - sIdx; + matchState.Level = 1; + } + + var resultLength = matchState.Level + (find ? 2 : 0); + var buffer = context.GetReturnBuffer(resultLength); + + if (find) + { + // Return start and end positions for string.find + buffer[0] = sIdx + 1; // Convert to 1-based index + buffer[1] = res; // Convert to 1-based index + buffer = buffer[2..]; + } + + // Return captures + for (int i = 0; i < matchState.Level; i++) + { + var capture = captures[i]; + if (capture.IsPosition) + { + buffer[i] = capture.Init + 1; // 1-based position + } + else + { + buffer[i] = s.AsSpan(capture.Init, capture.Len).ToString(); + } + } + + return new(resultLength); + } + + // For anchored patterns, only try once + if (anchor) break; + } + + return new(context.Return(LuaValue.Nil)); } public ValueTask Rep(LuaFunctionExecutionContext context, CancellationToken cancellationToken) diff --git a/tests/Lua.Tests/PatternMatchingTests.cs b/tests/Lua.Tests/PatternMatchingTests.cs new file mode 100644 index 00000000..76e260e9 --- /dev/null +++ b/tests/Lua.Tests/PatternMatchingTests.cs @@ -0,0 +1,801 @@ +using Lua.Standard; + +namespace Lua.Tests; + +public class PatternMatchingTests +{ + [Test] + public async Task Test_StringMatch_BasicPatterns() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Literal match + var result = await state.DoStringAsync("return string.match('hello world', 'hello')"); + Assert.That(result[0].Read(), Is.EqualTo("hello")); + + result = await state.DoStringAsync("return string.match('hello world', 'world')"); + Assert.That(result[0].Read(), Is.EqualTo("world")); + + // No match + result = await state.DoStringAsync("return string.match('hello world', 'xyz')"); + Assert.That(result[0].Type, Is.EqualTo(LuaValueType.Nil)); + } + + [Test] + public async Task Test_StringMatch_CharacterClasses() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // %d - digits + var result = await state.DoStringAsync("return string.match('hello123', '%d')"); + Assert.That(result[0].Read(), Is.EqualTo("1")); + + result = await state.DoStringAsync("return string.match('hello123', '%d+')"); + Assert.That(result[0].Read(), Is.EqualTo("123")); + + // %a - letters + result = await state.DoStringAsync("return string.match('123hello', '%a+')"); + Assert.That(result[0].Read(), Is.EqualTo("hello")); + + // %w - alphanumeric + result = await state.DoStringAsync("return string.match('test_123', '%w+')"); + Assert.That(result[0].Read(), Is.EqualTo("test")); + + // %s - whitespace + result = await state.DoStringAsync("return string.match('hello world', '%s')"); + Assert.That(result[0].Read(), Is.EqualTo(" ")); + } + + [Test] + public async Task Test_StringMatch_Quantifiers() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // + (one or more) + var result = await state.DoStringAsync("return string.match('aaa', 'a+')"); + Assert.That(result[0].Read(), Is.EqualTo("aaa")); + + // * (zero or more) + result = await state.DoStringAsync("return string.match('bbb', 'a*b')"); + Assert.That(result[0].Read(), Is.EqualTo("b")); + + result = await state.DoStringAsync("return string.match('aaab', 'a*b')"); + Assert.That(result[0].Read(), Is.EqualTo("aaab")); + + // ? (optional) + result = await state.DoStringAsync("return string.match('color', 'colou?r')"); + Assert.That(result[0].Read(), Is.EqualTo("color")); + + result = await state.DoStringAsync("return string.match('colour', 'colou?r')"); + Assert.That(result[0].Read(), Is.EqualTo("colour")); + + // - (minimal repetition) + result = await state.DoStringAsync("return string.match('aaab', 'a-b')"); + Assert.That(result[0].Read(), Is.EqualTo("aaab")); + } + + [Test] + public async Task Test_StringMatch_Captures() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Single capture + var result = await state.DoStringAsync("return string.match('hello world', '(%a+)')"); + Assert.That(result[0].Read(), Is.EqualTo("hello")); + + // Multiple captures + result = await state.DoStringAsync("return string.match('hello world', '(%a+) (%a+)')"); + Assert.That(result, Has.Length.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("hello")); + Assert.That(result[1].Read(), Is.EqualTo("world")); + + // Position capture + result = await state.DoStringAsync("return string.match('hello', '()llo')"); + Assert.That(result[0].Read(), Is.EqualTo(3)); + + // Email pattern + result = await state.DoStringAsync("return string.match('test@example.com', '(%w+)@(%w+)%.(%w+)')"); + Assert.That(result, Has.Length.EqualTo(3)); + Assert.That(result[0].Read(), Is.EqualTo("test")); + Assert.That(result[1].Read(), Is.EqualTo("example")); + Assert.That(result[2].Read(), Is.EqualTo("com")); + } + + [Test] + public async Task Test_StringMatch_Anchors() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // ^ (start anchor) + var result = await state.DoStringAsync("return string.match('hello world', '^hello')"); + Assert.That(result[0].Read(), Is.EqualTo("hello")); + + result = await state.DoStringAsync("return string.match('hello world', '^world')"); + Assert.That(result[0].Type, Is.EqualTo(LuaValueType.Nil)); + + // $ (end anchor) + result = await state.DoStringAsync("return string.match('hello world', 'world$')"); + Assert.That(result[0].Read(), Is.EqualTo("world")); + + result = await state.DoStringAsync("return string.match('hello world', 'hello$')"); + Assert.That(result[0].Type, Is.EqualTo(LuaValueType.Nil)); + } + + [Test] + public async Task Test_StringMatch_WithInitPosition() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Start from specific position + var result = await state.DoStringAsync("return string.match('hello world', 'o', 5)"); + Assert.That(result[0].Read(), Is.EqualTo("o")); + + result = await state.DoStringAsync("return string.match('hello world', 'o', 8)"); + Assert.That(result[0].Read(), Is.EqualTo("o")); + + // Negative init (from end) + result = await state.DoStringAsync("return string.match('hello', 'l', -2)"); + Assert.That(result[0].Read(), Is.EqualTo("l")); + } + + [Test] + public async Task Test_StringMatch_SpecialPatterns() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Dot (any character) + var result = await state.DoStringAsync("return string.match('hello', 'h.llo')"); + Assert.That(result[0].Read(), Is.EqualTo("hello")); + + // Character sets + result = await state.DoStringAsync("return string.match('hello123', '[0-9]+')"); + Assert.That(result[0].Read(), Is.EqualTo("123")); + + result = await state.DoStringAsync("return string.match('Hello', '[Hh]ello')"); + Assert.That(result[0].Read(), Is.EqualTo("Hello")); + + // Negated character sets + result = await state.DoStringAsync("return string.match('hello123', '[^a-z]+')"); + Assert.That(result[0].Read(), Is.EqualTo("123")); + } + + [Test] + public async Task Test_StringFind_BasicUsage() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Basic literal search + var result = await state.DoStringAsync("return string.find('hello world', 'world')"); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo(7)); // Start position (1-based) + Assert.That(result[1].Read(), Is.EqualTo(11)); // End position (1-based) + + // Search with start position + result = await state.DoStringAsync("return string.find('hello hello', 'hello', 3)"); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo(7)); // Second occurrence + Assert.That(result[1].Read(), Is.EqualTo(11)); + + // No match + result = await state.DoStringAsync("return string.find('hello world', 'xyz')"); + Assert.That(result.Length, Is.EqualTo(1)); + Assert.That(result[0].Type, Is.EqualTo(LuaValueType.Nil)); + } + + [Test] + public async Task Test_StringFind_WithPatterns() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Pattern with captures + var result = await state.DoStringAsync("return string.find('hello 123', '(%a+) (%d+)')"); + Assert.That(result.Length, Is.EqualTo(4)); // start, end, capture1, capture2 + Assert.That(result[0].Read(), Is.EqualTo(1)); // Start position + Assert.That(result[1].Read(), Is.EqualTo(9)); // End position + Assert.That(result[2].Read(), Is.EqualTo("hello")); // First capture + Assert.That(result[3].Read(), Is.EqualTo("123")); // Second capture + + // Character class patterns + result = await state.DoStringAsync("return string.find('abc123def', '%d+')"); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo(4)); // Position of '123' + Assert.That(result[1].Read(), Is.EqualTo(6)); + + // Anchored patterns + result = await state.DoStringAsync("return string.find('hello world', '^hello')"); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo(1)); + Assert.That(result[1].Read(), Is.EqualTo(5)); + + result = await state.DoStringAsync("return string.find('hello world', '^world')"); + Assert.That(result.Length, Is.EqualTo(1)); + Assert.That(result[0].Type, Is.EqualTo(LuaValueType.Nil)); + } + + [Test] + public async Task Test_StringFind_PlainSearch() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Plain search (4th parameter = true) + var result = await state.DoStringAsync("return string.find('hello (world)', '(world)', 1, true)"); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo(7)); // Start of '(world)' + Assert.That(result[1].Read(), Is.EqualTo(13)); // End of '(world)' + + // Pattern search would fail but plain search succeeds + result = await state.DoStringAsync("return string.find('test%d+test', '%d+', 1, true)"); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo(5)); // Literal '%d+' + Assert.That(result[1].Read(), Is.EqualTo(7)); + } + + [Test] + public async Task Test_StringFind_EdgeCases() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Empty pattern + var result = await state.DoStringAsync("return string.find('hello', '')"); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo(1)); + Assert.That(result[1].Read(), Is.EqualTo(0)); + + // Negative start position + result = await state.DoStringAsync("return string.find('hello', 'l', -2)"); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo(4)); // Last 'l' + Assert.That(result[1].Read(), Is.EqualTo(4)); + + // Start position beyond string length + result = await state.DoStringAsync("return string.find('hello', 'l', 10)"); + Assert.That(result.Length, Is.EqualTo(1)); + Assert.That(result[0].Type, Is.EqualTo(LuaValueType.Nil)); + + // Position captures + result = await state.DoStringAsync("return string.find('hello', '()l()l()')"); + Assert.That(result.Length, Is.EqualTo(5)); // start, end, pos1, pos2, pos3 + Assert.That(result[0].Read(), Is.EqualTo(3)); // Start of match + Assert.That(result[1].Read(), Is.EqualTo(4)); // End of match + Assert.That(result[2].Read(), Is.EqualTo(3)); // Position before first 'l' + Assert.That(result[3].Read(), Is.EqualTo(4)); // Position before second 'l' + Assert.That(result[4].Read(), Is.EqualTo(5)); // Position after second 'l' + } + + [Test] + public async Task Test_StringGMatch_BasicUsage() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + state.OpenTableLibrary(); + // Test basic gmatch iteration + var result = await state.DoStringAsync(@" + local words = {} + for word in string.gmatch('hello world lua', '%a+') do + table.insert(words, word) + end + return table.unpack(words) + "); + Assert.That(result.Length, Is.EqualTo(3)); + Assert.That(result[0].Read(), Is.EqualTo("hello")); + Assert.That(result[1].Read(), Is.EqualTo("world")); + Assert.That(result[2].Read(), Is.EqualTo("lua")); + } + + [Test] + public async Task Test_StringGMatch_WithCaptures() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + state.OpenTableLibrary(); + + // Test gmatch with captures + var result = await state.DoStringAsync(@" + local pairs = {} + for key, value in string.gmatch('a=1 b=2 c=3', '(%a)=(%d)') do + table.insert(pairs, key .. ':' .. value) + end + return table.unpack(pairs) + "); + Assert.That(result.Length, Is.EqualTo(3)); + Assert.That(result[0].Read(), Is.EqualTo("a:1")); + Assert.That(result[1].Read(), Is.EqualTo("b:2")); + Assert.That(result[2].Read(), Is.EqualTo("c:3")); + } + + [Test] + public async Task Test_StringGMatch_Numbers() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + state.OpenTableLibrary(); + + // Extract all numbers from a string + var result = await state.DoStringAsync(@" + local numbers = {} + for num in string.gmatch('price: $12.50, tax: $2.75, total: $15.25', '%d+%.%d+') do + table.insert(numbers, num) + end + return table.unpack(numbers) + "); + Assert.That(result.Length, Is.EqualTo(3)); + Assert.That(result[0].Read(), Is.EqualTo("12.50")); + Assert.That(result[1].Read(), Is.EqualTo("2.75")); + Assert.That(result[2].Read(), Is.EqualTo("15.25")); + } + + [Test] + public async Task Test_StringGMatch_EmptyMatches() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + state.OpenTableLibrary(); + + // Test with pattern that can match empty strings + var result = await state.DoStringAsync(@" + local count = 0 + for match in string.gmatch('abc', 'a*') do + count = count + 1 + if count > 10 then break end -- Prevent infinite loop + end + return count + "); + Assert.That(result[0].Read(), Is.EqualTo(3)); + } + + [Test] + public async Task Test_StringGMatch_ComplexPatterns() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + state.OpenTableLibrary(); + + // Extract email-like patterns + var result = await state.DoStringAsync(@" + local emails = {} + local text = 'Contact us at info@example.com or support@test.org for help' + for email in string.gmatch(text, '%w+@%w+%.%w+') do + table.insert(emails, email) + end + return table.unpack(emails) + "); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("info@example.com")); + Assert.That(result[1].Read(), Is.EqualTo("support@test.org")); + } + + [Test] + public async Task Test_StringGMatch_PositionCaptures() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + state.OpenTableLibrary(); + + // Test position captures with gmatch + var result = await state.DoStringAsync(@" + local positions = {} + for pos, char in string.gmatch('hello', '()(%a)') do + table.insert(positions, pos .. ':' .. char) + end + return table.unpack(positions) + "); + Assert.That(result.Length, Is.EqualTo(5)); + Assert.That(result[0].Read(), Is.EqualTo("1:h")); + Assert.That(result[1].Read(), Is.EqualTo("2:e")); + Assert.That(result[2].Read(), Is.EqualTo("3:l")); + Assert.That(result[3].Read(), Is.EqualTo("4:l")); + Assert.That(result[4].Read(), Is.EqualTo("5:o")); + } + + [Test] + public async Task Test_StringGMatch_NoMatches() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + state.OpenTableLibrary(); + + // Test when no matches are found + var result = await state.DoStringAsync(@" + local count = 0 + for match in string.gmatch('hello world', '%d+') do + count = count + 1 + end + return count + "); + Assert.That(result[0].Read(), Is.EqualTo(0)); + } + + [Test] + public async Task Test_StringGMatch_SingleCharacter() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + state.OpenTableLibrary(); + + // Test matching single characters + var result = await state.DoStringAsync(@" + local chars = {} + for char in string.gmatch('a1b2c3', '%a') do + table.insert(chars, char) + end + return table.unpack(chars) + "); + Assert.That(result.Length, Is.EqualTo(3)); + Assert.That(result[0].Read(), Is.EqualTo("a")); + Assert.That(result[1].Read(), Is.EqualTo("b")); + Assert.That(result[2].Read(), Is.EqualTo("c")); + } + + [Test] + public async Task Test_StringFind_And_GMatch_Consistency() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Test that find and gmatch work consistently with the same pattern + var result = await state.DoStringAsync(@" + local text = 'The quick brown fox jumps over the lazy dog' + + -- Find first word + local start, end_pos, word1 = string.find(text, '(%a+)') + + -- Get first word from gmatch + local word2 = string.gmatch(text, '%a+')() + + return word1, word2, start, end_pos + "); + Assert.That(result.Length, Is.EqualTo(4)); + Assert.That(result[0].Read(), Is.EqualTo("The")); // From find + Assert.That(result[1].Read(), Is.EqualTo("The")); // From gmatch + Assert.That(result[2].Read(), Is.EqualTo(1)); // Start position + Assert.That(result[3].Read(), Is.EqualTo(3)); // End position + } + + [Test] + public async Task Test_Pattern_NegatedCharacterClassWithCapture() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Test the problematic pattern ^([^:]*): + var result = await state.DoStringAsync(@" + local text = 'key:value' + local match = string.match(text, '^([^:]*):') + return match + "); + + Assert.That(result.Length, Is.EqualTo(1)); + Assert.That(result[0].Read(), Is.EqualTo("key")); + + // Test with empty match + result = await state.DoStringAsync(@" + local text = ':value' + local match = string.match(text, '^([^:]*):') + return match + "); + + Assert.That(result.Length, Is.EqualTo(1)); + Assert.That(result[0].Read(), Is.EqualTo("")); // Empty string + + // Test with multiple captures + result = await state.DoStringAsync(@" + local text = '[key]:[value]:extra' + local a, b = string.match(text, '^([^:]*):([^:]*)') + return a, b + "); + + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("[key]")); + Assert.That(result[1].Read(), Is.EqualTo("[value]")); + } + + [Test] + public async Task Test_StringGSub_BasicReplacements() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Simple string replacement + var result = await state.DoStringAsync("return string.gsub('hello world', 'world', 'lua')"); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("hello lua")); + Assert.That(result[1].Read(), Is.EqualTo(1)); // Replacement count + + // Multiple replacements + result = await state.DoStringAsync("return string.gsub('hello hello hello', 'hello', 'hi')"); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("hi hi hi")); + Assert.That(result[1].Read(), Is.EqualTo(3)); + + // Limited replacements + result = await state.DoStringAsync("return string.gsub('hello hello hello', 'hello', 'hi', 2)"); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("hi hi hello")); + Assert.That(result[1].Read(), Is.EqualTo(2)); + } + + [Test] + public async Task Test_StringGSub_PatternReplacements() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Character class patterns + var result = await state.DoStringAsync("return string.gsub('hello123world456', '%d+', 'X')"); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("helloXworldX")); + Assert.That(result[1].Read(), Is.EqualTo(2)); + + // Capture replacements + result = await state.DoStringAsync("return string.gsub('John Doe', '(%a+) (%a+)', '%2, %1')"); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("Doe, John")); + Assert.That(result[1].Read(), Is.EqualTo(1)); + + // Whole match replacement (%0) + result = await state.DoStringAsync("return string.gsub('test123', '%d+', '[%0]')"); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("test[123]")); + Assert.That(result[1].Read(), Is.EqualTo(1)); + } + + [Test] + public async Task Test_StringGSub_FunctionReplacements() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Function replacement + var result = await state.DoStringAsync(@" + return string.gsub('hello world', '%a+', function(s) + return s:upper() + end) + "); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("HELLO WORLD")); + Assert.That(result[1].Read(), Is.EqualTo(2)); + + // Function with position captures + result = await state.DoStringAsync(@" + return string.gsub('hello', '()l', function(pos) + return '[' .. pos .. ']' + end) + "); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("he[3][4]o")); + Assert.That(result[1].Read(), Is.EqualTo(2)); + + // Function returning nil (no replacement) + result = await state.DoStringAsync(@" + return string.gsub('a1b2c3', '%d', function(s) + if s == '2' then return nil end + return 'X' + end) + "); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("aXb2cX")); + Assert.That(result[1].Read(), Is.EqualTo(3)); // Only 2 replacements made + } + + [Test] + public async Task Test_StringGSub_TableReplacements() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Table replacement + var result = await state.DoStringAsync(@" + local map = {hello = 'hi', world = 'lua'} + return string.gsub('hello world', '%a+', map) + "); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("hi lua")); + Assert.That(result[1].Read(), Is.EqualTo(2)); + + // Table with missing keys (no replacement) + result = await state.DoStringAsync(@" + local map = {hello = 'hi'} + return string.gsub('hello world', '%a+', map) + "); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("hi world")); + Assert.That(result[1].Read(), Is.EqualTo(2)); // Only 'hello' was replaced + } + + [Test] + public async Task Test_StringGSub_EmptyPattern() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Empty pattern should match at every position + var result = await state.DoStringAsync("return string.gsub('abc', '', '.')"); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo(".a.b.c.")); + Assert.That(result[1].Read(), Is.EqualTo(4)); // 4 positions: before a, before b, before c, after c + } + + [Test] + public async Task Test_StringGSub_BalancedPatterns() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Balanced parentheses pattern + var result = await state.DoStringAsync(@" + return string.gsub('(hello) and (world)', '%b()', function(s) + return s:upper() + end) + "); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("(HELLO) and (WORLD)")); + Assert.That(result[1].Read(), Is.EqualTo(2)); + + // Balanced brackets + result = await state.DoStringAsync("return string.gsub('[a][b][c]', '%b[]', 'X')"); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("XXX")); + Assert.That(result[1].Read(), Is.EqualTo(3)); + } + + [Test] + public async Task Test_StringGSub_EscapeSequences() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Test %% escape (literal %) + var result = await state.DoStringAsync("return string.gsub('test', 'test', '100%%')"); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("100%")); + Assert.That(result[1].Read(), Is.EqualTo(1)); + } + + [Test] + public async Task Test_StringGSub_EdgeCases() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Empty string + var result = await state.DoStringAsync("return string.gsub('', 'a', 'b')"); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("")); + Assert.That(result[1].Read(), Is.EqualTo(0)); + + // No matches + result = await state.DoStringAsync("return string.gsub('hello', 'xyz', 'abc')"); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("hello")); + Assert.That(result[1].Read(), Is.EqualTo(0)); + + // Zero replacement limit + result = await state.DoStringAsync("return string.gsub('hello hello', 'hello', 'hi', 0)"); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("hello hello")); + Assert.That(result[1].Read(), Is.EqualTo(0)); + } + + [Test] + public async Task Test_StringGSub_ComplexPatterns() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Email replacement + var result = await state.DoStringAsync(@" + local text = 'Contact john@example.com or jane@test.org' + return string.gsub(text, '(%w+)@(%w+)%.(%w+)', function(user, domain, tld) + return user:upper() .. '@' .. domain:upper() .. '.' .. tld:upper() + end) + "); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("Contact JOHN@EXAMPLE.COM or JANE@TEST.ORG")); + Assert.That(result[1].Read(), Is.EqualTo(2)); + + // URL path extraction + result = await state.DoStringAsync(@" + return string.gsub('http://example.com/path/to/file.html', + '^https?://[^/]+(/.*)', '%1') + "); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("/path/to/file.html")); + Assert.That(result[1].Read(), Is.EqualTo(1)); + } + + [Test] + public async Task Test_PatternMatching_Consistency() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Test that all string functions work consistently with same patterns + var result = await state.DoStringAsync(@" + local text = 'The quick brown fox jumps over the lazy dog' + local pattern = '%a+' + + -- Test find + local start, end_pos, word = string.find(text, '(' .. pattern .. ')') + + -- Test match + local match = string.match(text, pattern) + + -- Test gsub count + local _, count = string.gsub(text, pattern, function(s) return s end) + + -- Test gmatch count + local gmatch_count = 0 + for word in string.gmatch(text, pattern) do + gmatch_count = gmatch_count + 1 + end + + return word, match, count, gmatch_count, start, end_pos + "); + + Assert.That(result.Length, Is.EqualTo(6)); + Assert.That(result[0].Read(), Is.EqualTo("The")); // find capture + Assert.That(result[1].Read(), Is.EqualTo("The")); // match result + Assert.That(result[2].Read(), Is.EqualTo(9)); // gsub count (9 words) + Assert.That(result[3].Read(), Is.EqualTo(9)); // gmatch count + Assert.That(result[4].Read(), Is.EqualTo(1)); // find start + Assert.That(result[5].Read(), Is.EqualTo(3)); // find end + } + + [Test] + public async Task Test_PatternMatching_SpecialPatterns() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Frontier pattern %f + var result = await state.DoStringAsync(@" + return string.gsub('hello world', '%f[%a]', '[') + "); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("[hello [world")); + Assert.That(result[1].Read(), Is.EqualTo(2)); + + // Minimal repetition with - + result = await state.DoStringAsync("return string.match('aaab', 'a-b')"); + Assert.That(result[0].Read(), Is.EqualTo("aaab")); + + // Optional quantifier ? + result = await state.DoStringAsync("return string.gsub('color colour', 'colou?r', 'COLOR')"); + Assert.That(result.Length, Is.EqualTo(2)); + Assert.That(result[0].Read(), Is.EqualTo("COLOR COLOR")); + Assert.That(result[1].Read(), Is.EqualTo(2)); + } + + [Test] + public async Task Test_PatternMatching_ErrorCases() + { + var state = LuaState.Create(); + state.OpenStringLibrary(); + + // Invalid pattern - missing closing bracket + var exception = Assert.ThrowsAsync(async () => + await state.DoStringAsync("return string.match('test', '[abc')")); + Assert.That(exception.Message, Does.Contain("missing ']'")); + + // Invalid pattern - missing %b arguments + exception = Assert.ThrowsAsync(async () => + await state.DoStringAsync("return string.match('test', '%b')")); + Assert.That(exception.Message, Does.Contain("missing arguments to '%b'")); + + // Pattern too complex (exceeds recursion limit) + exception = Assert.ThrowsAsync(async () => + await state.DoStringAsync("return string.match(string.rep('a', 1000), string.rep('a?', 1000) .. string.rep('a', 1000))")); + Assert.That(exception.Message, Does.Contain("pattern too complex")); + } +} \ No newline at end of file