Skip to content

Commit

Permalink
feature: add numeric column
Browse files Browse the repository at this point in the history
  • Loading branch information
densumesh committed May 23, 2024
1 parent 0859c46 commit 56151bc
Show file tree
Hide file tree
Showing 10 changed files with 224 additions and 24 deletions.
3 changes: 3 additions & 0 deletions server/migrations/2024-05-23-184907_add-number-col/down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-- This file should undo anything in `up.sql`
ALTER TABLE chunk_metadata DROP COLUMN num_value;
DROP INDEX IF EXISTS idx_num_val_chunk_metadata;
3 changes: 3 additions & 0 deletions server/migrations/2024-05-23-184907_add-number-col/up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-- Your SQL goes here
ALTER TABLE chunk_metadata ADD COLUMN num_value float8;
CREATE INDEX idx_num_val_chunk_metadata ON chunk_metadata USING btree (num_value);
2 changes: 2 additions & 0 deletions server/src/bin/ingestion-worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ pub async fn bulk_upload_chunks(
.clone()
.map(|urls| urls.into_iter().map(Some).collect()),
tag_set_array: None,
num_value: message.chunk.num_value,
};

(
Expand Down Expand Up @@ -681,6 +682,7 @@ async fn upload_chunk(
.image_urls
.map(|urls| urls.into_iter().map(Some).collect()),
tag_set_array: None,
num_value: payload.chunk.num_value,
};

let embedding_vector = if let Some(embedding_vector) = payload.chunk.chunk_vector.clone() {
Expand Down
47 changes: 45 additions & 2 deletions server/src/data/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use super::schema::*;
use crate::handlers::file_handler::UploadFileData;
use crate::operators::search_operator::{
get_group_metadata_filter_condition, get_group_tag_set_filter_condition,
get_metadata_filter_condition,
get_metadata_filter_condition, get_num_value_filter_condition,
};
use actix_web::web;
use chrono::{DateTime, NaiveDateTime};
Expand Down Expand Up @@ -302,6 +302,7 @@ pub struct ChunkMetadata {
pub location: Option<GeoInfo>,
pub image_urls: Option<Vec<Option<String>>>,
pub tag_set_array: Option<Vec<Option<String>>>,
pub num_value: Option<f64>,
}

impl ChunkMetadata {
Expand All @@ -318,6 +319,7 @@ impl ChunkMetadata {
image_urls: Option<Vec<String>>,
dataset_id: uuid::Uuid,
weight: f64,
num_value: Option<f64>,
) -> Self {
ChunkMetadata {
id: uuid::Uuid::new_v4(),
Expand All @@ -335,6 +337,7 @@ impl ChunkMetadata {
weight,
image_urls: image_urls.map(|urls| urls.into_iter().map(Some).collect()),
tag_set_array: None,
num_value,
}
}
}
Expand All @@ -354,6 +357,7 @@ impl ChunkMetadata {
image_urls: Option<Vec<String>>,
dataset_id: uuid::Uuid,
weight: f64,
num_value: Option<f64>,
) -> Self {
ChunkMetadata {
id: id.into(),
Expand All @@ -371,6 +375,7 @@ impl ChunkMetadata {
weight,
image_urls: image_urls.map(|urls| urls.into_iter().map(Some).collect()),
tag_set_array: None,
num_value,
}
}
}
Expand All @@ -393,6 +398,7 @@ impl From<SlimChunkMetadata> for ChunkMetadata {
weight: slim_chunk.weight,
image_urls: slim_chunk.image_urls,
tag_set_array: None,
num_value: slim_chunk.num_value,
}
}
}
Expand All @@ -415,6 +421,7 @@ impl From<ContentChunkMetadata> for ChunkMetadata {
weight: content_chunk.weight,
image_urls: content_chunk.image_urls,
tag_set_array: None,
num_value: content_chunk.num_value,
}
}
}
Expand Down Expand Up @@ -665,6 +672,7 @@ pub struct SlimChunkMetadata {
pub dataset_id: uuid::Uuid,
pub weight: f64,
pub image_urls: Option<Vec<Option<String>>>,
pub num_value: Option<f64>,
}

impl From<ChunkMetadata> for SlimChunkMetadata {
Expand All @@ -683,6 +691,7 @@ impl From<ChunkMetadata> for SlimChunkMetadata {
dataset_id: chunk.dataset_id,
weight: chunk.weight,
image_urls: chunk.image_urls,
num_value: chunk.num_value,
}
}
}
Expand All @@ -703,6 +712,7 @@ impl From<ContentChunkMetadata> for SlimChunkMetadata {
dataset_id: uuid::Uuid::new_v4(),
weight: chunk.weight,
image_urls: chunk.image_urls,
num_value: chunk.num_value,
}
}
}
Expand All @@ -729,6 +739,7 @@ pub struct ContentChunkMetadata {
pub time_stamp: Option<NaiveDateTime>,
pub weight: f64,
pub image_urls: Option<Vec<Option<String>>>,
pub num_value: Option<f64>,
}

