Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binder Exception when reading ARRAY type from Parquet (for embeddings) #3481

Open
prrao87 opened this issue May 13, 2024 · 3 comments
Open
Assignees
Labels
usability Issues related to better usability experience, including bad error messages

Comments

@prrao87
Copy link
Member

prrao87 commented May 13, 2024

I'm pushing typecasting to the limit here 😅. Am basically trying to ensure I have fine-grained control over each column's data type all the way from Python to Parquet and then into Kùzu. My aim is to replicate a typical Python workflow that would be used in similarity search.

Dependencies

I'll be using sentence-transformers to generate embeddings from raw text, as is common in many real-world scenarios.

pip install pyarrow polars kuzu sentence-transformers

Code

I first write out the vectors computed by an embedding model (from sentence-transformers) alongside the raw data to Parquet, so that I can bulk-import the data to Kùzu (computing vectors/embeddings via a model is expensive, so this would be pre-computed in a real scenario).

Note that I explicitly typecast the integers from Python (by default INT64) to UINT64 so that I can have unsigned integers in Kùzu per the graph schema below.

import os
import shutil
import kuzu


def create_db(conn):
    conn.execute(
        """
        CREATE NODE TABLE Person(
            id UINT64,
            name STRING,
            age UINT8,
            PRIMARY KEY (id)
        )
        """
    )

    conn.execute(
        """
        CREATE NODE TABLE Item(
            id UINT64,
            name STRING,
            vector DOUBLE[384],
            PRIMARY KEY (id)
        )
        """
    )

    conn.execute(
        """
        CREATE REL TABLE Purchased(
            FROM Person
            TO Item
        )
        """
    )


def write_data_to_parquet():
    import warnings
    import polars as pl
    from sentence_transformers import SentenceTransformer
    warnings.filterwarnings("ignore")

    model = SentenceTransformer("Snowflake/snowflake-arctic-embed-xs")

    persons = [
        {"id": 1, "name": "Karissa", "age": 25},
        {"id": 2, "name": "Zhang", "age": 29},
        {"id": 3, "name": "Noura", "age": 31},
    ]

    items = [
        {"id": 1, "name": "espresso machine", "vector": list(model.encode("espresso machine"))},
        {"id": 2, "name": "yoga mat", "vector": list(model.encode("yoga mat"))},
    ]

    purchased = [
        {"from": 1, "to": 1},
        {"from": 1, "to": 2},
        {"from": 2, "to": 1},
        {"from": 3, "to": 2},
    ]
    # Carefully typecast in Polars prior to exporting to Parquet so we can have unsigned integers in Kùzu
    df_persons = pl.DataFrame(persons).with_columns(
        pl.col("id").cast(pl.UInt64),
        pl.col("age").cast(pl.UInt8),
    )
    # Ensure that the `ARRAY` data type is output for the `vector` column prior to exporting to Parquet
    df_items = pl.DataFrame(items).with_columns(
        pl.col("id").cast(pl.UInt64),
        pl.col("vector").cast(pl.Array(pl.Float64, width=384)),
    )
    df_purchased = pl.DataFrame(purchased).with_columns(
        pl.col("from").cast(pl.UInt64),
        pl.col("to").cast(pl.UInt64),
    )
    print(df_persons)
    print(df_items)
    df_persons.write_parquet("persons.parquet")
    df_items.write_parquet("items.parquet")
    df_purchased.write_parquet("purchased.parquet")


def build_graph(conn):
    conn.execute(
        """
        COPY Person FROM 'persons.parquet';
        COPY Item FROM 'items.parquet';
        COPY Purchased FROM 'purchased.parquet';
        """
    )
    print("Finished importing nodes and rels")


if __name__ == "__main__":
    if os.path.exists("./vdb"):
        shutil.rmtree("./vdb")

    # Create database
    db = kuzu.Database("./vdb")
    conn = kuzu.Connection(db)
    create_db(conn)

    write_data_to_parquet()

    # Load data from parquet to graph
    build_graph(conn)

Error

Running the above code gives the following error:

