Skip to content

Commit

Permalink
[libc refactor] regex_match() returns a list of indices
Browse files Browse the repository at this point in the history
This will enable _start() and _end()

As well as m => start() and m => end()

Also fix issue where an unmatched subgroup returned '' instead of
None/nullptr.

A zero-length match is different than not matching.
  • Loading branch information
Andy Chu committed Dec 14, 2023
1 parent 71715fd commit e6bbc73
Show file tree
Hide file tree
Showing 10 changed files with 111 additions and 83 deletions.
25 changes: 22 additions & 3 deletions core/util.py
Expand Up @@ -16,19 +16,38 @@

import libc

from typing import List
from typing import List, Optional


def _Groups(s, indices):
# type: (str, Optional[List[int]]) -> List[str]
if indices is None:
return None

groups = [] # type: List[str]
n = len(indices)
for i in xrange(n / 2):
start = indices[2 * i]
end = indices[2 * i + 1]
if start == -1:
groups.append(None)
else:
groups.append(s[start:end])
return groups


def simple_regex_search(pat, s):
# type: (str, str) -> List[str]
"""Convenience wrapper around libc."""
return libc.regex_match(pat, s, 0)
indices = libc.regex_match(pat, s, 0)
return _Groups(s, indices)


def regex_search(pat, comp_flags, s):
# type: (str, int, str) -> List[str]
"""Convenience wrapper around libc."""
return libc.regex_match(pat, s, comp_flags)
indices = libc.regex_match(pat, s, comp_flags)
return _Groups(s, indices)


class UserExit(Exception):
Expand Down
35 changes: 35 additions & 0 deletions core/util_test.py
Expand Up @@ -8,16 +8,51 @@
"""util_test.py: Tests for util.py."""

import unittest
import sys

from core import util # module under test

# guard some tests that fail on Darwin
IS_DARWIN = sys.platform == 'darwin'


class UtilTest(unittest.TestCase):

def testDebugFile(self):
n = util.NullDebugFile()
n.write('foo')

def testSimpleRegexSearch(self):
cases = [
('([a-z]+)([0-9]+)', 'foo123', ['foo123', 'foo', '123']),
(r'.*\.py', 'foo.py', ['foo.py']),
(r'.*\.py', 'abcd', None),
# The match is unanchored
(r'bc', 'abcd', ['bc']),
# The match is unanchored
(r'.c', 'abcd', ['bc']),
# Empty matches empty
None if IS_DARWIN else (r'', '', ['']),
(r'^$', '', ['']),
(r'^.$', '', None),
]

# TODO:
#
# return a single list of length 2*(1 + nsub)
# 2 is for start and end, +1 is for 0
#
# indices = regex_search(...)
# indices[2*group] is start
# indices[2*group+1] is end
# group is from 0 ... n

for pat, s, expected in filter(None, cases):
#print('CASE %s' % pat)
actual = util.simple_regex_search(pat, s)
#print('actual %r' % actual)
self.assertEqual(expected, actual)


