# Test pgvector

## Setup

In [1]:
from sqlmodel import SQLModel, Field, create_engine, Session, select, JSON
from sqlalchemy.dialects.postgresql import JSONB
from sqlmodel import text
from typing import Any
from pgvector.sqlalchemy import Vector
from uuid import uuid4

VECTOR_SIZE = 3

USERNAME = "myuser"
PASSWORD = "mypassword"
HOST = "localhost"
PORT = 5432
DATABASE_NAME = "mydb"
DATABASE_URL = f"postgresql://{USERNAME}:{PASSWORD}@{HOST}:{PORT}/{DATABASE_NAME}"

## Create Engine and Extension

In [2]:
engine = create_engine(DATABASE_URL, echo=False)

In [3]:
with engine.connect() as conn:
    conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector;'))
    conn.commit()

## Create LongTermMemory

In [4]:
class LongTermMemory(SQLModel, table=True, extend_existing=True):
    id: str = Field(primary_key=True, default_factory=lambda: str(uuid4()))
    context: str
    embedding: Any = Field(sa_type=Vector(VECTOR_SIZE))
    meta: dict = Field(sa_type=JSONB, default_factory=dict)

In [5]:
SQLModel.metadata.create_all(engine)

# Add

In [6]:
with Session(engine) as session:
    ls_mem = LongTermMemory(context="first context", embedding=[0.0, 0.0, 0.0], meta={"source":"a"})
    session.add(ls_mem)
    session.commit()

In [None]:
rows = [
    LongTermMemory(context="second context", embedding=[0.1, 0.1, 0.1], meta={"source": "b"}),
    LongTermMemory(context="third context", embedding=[-0.1, -0.1, -0.1], meta={"source": "c"}),
    LongTermMemory(context="forth context", embedding=[0.1, 0.0, 0.0], meta={"source": "a"}),
    LongTermMemory(context="fifth context", embedding=[-0.1, 0.0, 0.0]),
    LongTermMemory(context="sixth context", embedding=[0.0, 0.0, 0.1]),
    LongTermMemory(context="seventh context", embedding=[0.0, 0.0, -0.1]),
]

with Session(engine) as session:
    session.add_all(rows)
    session.commit()

In [8]:
with Session(engine) as session:
    results = session.exec(select(LongTermMemory)).all()

results

[LongTermMemory(context='first context', id='8a029eec-b159-486f-af28-004f5055c602', embedding=array([0., 0., 0.], dtype=float32), meta={'source': 'a'}),
 LongTermMemory(context='second context', id='541b73c0-a316-49fd-b50a-d11cdd2a5942', embedding=array([0.1, 0.1, 0.1], dtype=float32), meta={'source': 'b'}),
 LongTermMemory(context='third context', id='9fd62b18-e36b-4e1f-934c-afe5a2680544', embedding=array([-0.1, -0.1, -0.1], dtype=float32), meta={'source': 'c'}),
 LongTermMemory(context='forth context', id='300ef578-efe5-4d07-a2bd-a08eff200ff6', embedding=array([0.1, 0. , 0. ], dtype=float32), meta={'source': 'a'}),
 LongTermMemory(context='fifth context', id='3b7004ec-01f3-4e9d-9802-8e5ea70400d1', embedding=array([-0.1,  0. ,  0. ], dtype=float32), meta={}),
 LongTermMemory(context='sixth context', id='87ba351e-a612-4d95-8ac9-75df6aa33268', embedding=array([0. , 0. , 0.1], dtype=float32), meta={}),
 LongTermMemory(context='seventh context', id='5cf96980-31aa-459a-88a9-8736a0e9d08b', 

In [6]:
with Session(engine) as session:
    results = session.exec(select(LongTermMemory).order_by(LongTermMemory.embedding.l2_distance([0.0, 0.0, 0.0])).limit(5)).all()

results

[LongTermMemory(context='first context', id='8a029eec-b159-486f-af28-004f5055c602', embedding=array([0., 0., 0.], dtype=float32), meta={'source': 'a'}),
 LongTermMemory(context='forth context', id='300ef578-efe5-4d07-a2bd-a08eff200ff6', embedding=array([0.1, 0. , 0. ], dtype=float32), meta={'source': 'a'}),
 LongTermMemory(context='sixth context', id='87ba351e-a612-4d95-8ac9-75df6aa33268', embedding=array([0. , 0. , 0.1], dtype=float32), meta={}),
 LongTermMemory(context='seventh context', id='5cf96980-31aa-459a-88a9-8736a0e9d08b', embedding=array([ 0. ,  0. , -0.1], dtype=float32), meta={}),
 LongTermMemory(context='fifth context', id='3b7004ec-01f3-4e9d-9802-8e5ea70400d1', embedding=array([-0.1,  0. ,  0. ], dtype=float32), meta={})]

In [8]:
with Session(engine) as session:
    result = session.exec(text("""
        SELECT * FROM longtermmemory WHERE meta @> '{"source": "a"}'
    """)).all()

result

[('8a029eec-b159-486f-af28-004f5055c602', 'first context', '[0,0,0]', {'source': 'a'}),
 ('300ef578-efe5-4d07-a2bd-a08eff200ff6', 'forth context', '[0.1,0,0]', {'source': 'a'})]