From a091aa9af083e7c8ef5ba4d45197136fb229212d Mon Sep 17 00:00:00 2001 From: Nasr Date: Tue, 23 Apr 2024 12:17:16 -0400 Subject: [PATCH 1/5] feat: abort oin entity update subscription --- dojo.h | 31 ++++++++++++++++++++++++++----- dojo.hpp | 12 ++++++++---- src/c/mod.rs | 19 +++++++++++++++---- src/types.rs | 3 +++ 4 files changed, 52 insertions(+), 13 deletions(-) diff --git a/dojo.h b/dojo.h index 6c7d6d3..0e64b3d 100644 --- a/dojo.h +++ b/dojo.h @@ -26,6 +26,8 @@ typedef struct Account Account; typedef struct Provider Provider; +typedef struct Subscription Subscription; + typedef struct ToriiClient ToriiClient; typedef struct Error { @@ -437,6 +439,23 @@ typedef struct Resultbool { }; } Resultbool; +typedef enum ResultSubscription_Tag { + OkSubscription, + ErrSubscription, +} ResultSubscription_Tag; + +typedef struct ResultSubscription { + ResultSubscription_Tag tag; + union { + struct { + struct Subscription ok; + }; + struct { + struct Error err; + }; + }; +} ResultSubscription; + typedef enum ResultFieldElement_Tag { OkFieldElement, ErrFieldElement, @@ -583,11 +602,11 @@ struct Resultbool client_on_sync_model_update(struct ToriiClient *client, struct KeysClause model, void (*callback)(void)); -struct Resultbool client_on_entity_state_update(struct ToriiClient *client, - struct FieldElement *entities, - uintptr_t entities_len, - void (*callback)(struct FieldElement, - struct CArrayModel)); +struct ResultSubscription client_on_entity_state_update(struct ToriiClient *client, + struct FieldElement *entities, + uintptr_t entities_len, + void (*callback)(struct FieldElement, + struct CArrayModel)); struct Resultbool client_remove_models_to_sync(struct ToriiClient *client, const struct KeysClause *models, @@ -637,6 +656,8 @@ struct FieldElement hash_get_contract_address(struct FieldElement class_hash, uintptr_t constructor_calldata_len, struct FieldElement deployer_address); +void subscription_abort(struct Subscription *subscription); + void client_free(struct ToriiClient *t); void provider_free(struct Provider *rpc); diff --git a/dojo.hpp b/dojo.hpp index bbc2460..9807767 100644 --- a/dojo.hpp +++ b/dojo.hpp @@ -29,6 +29,8 @@ struct Account; struct Provider; +struct Subscription; + struct ToriiClient; struct Error { @@ -777,10 +779,10 @@ Result client_add_models_to_sync(ToriiClient *client, Result client_on_sync_model_update(ToriiClient *client, KeysClause model, void (*callback)()); -Result client_on_entity_state_update(ToriiClient *client, - FieldElement *entities, - uintptr_t entities_len, - void (*callback)(FieldElement, CArray)); +Result client_on_entity_state_update(ToriiClient *client, + FieldElement *entities, + uintptr_t entities_len, + void (*callback)(FieldElement, CArray)); Result client_remove_models_to_sync(ToriiClient *client, const KeysClause *models, @@ -826,6 +828,8 @@ FieldElement hash_get_contract_address(FieldElement class_hash, uintptr_t constructor_calldata_len, FieldElement deployer_address); +void subscription_abort(Subscription *subscription); + void client_free(ToriiClient *t); void provider_free(Provider *rpc); diff --git a/src/c/mod.rs b/src/c/mod.rs index b4412ad..ffb4ce4 100644 --- a/src/c/mod.rs +++ b/src/c/mod.rs @@ -5,7 +5,7 @@ use self::types::{ ToriiClient, Ty, WorldMetadata, }; use crate::constants; -use crate::types::{Account, Provider}; +use crate::types::{Account, Provider, Subscription}; use crate::utils::watch_tx; use starknet::accounts::{Account as StarknetAccount, ExecutionEncoding, SingleOwnerAccount}; use starknet::core::types::FunctionCall; @@ -227,7 +227,7 @@ pub unsafe extern "C" fn client_on_entity_state_update( entities: *mut types::FieldElement, entities_len: usize, callback: unsafe extern "C" fn(types::FieldElement, CArray), -) -> Result { +) -> Result { let entities = unsafe { std::slice::from_raw_parts(entities, entities_len) }; // to vec of fieldleemnt let entities = entities.iter().map(|e| (&e.clone()).into()).collect(); @@ -238,7 +238,7 @@ pub unsafe extern "C" fn client_on_entity_state_update( Err(e) => return Result::Err(e.into()), }; - (*client).runtime.spawn(async move { + let handle = (*client).runtime.spawn(async move { while let Some(Ok(entity)) = rcv.next().await { let key: types::FieldElement = (&entity.hashed_keys).into(); let models: Vec = entity.models.into_iter().map(|e| (&e).into()).collect(); @@ -246,7 +246,7 @@ pub unsafe extern "C" fn client_on_entity_state_update( } }); - Result::Ok(true) + Result::Ok(Subscription(handle.abort_handle())) } #[no_mangle] @@ -560,6 +560,17 @@ pub unsafe extern "C" fn hash_get_contract_address( (&address).into() } +#[no_mangle] +#[allow(clippy::missing_safety_doc)] +pub unsafe extern "C" fn subscription_abort(subscription: *mut Subscription) { + if !subscription.is_null() { + unsafe { + let subscription = Box::from_raw(subscription); + subscription.0.abort(); + } + } +} + // This function takes a raw pointer to ToriiClient as an argument. // It checks if the pointer is not null. If it's not, it converts the raw pointer // back into a Box, which gets dropped at the end of the scope, diff --git a/src/types.rs b/src/types.rs index 67afa34..88519f6 100644 --- a/src/types.rs +++ b/src/types.rs @@ -5,9 +5,12 @@ use starknet::{ providers::{jsonrpc::HttpTransport, JsonRpcClient}, signers::LocalWallet, }; +use tokio::task::AbortHandle; use wasm_bindgen::prelude::*; #[wasm_bindgen] pub struct Provider(pub(crate) Arc>); #[wasm_bindgen] pub struct Account(pub(crate) SingleOwnerAccount>, LocalWallet>); +#[wasm_bindgen] +pub struct Subscription(pub(crate) AbortHandle); From 31093f1004cac9df0c41c25817b2581d919779d9 Mon Sep 17 00:00:00 2001 From: Nasr Date: Tue, 23 Apr 2024 12:18:04 -0400 Subject: [PATCH 2/5] feat: add abort for on sync moidel update --- dojo.h | 6 +++--- dojo.hpp | 4 +++- src/c/mod.rs | 6 +++--- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/dojo.h b/dojo.h index 0e64b3d..a6a48ea 100644 --- a/dojo.h +++ b/dojo.h @@ -598,9 +598,9 @@ struct Resultbool client_add_models_to_sync(struct ToriiClient *client, const struct KeysClause *models, uintptr_t models_len); -struct Resultbool client_on_sync_model_update(struct ToriiClient *client, - struct KeysClause model, - void (*callback)(void)); +struct ResultSubscription client_on_sync_model_update(struct ToriiClient *client, + struct KeysClause model, + void (*callback)(void)); struct ResultSubscription client_on_entity_state_update(struct ToriiClient *client, struct FieldElement *entities, diff --git a/dojo.hpp b/dojo.hpp index 9807767..d1e5613 100644 --- a/dojo.hpp +++ b/dojo.hpp @@ -777,7 +777,9 @@ Result client_add_models_to_sync(ToriiClient *client, const KeysClause *models, uintptr_t models_len); -Result client_on_sync_model_update(ToriiClient *client, KeysClause model, void (*callback)()); +Result client_on_sync_model_update(ToriiClient *client, + KeysClause model, + void (*callback)()); Result client_on_entity_state_update(ToriiClient *client, FieldElement *entities, diff --git a/src/c/mod.rs b/src/c/mod.rs index ffb4ce4..83a421f 100644 --- a/src/c/mod.rs +++ b/src/c/mod.rs @@ -199,7 +199,7 @@ pub unsafe extern "C" fn client_on_sync_model_update( client: *mut ToriiClient, model: KeysClause, callback: unsafe extern "C" fn(), -) -> Result { +) -> Result { let model: torii_grpc::types::KeysClause = (&model).into(); let storage = (*client).inner.storage(); @@ -211,13 +211,13 @@ pub unsafe extern "C" fn client_on_sync_model_update( Err(e) => return Result::Err(e.into()), }; - (*client).runtime.spawn(async move { + let handle = (*client).runtime.spawn(async move { if let Ok(Some(_)) = rcv.try_next() { callback(); } }); - Result::Ok(true) + Result::Ok(handle.abort_handle()) } #[no_mangle] From 6b716f2030a9678784fa225354061fb2f641f156 Mon Sep 17 00:00:00 2001 From: Nasr Date: Tue, 23 Apr 2024 14:05:35 -0400 Subject: [PATCH 3/5] fgeat: abort stream using stream-cancel for C and wasm --- Cargo.lock | 12 ++++++++++++ Cargo.toml | 1 + dojo.h | 2 +- dojo.hpp | 2 +- src/c/mod.rs | 27 +++++++++++++++++---------- src/types.rs | 4 ++-- src/wasm/mod.rs | 28 +++++++++++++++++++++------- 7 files changed, 55 insertions(+), 21 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d936b29..3fd6195 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2493,6 +2493,7 @@ dependencies = [ "serde_json", "starknet 0.9.0", "starknet-crypto", + "stream-cancel", "tokio", "tokio-stream", "torii-client", @@ -8755,6 +8756,17 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "stream-cancel" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f9fbf9bd71e4cf18d68a8a0951c0e5b7255920c0cd992c4ff51cddd6ef514a3" +dependencies = [ + "futures-core", + "pin-project", + "tokio", +] + [[package]] name = "string_cache" version = "0.8.7" diff --git a/Cargo.toml b/Cargo.toml index 1d9628b..cce96c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ tokio-stream = "0.1.14" futures = "0.3.30" futures-channel = "0.3.30" wasm-bindgen = "0.2.92" +stream-cancel = "0.8.2" [patch.crates-io] deno_task_shell = { git = "https://github.com/denoland/deno_task_shell", tag = "0.15.0" } diff --git a/dojo.h b/dojo.h index a6a48ea..9d7bda0 100644 --- a/dojo.h +++ b/dojo.h @@ -656,7 +656,7 @@ struct FieldElement hash_get_contract_address(struct FieldElement class_hash, uintptr_t constructor_calldata_len, struct FieldElement deployer_address); -void subscription_abort(struct Subscription *subscription); +void subscription_cancel(struct Subscription *subscription); void client_free(struct ToriiClient *t); diff --git a/dojo.hpp b/dojo.hpp index d1e5613..69bd0b6 100644 --- a/dojo.hpp +++ b/dojo.hpp @@ -830,7 +830,7 @@ FieldElement hash_get_contract_address(FieldElement class_hash, uintptr_t constructor_calldata_len, FieldElement deployer_address); -void subscription_abort(Subscription *subscription); +void subscription_cancel(Subscription *subscription); void client_free(ToriiClient *t); diff --git a/src/c/mod.rs b/src/c/mod.rs index 83a421f..18ae8e6 100644 --- a/src/c/mod.rs +++ b/src/c/mod.rs @@ -16,11 +16,12 @@ use starknet::providers::jsonrpc::HttpTransport; use starknet::providers::{JsonRpcClient, Provider as _}; use starknet::signers::{LocalWallet, SigningKey, VerifyingKey}; use starknet_crypto::FieldElement; +use tokio_stream::StreamExt; use std::ffi::{c_void, CStr, CString}; use std::ops::Deref; use std::os::raw::c_char; use std::sync::Arc; -use tokio_stream::StreamExt; +use stream_cancel::{StreamExt as _, Tripwire}; use torii_client::client::Client as TClient; use torii_relay::typed_data::TypedData; use torii_relay::types::Message; @@ -203,7 +204,7 @@ pub unsafe extern "C" fn client_on_sync_model_update( let model: torii_grpc::types::KeysClause = (&model).into(); let storage = (*client).inner.storage(); - let mut rcv = match storage.add_listener( + let rcv = match storage.add_listener( cairo_short_string_to_felt(model.model.as_str()).unwrap(), model.keys.as_slice(), ) { @@ -211,13 +212,16 @@ pub unsafe extern "C" fn client_on_sync_model_update( Err(e) => return Result::Err(e.into()), }; - let handle = (*client).runtime.spawn(async move { - if let Ok(Some(_)) = rcv.try_next() { + let (trigger, tripwire) = Tripwire::new(); + (*client).runtime.spawn(async move { + let mut rcv = rcv.take_until_if(tripwire); + + while let Some(_) = rcv.next().await { callback(); } }); - Result::Ok(handle.abort_handle()) + Result::Ok(Subscription(trigger)) } #[no_mangle] @@ -233,12 +237,15 @@ pub unsafe extern "C" fn client_on_entity_state_update( let entities = entities.iter().map(|e| (&e.clone()).into()).collect(); let entity_stream = unsafe { (*client).inner.on_entity_updated(entities) }; - let mut rcv = match (*client).runtime.block_on(entity_stream) { + let rcv = match (*client).runtime.block_on(entity_stream) { Ok(rcv) => rcv, Err(e) => return Result::Err(e.into()), }; - let handle = (*client).runtime.spawn(async move { + let (trigger, tripwire) = Tripwire::new(); + (*client).runtime.spawn(async move { + let mut rcv = rcv.take_until_if(tripwire); + while let Some(Ok(entity)) = rcv.next().await { let key: types::FieldElement = (&entity.hashed_keys).into(); let models: Vec = entity.models.into_iter().map(|e| (&e).into()).collect(); @@ -246,7 +253,7 @@ pub unsafe extern "C" fn client_on_entity_state_update( } }); - Result::Ok(Subscription(handle.abort_handle())) + Result::Ok(Subscription(trigger)) } #[no_mangle] @@ -562,11 +569,11 @@ pub unsafe extern "C" fn hash_get_contract_address( #[no_mangle] #[allow(clippy::missing_safety_doc)] -pub unsafe extern "C" fn subscription_abort(subscription: *mut Subscription) { +pub unsafe extern "C" fn subscription_cancel(subscription: *mut Subscription) { if !subscription.is_null() { unsafe { let subscription = Box::from_raw(subscription); - subscription.0.abort(); + subscription.0.cancel(); } } } diff --git a/src/types.rs b/src/types.rs index 88519f6..113edc9 100644 --- a/src/types.rs +++ b/src/types.rs @@ -5,7 +5,7 @@ use starknet::{ providers::{jsonrpc::HttpTransport, JsonRpcClient}, signers::LocalWallet, }; -use tokio::task::AbortHandle; +use stream_cancel::Trigger; use wasm_bindgen::prelude::*; #[wasm_bindgen] @@ -13,4 +13,4 @@ pub struct Provider(pub(crate) Arc>); #[wasm_bindgen] pub struct Account(pub(crate) SingleOwnerAccount>, LocalWallet>); #[wasm_bindgen] -pub struct Subscription(pub(crate) AbortHandle); +pub struct Subscription(pub(crate) Trigger); diff --git a/src/wasm/mod.rs b/src/wasm/mod.rs index 1db15ec..d4a7a83 100644 --- a/src/wasm/mod.rs +++ b/src/wasm/mod.rs @@ -18,13 +18,14 @@ use starknet::providers::jsonrpc::HttpTransport; use starknet::providers::{JsonRpcClient, Provider as _}; use starknet::signers::{LocalWallet, SigningKey, VerifyingKey}; use starknet_crypto::Signature; +use stream_cancel::{StreamExt as _, Tripwire}; use torii_relay::typed_data::TypedData; use torii_relay::types::Message; use tsify::Tsify; use wasm_bindgen::prelude::*; use crate::constants; -use crate::types::{Account, Provider}; +use crate::types::{Account, Provider, Subscription}; use crate::utils::watch_tx; use crate::wasm::utils::{parse_entities_as_json_str, parse_ty_as_json_str}; @@ -745,12 +746,12 @@ impl Client { &self, model: KeysClause, callback: js_sys::Function, - ) -> Result<(), JsValue> { + ) -> Result { #[cfg(feature = "console-error-panic")] console_error_panic_hook::set_once(); let name = cairo_short_string_to_felt(&model.model).expect("invalid model name"); - let mut rcv = self + let rcv = self .inner .storage() .add_listener( @@ -764,13 +765,16 @@ impl Client { ) .unwrap(); + let (trigger, tripwire) = Tripwire::new(); wasm_bindgen_futures::spawn_local(async move { + let mut rcv = rcv.take_until_if(tripwire); + while rcv.next().await.is_some() { let _ = callback.call0(&JsValue::null()); } }); - Ok(()) + Ok(Subscription(trigger)) } #[wasm_bindgen(js_name = onEntityUpdated)] @@ -778,7 +782,7 @@ impl Client { &self, ids: Option>, callback: js_sys::Function, - ) -> Result<(), JsValue> { + ) -> Result { #[cfg(feature = "console-error-panic")] console_error_panic_hook::set_once(); @@ -791,9 +795,12 @@ impl Client { }) .collect::, _>>()?; - let mut stream = self.inner.on_entity_updated(ids).await.unwrap(); + let stream = self.inner.on_entity_updated(ids).await.unwrap(); + let (trigger, tripwire) = Tripwire::new(); wasm_bindgen_futures::spawn_local(async move { + let mut stream = stream.take_until_if(tripwire); + while let Some(update) = stream.next().await { let entity = update.expect("no updated entity"); let json_str = parse_entities_as_json_str(vec![entity]).to_string(); @@ -804,7 +811,7 @@ impl Client { } }); - Ok(()) + Ok(Subscription(trigger)) } #[wasm_bindgen(js_name = publishMessage)] @@ -835,6 +842,13 @@ impl Client { } } +#[wasm_bindgen] +impl Subscription { + pub fn cancel(self) { + self.0.cancel(); + } +} + /// Create the a client with the given configurations. #[wasm_bindgen(js_name = createClient)] #[allow(non_snake_case)] From 7859d6ed820c59a3e39785a6375900222b0e8445 Mon Sep 17 00:00:00 2001 From: Nasr Date: Tue, 23 Apr 2024 14:06:09 -0400 Subject: [PATCH 4/5] fmt --- src/c/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/c/mod.rs b/src/c/mod.rs index 18ae8e6..2eaf055 100644 --- a/src/c/mod.rs +++ b/src/c/mod.rs @@ -16,12 +16,12 @@ use starknet::providers::jsonrpc::HttpTransport; use starknet::providers::{JsonRpcClient, Provider as _}; use starknet::signers::{LocalWallet, SigningKey, VerifyingKey}; use starknet_crypto::FieldElement; -use tokio_stream::StreamExt; use std::ffi::{c_void, CStr, CString}; use std::ops::Deref; use std::os::raw::c_char; use std::sync::Arc; use stream_cancel::{StreamExt as _, Tripwire}; +use tokio_stream::StreamExt; use torii_client::client::Client as TClient; use torii_relay::typed_data::TypedData; use torii_relay::types::Message; From b37b46409f35774aabdc1dcccba1416b1734d515 Mon Sep 17 00:00:00 2001 From: Nasr Date: Tue, 23 Apr 2024 14:29:47 -0400 Subject: [PATCH 5/5] fix: ffi safe reutrn ptr --- dojo.h | 2 +- dojo.hpp | 16 ++++++++-------- src/c/mod.rs | 10 +++++----- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/dojo.h b/dojo.h index 9d7bda0..9cf8f8c 100644 --- a/dojo.h +++ b/dojo.h @@ -448,7 +448,7 @@ typedef struct ResultSubscription { ResultSubscription_Tag tag; union { struct { - struct Subscription ok; + struct Subscription *ok; }; struct { struct Error err; diff --git a/dojo.hpp b/dojo.hpp index 69bd0b6..b5cbb85 100644 --- a/dojo.hpp +++ b/dojo.hpp @@ -777,14 +777,14 @@ Result client_add_models_to_sync(ToriiClient *client, const KeysClause *models, uintptr_t models_len); -Result client_on_sync_model_update(ToriiClient *client, - KeysClause model, - void (*callback)()); - -Result client_on_entity_state_update(ToriiClient *client, - FieldElement *entities, - uintptr_t entities_len, - void (*callback)(FieldElement, CArray)); +Result client_on_sync_model_update(ToriiClient *client, + KeysClause model, + void (*callback)()); + +Result client_on_entity_state_update(ToriiClient *client, + FieldElement *entities, + uintptr_t entities_len, + void (*callback)(FieldElement, CArray)); Result client_remove_models_to_sync(ToriiClient *client, const KeysClause *models, diff --git a/src/c/mod.rs b/src/c/mod.rs index 2eaf055..45fe206 100644 --- a/src/c/mod.rs +++ b/src/c/mod.rs @@ -200,7 +200,7 @@ pub unsafe extern "C" fn client_on_sync_model_update( client: *mut ToriiClient, model: KeysClause, callback: unsafe extern "C" fn(), -) -> Result { +) -> Result<*mut Subscription> { let model: torii_grpc::types::KeysClause = (&model).into(); let storage = (*client).inner.storage(); @@ -216,12 +216,12 @@ pub unsafe extern "C" fn client_on_sync_model_update( (*client).runtime.spawn(async move { let mut rcv = rcv.take_until_if(tripwire); - while let Some(_) = rcv.next().await { + while rcv.next().await.is_some() { callback(); } }); - Result::Ok(Subscription(trigger)) + Result::Ok(Box::into_raw(Box::new(Subscription(trigger)))) } #[no_mangle] @@ -231,7 +231,7 @@ pub unsafe extern "C" fn client_on_entity_state_update( entities: *mut types::FieldElement, entities_len: usize, callback: unsafe extern "C" fn(types::FieldElement, CArray), -) -> Result { +) -> Result<*mut Subscription> { let entities = unsafe { std::slice::from_raw_parts(entities, entities_len) }; // to vec of fieldleemnt let entities = entities.iter().map(|e| (&e.clone()).into()).collect(); @@ -253,7 +253,7 @@ pub unsafe extern "C" fn client_on_entity_state_update( } }); - Result::Ok(Subscription(trigger)) + Result::Ok(Box::into_raw(Box::new(Subscription(trigger)))) } #[no_mangle]