Skip to content

Commit 2bb5c46

Browse files
Merge pull request #1077 from ocsigen/lwt-6-default-domain
notifications now work with an abstract id
2 parents a5c0881 + 75ba144 commit 2bb5c46

File tree

10 files changed

+250
-126
lines changed

10 files changed

+250
-126
lines changed

src/unix/lwt_gc.ml

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,15 @@ let ensure_termination t =
2020
(fun () -> Lwt_main.Exit_hooks.remove hook; Lwt.return_unit))
2121
end
2222

23-
let finaliser f =
24-
(* In order for the domain id to be consistent, wherever the real finaliser is
25-
called, we pass it in the continuation. *)
26-
let domain_id = Domain.self () in
23+
let finaliser ?domain f =
2724
(* In order not to create a reference to the value in the
2825
notification callback, we use an initially unset option cell
2926
which will be filled when the finaliser is called. *)
3027
let opt = ref None in
3128
let id =
3229
Lwt_unix.make_notification
3330
~once:true
34-
domain_id
31+
?for_other_domain:domain
3532
(fun () ->
3633
match !opt with
3734
| None ->
@@ -43,10 +40,10 @@ let finaliser f =
4340
(* The real finaliser: fill the cell and send a notification. *)
4441
(fun x ->
4542
opt := Some x;
46-
Lwt_unix.send_notification domain_id id)
43+
Lwt_unix.send_notification id)
4744

48-
let finalise f x =
49-
Gc.finalise (finaliser f) x
45+
let finalise ?domain f x =
46+
Gc.finalise (finaliser ?domain f) x
5047

5148
(* Exit hook for a finalise_or_exit *)
5249
let foe_exit f called weak () =

src/unix/lwt_gc.mli

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,25 @@
99
thread to a value, without having to use [Lwt_unix.run] in the
1010
finaliser. *)
1111

