Skip to content

Commit

Permalink
impr: text - replace function
Browse files Browse the repository at this point in the history
  • Loading branch information
nalgeon committed Jun 5, 2023
1 parent 709febb commit e894804
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 2 deletions.
88 changes: 88 additions & 0 deletions src/sqlite3-text.c
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,88 @@ static void sqlite3_pad(sqlite3_context* context, int argc, sqlite3_value** argv

#pragma endregion

#pragma region Other modifications

// Replaces all old substrings with new substrings in the original string.
// text_replace(str, old, new)
// [pg-compatible] replace(string, from, to)
static void sqlite3_replace_all(sqlite3_context* context, int argc, sqlite3_value** argv) {
assert(argc == 3);

const char* src = (char*)sqlite3_value_text(argv[0]);
if (src == NULL) {
sqlite3_result_null(context);
return;
}

const char* old = (char*)sqlite3_value_text(argv[1]);
if (old == NULL) {
sqlite3_result_null(context);
return;
}

const char* new = (char*)sqlite3_value_text(argv[2]);
if (new == NULL) {
sqlite3_result_null(context);
return;
}

ByteString s_src = bstring.from_cstring(src, sqlite3_value_bytes(argv[0]));
ByteString s_old = bstring.from_cstring(old, sqlite3_value_bytes(argv[1]));
ByteString s_new = bstring.from_cstring(new, sqlite3_value_bytes(argv[2]));
ByteString s_res = bstring.replace_all(s_src, s_old, s_new);
const char* res = bstring.to_cstring(s_res);
sqlite3_result_text(context, res, -1, SQLITE_TRANSIENT);
bstring.free(s_src);
bstring.free(s_old);
bstring.free(s_new);
bstring.free(s_res);
}

// Replaces old substrings with new substrings in the original string,
// but not more than `count` times.
// text_replace(str, old, new, count)
static void sqlite3_replace(sqlite3_context* context, int argc, sqlite3_value** argv) {
assert(argc == 4);

const char* src = (char*)sqlite3_value_text(argv[0]);
if (src == NULL) {
sqlite3_result_null(context);
return;
}

const char* old = (char*)sqlite3_value_text(argv[1]);
if (old == NULL) {
sqlite3_result_null(context);
return;
}

const char* new = (char*)sqlite3_value_text(argv[2]);
if (new == NULL) {
sqlite3_result_null(context);
return;
}

if (sqlite3_value_type(argv[3]) != SQLITE_INTEGER) {
sqlite3_result_error(context, "count parameter should be integer", -1);
return;
}
int count = sqlite3_value_int(argv[3]);
// treat negative count as zero
count = count < 0 ? 0 : count;

ByteString s_src = bstring.from_cstring(src, sqlite3_value_bytes(argv[0]));
ByteString s_old = bstring.from_cstring(old, sqlite3_value_bytes(argv[1]));
ByteString s_new = bstring.from_cstring(new, sqlite3_value_bytes(argv[2]));
ByteString s_res = bstring.replace(s_src, s_old, s_new, count);
const char* res = bstring.to_cstring(s_res);
sqlite3_result_text(context, res, -1, SQLITE_TRANSIENT);
bstring.free(s_src);
bstring.free(s_old);
bstring.free(s_new);
bstring.free(s_res);
}

// Reverses a string.
// text_reverse(str)
static void sqlite3_reverse(sqlite3_context* context, int argc, sqlite3_value** argv) {
Expand All @@ -661,6 +743,8 @@ static void sqlite3_reverse(sqlite3_context* context, int argc, sqlite3_value**
rstring.free(s_res);
}

#pragma endregion

// substring
// utf8 text_slice(str, start [,end])
// utf8 text_substring(str, start [,length])
Expand Down Expand Up @@ -745,6 +829,10 @@ __declspec(dllexport)
sqlite3_create_function(db, "text_rpad", -1, flags, rstring.pad_right, sqlite3_pad, 0, 0);
sqlite3_create_function(db, "rpad", -1, flags, rstring.pad_right, sqlite3_pad, 0, 0);

// other modifications
sqlite3_create_function(db, "text_replace", 3, flags, 0, sqlite3_replace_all, 0, 0);
sqlite3_create_function(db, "replace", 3, flags, 0, sqlite3_replace_all, 0, 0);
sqlite3_create_function(db, "text_replace", 4, flags, 0, sqlite3_replace, 0, 0);
sqlite3_create_function(db, "text_reverse", 1, flags, 0, sqlite3_reverse, 0, 0);
sqlite3_create_function(db, "reverse", 1, flags, 0, sqlite3_reverse, 0, 0);
return SQLITE_OK;
Expand Down
2 changes: 1 addition & 1 deletion src/text/bstring.c
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ static ByteString string_repeat(ByteString str, size_t count) {

// string_replace replaces the `old` substring with the `new` substring in the original string,
// but not more than `max_count` times.
static ByteString string_replace(ByteString str, ByteString old, ByteString new, size_t max_count) {
static ByteString string_replace(ByteString str, ByteString old, ByteString new, int max_count) {
// count matches of the old string in the source string
size_t count = string_count(str, old);
if (count == 0) {
Expand Down
2 changes: 1 addition & 1 deletion src/text/bstring.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ struct bstring_ns {
ByteString (*concat)(ByteString* strings, size_t count);
ByteString (*repeat)(ByteString str, size_t count);

ByteString (*replace)(ByteString str, ByteString old, ByteString new, size_t max_count);
ByteString (*replace)(ByteString str, ByteString old, ByteString new, int max_count);
ByteString (*replace_all)(ByteString str, ByteString old, ByteString new);
ByteString (*reverse)(ByteString str);

Expand Down
35 changes: 35 additions & 0 deletions test/text.sql
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,41 @@ select '17_14', text_rpad('hello', 8, '*') = 'hello***';
select '17_15', text_rpad('hello', 8, 'xo') = 'helloxox';
select '17_16', text_rpad('мир', 6, 'хо') = 'мирхох';
-- Replace all
select '18_01', text_replace(null, 'a', '*') is null;
select '18_02', text_replace('abc', null, '*') is null;
select '18_03', text_replace('abc', 'a', null) is null;
select '18_04', text_replace('hello', 'l', '*') = 'he**o';
select '18_05', text_replace('hello', 'l', 'xo') = 'hexoxoo';
select '18_06', text_replace('hello', 'ell', '*') = 'h*o';
select '18_07', text_replace('hello', 'ello', 'argh') = 'hargh';
select '18_08', text_replace('hello', 'hello', '-') = '-';
select '18_09', text_replace('hello', '', '*') = 'hello';
select '18_10', text_replace('hello', 'l', '') = 'heo';
select '18_11', text_replace('', 'l', '*') = '';
select '18_12', text_replace('нетто', 'т', 'три') = 'нетритрио';
-- Replace
select '19_01', text_replace(null, 'a', '*', 1) is null;
select '19_02', text_replace('abc', null, '*', 1) is null;
select '19_03', text_replace('abc', 'a', null, 1) is null;
select '19_04', text_replace('hello', 'l', '*', 2) = 'he**o';
select '19_05', text_replace('hello', 'l', 'xo', 2) = 'hexoxoo';
select '19_06', text_replace('hello', 'ell', '*', 1) = 'h*o';
select '19_07', text_replace('hello', 'ello', 'argh', 1) = 'hargh';
select '19_08', text_replace('hello', 'hello', '-', 1) = '-';
select '19_09', text_replace('hello', '', '*', 1) = 'hello';
select '19_10', text_replace('hello', 'l', '', 2) = 'heo';
select '19_11', text_replace('', 'l', '*', 1) = '';
select '19_12', text_replace('нетто', 'т', 'три', 2) = 'нетритрио';
select '19_21', text_replace('hello', 'l', '*', -1) = 'hello';
select '19_22', text_replace('hello', 'l', '*', 0) = 'hello';
select '19_23', text_replace('hello', 'l', '*', 1) = 'he*lo';
select '19_24', text_replace('hello', 'l', '*', 2) = 'he**o';
select '19_25', text_replace('hello', 'l', '*', 3) = 'he**o';
select '19_16', text_replace('нетто', 'т', 'три', 1) = 'нетрито';
-- Reverse string
select 'x_01', text_reverse(null) is NULL;
select 'x_02', text_reverse('hello') = 'olleh';
Expand Down
20 changes: 20 additions & 0 deletions test/text/bstring.test.c
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,26 @@ static void test_replace(void) {
bstring.free(res);
}

{
ByteString old = bstring.from_cstring("o", 1);
ByteString new = bstring.from_cstring("***", 3);
ByteString res = bstring.replace(str, old, new, 1);
assert(eq(res, "hell*** world"));
bstring.free(old);
bstring.free(new);
bstring.free(res);
}

{
ByteString old = bstring.from_cstring("o", 1);
ByteString new = bstring.from_cstring("***", 3);
ByteString res = bstring.replace(str, old, new, 2);
assert(eq(res, "hell*** w***rld"));
bstring.free(old);
bstring.free(new);
bstring.free(res);
}

{
ByteString old = bstring.from_cstring("e", 1);
ByteString new = bstring.from_cstring("***", 3);
Expand Down

0 comments on commit e894804

Please sign in to comment.