In [1]:
%pip install -qU openai marvin
%pip install -qU "psycopg[binary]"

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


## [**Code**] Simple state machine parser


Because the outline structure of most legal codes is so simple, it's feasible to implement a parser using a hand-coded
state machine that shifts between states according to the level of the outline. This may be a simpler approach than
specifying a BNF-style grammar for Lark or a similar parser generator, because in some cases outlines skip levels, which
would complicate the formal grammar.

In [2]:
from enum import Enum
#from typing import List, Dict
from dataclasses import dataclass, field
import re
import sys

sys.path.insert(0, '..')

In [3]:
class Level(Enum):
    H0 = 0 # top level (initial state)
    H1 = 1
    H2 = 2
    H3 = 3

@dataclass
class HeadingPattern:
    level: Level
    regex: str
    multi_line: bool # whether the heading spans multiple lines

@dataclass
class Heading:
    level: Level
    # heading_type: str # e.g. "section", "subsection", "article", "chapter"
    enumeration: str # number or letter (e.g. "1", "a", "i", "A", "XVII")
    heading_text: str

@dataclass
class Segment:
    level: Level
    headings: dict[Level, Heading|None] = field(default_factory=dict)
    body: list[str] = field(default_factory=list) # list of paragraphs

## For our purposes, a document is just a list of segments -- the structure is
## implicit in the headings, which will be uploaded to a relational database

In [4]:
def split_paragraph(paragraph: str) -> tuple[str, str]:
    """"Split a paragraph into its first line and the rest of the paragraph."""
    lines = paragraph.split('\n', 1)
    if len(lines) == 0:
        return '', ''
    first_line = lines[0]
    rest_of_paragraph = lines[1] if len(lines) > 1 else ''
    return first_line, rest_of_paragraph

assert split_paragraph("") == ("", "")
assert split_paragraph("This is a\nparagraph.\n") == ("This is a", "paragraph.\n")

In [5]:
def match_heading(paragraph: str, patterns: dict[Level, HeadingPattern]) -> Heading | None:
    """For each patern in `patterns`, check if the paragraph matches (e.g., pattern r'^Chapter [IVXLC]+'
    matches 'Chapter VII'). If a match is found, return a Heading object. Otherwise, return None."""
    
    paragraph = paragraph.strip()

    for level, pattern in patterns.items():
        pattern_regex = re.compile(pattern.regex, re.DOTALL)
        match = pattern_regex.match(paragraph)
        if match:
            if pattern.multi_line == False:
                return Heading(level=level, enumeration=match.group(1), heading_text=match.group(2))
            else:
                _, rest = split_paragraph(paragraph)
                return Heading(level=level, enumeration=match.group(1), heading_text=rest)

## Tests
test_doc1 = "Chapter VII: The Final Chapter"
test_pattern1 = HeadingPattern(level=Level.H1, regex=r'^Chapter ([IVXLC]+): (.+)$', multi_line=False)

test_doc2 = "\n\nChapter 7:\nThe Final Chapter"
test_pattern2 = HeadingPattern(level=Level.H1, regex=r'^Chapter (\d+):', multi_line=True)

assert match_heading(test_doc1, {Level.H1: test_pattern1}) == \
    Heading(level=Level.H1, enumeration="VII", heading_text="The Final Chapter")

assert match_heading(test_doc2, {Level.H1: test_pattern2}) == \
    Heading(level=Level.H1, enumeration="7", heading_text="The Final Chapter")

In [32]:
class StateMachineParser:
    def __init__(self, document_name: str, heading_patterns: dict[Level, HeadingPattern]):
        self.document = []
        self.heading_names = {Level.H0: document_name, Level.H1: None, Level.H2: None, Level.H3: None}
        self.patterns = heading_patterns
        self.state = Level.H0

    def parse(self, text):
        paragraphs = text.split('\n\n')

        segment_headings = {Level.H0: self.heading_names[Level.H0]}
        segment = Segment(level=Level.H0, headings=segment_headings, body=[]) # preamble

        for paragraph in paragraphs:

            match = match_heading(paragraph, self.patterns)

            # no heading found, so add paragraph to the current segment
            if not match:
                segment.body.append(paragraph)
                continue

            # found a heading!
            self.document.append(segment) # add the last segment to document

            self.state = match.level
            new_headings = segment.headings.copy()
            new_headings[match.level] = match
            for level in Level: # have to delete the headings at higher levels in case of skips later
                if level.value in new_headings and level.value > match.level.value:
                    del new_headings[level]
            segment = Segment(level=self.state, headings=new_headings, body=[]) # start a new segment
        return self.document
    
    def _str_segment(self, segment: Segment) -> str:
        
        heading = segment.headings[segment.level]
        heading_str = f"Heading {heading.level} {heading.enumeration}: {heading.heading_text}" if heading else ""
        return f"{heading_str}\n\n{segment.body}"

    def __str__(self):
        return "\n\n".join([self._str_segment(segment) for segment in self.document])

