## Parameterized pgbench Test for Databricks Jobs
This notebook runs pgbench tests with configurable parameters for use in Databricks Jobs.


In [None]:
## Setup Requirements
%pip install --upgrade databricks-sdk


In [None]:
dbutils.library.restartPython()


## Install pgbench


In [None]:
%sh
apt-get update && apt-get install -y wget gnupg lsb-release
sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" \
    > /etc/apt/sources.list.d/pgdg.list'
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add -
apt-get update
apt-get install -y postgresql-client-15


In [None]:
%sh
apt-get install -y postgresql-contrib-15


In [None]:
%sh
pgbench --version


## Parse Job Parameters


In [None]:
import os, subprocess, re, glob, numpy as np, json
from databricks.sdk import WorkspaceClient
import uuid, shutil as _shutil

# Parse job parameters with defaults
def get_param(param_name, default_value, param_type=str):
    """Get parameter from dbutils.widgets or use default"""
    try:
        value = dbutils.widgets.get(param_name)
        if value:
            if param_type == int:
                return int(value)
            elif param_type == float:
                return float(value)
            elif param_type == bool:
                return value.lower() in ['true', '1', 'yes']
            else:
                return value
    except:
        pass
    return default_value

# Job parameters
LAKEBASE_INSTANCE_NAME = get_param("lakebase_instance_name", "ak-lakebase-accelerator-instance")
DATABASE_NAME = get_param("database_name", "databricks_postgres")
PGBENCH_CLIENTS = get_param("pgbench_clients", 8, int)
PGBENCH_JOBS = get_param("pgbench_jobs", 8, int)
PGBENCH_DURATION = get_param("pgbench_duration", 30, int)
PGBENCH_PROGRESS_INTERVAL = get_param("pgbench_progress_interval", 5, int)
PGBENCH_PROTOCOL = get_param("pgbench_protocol", "prepared")
PGBENCH_PER_STATEMENT_LATENCY = get_param("pgbench_per_statement_latency", True, bool)
PGBENCH_DETAILED_LOGGING = get_param("pgbench_detailed_logging", True, bool)
PGBENCH_CONNECT_PER_TRANSACTION = get_param("pgbench_connect_per_transaction", False, bool)

# Query configuration (JSON string)
QUERY_CONFIG_JSON = get_param("query_config", '[]')

print(f"Parameters:")
print(f"  Lakebase Instance: {LAKEBASE_INSTANCE_NAME}")
print(f"  Database: {DATABASE_NAME}")
print(f"  Clients: {PGBENCH_CLIENTS}")
print(f"  Jobs: {PGBENCH_JOBS}")
print(f"  Duration: {PGBENCH_DURATION}s")
print(f"  Protocol: {PGBENCH_PROTOCOL}")
print(f"  Query Config: {QUERY_CONFIG_JSON}")


## Setup Connection and Queries


In [None]:
# -------------------------
# 1) Connection env
# -------------------------
w = WorkspaceClient()
instance = w.database.get_database_instance(name=LAKEBASE_INSTANCE_NAME)
cred = w.database.generate_database_credential(request_id=str(uuid.uuid4()),
                                               instance_names=[LAKEBASE_INSTANCE_NAME])

env = os.environ.copy()
env.update({
    "PGHOST": instance.read_write_dns,
    "PGPORT": "5432",
    "PGDATABASE": DATABASE_NAME,
    "PGUSER": w.current_user.me().user_name,
    "PGPASSWORD": cred.token,
    "PGSSLMODE": "require",
})

print("pgbench at:", _shutil.which("pgbench"))

# -------------------------
# 2) Parse query configuration and write scripts locally
# -------------------------
workdir = "/databricks/driver/pgbench_mix"
os.makedirs(workdir, exist_ok=True)

try:
    query_configs = json.loads(QUERY_CONFIG_JSON)
except json.JSONDecodeError:
    print("Invalid query config JSON, using default queries")
    query_configs = []

# Default queries if none provided
if not query_configs:
    query_configs = [
        {
            "name": "point",
            "content": "\\set c_customer_sk random(0, 999)\nSELECT *\nFROM databricks_postgres.public.customer\nWHERE c_customer_sk = :c_customer_sk;",
            "weight": 60
        },
        {
            "name": "range",
            "content": "\\set c_current_hdemo_sk random(1, 700)\nSELECT count(*)\nFROM databricks_postgres.public.customer\nWHERE c_current_hdemo_sk BETWEEN :c_current_hdemo_sk AND :c_current_hdemo_sk + 1000;",
            "weight": 30
        },
        {
            "name": "agg",
            "content": "SELECT c_preferred_cust_flag, count(*)\nFROM databricks_postgres.public.customer\nGROUP BY c_preferred_cust_flag;",
            "weight": 10
        }
    ]