impl From<ChunkMetadata> for ContentChunkMetadata {
Expand All @@ -741,6 +752,7 @@ impl From<ChunkMetadata> for ContentChunkMetadata {
time_stamp: chunk.time_stamp,
weight: chunk.weight,
image_urls: chunk.image_urls,
num_value: chunk.num_value,
}
}
}
Expand Down Expand Up @@ -2161,6 +2173,7 @@ pub struct QdrantPayload {
pub content: String,
pub group_ids: Option<Vec<uuid::Uuid>>,
pub location: Option<GeoInfo>,
pub num_value: Option<f64>,
}

impl From<QdrantPayload> for Payload {
Expand Down Expand Up @@ -2189,6 +2202,7 @@ impl QdrantPayload {
content: convert_html_to_text(&chunk_metadata.chunk_html.unwrap_or_default()),
group_ids,
location: chunk_metadata.location,
num_value: chunk_metadata.num_value,
}
}

Expand Down Expand Up @@ -2235,6 +2249,11 @@ impl QdrantPayload {
serde_json::from_value(value.into()).expect("Failed to parse location")
})
.unwrap_or_default(),
num_value: point
.payload
.get("num_value")
.cloned()
.map(|x| x.as_double().expect("num_value should be a float")),
}
}
}
Expand Down Expand Up @@ -2294,6 +2313,11 @@ impl From<RetrievedPoint> for QdrantPayload {
serde_json::from_value(value.into()).expect("Failed to parse location")
})
.unwrap_or_default(),
num_value: point
.payload
.get("num_value")
.cloned()
.map(|x| x.as_double().expect("num_value should be a float")),
}
}
}
Expand Down Expand Up @@ -2389,6 +2413,7 @@ pub struct DateRange {
pub enum MatchCondition {
Text(String),
Integer(i64),
Float(f64),
}

impl MatchCondition {
Expand All @@ -2397,13 +2422,23 @@ impl MatchCondition {
match self {
MatchCondition::Text(text) => text.clone(),
MatchCondition::Integer(int) => int.to_string(),
MatchCondition::Float(float) => float.to_string(),
}
}

pub fn to_i64(&self) -> i64 {
match self {
MatchCondition::Text(text) => text.parse().unwrap(),
MatchCondition::Integer(int) => *int,
MatchCondition::Float(float) => *float as i64,
}
}

pub fn to_f64(&self) -> f64 {
match self {
MatchCondition::Text(text) => text.parse().unwrap(),
MatchCondition::Integer(int) => *int as f64,
MatchCondition::Float(float) => *float,
}
}
}
Expand Down Expand Up @@ -2548,6 +2583,14 @@ impl FieldCondition {
));
}

if self.field == "num_value" {
return Ok(Some(
get_num_value_filter_condition(self, dataset_id, pool)
.await?
.into(),
));
}

