From 202435742ed78b0eac80efcd19b357df96a6bbb9 Mon Sep 17 00:00:00 2001 From: Brian Pearce Date: Mon, 17 Jul 2023 10:08:25 +0200 Subject: [PATCH] feat: chat message fetching pagination (#5594) Description --- Previously the only function to get messages was `get_all_messages` which return all related records in the db. Now we've added pagination and you can select the amount of results per page, and what page you'd like returned. Results are now ordered by most recent received. Motivation and Context --- `get_all_messages` has the potential to get very slow and ddos the client. Limiting the amount returned and allowing to change the requested amounts makes things better to deal with. How Has This Been Tested? --- Test added for pagination. What process can a PR reviewer use to test or verify this change? --- Tests pass. Breaking Changes --- - [x] None - [ ] Requires data directory on base node to be deleted - [ ] Requires hard fork - [ ] Other - Please specify --- base_layer/chat_ffi/chat.h | 10 +- base_layer/chat_ffi/src/lib.rs | 20 +++- .../examples/chat_client/src/client.rs | 6 +- .../contacts/src/contacts_service/handle.rs | 31 +++++- .../contacts/src/contacts_service/service.rs | 4 +- .../src/contacts_service/storage/database.rs | 13 ++- .../src/contacts_service/storage/sqlite_db.rs | 22 +++-- .../storage/types/messages.rs | 5 + base_layer/contacts/tests/contacts_service.rs | 95 ++++++++++++++++++- integration_tests/src/chat_ffi.rs | 14 ++- integration_tests/tests/steps/chat_steps.rs | 9 +- 11 files changed, 192 insertions(+), 37 deletions(-) diff --git a/base_layer/chat_ffi/chat.h b/base_layer/chat_ffi/chat.h index fe60186068..7f39d69f67 100644 --- a/base_layer/chat_ffi/chat.h +++ b/base_layer/chat_ffi/chat.h @@ -152,6 +152,8 @@ int check_online_status(struct ClientFFI *client, struct TariAddress *receiver, * ## Arguments * `client` - The Client pointer * `address` - A TariAddress ptr + * `limit` - The amount of messages you want to fetch. Default to 35, max 2500 + * `page` - The page of results you'd like returned. Default to 0, maximum of u64 max * `error_out` - Pointer to an int which will be modified * * ## Returns @@ -161,9 +163,11 @@ int check_online_status(struct ClientFFI *client, struct TariAddress *receiver, * The ```address``` should be destroyed after use * The returned pointer to ```*mut ChatMessages``` should be destroyed after use */ -struct ChatMessages *get_all_messages(struct ClientFFI *client, - struct TariAddress *address, - int *error_out); +struct ChatMessages *get_messages(struct ClientFFI *client, + struct TariAddress *address, + int *limit, + int *page, + int *error_out); /** * Frees memory for messages diff --git a/base_layer/chat_ffi/src/lib.rs b/base_layer/chat_ffi/src/lib.rs index 602b6ecd54..b77cf6c760 100644 --- a/base_layer/chat_ffi/src/lib.rs +++ b/base_layer/chat_ffi/src/lib.rs @@ -22,7 +22,7 @@ #![recursion_limit = "1024"] -use std::{ffi::CStr, path::PathBuf, ptr, str::FromStr, sync::Arc}; +use std::{convert::TryFrom, ffi::CStr, path::PathBuf, ptr, str::FromStr, sync::Arc}; use callback_handler::CallbackContactStatusChange; use libc::{c_char, c_int}; @@ -35,7 +35,10 @@ use tari_chat_client::{ use tari_common::configuration::{MultiaddrList, Network}; use tari_common_types::tari_address::TariAddress; use tari_comms::{multiaddr::Multiaddr, NodeIdentity}; -use tari_contacts::contacts_service::types::Message; +use tari_contacts::contacts_service::{ + handle::{DEFAULT_MESSAGE_LIMIT, DEFAULT_MESSAGE_PAGE}, + types::Message, +}; use tokio::runtime::Runtime; use crate::{ @@ -393,6 +396,8 @@ pub unsafe extern "C" fn check_online_status( /// ## Arguments /// `client` - The Client pointer /// `address` - A TariAddress ptr +/// `limit` - The amount of messages you want to fetch. Default to 35, max 2500 +/// `page` - The page of results you'd like returned. Default to 0, maximum of u64 max /// `error_out` - Pointer to an int which will be modified /// /// ## Returns @@ -402,9 +407,11 @@ pub unsafe extern "C" fn check_online_status( /// The ```address``` should be destroyed after use /// The returned pointer to ```*mut ChatMessages``` should be destroyed after use #[no_mangle] -pub unsafe extern "C" fn get_all_messages( +pub unsafe extern "C" fn get_messages( client: *mut ClientFFI, address: *mut TariAddress, + limit: *mut c_int, + page: *mut c_int, error_out: *mut c_int, ) -> *mut ChatMessages { let mut error = 0; @@ -420,9 +427,14 @@ pub unsafe extern "C" fn get_all_messages( ptr::swap(error_out, &mut error as *mut c_int); } + let mlimit = u64::try_from(*limit).unwrap_or(DEFAULT_MESSAGE_LIMIT); + let mpage = u64::try_from(*page).unwrap_or(DEFAULT_MESSAGE_PAGE); + let mut messages = Vec::new(); - let mut retrieved_messages = (*client).runtime.block_on((*client).client.get_all_messages(&*address)); + let mut retrieved_messages = (*client) + .runtime + .block_on((*client).client.get_messages(&*address, mlimit, mpage)); messages.append(&mut retrieved_messages); Box::into_raw(Box::new(ChatMessages(messages))) diff --git a/base_layer/contacts/examples/chat_client/src/client.rs b/base_layer/contacts/examples/chat_client/src/client.rs index 745e923e1b..1e80fff872 100644 --- a/base_layer/contacts/examples/chat_client/src/client.rs +++ b/base_layer/contacts/examples/chat_client/src/client.rs @@ -46,7 +46,7 @@ pub trait ChatClient { async fn add_contact(&self, address: &TariAddress); async fn check_online_status(&self, address: &TariAddress) -> ContactOnlineStatus; async fn send_message(&self, receiver: TariAddress, message: String); - async fn get_all_messages(&self, sender: &TariAddress) -> Vec; + async fn get_messages(&self, sender: &TariAddress, limit: u64, page: u64) -> Vec; fn identity(&self) -> &NodeIdentity; fn shutdown(&mut self); } @@ -157,11 +157,11 @@ impl ChatClient for Client { } } - async fn get_all_messages(&self, sender: &TariAddress) -> Vec { + async fn get_messages(&self, sender: &TariAddress, limit: u64, page: u64) -> Vec { let mut messages = vec![]; if let Some(mut contacts_service) = self.contacts.clone() { messages = contacts_service - .get_all_messages(sender.clone()) + .get_messages(sender.clone(), limit, page) .await .expect("Messages not fetched"); } diff --git a/base_layer/contacts/src/contacts_service/handle.rs b/base_layer/contacts/src/contacts_service/handle.rs index d6eb9c451b..3ea70aeb89 100644 --- a/base_layer/contacts/src/contacts_service/handle.rs +++ b/base_layer/contacts/src/contacts_service/handle.rs @@ -21,6 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use std::{ + convert::TryFrom, fmt::{Display, Error, Formatter}, sync::Arc, }; @@ -38,6 +39,10 @@ use crate::contacts_service::{ types::{Contact, Message}, }; +pub static DEFAULT_MESSAGE_LIMIT: u64 = 35; +pub static MAX_MESSAGE_LIMIT: u64 = 2500; +pub static DEFAULT_MESSAGE_PAGE: u64 = 0; + #[derive(Debug, Clone, PartialEq, Eq)] pub struct ContactsLivenessData { address: TariAddress, @@ -131,7 +136,7 @@ pub enum ContactsServiceRequest { GetContacts, GetContactOnlineStatus(Contact), SendMessage(TariAddress, Message), - GetAllMessages(TariAddress), + GetMessages(TariAddress, i64, i64), } #[derive(Debug)] @@ -236,10 +241,30 @@ impl ContactsServiceHandle { } } - pub async fn get_all_messages(&mut self, pk: TariAddress) -> Result, ContactsServiceError> { + pub async fn get_messages( + &mut self, + pk: TariAddress, + mut limit: u64, + mut page: u64, + ) -> Result, ContactsServiceError> { + if limit == 0 || limit > MAX_MESSAGE_LIMIT { + limit = DEFAULT_MESSAGE_LIMIT; + } + + page = match page.checked_mul(limit) { + Some(_) => page, + None => DEFAULT_MESSAGE_PAGE, + }; + + // const values won't be a problem here + #[allow(clippy::cast_possible_wrap)] match self .request_response_service - .call(ContactsServiceRequest::GetAllMessages(pk)) + .call(ContactsServiceRequest::GetMessages( + pk, + i64::try_from(limit).unwrap_or(DEFAULT_MESSAGE_LIMIT as i64), + i64::try_from(page).unwrap_or(DEFAULT_MESSAGE_PAGE as i64), + )) .await?? { ContactsServiceResponse::Messages(messages) => Ok(messages), diff --git a/base_layer/contacts/src/contacts_service/service.rs b/base_layer/contacts/src/contacts_service/service.rs index 7ea3c28a81..8112c6461b 100644 --- a/base_layer/contacts/src/contacts_service/service.rs +++ b/base_layer/contacts/src/contacts_service/service.rs @@ -291,8 +291,8 @@ where T: ContactsBackend + 'static let result = self.get_online_status(&contact).await; Ok(result.map(ContactsServiceResponse::OnlineStatus)?) }, - ContactsServiceRequest::GetAllMessages(pk) => { - let result = self.db.get_messages(pk); + ContactsServiceRequest::GetMessages(pk, limit, page) => { + let result = self.db.get_messages(pk, limit, page); Ok(result.map(ContactsServiceResponse::Messages)?) }, ContactsServiceRequest::SendMessage(address, mut message) => { diff --git a/base_layer/contacts/src/contacts_service/storage/database.rs b/base_layer/contacts/src/contacts_service/storage/database.rs index cff109fec5..a80de6fe74 100644 --- a/base_layer/contacts/src/contacts_service/storage/database.rs +++ b/base_layer/contacts/src/contacts_service/storage/database.rs @@ -50,7 +50,7 @@ pub enum DbKey { Contact(TariAddress), ContactId(NodeId), Contacts, - Messages(TariAddress), + Messages(TariAddress, i64, i64), } pub enum DbValue { @@ -163,8 +163,13 @@ where T: ContactsBackend + 'static } } - pub fn get_messages(&self, address: TariAddress) -> Result, ContactsServiceStorageError> { - let key = DbKey::Messages(address); + pub fn get_messages( + &self, + address: TariAddress, + limit: i64, + page: i64, + ) -> Result, ContactsServiceStorageError> { + let key = DbKey::Messages(address, limit, page); let db_clone = self.db.clone(); match db_clone.fetch(&key) { Ok(None) => log_error( @@ -197,7 +202,7 @@ impl Display for DbKey { DbKey::Contact(c) => f.write_str(&format!("Contact: {:?}", c)), DbKey::ContactId(id) => f.write_str(&format!("Contact: {:?}", id)), DbKey::Contacts => f.write_str("Contacts"), - DbKey::Messages(c) => f.write_str(&format!("Messages for id: {:?}", c)), + DbKey::Messages(c, _l, _p) => f.write_str(&format!("Messages for id: {:?}", c)), } } } diff --git a/base_layer/contacts/src/contacts_service/storage/sqlite_db.rs b/base_layer/contacts/src/contacts_service/storage/sqlite_db.rs index 097e667f21..c47c566a78 100644 --- a/base_layer/contacts/src/contacts_service/storage/sqlite_db.rs +++ b/base_layer/contacts/src/contacts_service/storage/sqlite_db.rs @@ -104,15 +104,17 @@ where TContactServiceDbConnection: PooledDbConnection, _>>()?, )), - DbKey::Messages(address) => match MessagesSql::find_by_address(&address.to_bytes(), &mut conn) { - Ok(messages) => Some(DbValue::Messages( - messages - .iter() - .map(|m| Message::try_from(m.clone()).expect("Couldn't cast MessageSql to Message")) - .collect::>(), - )), - Err(ContactsServiceStorageError::DieselError(DieselError::NotFound)) => None, - Err(e) => return Err(e), + DbKey::Messages(address, limit, page) => { + match MessagesSql::find_by_address(&address.to_bytes(), *limit, *page, &mut conn) { + Ok(messages) => Some(DbValue::Messages( + messages + .iter() + .map(|m| Message::try_from(m.clone()).expect("Couldn't cast MessageSql to Message")) + .collect::>(), + )), + Err(ContactsServiceStorageError::DieselError(DieselError::NotFound)) => None, + Err(e) => return Err(e), + } }, }; @@ -170,7 +172,7 @@ where TContactServiceDbConnection: PooledDbConnection return Err(e), }, DbKey::Contacts => return Err(ContactsServiceStorageError::OperationNotSupported), - DbKey::Messages(_pk) => return Err(ContactsServiceStorageError::OperationNotSupported), + DbKey::Messages(_pk, _l, _p) => return Err(ContactsServiceStorageError::OperationNotSupported), }, WriteOperation::Insert(i) => { if let DbValue::Message(m) = *i { diff --git a/base_layer/contacts/src/contacts_service/storage/types/messages.rs b/base_layer/contacts/src/contacts_service/storage/types/messages.rs index 596cc6a5d5..3f6f9226d7 100644 --- a/base_layer/contacts/src/contacts_service/storage/types/messages.rs +++ b/base_layer/contacts/src/contacts_service/storage/types/messages.rs @@ -71,10 +71,15 @@ impl MessagesSql { /// Find a particular message by their address, if it exists pub fn find_by_address( address: &[u8], + limit: i64, + page: i64, conn: &mut SqliteConnection, ) -> Result, ContactsServiceStorageError> { Ok(messages::table .filter(messages::address.eq(address)) + .order(messages::stored_at.desc()) + .offset(limit * page) + .limit(limit) .load::(conn)?) } } diff --git a/base_layer/contacts/tests/contacts_service.rs b/base_layer/contacts/tests/contacts_service.rs index 34d2ac5661..068c6bd075 100644 --- a/base_layer/contacts/tests/contacts_service.rs +++ b/base_layer/contacts/tests/contacts_service.rs @@ -30,12 +30,12 @@ use tari_comms::{peer_manager::PeerFeatures, NodeIdentity}; use tari_comms_dht::{store_forward::SafConfig, DhtConfig}; use tari_contacts::contacts_service::{ error::{ContactsServiceError, ContactsServiceStorageError}, - handle::ContactsServiceHandle, + handle::{ContactsServiceHandle, DEFAULT_MESSAGE_LIMIT, MAX_MESSAGE_LIMIT}, storage::{ - database::{ContactsBackend, DbKey}, + database::{ContactsBackend, ContactsDatabase, DbKey}, sqlite_db::ContactsServiceSqliteDatabase, }, - types::Contact, + types::{Contact, MessageBuilder}, ContactsServiceInitializer, }; use tari_crypto::keys::PublicKey as PublicKeyTrait; @@ -216,3 +216,92 @@ pub fn test_contacts_service() { }; }); } + +#[test] +pub fn test_message_pagination() { + with_temp_dir(|dir_path| { + let mut runtime = Runtime::new().unwrap(); + + let db_name = format!("{}.sqlite3", string(8).as_str()); + let db_path = format!("{}/{}", dir_path.to_str().unwrap(), db_name); + let url: DbConnectionUrl = db_path.try_into().unwrap(); + + let db = DbConnection::connect_url(&url).unwrap(); + let backend = ContactsServiceSqliteDatabase::init(db); + let contacts_db = ContactsDatabase::new(backend.clone()); + + let (mut contacts_service, _node_identity, _shutdown) = setup_contacts_service(&mut runtime, backend); + + let (_secret_key, public_key) = PublicKey::random_keypair(&mut OsRng); + let address = TariAddress::new(public_key, Network::default()); + + let contact = Contact::new(random::string(8), address.clone(), None, None, false); + runtime.block_on(contacts_service.upsert_contact(contact)).unwrap(); + + // Test lower bounds + for num in 0..8 { + let message = MessageBuilder::new() + .message(format!("Test {:?}", num)) + .address(address.clone()) + .build(); + + contacts_db.save_message(message.clone()).expect("Message to be saved"); + } + + let messages = runtime + .block_on(contacts_service.get_messages(address.clone(), 5, 0)) + .unwrap(); + assert_eq!(5, messages.len()); + + let messages = runtime + .block_on(contacts_service.get_messages(address.clone(), 5, 1)) + .unwrap(); + assert_eq!(3, messages.len()); + + let messages = runtime + .block_on(contacts_service.get_messages(address.clone(), 0, 0)) + .unwrap(); + assert_eq!(8, messages.len()); + + let messages = runtime + .block_on(contacts_service.get_messages(address.clone(), 0, 1)) + .unwrap(); + assert_eq!(0, messages.len()); + + // Test upper bounds + for num in 0..3000 { + let message = MessageBuilder::new() + .message(format!("Test {:?}", num)) + .address(address.clone()) + .build(); + + contacts_db.save_message(message.clone()).expect("Message to be saved"); + } + + let messages = runtime + .block_on(contacts_service.get_messages(address.clone(), u64::MAX, 0)) + .unwrap(); + assert_eq!(DEFAULT_MESSAGE_LIMIT, messages.len() as u64); + + let messages = runtime + .block_on(contacts_service.get_messages(address.clone(), MAX_MESSAGE_LIMIT, 0)) + .unwrap(); + assert_eq!(2500, messages.len()); + + let messages = runtime + .block_on(contacts_service.get_messages(address.clone(), MAX_MESSAGE_LIMIT, 1)) + .unwrap(); + assert_eq!(508, messages.len()); + + // Would cause overflows, defaults to page = 0 + let messages = runtime + .block_on(contacts_service.get_messages(address.clone(), MAX_MESSAGE_LIMIT, u64::MAX)) + .unwrap(); + assert_eq!(2500, messages.len()); + + let messages = runtime + .block_on(contacts_service.get_messages(address, 1, i64::MAX as u64)) + .unwrap(); + assert_eq!(0, messages.len()); + }); +} diff --git a/integration_tests/src/chat_ffi.rs b/integration_tests/src/chat_ffi.rs index 174c13f52b..1c31efdce2 100644 --- a/integration_tests/src/chat_ffi.rs +++ b/integration_tests/src/chat_ffi.rs @@ -68,7 +68,13 @@ extern "C" { pub fn send_message(client: *mut ClientFFI, receiver: *mut c_void, message: *const c_char, out_error: *const c_int); pub fn add_contact(client: *mut ClientFFI, address: *mut c_void, out_error: *const c_int); pub fn check_online_status(client: *mut ClientFFI, address: *mut c_void, out_error: *const c_int) -> c_int; - pub fn get_all_messages(client: *mut ClientFFI, sender: *mut c_void, out_error: *const c_int) -> *mut c_void; + pub fn get_messages( + client: *mut ClientFFI, + sender: *mut c_void, + limit: *mut c_void, + page: *mut c_void, + out_error: *const c_int, + ) -> *mut c_void; pub fn destroy_client_ffi(client: *mut ClientFFI); } @@ -119,7 +125,7 @@ impl ChatClient for ChatFFI { } } - async fn get_all_messages(&self, address: &TariAddress) -> Vec { + async fn get_messages(&self, address: &TariAddress, limit: u64, page: u64) -> Vec { let client = self.ptr.lock().unwrap(); let address_ptr = Box::into_raw(Box::new(address.clone())) as *mut c_void; @@ -127,7 +133,9 @@ impl ChatClient for ChatFFI { let messages; unsafe { let out_error = Box::into_raw(Box::new(0)); - let all_messages = get_all_messages(client.0, address_ptr, out_error) as *mut Vec; + let limit = Box::into_raw(Box::new(limit)) as *mut c_void; + let page = Box::into_raw(Box::new(page)) as *mut c_void; + let all_messages = get_messages(client.0, address_ptr, limit, page, out_error) as *mut Vec; messages = (*all_messages).clone(); } diff --git a/integration_tests/tests/steps/chat_steps.rs b/integration_tests/tests/steps/chat_steps.rs index 84b5678755..5fbe417efa 100644 --- a/integration_tests/tests/steps/chat_steps.rs +++ b/integration_tests/tests/steps/chat_steps.rs @@ -25,7 +25,10 @@ use std::time::Duration; use cucumber::{then, when}; use tari_common::configuration::Network; use tari_common_types::tari_address::TariAddress; -use tari_contacts::contacts_service::service::ContactOnlineStatus; +use tari_contacts::contacts_service::{ + handle::{DEFAULT_MESSAGE_LIMIT, DEFAULT_MESSAGE_PAGE}, + service::ContactOnlineStatus, +}; use tari_integration_tests::{chat_client::spawn_chat_client, TariWorld}; use crate::steps::{HALF_SECOND, TWO_MINUTES_WITH_HALF_SECOND_SLEEP}; @@ -73,7 +76,9 @@ async fn receive_n_messages(world: &mut TariWorld, receiver: String, message_cou let mut messages = vec![]; for _ in 0..(TWO_MINUTES_WITH_HALF_SECOND_SLEEP) { - messages = (*receiver).get_all_messages(&address).await; + messages = (*receiver) + .get_messages(&address, DEFAULT_MESSAGE_LIMIT, DEFAULT_MESSAGE_PAGE) + .await; if messages.len() as u64 == message_count { return;