Skip to content

Commit

Permalink
Refs #32061 -- Unified DatabaseClient.runshell() in db backends.
Browse files Browse the repository at this point in the history
  • Loading branch information
charettes authored and felixxm committed Oct 29, 2020
1 parent 4ac2d4f commit bbe6fbb
Show file tree
Hide file tree
Showing 16 changed files with 273 additions and 207 deletions.
16 changes: 15 additions & 1 deletion django/db/backends/base/client.py
@@ -1,3 +1,7 @@
import os
import subprocess


class BaseDatabaseClient:
"""Encapsulate backend-specific methods for opening a client shell."""
# This should be a string representing the name of the executable
Expand All @@ -8,5 +12,15 @@ def __init__(self, connection):
# connection is an instance of BaseDatabaseWrapper.
self.connection = connection

@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
raise NotImplementedError(
'subclasses of BaseDatabaseClient must provide a '
'settings_to_cmd_args_env() method or override a runshell().'
)

def runshell(self, parameters):
raise NotImplementedError('subclasses of BaseDatabaseClient must provide a runshell() method')
args, env = self.settings_to_cmd_args_env(self.connection.settings_dict, parameters)
if env:
env = {**os.environ, **env}
subprocess.run(args, env=env, check=True)
10 changes: 2 additions & 8 deletions django/db/backends/mysql/client.py
@@ -1,13 +1,11 @@
import subprocess

from django.db.backends.base.client import BaseDatabaseClient


class DatabaseClient(BaseDatabaseClient):
executable_name = 'mysql'

@classmethod
def settings_to_cmd_args(cls, settings_dict, parameters):
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name]
db = settings_dict['OPTIONS'].get('db', settings_dict['NAME'])
user = settings_dict['OPTIONS'].get('user', settings_dict['USER'])
Expand Down Expand Up @@ -48,8 +46,4 @@ def settings_to_cmd_args(cls, settings_dict, parameters):
if db:
args += [db]
args.extend(parameters)
return args

def runshell(self, parameters):
args = DatabaseClient.settings_to_cmd_args(self.connection.settings_dict, parameters)
subprocess.run(args, check=True)
return args, None
12 changes: 7 additions & 5 deletions django/db/backends/mysql/creation.py
@@ -1,3 +1,4 @@
import os
import subprocess
import sys

Expand Down Expand Up @@ -55,12 +56,13 @@ def _clone_test_db(self, suffix, verbosity, keepdb=False):
self._clone_db(source_database_name, target_database_name)

def _clone_db(self, source_database_name, target_database_name):
dump_args = DatabaseClient.settings_to_cmd_args(self.connection.settings_dict, [])[1:]
dump_cmd = ['mysqldump', *dump_args[:-1], '--routines', '--events', source_database_name]
load_cmd = DatabaseClient.settings_to_cmd_args(self.connection.settings_dict, [])
cmd_args, cmd_env = DatabaseClient.settings_to_cmd_args_env(self.connection.settings_dict, [])
dump_cmd = ['mysqldump', *cmd_args[1:-1], '--routines', '--events', source_database_name]
dump_env = load_env = {**os.environ, **cmd_env} if cmd_env else None
load_cmd = cmd_args
load_cmd[-1] = target_database_name

with subprocess.Popen(dump_cmd, stdout=subprocess.PIPE) as dump_proc:
with subprocess.Popen(load_cmd, stdin=dump_proc.stdout, stdout=subprocess.DEVNULL):
with subprocess.Popen(dump_cmd, stdout=subprocess.PIPE, env=dump_env) as dump_proc:
with subprocess.Popen(load_cmd, stdin=dump_proc.stdout, stdout=subprocess.DEVNULL, env=load_env):
# Allow dump_proc to receive a SIGPIPE if the load process exits.
dump_proc.stdout.close()
15 changes: 2 additions & 13 deletions django/db/backends/oracle/base.py
Expand Up @@ -56,7 +56,7 @@ def _setup_environment(environ):
from .introspection import DatabaseIntrospection # NOQA isort:skip
from .operations import DatabaseOperations # NOQA isort:skip
from .schema import DatabaseSchemaEditor # NOQA isort:skip
from .utils import Oracle_datetime # NOQA isort:skip
from .utils import dsn, Oracle_datetime # NOQA isort:skip
from .validation import DatabaseValidation # NOQA isort:skip


