Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 0 additions & 34 deletions .github/workflows/devskim.yml

This file was deleted.

9 changes: 1 addition & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ We hope you enjoy using the MSSQL-Django 3rd party backend.

## Features

- Supports Django 3.2 and 4.0
- Supports Django 2.2, 3.0, 3.1, 3.2 and 4.0
- Tested on Microsoft SQL Server 2016, 2017, 2019
- Passes most of the tests of the Django test suite
- Compatible with
Expand Down Expand Up @@ -67,13 +67,6 @@ in DATABASES control the behavior of the backend:

String. Database user password.

- TOKEN

String. Access token fetched as a user or service principal which
has access to the database. E.g. when using `azure.identity`, the
result of `DefaultAzureCredential().get_token('https://database.windows.net/.default')`
can be passed.

- AUTOCOMMIT

Boolean. Set this to `False` if you want to disable
Expand Down
2 changes: 2 additions & 0 deletions SUPPORT.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ This project uses GitHub Issues to track bugs and feature requests. Please searc
issues before filing new issues to avoid duplicates. For new issues, file your bug or
feature request as a new Issue.

For help and questions about using this project, please utilize the Django Developers form at https://groups.google.com/g/django-developers. Please search for an existing discussion on your topic before adding a new conversation. For new conversations, include "MSSQL" in a descriptive subject.

## Microsoft Support Policy

Support for this project is limited to the resources listed above.
30 changes: 4 additions & 26 deletions mssql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import os
import re
import time
import struct

from django.core.exceptions import ImproperlyConfigured

Expand Down Expand Up @@ -54,22 +53,7 @@ def encode_connection_string(fields):
'%s=%s' % (k, encode_value(v))
for k, v in fields.items()
)
def prepare_token_for_odbc(token):
"""
Will prepare token for passing it to the odbc driver, as it expects
bytes and not a string
:param token:
:return: packed binary byte representation of token string
"""
if not isinstance(token, str):
raise TypeError("Invalid token format provided.")

tokenstr = token.encode()
exptoken = b""
for i in tokenstr:
exptoken += bytes({i})
exptoken += bytes(1)
return struct.pack("=i", len(exptoken)) + exptoken

