diff --git a/sqlexpr_sqlite_lwt.ml b/sqlexpr_sqlite_lwt.ml index e3061eb..736f37d 100644 --- a/sqlexpr_sqlite_lwt.ml +++ b/sqlexpr_sqlite_lwt.ml @@ -41,12 +41,13 @@ struct mutable thread : Thread.t; mutable finished : bool; db : db; (* keep ref to db so it isn't GCed while a worker is active *) + mutex : Lwt_mutex.t; } type stmt = worker * Stmt.t type 'a result = 'a Lwt.t - let max_threads = ref 100 + let max_threads = ref 4 let set_default_max_threads n = max_threads := n let close_db db = @@ -120,6 +121,7 @@ struct task_channel = Event.new_channel (); thread = Thread.self (); finished = false; + mutex = Lwt_mutex.create (); } in worker.thread <- Thread.create (worker_loop worker) (); worker @@ -169,11 +171,12 @@ struct wakeup wakener value | `Failure exn -> wakeup_exn wakener exn) - in try_lwt - check_worker_finished worker; - (* Send the id and the task to the worker: *) - Event.sync (Event.send worker.task_channel (id, task)); - waiter + in Lwt_mutex.with_lock worker.mutex + (fun () -> try_lwt + check_worker_finished worker; + (* Send the id and the task to the worker: *) + Event.sync (Event.send worker.task_channel (id, task)); + waiter) let do_raise_error ?sql ?params ?errmsg errcode = let msg = Sqlite3.Rc.to_string errcode ^ Option.map_default ((^) " ") "" errmsg in @@ -247,15 +250,15 @@ struct List.iter (fun v -> match Stmt.bind stmt !n v with Sqlite3.Rc.OK -> decr n - | code -> do_raise_error ~sql ~params code) + | code -> add_worker db worker; do_raise_error ~sql ~params code) params) stmt >> profile_execute_sql sql ~params (fun () -> try_lwt + add_worker db worker; f (worker, stmt) sql params finally - add_worker db worker; match stmt_id with Some id -> Stmt_cache.add_stmt worker.stmt_cache id stmt; return () | None -> return ())