Content
- [Database inspection](#Database_inspection)
- [Vacuum](#Vacuum)
- [Polars SQL read and write](#Polars_SQL_read_and_write)
- [Metadata](#Metadata)
- [ORM](#ORM)
- [Read table in parallel batches](#Read_table_in_parallel_batches)

Sqlalchemy is useful for database inspection, such as getting primary keys, indices, data types. It is also handy for building database-agnostic queries. It doesn't provide support for all statements though, such as ALTER, which are easier to build manually with the use of f-strings.

In [271]:
import functools
from typing import Literal

import joblib
import numpy as np
import oracledb
import pandas as pd
import polars as pl
import psycopg2
import sqlalchemy.orm
import sqlalchemy as sa
import tqdm

In [317]:
def get_oracle_driver_name():
    return "oracle+oracledb"


def get_postgres_driver_name():
    return "postgresql+psycopg2"


def to_autocommit_engine(engine):
    autocommit_engine = engine.execution_options(isolation_level="AUTOCOMMIT")
    return autocommit_engine


def execute_query(engine: sa.Engine, query):
    query = sa.text(query) if isinstance(query, str) else query
    with engine.connect() as connection:
        result = connection.execute(query)

    return result

def get_result_item(result):
    result_item = result.scalars().first()
    return result_item

def compile_statement(
    statement, dialect: Literal["postgres", "oracle"] = "postgres"
) -> str:
    match dialect:
        case "postgres":
            dialect = sa.dialects.postgresql.psycopg2.dialect()
        case "oracle":
            dialect = sa.dialects.oracle.cx_oracle.dialect()
        case _:
            raise ValueError(f"Unknown dialect {dialect}.")

    statement = str(
        statement.compile(dialect=dialect, compile_kwargs={"literal_binds": True})
    )
    return statement


def reflect_table(
    engine: sa.Engine, table: str, metadata: sa.MetaData | None = None, **kwargs
) -> sa.Table:
    table = sa.Table(table, metadata or sa.MetaData(), autoload_with=engine, **kwargs)
    return table


def get_primary_key(table: sa.Table) -> list[str]:
    primary_key = table.primary_key.columns.keys()
    return primary_key


def get_table_datatypes(
    table: sa.Table, as_python_dtypes: bool = False, dbapi=None
) -> dict:
    """Dbapi can be the psycopg2 module, for example."""
    table_datatypes = {}
    for column in table.columns:
        if as_python_dtypes:
            datatype = column.type.python_type
        elif dbapi:
            datatype = column.type.get_dbapi_type(dbapi)
        else:
            datatype = column.type

        table_datatypes[column.name] = datatype

    return table_datatypes


def get_table_size(engine, table):
    with engine.connect() as connection:
        result = connection.execute(
            sa.select(sa.func.count()).select_from(sa.text(table))
        )
        size = result.scalars().first()

    return size


def is_column_unique(engine, sa_column):
    result = execute_query(engine, sa.select(sa.func.count(sa_column), sa.func.count(sa_column.distinct())))
    count, distinct_count = result.first()
    return count == distinct_count


def get_number_of_batches(size, batch_size):
    n_batches = (size - 1) // batch_size + 1
    return n_batches

In [273]:
def generate_table(size: int):
    df = pd.DataFrame(
        {"id": range(size), "x": np.random.randint(0, 10, size), "y": "string"}
    )
    return df


def select_all(table):
    return f"SELECT * FROM {table}"

In [281]:
database_url = sqlalchemy.URL.create(
    drivername=get_postgres_driver_name(),
    database="learning",
)
engine = sqlalchemy.create_engine(database_url)
inspector = sa.inspect(engine)

In [332]:
with engine.begin() as connection:
    connection.execute(sqlalchemy.text("DROP TABLE IF EXISTS point;"))
    connection.execute(sqlalchemy.text("CREATE TABLE point (x int, y int);"))
    connection.execute(
        sqlalchemy.text("INSERT INTO point (x, y) VALUES (:x, :y)"),
        [{"x": 1, "y": 1}, {"x": 2, "y": 4}],
    )

with sqlalchemy.orm.Session(engine) as session:
    result = session.execute(sqlalchemy.text("SELECT * FROM POINT"))

is_column_unique(engine, reflect_table(engine, "random").columns["id"])

True

### Database inspection

In [293]:
(
    inspector.dialect,
    inspector.get_schema_names(),
    inspector.default_schema_name,
    inspector.get_table_names(),
    inspector.get_table_oid("random"),
    inspector.get_view_definition,
    inspector.get_columns("random"),
    inspector.get_indexes("point"),
    inspector.get_pk_constraint("point"),
    inspector.has_table("random", schema="public"),
);

### Vacuum

In [398]:
autocommit_engine = to_autocommit_engine(engine)
autocommit_engine.echo = True
execute_query(autocommit_engine, "VACUUM ANALYZE point;")

2024-06-16 15:34:56,129 INFO sqlalchemy.engine.Engine BEGIN (implicit; DBAPI should not BEGIN due to autocommit mode)
2024-06-16 15:34:56,130 INFO sqlalchemy.engine.Engine VACUUM ANALYZE point;
2024-06-16 15:34:56,131 INFO sqlalchemy.engine.Engine [generated in 0.00166s] {}
2024-06-16 15:34:56,145 INFO sqlalchemy.engine.Engine ROLLBACK using DBAPI connection.rollback(), DBAPI should ignore due to autocommit mode


<sqlalchemy.engine.cursor.CursorResult at 0x36aa8c3d0>

### Polars_SQL_read_and_write 

In [55]:
df = pl.read_database(select_all("point"), connection=engine)
pl.concat([df, df]).write_database(
    "point", connection=engine.url, if_table_exists="append"
)

12

In [None]:
requested_batch_size = 2
for df in pl.read_database(
    select_all("point"),
    connection=engine,
    iter_batches=True,
    batch_size=requested_batch_size,
    infer_schema_length=0,  # Number of rows to be scanned for schema inference.
    schema_overrides={"x": pl.Int16, "y": int},
):
    actual_batch_size = len(df)
    break

print(
    f"Requested batch size is {requested_batch_size}, actual batch size is {actual_batch_size}.\n"
    f"Note that they are not guaranteed to be equal."
)

Requested batch size is 2, actual batch size is 2.
Note that they are not guaranteed to be equal.


In [234]:
df = generate_table(10**6)
pdf = pl.from_dataframe(df)

In [235]:
%%time
df.to_sql("random", engine, if_exists="replace", index=False)

CPU times: user 5.5 s, sys: 91.4 ms, total: 5.59 s
Wall time: 7.5 s


1000

In [205]:
%%time
pdf.write_database("random", engine.url, if_table_exists="replace")

CPU times: user 5.67 s, sys: 131 ms, total: 5.8 s
Wall time: 7.81 s


1000

In [230]:
%%time
for _ in pd.read_sql(select_all("random"), engine, chunksize=10**3):
    pass

CPU times: user 970 ms, sys: 31.8 ms, total: 1 s
Wall time: 1.09 s


In [229]:
%%time
for _ in pl.read_database(
    select_all("random"), engine, iter_batches=True, batch_size=10**3
):
    pass

CPU times: user 572 ms, sys: 24.6 ms, total: 597 ms
Wall time: 697 ms


### Metadata

In [81]:
metadata = sqlalchemy.MetaData()
table = sqlalchemy.Table("point", metadata, schema="public", autoload_with=engine)

"point" in metadata.tables
table.indexes, table.primary_key, table.foreign_keys;

### ORM

In [6]:
class Base(sqlalchemy.orm.DeclarativeBase):
    pass


class Point(Base):
    __tablename__ = "point"

    id: sqlalchemy.orm.Mapped[int] = sqlalchemy.orm.mapped_column(primary_key=True)
    x: sqlalchemy.orm.Mapped[int]
    y: sqlalchemy.orm.Mapped[int]


Point.metadata, Point.x

(MetaData(), <sqlalchemy.orm.attributes.InstrumentedAttribute at 0x12e929080>)

In [120]:
statement = sqlalchemy.insert(table).values(x=3, y=5)
print(statement, statement.compile(engine), statement.compile().params, sep="\n")

INSERT INTO public.point (x, y) VALUES (:x, :y)
INSERT INTO public.point (x, y) VALUES (%(x)s, %(y)s)
{'x': 3, 'y': 5}


In [106]:
# Implicitly parametrizes insert statement

with engine.begin() as connection:
    result = connection.execute(
        sqlalchemy.insert(table),
        [
            {"x": 2, "y": 7},
            {"x": -1, "y": 5},
        ],
    )

In [125]:
print(sqlalchemy.select(table).where(table.c.x > 3))

SELECT public.point.x, public.point.y 
FROM public.point 
WHERE public.point.x > :x_1


In [137]:
print(sqlalchemy.select(Point).where(Point.x > 3))

SELECT point.id, point.x, point.y 
FROM point 
WHERE point.x > :x_1


In [215]:
print(sqlalchemy.select(Point.x, table.c["y"]))

SELECT point_1.x, public.point.y 
FROM point AS point_1, public.point


In [227]:
random = sqlalchemy.Table("random", metadata, autoload_with=engine)
print(
    sa.select((random.c.id + 1).label("random_id"))
    .where(random.c.y == 7)
    .order_by(random.c.x)
)

SELECT random.id + :id_1 AS random_id 
FROM random 
WHERE random.y = :y_1 ORDER BY random.x


### Read_table_in_parallel_batches

In [None]:
query = f"""
            WITH tiled_primary_keys AS (
                SELECT 
                    {primary_key}, 
                    (ntile({self.n_batches}) OVER (ORDER BY {primary_key})) AS tile
                FROM {table}
            )
            SELECT DISTINCT ON(tile) {primary_key}
            FROM tiled_primary_keys
            ORDER BY tile, {primary_key}
        """

In [275]:
class IndependentBatchReader:
    def __init__(
        self,
        database_url: str,
        table: str,
        primary_key: str,
        batch_size: int,
        database_backend: Literal["pandas", "polars"] = "pandas",
    ):
        self.database_url = database_url
        engine = sa.create_engine(database_url)
        self.table = reflect_table(engine, table)
        self.primary_key = primary_key
        self.batch_size = batch_size
        self.database_backend = database_backend
        self.size = get_table_size(engine, table)
        self.n_batches = get_number_of_batches(self.size, batch_size)

    def query_to_dataframe(self, query):
        match self.database_backend:
            case "pandas":
                df = pd.read_sql(query, con=self.database_url)
            case "polars":
                engine = sa.create_engine(self.database_url)
                table_datatypes = get_table_datatypes(self.table, as_python_dtypes=True)
                df = pl.read_database(
                    query,
                    connection=engine,
                    schema_overrides=table_datatypes,
                    infer_schema_length=0,
                )
            case _:
                raise ValueError(f"Unknown database backend '{database_backend}'.")

        return df

    def offset_read(self, batch_begin):
        query = (
            self.table.select()
            .order_by(self.primary_key)
            .offset(batch_begin)
            .limit(self.batch_size)
        )
        df = self.query_to_dataframe(query)
        return df

    def primary_key_read(self, batch_begin, batch_end=None):
        if batch_end:
            where = sa.between(sa.column(self.primary_key), batch_begin, batch_end)
        else:
            where = sa.column(self.primary_key) >= batch_begin

        query = self.table.select().where(where)
        df = self.query_to_dataframe(query)
        return df

    def generate_offset_reads(self):
        """Simple, but recomputes all preceding rows for each batch."""
        for batch_index in range(self.n_batches):
            yield functools.partial(
                self.offset_read, batch_begin=batch_index * self.batch_size
            )

    def build_primary_key_begins_query(self):
        """
        WITH tiled_primary_keys AS (
            SELECT
                {primary_key},
                (ntile({self.n_batches}) OVER (ORDER BY {primary_key})) AS tile
            FROM {table}
        )
        SELECT DISTINCT ON(tile) {primary_key}
        FROM tiled_primary_keys
        ORDER BY tile, {primary_key}
        """
        pk = sa.column(self.primary_key)
        cte = (
            sa.select(pk, sa.func.ntile(self.n_batches).over(order_by=pk).label("tile"))
            .select_from(self.table)
            .cte("tiled_primary_keys")
        )
        query = sa.select(pk).distinct("tile").select_from(cte).order_by("tile", pk)
        return query

    def generate_primary_key_reads(self):
        """Efficient, but more complex."""
        query = self.build_primary_key_begins_query()
        with sa.create_engine(self.database_url).connect() as connection:
            result = connection.execute(query)
            primary_key_begins = list(result.scalars())

        for batch_index, batch_begin in enumerate(primary_key_begins):
            if batch_index + 1 < len(primary_key_begins):
                batch_end = primary_key_begins[batch_index + 1] - 1
            else:
                batch_end = None

            yield functools.partial(
                self.primary_key_read,
                batch_begin=batch_begin,
                batch_end=batch_end,
            )

In [276]:
# Shuffle table
table = "random"
primary_key = "id"
table_df = generate_table(10**6)
table_df = table_df.sample(frac=1)
table_df.to_sql(table, engine, if_exists="replace", index=False)

### No primary key
random = sqlalchemy.Table(table, sqlalchemy.MetaData(), autoload_with=engine)
random.primary_key

PrimaryKeyConstraint()

In [277]:
batch_size = 10**4
n_jobs = -1
reader = IndependentBatchReader(
    str(engine.url), table, primary_key, batch_size, database_backend="pandas"
)
parallel = joblib.Parallel(n_jobs=n_jobs, return_as="generator_unordered")


def time_read_function(read_function):
    def f():
        pks = []
        for df in parallel(joblib.delayed(i)() for i in read_function()):
            assert len(df) == reader.batch_size
            pks.extend(list(df[primary_key]))

        assert len(set(pks)) == reader.size

    %time f()

In [278]:
time_read_function(reader.generate_offset_reads)
time_read_function(reader.generate_primary_key_reads)

CPU times: user 698 ms, sys: 165 ms, total: 863 ms
Wall time: 7.03 s
CPU times: user 482 ms, sys: 114 ms, total: 596 ms
Wall time: 2 s


In [279]:
with engine.begin() as connection:
    connection.execute(
        sqlalchemy.text(f"ALTER TABLE {table} ADD PRIMARY KEY ({primary_key})")
    )

time_read_function(reader.generate_offset_reads)
time_read_function(reader.generate_primary_key_reads)

CPU times: user 612 ms, sys: 131 ms, total: 743 ms
Wall time: 3.24 s
CPU times: user 405 ms, sys: 89.9 ms, total: 495 ms
Wall time: 1.42 s


In [280]:
execute_query(to_autocommit_engine(engine), f"VACUUM ANALYZE {table}")
time_read_function(reader.generate_offset_reads)
time_read_function(reader.generate_primary_key_reads)

CPU times: user 605 ms, sys: 114 ms, total: 719 ms
Wall time: 3.39 s
CPU times: user 397 ms, sys: 99.2 ms, total: 496 ms
Wall time: 1.02 s