Traceback (most recent call last):
  File "/Users/prrao/code/kuzu-debug/load_graph_similarity.py", line 107, in <module>
    build_graph(conn)
  File "/Users/prrao/code/kuzu-debug/load_graph_similarity.py", line 85, in build_graph
    conn.execute(
  File "/Users/prrao/code/kuzu-debug/.venv/lib/python3.11/site-packages/kuzu/connection.py", line 144, in execute
    raise RuntimeError(_query_result.getErrorMessage())
RuntimeError: Binder exception: Column `vector` type mismatch. Expected DOUBLE[384] but got DOUBLE[].

Workaround

The error disappears when I change the schema to specify the vector column as a LIST, i.e., by stating vector DOUBLE[] in the schema.

Desired behaviour

If we have the node table storing ARRAY values (based on the imported fields from Parquet), we would be able to easily run array functions for similarity search by providing a simple Python list, as the similarity search functions require at least one of the two arguments to be of type ARRAY for it to perform implicit casting. This would make the downstream Cypher query that performs similarity search less verbose and cleaner to write.

Currently, I have to write this query for it to work (requires explicit casting and the user to know a lot more about the correct syntax):

res = conn.execute(
    """
    MATCH (i:Item)
    WITH i, CAST($query_vector, "DOUBLE[384]") AS query_vector
    RETURN i.name as name, array_cosine_similarity(i.vector, query_vector) AS similarity
    ORDER BY similarity DESC
    """,
    parameters={"query_vector": query_vector}
)

What I want to be able to write is the below simpler query that doesn't require explicit casting by the user:

res = conn.execute(
    """
    MATCH (i:Item)
    RETURN i.name as name, array_cosine_similarity(i.vector, $query_vector) AS similarity
    ORDER BY similarity DESC
    """,
    parameters={"query_vector": query_vector}
)
@prrao87 prrao87 added the usability Issues related to better usability experience, including bad error messages label May 13, 2024
@acquamarin
Copy link
Collaborator

df_items.write_parquet("items.parquet")
I think the write_parquet method of pandas dataframe marks the vector field as double[] instead of double[384]. However, kuzu should perform an implicit casting from double[] to double[384] while doing the copy.
There is already a similar issue:
#2215

@prrao87
Copy link
Member Author

prrao87 commented May 14, 2024

But doesn't this part address the type of the array and fix the width at 384?

    # Ensure that the `ARRAY` data type is output for the `vector` column prior to exporting to Parquet
    df_items = pl.DataFrame(items).with_columns(
        pl.col("id").cast(pl.UInt64),
        pl.col("vector").cast(pl.Array(pl.Float64, width=384)),
    )

I confirmed that this is the structure prior to export:

shape: (2, 3)
┌─────┬──────────────────┬─────────────────────────────────┐
│ id  ┆ name             ┆ vector                          │
│ --- ┆ ---              ┆ ---                             │
│ u64 ┆ str              ┆ array[f64, 384]                 │
╞═════╪══════════════════╪═════════════════════════════════╡
│ 1   ┆ espresso machine ┆ [0.051414, -0.015042, … -0.017… │
│ 2   ┆ yoga mat         ┆ [-0.018925, -0.034097, … 0.002… │
└─────┴──────────────────┴─────────────────────────────────┘

The underlying type is an Arrow array with a width of 384.

@acquamarin
Copy link
Collaborator

Yes, however if you write the structure to parquet, the column type of vector becomes double[] instead of double[384]
You can check the schema of the parquet file using:

parquet_file = pq.ParquetFile("items.parquet")
print(parquet_file.schema)

and i got:

<pyarrow._parquet.ParquetSchema object at 0x8f818c040>
required group field_id=-1 root {
  optional int64 field_id=-1 id (Int(bitWidth=64, isSigned=false));
  optional binary field_id=-1 name (String);
  optional group field_id=-1 vector (List) {
    repeated group field_id=-1 list {
      optional double field_id=-1 item;
    }
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
usability Issues related to better usability experience, including bad error messages
Projects
None yet
Development

No branches or pull requests

2 participants