Expand Down Expand Up @@ -218,17 +218,6 @@ def __init__(self, *args, **kwargs):
use_returning_into = self.settings_dict["OPTIONS"].get('use_returning_into', True)
self.features.can_return_columns_from_insert = use_returning_into

def _dsn(self):
settings_dict = self.settings_dict
if not settings_dict['HOST'].strip():
settings_dict['HOST'] = 'localhost'
if settings_dict['PORT']:
return Database.makedsn(settings_dict['HOST'], int(settings_dict['PORT']), settings_dict['NAME'])
return settings_dict['NAME']

def _connect_string(self):
return '%s/"%s"@%s' % (self.settings_dict['USER'], self.settings_dict['PASSWORD'], self._dsn())

def get_connection_params(self):
conn_params = self.settings_dict['OPTIONS'].copy()
if 'use_returning_into' in conn_params:
Expand All @@ -240,7 +229,7 @@ def get_new_connection(self, conn_params):
return Database.connect(
user=self.settings_dict['USER'],
password=self.settings_dict['PASSWORD'],
dsn=self._dsn(),
dsn=dsn(self.settings_dict),
**conn_params,
)

Expand Down
21 changes: 15 additions & 6 deletions django/db/backends/oracle/client.py
@@ -1,5 +1,4 @@
import shutil
import subprocess

from django.db.backends.base.client import BaseDatabaseClient

Expand All @@ -8,11 +7,21 @@ class DatabaseClient(BaseDatabaseClient):
executable_name = 'sqlplus'
wrapper_name = 'rlwrap'

def runshell(self, parameters):
conn_string = self.connection._connect_string()
args = [self.executable_name, "-L", conn_string]
wrapper_path = shutil.which(self.wrapper_name)
@staticmethod
def connect_string(settings_dict):
from django.db.backends.oracle.utils import dsn

return '%s/"%s"@%s' % (
settings_dict['USER'],
settings_dict['PASSWORD'],
dsn(settings_dict),
)

@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name, '-L', cls.connect_string(settings_dict)]
wrapper_path = shutil.which(cls.wrapper_name)
if wrapper_path:
args = [wrapper_path, *args]
args.extend(parameters)
subprocess.run(args, check=True)
return args, None
7 changes: 7 additions & 0 deletions django/db/backends/oracle/utils.py
Expand Up @@ -82,3 +82,10 @@ class BulkInsertMapper:
'TextField': CLOB,
'TimeField': TIMESTAMP,
}


def dsn(settings_dict):
if settings_dict['PORT']:
host = settings_dict['HOST'].strip() or 'localhost'
return Database.makedsn(host, int(settings_dict['PORT']), settings_dict['NAME'])
return settings_dict['NAME']
47 changes: 23 additions & 24 deletions django/db/backends/postgresql/client.py
@@ -1,6 +1,4 @@
import os
import signal
import subprocess

from django.db.backends.base.client import BaseDatabaseClient

Expand All @@ -9,18 +7,19 @@ class DatabaseClient(BaseDatabaseClient):
executable_name = 'psql'

@classmethod
def runshell_db(cls, conn_params, parameters):
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [cls.executable_name]

host = conn_params.get('host', '')
port = conn_params.get('port', '')
dbname = conn_params.get('database', '')
user = conn_params.get('user', '')
passwd = conn_params.get('password', '')
sslmode = conn_params.get('sslmode', '')
sslrootcert = conn_params.get('sslrootcert', '')
sslcert = conn_params.get('sslcert', '')
sslkey = conn_params.get('sslkey', '')
options = settings_dict.get('OPTIONS', {})

host = settings_dict.get('HOST')
port = settings_dict.get('PORT')
dbname = settings_dict.get('NAME') or 'postgres'
user = settings_dict.get('USER')
passwd = settings_dict.get('PASSWORD')
sslmode = options.get('sslmode')
sslrootcert = options.get('sslrootcert')
sslcert = options.get('sslcert')
sslkey = options.get('sslkey')

if user:
args += ['-U', user]
Expand All @@ -31,25 +30,25 @@ def runshell_db(cls, conn_params, parameters):
args += [dbname]
args.extend(parameters)

