Skip to content

Commit

Permalink
Fixed #28805 -- Added regular expression database function.
Browse files Browse the repository at this point in the history
  • Loading branch information
ngnpope committed Mar 7, 2020
1 parent 98f23a8 commit e54691d
Show file tree
Hide file tree
Showing 7 changed files with 470 additions and 5 deletions.
24 changes: 24 additions & 0 deletions django/db/backends/sqlite3/base.py
Expand Up @@ -232,6 +232,8 @@ def get_new_connection(self, conn_params):
conn.create_function('PI', 0, lambda: math.pi)
conn.create_function('POWER', 2, none_guard(operator.pow))
conn.create_function('RADIANS', 1, none_guard(math.radians))
conn.create_function('REGEXP_REPLACE', 4, _sqlite_regexp_replace)
conn.create_function('REGEXP_SUBSTR', 3, _sqlite_regexp_substr)
conn.create_function('REPEAT', 2, none_guard(operator.mul))
conn.create_function('REVERSE', 1, none_guard(lambda x: x[::-1]))
conn.create_function('RPAD', 3, _sqlite_rpad)
Expand Down Expand Up @@ -579,6 +581,28 @@ def _sqlite_regexp(re_pattern, re_string):
return bool(re.search(re_pattern, str(re_string)))


def _sqlite_regexp_convert_flags(flags):
count = 0 if 'g' in flags else 1
if 'c' in flags and 'i' in flags and flags.index('c') > flags.index('i'):
flags = flags.replace('i', '') # Force case-sensitive matching.
flags = (getattr(re, x.upper()) for x in flags if x in 'imsx')
flags = functools.reduce(operator.or_, flags, 0)
return count, flags


@none_guard
def _sqlite_regexp_replace(text, pattern, replacement, flags=''):
count, flags = _sqlite_regexp_convert_flags(flags)
return re.sub(pattern, replacement, text, count, flags)


@none_guard
def _sqlite_regexp_substr(text, pattern, flags=''):
_, flags = _sqlite_regexp_convert_flags(flags)
match = re.search(pattern, text, flags)
return match and match.group(0)