def encode_value(v):
"""If the value contains a semicolon, or starts with a left curly brace,
Expand Down Expand Up @@ -310,7 +294,7 @@ def get_new_connection(self, conn_params):
cstr_parts['UID'] = user
if 'Authentication=ActiveDirectoryInteractive' not in options_extra_params:
cstr_parts['PWD'] = password
elif 'TOKEN' not in conn_params:
else:
if ms_drivers.match(driver) and 'Authentication=ActiveDirectoryMsi' not in options_extra_params:
cstr_parts['Trusted_Connection'] = trusted_connection
else:
Expand Down Expand Up @@ -340,17 +324,11 @@ def get_new_connection(self, conn_params):
conn = None
retry_count = 0
need_to_retry = False
args = {
'unicode_results': unicode_results,
'timeout': timeout,
}
if 'TOKEN' in conn_params:
args['attrs_before'] = {
1256: prepare_token_for_odbc(conn_params['TOKEN'])
}
while conn is None:
try:
conn = Database.connect(connstr, **args)
conn = Database.connect(connstr,
unicode_results=unicode_results,
timeout=timeout)
except Exception as e:
for error_number in self._transient_error_numbers:
if error_number in e.args[1]:
Expand Down
70 changes: 7 additions & 63 deletions mssql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,7 @@ def as_sql(self, with_limits=True, with_col_aliases=False):
# For subqueres with an ORDER BY clause, SQL Server also
# requires a TOP or OFFSET clause which is not generated for
# Django 2.x. See https://github.com/microsoft/mssql-django/issues/12
# Add OFFSET for all Django versions.
# https://github.com/microsoft/mssql-django/issues/109
if not (do_offset or do_limit):
if django.VERSION < (3, 0, 0) and not (do_offset or do_limit):
result.append("OFFSET 0 ROWS")

# SQL Server requires the backend-specific emulation (2008 or earlier)
Expand Down Expand Up @@ -428,16 +426,6 @@ def get_returned_fields(self):
return self.returning_fields
return self.return_id

def can_return_columns_from_insert(self):
if django.VERSION >= (3, 0, 0):
return self.connection.features.can_return_columns_from_insert
return self.connection.features.can_return_id_from_insert

def can_return_rows_from_bulk_insert(self):
if django.VERSION >= (3, 0, 0):
return self.connection.features.can_return_rows_from_bulk_insert
return self.connection.features.can_return_ids_from_bulk_insert

def fix_auto(self, sql, opts, fields, qn):
if opts.auto_field is not None:
# db_column is None if not explicitly specified by model field
Expand All @@ -453,39 +441,15 @@ def fix_auto(self, sql, opts, fields, qn):

return sql

def bulk_insert_default_values_sql(self, table):
seed_rows_number = 8
cross_join_power = 4 # 8^4 = 4096 > maximum allowed batch size for the backend = 1000

def generate_seed_rows(n):
return " UNION ALL ".join("SELECT 1 AS x" for _ in range(n))

def cross_join(p):
return ", ".join("SEED_ROWS AS _%s" % i for i in range(p))

return """
WITH SEED_ROWS AS (%s)
MERGE INTO %s
USING (
SELECT TOP %s * FROM (SELECT 1 as x FROM %s) FAKE_ROWS
) FAKE_DATA
ON 1 = 0
WHEN NOT MATCHED THEN
INSERT DEFAULT VALUES
""" % (generate_seed_rows(seed_rows_number),
table,
len(self.query.objs),
cross_join(cross_join_power))

def as_sql(self):
# We don't need quote_name_unless_alias() here, since these are all
# going to be column names (so we can avoid the extra overhead).
qn = self.connection.ops.quote_name
opts = self.query.get_meta()
result = ['INSERT INTO %s' % qn(opts.db_table)]
fields = self.query.fields or [opts.pk]

if self.query.fields:
fields = self.query.fields
result.append('(%s)' % ', '.join(qn(f.column) for f in fields))
values_format = 'VALUES (%s)'
value_rows = [
Expand All @@ -506,31 +470,11 @@ def as_sql(self):

placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)

if self.get_returned_fields() and self.can_return_columns_from_insert():
if self.can_return_rows_from_bulk_insert():
if not(self.query.fields):
# There isn't really a single statement to bulk multiple DEFAULT VALUES insertions,
# so we have to use a workaround:
# https://dba.stackexchange.com/questions/254771/insert-multiple-rows-into-a-table-with-only-an-identity-column
result = [self.bulk_insert_default_values_sql(qn(opts.db_table))]
r_sql, self.returning_params = self.connection.ops.return_insert_columns(self.get_returned_fields())
if r_sql:
result.append(r_sql)
sql = " ".join(result) + ";"
return [(sql, None)]
# Regular bulk insert
params = []
r_sql, self.returning_params = self.connection.ops.return_insert_columns(self.get_returned_fields())
if r_sql:
result.append(r_sql)
params += [self.returning_params]
params += param_rows
result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
else:
result.insert(0, 'SET NOCOUNT ON')
result.append((values_format + ';') % ', '.join(placeholder_rows[0]))
params = [param_rows[0]]
result.append('SELECT CAST(SCOPE_IDENTITY() AS bigint)')
if self.get_returned_fields() and self.connection.features.can_return_id_from_insert:
result.insert(0, 'SET NOCOUNT ON')
result.append((values_format + ';') % ', '.join(placeholder_rows[0]))
params = [param_rows[0]]
result.append('SELECT CAST(SCOPE_IDENTITY() AS bigint)')
sql = [(" ".join(result), tuple(chain.from_iterable(params)))]
else:
if can_bulk:
Expand Down
1 change: 0 additions & 1 deletion mssql/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
can_introspect_small_integer_field = True
can_return_columns_from_insert = True
can_return_id_from_insert = True
can_return_rows_from_bulk_insert = True
can_rollback_ddl = True
can_use_chunked_reads = False
for_update_after_from = True
Expand Down
58 changes: 19 additions & 39 deletions mssql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
import json

from django import VERSION
from django.core import validators

from django.db import NotSupportedError, connections, transaction
from django.db.models import BooleanField, CheckConstraint, Value
from django.db.models.expressions import Case, Exists, Expression, OrderBy, When, Window
from django.db.models.fields import BinaryField, Field
from django.db.models import BooleanField, Value
from django.db.models.functions import Cast, NthValue
from django.db.models.functions.math import ATan2, Ln, Log, Mod, Round
from django.db.models.lookups import In, Lookup
from django.db.models.query import QuerySet
from django.db.models.functions.math import ATan2, Log, Ln, Mod, Round
from django.db.models.expressions import Case, Exists, OrderBy, When, Window, Expression
from django.db.models.lookups import Lookup, In
from django.db.models import lookups, CheckConstraint
from django.db.models.fields import BinaryField, Field
from django.db.models.sql.query import Query
from django.db.models.query import QuerySet
from django.core import validators

if VERSION >= (3, 1):
from django.db.models.fields.json import (
Expand Down Expand Up @@ -65,11 +67,9 @@ def sqlserver_nth_value(self, compiler, connection, **extra_content):
def sqlserver_round(self, compiler, connection, **extra_context):
return self.as_sql(compiler, connection, template='%(function)s(%(expressions)s, 0)', **extra_context)


def sqlserver_random(self, compiler, connection, **extra_context):
return self.as_sql(compiler, connection, function='RAND', **extra_context)


def sqlserver_window(self, compiler, connection, template=None):
# MSSQL window functions require an OVER clause with ORDER BY
if self.order_by is None:
Expand Down Expand Up @@ -125,13 +125,6 @@ def sqlserver_orderby(self, compiler, connection):


def split_parameter_list_as_sql(self, compiler, connection):
if connection.vendor == 'microsoft':
return mssql_split_parameter_list_as_sql(self, compiler, connection)
else:
return in_split_parameter_list_as_sql(self, compiler, connection)


def mssql_split_parameter_list_as_sql(self, compiler, connection):
# Insert In clause parameters 1000 at a time into a temp table.
lhs, _ = self.process_lhs(compiler, connection)
_, rhs_params = self.batch_process_rhs(compiler, connection)
Expand All @@ -150,29 +143,26 @@ def mssql_split_parameter_list_as_sql(self, compiler, connection):

return in_clause, ()


def unquote_json_rhs(rhs_params):
for value in rhs_params:
value = json.loads(value)
if not isinstance(value, (list, dict)):
rhs_params = [param.replace('"', '') for param in rhs_params]
return rhs_params


def json_KeyTransformExact_process_rhs(self, compiler, connection):
rhs, rhs_params = key_transform_exact_process_rhs(self, compiler, connection)
if connection.vendor == 'microsoft':
rhs_params = unquote_json_rhs(rhs_params)
return rhs, rhs_params
if isinstance(self.rhs, KeyTransform):
return super(lookups.Exact, self).process_rhs(compiler, connection)
rhs, rhs_params = super(KeyTransformExact, self).process_rhs(compiler, connection)

return rhs, unquote_json_rhs(rhs_params)

def json_KeyTransformIn(self, compiler, connection):
lhs, _ = super(KeyTransformIn, self).process_lhs(compiler, connection)
rhs, rhs_params = super(KeyTransformIn, self).process_rhs(compiler, connection)

return (lhs + ' IN ' + rhs, unquote_json_rhs(rhs_params))


def json_HasKeyLookup(self, compiler, connection):
# Process JSON path from the left-hand side.
if isinstance(self.lhs, KeyTransform):
Expand Down Expand Up @@ -203,7 +193,6 @@ def json_HasKeyLookup(self, compiler, connection):

return sql % tuple(rhs_params), []


def BinaryField_init(self, *args, **kwargs):
# Add max_length option for BinaryField, default to max
kwargs.setdefault('editable', False)
Expand All @@ -213,7 +202,6 @@ def BinaryField_init(self, *args, **kwargs):
else:
self.max_length = 'max'


def _get_check_sql(self, model, schema_editor):
if VERSION >= (3, 1):
query = Query(model=model, alias_cols=False)
Expand All @@ -222,16 +210,13 @@ def _get_check_sql(self, model, schema_editor):
where = query.build_where(self.check)
compiler = query.get_compiler(connection=schema_editor.connection)
sql, params = where.as_sql(compiler, schema_editor.connection)
if schema_editor.connection.vendor == 'microsoft':
try:
for p in params:
str(p).encode('ascii')
except UnicodeEncodeError:
sql = sql.replace('%s', 'N%s')
try:
for p in params: str(p).encode('ascii')
except UnicodeEncodeError:
sql = sql.replace('%s', 'N%s')

return sql % tuple(schema_editor.quote_value(p) for p in params)


def bulk_update_with_default(self, objs, fields, batch_size=None, default=0):
"""
Update the given fields in each of the given objects in the database.
Expand Down Expand Up @@ -270,10 +255,10 @@ def bulk_update_with_default(self, objs, fields, batch_size=None, default=0):
attr = getattr(obj, field.attname)
if not isinstance(attr, Expression):
if attr is None:
value_none_counter += 1
value_none_counter+=1
attr = Value(attr, output_field=field)
when_statements.append(When(pk=obj.pk, then=attr))
if connections[self.db].vendor == 'microsoft' and value_none_counter == len(when_statements):
if(value_none_counter == len(when_statements)):
case_statement = Case(*when_statements, output_field=field, default=Value(default))
else:
case_statement = Case(*when_statements, output_field=field)
Expand All @@ -287,15 +272,10 @@ def bulk_update_with_default(self, objs, fields, batch_size=None, default=0):
rows_updated += self.filter(pk__in=pks).update(**update_kwargs)
return rows_updated


ATan2.as_microsoft = sqlserver_atan2
# Need copy of old In.split_parameter_list_as_sql for other backends to call
in_split_parameter_list_as_sql = In.split_parameter_list_as_sql
In.split_parameter_list_as_sql = split_parameter_list_as_sql
if VERSION >= (3, 1):
KeyTransformIn.as_microsoft = json_KeyTransformIn
# Need copy of old KeyTransformExact.process_rhs to call later
key_transform_exact_process_rhs = KeyTransformExact.process_rhs
KeyTransformExact.process_rhs = json_KeyTransformExact_process_rhs
HasKeyLookup.as_microsoft = json_HasKeyLookup
Ln.as_microsoft = sqlserver_ln
Expand Down
Loading