# Test pgvector

## Setup

In [1]:
from sqlmodel import SQLModel, Field, create_engine, Session, select, JSON
from sqlalchemy.dialects.postgresql import JSONB
from sqlmodel import text
from typing import Any
from pgvector.sqlalchemy import Vector
from uuid import uuid4

VECTOR_SIZE = 3

USERNAME = "myuser"
PASSWORD = "mypassword"
HOST = "localhost"
PORT = 5432
DATABASE_NAME = "mydb"
DATABASE_URL = f"postgresql://{USERNAME}:{PASSWORD}@{HOST}:{PORT}/{DATABASE_NAME}"

## Create Engine and Extension

In [2]:
engine = create_engine(DATABASE_URL, echo=False)

In [3]:
with engine.connect() as conn:
    conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector;'))
    conn.commit()

In [35]:
query = """\
CREATE EXTENSION IF NOT EXISTS pg_trgm;
""".strip()
with engine.connect() as conn:
    conn.execute(text(query))
    conn.commit()

In [6]:
query = """\
ALTER TABLE longtermmemory
ADD COLUMN context_tsv tsvector
    GENERATED ALWAYS AS (to_tsvector('english', context)) STORED;

CREATE INDEX idx_ltm_context_tsv ON longtermmemory USING GIN (context_tsv);
""".strip()

with engine.connect() as conn:
    conn.execute(text(query))
    conn.commit()

In [28]:
query = """\
CREATE INDEX idx_ltm_context_tsv
ON longtermmemory
USING GIN (to_tsvector('english', context));
""".strip()
with engine.connect() as conn:
    conn.execute(text(query))
    conn.commit()

## Create LongTermMemory

In [4]:
class LongTermMemory(SQLModel, table=True, extend_existing=True):
    id: str = Field(primary_key=True, default_factory=lambda: str(uuid4()))
    context: str
    embedding: Any = Field(sa_type=Vector(VECTOR_SIZE))
    meta: dict = Field(sa_type=JSONB, default_factory=dict)

In [5]:
SQLModel.metadata.create_all(engine)

# Add

In [7]:
with Session(engine) as session:
    ls_mem = LongTermMemory(context="first context", embedding=[0.0, 0.0, 0.0], meta={"source":"a"})
    session.add(ls_mem)
    session.commit()

In [8]:
rows = [
    LongTermMemory(context="second context", embedding=[0.1, 0.1, 0.1], meta={"source": "b"}),
    LongTermMemory(context="third context", embedding=[-0.1, -0.1, -0.1], meta={"source": "c"}),
    LongTermMemory(context="forth context", embedding=[0.1, 0.0, 0.0], meta={"source": "a"}),
    LongTermMemory(context="fifth context", embedding=[-0.1, 0.0, 0.0]),
    LongTermMemory(context="sixth context", embedding=[0.0, 0.0, 0.1]),
    LongTermMemory(context="seventh context", embedding=[0.0, 0.0, -0.1]),
]

with Session(engine) as session:
    session.add_all(rows)
    session.commit()

# Select

In [9]:
with Session(engine) as session:
    results = session.exec(select(LongTermMemory)).all()

results