## [**Code**] Heading patterns

In [7]:
from openai import OpenAI

openai_client = OpenAI()

def llm(prompt: str, system: str = ""):
    chat_completion = openai_client.chat.completions.create(
        messages=[
            {
                "role": "system",
                "content": system
            },
            {
                "role": "user",
                "content": prompt
            }
        ],
        model="gpt-4o",
    )
    return chat_completion.choices[0].message.content


In [8]:
from textwrap import fill

r = llm("Why is the sky blue?", system="You are Cher from the movie Clueless.")
print(fill(r, 80))

Oh, honey, the sky is blue because of something called Rayleigh scattering. So,
like, when sunlight enters the Earth's atmosphere, it collides with all these
gas molecules. The blue light waves are shorter and scatter more than the other
colors, making the sky look blue to our eyes. It’s kind of like how we choose
the perfect shade of blue for an outfit—it just stands out more! ✨


In [9]:
import marvin

marvin.settings.openai.chat.completions.model = 'gpt-4o'

r = marvin.classify(
    "I could take it or leave it.",
    labels=["positive", "negative", "neutral"],
)

assert r == "neutral"

In [10]:
@marvin.fn
def infer_regex_llm(examples: list[str]) -> str:
    """
    Return a regular expression for document headings matching the
    provided examples. The regular expression should match the example and any similar
    headings, and should not match unrelated text. Leading terms such as
    "Chapter", "Section", "Article", "Title", etc. should be included verbatim 
    in the regular expression. Assume that numbers and letters in the pattern
    can take on a normal range (e.g., if you see 1, 2, 3 as examples you should
    allow other digits like 7 in the match). Trailing descriptive text, following
    a number or letter pattern, will vary from heading to heading (although it
    should still be included in the match if it is on the first line).
    Assume that capitalization is consistent within the document. The regular
    expression should be PCRE-compatible, and expressed as raw strings (e.g.,
    r'^Title \\d+$'. The regular expression should match the beginning of a line
    with '^' (multiline mode), and the end of a line with '$'. There should be no
    newlines in the regular expressions (the examples will just be single lines).
    """

In [11]:
def first_line(s: str) -> str:
    return s.split('\n')[0]

def infer_regex(examples: list[str]) -> str:
    first_lines = [first_line(example) for example in examples]
    return infer_regex_llm(first_lines)

def is_multi_line(examples: list[str]) -> bool:
    return any('\n' in example.strip() for example in examples)

def infer_heading_patterns(example_headings: dict[Level, list[str]]) -> dict[Level, HeadingPattern]:
    """Infer heading patterns from examples. Return a dictionary mapping levels to
    HeadingPattern objects."""    
    return {k: HeadingPattern(level=k, regex=infer_regex(v), multi_line=is_multi_line(v))
            for k, v in example_headings.items()}

In [12]:
nyc_example_headings = {
    Level.H1: ["Title 1: General Provisions\n",
              "Title 2: City of New York\n",
              "Title 3: Elected officials\n",
    ],
    Level.H2: ["Chapter 1: Powers and Rights of the Corporation; Emblems and Insignia\n",
              "Chapter 2: Boundaries of the City\n",
              "Chapter 4: Board of Estimate\n",
     ],
    Level.H3: ["§ 2-101 Name; powers and rights of the corporation; seal.\n",
              "§ 2-202 Division into boroughs and boundaries thereof.\n",
              "§ 3-140 Office of labor standards.\n",
      ],
}

chicago_example_headings = {
    Level.H1: ["TITLE 1\nGENERAL PROVISION\n",
              "TITLE 2\nCITY GOVERNMENT AND ADMINISTRATION\n",
              "TITLE 3\nREVENUE AND FINANCE\n",
    ],
    Level.H2: ["CHAPTER 1-4\nCODE ADOPTION - ORGANIZATION\n",
              "CHAPTER 1-8\nCITY SEAL AND FLAG\n",
              "CHAPTER 1-12\nCITY EMBLEMS\n",
     ],
    Level.H3: ["1-4-010 Municipal Code of Chicago adopted.\n",
              "2-1-020 Code to be kept up-to-date.\n",
              "3-4-030 Official copy on file.\n",
      ],
}

