Skip to content

Commit

Permalink
feat: chat message fetching pagination (tari-project#5594)
Browse files Browse the repository at this point in the history
Description
---
Previously the only function to get messages was `get_all_messages`
which return all related records in the db. Now we've added pagination
and you can select the amount of results per page, and what page you'd
like returned. Results are now ordered by most recent received.

Motivation and Context
---
`get_all_messages` has the potential to get very slow and ddos the
client. Limiting the amount returned and allowing to change the
requested amounts makes things better to deal with.

How Has This Been Tested?
---
Test added for pagination.

What process can a PR reviewer use to test or verify this change?
---
Tests pass.

Breaking Changes
---

- [x] None
- [ ] Requires data directory on base node to be deleted
- [ ] Requires hard fork
- [ ] Other - Please specify
  • Loading branch information
brianp committed Jul 17, 2023
1 parent 52d7990 commit 2024357
Show file tree
Hide file tree
Showing 11 changed files with 192 additions and 37 deletions.
10 changes: 7 additions & 3 deletions base_layer/chat_ffi/chat.h
Expand Up @@ -152,6 +152,8 @@ int check_online_status(struct ClientFFI *client, struct TariAddress *receiver,
* ## Arguments
* `client` - The Client pointer
* `address` - A TariAddress ptr
* `limit` - The amount of messages you want to fetch. Default to 35, max 2500
* `page` - The page of results you'd like returned. Default to 0, maximum of u64 max
* `error_out` - Pointer to an int which will be modified
*
* ## Returns
Expand All @@ -161,9 +163,11 @@ int check_online_status(struct ClientFFI *client, struct TariAddress *receiver,
* The ```address``` should be destroyed after use
* The returned pointer to ```*mut ChatMessages``` should be destroyed after use
*/
struct ChatMessages *get_all_messages(struct ClientFFI *client,
struct TariAddress *address,
int *error_out);
struct ChatMessages *get_messages(struct ClientFFI *client,
struct TariAddress *address,
int *limit,
int *page,
int *error_out);

/**
* Frees memory for messages
Expand Down
20 changes: 16 additions & 4 deletions base_layer/chat_ffi/src/lib.rs
Expand Up @@ -22,7 +22,7 @@

#![recursion_limit = "1024"]

use std::{ffi::CStr, path::PathBuf, ptr, str::FromStr, sync::Arc};
use std::{convert::TryFrom, ffi::CStr, path::PathBuf, ptr, str::FromStr, sync::Arc};

use callback_handler::CallbackContactStatusChange;
use libc::{c_char, c_int};
Expand All @@ -35,7 +35,10 @@ use tari_chat_client::{
use tari_common::configuration::{MultiaddrList, Network};
use tari_common_types::tari_address::TariAddress;
use tari_comms::{multiaddr::Multiaddr, NodeIdentity};
use tari_contacts::contacts_service::types::Message;
use tari_contacts::contacts_service::{
handle::{DEFAULT_MESSAGE_LIMIT, DEFAULT_MESSAGE_PAGE},
types::Message,
};
use tokio::runtime::Runtime;

use crate::{
Expand Down Expand Up @@ -393,6 +396,8 @@ pub unsafe extern "C" fn check_online_status(
/// ## Arguments
/// `client` - The Client pointer
/// `address` - A TariAddress ptr
/// `limit` - The amount of messages you want to fetch. Default to 35, max 2500
/// `page` - The page of results you'd like returned. Default to 0, maximum of u64 max
/// `error_out` - Pointer to an int which will be modified
///
/// ## Returns
Expand All @@ -402,9 +407,11 @@ pub unsafe extern "C" fn check_online_status(
/// The ```address``` should be destroyed after use
/// The returned pointer to ```*mut ChatMessages``` should be destroyed after use
#[no_mangle]
pub unsafe extern "C" fn get_all_messages(
pub unsafe extern "C" fn get_messages(
client: *mut ClientFFI,
address: *mut TariAddress,
limit: *mut c_int,
page: *mut c_int,
error_out: *mut c_int,
) -> *mut ChatMessages {
let mut error = 0;
Expand All @@ -420,9 +427,14 @@ pub unsafe extern "C" fn get_all_messages(
ptr::swap(error_out, &mut error as *mut c_int);
}

let mlimit = u64::try_from(*limit).unwrap_or(DEFAULT_MESSAGE_LIMIT);
let mpage = u64::try_from(*page).unwrap_or(DEFAULT_MESSAGE_PAGE);

let mut messages = Vec::new();

let mut retrieved_messages = (*client).runtime.block_on((*client).client.get_all_messages(&*address));
let mut retrieved_messages = (*client)
.runtime
.block_on((*client).client.get_messages(&*address, mlimit, mpage));
messages.append(&mut retrieved_messages);

Box::into_raw(Box::new(ChatMessages(messages)))
Expand Down
6 changes: 3 additions & 3 deletions base_layer/contacts/examples/chat_client/src/client.rs
Expand Up @@ -46,7 +46,7 @@ pub trait ChatClient {
async fn add_contact(&self, address: &TariAddress);
async fn check_online_status(&self, address: &TariAddress) -> ContactOnlineStatus;
async fn send_message(&self, receiver: TariAddress, message: String);
async fn get_all_messages(&self, sender: &TariAddress) -> Vec<Message>;
async fn get_messages(&self, sender: &TariAddress, limit: u64, page: u64) -> Vec<Message>;
fn identity(&self) -> &NodeIdentity;
fn shutdown(&mut self);
}
Expand Down Expand Up @@ -157,11 +157,11 @@ impl ChatClient for Client {
}
}

async fn get_all_messages(&self, sender: &TariAddress) -> Vec<Message> {
async fn get_messages(&self, sender: &TariAddress, limit: u64, page: u64) -> Vec<Message> {
let mut messages = vec![];
if let Some(mut contacts_service) = self.contacts.clone() {
messages = contacts_service
.get_all_messages(sender.clone())
.get_messages(sender.clone(), limit, page)
.await
.expect("Messages not fetched");
}
Expand Down
31 changes: 28 additions & 3 deletions base_layer/contacts/src/contacts_service/handle.rs
Expand Up @@ -21,6 +21,7 @@
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use std::{
convert::TryFrom,
fmt::{Display, Error, Formatter},
sync::Arc,
};
Expand All @@ -38,6 +39,10 @@ use crate::contacts_service::{
types::{Contact, Message},
};

pub static DEFAULT_MESSAGE_LIMIT: u64 = 35;
pub static MAX_MESSAGE_LIMIT: u64 = 2500;
pub static DEFAULT_MESSAGE_PAGE: u64 = 0;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ContactsLivenessData {
address: TariAddress,
Expand Down Expand Up @@ -131,7 +136,7 @@ pub enum ContactsServiceRequest {
GetContacts,
GetContactOnlineStatus(Contact),
SendMessage(TariAddress, Message),
GetAllMessages(TariAddress),
GetMessages(TariAddress, i64, i64),
}

#[derive(Debug)]
Expand Down Expand Up @@ -236,10 +241,30 @@ impl ContactsServiceHandle {
}
}

pub async fn get_all_messages(&mut self, pk: TariAddress) -> Result<Vec<Message>, ContactsServiceError> {
pub async fn get_messages(
&mut self,
pk: TariAddress,
mut limit: u64,
mut page: u64,
) -> Result<Vec<Message>, ContactsServiceError> {
if limit == 0 || limit > MAX_MESSAGE_LIMIT {
limit = DEFAULT_MESSAGE_LIMIT;
}

page = match page.checked_mul(limit) {
Some(_) => page,
None => DEFAULT_MESSAGE_PAGE,
};

// const values won't be a problem here
#[allow(clippy::cast_possible_wrap)]
match self
.request_response_service
.call(ContactsServiceRequest::GetAllMessages(pk))
.call(ContactsServiceRequest::GetMessages(
pk,
i64::try_from(limit).unwrap_or(DEFAULT_MESSAGE_LIMIT as i64),
i64::try_from(page).unwrap_or(DEFAULT_MESSAGE_PAGE as i64),
))
.await??
{
ContactsServiceResponse::Messages(messages) => Ok(messages),
Expand Down
4 changes: 2 additions & 2 deletions base_layer/contacts/src/contacts_service/service.rs
Expand Up @@ -291,8 +291,8 @@ where T: ContactsBackend + 'static
let result = self.get_online_status(&contact).await;
Ok(result.map(ContactsServiceResponse::OnlineStatus)?)
},
ContactsServiceRequest::GetAllMessages(pk) => {
let result = self.db.get_messages(pk);
ContactsServiceRequest::GetMessages(pk, limit, page) => {
let result = self.db.get_messages(pk, limit, page);
Ok(result.map(ContactsServiceResponse::Messages)?)
},
ContactsServiceRequest::SendMessage(address, mut message) => {
Expand Down
13 changes: 9 additions & 4 deletions base_layer/contacts/src/contacts_service/storage/database.rs
Expand Up @@ -50,7 +50,7 @@ pub enum DbKey {
Contact(TariAddress),
ContactId(NodeId),
Contacts,
Messages(TariAddress),
Messages(TariAddress, i64, i64),
}

pub enum DbValue {
Expand Down Expand Up @@ -163,8 +163,13 @@ where T: ContactsBackend + 'static
}
}

pub fn get_messages(&self, address: TariAddress) -> Result<Vec<Message>, ContactsServiceStorageError> {
let key = DbKey::Messages(address);
pub fn get_messages(
&self,
address: TariAddress,
limit: i64,
page: i64,
) -> Result<Vec<Message>, ContactsServiceStorageError> {
let key = DbKey::Messages(address, limit, page);
let db_clone = self.db.clone();
match db_clone.fetch(&key) {
Ok(None) => log_error(
Expand Down Expand Up @@ -197,7 +202,7 @@ impl Display for DbKey {
DbKey::Contact(c) => f.write_str(&format!("Contact: {:?}", c)),
DbKey::ContactId(id) => f.write_str(&format!("Contact: {:?}", id)),
DbKey::Contacts => f.write_str("Contacts"),
DbKey::Messages(c) => f.write_str(&format!("Messages for id: {:?}", c)),
DbKey::Messages(c, _l, _p) => f.write_str(&format!("Messages for id: {:?}", c)),
}
}
}
Expand Down
22 changes: 12 additions & 10 deletions base_layer/contacts/src/contacts_service/storage/sqlite_db.rs
Expand Up @@ -104,15 +104,17 @@ where TContactServiceDbConnection: PooledDbConnection<Error = SqliteStorageError
.map(|c| Contact::try_from(c.clone()))
.collect::<Result<Vec<_>, _>>()?,
)),
DbKey::Messages(address) => match MessagesSql::find_by_address(&address.to_bytes(), &mut conn) {
Ok(messages) => Some(DbValue::Messages(
messages
.iter()
.map(|m| Message::try_from(m.clone()).expect("Couldn't cast MessageSql to Message"))
.collect::<Vec<Message>>(),
)),
Err(ContactsServiceStorageError::DieselError(DieselError::NotFound)) => None,
Err(e) => return Err(e),
DbKey::Messages(address, limit, page) => {
match MessagesSql::find_by_address(&address.to_bytes(), *limit, *page, &mut conn) {
Ok(messages) => Some(DbValue::Messages(
messages
.iter()
.map(|m| Message::try_from(m.clone()).expect("Couldn't cast MessageSql to Message"))
.collect::<Vec<Message>>(),
)),
Err(ContactsServiceStorageError::DieselError(DieselError::NotFound)) => None,
Err(e) => return Err(e),
}
},
};

Expand Down Expand Up @@ -170,7 +172,7 @@ where TContactServiceDbConnection: PooledDbConnection<Error = SqliteStorageError
Err(e) => return Err(e),
},
DbKey::Contacts => return Err(ContactsServiceStorageError::OperationNotSupported),
DbKey::Messages(_pk) => return Err(ContactsServiceStorageError::OperationNotSupported),
DbKey::Messages(_pk, _l, _p) => return Err(ContactsServiceStorageError::OperationNotSupported),
},
WriteOperation::Insert(i) => {
if let DbValue::Message(m) = *i {
Expand Down
Expand Up @@ -71,10 +71,15 @@ impl MessagesSql {
/// Find a particular message by their address, if it exists
pub fn find_by_address(
address: &[u8],
limit: i64,
page: i64,
conn: &mut SqliteConnection,
) -> Result<Vec<MessagesSql>, ContactsServiceStorageError> {
Ok(messages::table
.filter(messages::address.eq(address))
.order(messages::stored_at.desc())
.offset(limit * page)
.limit(limit)
.load::<MessagesSql>(conn)?)
}
}
Expand Down
95 changes: 92 additions & 3 deletions base_layer/contacts/tests/contacts_service.rs
Expand Up @@ -30,12 +30,12 @@ use tari_comms::{peer_manager::PeerFeatures, NodeIdentity};
use tari_comms_dht::{store_forward::SafConfig, DhtConfig};
use tari_contacts::contacts_service::{
error::{ContactsServiceError, ContactsServiceStorageError},
handle::ContactsServiceHandle,
handle::{ContactsServiceHandle, DEFAULT_MESSAGE_LIMIT, MAX_MESSAGE_LIMIT},
storage::{
database::{ContactsBackend, DbKey},
database::{ContactsBackend, ContactsDatabase, DbKey},
sqlite_db::ContactsServiceSqliteDatabase,
},
types::Contact,
types::{Contact, MessageBuilder},
ContactsServiceInitializer,
};
use tari_crypto::keys::PublicKey as PublicKeyTrait;
Expand Down Expand Up @@ -216,3 +216,92 @@ pub fn test_contacts_service() {
};
});
}

#[test]
pub fn test_message_pagination() {
with_temp_dir(|dir_path| {
let mut runtime = Runtime::new().unwrap();

let db_name = format!("{}.sqlite3", string(8).as_str());
let db_path = format!("{}/{}", dir_path.to_str().unwrap(), db_name);
let url: DbConnectionUrl = db_path.try_into().unwrap();

let db = DbConnection::connect_url(&url).unwrap();
let backend = ContactsServiceSqliteDatabase::init(db);
let contacts_db = ContactsDatabase::new(backend.clone());

let (mut contacts_service, _node_identity, _shutdown) = setup_contacts_service(&mut runtime, backend);

let (_secret_key, public_key) = PublicKey::random_keypair(&mut OsRng);
let address = TariAddress::new(public_key, Network::default());

let contact = Contact::new(random::string(8), address.clone(), None, None, false);
runtime.block_on(contacts_service.upsert_contact(contact)).unwrap();

// Test lower bounds
for num in 0..8 {
let message = MessageBuilder::new()
.message(format!("Test {:?}", num))
.address(address.clone())
.build();

contacts_db.save_message(message.clone()).expect("Message to be saved");
}

let messages = runtime
.block_on(contacts_service.get_messages(address.clone(), 5, 0))
.unwrap();
assert_eq!(5, messages.len());

let messages = runtime
.block_on(contacts_service.get_messages(address.clone(), 5, 1))
.unwrap();
assert_eq!(3, messages.len());

let messages = runtime
.block_on(contacts_service.get_messages(address.clone(), 0, 0))
.unwrap();
assert_eq!(8, messages.len());

let messages = runtime
.block_on(contacts_service.get_messages(address.clone(), 0, 1))
.unwrap();
assert_eq!(0, messages.len());

// Test upper bounds
for num in 0..3000 {
let message = MessageBuilder::new()
.message(format!("Test {:?}", num))
.address(address.clone())
.build();

contacts_db.save_message(message.clone()).expect("Message to be saved");
}

let messages = runtime
.block_on(contacts_service.get_messages(address.clone(), u64::MAX, 0))
.unwrap();
assert_eq!(DEFAULT_MESSAGE_LIMIT, messages.len() as u64);

let messages = runtime
.block_on(contacts_service.get_messages(address.clone(), MAX_MESSAGE_LIMIT, 0))
.unwrap();
assert_eq!(2500, messages.len());

let messages = runtime
.block_on(contacts_service.get_messages(address.clone(), MAX_MESSAGE_LIMIT, 1))
.unwrap();
assert_eq!(508, messages.len());

// Would cause overflows, defaults to page = 0
let messages = runtime
.block_on(contacts_service.get_messages(address.clone(), MAX_MESSAGE_LIMIT, u64::MAX))
.unwrap();
assert_eq!(2500, messages.len());

let messages = runtime
.block_on(contacts_service.get_messages(address, 1, i64::MAX as u64))
.unwrap();
assert_eq!(0, messages.len());
});
}

0 comments on commit 2024357

Please sign in to comment.