12-
val finalise : ('a -> unit Lwt.t) -> 'a -> unit
12+
val finalise : ?domain:Domain.id -> ('a -> unit Lwt.t) -> 'a -> unit
1313
(** [finalise f x] ensures [f x] is evaluated after [x] has been
1414
garbage collected. If [f x] yields, then Lwt will wait for its
1515
termination at the end of the program.
1616
1717
Note that [f x] is not called at garbage collection time, but
18-
later in the main loop. *)
18+
later in the main loop.
19+
20+
If [domain] is provided, then [f x] is evaluated in the corresponding
21+
domain. Otherwise it is evaluated in the domain calling [finalise]. If
22+
Lwt is not running in the domain set to run the finaliser, an
23+
unspecified error occurs at an unspecified time or the finaliser doesn't
24+
run or some other bad thing happens. *)
1925

2026
val finalise_or_exit : ('a -> unit Lwt.t) -> 'a -> unit
2127
(** [finalise_or_exit f x] call [f x] when [x] is garbage collected
22-
or (exclusively) when the program exits. *)
28+
or (exclusively) when the program exits.
29+
30+
The finaliser [f] is called in the same domain that called
31+
[finalise_or_exit]. If there is no Lwt scheduler running in this domain an
32+
unspecified error occurs. You can use [Lwt_preemptive.run_in_domain] to
33+
bypass the same-domain limitation. *)

src/unix/lwt_main.ml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ let abandon_yielded_and_paused () =
2121
Lwt.abandon_paused ()
2222

2323
let run p =
24-
let domain_id = Domain.self () in
25-
let () = if (Lwt.Private.Multidomain_sync.is_alredy_registered[@alert "-trespassing"]) domain_id then
24+
let domain = Domain.self () in
25+
let () = if (Lwt.Private.Multidomain_sync.is_alredy_registered[@alert "-trespassing"]) domain then
2626
()
2727
else begin
28-
let n = Lwt_unix.make_notification domain_id (fun () ->
29-
let cbs = (Lwt.Private.Multidomain_sync.get_sent_callbacks[@alert "-trespassing"]) domain_id in
28+
let n = Lwt_unix.make_notification (fun () ->
29+
let cbs = (Lwt.Private.Multidomain_sync.get_sent_callbacks[@alert "-trespassing"]) domain in
3030
Lwt_sequence.iter_l (fun f -> f ()) cbs
3131
) in
32-
(Lwt.Private.Multidomain_sync.register_notification[@alert "-trespassing"]) domain_id (fun () -> Lwt_unix.send_notification domain_id n)
32+
(Lwt.Private.Multidomain_sync.register_notification[@alert "-trespassing"]) domain(fun () -> Lwt_unix.send_notification n)
3333
end
3434
in
3535
let rec run_loop () =

src/unix/lwt_main.mli

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,14 @@ val abandon_yielded_and_paused : unit -> unit [@@deprecated "Use Lwt.abandon_pau
7777

7878

7979
(** Hook sequences. Each module of this type is a set of hooks, to be run by Lwt
80-
at certain points during execution. See modules {!Enter_iter_hooks},
81-
{!Leave_iter_hooks}, and {!Exit_hooks}. *)
80+
at certain points during execution.
81+
82+
Hooks are added for the current domain. If you are calling the Hook
83+
functions from a domain where Lwt is not running a scheduler then some
84+
unspecified error may occur. If you need to set some Hooks to/from a
85+
different domain, you can use [Lwt_preemptive.run_in_domain].
86+
87+
See modules {!Enter_iter_hooks}, {!Leave_iter_hooks}, and {!Exit_hooks}. *)
8288
module type Hooks =
8389
sig
8490
type 'return_value kind

src/unix/lwt_preemptive.ml

Lines changed: 72 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,23 @@ open Lwt.Infix
1616
| Parameters |
1717
+-----------------------------------------------------------------+ *)
1818

19-
(* Minimum number of preemptive threads: *)
20-
let min_threads : int Atomic.t = Atomic.make 0
19+
(* Minimum number of preemptive threads per domain *)
20+
let min_threads : int Domain.DLS.key = Domain.DLS.new_key (fun () -> 0)
2121

22-
(* Maximum number of preemptive threads: *)
23-
let max_threads : int Atomic.t = Atomic.make 0
22+
(* Maximum number of preemptive threads per domain *)
23+
let max_threads : int Domain.DLS.key = Domain.DLS.new_key (fun () -> 0)
2424

25-
(* Size of the waiting queue: *)
26-
let max_thread_queued = Atomic.make 1000
25+
(* Size of the waiting queue per domain *)
26+
let max_thread_queued = Domain.DLS.new_key (fun () -> 1000)
2727

28-
let get_max_number_of_threads_queued _ =
29-
Atomic.get max_thread_queued
28+
let get_max_number_of_threads_queued () = Domain.DLS.get max_thread_queued
3029

3130
let set_max_number_of_threads_queued n =
3231
if n < 0 then invalid_arg "Lwt_preemptive.set_max_number_of_threads_queued";
33-
Atomic.set max_thread_queued n
32+
Domain.DLS.set max_thread_queued n
3433

3534
(* The total number of preemptive threads currently running: *)
36-
let threads_count = Atomic.make 0
35+
let threads_count = Domain.DLS.new_key (fun () -> 0)
3736

3837
(* +-----------------------------------------------------------------+
3938
| Preemptive threads management |
@@ -44,14 +43,15 @@ sig
4443
type 'a t
4544

4645
val make : unit -> 'a t
47-
val get : 'a t -> 'a
46+
val get : 'a t -> ('a, unit) result
4847
val set : 'a t -> 'a -> unit
48+
val kill : 'a t -> unit
4949
end =
5050
struct
5151
type 'a t = {
5252
m : Mutex.t;
5353
cv : Condition.t;
54-
mutable cell : 'a option;
54+
mutable cell : ('a, unit) result option;
5555
}
5656

5757
let make () = { m = Mutex.create (); cv = Condition.create (); cell = None }
@@ -72,13 +72,19 @@ struct
7272

7373
let set t v =
7474
Mutex.lock t.m;
75-
t.cell <- Some v;
75+
t.cell <- Some (Ok v);
76+
Mutex.unlock t.m;
77+
Condition.signal t.cv
78+
79+
let kill t =
80+
Mutex.lock t.m;
81+
t.cell <- Some (Error ());
7682
Mutex.unlock t.m;
7783
Condition.signal t.cv
7884
end
7985

8086
type thread = {
81-
task_cell: (int * (unit -> unit)) CELL.t;
87+
task_cell: (Lwt_unix.notification * (unit -> unit)) CELL.t;
8288
(* Channel used to communicate notification id and tasks to the
8389
worker thread. *)
8490

@@ -91,25 +97,27 @@ type thread = {
9197
}
9298

9399
(* Pool of worker threads: *)
94-
let workers : thread Queue.t = Queue.create ()
100+
let workers : thread Queue.t Domain.DLS.key = Domain.DLS.new_key Queue.create
95101

96102
(* Queue of clients waiting for a worker to be available: *)
97-
let waiters : thread Lwt.u Lwt_sequence.t = Lwt_sequence.create ()
103+
let waiters : thread Lwt.u Lwt_sequence.t Domain.DLS.key = Domain.DLS.new_key Lwt_sequence.create
98104

99105
(* Code executed by a worker: *)
100106
let rec worker_loop worker =
101-
let id, task = CELL.get worker.task_cell in
102-
task ();
103-
(* If there is too much threads, exit. This can happen if the user
104-
decreased the maximum: *)
105-
if Atomic.get threads_count > Atomic.get max_threads then worker.reuse <- false;
106-
(* Tell the main thread that work is done: *)
107-
Lwt_unix.send_notification (Domain.self ()) id;
108-
if worker.reuse then worker_loop worker
107+
match CELL.get worker.task_cell with
108+
| Error () -> ()
109+
| Ok (id, task) ->
110+
task ();
111+
(* If there is too much threads, exit. This can happen if the user
112+
decreased the maximum: *)
113+
if Domain.DLS.get threads_count > Domain.DLS.get max_threads then worker.reuse <- false;
114+
(* Tell the main thread that work is done: *)
115+
Lwt_unix.send_notification id;
116+
if worker.reuse then worker_loop worker
109117

110118
(* create a new worker: *)
111119
let make_worker () =
112-
Atomic.incr threads_count;
120+
Domain.DLS.set threads_count (Domain.DLS.get threads_count + 1);
113121
let worker = {
114122
task_cell = CELL.make ();
115123
thread = Thread.self ();
@@ -120,52 +128,52 @@ let make_worker () =
120128

121129
(* Add a worker to the pool: *)
122130
let add_worker worker =
123-
match Lwt_sequence.take_opt_l waiters with
131+
match Lwt_sequence.take_opt_l (Domain.DLS.get waiters) with
124132
| None ->
125-
Queue.add worker workers
133+
Queue.add worker (Domain.DLS.get workers)
126134
| Some w ->
127135
Lwt.wakeup w worker
128136

129137
(* Wait for worker to be available, then return it: *)
130138
let get_worker () =
131-
if not (Queue.is_empty workers) then
132-
Lwt.return (Queue.take workers)
133-
else if Atomic.get threads_count < Atomic.get max_threads then
139+
if not (Queue.is_empty (Domain.DLS.get workers)) then
140+
Lwt.return (Queue.take (Domain.DLS.get workers))
141+
else if Domain.DLS.get threads_count < Domain.DLS.get max_threads then
134142
Lwt.return (make_worker ())
135143
else
136-
(Lwt.add_task_r [@ocaml.warning "-3"]) waiters
144+
(Lwt.add_task_r [@ocaml.warning "-3"]) (Domain.DLS.get waiters)
137145

138146
(* +-----------------------------------------------------------------+
139147
| Initialisation, and dynamic parameters reset |
140148
+-----------------------------------------------------------------+ *)
141149

142-
let get_bounds () = (Atomic.get min_threads, Atomic.get max_threads)
150+
let get_bounds () = (Domain.DLS.get min_threads, Domain.DLS.get max_threads)
143151

144152
let set_bounds (min, max) =
145153
if min < 0 || max < min then invalid_arg "Lwt_preemptive.set_bounds";
146-
let diff = min - Atomic.get threads_count in
147-
Atomic.set min_threads min;
148-
Atomic.set max_threads max;
154+
let diff = min - Domain.DLS.get threads_count in
155+
Domain.DLS.set min_threads min;
156+
Domain.DLS.set max_threads max;
149157
(* Launch new workers: *)
150158
for _i = 1 to diff do
151159
add_worker (make_worker ())
152160
done
153161

154-
let initialized = Atomic.make false
162+
let initialized = Domain.DLS.new_key (fun () -> false)
155163

156164
let init min max _errlog =
157-
Atomic.set initialized true;
165+
Domain.DLS.set initialized true;
158166
set_bounds (min, max)
159167

160168
let simple_init () =
161-
if not (Atomic.get initialized) then begin
162-
Atomic.set initialized true;
169+
if not (Domain.DLS.get initialized) then begin
170+
Domain.DLS.set initialized true;
163171
set_bounds (0, 4)
164172
end
165173

166-
let nbthreads () = Atomic.get threads_count
167-
let nbthreadsqueued () = Lwt_sequence.fold_l (fun _ x -> x + 1) waiters 0
168-
let nbthreadsbusy () = Atomic.get threads_count - Queue.length workers
174+
let nbthreads () = Domain.DLS.get threads_count
175+
let nbthreadsqueued () = Lwt_sequence.fold_l (fun _ x -> x + 1) (Domain.DLS.get waiters) 0
176+
let nbthreadsbusy () = Domain.DLS.get threads_count - Queue.length (Domain.DLS.get workers)
169177

170178
(* +-----------------------------------------------------------------+
171179
| Detaching |
@@ -186,7 +194,8 @@ let detach f args =
186194
get_worker () >>= fun worker ->
187195
let waiter, wakener = Lwt.wait () in
188196
let id =
189-
Lwt_unix.make_notification ~once:true (Domain.self ())
197+
(* call back the domain that called the [detach] function: self *)
198+
Lwt_unix.make_notification ~once:true
190199
(fun () -> Lwt.wakeup_result wakener !result)
191200
in
192201
Lwt.finalize
@@ -199,7 +208,7 @@ let detach f args =
199208
(* Put back the worker to the pool: *)
200209
add_worker worker
201210
else begin
202-
Atomic.decr threads_count;
211+
Domain.DLS.set threads_count (Domain.DLS.get threads_count - 1);
203212
(* Or wait for the thread to terminates, to free its associated
204213
resources: *)
205214
Thread.join worker.thread
@@ -216,23 +225,27 @@ let jobs = Queue.create ()
216225
(* Mutex to protect access to [jobs]. *)
217226
let jobs_mutex = Mutex.create ()
218227

219-
let job_notification =
220-
Lwt_unix.make_notification (Domain.self ())
228+
let job_notification = Domain_map.create_protected_map ()
229+
let get_job_notification d =
230+
Domain_map.init job_notification d
221231
(fun () ->
222-
(* Take the first job. The queue is never empty at this
223-
point. *)
224-
Mutex.lock jobs_mutex;
225-
let thunk = Queue.take jobs in
226-
Mutex.unlock jobs_mutex;
227-
ignore (thunk ()))
232+
Lwt_unix.make_notification ~for_other_domain:d
233+
(fun () ->
234+
(* Take the first job. The queue is never empty at this
235+
point. *)
236+
Mutex.lock jobs_mutex;
237+
let thunk = Queue.take jobs in
238+
Mutex.unlock jobs_mutex;
239+
ignore (thunk ()))
240+
)
228241

229242
let run_in_domain_dont_wait d f =
230243
(* Add the job to the queue. *)
231244
Mutex.lock jobs_mutex;
232245
Queue.add f jobs;
233246
Mutex.unlock jobs_mutex;
234247
(* Notify the main thread. *)
235-
Lwt_unix.send_notification d job_notification
248+
Lwt_unix.send_notification (get_job_notification d)
236249

237250
(* There is a potential performance issue from creating a cell every time this
238251
function is called. See:
@@ -254,10 +267,14 @@ let run_in_domain d f =
254267
run_in_domain_dont_wait d job;
255268
(* Wait for the result. *)
256269
match CELL.get cell with
257-
| Result.Ok ret -> ret
258-
| Result.Error exn -> raise exn
270+
| Ok (Ok ret) -> ret
271+
| Ok (Error exn) -> raise exn
272+
| Error () -> assert false
259273

260274
(* This version shadows the one above, adding an exception handler *)
261275
let run_in_domain_dont_wait d f handler =
262276
let f () = Lwt.catch f (fun exc -> handler exc; Lwt.return_unit) in
263277
run_in_domain_dont_wait d f
278+
279+
let terminate_worker_threads () =
280+
Queue.iter (fun thread -> CELL.kill thread.task_cell) (Domain.DLS.get workers)

src/unix/lwt_preemptive.mli

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@ val init : int -> int -> (string -> unit) -> unit
5353
@param log is used to log error messages
5454
5555
If {!Lwt_preemptive} has already been initialised, this call
56-
only modify bounds and the log function. *)
56+
only modify bounds and the log function.
57+
58+
The limits are set per-domain. More specifically, each domain manages a
59+
pool of systhreads, each pool having its own limits and its own state. *)
5760

5861
val simple_init : unit -> unit
5962
(** [simple_init ()] checks if the library is not yet initialized, and if not,
@@ -80,6 +83,17 @@ val get_max_number_of_threads_queued : unit -> int
8083
(** Returns the size of the waiting queue, if no more threads are
8184
available *)
8285

86+
val terminate_worker_threads : unit -> unit
87+
(* [terminate_worker_threads ()] queues up a message for all the workers of the
88+
calling domain to self-terminate. This causes all the workers to terminate
89+
after their current jobs are done which causes the threads of these workers
90+
to end.
91+
92+
Terminating the threads attached to a domain is necessary for joining the
93+
domain. Thus, if you use-case for domains includes spawning and joining them,
94+
you must call [terminate_worker_threads] just before calling
95+
[Domain.join]. *)
96+
8397
(**/**)
8498
val nbthreads : unit -> int
8599
val nbthreadsbusy : unit -> int

0 commit comments

Comments
 (0)