diff --git a/rust/README.md b/rust/README.md index 7b9b96c..d8c9563 100644 --- a/rust/README.md +++ b/rust/README.md @@ -24,6 +24,8 @@ cargo build --release ``` cargo run -- server/vss-server-config.toml ``` + + **Note:** For testing purposes you can edit `vss-server-config.toml` to use `store_type` as in-memory instead of PostgreSQL: `store_type = "in_memory"` 4. VSS endpoint should be reachable at `http://localhost:8080/vss`. ### Configuration diff --git a/rust/impls/src/in_memory_store.rs b/rust/impls/src/in_memory_store.rs new file mode 100644 index 0000000..24c3a76 --- /dev/null +++ b/rust/impls/src/in_memory_store.rs @@ -0,0 +1,403 @@ +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +use async_trait::async_trait; +use bytes::Bytes; +use chrono::prelude::Utc; + +use crate::postgres_store::{ + VssDbRecord, LIST_KEY_VERSIONS_MAX_PAGE_SIZE, MAX_PUT_REQUEST_ITEM_COUNT, +}; +use api::error::VssError; +use api::kv_store::{KvStore, GLOBAL_VERSION_KEY}; +use api::types::{ + DeleteObjectRequest, DeleteObjectResponse, GetObjectRequest, GetObjectResponse, KeyValue, + ListKeyVersionsRequest, ListKeyVersionsResponse, PutObjectRequest, PutObjectResponse, +}; + +fn build_vss_record(user_token: String, store_id: String, kv: KeyValue) -> VssDbRecord { + let now = Utc::now(); + VssDbRecord { + user_token, + store_id, + key: kv.key, + value: kv.value.to_vec(), + version: kv.version, + created_at: now, + last_updated_at: now, + } +} + +fn build_key(user_token: &str, store_id: &str, key: &str) -> String { + format!("{}#{}#{}", user_token, store_id, key) +} + +/// In-memory implementation of the VSS Store. +pub struct InMemoryBackendImpl { + store: Arc>>, +} + +impl InMemoryBackendImpl { + /// Creates an in-memory instance. + pub fn new() -> Self { + Self { store: Arc::new(Mutex::new(HashMap::new())) } + } + + fn get_current_global_version( + &self, guard: &HashMap, user_token: &str, store_id: &str, + ) -> i64 { + let global_key = build_key(user_token, store_id, GLOBAL_VERSION_KEY); + guard.get(&global_key).map(|r| r.version).unwrap_or(0) + } + + fn set_global_version( + &self, guard: &mut HashMap, user_token: String, store_id: String, + new_version: i64, + ) { + let global_key = build_key(&user_token, &store_id, GLOBAL_VERSION_KEY); + let now = Utc::now(); + + let entry = guard.entry(global_key); + match entry { + std::collections::hash_map::Entry::Occupied(mut occ) => { + let rec = occ.get_mut(); + rec.version = new_version; + rec.last_updated_at = now; + }, + std::collections::hash_map::Entry::Vacant(vac) => { + let record = VssDbRecord { + user_token, + store_id, + key: GLOBAL_VERSION_KEY.to_string(), + value: vec![], + version: new_version, + created_at: now, + last_updated_at: now, + }; + vac.insert(record); + }, + } + } +} + +fn execute_put_object( + store: &mut HashMap, record: VssDbRecord, +) -> Result<(), VssError> { + let key = format!("{}#{}#{}", record.user_token, record.store_id, record.key); + let now = Utc::now(); + + let entry = store.entry(key); + match entry { + std::collections::hash_map::Entry::Occupied(mut occ) => { + let existing = occ.get_mut(); + let new_version = + if record.version == -1 { existing.version } else { existing.version + 1 }; + existing.version = new_version; + existing.value = record.value; + existing.last_updated_at = now; + }, + std::collections::hash_map::Entry::Vacant(vac) => { + let new_record = VssDbRecord { + user_token: record.user_token, + store_id: record.store_id, + key: record.key, + value: record.value, + version: 1, + created_at: now, + last_updated_at: now, + }; + vac.insert(new_record); + }, + } + Ok(()) +} + +fn execute_delete_object( + store: &mut HashMap, record: VssDbRecord, +) -> Result { + let key = format!("{}#{}#{}", record.user_token, record.store_id, record.key); + if record.version != -1 { + if let Some(existing) = store.get(&key) { + if existing.version != record.version { + return Err(VssError::ConflictError(format!( + "Version conflict on delete for key {}", + record.key + ))); + } + } else { + return Err(VssError::ConflictError(format!( + "Key {} does not exist for delete", + record.key + ))); + } + } + Ok(store.remove(&key).is_some()) +} + +#[async_trait] +impl KvStore for InMemoryBackendImpl { + async fn get( + &self, user_token: String, request: GetObjectRequest, + ) -> Result { + let key = build_key(&user_token, &request.store_id, &request.key); + let guard = self.store.lock().unwrap(); + + if let Some(record) = guard.get(&key) { + Ok(GetObjectResponse { + value: Some(KeyValue { + key: record.key.clone(), + value: Bytes::from(record.value.clone()), + version: record.version, + }), + }) + } else if request.key == GLOBAL_VERSION_KEY { + Ok(GetObjectResponse { + value: Some(KeyValue { + key: GLOBAL_VERSION_KEY.to_string(), + value: Bytes::new(), + version: self.get_current_global_version( + &guard, + &user_token, + &request.store_id, + ), + }), + }) + } else { + Err(VssError::NoSuchKeyError("Requested key not found.".to_string())) + } + } + + async fn put( + &self, user_token: String, request: PutObjectRequest, + ) -> Result { + if request.transaction_items.len() + request.delete_items.len() > MAX_PUT_REQUEST_ITEM_COUNT + { + return Err(VssError::InvalidRequestError(format!( + "Number of write items per request should be less than equal to {}", + MAX_PUT_REQUEST_ITEM_COUNT + ))); + } + + let store_id = request.store_id.clone(); + let mut guard = self.store.lock().unwrap(); + + let current_global = self.get_current_global_version(&guard, &user_token, &store_id); + if let Some(expected_global) = request.global_version { + if current_global != expected_global { + return Err(VssError::ConflictError(format!( + "Global version conflict: expected {}, current {}", + expected_global, current_global + ))); + } + } + + for kv in &request.transaction_items { + let key = build_key(&user_token, &store_id, &kv.key); + if kv.version != -1 { + if let Some(existing) = guard.get(&key) { + if existing.version != kv.version { + return Err(VssError::ConflictError(format!( + "Version conflict on put for key {}", + kv.key + ))); + } + } else if kv.version != 0 { + return Err(VssError::ConflictError(format!( + "Key {} does not exist for put with version {}", + kv.key, kv.version + ))); + } + } + } + for kv in &request.delete_items { + let key = build_key(&user_token, &store_id, &kv.key); + if kv.version != -1 { + if let Some(existing) = guard.get(&key) { + if existing.version != kv.version { + return Err(VssError::ConflictError(format!( + "Version conflict on delete for key {}", + kv.key + ))); + } + } else { + return Err(VssError::ConflictError(format!( + "Key {} does not exist for delete", + kv.key + ))); + } + } + } + + let vss_put_records: Vec = request + .transaction_items + .into_iter() + .map(|kv| build_vss_record(user_token.clone(), store_id.clone(), kv)) + .collect(); + + let vss_delete_records: Vec = request + .delete_items + .into_iter() + .map(|kv| build_vss_record(user_token.clone(), store_id.clone(), kv)) + .collect(); + + let mut mutated = false; + for vss_record in vss_put_records { + execute_put_object(&mut guard, vss_record)?; + mutated = true; + } + for vss_record in vss_delete_records { + if execute_delete_object(&mut guard, vss_record)? { + mutated = true; + } + } + + if mutated && request.global_version.is_some() { + let new_global = current_global + 1; + self.set_global_version(&mut guard, user_token.clone(), store_id.clone(), new_global); + } + + Ok(PutObjectResponse {}) + } + + async fn delete( + &self, user_token: String, request: DeleteObjectRequest, + ) -> Result { + let key_value = request.key_value.ok_or_else(|| { + VssError::InvalidRequestError("key_value missing in DeleteObjectRequest".to_string()) + })?; + let store_id = request.store_id.clone(); + let mut guard = self.store.lock().unwrap(); + + let key = build_key(&user_token, &store_id, &key_value.key); + if key_value.version != -1 { + if let Some(existing) = guard.get(&key) { + if existing.version != key_value.version { + return Ok(DeleteObjectResponse {}); + } + } else { + return Ok(DeleteObjectResponse {}); + } + } + + let vss_record = build_vss_record(user_token.clone(), store_id.clone(), key_value); + execute_delete_object(&mut guard, vss_record)?; + Ok(DeleteObjectResponse {}) + } + + async fn list_key_versions( + &self, user_token: String, request: ListKeyVersionsRequest, + ) -> Result { + let store_id = request.store_id; + let key_prefix = request.key_prefix.unwrap_or("".to_string()); + let page_token = request.page_token.unwrap_or("".to_string()); + let page_size = request.page_size.unwrap_or(i32::MAX); + let limit = std::cmp::min(page_size, LIST_KEY_VERSIONS_MAX_PAGE_SIZE) as usize; + + let mut global_version = None; + if page_token.is_empty() { + let get_global_version_request = GetObjectRequest { + store_id: store_id.clone(), + key: GLOBAL_VERSION_KEY.to_string(), + }; + let get_response = self.get(user_token.clone(), get_global_version_request).await?; + global_version = Some(get_response.value.unwrap().version); + } + + let key_versions: Vec = { + let guard = self.store.lock().unwrap(); + let mut key_versions: Vec = guard + .iter() + .filter(|(k, _)| { + let parts: Vec<&str> = k.split('#').collect(); + if parts.len() < 3 { + return false; + } + parts[0] == user_token.as_str() + && parts[1] == store_id.as_str() + && parts[2].starts_with(&key_prefix) + && parts[2] > page_token.as_str() + && parts[2] != GLOBAL_VERSION_KEY + }) + .map(|(_, record)| KeyValue { + key: record.key.clone(), + value: Bytes::new(), + version: record.version, + }) + .collect(); + + key_versions.sort_by(|a, b| a.key.cmp(&b.key)); + key_versions.into_iter().take(limit).collect() + }; + + let next_page_token = if key_versions.len() == limit { + key_versions.last().map(|kv| kv.key.clone()) + } else { + None + }; + + Ok(ListKeyVersionsResponse { key_versions, next_page_token, global_version }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use api::define_kv_store_tests; + use bytes::Bytes; + use tokio::test; + + define_kv_store_tests!(InMemoryKvStoreTest, InMemoryBackendImpl, InMemoryBackendImpl::new()); + + #[test] + async fn test_in_memory_crud() { + let store = InMemoryBackendImpl::new(); + let user_token = "test_user".to_string(); + let store_id = "test_store".to_string(); + + let put_request = PutObjectRequest { + store_id: store_id.clone(), + transaction_items: vec![KeyValue { + key: "key1".to_string(), + value: Bytes::from("value1"), + version: 0, + }], + delete_items: vec![], + global_version: None, + }; + store.put(user_token.clone(), put_request).await.unwrap(); + + let get_request = GetObjectRequest { store_id: store_id.clone(), key: "key1".to_string() }; + let response = store.get(user_token.clone(), get_request).await.unwrap(); + let key_value = response.value.unwrap(); + assert_eq!(key_value.value, Bytes::from("value1")); + assert_eq!(key_value.version, 1, "Expected version 1 after put"); + + let list_request = ListKeyVersionsRequest { + store_id: store_id.clone(), + key_prefix: None, + page_size: Some(1), + page_token: None, + }; + let response = store.list_key_versions(user_token.clone(), list_request).await.unwrap(); + assert_eq!(response.key_versions.len(), 1); + assert_eq!(response.key_versions[0].key, "key1"); + assert_eq!(response.key_versions[0].version, 1); + + let delete_request = DeleteObjectRequest { + store_id: store_id.clone(), + key_value: Some(KeyValue { key: "key1".to_string(), value: Bytes::new(), version: 1 }), + }; + store.delete(user_token.clone(), delete_request).await.unwrap(); + + let get_request = GetObjectRequest { store_id: store_id.clone(), key: "key1".to_string() }; + assert!(matches!( + store.get(user_token.clone(), get_request).await, + Err(VssError::NoSuchKeyError(_)) + )); + + let global_request = + GetObjectRequest { store_id: store_id.clone(), key: GLOBAL_VERSION_KEY.to_string() }; + let response = store.get(user_token.clone(), global_request).await.unwrap(); + assert_eq!(response.value.unwrap().version, 0, "Expected global_version=0"); + } +} diff --git a/rust/impls/src/lib.rs b/rust/impls/src/lib.rs index 27844d0..1aaf4d0 100644 --- a/rust/impls/src/lib.rs +++ b/rust/impls/src/lib.rs @@ -11,6 +11,8 @@ #![deny(rustdoc::private_intra_doc_links)] #![deny(missing_docs)] +/// Contains in-memory backend implementation for VSS, for testing purposes only. +pub mod in_memory_store; mod migrations; /// Contains [PostgreSQL](https://www.postgresql.org/) based backend implementation for VSS. pub mod postgres_store; diff --git a/rust/impls/src/migrations.rs b/rust/impls/src/migrations.rs index bab951b..02d3f54 100644 --- a/rust/impls/src/migrations.rs +++ b/rust/impls/src/migrations.rs @@ -5,7 +5,7 @@ pub(crate) const MIGRATION_LOG_COLUMN: &str = "upgrade_from"; pub(crate) const CHECK_DB_STMT: &str = "SELECT 1 FROM pg_database WHERE datname = $1"; pub(crate) const INIT_DB_CMD: &str = "CREATE DATABASE"; #[cfg(test)] -const DROP_DB_CMD: &str = "DROP DATABASE"; +pub(crate) const DROP_DB_CMD: &str = "DROP DATABASE"; pub(crate) const GET_VERSION_STMT: &str = "SELECT db_version FROM vss_db_version;"; pub(crate) const UPDATE_VERSION_STMT: &str = "UPDATE vss_db_version SET db_version=$1;"; pub(crate) const LOG_MIGRATION_STMT: &str = "INSERT INTO vss_db_upgrades VALUES($1);"; @@ -36,4 +36,4 @@ pub(crate) const MIGRATIONS: &[&str] = &[ );", ]; #[cfg(test)] -const DUMMY_MIGRATION: &str = "SELECT 1 WHERE FALSE;"; +pub(crate) const DUMMY_MIGRATION: &str = "SELECT 1 WHERE FALSE;"; diff --git a/rust/impls/src/postgres_store.rs b/rust/impls/src/postgres_store.rs index f3b39c3..1e951df 100644 --- a/rust/impls/src/postgres_store.rs +++ b/rust/impls/src/postgres_store.rs @@ -93,10 +93,7 @@ async fn drop_database(postgres_endpoint: &str, db_name: &str) -> Result<(), Err let drop_database_statement = format!("{} {};", DROP_DB_CMD, db_name); let num_rows = client.execute(&drop_database_statement, &[]).await.map_err(|e| { - Error::new( - ErrorKind::Other, - format!("Failed to drop database {}: {}", db_name, e), - ) + Error::new(ErrorKind::Other, format!("Failed to drop database {}: {}", db_name, e)) })?; assert_eq!(num_rows, 0); @@ -134,10 +131,7 @@ impl PostgresBackendImpl { async fn migrate_vss_database(&self, migrations: &[&str]) -> Result<(usize, usize), Error> { let mut conn = self.pool.get().await.map_err(|e| { - Error::new( - ErrorKind::Other, - format!("Failed to fetch a connection from Pool: {}", e), - ) + Error::new(ErrorKind::Other, format!("Failed to fetch a connection from Pool: {}", e)) })?; // Get the next migration to be applied. @@ -230,7 +224,9 @@ impl PostgresBackendImpl { async fn get_upgrades_list(&self) -> Vec { let conn = self.pool.get().await.unwrap(); let rows = conn.query(GET_MIGRATION_LOG_STMT, &[]).await.unwrap(); - rows.iter().map(|row| usize::try_from(row.get::<&str, i32>(MIGRATION_LOG_COLUMN)).unwrap()).collect() + rows.iter() + .map(|row| usize::try_from(row.get::<&str, i32>(MIGRATION_LOG_COLUMN)).unwrap()) + .collect() } fn build_vss_record(&self, user_token: String, store_id: String, kv: KeyValue) -> VssDbRecord { @@ -581,10 +577,10 @@ impl KvStore for PostgresBackendImpl { #[cfg(test)] mod tests { + use super::{drop_database, DUMMY_MIGRATION, MIGRATIONS}; use crate::postgres_store::PostgresBackendImpl; use api::define_kv_store_tests; use tokio::sync::OnceCell; - use super::{MIGRATIONS, DUMMY_MIGRATION, drop_database}; const POSTGRES_ENDPOINT: &str = "postgresql://postgres:postgres@localhost:5432"; const MIGRATIONS_START: usize = 0; @@ -670,7 +666,10 @@ mod tests { let (start, end) = store.migrate_vss_database(&migrations).await.unwrap(); assert_eq!(start, MIGRATIONS_END + 1); assert_eq!(end, MIGRATIONS_END + 3); - assert_eq!(store.get_upgrades_list().await, [MIGRATIONS_START, MIGRATIONS_END, MIGRATIONS_END + 1]); + assert_eq!( + store.get_upgrades_list().await, + [MIGRATIONS_START, MIGRATIONS_END, MIGRATIONS_END + 1] + ); assert_eq!(store.get_schema_version().await, MIGRATIONS_END + 3); }; diff --git a/rust/server/src/main.rs b/rust/server/src/main.rs index 5a78be6..85d61c7 100644 --- a/rust/server/src/main.rs +++ b/rust/server/src/main.rs @@ -20,6 +20,7 @@ use hyper_util::rt::TokioIo; use crate::vss_service::VssService; use api::auth::{Authorizer, NoopAuthorizer}; use api::kv_store::KvStore; +use impls::in_memory_store::InMemoryBackendImpl; use impls::postgres_store::PostgresBackendImpl; use std::sync::Arc; @@ -67,15 +68,32 @@ fn main() { }, }; let authorizer = Arc::new(NoopAuthorizer {}); - let postgresql_config = config.postgresql_config.expect("PostgreSQLConfig must be defined in config file."); - let endpoint = postgresql_config.to_postgresql_endpoint(); - let db_name = postgresql_config.database; - let store = Arc::new( - PostgresBackendImpl::new(&endpoint, &db_name) - .await - .unwrap(), - ); - println!("Connected to PostgreSQL backend with DSN: {}/{}", endpoint, db_name); + let store: Arc = match config.server_config.store_type.as_str() { + "postgres" => { + let pg_config = config.postgresql_config + .expect("PostgreSQL configuration required for postgres backend"); + let endpoint = pg_config.to_postgresql_endpoint(); + let db_name = pg_config.database; + match PostgresBackendImpl::new(&endpoint, &db_name).await { + Ok(backend) => { + println!("Connected to PostgreSQL backend with DSN: {}/{}", endpoint, db_name); + Arc::new(backend) + }, + Err(e) => { + eprintln!("Failed to connect to PostgreSQL backend: {}", e); + std::process::exit(1); + }, + } + }, + "in_memory" => { + println!("Using in-memory backend for testing"); + Arc::new(InMemoryBackendImpl::new()) + }, + _ => { + eprintln!("Invalid backend_type: {}. Must be 'postgres' or 'in_memory'", config.server_config.store_type); + std::process::exit(1); + }, + }; let rest_svc_listener = TcpListener::bind(&addr).await.expect("Failed to bind listening port"); println!("Listening for incoming connections on {}", addr); diff --git a/rust/server/src/util/config.rs b/rust/server/src/util/config.rs index cf70daf..87b0aa0 100644 --- a/rust/server/src/util/config.rs +++ b/rust/server/src/util/config.rs @@ -10,6 +10,7 @@ pub(crate) struct Config { pub(crate) struct ServerConfig { pub(crate) host: String, pub(crate) port: u16, + pub(crate) store_type: String, // "postgresql" or "in_memory" } #[derive(Deserialize)] diff --git a/rust/server/vss-server-config.toml b/rust/server/vss-server-config.toml index 8a013b5..b030851 100644 --- a/rust/server/vss-server-config.toml +++ b/rust/server/vss-server-config.toml @@ -1,6 +1,7 @@ [server_config] host = "127.0.0.1" port = 8080 +store_type = "postgres" # "postgres" for using postgresql and "in_memory" for testing purpooses [postgresql_config] username = "postgres" # Optional in TOML, can be overridden by env var `VSS_POSTGRESQL_USERNAME`