@none_guard
def _sqlite_lpad(text, length, fill_text):
if len(text) >= length:
Expand Down
10 changes: 5 additions & 5 deletions django/db/models/functions/__init__.py
Expand Up @@ -12,8 +12,8 @@
)
from .text import (
MD5, SHA1, SHA224, SHA256, SHA384, SHA512, Chr, Concat, ConcatPair, Left,
Length, Lower, LPad, LTrim, Ord, Repeat, Replace, Reverse, Right, RPad,
RTrim, StrIndex, Substr, Trim, Upper,
Length, Lower, LPad, LTrim, Ord, RegexpReplace, RegexpSubstr, Repeat,
Replace, Reverse, Right, RPad, RTrim, StrIndex, Substr, Trim, Upper,
)
from .window import (
CumeDist, DenseRank, FirstValue, Lag, LastValue, Lead, NthValue, Ntile,
Expand All @@ -35,9 +35,9 @@
'Sign', 'Sin', 'Sqrt', 'Tan',
# text
'MD5', 'SHA1', 'SHA224', 'SHA256', 'SHA384', 'SHA512', 'Chr', 'Concat',
'ConcatPair', 'Left', 'Length', 'Lower', 'LPad', 'LTrim', 'Ord', 'Repeat',
'Replace', 'Reverse', 'Right', 'RPad', 'RTrim', 'StrIndex', 'Substr',
'Trim', 'Upper',
'ConcatPair', 'Left', 'Length', 'Lower', 'LPad', 'LTrim', 'Ord',
'RegexpReplace', 'RegexpSubstr', 'Repeat', 'Replace', 'Reverse', 'Right',
'RPad', 'RTrim', 'StrIndex', 'Substr', 'Trim', 'Upper',
# window
'CumeDist', 'DenseRank', 'FirstValue', 'Lag', 'LastValue', 'Lead',
'NthValue', 'Ntile', 'PercentRank', 'Rank', 'RowNumber',
Expand Down
76 changes: 76 additions & 0 deletions django/db/models/functions/text.py
@@ -1,3 +1,5 @@
import re

from django.db import NotSupportedError
from django.db.models.expressions import Func, Value
from django.db.models.fields import IntegerField
Expand Down Expand Up @@ -202,6 +204,80 @@ def as_sqlite(self, compiler, connection, **extra_context):
return super().as_sql(compiler, connection, function='UNICODE', **extra_context)


class RegexpFlagMixin:
flag_mapping = {k: str.maketrans(v) for k, v in {
'mariadb': {'c': '-i'},
'oracle': {'s': 'n'},
'postgresql': {'m': 'n'},
}.items()}
inline_flags = {}
wrap_pattern = {}

def as_sql(self, compiler, connection, **extra_context):
expression, pattern, *extra, flags = self.get_source_expressions()

if connection.vendor == 'mysql' and connection.mysql_is_mariadb:
vendor = 'mariadb'
else:
vendor = connection.vendor

# Wrap pattern in group if required and increment backreferences.
if pattern.value and vendor in self.wrap_pattern:
increment = lambda m: r'\%d' % (int(m.group(1)) + 1)
pattern.value = '(%s)' % re.sub(r'\\([0-9]+)', increment, pattern.value)

# Remove duplicate flags preserving the last occurrence of each to
# ensures that the last flag is preferred if any are contradictory.
# PY38: list() is no longer required to use reversed() on dict().
if flags and flags.value:
flags.value = ''.join(reversed(list(dict.fromkeys(reversed(flags.value)))))

if vendor in self.flag_mapping:
flags.value = flags.value.translate(self.flag_mapping[vendor])

if vendor in {'mysql', 'oracle'}:
position = Value(1)
occurrence = Value(0 if 'g' in flags.value else 1)
flags.value = flags.value.replace('g', '')
extra += (position, occurrence, flags)
elif vendor in self.inline_flags:
expressions = [expression, pattern, *extra]
if flags.value and pattern.value:
pattern.value = ('(?%s)' % flags.value.replace('g', '')) + pattern.value
else:
extra += (flags,)

self.set_source_expressions([expression, pattern, *extra])
return super().as_sql(compiler, connection, **extra_context)


class RegexpReplace(RegexpFlagMixin, Func):
function = 'REGEXP_REPLACE'
inline_flags = {'mariadb'}

def __init__(self, expression, pattern, replacement=Value(''), flags=Value(''), **extra):
super().__init__(expression, pattern, replacement, flags, **extra)


class RegexpSubstr(RegexpFlagMixin, Func):
function = 'REGEXP_SUBSTR'
inline_flags = {'mariadb', 'postgresql'}
wrap_pattern = {'postgresql'}

def __init__(self, expression, pattern, flags=Value(''), **extra):
super().__init__(expression, pattern, flags, **extra)

def as_postgresql(self, compiler, connection, **extra_context):
return super().as_sql(
compiler,
connection,
template='%(function)s(%(expressions)s)',
arg_joiner=' FROM ',
function='SUBSTRING',
**extra_context,
)


class Repeat(BytesToCharFieldConversionMixin, Func):
function = 'REPEAT'

Expand Down
62 changes: 62 additions & 0 deletions docs/ref/models/database-functions.txt
Expand Up @@ -1390,6 +1390,68 @@ Usage example::
>>> print(author.name_code_point)
77

``RegexpReplace`` and ``RegexpSubstr``
--------------------------------------

.. class:: RegexpSubstr(expression, pattern, flags=Value(''), **extra)

Returns a substring from ``expression`` matching ``pattern``.

.. class:: RegexpReplace(expression, pattern, replacement=Value(''), flags=Value(''), **extra)

Replaces matches of ``pattern`` with ``replacement`` in ``expression``. The
default replacement text is the empty string.

.. versionadded:: 3.1

A string of ``flags`` can be provided to adjust the matching and replacement
behavior:

* ``c``: Perform case-sensitive matching of ``pattern``.
* ``i``: Perform case-insensitive matching of ``pattern``.
* ``g``: Replace all occurrences of ``pattern`` instead of only the first.
* ``m``: Perform matching across multiple lines in the ``pattern``.
* ``s``: Allow the ``.`` character to match newline characters.
* ``x``: Enable extended (verbose) pattern (whitespace is ignored).

.. admonition:: Variations between database backends

Regular expression functions have varying implementations across database
backends and provide different functionality. Django attempts to smooth out
these differences as much as possible to provide as consistent an
implementation as possible. It may be that certain niche features are not
exposed or are passed through in different ways to achieve this goal.

.. admonition:: Limitations on MariaDB

MariaDB doesn't natively support passing ``flags`` as an argument, but
support is emulated by prepending ``pattern`` with inline flags, e.g.
``(?i)``. It also behaves differently to other backends in that matches are
case-insensitive by default and it is only possible to replace all matches,
not a single match.

In addition, MariaDB 10.0.11+ supports configuring the `default flags
<https://mariadb.com/kb/en/server-system-variables/#default_regex_flags>`_
for regular expressions which may affect the default behavior of this
function.

.. admonition:: Limitations on PostgreSQL

PostgreSQL enables *non-newline-sensitive* matching (the ``s`` flag) by
default. It is possible to pass the ``p`` flag which changes to *partial
newline-sensitive* matching to have the same behavior as other backends.

Usage example::

>>> from django.db.models import Value
>>> from django.db.models.functions import RegexpReplace
>>> Author.objects.create(name='J. R. R. Tolkien')
>>> Author.objects.create(name='Margaret Smith')
>>> Author.objects.update(name=RegexpReplace('name', 'R\.', Value(''), Value('g')))
1
>>> Author.objects.values('name')
<QuerySet [{'name': 'J. Tolkien'}, {'name': 'Margaret Smith'}]>

``Repeat``
----------

Expand Down
3 changes: 3 additions & 0 deletions docs/releases/3.1.txt
Expand Up @@ -272,6 +272,9 @@ Models
:meth:`~.RelatedManager.set` methods now accept callables as values in the
``through_defaults`` argument.

* Added the :class:`~django.db.models.functions.RegexpReplace` and
:class:`~django.db.models.functions.RegexpSubstr` database functions.

Pagination
~~~~~~~~~~

Expand Down
160 changes: 160 additions & 0 deletions tests/db_functions/text/test_regexpreplace.py
@@ -0,0 +1,160 @@
import unittest

from django.db import connection
from django.db.models import F, Value
from django.db.models.functions import Concat, Now, RegexpReplace
from django.test import TestCase

from ..models import Article, Author

mariadb = connection.vendor == 'mysql' and connection.mysql_is_mariadb
unsupported_mysql = (
connection.vendor == 'mysql' and
not connection.mysql_is_mariadb and
connection.mysql_version < (8, 0, 4)
)


@unittest.skipIf(unsupported_mysql, "MySQL only supports REGEXP_REPLACE() in 8.0.4+.")
class RegexpReplaceTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.author1 = Author.objects.create(name='George R. R. Martin')
cls.author2 = Author.objects.create(name='J. R. R. Tolkien')

def test_null(self):
tests = [
('alias', Value(r'(R\. ){2}'), Value('')),
('name', None, Value('')),
('name', Value(r'(R\. ){2}'), None),
]
expected = '' if connection.features.interprets_empty_strings_as_nulls else None
for field, pattern, replacement in tests:
with self.subTest(field=field, pattern=pattern, replacement=replacement):
expression = RegexpReplace(field, pattern, replacement)
author = Author.objects.annotate(replaced=expression).get(pk=self.author1.pk)
self.assertEqual(author.replaced, expected)

def test_simple(self):
# The default replacement is an empty string.
expression = RegexpReplace('name', Value(r'(R\. ){2}'))
queryset = Author.objects.annotate(without_middlename=expression)
self.assertQuerysetEqual(queryset, [
('George R. R. Martin', 'George Martin'),
('J. R. R. Tolkien', 'J. Tolkien'),
], transform=lambda x: (x.name, x.without_middlename), ordered=False)

@unittest.skipIf(mariadb, 'MariaDB is case-insensitive by default.')
def test_case_sensitive(self):
expression = RegexpReplace('name', Value(r'(r\. ){2}'), Value(''))
queryset = Author.objects.annotate(same_name=expression)
self.assertQuerysetEqual(queryset, [
('George R. R. Martin', 'George R. R. Martin'),
('J. R. R. Tolkien', 'J. R. R. Tolkien'),
], transform=lambda x: (x.name, x.same_name), ordered=False)

def test_lookahead(self):
expression = RegexpReplace('name', Value(r'(R\. ){2}(?=Martin)'), Value(''))
queryset = Author.objects.annotate(altered_name=expression)
self.assertQuerysetEqual(queryset, [
('George R. R. Martin', 'George Martin'),
('J. R. R. Tolkien', 'J. R. R. Tolkien'),
], transform=lambda x: (x.name, x.altered_name), ordered=False)

def test_lookbehind(self):
expression = RegexpReplace('name', Value(r'(?<=George )(R\. ){2}'), Value(''))
queryset = Author.objects.annotate(altered_name=expression)
self.assertQuerysetEqual(queryset, [
('George R. R. Martin', 'George Martin'),
('J. R. R. Tolkien', 'J. R. R. Tolkien'),
], transform=lambda x: (x.name, x.altered_name), ordered=False)

def test_substitution(self):
expression = RegexpReplace('name', Value(r'^(.*(?:R\. ?){2}) (.*)$'), Value(r'\2, \1'))
queryset = Author.objects.annotate(flipped_name=expression)
self.assertQuerysetEqual(queryset, [
('George R. R. Martin', 'Martin, George R. R.'),
('J. R. R. Tolkien', 'Tolkien, J. R. R.'),
], transform=lambda x: (x.name, x.flipped_name), ordered=False)

def test_expression(self):
expression = RegexpReplace(Concat(Value('Author: '), 'name'), Value(r'.*: '), Value(''))
queryset = Author.objects.annotate(same_name=expression)
self.assertQuerysetEqual(queryset, [
('George R. R. Martin', 'George R. R. Martin'),
('J. R. R. Tolkien', 'J. R. R. Tolkien'),
], transform=lambda x: (x.name, x.same_name), ordered=False)

def test_update(self):
Author.objects.update(name=RegexpReplace('name', Value(r'(R\. ){2}'), Value('')))
self.assertQuerysetEqual(Author.objects.all(), [
'George Martin',
'J. Tolkien',
], transform=lambda x: x.name, ordered=False)

@unittest.skipIf(mariadb, 'MariaDB can only replace all occurrences.')
def test_first_occurrence(self):
expression = RegexpReplace('name', Value(r'R\. '), Value(''))
queryset = Author.objects.annotate(single_middlename=expression)
self.assertQuerysetEqual(queryset, [
('George R. R. Martin', 'George R. Martin'),
('J. R. R. Tolkien', 'J. R. Tolkien'),
], transform=lambda x: (x.name, x.single_middlename), ordered=False)


@unittest.skipIf(unsupported_mysql, "MySQL only supports REGEXP_REPLACE() in 8.0.4+.")
class RegexpReplaceFlagTests(TestCase):
@classmethod
def setUpTestData(cls):
Article.objects.create(
title='Chapter One',
text='First Line.\nSecond Line.\nThird Line.',
written=Now(),
)

@unittest.skipIf(mariadb, "MariaDB doesn't support passing flags to REGEXP_REPLACE().")
def test_global_flag(self):
expression = RegexpReplace('text', Value(r'Line'), Value('Word'), Value('g'))
article = Article.objects.annotate(result=expression).first()
self.assertEqual(article.result, 'First Word.\nSecond Word.\nThird Word.')

# FIXME: This is the default behaviour on PostgreSQL.
def test_dotall_flag(self):
expression = RegexpReplace('text', Value(r'\..'), Value(', '), Value('gs'))
article = Article.objects.annotate(result=expression).first()
self.assertEqual(article.result, 'First Line, Second Line, Third Line.')

def test_multiline_flag(self):
expression = RegexpReplace('text', Value(r' Line\.$'), Value(''), Value('gm'))
article = Article.objects.annotate(result=expression).first()
self.assertEqual(article.result, 'First\nSecond\nThird')

def test_extended_flag(self):
pattern = Value(r"""
. # Match the space character
Line # Match the word "Line"
\. # Match the period.
""")
expression = RegexpReplace('text', pattern, Value(''), Value('gx'))
article = Article.objects.annotate(result=expression).first()
self.assertEqual(article.result, 'First\nSecond\nThird')

def test_case_sensitive_flag(self):
expression = RegexpReplace('title', Value(r'chapter'), Value('Section'), Value('c'))
article = Article.objects.annotate(result=expression).first()
self.assertEqual(article.result, 'Chapter One')

def test_case_insensitive_flag(self):
expression = RegexpReplace('title', Value(r'chapter'), Value('Section'), Value('i'))
article = Article.objects.annotate(result=expression).first()
self.assertEqual(article.result, 'Section One')

def test_case_sensitive_flag_preferred(self):
expression = RegexpReplace('title', Value(r'chapter'), Value('Section'), Value('ic'))
article = Article.objects.annotate(result=expression).first()
self.assertEqual(article.result, 'Chapter One')

def test_case_insensitive_flag_preferred(self):
expression = RegexpReplace('title', Value(r'Chapter'), Value('Section'), Value('ci'))
article = Article.objects.annotate(result=expression).first()
self.assertEqual(article.result, 'Section One')

0 comments on commit e54691d

Please sign in to comment.