Skip to content

Commit ccfd5fe

Browse files
committed
connect: Add DSN parsing and tests
1 parent 2716aae commit ccfd5fe

File tree

2 files changed

+325
-30
lines changed

2 files changed

+325
-30
lines changed

asyncpg/__init__.py

Lines changed: 114 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import getpass
33
import os
4+
import urllib.parse
45

56
from .exceptions import *
67
from . import connection
@@ -10,45 +11,24 @@
1011
__all__ = ('connect',) + exceptions.__all__
1112

1213

13-
async def connect(iri=None, *,
14+
async def connect(dsn=None, *,
1415
host=None, port=None,
1516
user=None, password=None,
1617
database=None,
1718
loop=None,
18-
timeout=60):
19+
timeout=60,
20+
**kwargs):
1921

2022
if loop is None:
2123
loop = asyncio.get_event_loop()
2224

23-
# On env-var -> connection parameter conversion read here:
24-
# https://www.postgresql.org/docs/current/static/libpq-envars.html
25-
# Note that env values may be an empty string in cases when
26-
# the variable is "unset" by setting it to an empty value
27-
#
28-
if host is None:
29-
host = os.getenv('PGHOST')
30-
if not host:
31-
host = ['/tmp', '/private/tmp',
32-
'/var/pgsql_socket', '/run/postgresql',
33-
'localhost']
34-
if not isinstance(host, list):
35-
host = [host]
36-
37-
if port is None:
38-
port = os.getenv('PGPORT')
39-
if not port:
40-
port = 5432
41-
42-
if user is None:
43-
user = os.getenv('PGUSER')
44-
if not user:
45-
user = getpass.getuser()
46-
47-
if password is None:
48-
password = os.getenv('PGPASSWORD')
25+
host, port, user, password, database, kwargs = _parse_connect_params(
26+
dsn=dsn, host=host, port=port, user=user, password=password,
27+
database=database, kwargs=kwargs)
4928

50-
if database is None:
51-
database = os.getenv('PGDATABASE')
29+
if kwargs:
30+
raise RuntimeError(
31+
'arbitrary connection arguments are not yet supported')
5232

