Skip to content

Commit

Permalink
Cleanup IOCP::OverlappedOperation (#14723)
Browse files Browse the repository at this point in the history
Light refactor of `IOCP::OverlappedOperation` to simplify the implementation.

* Add `OverlappedOperation#to_unsafe` as standard format for passing to C functions
* Add `OverlappedOperation.unbox` for the reverse
* Drop unnecessary `OverlappedOperation#start` to simplify the logic
  • Loading branch information
straight-shoota committed Jun 20, 2024
1 parent e0754ca commit b14be1e
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 21 deletions.
4 changes: 2 additions & 2 deletions src/crystal/system/win32/event_loop_iocp.cr
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ class Crystal::IOCP::EventLoop < Crystal::EventLoop
def connect(socket : ::Socket, address : ::Socket::Addrinfo | ::Socket::Address, timeout : ::Time::Span?) : IO::Error?
socket.overlapped_connect(socket.fd, "ConnectEx") do |overlapped|
# This is: LibC.ConnectEx(fd, address, address.size, nil, 0, nil, overlapped)
Crystal::System::Socket.connect_ex.call(socket.fd, address.to_unsafe, address.size, Pointer(Void).null, 0_u32, Pointer(UInt32).null, overlapped)
Crystal::System::Socket.connect_ex.call(socket.fd, address.to_unsafe, address.size, Pointer(Void).null, 0_u32, Pointer(UInt32).null, overlapped.to_unsafe)
end
end

Expand All @@ -256,7 +256,7 @@ class Crystal::IOCP::EventLoop < Crystal::EventLoop
received_bytes = uninitialized UInt32
Crystal::System::Socket.accept_ex.call(socket.fd, client_handle,
output_buffer.to_unsafe.as(Void*), buffer_size.to_u32!,
address_size.to_u32!, address_size.to_u32!, pointerof(received_bytes), overlapped)
address_size.to_u32!, address_size.to_u32!, pointerof(received_bytes), overlapped.to_unsafe)
end

if success
Expand Down
30 changes: 13 additions & 17 deletions src/crystal/system/win32/iocp.cr
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ module Crystal::IOCP
# I/O operations, including socket ones, do not set this field
case completion_key = Pointer(Void).new(entry.lpCompletionKey).as(CompletionKey?)
when Nil
OverlappedOperation.schedule(entry.lpOverlapped) { |fiber| yield fiber }
operation = OverlappedOperation.unbox(entry.lpOverlapped)
operation.schedule { |fiber| yield fiber }
else
case entry.dwNumberOfBytesTransferred
when LibC::JOB_OBJECT_MSG_EXIT_PROCESS, LibC::JOB_OBJECT_MSG_ABNORMAL_EXIT_PROCESS
Expand All @@ -62,15 +63,14 @@ module Crystal::IOCP

class OverlappedOperation
enum State
INITIALIZED
STARTED
DONE
CANCELLED
end

@overlapped = LibC::OVERLAPPED.new
@fiber : Fiber? = nil
@state : State = :initialized
@fiber = Fiber.current
@state : State = :started
property next : OverlappedOperation?
property previous : OverlappedOperation?
@@canceled = Thread::LinkedList(OverlappedOperation).new
Expand All @@ -84,22 +84,18 @@ module Crystal::IOCP
end
end

def self.schedule(overlapped : LibC::OVERLAPPED*, &)
def self.unbox(overlapped : LibC::OVERLAPPED*)
start = overlapped.as(Pointer(UInt8)) - offsetof(OverlappedOperation, @overlapped)
operation = Box(OverlappedOperation).unbox(start.as(Pointer(Void)))
operation.schedule { |fiber| yield fiber }
Box(OverlappedOperation).unbox(start.as(Pointer(Void)))
end

def start
raise Exception.new("Invalid state #{@state}") unless @state.initialized?
@fiber = Fiber.current
@state = State::STARTED
def to_unsafe
pointerof(@overlapped)
end

def result(handle, &)
raise Exception.new("Invalid state #{@state}") unless @state.done? || @state.started?
result = LibC.GetOverlappedResult(handle, pointerof(@overlapped), out bytes, 0)
result = LibC.GetOverlappedResult(handle, self, out bytes, 0)
if result.zero?
error = WinError.value
yield error
Expand All @@ -113,7 +109,7 @@ module Crystal::IOCP
def wsa_result(socket, &)
raise Exception.new("Invalid state #{@state}") unless @state.done? || @state.started?
flags = 0_u32
result = LibC.WSAGetOverlappedResult(socket, pointerof(@overlapped), out bytes, false, pointerof(flags))
result = LibC.WSAGetOverlappedResult(socket, self, out bytes, false, pointerof(flags))
if result.zero?
error = WinError.wsa_value
yield error
Expand All @@ -127,7 +123,7 @@ module Crystal::IOCP
protected def schedule(&)
case @state
when .started?
yield @fiber.not_nil!
yield @fiber
done!
when .cancelled?
@@canceled.delete(self)
Expand All @@ -144,7 +140,7 @@ module Crystal::IOCP
# https://learn.microsoft.com/en-us/windows/win32/api/ioapiset/nf-ioapiset-cancelioex
# > The application must not free or reuse the OVERLAPPED structure
# associated with the canceled I/O operations until they have completed
if LibC.CancelIoEx(handle, pointerof(@overlapped)) != 0
if LibC.CancelIoEx(handle, self) != 0
@state = :cancelled
@@canceled.push(self) # to increase lifetime
end
Expand Down Expand Up @@ -176,7 +172,7 @@ module Crystal::IOCP

def self.overlapped_operation(target, handle, method, timeout, *, writing = false, &)
OverlappedOperation.run(handle) do |operation|
result, value = yield operation.start
result, value = yield operation

if result == 0
case error = WinError.value
Expand Down Expand Up @@ -214,7 +210,7 @@ module Crystal::IOCP

def self.wsa_overlapped_operation(target, socket, method, timeout, connreset_is_error = true, &)
OverlappedOperation.run(socket) do |operation|
result, value = yield operation.start
result, value = yield operation

if result == LibC::SOCKET_ERROR
case error = WinError.wsa_value
Expand Down
4 changes: 2 additions & 2 deletions src/crystal/system/win32/socket.cr
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ module Crystal::System::Socket
# :nodoc:
def overlapped_connect(socket, method, &)
IOCP::OverlappedOperation.run(socket) do |operation|
result = yield operation.start
result = yield operation

if result == 0
case error = WinError.wsa_value
Expand Down Expand Up @@ -196,7 +196,7 @@ module Crystal::System::Socket

def overlapped_accept(socket, method, &)
IOCP::OverlappedOperation.run(socket) do |operation|
result = yield operation.start
result = yield operation

if result == 0
case error = WinError.wsa_value
Expand Down

0 comments on commit b14be1e

Please sign in to comment.