In [None]:
%load_ext dotenv
%dotenv
import os
import time
from multiprocessing.pool import ThreadPool
from typing import Literal

import pymssql
import pyodbc
import sqlalchemy

SERVER = os.getenv("MSSQL_HOSTNAME")
USER = os.getenv("MSSQL_USER")
PASSWORD = os.getenv("MSSQL_PASSWORD")
DATABASE = os.getenv("MSSQL_DB")
QUERY = "SELECT * FROM Devices.Identities"

def make_cnxn_pymssql():
    return pymssql.connect(
        server=SERVER,
        user=USER,
        password=PASSWORD,
        database=DATABASE,
        read_only=True,
        autocommit=True,
    )

def make_cnxn_pyodbc():
    return pyodbc.connect(
        f"SERVER={SERVER};"
        f"UID={USER};"
        f"PWD={PASSWORD};"
        f"DATABASE={DATABASE};"
        "Driver=ODBC Driver 18 for SQL Server;"
        "TrustServerCertificate=yes;",
        readonly=True,
    )


def workerfn(conn_pool: sqlalchemy.QueuePool) -> float:
    cnxn = conn_pool.connect()
    start = time.time()
    cursor = cnxn.cursor()
    cursor.execute(QUERY)
    _ = cursor.fetchall()
    cursor.close()
    cnxn.close()
    end = time.time()
    return end - start


def bench(
    num_worker: int,
    sql_driver: Literal["pymssql", "pyodbc"],
) -> float:
    connfn = make_cnxn_pyodbc if sql_driver == "pyodbc" else make_cnxn_pymssql
    conn_pool = sqlalchemy.QueuePool(
        connfn,
        pool_size=num_worker,
        max_overflow=0,
        reset_on_return=True,
        pre_ping=True,
    )

    start = time.time()
    with ThreadPool(num_worker) as pool:
        _timings_s = pool.map(workerfn, [conn_pool] * num_worker)
    end = time.time()

    pool.close()
    return end - start

from tqdm import tqdm
from itertools import product

num_workers = [1, 2, 4, 8]
sql_drivers = ["pymssql", "pyodbc"]

res = {
    "num_workers": [],
    "sql_driver": [],
    "dur_s": [],
}

param_grid = list(product(num_workers, sql_drivers))
for n, sql_driver in tqdm(param_grid):
    res["num_workers"].append(n)
    res["sql_driver"].append(sql_driver)
    dur_s = bench(n, sql_driver)
    res["dur_s"].append(dur_s)

import plotly.express as px
import pandas as pd

df = pd.DataFrame.from_dict(res)
df["num_workers"] = df["num_workers"]
p = px.line(
    df,
    x="num_workers",
    y="dur_s",
    symbol="sql_driver",
)
p.update_traces(marker=dict(size=12))
p.update_yaxes(matches=None)
p
