diff --git a/Cargo.lock b/Cargo.lock index d936b29..3fd6195 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2493,6 +2493,7 @@ dependencies = [ "serde_json", "starknet 0.9.0", "starknet-crypto", + "stream-cancel", "tokio", "tokio-stream", "torii-client", @@ -8755,6 +8756,17 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "stream-cancel" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f9fbf9bd71e4cf18d68a8a0951c0e5b7255920c0cd992c4ff51cddd6ef514a3" +dependencies = [ + "futures-core", + "pin-project", + "tokio", +] + [[package]] name = "string_cache" version = "0.8.7" diff --git a/Cargo.toml b/Cargo.toml index 1d9628b..cce96c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ tokio-stream = "0.1.14" futures = "0.3.30" futures-channel = "0.3.30" wasm-bindgen = "0.2.92" +stream-cancel = "0.8.2" [patch.crates-io] deno_task_shell = { git = "https://github.com/denoland/deno_task_shell", tag = "0.15.0" } diff --git a/dojo.h b/dojo.h index 6c7d6d3..9cf8f8c 100644 --- a/dojo.h +++ b/dojo.h @@ -26,6 +26,8 @@ typedef struct Account Account; typedef struct Provider Provider; +typedef struct Subscription Subscription; + typedef struct ToriiClient ToriiClient; typedef struct Error { @@ -437,6 +439,23 @@ typedef struct Resultbool { }; } Resultbool; +typedef enum ResultSubscription_Tag { + OkSubscription, + ErrSubscription, +} ResultSubscription_Tag; + +typedef struct ResultSubscription { + ResultSubscription_Tag tag; + union { + struct { + struct Subscription *ok; + }; + struct { + struct Error err; + }; + }; +} ResultSubscription; + typedef enum ResultFieldElement_Tag { OkFieldElement, ErrFieldElement, @@ -579,15 +598,15 @@ struct Resultbool client_add_models_to_sync(struct ToriiClient *client, const struct KeysClause *models, uintptr_t models_len); -struct Resultbool client_on_sync_model_update(struct ToriiClient *client, - struct KeysClause model, - void (*callback)(void)); +struct ResultSubscription client_on_sync_model_update(struct ToriiClient *client, + struct KeysClause model, + void (*callback)(void)); -struct Resultbool client_on_entity_state_update(struct ToriiClient *client, - struct FieldElement *entities, - uintptr_t entities_len, - void (*callback)(struct FieldElement, - struct CArrayModel)); +struct ResultSubscription client_on_entity_state_update(struct ToriiClient *client, + struct FieldElement *entities, + uintptr_t entities_len, + void (*callback)(struct FieldElement, + struct CArrayModel)); struct Resultbool client_remove_models_to_sync(struct ToriiClient *client, const struct KeysClause *models, @@ -637,6 +656,8 @@ struct FieldElement hash_get_contract_address(struct FieldElement class_hash, uintptr_t constructor_calldata_len, struct FieldElement deployer_address); +void subscription_cancel(struct Subscription *subscription); + void client_free(struct ToriiClient *t); void provider_free(struct Provider *rpc); diff --git a/dojo.hpp b/dojo.hpp index bbc2460..b5cbb85 100644 --- a/dojo.hpp +++ b/dojo.hpp @@ -29,6 +29,8 @@ struct Account; struct Provider; +struct Subscription; + struct ToriiClient; struct Error { @@ -775,12 +777,14 @@ Result client_add_models_to_sync(ToriiClient *client, const KeysClause *models, uintptr_t models_len); -Result client_on_sync_model_update(ToriiClient *client, KeysClause model, void (*callback)()); +Result client_on_sync_model_update(ToriiClient *client, + KeysClause model, + void (*callback)()); -Result client_on_entity_state_update(ToriiClient *client, - FieldElement *entities, - uintptr_t entities_len, - void (*callback)(FieldElement, CArray)); +Result client_on_entity_state_update(ToriiClient *client, + FieldElement *entities, + uintptr_t entities_len, + void (*callback)(FieldElement, CArray)); Result client_remove_models_to_sync(ToriiClient *client, const KeysClause *models, @@ -826,6 +830,8 @@ FieldElement hash_get_contract_address(FieldElement class_hash, uintptr_t constructor_calldata_len, FieldElement deployer_address); +void subscription_cancel(Subscription *subscription); + void client_free(ToriiClient *t); void provider_free(Provider *rpc); diff --git a/src/c/mod.rs b/src/c/mod.rs index b4412ad..45fe206 100644 --- a/src/c/mod.rs +++ b/src/c/mod.rs @@ -5,7 +5,7 @@ use self::types::{ ToriiClient, Ty, WorldMetadata, }; use crate::constants; -use crate::types::{Account, Provider}; +use crate::types::{Account, Provider, Subscription}; use crate::utils::watch_tx; use starknet::accounts::{Account as StarknetAccount, ExecutionEncoding, SingleOwnerAccount}; use starknet::core::types::FunctionCall; @@ -20,6 +20,7 @@ use std::ffi::{c_void, CStr, CString}; use std::ops::Deref; use std::os::raw::c_char; use std::sync::Arc; +use stream_cancel::{StreamExt as _, Tripwire}; use tokio_stream::StreamExt; use torii_client::client::Client as TClient; use torii_relay::typed_data::TypedData; @@ -199,11 +200,11 @@ pub unsafe extern "C" fn client_on_sync_model_update( client: *mut ToriiClient, model: KeysClause, callback: unsafe extern "C" fn(), -) -> Result { +) -> Result<*mut Subscription> { let model: torii_grpc::types::KeysClause = (&model).into(); let storage = (*client).inner.storage(); - let mut rcv = match storage.add_listener( + let rcv = match storage.add_listener( cairo_short_string_to_felt(model.model.as_str()).unwrap(), model.keys.as_slice(), ) { @@ -211,13 +212,16 @@ pub unsafe extern "C" fn client_on_sync_model_update( Err(e) => return Result::Err(e.into()), }; + let (trigger, tripwire) = Tripwire::new(); (*client).runtime.spawn(async move { - if let Ok(Some(_)) = rcv.try_next() { + let mut rcv = rcv.take_until_if(tripwire); + + while rcv.next().await.is_some() { callback(); } }); - Result::Ok(true) + Result::Ok(Box::into_raw(Box::new(Subscription(trigger)))) } #[no_mangle] @@ -227,18 +231,21 @@ pub unsafe extern "C" fn client_on_entity_state_update( entities: *mut types::FieldElement, entities_len: usize, callback: unsafe extern "C" fn(types::FieldElement, CArray), -) -> Result { +) -> Result<*mut Subscription> { let entities = unsafe { std::slice::from_raw_parts(entities, entities_len) }; // to vec of fieldleemnt let entities = entities.iter().map(|e| (&e.clone()).into()).collect(); let entity_stream = unsafe { (*client).inner.on_entity_updated(entities) }; - let mut rcv = match (*client).runtime.block_on(entity_stream) { + let rcv = match (*client).runtime.block_on(entity_stream) { Ok(rcv) => rcv, Err(e) => return Result::Err(e.into()), }; + let (trigger, tripwire) = Tripwire::new(); (*client).runtime.spawn(async move { + let mut rcv = rcv.take_until_if(tripwire); + while let Some(Ok(entity)) = rcv.next().await { let key: types::FieldElement = (&entity.hashed_keys).into(); let models: Vec = entity.models.into_iter().map(|e| (&e).into()).collect(); @@ -246,7 +253,7 @@ pub unsafe extern "C" fn client_on_entity_state_update( } }); - Result::Ok(true) + Result::Ok(Box::into_raw(Box::new(Subscription(trigger)))) } #[no_mangle] @@ -560,6 +567,17 @@ pub unsafe extern "C" fn hash_get_contract_address( (&address).into() } +#[no_mangle] +#[allow(clippy::missing_safety_doc)] +pub unsafe extern "C" fn subscription_cancel(subscription: *mut Subscription) { + if !subscription.is_null() { + unsafe { + let subscription = Box::from_raw(subscription); + subscription.0.cancel(); + } + } +} + // This function takes a raw pointer to ToriiClient as an argument. // It checks if the pointer is not null. If it's not, it converts the raw pointer // back into a Box, which gets dropped at the end of the scope, diff --git a/src/types.rs b/src/types.rs index 67afa34..113edc9 100644 --- a/src/types.rs +++ b/src/types.rs @@ -5,9 +5,12 @@ use starknet::{ providers::{jsonrpc::HttpTransport, JsonRpcClient}, signers::LocalWallet, }; +use stream_cancel::Trigger; use wasm_bindgen::prelude::*; #[wasm_bindgen] pub struct Provider(pub(crate) Arc>); #[wasm_bindgen] pub struct Account(pub(crate) SingleOwnerAccount>, LocalWallet>); +#[wasm_bindgen] +pub struct Subscription(pub(crate) Trigger); diff --git a/src/wasm/mod.rs b/src/wasm/mod.rs index 1db15ec..d4a7a83 100644 --- a/src/wasm/mod.rs +++ b/src/wasm/mod.rs @@ -18,13 +18,14 @@ use starknet::providers::jsonrpc::HttpTransport; use starknet::providers::{JsonRpcClient, Provider as _}; use starknet::signers::{LocalWallet, SigningKey, VerifyingKey}; use starknet_crypto::Signature; +use stream_cancel::{StreamExt as _, Tripwire}; use torii_relay::typed_data::TypedData; use torii_relay::types::Message; use tsify::Tsify; use wasm_bindgen::prelude::*; use crate::constants; -use crate::types::{Account, Provider}; +use crate::types::{Account, Provider, Subscription}; use crate::utils::watch_tx; use crate::wasm::utils::{parse_entities_as_json_str, parse_ty_as_json_str}; @@ -745,12 +746,12 @@ impl Client { &self, model: KeysClause, callback: js_sys::Function, - ) -> Result<(), JsValue> { + ) -> Result { #[cfg(feature = "console-error-panic")] console_error_panic_hook::set_once(); let name = cairo_short_string_to_felt(&model.model).expect("invalid model name"); - let mut rcv = self + let rcv = self .inner .storage() .add_listener( @@ -764,13 +765,16 @@ impl Client { ) .unwrap(); + let (trigger, tripwire) = Tripwire::new(); wasm_bindgen_futures::spawn_local(async move { + let mut rcv = rcv.take_until_if(tripwire); + while rcv.next().await.is_some() { let _ = callback.call0(&JsValue::null()); } }); - Ok(()) + Ok(Subscription(trigger)) } #[wasm_bindgen(js_name = onEntityUpdated)] @@ -778,7 +782,7 @@ impl Client { &self, ids: Option>, callback: js_sys::Function, - ) -> Result<(), JsValue> { + ) -> Result { #[cfg(feature = "console-error-panic")] console_error_panic_hook::set_once(); @@ -791,9 +795,12 @@ impl Client { }) .collect::, _>>()?; - let mut stream = self.inner.on_entity_updated(ids).await.unwrap(); + let stream = self.inner.on_entity_updated(ids).await.unwrap(); + let (trigger, tripwire) = Tripwire::new(); wasm_bindgen_futures::spawn_local(async move { + let mut stream = stream.take_until_if(tripwire); + while let Some(update) = stream.next().await { let entity = update.expect("no updated entity"); let json_str = parse_entities_as_json_str(vec![entity]).to_string(); @@ -804,7 +811,7 @@ impl Client { } }); - Ok(()) + Ok(Subscription(trigger)) } #[wasm_bindgen(js_name = publishMessage)] @@ -835,6 +842,13 @@ impl Client { } } +#[wasm_bindgen] +impl Subscription { + pub fn cancel(self) { + self.0.cancel(); + } +} + /// Create the a client with the given configurations. #[wasm_bindgen(js_name = createClient)] #[allow(non_snake_case)]