diff --git a/lib/openai/base_stream.rb b/lib/openai/base_stream.rb index 082d045d..c2beb4b9 100644 --- a/lib/openai/base_stream.rb +++ b/lib/openai/base_stream.rb @@ -16,7 +16,7 @@ module OpenAI # # messages => Array # ``` - class BaseStream + module BaseStream # @return [void] # def close = OpenAI::Util.close_fused!(@iterator) diff --git a/lib/openai/stream.rb b/lib/openai/stream.rb index 801b6247..f9319992 100644 --- a/lib/openai/stream.rb +++ b/lib/openai/stream.rb @@ -16,7 +16,9 @@ module OpenAI # # messages => Array # ``` - class Stream < OpenAI::BaseStream + class Stream + include OpenAI::BaseStream + # @private # # @return [Enumerable] diff --git a/rbi/lib/openai/base_client.rbi b/rbi/lib/openai/base_client.rbi index ae1f372e..fd80c3c5 100644 --- a/rbi/lib/openai/base_client.rbi +++ b/rbi/lib/openai/base_client.rbi @@ -22,7 +22,7 @@ module OpenAI body: T.nilable(T.anything), unwrap: T.nilable(Symbol), page: T.nilable(T::Class[OpenAI::BasePage[OpenAI::BaseModel]]), - stream: T.nilable(T::Class[OpenAI::BaseStream[OpenAI::BaseModel]]), + stream: T.nilable(T::Class[OpenAI::BaseStream[T.anything, OpenAI::BaseModel]]), model: T.nilable(OpenAI::Converter::Input), options: T.nilable(T.any(OpenAI::RequestOptions, T::Hash[Symbol, T.anything])) } @@ -148,7 +148,7 @@ module OpenAI body: T.nilable(T.anything), unwrap: T.nilable(Symbol), page: T.nilable(T::Class[OpenAI::BasePage[OpenAI::BaseModel]]), - stream: T.nilable(T::Class[OpenAI::BaseStream[OpenAI::BaseModel]]), + stream: T.nilable(T::Class[OpenAI::BaseStream[T.anything, OpenAI::BaseModel]]), model: T.nilable(OpenAI::Converter::Input), options: T.nilable(T.any(OpenAI::RequestOptions, T::Hash[Symbol, T.anything])) ) diff --git a/rbi/lib/openai/base_stream.rbi b/rbi/lib/openai/base_stream.rbi index f527b849..8b829bd1 100644 --- a/rbi/lib/openai/base_stream.rbi +++ b/rbi/lib/openai/base_stream.rbi @@ -1,7 +1,8 @@ # typed: strong module OpenAI - class BaseStream + module BaseStream + Message = type_member(:in) Elem = type_member(:out) sig { void } @@ -28,11 +29,11 @@ module OpenAI url: URI::Generic, status: Integer, response: Net::HTTPResponse, - messages: T::Enumerable[OpenAI::Util::SSEMessage] + messages: T::Enumerable[Message] ) - .returns(T.attached_class) + .void end - def self.new(model:, url:, status:, response:, messages:) + def initialize(model:, url:, status:, response:, messages:) end end end diff --git a/rbi/lib/openai/stream.rbi b/rbi/lib/openai/stream.rbi index f4bf6fa9..3dc46e28 100644 --- a/rbi/lib/openai/stream.rbi +++ b/rbi/lib/openai/stream.rbi @@ -1,11 +1,27 @@ # typed: strong module OpenAI - class Stream < OpenAI::BaseStream + class Stream + include OpenAI::BaseStream + + Message = type_member(:in) { {fixed: OpenAI::Util::SSEMessage} } Elem = type_member(:out) sig { override.returns(T::Enumerable[Elem]) } private def iterator end + + sig do + params( + model: T.any(T::Class[T.anything], OpenAI::Converter), + url: URI::Generic, + status: Integer, + response: Net::HTTPResponse, + messages: T::Enumerable[OpenAI::Util::SSEMessage] + ) + .returns(T.attached_class) + end + def self.new(model:, url:, status:, response:, messages:) + end end end diff --git a/sig/openai/base_stream.rbs b/sig/openai/base_stream.rbs index 397d46e5..e5d9ec89 100644 --- a/sig/openai/base_stream.rbs +++ b/sig/openai/base_stream.rbs @@ -1,5 +1,5 @@ module OpenAI - class BaseStream[Elem] + module BaseStream[Message, Elem] def close: -> void private def iterator: -> Enumerable[Elem] @@ -15,7 +15,7 @@ module OpenAI url: URI::Generic, status: Integer, response: top, - messages: Enumerable[OpenAI::Util::sse_message] + messages: Enumerable[Message] ) -> void end end diff --git a/sig/openai/stream.rbs b/sig/openai/stream.rbs index 78d58b92..675ecb74 100644 --- a/sig/openai/stream.rbs +++ b/sig/openai/stream.rbs @@ -1,5 +1,15 @@ module OpenAI - class Stream[Elem] < OpenAI::BaseStream[Elem] + class Stream[Elem] + include OpenAI::BaseStream[OpenAI::Util::sse_message, Elem] + private def iterator: -> Enumerable[Elem] + + def initialize: ( + model: Class | OpenAI::Converter, + url: URI::Generic, + status: Integer, + response: top, + messages: Enumerable[OpenAI::Util::sse_message] + ) -> void end end