Skip to content

Commit 7bfa884

Browse files
refactor: replace Arc<RwLock<HashMap>> with Arc<DashMap> for attested contracts
Replace Arc<RwLock<HashMap<_, _>>> with Arc<DashMap<_, _>> for better concurrency and simpler API. DashMap provides lock-free reads and fine-grained per-shard locking, eliminating the need for manual lock management and improving scalability under concurrent access. Changes: - Update AttestedContractMap type alias to use DashMap - Refactor all read/write lock patterns to use DashMap's direct methods - Simplify token cleanup task by using DashMap's retain() method - Remove lock acquisition code throughout the codebase Co-authored-by: nacho.d.g <iduartgomez@users.noreply.github.com>
1 parent 16c0c83 commit 7bfa884

File tree

4 files changed

+53
-54
lines changed

4 files changed

+53
-54
lines changed

crates/core/src/client_events/websocket.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
use std::{
22
collections::{HashMap, VecDeque},
3-
sync::{Arc, OnceLock, RwLock},
3+
sync::{Arc, OnceLock},
44
time::Duration,
55
};
66

7+
use dashmap::DashMap;
8+
79
use axum::{
810
extract::{
911
ws::{Message, WebSocket},
@@ -53,7 +55,7 @@ const PARALLELISM: usize = 10; // TODO: get this from config, or whatever optima
5355
impl WebSocketProxy {
5456
pub fn create_router(server_routing: Router) -> (Self, Router) {
5557
// Create a default empty attested contracts map
56-
let attested_contracts = Arc::new(RwLock::new(HashMap::new()));
58+
let attested_contracts = Arc::new(DashMap::new());
5759
Self::create_router_with_attested_contracts(server_routing, attested_contracts)
5860
}
5961

@@ -287,17 +289,16 @@ async fn websocket_commands(
287289
Extension(attested_contracts): Extension<AttestedContractMap>,
288290
) -> Response {
289291
let on_upgrade = move |ws: WebSocket| async move {
290-
// Get the data we need and immediately drop the lock
292+
// Get the data we need from the DashMap
291293
let auth_and_instance = if let Some(token) = auth_token.as_ref() {
292-
let attested_contracts_read = attested_contracts.read().unwrap();
293-
294294
// Only collect and log map contents when trace is enabled
295295
if tracing::enabled!(tracing::Level::TRACE) {
296-
let map_contents: Vec<_> = attested_contracts_read.keys().cloned().collect();
296+
let map_contents: Vec<_> = attested_contracts.iter().map(|e| e.key().clone()).collect();
297297
tracing::trace!(?token, "attested_contracts map keys: {:?}", map_contents);
298298
}
299299

300-
if let Some((cid, _, _)) = attested_contracts_read.get(token) {
300+
if let Some(entry) = attested_contracts.get(token) {
301+
let (cid, _, _) = entry.value();
301302
tracing::trace!(?token, ?cid, "Found token in attested_contracts map");
302303
Some((token.clone(), *cid))
303304
} else {
@@ -307,7 +308,7 @@ async fn websocket_commands(
307308
} else {
308309
tracing::trace!("No auth token provided in WebSocket request");
309310
None
310-
}; // RwLockReadGuard is dropped here
311+
};
311312

312313
// Only evaluate auth_and_instance for trace when trace is enabled
313314
if tracing::enabled!(tracing::Level::TRACE) {

crates/core/src/node/mod.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,9 +1334,8 @@ pub async fn run_local_node(
13341334
ClientRequest::DelegateOp(op) => {
13351335
let attested_contract = token.and_then(|token| {
13361336
gw.attested_contracts
1337-
.read()
1338-
.ok()
1339-
.and_then(|guard| guard.get(&token).map(|(t, _, _)| *t))
1337+
.get(&token)
1338+
.map(|entry| entry.value().0)
13401339
});
13411340
let op_name = match op {
13421341
DelegateRequest::RegisterDelegate { .. } => "RegisterDelegate",

crates/core/src/server/http_gateway.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
use std::collections::HashMap;
22
use std::net::{IpAddr, SocketAddr};
3-
use std::sync::{Arc, RwLock};
3+
use std::sync::Arc;
44
use std::time::Instant;
55

6+
use dashmap::DashMap;
7+
68
use axum::extract::Path;
79
use axum::response::IntoResponse;
810
use axum::routing::get;
@@ -34,8 +36,7 @@ impl std::ops::Deref for HttpGatewayRequest {
3436

3537
/// Maps authentication tokens to contract instances, client IDs, and last access time.
3638
/// The Instant tracks when the token was last used to enable time-based expiration.
37-
pub type AttestedContractMap =
38-
Arc<RwLock<HashMap<AuthToken, (ContractInstanceId, ClientId, Instant)>>>;
39+
pub type AttestedContractMap = Arc<DashMap<AuthToken, (ContractInstanceId, ClientId, Instant)>>;
3940

4041
/// A gateway to access and interact with contracts through an HTTP interface.
4142
pub(crate) struct HttpGateway {
@@ -47,7 +48,7 @@ pub(crate) struct HttpGateway {
4748
impl HttpGateway {
4849
/// Returns the uninitialized axum router to compose with other routing handling or websockets.
4950
pub fn as_router(socket: &SocketAddr) -> (Self, Router) {
50-
let attested_contracts = Arc::new(RwLock::new(HashMap::new()));
51+
let attested_contracts = Arc::new(DashMap::new());
5152
Self::as_router_with_attested_contracts(socket, attested_contracts)
5253
}
5354

@@ -87,8 +88,6 @@ impl ClientEventsProxy for HttpGateway {
8788
if let Some((assigned_token, contract)) = assigned_token {
8889
let now = Instant::now();
8990
self.attested_contracts
90-
.write()
91-
.map_err(|_| ErrorKind::FailedOperation)?
9291
.insert(assigned_token.clone(), (contract, cli_id, now));
9392
tracing::debug!(
9493
?assigned_token,

crates/core/src/server/mod.rs

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@ pub(crate) mod path_handlers;
1313

1414
use std::collections::HashMap;
1515
use std::net::SocketAddr;
16-
use std::sync::{Arc, RwLock};
16+
use std::sync::Arc;
1717
use std::time::{Duration, Instant};
1818

19+
use dashmap::DashMap;
20+
1921
use freenet_stdlib::{
2022
client_api::{ClientError, ClientRequest, HostResponse},
2123
prelude::*,
@@ -149,13 +151,15 @@ pub mod local_node {
149151
tracing::info!("disconnecting cause: {cause}");
150152
}
151153
// fixme: token must live for a bit to allow reconnections
152-
if let Ok(mut guard) = gw.attested_contracts.write() {
153-
if let Some(rm_token) = guard
154-
.iter()
155-
.find_map(|(k, (_, eid, _))| (eid == &id).then(|| k.clone()))
156-
{
157-
guard.remove(&rm_token);
158-
}
154+
if let Some(rm_token) = gw
155+
.attested_contracts
156+
.iter()
157+
.find_map(|entry| {
158+
let (k, (_, eid, _)) = entry.pair();
159+
(eid == &id).then(|| k.clone())
160+
})
161+
{
162+
gw.attested_contracts.remove(&rm_token);
159163
}
160164
continue;
161165
}
@@ -209,7 +213,7 @@ pub(crate) async fn serve_gateway_in(config: WebsocketApiConfig) -> (HttpGateway
209213
let ws_socket = (config.address, config.port).into();
210214

211215
// Create a shared attested_contracts map with token expiration support
212-
let attested_contracts: AttestedContractMap = Arc::new(RwLock::new(HashMap::new()));
216+
let attested_contracts: AttestedContractMap = Arc::new(DashMap::new());
213217

214218
// Spawn background task to clean up expired tokens
215219
spawn_token_cleanup_task(attested_contracts.clone());
@@ -242,38 +246,34 @@ fn spawn_token_cleanup_task(attested_contracts: AttestedContractMap) {
242246
interval.tick().await;
243247

244248
// Clean up expired tokens
245-
if let Ok(mut guard) = attested_contracts.write() {
246-
let now = Instant::now();
247-
let initial_count = guard.len();
248-
249-
// Remove tokens that haven't been accessed in TOKEN_TTL
250-
guard.retain(|token, (contract_id, client_id, last_used)| {
251-
let elapsed = now.duration_since(*last_used);
252-
let should_keep = elapsed < TOKEN_TTL;
253-
254-
if !should_keep {
255-
tracing::info!(
256-
?token,
257-
?contract_id,
258-
?client_id,
259-
elapsed_hours = elapsed.as_secs() / 3600,
260-
"Removing expired authentication token"
261-
);
262-
}
249+
let now = Instant::now();
250+
let initial_count = attested_contracts.len();
263251

264-
should_keep
265-
});
252+
// Remove tokens that haven't been accessed in TOKEN_TTL
253+
attested_contracts.retain(|token, (contract_id, client_id, last_used)| {
254+
let elapsed = now.duration_since(*last_used);
255+
let should_keep = elapsed < TOKEN_TTL;
266256

267-
let removed_count = initial_count - guard.len();
268-
if removed_count > 0 {
269-
tracing::debug!(
270-
removed_count,
271-
remaining_count = guard.len(),
272-
"Token cleanup completed"
257+
if !should_keep {
258+
tracing::info!(
259+
?token,
260+
?contract_id,
261+
?client_id,
262+
elapsed_hours = elapsed.as_secs() / 3600,
263+
"Removing expired authentication token"
273264
);
274265
}
275-
} else {
276-
tracing::warn!("Failed to acquire write lock for token cleanup");
266+
267+
should_keep
268+
});
269+
270+
let removed_count = initial_count - attested_contracts.len();
271+
if removed_count > 0 {
272+
tracing::debug!(
273+
removed_count,
274+
remaining_count = attested_contracts.len(),
275+
"Token cleanup completed"
276+
);
277277
}
278278
}
279279
});

0 commit comments

Comments
 (0)