Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions lib/openai/base_client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,20 @@ def follow_redirect(request, status:, response_headers:)

request
end

# @api private
#
# @param status [Integer, OpenAI::APIConnectionError]
# @param stream [Enumerable, nil]
def reap_connection!(status, stream:)
case status
in (..199) | (300..499)
stream&.each { next }
in OpenAI::APIConnectionError | (500..)
OpenAI::Util.close_fused!(stream)
else
end
end
end

# @api private
Expand Down Expand Up @@ -321,28 +335,23 @@ def initialize(
end

begin
response, stream = @requester.execute(input)
status = Integer(response.code)
status, response, stream = @requester.execute(input)
rescue OpenAI::APIConnectionError => e
status = e
end

# normally we want to drain the response body and reuse the HTTP session by clearing the socket buffers
# unless we hit a server error
srv_fault = (500...).include?(status)

case status
in ..299
[status, response, stream]
in 300..399 if redirect_count >= self.class::MAX_REDIRECTS
message = "Failed to complete the request within #{self.class::MAX_REDIRECTS} redirects."
self.class.reap_connection!(status, stream: stream)

stream.each { next }
message = "Failed to complete the request within #{self.class::MAX_REDIRECTS} redirects."
raise OpenAI::APIConnectionError.new(url: url, message: message)
in 300..399
request = self.class.follow_redirect(request, status: status, response_headers: response)
self.class.reap_connection!(status, stream: stream)

stream.each { next }
request = self.class.follow_redirect(request, status: status, response_headers: response)
send_request(
request,
redirect_count: redirect_count + 1,
Expand All @@ -352,12 +361,10 @@ def initialize(
in OpenAI::APIConnectionError if retry_count >= max_retries
raise status
in (400..) if retry_count >= max_retries || !self.class.should_retry?(status, headers: response)
decoded = OpenAI::Util.decode_content(response, stream: stream, suppress_error: true)

if srv_fault
OpenAI::Util.close_fused!(stream)
else
stream.each { next }
decoded = Kernel.then do
OpenAI::Util.decode_content(response, stream: stream, suppress_error: true)
ensure
self.class.reap_connection!(status, stream: stream)
end

raise OpenAI::APIStatusError.for(
Expand All @@ -368,13 +375,9 @@ def initialize(
response: response
)
in (400..) | OpenAI::APIConnectionError
delay = retry_delay(response, retry_count: retry_count)
self.class.reap_connection!(status, stream: stream)

if srv_fault
OpenAI::Util.close_fused!(stream)
else
stream&.each { next }
end
delay = retry_delay(response, retry_count: retry_count)
sleep(delay)

send_request(
Expand Down
3 changes: 2 additions & 1 deletion lib/openai/cursor_page.rb
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def next_page?
# @return [OpenAI::CursorPage]
def next_page
unless next_page?
raise RuntimeError.new("No more pages available. Please check #next_page? before calling ##{__method__}")
message = "No more pages available. Please check #next_page? before calling ##{__method__}"
raise RuntimeError.new(message)
end

req = OpenAI::Util.deep_merge(@req, {query: {after: data&.last&.id}})
Expand Down
19 changes: 12 additions & 7 deletions lib/openai/pooled_net_requester.rb
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def build_request(request, &)

case body
in nil
nil
in String
req["content-length"] ||= body.bytesize.to_s unless req["transfer-encoding"]
req.body_stream = OpenAI::Util::ReadIOAdapter.new(body, &)
Expand All @@ -79,17 +80,19 @@ def build_request(request, &)
# @api private
#
# @param url [URI::Generic]
# @param deadline [Float]
# @param blk [Proc]
private def with_pool(url, &)
private def with_pool(url, deadline:, &blk)
origin = OpenAI::Util.uri_origin(url)
timeout = deadline - OpenAI::Util.monotonic_secs
pool =
@mutex.synchronize do
@pools[origin] ||= ConnectionPool.new(size: @size) do
self.class.connect(url)
end
end

pool.with(&)
pool.with(timeout: timeout, &blk)
end

# @api private
Expand All @@ -106,14 +109,14 @@ def build_request(request, &)
#
# @option request [Float] :deadline
#
# @return [Array(Net::HTTPResponse, Enumerable)]
# @return [Array(Integer, Net::HTTPResponse, Enumerable)]
def execute(request)
url, deadline = request.fetch_values(:url, :deadline)

eof = false
finished = false
enum = Enumerator.new do |y|
with_pool(url) do |conn|
with_pool(url, deadline: deadline) do |conn|
next if finished

req = self.class.build_request(request) do
Expand All @@ -125,7 +128,7 @@ def execute(request)

self.class.calibrate_socket_timeout(conn, deadline)
conn.request(req) do |rsp|
y << [conn, rsp]
y << [conn, req, rsp]
break if finished

rsp.read_body do |bytes|
Expand All @@ -137,9 +140,11 @@ def execute(request)
eof = true
end
end
rescue Timeout::Error
raise OpenAI::APITimeoutError
end

conn, response = enum.next
conn, _, response = enum.next
body = OpenAI::Util.fused_enum(enum, external: true) do
finished = true
tap do
Expand All @@ -149,7 +154,7 @@ def execute(request)
end
conn.finish if !eof && conn&.started?
end
[response, (response.body = body)]
[Integer(response.code), response, (response.body = body)]
end

# @api private
Expand Down
4 changes: 3 additions & 1 deletion lib/openai/util.rb
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class << self
#
# @param input [Object]
#
# @return [Boolean, Object]
# @return [Boolean]
def primitive?(input)
case input
in true | false | Integer | Float | Symbol | String
Expand Down Expand Up @@ -627,6 +627,8 @@ def close_fused!(enum)
#
# @param enum [Enumerable, nil]
# @param blk [Proc]
#
# @return [Enumerable]
def chain_fused(enum, &blk)
iter = Enumerator.new { blk.call(_1) }
fused_enum(iter) { close_fused!(enum) }
Expand Down
10 changes: 10 additions & 0 deletions rbi/lib/openai/base_client.rbi
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ module OpenAI
end
def follow_redirect(request, status:, response_headers:)
end

# @api private
sig do
params(
status: T.any(Integer, OpenAI::APIConnectionError),
stream: T.nilable(T::Enumerable[String])
).void
end
def reap_connection!(status, stream:)
end
end

sig { returns(T.anything) }
Expand Down
17 changes: 4 additions & 13 deletions rbi/lib/openai/base_model.rbi
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ module OpenAI
# @api private
#
# All of the specified variant info for this union.
sig { returns(T::Array[[T.nilable(Symbol), Proc]]) }
sig { returns(T::Array[[T.nilable(Symbol), T.proc.returns(Variants)]]) }
private def known_variants
end

Expand All @@ -250,25 +250,16 @@ module OpenAI
# @api private
sig do
params(
key: T.any(
Symbol,
T::Hash[Symbol, T.anything],
T.proc.returns(OpenAI::Converter::Input),
OpenAI::Converter::Input
),
spec: T.any(
T::Hash[Symbol, T.anything],
T.proc.returns(OpenAI::Converter::Input),
OpenAI::Converter::Input
)
key: T.any(Symbol, T::Hash[Symbol, T.anything], T.proc.returns(Variants), Variants),
spec: T.any(T::Hash[Symbol, T.anything], T.proc.returns(Variants), Variants)
)
.void
end
private def variant(key, spec = nil)
end

# @api private
sig { params(value: T.anything).returns(T.nilable(OpenAI::Converter::Input)) }
sig { params(value: T.anything).returns(T.nilable(Variants)) }
private def resolve_variant(value)
end
end
Expand Down
6 changes: 3 additions & 3 deletions rbi/lib/openai/pooled_net_requester.rbi
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ module OpenAI
end

# @api private
sig { params(url: URI::Generic, blk: T.proc.params(arg0: Net::HTTP).void).void }
private def with_pool(url, &blk)
sig { params(url: URI::Generic, deadline: Float, blk: T.proc.params(arg0: Net::HTTP).void).void }
private def with_pool(url, deadline:, &blk)
end

# @api private
sig do
params(request: OpenAI::PooledNetRequester::RequestShape)
.returns([Net::HTTPResponse, T::Enumerable[String]])
.returns([Integer, Net::HTTPResponse, T::Enumerable[String]])
end
def execute(request)
end
Expand Down
8 changes: 3 additions & 5 deletions rbi/lib/openai/util.rbi
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ module OpenAI

class << self
# @api private
sig { params(input: T.anything).returns(T.any(T::Boolean, T.anything)) }
sig { params(input: T.anything).returns(T::Boolean) }
def primitive?(input)
end

Expand Down Expand Up @@ -239,10 +239,8 @@ module OpenAI

# @api private
sig do
params(
enum: T.nilable(T::Enumerable[T.anything]),
blk: T.proc.params(arg0: Enumerator::Yielder).void
).void
params(enum: T.nilable(T::Enumerable[T.anything]), blk: T.proc.params(arg0: Enumerator::Yielder).void)
.returns(T::Enumerable[T.anything])
end
def chain_fused(enum, &blk)
end
Expand Down
5 changes: 5 additions & 0 deletions sig/openai/base_client.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ module OpenAI
response_headers: ::Hash[String, String]
) -> OpenAI::BaseClient::request_input

def self.reap_connection!: (
Integer | OpenAI::APIConnectionError status,
stream: Enumerable[String]?
) -> void

# @api private
attr_accessor requester: top

Expand Down
2 changes: 1 addition & 1 deletion sig/openai/base_model.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ module OpenAI
class Union
extend OpenAI::Converter

private def self.known_variants: -> ::Array[[Symbol?, Proc]]
private def self.known_variants: -> ::Array[[Symbol?, (^-> OpenAI::Converter::input)]]

def self.derefed_variants: -> ::Array[[Symbol?, top]]

Expand Down
9 changes: 7 additions & 2 deletions sig/openai/pooled_net_requester.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@ module OpenAI
(String arg0) -> void
} -> top

private def with_pool: (URI::Generic url) { (top arg0) -> void } -> void
private def with_pool: (
URI::Generic url,
deadline: Float
) {
(top arg0) -> void
} -> void

def execute: (
OpenAI::PooledNetRequester::request request
) -> [top, Enumerable[String]]
) -> [Integer, top, Enumerable[String]]

def initialize: (size: Integer) -> void
end
Expand Down
4 changes: 2 additions & 2 deletions sig/openai/util.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ module OpenAI

def self?.os: -> String

def self?.primitive?: (top input) -> (bool | top)
def self?.primitive?: (top input) -> bool

def self?.coerce_boolean: (top input) -> (bool | top)

Expand Down Expand Up @@ -118,7 +118,7 @@ module OpenAI
Enumerable[top]? enum
) {
(Enumerator::Yielder arg0) -> void
} -> void
} -> Enumerable[top]

type server_sent_event =
{ event: String?, data: String?, id: String?, retry: Integer? }
Expand Down
29 changes: 2 additions & 27 deletions test/openai/client_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,6 @@ def test_raises_on_missing_non_nullable_opts
assert_match(/is required/, e.message)
end

class MockResponse
# @return [Integer]
attr_reader :code

# @param code [Integer]
# @param headers [Hash{String=>String}]
def initialize(code, headers)
@code = code
@headers = {"content-type" => "application/json", **headers}
end

# @param header [String]
#
# @return [String, nil]
def [](header)
@headers[header]
end

# @param header [String]
#
# @return [Boolean]
def key?(header)
@headers.key?(header)
end
end

class MockRequester
# @return [Integer]
attr_reader :response_code
Expand Down Expand Up @@ -71,7 +45,8 @@ def initialize(response_code, response_headers, response_data)
def execute(req)
# Deep copy the request because it is mutated on each retry.
attempts.push(Marshal.load(Marshal.dump(req)))
[MockResponse.new(response_code, response_headers), response_data.grapheme_clusters]
headers = {"content-type" => "application/json", **response_headers}
[response_code, headers, response_data.grapheme_clusters]
end
end

Expand Down
Loading