Skip to content

Commit 5b4a488

Browse files
committed
Move psql CLI orchestration off DatabaseConnection
runshell, executable_name, and the psql argv/env builder were CLI-layer concerns bolted onto DatabaseConnection for Django-historical reasons — they never used the psycopg connection, only read settings_dict. Moving them out lets DatabaseConnection drop three stdlib imports (os, signal, subprocess) and shrink toward its actual job (transactions, cursors, lifecycle). Extract postgres_cli_args() and postgres_cli_env() as public helpers in plain.postgres.database_url. The plain postgres shell command composes them inline with its SIGINT dance. PostgresBackupClient in plain-dev had ~35 lines of the same DatabaseConfig → argv/env mapping duplicated; it now calls the shared helpers.
1 parent 319f6ac commit 5b4a488

4 files changed

Lines changed: 52 additions & 98 deletions

File tree

plain-dev/plain/dev/backups/clients.py

Lines changed: 5 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import TYPE_CHECKING
77

88
from plain.exceptions import ImproperlyConfigured
9+
from plain.postgres.database_url import postgres_cli_args, postgres_cli_env
910

1011
if TYPE_CHECKING:
1112
from plain.postgres.connection import DatabaseConnection
@@ -15,52 +16,17 @@ class PostgresBackupClient:
1516
def __init__(self, connection: DatabaseConnection) -> None:
1617
self.connection = connection
1718

18-
def get_env(self) -> dict[str, str]:
19-
settings_dict = self.connection.settings_dict
20-
options = settings_dict.get("OPTIONS", {})
21-
env: dict[str, str] = {}
22-
23-
if password := settings_dict.get("PASSWORD"):
24-
env["PGPASSWORD"] = str(password)
25-
26-
# Map OPTIONS keys to their corresponding environment variables.
27-
option_env_vars = {
28-
"passfile": "PGPASSFILE",
29-
"sslmode": "PGSSLMODE",
30-
"sslrootcert": "PGSSLROOTCERT",
31-
"sslcert": "PGSSLCERT",
32-
"sslkey": "PGSSLKEY",
33-
}
34-
for option_key, env_var in option_env_vars.items():
35-
if value := options.get(option_key):
36-
env[env_var] = str(value)
37-
38-
return env
39-
40-
def _get_conn_args(self) -> list[str]:
41-
"""Build common connection CLI args from settings."""
42-
settings_dict = self.connection.settings_dict
43-
args: list[str] = []
44-
if user := settings_dict.get("USER"):
45-
args += ["-U", user]
46-
if host := settings_dict.get("HOST"):
47-
args += ["-h", host]
48-
if port := settings_dict.get("PORT"):
49-
args += ["-p", str(port)]
50-
return args
51-
5219
def _run(self, cmd: str | list[str], *, shell: bool = False) -> None:
53-
subprocess.run(
54-
cmd, env={**os.environ, **self.get_env()}, check=True, shell=shell
55-
)
20+
env = {**os.environ, **postgres_cli_env(self.connection.settings_dict)}
21+
subprocess.run(cmd, env=env, check=True, shell=shell)
5622

5723
def create_backup(self, backup_path: Path, *, pg_dump: str = "pg_dump") -> None:
5824
settings_dict = self.connection.settings_dict
5925
dbname = settings_dict.get("DATABASE")
6026
if not dbname:
6127
raise ImproperlyConfigured("POSTGRES_DATABASE is required in settings")
6228

63-
args = pg_dump.split() + self._get_conn_args()
29+
args = pg_dump.split() + postgres_cli_args(settings_dict)
6430
args += ["-Fc", dbname]
6531

