-
-
Notifications
You must be signed in to change notification settings - Fork 6
/
features.py
72 lines (54 loc) · 2.31 KB
/
features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from django.db import transaction
from django.db.backends.base.features import BaseDatabaseFeatures
from django.db.backends.signals import connection_created
from django.db.utils import OperationalError
class DatabaseFeatures(BaseDatabaseFeatures):
# Does the backend support JSONField?
supports_json_field = True
# Does the backend support primitives in JSONField?
supports_primitives_in_json_field = True
# Is there a true datatype for JSON?
has_native_json_field = False
# Does the backend use PostgreSQL-style JSON operators like '->'?
has_json_operators = False
# Does the backend support __contains and __contained_by lookups for a JSONField?
supports_json_field_contains = True
# Does value__d__contains={'f': 'g'} (without a list around the dict) match
# {'d': [{'f': 'g'}]}?
json_key_contains_list_matching_requires_list = False
class MySQLFeatures(DatabaseFeatures):
def supports_json_field(self):
if self.connection.mysql_is_mariadb:
return self.connection.mysql_version >= (10, 2, 7)
return self.connection.mysql_version >= (5, 7, 8)
class OracleFeatures(DatabaseFeatures):
supports_primitives_in_json_field = False
supports_json_field_contains = False
class PostgresFeatures(DatabaseFeatures):
has_native_json_field = True
has_json_operators = True
json_key_contains_list_matching_requires_list = True
class SQLiteFeatures(DatabaseFeatures):
def supports_json_field(self):
try:
with self.connection.cursor() as cursor, transaction.atomic():
cursor.execute('SELECT JSON(\'{"a": "b"}\')')
except OperationalError:
return False
return True
supports_json_field_contains = False
feature_classes = {
"mysql": MySQLFeatures,
"oracle": OracleFeatures,
"postgresql": PostgresFeatures,
"sqlite": SQLiteFeatures,
}
feature_names = set(dir(DatabaseFeatures)) - set(dir(BaseDatabaseFeatures))
def extend_features(connection, **kwargs):
for name in feature_names:
value = feature = getattr(feature_classes[connection.vendor], name)
if callable(feature):
value = feature(connection.features)
setattr(connection.features, name, value)
def connect_signal_receivers():
connection_created.connect(extend_features)