In [None]:
%load_ext dotenv
%dotenv
import os
import time
from multiprocessing.pool import Pool, ThreadPool
from typing import Literal, Type

import pymssql
import pyodbc

SERVER = os.getenv("MSSQL_HOSTNAME")
USER = os.getenv("MSSQL_USER")
PASSWORD = os.getenv("MSSQL_PASSWORD")
DATABASE = os.getenv("MSSQL_DB")
QUERY = "SELECT * FROM Devices.Identities"


def make_cnxn(sql_driver: Literal["pymssql", "pyodbc"]):
    if sql_driver == "pymssql":
        return pymssql.connect(
            server=SERVER,
            user=USER,
            password=PASSWORD,
            database=DATABASE,
            read_only=True,
            autocommit=True,
        )
    if sql_driver == "pyodbc":
        return pyodbc.connect(
            f"SERVER={SERVER};"
            f"UID={USER};"
            f"PWD={PASSWORD};"
            f"DATABASE={DATABASE};"
            "Driver=ODBC Driver 18 for SQL Server;"
            "TrustServerCertificate=yes;"
        )
    raise ValueError(sql_driver)


def workerfn(sql_driver: Literal["pymssql", "pyodbc"]) -> float:
    cnxn = make_cnxn(sql_driver)

    start = time.time()
    cursor = cnxn.cursor()
    cursor.execute(QUERY)
    _ = cursor.fetchall()
    cursor.close()
    cnxn.close()
    end = time.time()

    return end - start


def bench(
    pool_type: Type[Pool | ThreadPool],
    num_worker: int,
    sql_driver: Literal["pymssql", "pyodbc"],
) -> float:
    start = time.time()
    with pool_type(num_worker) as pool:
        _timings_s = pool.map(workerfn, [sql_driver] * num_worker)
    end = time.time()
    return end - start

from tqdm import tqdm
from itertools import product

def pool_type_to_str(t: Type[Pool | ThreadPool]) -> str:
    if t == Pool:
        return "Processes"
    return "Threads"

pool_types = [Pool, ThreadPool]
num_workers = [1, 2, 4, 8, 16]
sql_drivers = ["pymssql", "pyodbc"]

res = {
    "parallelization": [],
    "num_workers": [],
    "sql_driver": [],
    "dur_s": [],
}

param_grid = list(product(pool_types, num_workers, sql_drivers))
for pool_type, n, sql_driver in tqdm(param_grid):
    par = pool_type_to_str(pool_type)
    res["parallelization"].append(par)
    res["num_workers"].append(n)
    res["sql_driver"].append(sql_driver)
    dur_s = bench(pool_type, n, sql_driver)
    res["dur_s"].append(dur_s)

import plotly.express as px
import pandas as pd

df = pd.DataFrame.from_dict(res)
df["num_workers"] = df["num_workers"]
p = px.line(
    df,
    x="num_workers",
    y="dur_s",
    color="parallelization",
    symbol="sql_driver",
    facet_row="parallelization",
)
p.update_traces(marker=dict(size=12))
p.update_yaxes(matches=None)
p