# Hello SQL Alchemy

This is a tutorial guide on using SQL alchemy to interact with a postgres database.

It assumes there exists an `.env` file that contains the database connection information including, `DB_USER`, `DB_PASS`, `DB_PORT`, `DB_ENDPOINT` and `DB_NAME`, or these variables are loaded into the environment by some other means.

Based on [this tutorial from SQLAlchemy docs](https://docs.sqlalchemy.org/en/20/tutorial/).

## Load variables to form database connection string 

In [1]:
import dotenv
import os

In [2]:
dotenv.load_dotenv()

True

In [3]:
db_user = os.getenv("DB_USER")
db_pass = os.getenv("DB_PASS")
db_port = os.getenv("DB_PORT")
db_name = os.getenv("DB_NAME")
db_endpoint = os.getenv("DB_ENDPOINT")

## Connect to the database

In [4]:
from sqlalchemy import create_engine

In [5]:
engine = create_engine(
    f"postgresql+psycopg2://{db_user}:{db_pass}@{db_endpoint}:{db_port}/{db_name}"
)

## Making a connection

In [6]:
from sqlalchemy import text

In [7]:
with engine.connect() as conn:
    result = conn.execute(text("select 'hello world'"))
    print(result.all())

[('hello world',)]


## Creating tables

In [8]:
from typing import Optional
from datetime import datetime 
from sqlalchemy.orm import (
    DeclarativeBase,
    Mapped,
    mapped_column,
    relationship,
)
from sqlalchemy import ForeignKey
from sqlalchemy.dialects.postgresql import JSONB

In [9]:
class Base(DeclarativeBase):
    type_annotation_map = {
        dict: JSONB
    }

In [10]:
class Task(Base):
    __tablename__ = "test_task"

    tid: Mapped[str] = mapped_column(primary_key=True)
    description: Mapped[str]
    solution: Mapped[Optional[str]]

    sessions: Mapped[list["Session"]] = relationship(back_populates="task")


class Session(Base):
    __tablename__ = "test_session"
    
    sid: Mapped[str] = mapped_column(primary_key=True)
    start_ts: Mapped[datetime]
    tid: Mapped[Optional[str]] = mapped_column(ForeignKey("test_task.tid"))
    
    task: Mapped[Optional["Task"]] = relationship(back_populates="sessions")
    messages: Mapped[list["Message"]] = relationship(back_populates="session")
    prediction: Mapped[Optional["Prediction"]] = relationship(back_populates="session")


class Prediction(Base):
    __tablename__ = "test_prediction"

    pid: Mapped[str] = mapped_column(primary_key=True)
    prediction: Mapped[str]
    sid: Mapped[str] = mapped_column(ForeignKey("test_session.sid"))

    session: Mapped["Session"] = relationship(back_populates="prediction")

                                     
class Message(Base):
    __tablename__ = "test_message"

    mid: Mapped[str] = mapped_column(primary_key=True)
    ts: Mapped[datetime]
    content: Mapped[str]
    sid: Mapped[str] = mapped_column(ForeignKey("test_session.sid"))
    aid: Mapped[str] = mapped_column(ForeignKey("test_agent.aid"))

    session: Mapped["Session"] = relationship(back_populates="messages")
    agent: Mapped["Agent"] = relationship(back_populates="messages")
    
    
class Agent(Base):
    __tablename__ = "test_agent"

    aid: Mapped[str] = mapped_column(primary_key=True)
    role: Mapped[str]
    name: Mapped[str]
    
    messages: Mapped[list["Message"]] = relationship(back_populates="agent")
    checkpoints: Mapped[Optional[list["ModelCheckpoint"]]] = relationship(back_populates="agent")


class ModelCheckpoint(Base):
    __tablename__ = "test_checkpoint"

    cid: Mapped[str] = mapped_column(primary_key=True)
    aid: Mapped[str] = mapped_column(ForeignKey("test_agent.aid"))
    url: Mapped[str]
    ts: Mapped[datetime]
    params: Mapped[Optional[dict]]

    agent: Mapped["Agent"] = relationship(back_populates="checkpoints")

In [11]:
Base.metadata.create_all(engine, checkfirst=False)  

## Insert data 

In [12]:
from sqlalchemy.orm import Session as SQLSession 

In [13]:
def get_session():
    return {
        "sid": "SID123",
        "start_ts": datetime.now(),
    }

def get_task():
    return {
        "tid": "TID123",
        "description": "problem description...",
        "solution": "<answer>",
    }

def get_messages():
    return [
        {"mid": "mid001",  "content": "prompt...", "ts": datetime.now(),
             "agent": {"aid": "aid123",  "role": "user", "name": "Kevin"}},
        {"mid": "mid002", "content": "response...", "ts": datetime.now(),
             "agent": {"aid": "aid345", "role": "agent", "name": "numina"}},
        {"mid": "mid003", "content": "message...", "ts": datetime.now(),
             "agent": {"aid": "aid123", "role": "user", "name": "Kevin"}},
        {"mid": "mid004", "content": "response...", "ts": datetime.now(),
             "agent": {"aid": "aid345", "role": "agent", "name": "numina"}},
    ]

_session = get_session()
_task = get_task()
_messages = get_messages()

In [14]:
# insert session
session = Session(**_session)

In [15]:
# insert tasks
task = Task(**_task)

In [16]:
# update session with task id
session.tid = task.tid
session.task = task

In [17]:
# insert messages and agents
agents = {}
for _msg in _messages:
    _agent = _msg["agent"]
    message = Message(
        mid=_msg["mid"],
        content=_msg["content"],
        ts=_msg["ts"],
        sid=_session["sid"],
        aid=_agent["aid"],
    )
    if not (aid := _agent["aid"]) in agents:
        agent = Agent(**_agent)
        agents[aid] = agent
    message.agent = agents[aid]
    message.session = session

In [18]:
# add to session
with SQLSession(engine) as sql_session:
    sql_session.add(session)
    sql_session.commit()

## Example Query

In [19]:
from sqlalchemy import select

In [20]:
# using SQLAlchemy ORM, which returns objects
with SQLSession(engine) as sql_session:
    stmt = select(Message).order_by(Message.ts)
    print(sql_session.execute(stmt).all())

[(<__main__.Message object at 0x7733489e3500>,), (<__main__.Message object at 0x7733489e3530>,), (<__main__.Message object at 0x7733489e3560>,), (<__main__.Message object at 0x7733489e3590>,)]


In [21]:
# using SQLAlchemy core, which returns raw results
with engine.connect() as conn:
    stmt = select(Message).order_by(Message.ts)
    for row in conn.execute(stmt):
        print(dict(zip(stmt.selected_columns.keys(), row)))

{'mid': 'mid001', 'ts': datetime.datetime(2024, 11, 25, 22, 44, 45, 948198), 'content': 'prompt...', 'sid': 'SID123', 'aid': 'aid123'}
{'mid': 'mid002', 'ts': datetime.datetime(2024, 11, 25, 22, 44, 45, 948199), 'content': 'response...', 'sid': 'SID123', 'aid': 'aid345'}
{'mid': 'mid003', 'ts': datetime.datetime(2024, 11, 25, 22, 44, 45, 948200), 'content': 'message...', 'sid': 'SID123', 'aid': 'aid123'}
{'mid': 'mid004', 'ts': datetime.datetime(2024, 11, 25, 22, 44, 45, 948200), 'content': 'response...', 'sid': 'SID123', 'aid': 'aid345'}


## Clean up test tables 

In [22]:
Base.metadata.drop_all(engine)