diff --git a/lib/openai/base_client.rb b/lib/openai/base_client.rb index 5c609ac7..b5cc5490 100644 --- a/lib/openai/base_client.rb +++ b/lib/openai/base_client.rb @@ -124,6 +124,20 @@ def follow_redirect(request, status:, response_headers:) request end + + # @api private + # + # @param status [Integer, OpenAI::APIConnectionError] + # @param stream [Enumerable, nil] + def reap_connection!(status, stream:) + case status + in (..199) | (300..499) + stream&.each { next } + in OpenAI::APIConnectionError | (500..) + OpenAI::Util.close_fused!(stream) + else + end + end end # @api private @@ -321,28 +335,23 @@ def initialize( end begin - response, stream = @requester.execute(input) - status = Integer(response.code) + status, response, stream = @requester.execute(input) rescue OpenAI::APIConnectionError => e status = e end - # normally we want to drain the response body and reuse the HTTP session by clearing the socket buffers - # unless we hit a server error - srv_fault = (500...).include?(status) - case status in ..299 [status, response, stream] in 300..399 if redirect_count >= self.class::MAX_REDIRECTS - message = "Failed to complete the request within #{self.class::MAX_REDIRECTS} redirects." + self.class.reap_connection!(status, stream: stream) - stream.each { next } + message = "Failed to complete the request within #{self.class::MAX_REDIRECTS} redirects." raise OpenAI::APIConnectionError.new(url: url, message: message) in 300..399 - request = self.class.follow_redirect(request, status: status, response_headers: response) + self.class.reap_connection!(status, stream: stream) - stream.each { next } + request = self.class.follow_redirect(request, status: status, response_headers: response) send_request( request, redirect_count: redirect_count + 1, @@ -352,12 +361,10 @@ def initialize( in OpenAI::APIConnectionError if retry_count >= max_retries raise status in (400..) if retry_count >= max_retries || !self.class.should_retry?(status, headers: response) - decoded = OpenAI::Util.decode_content(response, stream: stream, suppress_error: true) - - if srv_fault - OpenAI::Util.close_fused!(stream) - else - stream.each { next } + decoded = Kernel.then do + OpenAI::Util.decode_content(response, stream: stream, suppress_error: true) + ensure + self.class.reap_connection!(status, stream: stream) end raise OpenAI::APIStatusError.for( @@ -368,13 +375,9 @@ def initialize( response: response ) in (400..) | OpenAI::APIConnectionError - delay = retry_delay(response, retry_count: retry_count) + self.class.reap_connection!(status, stream: stream) - if srv_fault - OpenAI::Util.close_fused!(stream) - else - stream&.each { next } - end + delay = retry_delay(response, retry_count: retry_count) sleep(delay) send_request( diff --git a/lib/openai/cursor_page.rb b/lib/openai/cursor_page.rb index 92a6b024..35b585e9 100644 --- a/lib/openai/cursor_page.rb +++ b/lib/openai/cursor_page.rb @@ -68,7 +68,8 @@ def next_page? # @return [OpenAI::CursorPage] def next_page unless next_page? - raise RuntimeError.new("No more pages available. Please check #next_page? before calling ##{__method__}") + message = "No more pages available. Please check #next_page? before calling ##{__method__}" + raise RuntimeError.new(message) end req = OpenAI::Util.deep_merge(@req, {query: {after: data&.last&.id}}) diff --git a/lib/openai/pooled_net_requester.rb b/lib/openai/pooled_net_requester.rb index 0e16cc88..4b0ae742 100644 --- a/lib/openai/pooled_net_requester.rb +++ b/lib/openai/pooled_net_requester.rb @@ -61,6 +61,7 @@ def build_request(request, &) case body in nil + nil in String req["content-length"] ||= body.bytesize.to_s unless req["transfer-encoding"] req.body_stream = OpenAI::Util::ReadIOAdapter.new(body, &) @@ -79,9 +80,11 @@ def build_request(request, &) # @api private # # @param url [URI::Generic] + # @param deadline [Float] # @param blk [Proc] - private def with_pool(url, &) + private def with_pool(url, deadline:, &blk) origin = OpenAI::Util.uri_origin(url) + timeout = deadline - OpenAI::Util.monotonic_secs pool = @mutex.synchronize do @pools[origin] ||= ConnectionPool.new(size: @size) do @@ -89,7 +92,7 @@ def build_request(request, &) end end - pool.with(&) + pool.with(timeout: timeout, &blk) end # @api private @@ -106,14 +109,14 @@ def build_request(request, &) # # @option request [Float] :deadline # - # @return [Array(Net::HTTPResponse, Enumerable)] + # @return [Array(Integer, Net::HTTPResponse, Enumerable)] def execute(request) url, deadline = request.fetch_values(:url, :deadline) eof = false finished = false enum = Enumerator.new do |y| - with_pool(url) do |conn| + with_pool(url, deadline: deadline) do |conn| next if finished req = self.class.build_request(request) do @@ -125,7 +128,7 @@ def execute(request) self.class.calibrate_socket_timeout(conn, deadline) conn.request(req) do |rsp| - y << [conn, rsp] + y << [conn, req, rsp] break if finished rsp.read_body do |bytes| @@ -137,9 +140,11 @@ def execute(request) eof = true end end + rescue Timeout::Error + raise OpenAI::APITimeoutError end - conn, response = enum.next + conn, _, response = enum.next body = OpenAI::Util.fused_enum(enum, external: true) do finished = true tap do @@ -149,7 +154,7 @@ def execute(request) end conn.finish if !eof && conn&.started? end - [response, (response.body = body)] + [Integer(response.code), response, (response.body = body)] end # @api private diff --git a/lib/openai/util.rb b/lib/openai/util.rb index 2f4fecde..ef208d81 100644 --- a/lib/openai/util.rb +++ b/lib/openai/util.rb @@ -57,7 +57,7 @@ class << self # # @param input [Object] # - # @return [Boolean, Object] + # @return [Boolean] def primitive?(input) case input in true | false | Integer | Float | Symbol | String @@ -627,6 +627,8 @@ def close_fused!(enum) # # @param enum [Enumerable, nil] # @param blk [Proc] + # + # @return [Enumerable] def chain_fused(enum, &blk) iter = Enumerator.new { blk.call(_1) } fused_enum(iter) { close_fused!(enum) } diff --git a/rbi/lib/openai/base_client.rbi b/rbi/lib/openai/base_client.rbi index 85abd5cb..6b76b254 100644 --- a/rbi/lib/openai/base_client.rbi +++ b/rbi/lib/openai/base_client.rbi @@ -67,6 +67,16 @@ module OpenAI end def follow_redirect(request, status:, response_headers:) end + + # @api private + sig do + params( + status: T.any(Integer, OpenAI::APIConnectionError), + stream: T.nilable(T::Enumerable[String]) + ).void + end + def reap_connection!(status, stream:) + end end sig { returns(T.anything) } diff --git a/rbi/lib/openai/base_model.rbi b/rbi/lib/openai/base_model.rbi index 3a7913f6..06bae256 100644 --- a/rbi/lib/openai/base_model.rbi +++ b/rbi/lib/openai/base_model.rbi @@ -228,7 +228,7 @@ module OpenAI # @api private # # All of the specified variant info for this union. - sig { returns(T::Array[[T.nilable(Symbol), Proc]]) } + sig { returns(T::Array[[T.nilable(Symbol), T.proc.returns(Variants)]]) } private def known_variants end @@ -250,17 +250,8 @@ module OpenAI # @api private sig do params( - key: T.any( - Symbol, - T::Hash[Symbol, T.anything], - T.proc.returns(OpenAI::Converter::Input), - OpenAI::Converter::Input - ), - spec: T.any( - T::Hash[Symbol, T.anything], - T.proc.returns(OpenAI::Converter::Input), - OpenAI::Converter::Input - ) + key: T.any(Symbol, T::Hash[Symbol, T.anything], T.proc.returns(Variants), Variants), + spec: T.any(T::Hash[Symbol, T.anything], T.proc.returns(Variants), Variants) ) .void end @@ -268,7 +259,7 @@ module OpenAI end # @api private - sig { params(value: T.anything).returns(T.nilable(OpenAI::Converter::Input)) } + sig { params(value: T.anything).returns(T.nilable(Variants)) } private def resolve_variant(value) end end diff --git a/rbi/lib/openai/pooled_net_requester.rbi b/rbi/lib/openai/pooled_net_requester.rbi index e940c4f4..9297bdea 100644 --- a/rbi/lib/openai/pooled_net_requester.rbi +++ b/rbi/lib/openai/pooled_net_requester.rbi @@ -27,14 +27,14 @@ module OpenAI end # @api private - sig { params(url: URI::Generic, blk: T.proc.params(arg0: Net::HTTP).void).void } - private def with_pool(url, &blk) + sig { params(url: URI::Generic, deadline: Float, blk: T.proc.params(arg0: Net::HTTP).void).void } + private def with_pool(url, deadline:, &blk) end # @api private sig do params(request: OpenAI::PooledNetRequester::RequestShape) - .returns([Net::HTTPResponse, T::Enumerable[String]]) + .returns([Integer, Net::HTTPResponse, T::Enumerable[String]]) end def execute(request) end diff --git a/rbi/lib/openai/util.rbi b/rbi/lib/openai/util.rbi index 8faebf37..22824de5 100644 --- a/rbi/lib/openai/util.rbi +++ b/rbi/lib/openai/util.rbi @@ -22,7 +22,7 @@ module OpenAI class << self # @api private - sig { params(input: T.anything).returns(T.any(T::Boolean, T.anything)) } + sig { params(input: T.anything).returns(T::Boolean) } def primitive?(input) end @@ -239,10 +239,8 @@ module OpenAI # @api private sig do - params( - enum: T.nilable(T::Enumerable[T.anything]), - blk: T.proc.params(arg0: Enumerator::Yielder).void - ).void + params(enum: T.nilable(T::Enumerable[T.anything]), blk: T.proc.params(arg0: Enumerator::Yielder).void) + .returns(T::Enumerable[T.anything]) end def chain_fused(enum, &blk) end diff --git a/sig/openai/base_client.rbs b/sig/openai/base_client.rbs index d685733f..0c19b54e 100644 --- a/sig/openai/base_client.rbs +++ b/sig/openai/base_client.rbs @@ -43,6 +43,11 @@ module OpenAI response_headers: ::Hash[String, String] ) -> OpenAI::BaseClient::request_input + def self.reap_connection!: ( + Integer | OpenAI::APIConnectionError status, + stream: Enumerable[String]? + ) -> void + # @api private attr_accessor requester: top diff --git a/sig/openai/base_model.rbs b/sig/openai/base_model.rbs index 857d4573..574847b4 100644 --- a/sig/openai/base_model.rbs +++ b/sig/openai/base_model.rbs @@ -85,7 +85,7 @@ module OpenAI class Union extend OpenAI::Converter - private def self.known_variants: -> ::Array[[Symbol?, Proc]] + private def self.known_variants: -> ::Array[[Symbol?, (^-> OpenAI::Converter::input)]] def self.derefed_variants: -> ::Array[[Symbol?, top]] diff --git a/sig/openai/pooled_net_requester.rbs b/sig/openai/pooled_net_requester.rbs index 9e7daafb..c9f6520d 100644 --- a/sig/openai/pooled_net_requester.rbs +++ b/sig/openai/pooled_net_requester.rbs @@ -19,11 +19,16 @@ module OpenAI (String arg0) -> void } -> top - private def with_pool: (URI::Generic url) { (top arg0) -> void } -> void + private def with_pool: ( + URI::Generic url, + deadline: Float + ) { + (top arg0) -> void + } -> void def execute: ( OpenAI::PooledNetRequester::request request - ) -> [top, Enumerable[String]] + ) -> [Integer, top, Enumerable[String]] def initialize: (size: Integer) -> void end diff --git a/sig/openai/util.rbs b/sig/openai/util.rbs index 065ab7d1..375f8324 100644 --- a/sig/openai/util.rbs +++ b/sig/openai/util.rbs @@ -6,7 +6,7 @@ module OpenAI def self?.os: -> String - def self?.primitive?: (top input) -> (bool | top) + def self?.primitive?: (top input) -> bool def self?.coerce_boolean: (top input) -> (bool | top) @@ -118,7 +118,7 @@ module OpenAI Enumerable[top]? enum ) { (Enumerator::Yielder arg0) -> void - } -> void + } -> Enumerable[top] type server_sent_event = { event: String?, data: String?, id: String?, retry: Integer? } diff --git a/test/openai/client_test.rb b/test/openai/client_test.rb index 4147e2ac..7d0758fa 100644 --- a/test/openai/client_test.rb +++ b/test/openai/client_test.rb @@ -18,32 +18,6 @@ def test_raises_on_missing_non_nullable_opts assert_match(/is required/, e.message) end - class MockResponse - # @return [Integer] - attr_reader :code - - # @param code [Integer] - # @param headers [Hash{String=>String}] - def initialize(code, headers) - @code = code - @headers = {"content-type" => "application/json", **headers} - end - - # @param header [String] - # - # @return [String, nil] - def [](header) - @headers[header] - end - - # @param header [String] - # - # @return [Boolean] - def key?(header) - @headers.key?(header) - end - end - class MockRequester # @return [Integer] attr_reader :response_code @@ -71,7 +45,8 @@ def initialize(response_code, response_headers, response_data) def execute(req) # Deep copy the request because it is mutated on each retry. attempts.push(Marshal.load(Marshal.dump(req))) - [MockResponse.new(response_code, response_headers), response_data.grapheme_clusters] + headers = {"content-type" => "application/json", **response_headers} + [response_code, headers, response_data.grapheme_clusters] end end diff --git a/test/openai/util_test.rb b/test/openai/util_test.rb index d319e2f9..476e16af 100644 --- a/test/openai/util_test.rb +++ b/test/openai/util_test.rb @@ -161,7 +161,9 @@ class OpenAI::Test::UtilFormDataEncodingTest < Minitest::Test class FakeCGI < CGI def initialize(headers, io) @ctype = headers["content-type"] + # rubocop:disable Lint/EmptyBlock @io = OpenAI::Util::ReadIOAdapter.new(io) {} + # rubocop:enable Lint/EmptyBlock @c_len = io.to_a.join.bytesize.to_s super() end @@ -217,7 +219,9 @@ def test_copy_read } cases.each do |input, expected| io = StringIO.new + # rubocop:disable Lint/EmptyBlock adapter = OpenAI::Util::ReadIOAdapter.new(input) {} + # rubocop:enable Lint/EmptyBlock IO.copy_stream(adapter, io) assert_equal(expected, io.string) end