Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tls-lwt: do not catch out of memory exception #469

Merged
merged 1 commit into from
Feb 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 17 additions & 7 deletions lwt/tls_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,22 @@ module Unix = struct
}

let safely th =
Lwt.catch (fun () -> th >>= fun _ -> return_unit) (fun _ -> return_unit)
Lwt.catch
(fun () -> th >>= fun _ -> return_unit)
(function
| Out_of_memory -> raise Out_of_memory
| _ -> return_unit)

let (read_t, write_t) =
let recording_errors op t cs =
Lwt.catch
(fun () -> op t.fd cs)
(fun exn -> (match t.state with
| `Error _ | `Eof -> ()
| `Active _ -> t.state <- `Error exn) ;
fail exn)
(function
| Out_of_memory -> raise Out_of_memory
| exn -> (match t.state with
| `Error _ | `Eof -> ()
| `Active _ -> t.state <- `Error exn) ;
fail exn)
in
(recording_errors Lwt_cs.read, recording_errors Lwt_cs.write_full)

Expand Down Expand Up @@ -206,7 +212,9 @@ module Unix = struct
let accept conf fd =
Lwt_unix.accept fd >>= fun (fd', addr) ->
Lwt.catch (fun () -> server_of_fd conf fd' >|= fun t -> (t, addr))
(fun exn -> safely (Lwt_unix.close fd') >>= fun () -> fail exn)
(function
| Out_of_memory -> raise Out_of_memory
| exn -> safely (Lwt_unix.close fd') >>= fun () -> fail exn)

let connect conf (host, port) =
resolve host (string_of_int port) >>= fun addr ->
Expand All @@ -217,7 +225,9 @@ module Unix = struct
(Result.bind (Domain_name.of_string host) Domain_name.host)
in
Lwt_unix.connect fd addr >>= fun () -> client_of_fd conf ?host fd)
(fun exn -> safely (Lwt_unix.close fd) >>= fun () -> fail exn)
(function
| Out_of_memory -> raise Out_of_memory
| exn -> safely (Lwt_unix.close fd) >>= fun () -> fail exn)

let read_bytes t bs off len =
read t (Cstruct.of_bigarray ~off ~len bs)
Expand Down