5333
last_ex = None
5434
for h in host:
@@ -85,6 +65,110 @@ async def connect(iri=None, *,
8565
return connection.Connection(pr, tr, loop)
8666

8767

68+
def _parse_connect_params(*, dsn, host, port, user,
69+
password, database, kwargs):
70+
71+
if dsn:
72+
parsed = urllib.parse.urlparse(dsn)
73+
74+
if parsed.scheme not in {'postgresql', 'postgres'}:
75+
raise ValueError(
76+
'invalid DSN: scheme is expected to be either of '
77+
'"postgresql" or "postgres", got {!r}'.format(parsed.scheme))
78+
79+
if parsed.port and port is None:
80+
port = int(parsed.port)
81+
82+
if parsed.hostname and host is None:
83+
host = parsed.hostname
84+
85+
if parsed.path and database is None:
86+
database = parsed.path
87+
if database.startswith('/'):
88+
database = database[1:]
89+
90+
if parsed.username and user is None:
91+
user = parsed.username
92+
93+
if parsed.password and password is None:
94+
password = parsed.password
95+
96+
if parsed.query:
97+
query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
98+
for key, val in query.items():
99+
if isinstance(val, list):
100+
query[key] = val[-1]
101+
102+
if 'host' in query:
103+
val = query.pop('host')
104+
if host is None:
105+
host = val
106+
107+
if 'port' in query:
108+
val = int(query.pop('port'))
109+
if port is None:
110+
port = val
111+
112+
if 'dbname' in query:
113+
val = query.pop('dbname')
114+
if database is None:
115+
database = val
116+
117+
if 'database' in query:
118+
val = query.pop('database')
119+
if database is None:
120+
database = val
121+
122+
if 'user' in query:
123+
val = query.pop('user')
124+
if user is None:
125+
user = val
126+
127+
if 'password' in query:
128+
val = query.pop('password')
129+
if password is None:
130+
password = val
131+
132+
if query:
133+
kwargs = {**query, **kwargs}
134+
135+
# On env-var -> connection parameter conversion read here:
136+
# https://www.postgresql.org/docs/current/static/libpq-envars.html
137+
# Note that env values may be an empty string in cases when
138+
# the variable is "unset" by setting it to an empty value
139+
#
140+
if host is None:
141+
host = os.getenv('PGHOST')
142+
if not host:
143+
host = ['/tmp', '/private/tmp',
144+
'/var/pgsql_socket', '/run/postgresql',
145+
'localhost']
146+
if not isinstance(host, list):
147+
host = [host]
148+
149+
if port is None:
150+
port = os.getenv('PGPORT')
151+
if port:
152+
port = int(port)
153+
else:
154+
port = 5432
155+
else:
156+
port = int(port)
157+
158+
if user is None:
159+
user = os.getenv('PGUSER')
160+
if not user:
161+
user = getpass.getuser()
162+
163+
if password is None:
164+
password = os.getenv('PGPASSWORD')
165+
166+
if database is None:
167+
database = os.getenv('PGDATABASE')
168+
169+
return host, port, user, password, database, kwargs
170+
171+
88172
def _create_future(loop):
89173
try:
90174
create_future = loop.create_future

tests/test_connect.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import asyncpg
2+
import contextlib
3+
import os
4+
import unittest
25

36
from asyncpg import _testbase as tb
47

@@ -9,3 +12,211 @@ async def test_connect_1(self):
912
with self.assertRaisesRegex(
1013
Exception, 'role "__does_not_exist__" does not exist'):
1114
await asyncpg.connect(user="__does_not_exist__", loop=self.loop)
15+
16+
17+
class TestConnectParams(unittest.TestCase):
18+
19+
TESTS = [
20+
{
21+
'env': {
22+
'PGUSER': 'user',
23+
'PGDATABASE': 'testdb',
24+
'PGPASSWORD': 'passw',
25+
'PGHOST': 'host',
26+
'PGPORT': '123'
27+
},
28+
'result': (['host'], 123, 'user', 'passw', 'testdb', {})
29+
},
30+
31+
{
32+
'env': {
33+
'PGUSER': 'user',
34+
'PGDATABASE': 'testdb',
35+
'PGPASSWORD': 'passw',
36+
'PGHOST': 'host',
37+
'PGPORT': '123'
38+
},
39+
40+
'host': 'host2',
41+
'port': '456',
42+
'user': 'user2',
43+
'password': 'passw2',
44+
'database': 'db2',
45+
46+
'result': (['host2'], 456, 'user2', 'passw2', 'db2', {})
47+
},
48+
49+
{
50+
'env': {
51+
'PGUSER': 'user',
52+
'PGDATABASE': 'testdb',
53+
'PGPASSWORD': 'passw',
54+
'PGHOST': 'host',
55+
'PGPORT': '123'
56+
},
57+
58+
'dsn': 'postgres://user3:123123@localhost/abcdef',
59+
60+
'host': 'host2',
61+
'port': '456',
62+
'user': 'user2',
63+
'password': 'passw2',
64+
'database': 'db2',
65+
66+
'result': (['host2'], 456, 'user2', 'passw2', 'db2', {})
67+
},
68+
69+
{
70+
'env': {
71+
'PGUSER': 'user',
72+
'PGDATABASE': 'testdb',
73+
'PGPASSWORD': 'passw',
74+
'PGHOST': 'host',
75+
'PGPORT': '123'
76+
},
77+
78+
'dsn': 'postgres://user3:123123@localhost:5555/abcdef',
79+
80+
'result': (['localhost'], 5555, 'user3', '123123', 'abcdef', {})
81+
},
82+
83+
{
84+
'dsn': 'postgres://user3:123123@localhost:5555/abcdef',
85+
'result': (['localhost'], 5555, 'user3', '123123', 'abcdef', {})
86+
},
87+
88+
{
89+
'dsn': 'postgresql://user3:123123@localhost:5555/'
90+
'abcdef?param=sss&param=123&host=testhost&user=testuser'
91+
'&port=2222&database=testdb',
92+
'host': '127.0.0.1',
93+
'port': '888',
94+
'user': 'me',
95+
'password': 'ask',
96+
'database': 'db',
97+
'result': (['127.0.0.1'], 888, 'me', 'ask', 'db', {'param': '123'})
98+
},
99+
100+
{
101+
'dsn': 'postgresql:///dbname?host=/unix_sock/test&user=spam',
102+
'result': (['/unix_sock/test'], 5432, 'spam', None, 'dbname', {})
103+
},
104+
105+
{
106+
'dsn': 'pq:///dbname?host=/unix_sock/test&user=spam',
107+
'error': (ValueError, 'invalid DSN')
108+
},
109+
]
110+
111+
@contextlib.contextmanager
112+
def environ(self, **kwargs):
113+
old_vals = {}
114+
for key in kwargs:
115+
if key in os.environ:
116+
old_vals[key] = os.environ[key]
117+
118+
for key, val in kwargs.items():
119+
if val is None:
120+
if key in os.environ:
121+
del os.environ[key]
122+
else:
123+
os.environ[key] = val
124+
125+
try:
126+
yield
127+
finally:
128+
for key in kwargs:
129+
if key in os.environ:
130+
del os.environ[key]
131+
for key, val in old_vals.items():
132+
os.environ[key] = val
133+
134+
def run_testcase(self, testcase):
135+
env = testcase.get('env', {})
136+
test_env = {'PGHOST': None, 'PGPORT': None,
137+
'PGUSER': None, 'PGPASSWORD': None,
138+
'PGDATABASE': None}
139+
test_env.update(env)
140+
141+
dsn = testcase.get('dsn')
142+
kwargs = testcase.get('kwargs', {})
143+
user = testcase.get('user')
144+
port = testcase.get('port')
145+
host = testcase.get('host')
146+
password = testcase.get('password')
147+
database = testcase.get('database')
148+
149+
expected = testcase.get('result')
150+
expected_error = testcase.get('error')
151+
if expected is None and expected_error is None:
152+
raise RuntimeError(
153+
'invalid test case: either "result" or "error" key '
154+
'has to be specified')
155+
if expected is not None and expected_error is not None:
156+
raise RuntimeError(
157+
'invalid test case: either "result" or "error" key '
158+
'has to be specified, got both')
159+
160+
with contextlib.ExitStack() as es:
161+
es.enter_context(self.subTest(dsn=dsn, kwargs=kwargs, env=env))
162+
es.enter_context(self.environ(**test_env))
163+
164+
if expected_error:
165+
es.enter_context(self.assertRaisesRegex(*expected_error))
166+
167+
result = asyncpg._parse_connect_params(
168+
dsn=dsn, host=host, port=port, user=user, password=password,
169+
database=database, kwargs=kwargs)
170+
171+
if expected is not None:
172+
self.assertEqual(expected, result)
173+
174+
def test_test_connect_params_environ(self):
175+
self.assertNotIn('AAAAAAAAAA123', os.environ)
176+
self.assertNotIn('AAAAAAAAAA456', os.environ)
177+
self.assertNotIn('AAAAAAAAAA789', os.environ)
178+
179+
try:
180+
181+
os.environ['AAAAAAAAAA456'] = '123'
182+
os.environ['AAAAAAAAAA789'] = '123'
183+
184+
with self.environ(AAAAAAAAAA123='1',
185+
AAAAAAAAAA456='2',
186+
AAAAAAAAAA789=None):
187+
188+
self.assertEqual(os.environ['AAAAAAAAAA123'], '1')
189+
self.assertEqual(os.environ['AAAAAAAAAA456'], '2')
190+
self.assertNotIn('AAAAAAAAAA789', os.environ)
191+
192+
self.assertNotIn('AAAAAAAAAA123', os.environ)
193+
self.assertEqual(os.environ['AAAAAAAAAA456'], '123')
194+
self.assertEqual(os.environ['AAAAAAAAAA789'], '123')
195+
196+
finally:
197+
for key in {'AAAAAAAAAA123', 'AAAAAAAAAA456', 'AAAAAAAAAA789'}:
198+
if key in os.environ:
199+
del os.environ[key]
200+
201+
def test_test_connect_params_run_testcase(self):
202+
with self.environ(PGPORT='777'):
203+
self.run_testcase({
204+
'env': {
205+
'PGUSER': '__test__'
206+
},
207+
'host': 'abc',
208+
'result': (['abc'], 5432, '__test__', None, None, {})
209+
})
210+
211+
with self.assertRaises(AssertionError):
212+
self.run_testcase({
213+
'env': {
214+
'PGUSER': '__test__'
215+
},
216+
'host': 'abc',
217+
'result': (['abc'], 5432, 'wrong_user', None, None, {})
218+
})
219+
220+
def test_connect_params(self):
221+
for testcase in self.TESTS:
222+
self.run_testcase(testcase)

0 commit comments

Comments
 (0)