Skip to content

Commit

Permalink
Merge 3669928 into ca3b8cf
Browse files Browse the repository at this point in the history
  • Loading branch information
MiiRaGe committed Aug 19, 2017
2 parents ca3b8cf + 3669928 commit 612a46e
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 11 deletions.
5 changes: 3 additions & 2 deletions dbbackup/db/mongodb.py
@@ -1,3 +1,4 @@
from dbbackup import utils
from .base import BaseCommandDBConnector


Expand All @@ -19,7 +20,7 @@ def _create_dump(self):
if self.settings.get('USER'):
cmd += ' --username {}'.format(self.settings['USER'])
if self.settings.get('PASSWORD'):
cmd += ' --password {}'.format(self.settings['PASSWORD'])
cmd += ' --password {}'.format(utils.get_escaped_command_arg(self.settings['PASSWORD']))
for collection in self.exclude:
cmd += ' --excludeCollection {}'.format(collection)
cmd += ' --archive'
Expand All @@ -35,7 +36,7 @@ def _restore_dump(self, dump):
if self.settings.get('USER'):
cmd += ' --username {}'.format(self.settings['USER'])
if self.settings.get('PASSWORD'):
cmd += ' --password {}'.format(self.settings['PASSWORD'])
cmd += ' --password {}'.format(utils.get_escaped_command_arg(self.settings['PASSWORD']))
if self.object_check:
cmd += ' --objcheck'
if self.drop:
Expand Down
5 changes: 3 additions & 2 deletions dbbackup/db/mysql.py
@@ -1,3 +1,4 @@
from dbbackup import utils
from .base import BaseCommandDBConnector


Expand All @@ -18,7 +19,7 @@ def _create_dump(self):
if self.settings.get('USER'):
cmd += ' --user={}'.format(self.settings['USER'])
if self.settings.get('PASSWORD'):
cmd += ' --password={}'.format(self.settings['PASSWORD'])
cmd += ' --password={}'.format(utils.get_escaped_command_arg(self.settings['PASSWORD']))
for table in self.exclude:
cmd += ' --ignore-table={}.{}'.format(self.settings['NAME'], table)
cmd = '{} {} {}'.format(self.dump_prefix, cmd, self.dump_suffix)
Expand All @@ -34,7 +35,7 @@ def _restore_dump(self, dump):
if self.settings.get('USER'):
cmd += ' --user={}'.format(self.settings['USER'])
if self.settings.get('PASSWORD'):
cmd += ' --password={}'.format(self.settings['PASSWORD'])
cmd += ' --password={}'.format(utils.get_escaped_command_arg(self.settings['PASSWORD']))
cmd = '{} {} {}'.format(self.restore_prefix, cmd, self.restore_suffix)
stdout, stderr = self.run_command(cmd, stdin=dump, env=self.restore_env)
return stdout, stderr
3 changes: 2 additions & 1 deletion dbbackup/db/postgresql.py
@@ -1,3 +1,4 @@
from dbbackup import utils
from .base import BaseCommandDBConnector


Expand All @@ -15,7 +16,7 @@ class PgDumpConnector(BaseCommandDBConnector):
def run_command(self, *args, **kwargs):
if self.settings.get('PASSWORD'):
env = kwargs.get('env', {})
env['PGPASSWORD'] = self.settings['PASSWORD']
env['PGPASSWORD'] = utils.get_escaped_command_arg(self.settings['PASSWORD'])
kwargs['env'] = env
return super(PgDumpConnector, self).run_command(*args, **kwargs)

Expand Down
27 changes: 21 additions & 6 deletions dbbackup/tests/test_connectors.py
@@ -1,6 +1,7 @@
from __future__ import unicode_literals

import os
from django.db import DEFAULT_DB_ALIAS
from mock import patch, mock_open
from tempfile import SpooledTemporaryFile

Expand All @@ -25,14 +26,28 @@ def test_get_connector(self):
class BaseDBConnectorTest(TestCase):
def test_init(self):
connector = BaseDBConnector()
self.assertEqual(connector.database_name, DEFAULT_DB_ALIAS)
self.assertTrue(connector.connection)

@patch('django.db.connections')
def test_init_with_args(self, mocked_connections):
connector = BaseDBConnector(database_name='foo')
self.assertEqual(connector.database_name, 'foo')
self.assertEqual(connector.connection, mocked_connections['foo'])

def test_init_with_kwargs(self):
connector = BaseDBConnector(FoO='bar')
self.assertEqual(connector.foo, 'bar')

def test_settings(self):
connector = BaseDBConnector()
connector.settings
self.assertFalse(hasattr(connector, '_settings'))
self.assertTrue(connector.settings)
self.assertTrue(hasattr(connector, '_settings'))

def test_generate_filename(self):
connector = BaseDBConnector()
filename = connector.generate_filename()
self.assertIsNotNone(connector.generate_filename())


class BaseCommandDBConnectorTest(TestCase):
Expand Down Expand Up @@ -193,9 +208,9 @@ def test_create_dump_password(self, mock_dump_cmd):
connector.create_dump()
self.assertNotIn(' --password=', mock_dump_cmd.call_args[0][0])
# With
connector.settings['PASSWORD'] = 'foo'
connector.settings['PASSWORD'] = 'foo bar'
connector.create_dump()
self.assertIn(' --password=foo', mock_dump_cmd.call_args[0][0])
self.assertIn(' --password=\'foo bar\'', mock_dump_cmd.call_args[0][0])

def test_create_dump_exclude(self, mock_dump_cmd):
connector = MysqlDumpConnector()
Expand Down Expand Up @@ -512,9 +527,9 @@ def test_create_dump_password(self, mock_dump_cmd):
connector.create_dump()
self.assertNotIn(' --password ', mock_dump_cmd.call_args[0][0])
# With
connector.settings['PASSWORD'] = 'foo'
connector.settings['PASSWORD'] = 'foo bar'
connector.create_dump()
self.assertIn(' --password foo', mock_dump_cmd.call_args[0][0])
self.assertIn(' --password \'foo bar\'', mock_dump_cmd.call_args[0][0])

@patch('dbbackup.db.mongodb.MongoDumpConnector.run_command',
return_value=(BytesIO(), BytesIO()))
Expand Down
5 changes: 5 additions & 0 deletions dbbackup/tests/test_utils.py
Expand Up @@ -258,3 +258,8 @@ def test_template_is_callable(self, *args):
extension = 'foo'
generated_name = utils.filename_generate(extension)
self.assertTrue(generated_name.endswith('foo'))


class QuoteCommandArg(TestCase):
def test_arg_with_space(self):
assert utils.get_escaped_command_arg('foo bar') == '\'foo bar\''
9 changes: 9 additions & 0 deletions dbbackup/utils.py
Expand Up @@ -20,6 +20,11 @@
from django.http import HttpRequest
from django.utils import six, timezone

try:
from pipes import quote
except ImportError:
from shlex import quote

from . import settings

input = raw_input if six.PY2 else input # noqa
Expand Down Expand Up @@ -417,3 +422,7 @@ def filename_generate(extension, database_name='', servername=None, content_type
filename = REG_FILENAME_CLEAN.sub('-', filename)
filename = filename[1:] if filename.startswith('-') else filename
return filename


def get_escaped_command_arg(arg):
return quote(arg)

0 comments on commit 612a46e

Please sign in to comment.