6632
# Pipe through gzip for compression
@@ -75,7 +41,7 @@ def restore_backup(
7541
if not dbname:
7642
raise ImproperlyConfigured("POSTGRES_DATABASE is required in settings")
7743

78-
conn_args = self._get_conn_args()
44+
conn_args = postgres_cli_args(settings_dict)
7945

8046
# Drop and recreate the database via template1
8147
drop_create_cmds = [

plain-postgres/plain/postgres/cli/core.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import os
4+
import signal
35
import subprocess
46
import sys
57
import time
@@ -10,6 +12,7 @@
1012

1113
from plain.cli import register_cli
1214

15+
from ..database_url import postgres_cli_args, postgres_cli_env
1316
from ..db import get_connection
1417
from ..dialect import quote_name
1518
from .converge import converge
@@ -36,16 +39,20 @@ def cli() -> None:
3639
@database_management_command
3740
def shell(parameters: tuple[str, ...]) -> None:
3841
"""Open an interactive database shell"""
39-
conn = get_connection()
42+
config = get_connection().settings_dict
43+
args = ["psql", *postgres_cli_args(config), *parameters, config["DATABASE"]]
44+
env = {**os.environ, **postgres_cli_env(config)}
45+
sigint_handler = signal.getsignal(signal.SIGINT)
4046
try:
41-
conn.runshell(list(parameters))
47+
# Allow SIGINT to pass to psql to abort queries.
48+
signal.signal(signal.SIGINT, signal.SIG_IGN)
49+
subprocess.run(args, env=env, check=True)
4250
except FileNotFoundError:
43-
# Note that we're assuming the FileNotFoundError relates to the
44-
# command missing. It could be raised for some other reason, in
45-
# which case this error message would be inaccurate. Still, this
46-
# message catches the common case.
51+
# FileNotFoundError almost always means psql isn't installed or on
52+
# PATH, but could be raised for other reasons — the message covers
53+
# the common case.
4754
click.secho(
48-
f"You appear not to have the {conn.executable_name!r} program installed or on your path.",
55+
"You appear not to have the 'psql' program installed or on your path.",
4956
fg="red",
5057
err=True,
5158
)
@@ -60,6 +67,8 @@ def shell(parameters: tuple[str, ...]) -> None:
6067
err=True,
6168
)
6269
sys.exit(e.returncode)
70+
finally:
71+
signal.signal(signal.SIGINT, sigint_handler)
6372

6473

6574
@cli.command("drop-unknown-tables")

plain-postgres/plain/postgres/connection.py

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
from __future__ import annotations
22

33
import _thread
4-
import os
5-
import signal
6-
import subprocess
74
import warnings
85
from collections import deque
96
from collections.abc import Generator, Sequence
@@ -57,41 +54,6 @@ class TableInfo(NamedTuple):
5754
comment: str | None
5855

5956

60-
def _psql_settings_to_cmd_args_env(
61-
settings_dict: DatabaseConfig, parameters: list[str]
62-
) -> tuple[list[str], dict[str, str] | None]:
63-
"""Build psql command-line arguments from database settings."""
64-
args = ["psql"]
65-
options = settings_dict.get("OPTIONS", {})
66-
67-
if user := settings_dict.get("USER"):
68-
args += ["-U", user]
69-
if host := settings_dict.get("HOST"):
70-
args += ["-h", host]
71-
if port := settings_dict.get("PORT"):
72-
args += ["-p", str(port)]
73-
args.extend(parameters)
74-
args += [settings_dict["DATABASE"]]
75-
76-
env: dict[str, str] = {}
77-
if password := settings_dict.get("PASSWORD"):
78-
env["PGPASSWORD"] = str(password)
79-
80-
# Map OPTIONS keys to their corresponding environment variables.
81-
option_env_vars = {
82-
"passfile": "PGPASSFILE",
83-
"sslmode": "PGSSLMODE",
84-
"sslrootcert": "PGSSLROOTCERT",
85-
"sslcert": "PGSSLCERT",
86-
"sslkey": "PGSSLKEY",
87-
}
88-
for option_key, env_var in option_env_vars.items():
89-
if value := options.get(option_key):
90-
env[env_var] = str(value)
91-
92-
return args, (env or None)
93-
94-
9557
class DatabaseConnection:
9658
"""
9759
PostgreSQL database connection.
@@ -100,7 +62,6 @@ class DatabaseConnection:
10062
"""
10163

10264
queries_limit: int = 9000
103-
executable_name: str = "psql"
10465

10566
index_default_access_method = "btree"
10667
ignored_tables: list[str] = []
@@ -415,19 +376,6 @@ def schema_editor(self, *args: Any, **kwargs: Any) -> DatabaseSchemaEditor:
415376
"""Return a new instance of the schema editor."""
416377
return DatabaseSchemaEditor(self, *args, **kwargs)
417378

418-
def runshell(self, parameters: list[str]) -> None:
419-
"""Run an interactive psql shell."""
420-
args, env = _psql_settings_to_cmd_args_env(self.settings_dict, parameters)
421-
env = {**os.environ, **env} if env else None
422-
sigint_handler = signal.getsignal(signal.SIGINT)
423-
try:
424-
# Allow SIGINT to pass to psql to abort queries.
425-
signal.signal(signal.SIGINT, signal.SIG_IGN)
426-
subprocess.run(args, env=env, check=True)
427-
finally:
428-
# Restore the original SIGINT handler.
429-
signal.signal(signal.SIGINT, sigint_handler)
430-
431379
def on_commit(self, func: Any, robust: bool = False) -> None:
432380
if not callable(func):
433381
raise TypeError("on_commit()'s callback must be a callable.")

plain-postgres/plain/postgres/database_url.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,37 @@ def parse_database_url(url: str) -> DatabaseConfig:
7777
}
7878

7979

80+
_CLI_FLAGS: list[tuple[str, str]] = [("USER", "-U"), ("HOST", "-h"), ("PORT", "-p")]
81+
_CLI_OPTION_ENV_VARS: dict[str, str] = {
82+
"passfile": "PGPASSFILE",
83+
"sslmode": "PGSSLMODE",
84+
"sslrootcert": "PGSSLROOTCERT",
85+
"sslcert": "PGSSLCERT",
86+
"sslkey": "PGSSLKEY",
87+
}
88+
89+
90+
def postgres_cli_args(config: DatabaseConfig) -> list[str]:
91+
"""Build connection flags for libpq-based tools (psql, pg_dump, pg_restore)."""
92+
args: list[str] = []
93+
for key, flag in _CLI_FLAGS:
94+
if value := config.get(key):
95+
args += [flag, str(value)]
96+
return args
97+
98+
99+
def postgres_cli_env(config: DatabaseConfig) -> dict[str, str]:
100+
"""Build env vars for libpq-based tools (psql, pg_dump, pg_restore)."""
101+
env: dict[str, str] = {}
102+
if password := config.get("PASSWORD"):
103+
env["PGPASSWORD"] = str(password)
104+
options = config.get("OPTIONS", {})
105+
for option_key, env_var in _CLI_OPTION_ENV_VARS.items():
106+
if value := options.get(option_key):
107+
env[env_var] = str(value)
108+
return env
109+
110+
80111
def build_database_url(config: DatabaseConfig) -> str:
81112
"""Build a database URL from a configuration dictionary."""
82113
options = config.get("OPTIONS", {})

0 commit comments

Comments
 (0)