Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add merge_insert to the node and rust APIs #915

Merged
merged 10 commits into from Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions Cargo.toml
Expand Up @@ -11,10 +11,10 @@ license = "Apache-2.0"
repository = "https://github.com/lancedb/lancedb"

[workspace.dependencies]
lance = { "version" = "=0.9.10", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.9.10" }
lance-linalg = { "version" = "=0.9.10" }
lance-testing = { "version" = "=0.9.10" }
lance = { "version" = "=0.9.12", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.9.12" }
lance-linalg = { "version" = "=0.9.12" }
lance-testing = { "version" = "=0.9.12" }
# Note that this one does not include pyarrow
arrow = { version = "50.0", optional = false }
arrow-array = "50.0"
Expand Down
95 changes: 95 additions & 0 deletions node/src/index.ts
Expand Up @@ -37,6 +37,7 @@ const {
tableCountRows,
tableDelete,
tableUpdate,
tableMergeInsert,
tableCleanupOldVersions,
tableCompactFiles,
tableListIndices,
Expand Down Expand Up @@ -440,6 +441,38 @@ export interface Table<T = number[]> {
*/
update: (args: UpdateArgs | UpdateSqlArgs) => Promise<void>

/**
* Runs a "merge insert" operation on the table
*
* This operation can add rows, update rows, and remove rows all in a single
* transaction. It is a very generic tool that can be used to create
* behaviors like "insert if not exists", "update or insert (i.e. upsert)",
* or even replace a portion of existing data with new data (e.g. replace
* all data where month="january")
*
* The merge insert operation works by combining new data from a
* **source table** with existing data in a **target table** by using a
* join. There are three categories of records.
*
* "Matched" records are records that exist in both the source table and
* the target table. "Not matched" records exist only in the source table
* (e.g. these are new data) "Not matched by source" records exist only
* in the target table (this is old data)
*
* The MergeInsertArgs can be used to customize what should happen for
* each category of data.
*
* Please note that the data may appear to be reordered as part of this
* operation. This is because updated rows will be deleted from the
* dataset and then reinserted at the end with the new values.
*
* @param on a column to join on. This is how records from the source
* table and target table are matched.
* @param data the new data to insert
* @param args parameters controlling how the operation should behave
*/
mergeInsert: (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs) => Promise<void>

/**
* List the indicies on this table.
*/
Expand Down Expand Up @@ -483,6 +516,36 @@ export interface UpdateSqlArgs {
valuesSql: Record<string, string>
}

export interface MergeInsertArgs {
/**
* If true then rows that exist in both the source table (new data) and
* the target table (old data) will be updated, replacing the old row
* with the corresponding matching row.
*
* If there are multiple matches then the behavior is undefined.
* Currently this causes multiple copies of the row to be created
* but that behavior is subject to change.
*/
whenMatchedUpdateAll?: boolean
/**
* If true then rows that exist only in the source table (new data)
* will be inserted into the target table.
*/
whenNotMatchedInsertAll?: boolean
/**
* If true then rows that exist only in the target table (old data)
* will be deleted.
*
* If this is a string then it will be treated as an SQL filter and
* only rows that both do not match any row in the source table and
* match the given filter will be deleted.
*
* This can be used to replace a selection of existing data with
* new data.
*/
whenNotMatchedBySourceDelete?: string | boolean
}

export interface VectorIndex {
columns: string[]
name: string
Expand Down Expand Up @@ -821,6 +884,38 @@ export class LocalTable<T = number[]> implements Table<T> {
})
}

async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
const whenMatchedUpdateAll = args.whenMatchedUpdateAll ?? false
const whenNotMatchedInsertAll = args.whenNotMatchedInsertAll ?? false
let whenNotMatchedBySourceDelete = false
let whenNotMatchedBySourceDeleteFilt = null
if (args.whenNotMatchedBySourceDelete !== undefined && args.whenNotMatchedBySourceDelete !== null) {
whenNotMatchedBySourceDelete = true
if (args.whenNotMatchedBySourceDelete !== true) {
whenNotMatchedBySourceDeleteFilt = args.whenNotMatchedBySourceDelete
}
}

const schema = await this.schema
let tbl: ArrowTable
if (data instanceof ArrowTable) {
tbl = data
} else {
tbl = makeArrowTable(data, { schema })
}
const buffer = await fromTableToBuffer(tbl, this._embeddings, schema)

this._tbl = await tableMergeInsert.call(
this._tbl,
on,
whenMatchedUpdateAll,
whenNotMatchedInsertAll,
whenNotMatchedBySourceDelete,
whenNotMatchedBySourceDeleteFilt,
buffer
)
}

/**
* Clean up old versions of the table, freeing disk space.
*
Expand Down
49 changes: 48 additions & 1 deletion node/src/remote/index.ts
Expand Up @@ -24,7 +24,8 @@ import {
type IndexStats,
type UpdateArgs,
type UpdateSqlArgs,
makeArrowTable
makeArrowTable,
type MergeInsertArgs
} from '../index'
import { Query } from '../query'

Expand Down Expand Up @@ -274,6 +275,52 @@ export class RemoteTable<T = number[]> implements Table<T> {
throw new Error('Not implemented')
}

async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
let tbl: ArrowTable
if (data instanceof ArrowTable) {
tbl = data
} else {
tbl = makeArrowTable(data, await this.schema)
}

const queryParams: any = {
on
}
if (args.whenMatchedUpdateAll ?? false) {
queryParams.when_matched_update_all = 'true'
} else {
queryParams.when_matched_update_all = 'false'
}
if (args.whenNotMatchedInsertAll ?? false) {
queryParams.when_not_matched_insert_all = 'true'
} else {
queryParams.when_not_matched_insert_all = 'false'
}
if (args.whenNotMatchedBySourceDelete !== false && args.whenNotMatchedBySourceDelete !== null && args.whenNotMatchedBySourceDelete !== undefined) {
queryParams.when_not_matched_by_source_delete = 'true'
if (typeof args.whenNotMatchedBySourceDelete === 'string') {
queryParams.when_not_matched_by_source_delete_filt = args.whenNotMatchedBySourceDelete
}
} else {
queryParams.when_not_matched_by_source_delete = 'false'
}

const buffer = await fromTableToStreamBuffer(tbl, this._embeddings)
const res = await this._client.post(
`/v1/table/${this._name}/merge_insert/`,
buffer,
queryParams,
'application/vnd.apache.arrow.stream'
)
if (res.status !== 200) {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
)
}
}

async add (data: Array<Record<string, unknown>> | ArrowTable): Promise<number> {
let tbl: ArrowTable
if (data instanceof ArrowTable) {
Expand Down
38 changes: 38 additions & 0 deletions node/src/test/test.ts
Expand Up @@ -531,6 +531,44 @@ describe('LanceDB client', function () {
assert.equal(await table.countRows(), 2)
})

it('can merge insert records into the table', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)

const data = [{ id: 1, age: 1 }, { id: 2, age: 1 }]
const table = await con.createTable('my_table', data)

let newData = [{ id: 2, age: 2 }, { id: 3, age: 2 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true
})
assert.equal(await table.countRows(), 3)
assert.equal((await table.filter('age = 2').execute()).length, 1)

newData = [{ id: 3, age: 3 }, { id: 4, age: 3 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true
})
assert.equal(await table.countRows(), 4)
assert.equal((await table.filter('age = 3').execute()).length, 2)

newData = [{ id: 5, age: 4 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true,
whenNotMatchedBySourceDelete: 'age < 3'
})
assert.equal(await table.countRows(), 3)

await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true,
whenNotMatchedBySourceDelete: true
})
assert.equal(await table.countRows(), 1)
})

it('can update records in the table', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
Expand Down
25 changes: 21 additions & 4 deletions python/lancedb/merge.py
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Iterable, Optional
from typing import TYPE_CHECKING, List, Optional

if TYPE_CHECKING:
from .common import DATA
Expand All @@ -25,7 +25,7 @@ class LanceMergeInsertBuilder(object):
more context
"""

def __init__(self, table: "Table", on: Iterable[str]): # noqa: F821
def __init__(self, table: "Table", on: List[str]): # noqa: F821
# Do not put a docstring here. This method should be hidden
# from API docs. Users should use merge_insert to create
# this object.
Expand Down Expand Up @@ -77,10 +77,27 @@ def when_not_matched_by_source_delete(
self._when_not_matched_by_source_condition = condition
return self

def execute(self, new_data: DATA):
def execute(
self,
new_data: DATA,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
):
"""
Executes the merge insert operation

Nothing is returned but the [`Table`][lancedb.table.Table] is updated

Parameters
----------
new_data: DATA
New records which will be matched against the existing records
to potentially insert or update into the table. This parameter
can be anything you use for [`add`][lancedb.table.Table.add]
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
self._table._do_merge(self, new_data)
self._table._do_merge(self, new_data, on_bad_vectors, fill_value)
44 changes: 41 additions & 3 deletions python/lancedb/remote/table.py
Expand Up @@ -19,6 +19,7 @@
from lance import json_to_schema

from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from lancedb.merge import LanceMergeInsertBuilder

from ..query import LanceVectorQueryBuilder
from ..table import Query, Table, _sanitize_data
Expand Down Expand Up @@ -244,9 +245,46 @@ def _execute_query(self, query: Query) -> pa.Table:
result = self._conn._client.query(self._name, query)
return result.to_arrow()

def _do_merge(self, *_args):
"""_do_merge() is not supported on the LanceDB cloud yet"""
return NotImplementedError("_do_merge() is not supported on the LanceDB cloud")
def _do_merge(
self,
merge: LanceMergeInsertBuilder,
new_data: DATA,
on_bad_vectors: str,
fill_value: float,
):
data = _sanitize_data(
new_data,
self.schema,
metadata=None,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
payload = to_ipc_binary(data)

params = {}
if len(merge._on) != 1:
raise ValueError(
"RemoteTable only supports a single on key in merge_insert"
)
params["on"] = merge._on[0]
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
params["when_not_matched_insert_all"] = str(
merge._when_not_matched_insert_all
).lower()
params["when_not_matched_by_source_delete"] = str(
merge._when_not_matched_by_source_delete
).lower()
if merge._when_not_matched_by_source_condition is not None:
params[
"when_not_matched_by_source_delete_filt"
] = merge._when_not_matched_by_source_condition

self._conn._client.post(
f"/v1/table/{self._name}/merge_insert/",
data=payload,
params=params,
content_type=ARROW_STREAM_CONTENT_TYPE,
)

def delete(self, predicate: str):
"""Delete rows from the table.
Expand Down