Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions postgresql_proxy/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,29 @@ def service_connection(self, key: SelectorKeyProxy, mask):
LOG.debug('%s connection closing %s', conn.name, conn.address)
# A file object shall be unregistered prior to being closed.
sock.close()
return
except OSError as e:
# it means the socket was closed by peer
LOG.debug('%s connection closed by peer %s: %s', conn.name, conn.address, e)
self._unregister_conn(conn)
return

if mask & selectors.EVENT_WRITE:
# Socket has buffer space — flush this connection's backlogged output.
try:
while conn.out_bytes:
sent = sock.send(conn.out_bytes)
conn.sent(sent)
# All data drained; stop watching for writability.
conn.events = selectors.EVENT_READ
self.selector.modify(sock, selectors.EVENT_READ, data=conn)
except (BlockingIOError, ssl.SSLWantWriteError):
pass # Still full; will retry on the next EVENT_WRITE notification.
except OSError as e:
LOG.debug('%s closed while flushing backlog: %s', conn.name, e)
self._unregister_conn(conn)
sock.close()
return

next_conn = conn.redirect_conn
if next_conn and next_conn.out_bytes:
Expand All @@ -263,6 +282,15 @@ def service_connection(self, key: SelectorKeyProxy, mask):
LOG.debug('sending to %s:\n%s', next_conn.name, next_conn.out_bytes)
sent = next_conn.sock.send(next_conn.out_bytes)
next_conn.sent(sent)
# All sent; clear write interest if it was previously registered.
if next_conn.events & selectors.EVENT_WRITE:
next_conn.events = selectors.EVENT_READ
self.selector.modify(next_conn.sock, selectors.EVENT_READ, data=next_conn)
except (BlockingIOError, ssl.SSLWantWriteError):
# next_conn's send buffer is full — register for writability so we retry when there's space.
if not (next_conn.events & selectors.EVENT_WRITE):
next_conn.events = selectors.EVENT_READ | selectors.EVENT_WRITE
self.selector.modify(next_conn.sock, next_conn.events, data=next_conn)
except OSError:
# If one side is closed, close the other one
# this can happen in the case where the client disconnects, and postgres still return a response
Expand Down
64 changes: 64 additions & 0 deletions tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,70 @@ def test_repeated_connect_query_smoke_no_hang(postgres_settings, plain_proxy_por
assert cur.fetchone() == (i,)


@pytest.mark.timeout(60)
@pytest.mark.parametrize("sslmode", ["disable", "require"])
@pytest.mark.parametrize(
["sql", "expected"],
[
pytest.param(
"SELECT 1",
[(1,)],
id="tiny-1B",
),
pytest.param(
"SELECT repeat('x', 1024)",
[("x" * 1024,)],
id="small-1KB",
),
pytest.param(
"SELECT repeat('x', 102400)",
[("x" * 102400,)],
id="medium-100KB",
),
pytest.param(
"SELECT repeat('x', 1048576)",
[("x" * 1048576,)],
id="large-1MB",
),
pytest.param(
"SELECT repeat('x', 10485760)",
[("x" * 10485760,)],
id="xlarge-10MB",
),
pytest.param(
"SELECT i FROM generate_series(1, 10000) AS t(i)",
[(i,) for i in range(1, 10001)],
id="rows-10k",
),
pytest.param(
"SELECT i FROM generate_series(1, 100000) AS t(i)",
[(i,) for i in range(1, 100001)],
id="rows-100k",
),
]
)
def test_various_payload_sizes(
postgres_settings,
plain_proxy_port,
ssl_proxy_port,
sslmode,
sql,
expected,
):
with psycopg2.connect(
host="127.0.0.1",
port=plain_proxy_port if sslmode == "disable" else ssl_proxy_port,
user=postgres_settings["user"],
password=postgres_settings["password"],
dbname=postgres_settings["dbname"],
sslmode=sslmode,
connect_timeout=3,
) as conn:
with conn.cursor() as cur:
cur.execute(sql)
assert cur.fetchall() == expected


@pytest.mark.timeout(60)
def test_psql_ssl_file_batch_stress_no_hang(postgres_settings, ssl_proxy_port):
if shutil.which("psql") is None:
Expand Down
Loading