if let Some(date_range) = self.date_range.clone() {
let time_range = get_date_range(date_range)?;
return Ok(Some(qdrant::Condition::range(
Expand Down Expand Up @@ -2665,7 +2708,7 @@ impl FieldCondition {
"Invalid condition type".to_string(),
)),
},
MatchCondition::Integer(_) => match condition_type {
MatchCondition::Integer(_) | MatchCondition::Float(_) => match condition_type {
"must" | "should" => Ok(Some(qdrant::Condition::matches(
self.field.as_str(),
matches
Expand Down
1 change: 1 addition & 0 deletions server/src/data/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ diesel::table! {
location -> Nullable<Jsonb>,
image_urls -> Nullable<Array<Nullable<Text>>>,
tag_set_array -> Nullable<Array<Nullable<Text>>>,
num_value -> Nullable<Float8>,
}
}

Expand Down
17 changes: 14 additions & 3 deletions server/src/handlers/chunk_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ pub struct ChunkData {
pub link: Option<String>,
/// Tag set is a list of tags. This can be used to filter chunks by tag. Unlike with metadata filtering, HNSW indices will exist for each tag such that there is not a performance hit for filtering on them.
pub tag_set: Option<Vec<String>>,
/// Num value is an arbitrary numerical value that can be used to filter chunks. This is useful for when you want to filter chunks by numerical value. There is no performance hit for filtering on num_value.
pub num_value: Option<f64>,
/// Metadata is a JSON object which can be used to filter chunks. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata.
pub metadata: Option<serde_json::Value>,
/// Chunk_vector is a vector of floats which can be used instead of generating a new embedding. This is useful for when you are using a pre-embedded dataset. If this is not provided, the innerText of the chunk_html will be used to create the embedding.
Expand Down Expand Up @@ -463,6 +465,8 @@ pub struct UpdateChunkData {
tag_set: Option<Vec<String>>,
/// Link of the chunk you want to update. This can also be any string. Frequently, this is a link to the source of the chunk. The link value will not affect the embedding creation. If no link is provided, the existing link will be used.
link: Option<String>,
///Num value is an arbitrary numerical value that can be used to filter chunks. This is useful for when you want to filter chunks by numerical value. If no num_value is provided, the existing num_value will be used.
num_value: Option<f64>,
/// HTML content of the chunk you want to update. This can also be plaintext. The innerText of the HTML will be used to create the embedding vector. The point of using HTML is for convienience, as some users have applications where users submit HTML content. If no chunk_html is provided, the existing chunk_html will be used.
chunk_html: Option<String>,
/// The metadata is a JSON object which can be used to filter chunks. This is useful for when you want to filter chunks by arbitrary metadata. Unlike with tag filtering, there is a performance hit for filtering on metadata. If no metadata is provided, the existing metadata will be used.
Expand Down Expand Up @@ -581,10 +585,16 @@ pub async fn update_chunk(
})
.transpose()?
.or(chunk_metadata.time_stamp),
update_chunk_data.location,
update_chunk_data.image_urls.clone(),
update_chunk_data
.location
.clone()
.or(chunk_metadata.location),
update_chunk_data.image_urls.clone().or(chunk_metadata
.image_urls
.map(|x| x.into_iter().map(|x| x.unwrap()).collect())),
dataset_id,
update_chunk_data.weight.unwrap_or(1.0),
update_chunk_data.weight.unwrap_or(chunk_metadata.weight),
update_chunk_data.num_value.or(chunk_metadata.num_value),
);

let group_ids = if let Some(group_ids) = update_chunk_data.group_ids.clone() {
Expand Down Expand Up @@ -734,6 +744,7 @@ pub async fn update_chunk_by_tracking_id(
None,
dataset_org_plan_sub.dataset.id,
update_chunk_data.weight.unwrap_or(1.0),
None,
);
let group_ids = if let Some(group_ids) = update_chunk_data.group_ids.clone() {
Some(
Expand Down
2 changes: 1 addition & 1 deletion server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ impl Modify for SecurityAddon {
name = "BSL",
url = "https://github.com/devflowinc/trieve/blob/main/LICENSE.txt",
),
version = "0.8.7",
version = "0.8.8",
),
servers(
(url = "https://api.trieve.ai",
Expand Down
8 changes: 8 additions & 0 deletions server/src/operators/chunk_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ pub async fn get_slim_chunk_metadatas_from_point_ids(
chunk_metadata_columns::dataset_id,
chunk_metadata_columns::weight,
chunk_metadata_columns::image_urls,
chunk_metadata_columns::num_value,
))
.load::<SlimChunkMetadata>(&mut conn)
.await
Expand Down Expand Up @@ -211,6 +212,7 @@ pub async fn get_chunk_metadatas_and_collided_chunks_from_point_ids_query(
weight: chunk.0.weight,
image_urls: chunk.0.image_urls.clone(),
tag_set_array: None,
num_value: chunk.0.num_value,
}
.into()
})
Expand Down Expand Up @@ -262,6 +264,7 @@ pub async fn get_chunk_metadatas_and_collided_chunks_from_point_ids_query(
weight: chunk.0.weight,
image_urls: chunk.0.image_urls.clone(),
tag_set_array: None,
num_value: chunk.0.num_value,
}
.into()
})
Expand Down Expand Up @@ -327,6 +330,7 @@ pub async fn get_slim_chunks_from_point_ids_query(
chunk_metadata_columns::dataset_id,
chunk_metadata_columns::weight,
chunk_metadata_columns::image_urls,
chunk_metadata_columns::num_value,
))
.filter(chunk_metadata_columns::qdrant_point_id.eq_any(&point_ids))
.load(&mut conn)
Expand Down Expand Up @@ -388,6 +392,7 @@ pub async fn get_content_chunk_from_point_ids_query(
chunk_metadata_columns::time_stamp,
chunk_metadata_columns::weight,
chunk_metadata_columns::image_urls,
chunk_metadata_columns::num_value,
))
.filter(chunk_metadata_columns::qdrant_point_id.eq_any(&point_ids))
.load(&mut conn)
Expand Down Expand Up @@ -831,6 +836,8 @@ pub async fn update_chunk_metadata_query(
chunk_metadata_columns::time_stamp.eq(chunk_data.time_stamp),
chunk_metadata_columns::location.eq(chunk_data.location),
chunk_metadata_columns::weight.eq(chunk_data.weight),
chunk_metadata_columns::image_urls.eq(chunk_data.image_urls),
chunk_metadata_columns::num_value.eq(chunk_data.num_value),
))
.get_result::<ChunkMetadata>(&mut conn)
.await
Expand Down Expand Up @@ -1347,6 +1354,7 @@ pub async fn create_chunk_metadata(
chunk.image_urls.clone(),
dataset_uuid,
chunk.weight.unwrap_or(0.0),
chunk.num_value,
);
chunk_metadatas.push(chunk_metadata.clone());

Expand Down
1 change: 1 addition & 0 deletions server/src/operators/file_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ pub async fn create_chunks_with_handler(
split_avg: None,
convert_html_to_text: None,
image_urls: None,
num_value: None,
};
chunks.push(create_chunk_data);
}
Expand Down
Loading

0 comments on commit 56151bc

Please sign in to comment.