[LongTermMemory(meta={'source': 'a'}, context='first context', id='c4324459-9268-465e-b635-aefeaa9a98da', embedding=array([0., 0., 0.], dtype=float32)),
 LongTermMemory(meta={'source': 'b'}, context='second context', id='3d9c08ff-ed8c-4b49-a838-600f1e9a70a5', embedding=array([0.1, 0.1, 0.1], dtype=float32)),
 LongTermMemory(meta={'source': 'c'}, context='third context', id='30b8666a-db00-4000-afe4-ae1c5043aa67', embedding=array([-0.1, -0.1, -0.1], dtype=float32)),
 LongTermMemory(meta={'source': 'a'}, context='forth context', id='d65c2659-0492-4877-8673-226d7b226dde', embedding=array([0.1, 0. , 0. ], dtype=float32)),
 LongTermMemory(meta={}, context='fifth context', id='4c193349-b428-4076-bf2f-b86cad3c41e4', embedding=array([-0.1,  0. ,  0. ], dtype=float32)),
 LongTermMemory(meta={}, context='sixth context', id='e40a95bc-e74b-4247-b2a3-d0b1afcd2c2a', embedding=array([0. , 0. , 0.1], dtype=float32)),
 LongTermMemory(meta={}, context='seventh context', id='dd0ceeba-d3f3-4135-bb93-2a21ad

In [10]:
with Session(engine) as session:
    results = session.exec(select(LongTermMemory).order_by(LongTermMemory.embedding.l2_distance([0.0, 0.0, 0.0])).limit(5)).all()

results

[LongTermMemory(meta={'source': 'a'}, context='first context', id='c4324459-9268-465e-b635-aefeaa9a98da', embedding=array([0., 0., 0.], dtype=float32)),
 LongTermMemory(meta={'source': 'a'}, context='forth context', id='d65c2659-0492-4877-8673-226d7b226dde', embedding=array([0.1, 0. , 0. ], dtype=float32)),
 LongTermMemory(meta={}, context='sixth context', id='e40a95bc-e74b-4247-b2a3-d0b1afcd2c2a', embedding=array([0. , 0. , 0.1], dtype=float32)),
 LongTermMemory(meta={}, context='seventh context', id='dd0ceeba-d3f3-4135-bb93-2a21adea6f5e', embedding=array([ 0. ,  0. , -0.1], dtype=float32)),
 LongTermMemory(meta={}, context='fifth context', id='4c193349-b428-4076-bf2f-b86cad3c41e4', embedding=array([-0.1,  0. ,  0. ], dtype=float32))]

In [11]:
with Session(engine) as session:
    result = session.exec(text("""
        SELECT * FROM longtermmemory WHERE meta @> '{"source": "a"}'
    """)).all()

result

[('c4324459-9268-465e-b635-aefeaa9a98da', 'first context', '[0,0,0]', {'source': 'a'}, "'context':2 'first':1"),
 ('d65c2659-0492-4877-8673-226d7b226dde', 'forth context', '[0.1,0,0]', {'source': 'a'}, "'context':2 'forth':1")]

In [12]:
import json
top_n = 2
vector = [0,0,0]
with Session(engine) as session:
    sql_orm = select(LongTermMemory).where(
        # text(f"""meta @> '{json.dumps({"source": "a"})}'""")
        # text("""meta @> '{"source": "a"}'""")
        text("meta @> :jsonb_filter")
    ).order_by(LongTermMemory.embedding.l2_distance(vector)).limit(top_n).params(jsonb_filter='{"source": "a"}')
    results = session.exec(sql_orm).all()
results

[LongTermMemory(meta={'source': 'a'}, context='first context', id='c4324459-9268-465e-b635-aefeaa9a98da', embedding=array([0., 0., 0.], dtype=float32)),
 LongTermMemory(meta={'source': 'a'}, context='forth context', id='d65c2659-0492-4877-8673-226d7b226dde', embedding=array([0.1, 0. , 0. ], dtype=float32))]

In [13]:
top_n = 5
vector = [0,0,0]
with Session(engine) as session:
    sql_orm = select(LongTermMemory).where(
        text("meta ->> 'source' = ANY(:sources)")
    ).order_by(LongTermMemory.embedding.l2_distance(vector)).limit(top_n).params(sources=["a", "b"])
    results = session.exec(sql_orm).all()
results

[LongTermMemory(meta={'source': 'a'}, context='first context', id='c4324459-9268-465e-b635-aefeaa9a98da', embedding=array([0., 0., 0.], dtype=float32)),
 LongTermMemory(meta={'source': 'a'}, context='forth context', id='d65c2659-0492-4877-8673-226d7b226dde', embedding=array([0.1, 0. , 0. ], dtype=float32)),
 LongTermMemory(meta={'source': 'b'}, context='second context', id='3d9c08ff-ed8c-4b49-a838-600f1e9a70a5', embedding=array([0.1, 0.1, 0.1], dtype=float32))]

In [14]:
query = "first"
with Session(engine) as session:
    sql_orm = select(LongTermMemory).where(
        text("to_tsvector('english', context) @@ plainto_tsquery(:q)")
    ).params(q=query)
    results = session.exec(sql_orm).all()
results

[LongTermMemory(meta={'source': 'a'}, context='first context', id='c4324459-9268-465e-b635-aefeaa9a98da', embedding=array([0., 0., 0.], dtype=float32))]

In [None]:
query = "th:* & context"

with Session(engine) as session:
    sql_orm = select(LongTermMemory).where(
        text("context_tsv @@ to_tsquery(:q)")
    ).order_by(text("ts_rank(context_tsv, to_tsquery('english', :q)) DESC")).params(q=query)
    results = session.exec(sql_orm).all()
results

[LongTermMemory(meta={'source': 'c'}, context='third context', id='30b8666a-db00-4000-afe4-ae1c5043aa67', embedding=array([-0.1, -0.1, -0.1], dtype=float32))]

In [34]:
with Session(engine) as session:
    query = select(LongTermMemory).where(
        LongTermMemory.context.ilike("%th context%")
    )
    results = session.exec(query).all()

results

[LongTermMemory(meta={'source': 'a'}, context='forth context', id='d65c2659-0492-4877-8673-226d7b226dde', embedding=array([0.1, 0. , 0. ], dtype=float32)),
 LongTermMemory(meta={}, context='fifth context', id='4c193349-b428-4076-bf2f-b86cad3c41e4', embedding=array([-0.1,  0. ,  0. ], dtype=float32)),
 LongTermMemory(meta={}, context='sixth context', id='e40a95bc-e74b-4247-b2a3-d0b1afcd2c2a', embedding=array([0. , 0. , 0.1], dtype=float32)),
 LongTermMemory(meta={}, context='seventh context', id='dd0ceeba-d3f3-4135-bb93-2a21adea6f5e', embedding=array([ 0. ,  0. , -0.1], dtype=float32))]

In [42]:
misspelled = "sevonth context"

with Session(engine) as session:
    stmt = select(LongTermMemory).where(
        text("similarity(context, :q) > 0.5")
    ).order_by(
        text("similarity(context, :q) DESC")
    ).params(q=misspelled)

    results = session.exec(stmt).all()

results

[LongTermMemory(meta={}, context='seventh context', id='dd0ceeba-d3f3-4135-bb93-2a21adea6f5e', embedding=array([ 0. ,  0. , -0.1], dtype=float32)),
 LongTermMemory(meta={'source': 'b'}, context='second context', id='3d9c08ff-ed8c-4b49-a838-600f1e9a70a5', embedding=array([0.1, 0.1, 0.1], dtype=float32)),
 LongTermMemory(meta={}, context='sixth context', id='e40a95bc-e74b-4247-b2a3-d0b1afcd2c2a', embedding=array([0. , 0. , 0.1], dtype=float32))]