In [1]:
%pip install -qU openai
%pip install -qU "psycopg[binary]"

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
import sys
sys.path.insert(0, '..')

In [3]:
from psycopg import connect
from muni.llm import create_embedding

def connection():
    return connect(
        dbname="regrag",
        host="localhost",
        port="5432",
        autocommit=True
    )

In [4]:
def simple_semantic_query(conn, query, limit=10):
    query_embedding = create_embedding(query)
    with conn.cursor() as cursor:
        sql = """
        SELECT id, L4_heading, text
        FROM muni
        WHERE jurisdiction = 'Chicago'
        ORDER BY embedding <=> %s
        LIMIT %s;
        """
        cursor.execute(sql, (str(query_embedding), limit))
        return cursor.fetchall()
        
with connection() as conn:        
    results = simple_semantic_query(conn, 'drug paraphernalia')
for r in results:
    print(r)

(201, 'Sale of certain substances.', 'Article II. Drug Paraphernalia')
(225, 'Seizure and forfeiture.', 'All drug paraphernalia defined in Section 720 ILCS 600/2,\nsubparagraph (5), and including glass tubing utilized for the ingestion\nof cocaine or crack cocaine, is subject to forfeiture and may be seized\nby any peace officer. The seizure and forfeiture shall be made in\naccordance with rules issued by the superintendent of police or his\ndesignee.\n\n(Added Coun. J. 5-12-99, p. 3327)')
(221, 'Manufacture.', 'Except as authorized by law, any person who manufactures, with intent\nto deliver, furnish, or transfer drug paraphernalia knowing, or under\ncircumstances where one reasonably should know, that it will be used to\nplant, propagate, cultivate, grow, harvest, manufacture, compound,\nconvert, produce, process, prepare, test, analyze, pack, repack, store,\ncontain, conceal, ingest, inhale or otherwise introduce into the human\nbody cocaine, cocaine base, heroin, phencyclidine, or 

In [5]:
def simple_full_text_query(conn, query, limit=10):
    with conn.cursor() as cursor:
        sql = """
        WITH tsq AS (
            SELECT to_tsquery('english', %s) AS search
            )
        SELECT id, L4_heading, text
        FROM muni, tsq
        WHERE jurisdiction = 'Chicago'
        AND textsearchable @@ tsq.search
        ORDER BY ts_rank_cd(textsearchable, tsq.search)
        LIMIT %s;
        """
        cursor.execute(sql, (query, limit))
        return cursor.fetchall()

with connection() as conn:        
    results = simple_full_text_query(conn, 'drug & paraphernalia')
for r in results:
     print(r)

(201, 'Sale of certain substances.', 'Article II. Drug Paraphernalia')
(219, 'Sale of certain substances.', "No person shall knowingly sell or offer for sale, deliver or give\naway to any person under 17 years of age, unless upon the written order\nof parent or guardian, any substances containing any of the following\nvolatile solvents, where the seller, offerer or deliverer knows or has\nreason to believe that the substance will be used for the purpose of\ninducing symptoms of intoxication, elation, dizziness, paralysis,\nirrational behavior, or in any manner change, distort or disturb the\naudio, visual or mental processes:\n\n\xa0\xa0\xa0Toluol, hexane, trichloroethylene, acetone, toluene, ethyl acetate,\nmethyl ethyl ketone, trichloroethane, isopropanol, methyl isobutyl\nketone, methyl cellosolve acetate, cyclohexanone, or any other substance\nwhich will induce symptoms of intoxication, elation, dizziness,\nparalysis, irrational behavior, or in any manner change, distort or\ndistur

In [6]:
# Now we do a more complicated hybrid search, borrowing and adapting from 
# https://github.com/pgvector/pgvector-python/blob/master/examples/hybrid_search_rrf.py

def hybrid_query(conn, query, limit=10):
    embedding = create_embedding(query)

    sql = """
    WITH semantic_search AS (
        SELECT id, L4_heading, RANK () OVER (ORDER BY embedding <=> %(embedding)s) AS rank
        FROM muni
        ORDER BY embedding <=> %(embedding)s
        LIMIT 20
    ),
    keyword_search AS (
        SELECT id, L4_heading, RANK () OVER (ORDER BY ts_rank_cd(textsearchable, query) DESC)
        FROM muni, plainto_tsquery('english', %(query)s) query
        WHERE textsearchable @@ query
        ORDER BY ts_rank_cd(textsearchable, query) DESC
        LIMIT 20
    )
    SELECT
        COALESCE(semantic_search.id, keyword_search.id) AS id,
        COALESCE(1.0 / (%(k)s + semantic_search.rank), 0.0) +
        COALESCE(1.0 / (%(k)s + keyword_search.rank), 0.0) AS score,
        COALESCE(semantic_search.L4_heading, keyword_search.L4_heading) AS L4_heading
    FROM semantic_search
    FULL OUTER JOIN keyword_search ON semantic_search.id = keyword_search.id
    ORDER BY score DESC
    LIMIT %(limit)s;
    """
    result = conn.execute(sql, {'query': query, 'embedding': str(embedding), 'limit': limit, 'k': 60})
    return result.fetchall()

with connection() as conn:
    results = hybrid_query(conn, 'drug paraphernalia')

for row in results:
    print(row)

(201, Decimal('0.03252247488101533580'), 'Sale of certain substances.')
(225, Decimal('0.03225806451612903226'), 'Seizure and forfeiture.')
(221, Decimal('0.03200204813108038915'), 'Manufacture.')
(222, Decimal('0.03175403225806451613'), 'Accomplice liability.')
(220, Decimal('0.01639344262295081967'), 'Possession or delivery.')
(219, Decimal('0.01612903225806451613'), 'Sale of certain substances.')
(212, Decimal('0.01538461538461538462'), 'Sample packages of medicines.')
(231, Decimal('0.01515151515151515152'), 'Fraudulent prescriptions.')
(203, Decimal('0.01492537313432835821'), 'Prohibited possession or use of cannabis.')
(238, Decimal('0.01470588235294117647'), 'Purchase of cigar and cigarette stumps.')


In [7]:
# Try query augmentation using Hyde (generation of synthetic replies matching the
# format of expected answers)
from muni.llm import augmented_embedding

def augmented_query(conn, query, limit=10):
    query_embedding = augmented_embedding(query, orig_weight = 0.5)
    with conn.cursor() as cursor:
        sql = """
        SELECT id, L4_heading, text
        FROM muni
        WHERE jurisdiction = 'Chicago'
        ORDER BY embedding <=> %s
        LIMIT %s;
        """
        cursor.execute(sql, (str(query_embedding), limit))
        return cursor.fetchall()
        
with connection() as conn:        
    results = augmented_query(conn, 'Does the code restrict drug paraphernalia')
for r in results:
    print(r)

(201, 'Sale of certain substances.', 'Article II. Drug Paraphernalia')
(221, 'Manufacture.', 'Except as authorized by law, any person who manufactures, with intent\nto deliver, furnish, or transfer drug paraphernalia knowing, or under\ncircumstances where one reasonably should know, that it will be used to\nplant, propagate, cultivate, grow, harvest, manufacture, compound,\nconvert, produce, process, prepare, test, analyze, pack, repack, store,\ncontain, conceal, ingest, inhale or otherwise introduce into the human\nbody cocaine, cocaine base, heroin, phencyclidine, or methamphetamine in\nviolation of the Illinois Controlled Substances Act shall be fined\n$1,000.00, or punished by imprisonment for a period of six months, or by\nboth such fine and imprisonment.\n\n(Added Coun. J. 5-12-99, p. 3327)')
(225, 'Seizure and forfeiture.', 'All drug paraphernalia defined in Section 720 ILCS 600/2,\nsubparagraph (5), and including glass tubing utilized for the ingestion\nof cocaine or crack coca

In [8]:
def collect_associations(conn, right_id: int, association: str, jurisdiction: str) -> str:
    """Query the association table (e.g., to retireve all definitions that apply to a block of text)."""
    sql = """
        SELECT L1_ref, L2_ref, L3_ref, L4_ref, text from muni
        WHERE id IN (
            SELECT left_id AS id
            FROM muni_associations
            WHERE right_id = %s AND association = %s AND jurisdiction = %s
            ORDER BY left_id
        );
        """
    with conn.cursor() as cursor:
        cursor.execute(sql, (right_id, association, jurisdiction))
        rows = cursor.fetchall()
    print([row[3] for row in rows])
    return '\n'.join([row[4] for row in rows])

with connection() as conn:        
    result = collect_associations(conn, 211, 'definition', 'Chicago')
print(result)


['7-24-001', '1-16-010', '8-20-010', '8-4-080', '7-52-020', '7-4-010', '7-24-001', '1-4-090']
As used in this chapter:

   "Cannabis Control Act" means the Cannabis Control Act, codified at
720 ILCS 550/1, et seq., or its successor act.

   "Cannabis Regulation and Tax Act" means the Cannabis Regulation and
Tax Act, codified at 410 ILCS 705/1-1, et seq., or its successor act.

   "Compassionate Use of Medical Cannabis Program Act" means the
Compassionate Use of Medical Cannabis Program Act, codified at 410 ILCS
130/1, et seq., or its successor act.

   "Illinois Controlled Substances Act" means the Illinois Controlled
Substances Act, codified at 720 ILCS 570/100, et seq., or its successor
act.

   "Illinois Vehicle Code" means the Illinois Vehicle Code, codified at
625 ILCS 5/1-100, et seq., or its successor act.

   "Smoke Free Illinois Act" means the Smoke Free Illinois Act, codified
at 410 ILCS 82/1, et seq., or its successor act.

(Added Coun. J. 11-26-19, p. 11547, § 10)

ARTICLE 

In [9]:
len(result)

12565

In [10]:
def hybrid_augmented_query(conn, query, limit=10):
    embedding = augmented_embedding(query, orig_weight = 0.5)
    sql = """
    WITH semantic_search AS (
        SELECT id, L4_heading, RANK () OVER (ORDER BY embedding <=> %(embedding)s) AS rank
        FROM muni
        ORDER BY embedding <=> %(embedding)s
        LIMIT 20
    ),
    keyword_search AS (
        SELECT id, L4_heading, RANK () OVER (ORDER BY ts_rank_cd(textsearchable, query) DESC)
        FROM muni, plainto_tsquery('english', %(query)s) query
        WHERE textsearchable @@ query
        ORDER BY ts_rank_cd(textsearchable, query) DESC
        LIMIT 20
    )
    SELECT
        COALESCE(semantic_search.id, keyword_search.id) AS id,
        COALESCE(1.0 / (%(k)s + semantic_search.rank), 0.0) +
        COALESCE(1.0 / (%(k)s + keyword_search.rank), 0.0) AS score,
        COALESCE(semantic_search.L4_heading, keyword_search.L4_heading) AS L4_heading
    FROM semantic_search
    FULL OUTER JOIN keyword_search ON semantic_search.id = keyword_search.id
    ORDER BY score DESC
    LIMIT %(limit)s;
    """
    result = conn.execute(sql, {'query': query, 'embedding': str(embedding), 'limit': limit, 'k': 60})
    return result.fetchall()

with connection() as conn:
    results = hybrid_query(conn, 'drug paraphernalia')

for row in results:
    print(row)

(201, Decimal('0.03252247488101533580'), 'Sale of certain substances.')
(225, Decimal('0.03225806451612903226'), 'Seizure and forfeiture.')
(221, Decimal('0.03200204813108038915'), 'Manufacture.')
(222, Decimal('0.03175403225806451613'), 'Accomplice liability.')
(220, Decimal('0.01639344262295081967'), 'Possession or delivery.')
(219, Decimal('0.01612903225806451613'), 'Sale of certain substances.')
(212, Decimal('0.01538461538461538462'), 'Sample packages of medicines.')
(231, Decimal('0.01515151515151515152'), 'Fraudulent prescriptions.')
(203, Decimal('0.01492537313432835821'), 'Prohibited possession or use of cannabis.')
(238, Decimal('0.01470588235294117647'), 'Purchase of cigar and cigarette stumps.')


In [12]:
from muni.llm import is_relevant

def check_relevance(conn, id_, query):
    sql = """
    SELECT text
    FROM muni
    WHERE id = %s
    """
    result = conn.execute(sql, (id_,))
    rows = result.fetchall()
    if not rows:
        return None
    text = rows[0][0]
    return is_relevant(text, query, threshold=4)

with connection() as conn:
    r = check_relevance(conn, 231, 'Does the jurisdiction have laws against drug paraphernalia?')
    print(r)

False


In [13]:
# demo retrieve-and-filter pattern

filtered_results = []
query = 'Does the jurisdiction have laws against drug paraphernalia?'
with connection() as conn:
    results = hybrid_query(conn, query)
    for row in results:
        try:
            if check_relevance(conn, row[0], query):
                print(row)
            else:
                print(f'EXCLUDING {row} as irrelevant')
        except IndexError:
            print('SKIPPING ROW')

(201, Decimal('0.01639344262295081967'), 'Sale of certain substances.')
(225, Decimal('0.01612903225806451613'), 'Seizure and forfeiture.')
(222, Decimal('0.01587301587301587302'), 'Accomplice liability.')
(221, Decimal('0.01562500000000000000'), 'Manufacture.')
EXCLUDING (239, Decimal('0.01538461538461538462'), 'Manufacture from cigar and cigarette stumps.') as irrelevant
EXCLUDING (248, Decimal('0.01515151515151515152'), 'Violation – Penalty.') as irrelevant
(223, Decimal('0.01492537313432835821'), 'Delivery to persons under 18 years of age on school grounds.')
EXCLUDING (230, Decimal('0.01470588235294117647'), 'Inspection of prescriptions.') as irrelevant
EXCLUDING (206, Decimal('0.01449275362318840580'), 'Manufacture from cigar and cigarette stumps.') as irrelevant
EXCLUDING (238, Decimal('0.01428571428571428571'), 'Purchase of cigar and cigarette stumps.') as irrelevant