losangeles_example_headings = {
    Level.H1: ["CHAPTER I\nGENERAL PROVISIONS AND ZONING\n",
              "CHAPTER IV\nPUBLIC WELFARE",
              "CHAPTER VII\nTRANSPORTATION\n",
    ],
    Level.H2: ["ARTICLE 1\nGENERAL PROVISIONS\n",
              "ARTICLE 4\nPUBLIC BENEFIT PROJECTS\n",
              "ARTICLE 4.3\nELDERCARE FACILITY UNIFIED PERMIT PROCESS\n",
     ],
    Level.H3: ["SEC. 11.00. PROVISIONS APPLICABLE TO CODE.\n",
              "SEC. 11.01. DEFINITIONS AND INTERPRETATION.\n",
              "SEC. 14.4.1. PURPOSE.\n",
      ],
}

In [13]:
nyc_patterns = infer_heading_patterns(nyc_example_headings)
print(nyc_patterns)
chicago_patterns = infer_heading_patterns(chicago_example_headings)
print(chicago_patterns)
la_patterns = infer_heading_patterns(losangeles_example_headings)
print(la_patterns)

{<Level.H1: 1>: HeadingPattern(level=<Level.H1: 1>, regex="r'^Title \\d+: .+$'", multi_line=False), <Level.H2: 2>: HeadingPattern(level=<Level.H2: 2>, regex="r'^Chapter \\d+: .+$'", multi_line=False), <Level.H3: 3>: HeadingPattern(level=<Level.H3: 3>, regex="r'^§ \\d+-\\d+ .+$'", multi_line=False)}
{<Level.H1: 1>: HeadingPattern(level=<Level.H1: 1>, regex="r'^TITLE \\d+$'", multi_line=True), <Level.H2: 2>: HeadingPattern(level=<Level.H2: 2>, regex="r'^CHAPTER \\d+-\\d+$'", multi_line=True), <Level.H3: 3>: HeadingPattern(level=<Level.H3: 3>, regex="r'^\\d+-\\d+-\\d+ .+$'", multi_line=False)}
{<Level.H1: 1>: HeadingPattern(level=<Level.H1: 1>, regex="r'^CHAPTER [IVXLCDM]+$'", multi_line=True), <Level.H2: 2>: HeadingPattern(level=<Level.H2: 2>, regex="r'^ARTICLE \\d+(\\.\\d+)?$'", multi_line=True), <Level.H3: 3>: HeadingPattern(level=<Level.H3: 3>, regex="r'^SEC\\. \\d+\\.\\d*(?:\\.\\d+)*\\..*$'", multi_line=False)}


## [**Action**] Specify heading patterns

In [14]:
example_headings = {
    Level.H1: ["TITLE 1\nGENERAL PROVISION\n",
              "TITLE 2\nCITY GOVERNMENT AND ADMINISTRATION\n",
              "TITLE 3\nREVENUE AND FINANCE\n",
    ],
    Level.H2: ["CHAPTER 1-4\nCODE ADOPTION - ORGANIZATION\n",
              "CHAPTER 1-8\nCITY SEAL AND FLAG\n",
              "CHAPTER 1-12\nCITY EMBLEMS\n",
     ],
    Level.H3: ["1-4-010 Municipal Code of Chicago adopted.\n",
              "2-1-020 Code to be kept up-to-date.\n",
              "3-4-030 Official copy on file.\n",
      ],
}

In [33]:
chicago_mini_patterns = infer_heading_patterns(example_headings)
print(chicago_patterns)

{<Level.H1: 1>: HeadingPattern(level=<Level.H1: 1>, regex="r'^TITLE \\d+$'", multi_line=True), <Level.H2: 2>: HeadingPattern(level=<Level.H2: 2>, regex="r'^CHAPTER \\d+-\\d+$'", multi_line=True), <Level.H3: 3>: HeadingPattern(level=<Level.H3: 3>, regex="r'^\\d+-\\d+-\\d+ .+$'", multi_line=False)}


## [**Code**] Ingest municipal code into database & create embeddings

