Skip to content

Commit

Permalink
Add support for MSSQL when creating/dropping db
Browse files Browse the repository at this point in the history
  • Loading branch information
jomasti committed Nov 18, 2018
1 parent 9c4c9df commit fa7045c
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 2 deletions.
26 changes: 26 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,38 @@ def sqlite_file_dsn():
return 'sqlite:///{0}.db'.format(db_name)


@pytest.fixture
def mssql_db_user():
return os.environ.get('SQLALCHEMY_UTILS_TEST_MSSQL_USER', 'sa')


@pytest.fixture
def mssql_db_password():
return os.environ.get('SQLALCHEMY_UTILS_TEST_MSSQL_PASSWORD',
'Strong!Passw0rd')


@pytest.fixture
def mssql_db_driver():
driver = os.environ.get('SQLALCHEMY_UTILS_TEST_MSSQL_DRIVER',
'ODBC Driver 17 for SQL Server')
return driver.replace(' ', '+')


@pytest.fixture
def mssql_dsn(mssql_db_user, mssql_db_password, mssql_db_driver, db_name):
return 'mssql+pyodbc://{0}:{1}@localhost/{2}?driver={3}'\
.format(mssql_db_user, mssql_db_password, db_name, mssql_db_driver)


@pytest.fixture
def dsn(request):
if 'postgresql_dsn' in request.fixturenames:
return request.getfuncargvalue('postgresql_dsn')
elif 'mysql_dsn' in request.fixturenames:
return request.getfuncargvalue('mysql_dsn')
elif 'mssql_dsn' in request.fixturenames:
return request.getfuncargvalue('mssql_dsn')
elif 'sqlite_file_dsn' in request.fixturenames:
return request.getfuncargvalue('sqlite_file_dsn')
elif 'sqlite_memory_dsn' in request.fixturenames:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_version():
'pymysql',
'flake8>=2.4.0',
'isort>=4.2.2',
'pyodbc',
],
'anyjson': ['anyjson>=0.3.3'],
'babel': ['Babel>=1.3'],
Expand Down
14 changes: 12 additions & 2 deletions sqlalchemy_utils/functions/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,10 +530,15 @@ def create_database(url, encoding='utf8', template=None):

if url.drivername.startswith('postgres'):
url.database = 'postgres'
elif url.drivername.startswith('mssql'):
url.database = 'master'
elif not url.drivername.startswith('sqlite'):
url.database = None

engine = sa.create_engine(url)
if url.drivername == 'mssql+pyodbc':
engine = sa.create_engine(url, connect_args={'autocommit': True})
else:
engine = sa.create_engine(url)
result_proxy = None

if engine.dialect.name == 'postgresql':
Expand Down Expand Up @@ -592,10 +597,15 @@ def drop_database(url):

if url.drivername.startswith('postgres'):
url.database = 'postgres'
elif url.drivername.startswith('mssql'):
url.database = 'master'
elif not url.drivername.startswith('sqlite'):
url.database = None

engine = sa.create_engine(url)
if url.drivername == 'mssql+pyodbc':
engine = sa.create_engine(url, connect_args={'autocommit': True})
else:
engine = sa.create_engine(url)
conn_resource = None

if engine.dialect.name == 'sqlite' and database != ':memory:':
Expand Down
9 changes: 9 additions & 0 deletions tests/functions/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,12 @@ def test_create_database_twice(self, postgresql_db_user):
for dsn_item in dsn_list:
drop_database(dsn_item)
assert not database_exists(dsn_item)


@pytest.mark.usefixtures('mssql_dsn')
class TestDatabaseMssql(DatabaseTest):

@pytest.fixture
def db_name(self):
pytest.importorskip('pyodbc')
return 'db_test_sqlalchemy_util'

0 comments on commit fa7045c

Please sign in to comment.