Skip to content

Commit

Permalink
feat: add merge_insert to the node and rust APIs (#915)
Browse files Browse the repository at this point in the history
  • Loading branch information
westonpace committed Feb 2, 2024
1 parent 09cd082 commit 7f8637a
Show file tree
Hide file tree
Showing 11 changed files with 565 additions and 18 deletions.
8 changes: 4 additions & 4 deletions Cargo.toml
Expand Up @@ -11,10 +11,10 @@ license = "Apache-2.0"
repository = "https://github.com/lancedb/lancedb"

[workspace.dependencies]
lance = { "version" = "=0.9.10", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.9.10" }
lance-linalg = { "version" = "=0.9.10" }
lance-testing = { "version" = "=0.9.10" }
lance = { "version" = "=0.9.12", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.9.12" }
lance-linalg = { "version" = "=0.9.12" }
lance-testing = { "version" = "=0.9.12" }
# Note that this one does not include pyarrow
arrow = { version = "50.0", optional = false }
arrow-array = "50.0"
Expand Down
95 changes: 95 additions & 0 deletions node/src/index.ts
Expand Up @@ -37,6 +37,7 @@ const {
tableCountRows,
tableDelete,
tableUpdate,
tableMergeInsert,
tableCleanupOldVersions,
tableCompactFiles,
tableListIndices,
Expand Down Expand Up @@ -440,6 +441,38 @@ export interface Table<T = number[]> {
*/
update: (args: UpdateArgs | UpdateSqlArgs) => Promise<void>

/**
* Runs a "merge insert" operation on the table
*
* This operation can add rows, update rows, and remove rows all in a single
* transaction. It is a very generic tool that can be used to create
* behaviors like "insert if not exists", "update or insert (i.e. upsert)",
* or even replace a portion of existing data with new data (e.g. replace
* all data where month="january")
*
* The merge insert operation works by combining new data from a
* **source table** with existing data in a **target table** by using a
* join. There are three categories of records.
*
* "Matched" records are records that exist in both the source table and
* the target table. "Not matched" records exist only in the source table
* (e.g. these are new data) "Not matched by source" records exist only
* in the target table (this is old data)
*
* The MergeInsertArgs can be used to customize what should happen for
* each category of data.
*
* Please note that the data may appear to be reordered as part of this
* operation. This is because updated rows will be deleted from the
* dataset and then reinserted at the end with the new values.
*
* @param on a column to join on. This is how records from the source
* table and target table are matched.
* @param data the new data to insert
* @param args parameters controlling how the operation should behave
*/
mergeInsert: (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs) => Promise<void>

/**
* List the indicies on this table.
*/
Expand Down Expand Up @@ -483,6 +516,36 @@ export interface UpdateSqlArgs {
valuesSql: Record<string, string>
}

export interface MergeInsertArgs {
/**
* If true then rows that exist in both the source table (new data) and
* the target table (old data) will be updated, replacing the old row
* with the corresponding matching row.
*
* If there are multiple matches then the behavior is undefined.
* Currently this causes multiple copies of the row to be created
* but that behavior is subject to change.
*/
whenMatchedUpdateAll?: boolean
/**
* If true then rows that exist only in the source table (new data)
* will be inserted into the target table.
*/
whenNotMatchedInsertAll?: boolean
/**
* If true then rows that exist only in the target table (old data)
* will be deleted.
*
* If this is a string then it will be treated as an SQL filter and
* only rows that both do not match any row in the source table and
* match the given filter will be deleted.
*
* This can be used to replace a selection of existing data with
* new data.
*/
whenNotMatchedBySourceDelete?: string | boolean
}

export interface VectorIndex {
columns: string[]
name: string
Expand Down Expand Up @@ -821,6 +884,38 @@ export class LocalTable<T = number[]> implements Table<T> {
})
}

async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
const whenMatchedUpdateAll = args.whenMatchedUpdateAll ?? false
const whenNotMatchedInsertAll = args.whenNotMatchedInsertAll ?? false
let whenNotMatchedBySourceDelete = false
let whenNotMatchedBySourceDeleteFilt = null
if (args.whenNotMatchedBySourceDelete !== undefined && args.whenNotMatchedBySourceDelete !== null) {
whenNotMatchedBySourceDelete = true
if (args.whenNotMatchedBySourceDelete !== true) {
whenNotMatchedBySourceDeleteFilt = args.whenNotMatchedBySourceDelete
}
}

const schema = await this.schema
let tbl: ArrowTable
if (data instanceof ArrowTable) {
tbl = data
} else {
tbl = makeArrowTable(data, { schema })
}
const buffer = await fromTableToBuffer(tbl, this._embeddings, schema)

this._tbl = await tableMergeInsert.call(
this._tbl,
on,
whenMatchedUpdateAll,
whenNotMatchedInsertAll,
whenNotMatchedBySourceDelete,
whenNotMatchedBySourceDeleteFilt,
buffer
)
}

/**
* Clean up old versions of the table, freeing disk space.
*
Expand Down
49 changes: 48 additions & 1 deletion node/src/remote/index.ts
Expand Up @@ -24,7 +24,8 @@ import {
type IndexStats,
type UpdateArgs,
type UpdateSqlArgs,
makeArrowTable
makeArrowTable,
type MergeInsertArgs
} from '../index'
import { Query } from '../query'

Expand Down Expand Up @@ -274,6 +275,52 @@ export class RemoteTable<T = number[]> implements Table<T> {
throw new Error('Not implemented')
}

async mergeInsert (on: string, data: Array<Record<string, unknown>> | ArrowTable, args: MergeInsertArgs): Promise<void> {
let tbl: ArrowTable
if (data instanceof ArrowTable) {
tbl = data
} else {
tbl = makeArrowTable(data, await this.schema)
}

const queryParams: any = {
on
}
if (args.whenMatchedUpdateAll ?? false) {
queryParams.when_matched_update_all = 'true'
} else {
queryParams.when_matched_update_all = 'false'
}
if (args.whenNotMatchedInsertAll ?? false) {
queryParams.when_not_matched_insert_all = 'true'
} else {
queryParams.when_not_matched_insert_all = 'false'
}
if (args.whenNotMatchedBySourceDelete !== false && args.whenNotMatchedBySourceDelete !== null && args.whenNotMatchedBySourceDelete !== undefined) {
queryParams.when_not_matched_by_source_delete = 'true'
if (typeof args.whenNotMatchedBySourceDelete === 'string') {
queryParams.when_not_matched_by_source_delete_filt = args.whenNotMatchedBySourceDelete
}
} else {
queryParams.when_not_matched_by_source_delete = 'false'
}

const buffer = await fromTableToStreamBuffer(tbl, this._embeddings)
const res = await this._client.post(
`/v1/table/${this._name}/merge_insert/`,
buffer,
queryParams,
'application/vnd.apache.arrow.stream'
)
if (res.status !== 200) {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
)
}
}

async add (data: Array<Record<string, unknown>> | ArrowTable): Promise<number> {
let tbl: ArrowTable
if (data instanceof ArrowTable) {
Expand Down
38 changes: 38 additions & 0 deletions node/src/test/test.ts
Expand Up @@ -531,6 +531,44 @@ describe('LanceDB client', function () {
assert.equal(await table.countRows(), 2)
})

it('can merge insert records into the table', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)

const data = [{ id: 1, age: 1 }, { id: 2, age: 1 }]
const table = await con.createTable('my_table', data)

let newData = [{ id: 2, age: 2 }, { id: 3, age: 2 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true
})
assert.equal(await table.countRows(), 3)
assert.equal((await table.filter('age = 2').execute()).length, 1)

newData = [{ id: 3, age: 3 }, { id: 4, age: 3 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true
})
assert.equal(await table.countRows(), 4)
assert.equal((await table.filter('age = 3').execute()).length, 2)

newData = [{ id: 5, age: 4 }]
await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true,
whenNotMatchedBySourceDelete: 'age < 3'
})
assert.equal(await table.countRows(), 3)

await table.mergeInsert('id', newData, {
whenNotMatchedInsertAll: true,
whenMatchedUpdateAll: true,
whenNotMatchedBySourceDelete: true
})
assert.equal(await table.countRows(), 1)
})

it('can update records in the table', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)
Expand Down
25 changes: 21 additions & 4 deletions python/lancedb/merge.py
Expand Up @@ -12,7 +12,7 @@
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Iterable, Optional
from typing import TYPE_CHECKING, List, Optional

if TYPE_CHECKING:
from .common import DATA
Expand All @@ -25,7 +25,7 @@ class LanceMergeInsertBuilder(object):
more context
"""

def __init__(self, table: "Table", on: Iterable[str]): # noqa: F821
def __init__(self, table: "Table", on: List[str]): # noqa: F821
# Do not put a docstring here. This method should be hidden
# from API docs. Users should use merge_insert to create
# this object.
Expand Down Expand Up @@ -77,10 +77,27 @@ def when_not_matched_by_source_delete(
self._when_not_matched_by_source_condition = condition
return self

def execute(self, new_data: DATA):
def execute(
self,
new_data: DATA,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
):
"""
Executes the merge insert operation
Nothing is returned but the [`Table`][lancedb.table.Table] is updated
Parameters
----------
new_data: DATA
New records which will be matched against the existing records
to potentially insert or update into the table. This parameter
can be anything you use for [`add`][lancedb.table.Table.add]
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
self._table._do_merge(self, new_data)
self._table._do_merge(self, new_data, on_bad_vectors, fill_value)
44 changes: 41 additions & 3 deletions python/lancedb/remote/table.py
Expand Up @@ -19,6 +19,7 @@
from lance import json_to_schema

from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from lancedb.merge import LanceMergeInsertBuilder

from ..query import LanceVectorQueryBuilder
from ..table import Query, Table, _sanitize_data
Expand Down Expand Up @@ -244,9 +245,46 @@ def _execute_query(self, query: Query) -> pa.Table:
result = self._conn._client.query(self._name, query)
return result.to_arrow()

def _do_merge(self, *_args):
"""_do_merge() is not supported on the LanceDB cloud yet"""
return NotImplementedError("_do_merge() is not supported on the LanceDB cloud")
def _do_merge(
self,
merge: LanceMergeInsertBuilder,
new_data: DATA,
on_bad_vectors: str,
fill_value: float,
):
data = _sanitize_data(
new_data,
self.schema,
metadata=None,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
payload = to_ipc_binary(data)

params = {}
if len(merge._on) != 1:
raise ValueError(
"RemoteTable only supports a single on key in merge_insert"
)
params["on"] = merge._on[0]
params["when_matched_update_all"] = str(merge._when_matched_update_all).lower()
params["when_not_matched_insert_all"] = str(
merge._when_not_matched_insert_all
).lower()
params["when_not_matched_by_source_delete"] = str(
merge._when_not_matched_by_source_delete
).lower()
if merge._when_not_matched_by_source_condition is not None:
params[
"when_not_matched_by_source_delete_filt"
] = merge._when_not_matched_by_source_condition

self._conn._client.post(
f"/v1/table/{self._name}/merge_insert/",
data=payload,
params=params,
content_type=ARROW_STREAM_CONTENT_TYPE,
)

def delete(self, predicate: str):
"""Delete rows from the table.
Expand Down

0 comments on commit 7f8637a

Please sign in to comment.