In [34]:
@marvin.fn
def infer_level_name(pattern: HeadingPattern) -> str:
    """Infer level names from the regular expressions in the patterns. For example,
    if pattern.regex is r'^Title \\d+$', the level name would be 'Title'. The name
    should be a string starting with a capital letter, followed by lowercase letters,
    with no punctuation. If there is no clear name (e.g., if the pattern is
    r'^\\d+\\-\\d+\\-\\d+'), return 'Section'.
    """

# shouldn't be necessary, but the LLM ignores the instructions about letters
def letters_only(s: str) -> str:
    return ''.join(c for c in s if c.isalpha())

def infer_level_names(patterns: dict[Level, HeadingPattern]) -> dict[Level, str]:
    return {k: letters_only(infer_level_name(v)) for k, v in patterns.items()}

infer_level_names(chicago_mini_patterns)

{<Level.H1: 1>: 'Title', <Level.H2: 2>: 'Chapter', <Level.H3: 3>: 'Section'}

In [35]:
@dataclass
class Jurisdiction:
    name: str
    patterns: dict[Level, HeadingPattern]
    source_local: str = ''
    source_url: str = ''
    raw_text: str = ''
    parser: StateMachineParser | None = None
    document: list[Segment] = field(default_factory=list)
    autoload: bool = True    

    def __post_init__(self):
        if self.autoload:
            self.load()

    def load(self):
        """Loads the text of local code from file (source_local)."""
        try:
            with open(self.source_local, "r") as f:
                self.raw_text = f.read()
        except FileNotFoundError as e:
            print(f"Error reading {self.source_local}: {e}")

# FIXME move to Action section below
chicago_mini = Jurisdiction(
    name="Chicago Mini",
    patterns=chicago_mini_patterns,
    source_local="../data/chicago-mini/code.txt",
    source_url="https://www.chicago.gov/city/en/depts/doit/supp_info/municipal_code.html",
)

In [36]:
chicago_mini.parser = StateMachineParser(document_name="Chicago Mini Code", heading_patterns=chicago_mini.patterns)
chicago_mini.document = chicago_mini.parser.parse(chicago_mini.raw_text)

In [39]:
len(chicago_mini.raw_text)

661150

In [None]:
from psycopg import connect
from muni.llm import create_embedding, summarize
from muni.structure import Node

RESET = False
EMBEDDING_LENGTH = len(create_embedding("test"))

def connection():
    return connect(
        dbname="regrag",
        host="localhost",
        port="5432",
        autocommit=True
    )

with connection() as conn:
    if RESET:
        with conn.cursor() as cursor:
            cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
            cursor.execute("DROP TABLE IF EXISTS muni_associations;")
            cursor.execute("DROP TABLE IF EXISTS muni CASCADE;")

    with conn.cursor() as cursor:
        cursor.execute(
            """
            CREATE TABLE IF NOT EXISTS muni_associations (
                jurisdiction TEXT,
                association TEXT,
                left_id INTEGER,
                right_id INTEGER
            );
            """)
        cursor.execute(
            """
            CREATE TABLE IF NOT EXISTS muni (
                id SERIAL PRIMARY KEY,
                jurisdiction TEXT,
                L1_ref TEXT, L1_heading TEXT,
                L2_ref TEXT, L2_heading TEXT,
                L3_ref TEXT, L3_heading TEXT,
                L4_ref TEXT, L4_heading TEXT,
                segment INTEGER,
                text TEXT,
                embedding VECTOR(%s)
            );
            """, (EMBEDDING_LENGTH,))
        cursor.execute(
            """
            ALTER TABLE muni
                ADD COLUMN IF NOT EXISTS textsearchable tsvector
                    GENERATED ALWAYS AS
                    (to_tsvector('english',
                        coalesce(jurisdiction, '') || ' ' ||
                        coalesce(L1_heading, '') || ' ' ||
                        coalesce(L2_heading, '') || ' ' ||
                        coalesce(L3_heading, '') || ' ' ||
                        coalesce(L4_heading, '') || ' ' ||
                        coalesce(text, '') || ' '))
                    STORED;
            """
        )
        cursor.execute(
            """
            DROP INDEX IF EXISTS muni_fulltext;
            CREATE INDEX muni_fulltext ON muni USING GIN (textsearchable);
            """
        )

def node_embedding(node: Node) -> list[float]:
    pre = '\n'.join(list(node.metadata['headings'].values()))
    summary = summarize(node.text)
    if summary is not None:
        embedding_text = pre + summary
    else:
        embedding_text = pre
    return create_embedding(embedding_text)