if __name__ == '__main__':
unittest.main()
22 changes: 11 additions & 11 deletions cpp/libc.cc
Expand Up @@ -105,28 +105,28 @@ List<BigStr*>* glob(BigStr* pat) {

// Raises RuntimeError if the pattern is invalid. TODO: Use a different
// exception?
List<BigStr*>* regex_match(BigStr* pattern, BigStr* str, int flags, int pos) {
List<BigStr*>* results = NewList<BigStr*>();

List<int>* regex_match(BigStr* pattern, BigStr* str, int flags, int pos) {
flags |= REG_EXTENDED;
regex_t pat;
if (regcomp(&pat, pattern->data_, flags) != 0) {
// TODO: check error code, as in func_regex_parse()
throw Alloc<RuntimeError>(StrFromC("Invalid regex syntax (regex_match)"));
}

int outlen = pat.re_nsub + 1; // number of captures
int num_groups = pat.re_nsub + 1; // number of captures

List<int>* indices = NewList<int>();
indices->reserve(num_groups * 2);

const char* s0 = str->data_;
regmatch_t* pmatch =
static_cast<regmatch_t*>(malloc(sizeof(regmatch_t) * outlen));
int match = regexec(&pat, s0, outlen, pmatch, 0) == 0;
static_cast<regmatch_t*>(malloc(sizeof(regmatch_t) * num_groups));
int match = regexec(&pat, s0, num_groups, pmatch, 0) == 0;
if (match) {
int i;
for (i = 0; i < outlen; i++) {
int len = pmatch[i].rm_eo - pmatch[i].rm_so;
BigStr* m = StrFromC(s0 + pmatch[i].rm_so, len);
results->append(m);
for (i = 0; i < num_groups; i++) {
indices->append(pmatch[i].rm_so);
indices->append(pmatch[i].rm_eo);
}
}

Expand All @@ -137,7 +137,7 @@ List<BigStr*>* regex_match(BigStr* pattern, BigStr* str, int flags, int pos) {
return nullptr;
}

return results;
return indices;
}

// For ${//}, the number of groups is always 1, so we want 2 match position
Expand Down
3 changes: 1 addition & 2 deletions cpp/libc.h
Expand Up @@ -27,8 +27,7 @@ List<BigStr*>* glob(BigStr* pat);
Tuple2<int, int>* regex_first_group_match(BigStr* pattern, BigStr* str,
int pos);

List<BigStr*>* regex_match(BigStr* pattern, BigStr* str, int flags,
int pos = 0);
List<int>* regex_match(BigStr* pattern, BigStr* str, int flags, int pos = 0);

int wcswidth(BigStr* str);
int get_terminal_width();
Expand Down
36 changes: 25 additions & 11 deletions cpp/libc_test.cc
@@ -1,6 +1,6 @@
#include "cpp/libc.h"

#include <regex.h> // regcomp()
#include <regex.h> // regcomp()
#include <unistd.h> // gethostname()

#include "mycpp/runtime.h"
Expand Down Expand Up @@ -68,23 +68,40 @@ TEST libc_test() {
PASS();
}

static List<BigStr*>* Groups(BigStr* s, List<int>* indices) {
List<BigStr*>* groups = NewList<BigStr*>();
int n = len(indices) / 2;
for (int i = 0; i < n; ++i) {
int start = indices->at(2 * i);
int end = indices->at(2 * i + 1);
if (start == -1) {
groups->append(nullptr);
} else {
groups->append(s->slice(start, end));
}
}
return groups;
}

TEST regex_test() {
List<BigStr*>* results =
libc::regex_match(StrFromC("(a+).(a+)"), StrFromC("-abaacaaa"), 0);
BigStr* s1 = StrFromC("-abaacaaa");
List<int>* indices = libc::regex_match(StrFromC("(a+).(a+)"), s1, 0);
List<BigStr*>* results = Groups(s1, indices);
ASSERT_EQ_FMT(3, len(results), "%d");
ASSERT(str_equals(StrFromC("abaa"), results->at(0))); // whole match
ASSERT(str_equals(StrFromC("a"), results->at(1)));
ASSERT(str_equals(StrFromC("aa"), results->at(2)));

results = libc::regex_match(StrFromC("z+"), StrFromC("abaacaaa"), 0);
ASSERT_EQ(nullptr, results);
indices = libc::regex_match(StrFromC("z+"), StrFromC("abaacaaa"), 0);
ASSERT_EQ(nullptr, indices);

// Alternation gives unmatched group
results = libc::regex_match(StrFromC("(a)|(b)"), StrFromC("b"), 0);
BigStr* s2 = StrFromC("b");
indices = libc::regex_match(StrFromC("(a)|(b)"), s2, 0);
results = Groups(s2, indices);
ASSERT_EQ_FMT(3, len(results), "%d");
ASSERT(str_equals(StrFromC("b"), results->at(0))); // whole match
// TODO: this should be null. It is in JavaScript and Python
ASSERT(str_equals(StrFromC(""), results->at(1)));
ASSERT_EQ(nullptr, results->at(1));
ASSERT(str_equals(StrFromC("b"), results->at(2)));

Tuple2<int, int>* result;
Expand Down Expand Up @@ -145,7 +162,6 @@ void FindAll(const char* p, const char* s) {
int cur_pos = 0;
// int n = strlen(s);
while (true) {

// Necessary so ^ doesn't match in the middle!
int eflags = cur_pos == 0 ? 0 : REG_NOTBOL;
bool match = regexec(&pat, s + cur_pos, outlen, pmatch, eflags) == 0;
Expand Down Expand Up @@ -176,7 +192,6 @@ void FindAll(const char* p, const char* s) {
// adjacent matches
const char* s = "a345y-axy- there b789y- cy-";


TEST regex_unanchored() {
const char* unanchored = "[abc]([0-9]*)(x?)(y)-";
FindAll(unanchored, s);
Expand Down Expand Up @@ -233,7 +248,6 @@ TEST regex_alt_with_capture() {
PASS();
}


GREATEST_MAIN_DEFS();

int main(int argc, char** argv) {
Expand Down
2 changes: 1 addition & 1 deletion osh/word_eval.py
Expand Up @@ -177,7 +177,7 @@ def _SplitAssignArg(arg, blame_word):
# m[2] is used for grouping; ERE doesn't have non-capturing groups

op = m[3]
if len(op): # declare NAME=
if op is not None and len(op): # declare NAME=
val = value.Str(m[4]) # type: Optional[value_t]
append = op[0] == '+'
else: # declare NAME
Expand Down
8 changes: 3 additions & 5 deletions osh/word_eval_test.py
Expand Up @@ -35,8 +35,8 @@ def InitEvaluator():
class RegexTest(unittest.TestCase):
def testSplitAssignArg(self):
CASES = [
('s', ['s', '', '']),
('value', ['value', '', '']),
('s', ['s', None, None]),
('value', ['value', None, None]),
('s!', None),
('!', None),
('=s', None),
Expand All @@ -53,9 +53,7 @@ def testSplitAssignArg(self):
self.assertEqual(expected, actual) # no match
else:
_, var_name, _, op, value = actual
self.assertEqual(expected[0], var_name)
self.assertEqual(expected[1], op)
self.assertEqual(expected[2], value)
self.assertEqual(expected, [var_name, op, value])


class WordEvalTest(unittest.TestCase):
Expand Down
18 changes: 10 additions & 8 deletions pyext/libc.c
Expand Up @@ -202,22 +202,24 @@ func_regex_match(PyObject *self, PyObject *args) {
return NULL;
}

int outlen = pat.re_nsub + 1;
PyObject *ret = PyList_New(outlen);
int num_groups = pat.re_nsub + 1;
PyObject *ret = PyList_New(num_groups * 2);

if (ret == NULL) {
regfree(&pat);
return NULL;
}

regmatch_t *pmatch = (regmatch_t*) malloc(sizeof(regmatch_t) * outlen);
int match = regexec(&pat, str, outlen, pmatch, 0);
regmatch_t *pmatch = (regmatch_t*) malloc(sizeof(regmatch_t) * num_groups);
int match = regexec(&pat, str, num_groups, pmatch, 0);
if (match == 0) {
int i;
for (i = 0; i < outlen; i++) {
int len = pmatch[i].rm_eo - pmatch[i].rm_so;
PyObject *v = PyString_FromStringAndSize(str + pmatch[i].rm_so, len);
PyList_SetItem(ret, i, v);
for (i = 0; i < num_groups; i++) {
PyObject *start = PyInt_FromLong(pmatch[i].rm_so);
PyList_SetItem(ret, 2*i, start);

PyObject *end = PyInt_FromLong(pmatch[i].rm_eo);
PyList_SetItem(ret, 2*i + 1, end);
}
}

Expand Down
2 changes: 1 addition & 1 deletion pyext/libc.pyi
Expand Up @@ -8,7 +8,7 @@ def gethostname() -> str: ...
def glob(pat: str) -> List[str]: ...
def fnmatch(pat: str, s: str, flags: int = 0) -> bool: ...
def regex_first_group_match(regex: str, s: str, pos: int) -> Optional[Tuple[int, int]]: ...
def regex_match(regex: str, s: str, flags: int, pos: int = 0) -> Optional[List[str]]: ...
def regex_match(regex: str, s: str, flags: int, pos: int = 0) -> Optional[List[int]]: ...
def wcswidth(s: str) -> int: ...
def get_terminal_width() -> int: ...
def print_time(real: float, user: float, sys: float) -> None: ...
Expand Down
43 changes: 2 additions & 41 deletions pyext/libc_test.py
Expand Up @@ -179,47 +179,8 @@ def testGlob(self):
print(libc.glob('\\\\'))
print(libc.glob('[[:punct:]]'))

def testRegexMatch(self):
# TODO: can delete this function
if 0:
self.assertEqual(True, libc.regex_parse(r'.*\.py'))

# Syntax errors
self.assertRaises(RuntimeError, libc.regex_parse, r'*')
self.assertRaises(RuntimeError, libc.regex_parse, '\\')
if not IS_DARWIN:
self.assertRaises(RuntimeError, libc.regex_parse, '{')

cases = [
('([a-z]+)([0-9]+)', 'foo123', ['foo123', 'foo', '123']),
(r'.*\.py', 'foo.py', ['foo.py']),
(r'.*\.py', 'abcd', None),
# The match is unanchored
(r'bc', 'abcd', ['bc']),
# The match is unanchored
(r'.c', 'abcd', ['bc']),
# Empty matches empty
None if IS_DARWIN else (r'', '', ['']),
(r'^$', '', ['']),
(r'^.$', '', None),
]

# TODO:
#
# return a single list of length 2*(1 + nsub)
# 2 is for start and end, +1 is for 0
#
# indices = regex_search(...)
# indices[2*group] is start
# indices[2*group+1] is end
# group is from 0 ... n

for pat, s, expected in filter(None, cases):
#print('CASE %s' % pat)
actual = libc.regex_match(pat, s, 0)
self.assertEqual(expected, actual)

def testRegexMatch(self):
def testRegexMatchError(self):
# See core/util_test.py for more tests
self.assertRaises(RuntimeError, libc.regex_match, r'*', 'abcd', 0)

def testRegexFirstGroupMatch(self):
Expand Down

0 comments on commit e6bbc73

Please sign in to comment.