In [47]:
import os
import sqlalchemy as sa
from sqlalchemy.orm import declarative_base, sessionmaker, Session
from pgvector.sqlalchemy import Vector
from dotenv import load_dotenv

load_dotenv()

True

In [48]:
Base = declarative_base()
DATABASE_URL = os.getenv("DB_URI", "")
VECTOR_DIM = int(os.getenv("VECTOR_DIM", 896))
engine = sa.create_engine(DATABASE_URL)

class RelationType(Base):
    __tablename__ = "relation_types"
    
    id = sa.Column(sa.BigInteger, primary_key=True)
    type = sa.Column(sa.String, nullable=False)
    definition = sa.Column(sa.String, nullable=False)
    embedding = sa.Column(Vector(VECTOR_DIM))
    
class Entity(Base):
    __tablename__ = "entity"
    
    id = sa.Column(sa.BigInteger, primary_key=True)
    name = sa.Column(sa.String, nullable=False)
    embedding = sa.Column(Vector(VECTOR_DIM))

# 3. Create tables
Base.metadata.create_all(engine)

In [49]:
SessionLocal = sessionmaker(bind=engine)
session = SessionLocal()

In [72]:
import numpy as np

sample_vector = np.array([1.0] * 896)
# print(type(sample_vector))

In [73]:
relation_1 = RelationType(type="award", definition="hello", embedding=sample_vector)
relation_2 = RelationType(type="birthday", definition="hello", embedding=sample_vector)

entity_1 = Entity(name="Do van", embedding = sample_vector)
entity_2 = Entity(name="Helo 1 2 3", embedding= sample_vector)

try:
    session.add_all([entity_2, relation_1, relation_2])
    session.commit()
    print("Item add successfully")
except Exception as e:
    print(f"Error: {e}")
    session.rollback()


Error: (psycopg2.errors.UniqueViolation) duplicate key value violates unique constraint "entity_name_key"
DETAIL:  Key (name)=(Helo 1 2 3) already exists.

[SQL: INSERT INTO entity (name, embedding) VALUES (%(name)s, %(embedding)s) RETURNING entity.id]
[parameters: {'name': 'Helo 1 2 3', 'embedding': '[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0, ... (3287 characters truncated) ... ,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0]'}]
(Background on this error at: https://sqlalche.me/e/20/gkpj)


In [75]:
session.add(entity_1)
session.commit()

In [76]:
from sqlalchemy import select

distance_query = select(
    Entity,
    Entity.embedding.cosine_distance([1.0] * 896).label("distance")
).order_by("distance").limit(1)

results_l2 = session.execute(distance_query).all()

for row in results_l2:
    item = row[0]
    dist = row[1]
    print(f"ID: {item.id}, Name: {item.name}, Distance: {dist}")

ID: 4, Name: Do van, Distance: 0.0