def upload(node: Node) -> None:
    if node.text:
        references = node.metadata['references']
        headings = node.metadata['headings']
        with connection() as conn:
            with conn.cursor() as cursor:
                cursor.execute(
                    """
                    INSERT INTO muni (
                        jurisdiction,
                        L1_ref, L1_heading,
                        L2_ref, L2_heading,
                        L3_ref, L3_heading,
                        L4_ref, L4_heading,
                        segment,
                        text,
                        embedding
                    ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s);
                    """,
                    (
                        "Chicago",
                        references.get("title", ""),   headings.get("title", ""),   # L1
                        references.get("chapter", ""), headings.get("chapter", ""), # L2
                        references.get("article", ""), headings.get("article", ""), # L3
                        references.get("section", ""), headings.get("section", ""), # L3
                        0, # can add break-down segments later for large text blocks
                        node.text,
                        node_embedding(node),
                    )
                )
    
    if not node.children:
        return
    for child in node.children:
        upload(child)

## [**Action**] Upload municipal code

In [None]:
chicago = Jurisdiction(
    name="Chicago",
    hierarchy={
        "title":   r"TITLE \d+",
        "chapter": r"CHAPTER \d+-\d+",
        "article": r"ARTICLE [IVX]+\\.",
        "section": r"\d+-\d+-\d+",
    },
    source_local="../data/chicago/chicago.txt",
    source_url="https://www.chicago.gov/city/en/sites/covid-19/home.html",
)
chicago_tree = chicago.parse()

upload(chicago_tree)

## [**Code**] Find associations among sections

In [None]:
# Go through rows in the muni database and identify definitions

from muni.llm import definition, analyze_context

sql_select = """
    SELECT  id,
        L1_ref, L1_heading,
        L2_ref, L2_heading,
        L3_ref, L3_heading,
        L4_ref, L4_heading,
        text
    FROM muni;
    """

sql_unique = """
    BEGIN
        IF NOT EXISTS (
            SELECT FROM pg_constraint
            WHERE conname = 'unique_associations')
            AND   conrelid = 'muni_associations'::regclass
        ) 
        THEN
            ALTER TABLE muni_associations
            ADD CONSTRAINT unique_associations UNIQUE (jurisdiction, association, left_id, right_id);
        END IF;
    END;
    """

sql_assoc = """
    INSERT INTO muni_associations (jurisdiction, association, left_id, right_id)
    VALUES (%s, %s, %s, %s)
    ON CONFLICT (jurisdiction, association, left_id, right_id) DO NOTHING;
    """

def scope_map(scope):
    """For a given scope, what are the columns in muni that need to match?"""
    table = {'global': ['jurisdiction'],
             'title': ['jurisdiction', 'L1_ref'],
             'chapter': ['jurisdiction', 'L1_ref', 'L2_ref'],
             'article': ['jurisdiction', 'L1_ref', 'L2_ref', 'L3_ref'],
             'section': ['jurisdiction', 'L1_ref', 'L2_ref', 'L3_ref', 'L4_ref']
             }
    if scope not in table.keys():
        return None
    return table[scope]

In [None]:

def set_associations(conn, id_, scope, context_type):
    """Set associations with a row in muni with all rows matching the scope.
    Args:
        conn: a connection to the database
        id_: the id of the row to associate
        scope: the scope of the association (e.g. 'title', 'chapter', 'article', 'section')
        context_type: the type of association (e.g. 'definition')
    """
    with conn.cursor() as cursor:
        # get the jurisdiction and the references
        cursor.execute(f"SELECT jurisdiction, L1_ref, L2_ref, L3_ref, L4_ref FROM muni WHERE id = {id_}")
        jurisdiction, L1_ref, L2_ref, L3_ref, L4_ref = cursor.fetchone()
        # get the columns that need to match
        columns = scope_map(scope)
        if not columns:
            return
        # get the rows that match the scope
        match_str = ' AND '.join([f"{col} = '{val}'" for col, val in zip(columns, [jurisdiction, L1_ref, L2_ref, L3_ref, L4_ref])])
        cursor.execute(f"SELECT id FROM muni WHERE {match_str} AND id != {id_}")
        rows = cursor.fetchall()
        # set the associations
        for row in rows:
            cursor.execute(sql_assoc, (jurisdiction, context_type, id_, row[0]))