sigint_handler = signal.getsignal(signal.SIGINT)
subprocess_env = os.environ.copy()
env = {}
if passwd:
subprocess_env['PGPASSWORD'] = str(passwd)
env['PGPASSWORD'] = str(passwd)
if sslmode:
subprocess_env['PGSSLMODE'] = str(sslmode)
env['PGSSLMODE'] = str(sslmode)
if sslrootcert:
subprocess_env['PGSSLROOTCERT'] = str(sslrootcert)
env['PGSSLROOTCERT'] = str(sslrootcert)
if sslcert:
subprocess_env['PGSSLCERT'] = str(sslcert)
env['PGSSLCERT'] = str(sslcert)
if sslkey:
subprocess_env['PGSSLKEY'] = str(sslkey)
env['PGSSLKEY'] = str(sslkey)
return args, env

def runshell(self, parameters):
sigint_handler = signal.getsignal(signal.SIGINT)
try:
# Allow SIGINT to pass to psql to abort queries.
signal.signal(signal.SIGINT, signal.SIG_IGN)
subprocess.run(args, check=True, env=subprocess_env)
super().runshell(parameters)
finally:
# Restore the original SIGINT handler.
signal.signal(signal.SIGINT, sigint_handler)

def runshell(self, parameters):
self.runshell_db(self.connection.get_connection_params(), parameters)
19 changes: 10 additions & 9 deletions django/db/backends/sqlite3/client.py
@@ -1,15 +1,16 @@
import subprocess

from django.db.backends.base.client import BaseDatabaseClient


class DatabaseClient(BaseDatabaseClient):
executable_name = 'sqlite3'

def runshell(self, parameters):
# TODO: Remove str() when dropping support for PY37.
# args parameter accepts path-like objects on Windows since Python 3.8.
args = [self.executable_name,
str(self.connection.settings_dict['NAME'])]
args.extend(parameters)
subprocess.run(args, check=True)
@classmethod
def settings_to_cmd_args_env(cls, settings_dict, parameters):
args = [
cls.executable_name,
# TODO: Remove str() when dropping support for PY37. args
# parameter accepts path-like objects on Windows since Python 3.8.
str(settings_dict['NAME']),
*parameters,
]
return args, None
6 changes: 6 additions & 0 deletions docs/releases/3.2.txt
Expand Up @@ -477,6 +477,12 @@ backends.
``DatabaseOperations.time_trunc_sql()`` now take the optional ``tzname``
argument in order to truncate in a specific timezone.

* ``DatabaseClient.runshell()`` now gets arguments and an optional dictionary
with environment variables to the underlying command-line client from
``DatabaseClient.settings_to_cmd_args_env()`` method. Third-party database
backends must implement ``DatabaseClient.settings_to_cmd_args_env()`` or
override ``DatabaseClient.runshell()``.

:mod:`django.contrib.admin`
---------------------------

Expand Down
16 changes: 16 additions & 0 deletions tests/backends/base/test_client.py
@@ -0,0 +1,16 @@
from django.db import connection
from django.db.backends.base.client import BaseDatabaseClient
from django.test import SimpleTestCase


class SimpleDatabaseClientTests(SimpleTestCase):
def setUp(self):
self.client = BaseDatabaseClient(connection=connection)

def test_settings_to_cmd_args_env(self):
msg = (
'subclasses of BaseDatabaseClient must provide a '
'settings_to_cmd_args_env() method or override a runshell().'
)
with self.assertRaisesMessage(NotImplementedError, msg):
self.client.settings_to_cmd_args_env(None, None)
1 change: 1 addition & 0 deletions tests/backends/mysql/test_creation.py
Expand Up @@ -78,6 +78,7 @@ def test_clone_test_db_options_ordering(self):
'source_db',
],
stdout=subprocess.PIPE,
env=None,
),
])
finally:
Expand Down
5 changes: 4 additions & 1 deletion tests/backends/oracle/tests.py
Expand Up @@ -86,7 +86,10 @@ def test_password_with_at_sign(self):
old_password = connection.settings_dict['PASSWORD']
connection.settings_dict['PASSWORD'] = 'p@ssword'
try:
self.assertIn('/"p@ssword"@', connection._connect_string())
self.assertIn(
'/"p@ssword"@',
connection.client.connect_string(connection.settings_dict),
)
with self.assertRaises(DatabaseError) as context:
connection.cursor()
# Database exception: "ORA-01017: invalid username/password" is
Expand Down

0 comments on commit bbe6fbb

Please sign in to comment.