# Write query files
query_files = []
for query_config in query_configs:
    query_name = query_config.get("name", "query")
    query_content = query_config.get("content", "")
    
    query_path = os.path.join(workdir, f"{query_name}.sql")
    with open(query_path, "w") as f:
        f.write(query_content.strip() + "\n")
    
    query_files.append((query_path, query_config.get("weight", 1)))
    print(f"Created query file: {query_path}")

# Verify all files exist
for query_path, _ in query_files:
    assert os.path.exists(query_path), f"Missing script: {query_path}"


## Execute pgbench Test


In [None]:
# -------------------------
# 3) Build pgbench command
# -------------------------
cmd = [
    "pgbench",
    "-n",  # no vacuuming
    "-c", str(PGBENCH_CLIENTS),
    "-j", str(PGBENCH_JOBS),
    "-T", str(PGBENCH_DURATION),
    "-P", str(PGBENCH_PROGRESS_INTERVAL),
    "-M", PGBENCH_PROTOCOL,
]

# Add optional flags
if PGBENCH_PER_STATEMENT_LATENCY:
    cmd.append("-r")

if PGBENCH_DETAILED_LOGGING:
    cmd.append("-l")

if PGBENCH_CONNECT_PER_TRANSACTION:
    cmd.append("-C")

# Add query files with weights (simulate weights by repeating -f)
for query_path, weight in query_files:
    for _ in range(int(weight)):
        cmd.extend(["-f", query_path])

print(f"pgbench command: {' '.join(cmd)}")

# -------------------------
# 4) Run & parse output
# -------------------------
print("\n=== Starting pgbench test ===")
res = subprocess.run(cmd, capture_output=True, text=True, env=env, cwd=workdir)

print("=== STDOUT ===\n", res.stdout)
if res.stderr:
    print("=== STDERR ===\n", res.stderr)

if res.returncode != 0:
    raise SystemExit(f"pgbench failed (exit {res.returncode}). See STDERR above. Workdir: {workdir}")

# Parse TPS
m = re.search(r"tps\s*=\s*([\d\.]+)", res.stdout)
tps = float(m.group(1)) if m else None
print(f"\n=== RESULTS ===")
print(f"TPS: {tps}")

# Parse latencies from log files
latencies = []
for path in glob.glob(os.path.join(workdir, "pgbench_log.*")):
    with open(path) as f:
        for line in f:
            parts = line.split()
            if parts:
                try:
                    latencies.append(float(parts[-1]))  # last col = latency ms
                except ValueError:
                    pass

if latencies:
    p50, p95, p99 = np.percentile(latencies, [50, 95, 99])
    print(f"Latency p50/p95/p99 (ms): {p50:.3f} / {p95:.3f} / {p99:.3f}")
    print(f"Total transactions: {len(latencies)}")
else:
    print("No pgbench_log.* found or no latencies parsed.")

print(f"\nLogs & scripts available at: {workdir}")

# -------------------------
# 5) Store results for job output
# -------------------------
results = {
    "test_parameters": {
        "lakebase_instance": LAKEBASE_INSTANCE_NAME,
        "database_name": DATABASE_NAME,
        "clients": PGBENCH_CLIENTS,
        "jobs": PGBENCH_JOBS,
        "duration_seconds": PGBENCH_DURATION,
        "protocol": PGBENCH_PROTOCOL,
        "query_count": len(query_configs)
    },
    "performance_metrics": {
        "tps": tps,
        "total_transactions": len(latencies),
        "latency_p50_ms": p50 if latencies else None,
        "latency_p95_ms": p95 if latencies else None,
        "latency_p99_ms": p99 if latencies else None
    },
    "test_status": "completed" if res.returncode == 0 else "failed",
    "raw_output": res.stdout
}

# Save results to a file that can be accessed by the job
results_path = os.path.join(workdir, "pgbench_results.json")
with open(results_path, "w") as f:
    json.dump(results, f, indent=2)

print(f"\n=== Test completed successfully! ===")
print(f"Results saved to: {results_path}")
print(json.dumps(results, indent=2))