def find_associations(conn):
    allowed_types = ['penalty', 'definition', 'interpretation', 'date']
    with conn.cursor() as cursor:
        cursor.execute(sql_select)
        rows = cursor.fetchall()
        for row in rows:
            id_, L1_ref, L1_heading, L2_ref, L2_heading, L3_ref, L3_heading, L4_ref, L4_heading, text = row
            headings = {'title': L1_heading, 'chapter': L2_heading, 'article': L3_heading, 'section': L4_heading}
            r = analyze_context(text, headings, model='gpt-4')
            if r:
                context_type, scope = r
                if context_type in allowed_types:
                    print(f"* Setting associations for id {id_}")
                    print(f"  Context type: {context_type}; Scope: {scope}")
                    print("  --> %s ..." % text[:80].replace('\n', ' '))
                    set_associations(conn, id_, scope, context_type)

## [**Code**] Hybrid search

In [None]:
def simple_semantic_query(conn, query, limit=10):
    query_embedding = create_embedding(query)
    with conn.cursor() as cursor:
        sql = """
        SELECT id, L4_heading, text
        FROM muni
        WHERE jurisdiction = 'Chicago'
        ORDER BY embedding <=> %s
        LIMIT %s;
        """
        cursor.execute(sql, (str(query_embedding), limit))
        return cursor.fetchall()
        
#with connection() as conn:        
#    results = simple_semantic_query(conn, 'drug paraphernalia')
#for r in results:
#    print(r)

In [None]:
def simple_full_text_query(conn, query, limit=10):
    with conn.cursor() as cursor:
        sql = """
        WITH tsq AS (
            SELECT to_tsquery('english', %s) AS search
            )
        SELECT id, L4_heading, text
        FROM muni, tsq
        WHERE jurisdiction = 'Chicago'
        AND textsearchable @@ tsq.search
        ORDER BY ts_rank_cd(textsearchable, tsq.search)
        LIMIT %s;
        """
        cursor.execute(sql, (query, limit))
        return cursor.fetchall()

#with connection() as conn:        
#    results = simple_full_text_query(conn, 'drug & paraphernalia')
#for r in results:
#     print(r)

In [None]:
# Now we do a more complicated hybrid search, borrowing and adapting from 
# https://github.com/pgvector/pgvector-python/blob/master/examples/hybrid_search_rrf.py

def hybrid_query(conn, query, limit=10):
    embedding = create_embedding(query)

    sql = """
    WITH semantic_search AS (
        SELECT id, L4_heading, RANK () OVER (ORDER BY embedding <=> %(embedding)s) AS rank
        FROM muni
        ORDER BY embedding <=> %(embedding)s
        LIMIT 20
    ),
    keyword_search AS (
        SELECT id, L4_heading, RANK () OVER (ORDER BY ts_rank_cd(textsearchable, query) DESC)
        FROM muni, plainto_tsquery('english', %(query)s) query
        WHERE textsearchable @@ query
        ORDER BY ts_rank_cd(textsearchable, query) DESC
        LIMIT 20
    )
    SELECT
        COALESCE(semantic_search.id, keyword_search.id) AS id,
        COALESCE(1.0 / (%(k)s + semantic_search.rank), 0.0) +
        COALESCE(1.0 / (%(k)s + keyword_search.rank), 0.0) AS score,
        COALESCE(semantic_search.L4_heading, keyword_search.L4_heading) AS L4_heading
    FROM semantic_search
    FULL OUTER JOIN keyword_search ON semantic_search.id = keyword_search.id
    ORDER BY score DESC
    LIMIT %(limit)s;
    """
    result = conn.execute(sql, {'query': query, 'embedding': str(embedding), 'limit': limit, 'k': 60})
    return result.fetchall()

#with connection() as conn:
#    results = hybrid_query(conn, 'drug paraphernalia')

#for row in results:
#    print(row)

In [None]:
# Try query augmentation using Hyde (generation of synthetic replies matching the
# format of expected answers)
from muni.llm import augmented_embedding

def augmented_query(conn, query, limit=10):
    query_embedding = augmented_embedding(query, orig_weight = 0.5)
    with conn.cursor() as cursor:
        sql = """
        SELECT id, L4_heading, text
        FROM muni
        WHERE jurisdiction = 'Chicago'
        ORDER BY embedding <=> %s
        LIMIT %s;
        """
        cursor.execute(sql, (str(query_embedding), limit))
        return cursor.fetchall()
        
#with connection() as conn:        
#    results = augmented_query(conn, 'Does the code restrict drug paraphernalia')
#for r in results:
#    print(r)