Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ tokio-stream = "0.1.14"
futures = "0.3.30"
futures-channel = "0.3.30"
wasm-bindgen = "0.2.92"
stream-cancel = "0.8.2"

[patch.crates-io]
deno_task_shell = { git = "https://github.com/denoland/deno_task_shell", tag = "0.15.0" }
Expand Down
37 changes: 29 additions & 8 deletions dojo.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ typedef struct Account Account;

typedef struct Provider Provider;

typedef struct Subscription Subscription;

typedef struct ToriiClient ToriiClient;

typedef struct Error {
Expand Down Expand Up @@ -437,6 +439,23 @@ typedef struct Resultbool {
};
} Resultbool;

typedef enum ResultSubscription_Tag {
OkSubscription,
ErrSubscription,
} ResultSubscription_Tag;

typedef struct ResultSubscription {
ResultSubscription_Tag tag;
union {
struct {
struct Subscription *ok;
};
struct {
struct Error err;
};
};
} ResultSubscription;

typedef enum ResultFieldElement_Tag {
OkFieldElement,
ErrFieldElement,
Expand Down Expand Up @@ -579,15 +598,15 @@ struct Resultbool client_add_models_to_sync(struct ToriiClient *client,
const struct KeysClause *models,
uintptr_t models_len);

struct Resultbool client_on_sync_model_update(struct ToriiClient *client,
struct KeysClause model,
void (*callback)(void));
struct ResultSubscription client_on_sync_model_update(struct ToriiClient *client,
struct KeysClause model,
void (*callback)(void));

struct Resultbool client_on_entity_state_update(struct ToriiClient *client,
struct FieldElement *entities,
uintptr_t entities_len,
void (*callback)(struct FieldElement,
struct CArrayModel));
struct ResultSubscription client_on_entity_state_update(struct ToriiClient *client,
struct FieldElement *entities,
uintptr_t entities_len,
void (*callback)(struct FieldElement,
struct CArrayModel));

struct Resultbool client_remove_models_to_sync(struct ToriiClient *client,
const struct KeysClause *models,
Expand Down Expand Up @@ -637,6 +656,8 @@ struct FieldElement hash_get_contract_address(struct FieldElement class_hash,
uintptr_t constructor_calldata_len,
struct FieldElement deployer_address);

void subscription_cancel(struct Subscription *subscription);

void client_free(struct ToriiClient *t);

void provider_free(struct Provider *rpc);
Expand Down
16 changes: 11 additions & 5 deletions dojo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ struct Account;

struct Provider;

struct Subscription;

struct ToriiClient;

struct Error {
Expand Down Expand Up @@ -775,12 +777,14 @@ Result<bool> client_add_models_to_sync(ToriiClient *client,
const KeysClause *models,
uintptr_t models_len);

Result<bool> client_on_sync_model_update(ToriiClient *client, KeysClause model, void (*callback)());
Result<Subscription*> client_on_sync_model_update(ToriiClient *client,
KeysClause model,
void (*callback)());

Result<bool> client_on_entity_state_update(ToriiClient *client,
FieldElement *entities,
uintptr_t entities_len,
void (*callback)(FieldElement, CArray<Model>));
Result<Subscription*> client_on_entity_state_update(ToriiClient *client,
FieldElement *entities,
uintptr_t entities_len,
void (*callback)(FieldElement, CArray<Model>));

Result<bool> client_remove_models_to_sync(ToriiClient *client,
const KeysClause *models,
Expand Down Expand Up @@ -826,6 +830,8 @@ FieldElement hash_get_contract_address(FieldElement class_hash,
uintptr_t constructor_calldata_len,
FieldElement deployer_address);

void subscription_cancel(Subscription *subscription);

void client_free(ToriiClient *t);

void provider_free(Provider *rpc);
Expand Down
34 changes: 26 additions & 8 deletions src/c/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use self::types::{
ToriiClient, Ty, WorldMetadata,
};
use crate::constants;
use crate::types::{Account, Provider};
use crate::types::{Account, Provider, Subscription};
use crate::utils::watch_tx;
use starknet::accounts::{Account as StarknetAccount, ExecutionEncoding, SingleOwnerAccount};
use starknet::core::types::FunctionCall;
Expand All @@ -20,6 +20,7 @@ use std::ffi::{c_void, CStr, CString};
use std::ops::Deref;
use std::os::raw::c_char;
use std::sync::Arc;
use stream_cancel::{StreamExt as _, Tripwire};
use tokio_stream::StreamExt;
use torii_client::client::Client as TClient;
use torii_relay::typed_data::TypedData;
Expand Down Expand Up @@ -199,25 +200,28 @@ pub unsafe extern "C" fn client_on_sync_model_update(
client: *mut ToriiClient,
model: KeysClause,
callback: unsafe extern "C" fn(),
) -> Result<bool> {
) -> Result<*mut Subscription> {
let model: torii_grpc::types::KeysClause = (&model).into();
let storage = (*client).inner.storage();

let mut rcv = match storage.add_listener(
let rcv = match storage.add_listener(
cairo_short_string_to_felt(model.model.as_str()).unwrap(),
model.keys.as_slice(),
) {
Ok(rcv) => rcv,
Err(e) => return Result::Err(e.into()),
};

let (trigger, tripwire) = Tripwire::new();
(*client).runtime.spawn(async move {
if let Ok(Some(_)) = rcv.try_next() {
let mut rcv = rcv.take_until_if(tripwire);

while rcv.next().await.is_some() {
callback();
}
});

Result::Ok(true)
Result::Ok(Box::into_raw(Box::new(Subscription(trigger))))
}

#[no_mangle]
Expand All @@ -227,26 +231,29 @@ pub unsafe extern "C" fn client_on_entity_state_update(
entities: *mut types::FieldElement,
entities_len: usize,
callback: unsafe extern "C" fn(types::FieldElement, CArray<Model>),
) -> Result<bool> {
) -> Result<*mut Subscription> {
let entities = unsafe { std::slice::from_raw_parts(entities, entities_len) };
// to vec of fieldleemnt
let entities = entities.iter().map(|e| (&e.clone()).into()).collect();

let entity_stream = unsafe { (*client).inner.on_entity_updated(entities) };
let mut rcv = match (*client).runtime.block_on(entity_stream) {
let rcv = match (*client).runtime.block_on(entity_stream) {
Ok(rcv) => rcv,
Err(e) => return Result::Err(e.into()),
};

let (trigger, tripwire) = Tripwire::new();
(*client).runtime.spawn(async move {
let mut rcv = rcv.take_until_if(tripwire);

while let Some(Ok(entity)) = rcv.next().await {
let key: types::FieldElement = (&entity.hashed_keys).into();
let models: Vec<Model> = entity.models.into_iter().map(|e| (&e).into()).collect();
callback(key, models.into());
}
});

Result::Ok(true)
Result::Ok(Box::into_raw(Box::new(Subscription(trigger))))
}

#[no_mangle]
Expand Down Expand Up @@ -560,6 +567,17 @@ pub unsafe extern "C" fn hash_get_contract_address(
(&address).into()
}

#[no_mangle]
#[allow(clippy::missing_safety_doc)]
pub unsafe extern "C" fn subscription_cancel(subscription: *mut Subscription) {
if !subscription.is_null() {
unsafe {
let subscription = Box::from_raw(subscription);
subscription.0.cancel();
}
}
}

// This function takes a raw pointer to ToriiClient as an argument.
// It checks if the pointer is not null. If it's not, it converts the raw pointer
// back into a Box<ToriiClient>, which gets dropped at the end of the scope,
Expand Down
3 changes: 3 additions & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ use starknet::{
providers::{jsonrpc::HttpTransport, JsonRpcClient},
signers::LocalWallet,
};
use stream_cancel::Trigger;
use wasm_bindgen::prelude::*;

#[wasm_bindgen]
pub struct Provider(pub(crate) Arc<JsonRpcClient<HttpTransport>>);
#[wasm_bindgen]
pub struct Account(pub(crate) SingleOwnerAccount<Arc<JsonRpcClient<HttpTransport>>, LocalWallet>);
#[wasm_bindgen]
pub struct Subscription(pub(crate) Trigger);
28 changes: 21 additions & 7 deletions src/wasm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ use starknet::providers::jsonrpc::HttpTransport;
use starknet::providers::{JsonRpcClient, Provider as _};
use starknet::signers::{LocalWallet, SigningKey, VerifyingKey};
use starknet_crypto::Signature;
use stream_cancel::{StreamExt as _, Tripwire};
use torii_relay::typed_data::TypedData;
use torii_relay::types::Message;
use tsify::Tsify;
use wasm_bindgen::prelude::*;

use crate::constants;
use crate::types::{Account, Provider};
use crate::types::{Account, Provider, Subscription};
use crate::utils::watch_tx;
use crate::wasm::utils::{parse_entities_as_json_str, parse_ty_as_json_str};

Expand Down Expand Up @@ -745,12 +746,12 @@ impl Client {
&self,
model: KeysClause,
callback: js_sys::Function,
) -> Result<(), JsValue> {
) -> Result<Subscription, JsValue> {
#[cfg(feature = "console-error-panic")]
console_error_panic_hook::set_once();

let name = cairo_short_string_to_felt(&model.model).expect("invalid model name");
let mut rcv = self
let rcv = self
.inner
.storage()
.add_listener(
Expand All @@ -764,21 +765,24 @@ impl Client {
)
.unwrap();

let (trigger, tripwire) = Tripwire::new();
wasm_bindgen_futures::spawn_local(async move {
let mut rcv = rcv.take_until_if(tripwire);

while rcv.next().await.is_some() {
let _ = callback.call0(&JsValue::null());
}
});

Ok(())
Ok(Subscription(trigger))
}

#[wasm_bindgen(js_name = onEntityUpdated)]
pub async fn on_entity_updated(
&self,
ids: Option<Vec<String>>,
callback: js_sys::Function,
) -> Result<(), JsValue> {
) -> Result<Subscription, JsValue> {
#[cfg(feature = "console-error-panic")]
console_error_panic_hook::set_once();

Expand All @@ -791,9 +795,12 @@ impl Client {
})
.collect::<Result<Vec<_>, _>>()?;

let mut stream = self.inner.on_entity_updated(ids).await.unwrap();
let stream = self.inner.on_entity_updated(ids).await.unwrap();

let (trigger, tripwire) = Tripwire::new();
wasm_bindgen_futures::spawn_local(async move {
let mut stream = stream.take_until_if(tripwire);

while let Some(update) = stream.next().await {
let entity = update.expect("no updated entity");
let json_str = parse_entities_as_json_str(vec![entity]).to_string();
Expand All @@ -804,7 +811,7 @@ impl Client {
}
});

Ok(())
Ok(Subscription(trigger))
}

#[wasm_bindgen(js_name = publishMessage)]
Expand Down Expand Up @@ -835,6 +842,13 @@ impl Client {
}
}

#[wasm_bindgen]
impl Subscription {
pub fn cancel(self) {
self.0.cancel();
}
}

/// Create the a client with the given configurations.
#[wasm_bindgen(js_name = createClient)]
#[allow(non_snake_case)]
Expand Down