Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions lib/openai/base_model.rb
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,13 @@ def known_fields
@known_fields ||= (self < OpenAI::BaseModel ? superclass.known_fields.dup : {})
end

# @api private
#
# @return [Hash{Symbol=>Symbol}]
def reverse_map
@reverse_map ||= (self < OpenAI::BaseModel ? superclass.reverse_map.dup : {})
end

# @api private
#
# @return [Hash{Symbol=>Hash{Symbol=>Object}}]
Expand Down Expand Up @@ -945,7 +952,7 @@ def defaults = (@defaults ||= {})
fallback = info[:const]
defaults[name_sym] = fallback if required && !info[:nil?] && info.key?(:const)

key = info.fetch(:api_name, name_sym)
key = info[:api_name]&.tap { reverse_map[_1] = name_sym } || name_sym
setter = "#{name_sym}="

if known_fields.key?(name_sym)
Expand Down Expand Up @@ -1202,7 +1209,21 @@ def deconstruct_keys(keys)
def initialize(data = {})
case OpenAI::Util.coerce_hash(data)
in Hash => coerced
@data = coerced.transform_keys(&:to_sym)
@data = coerced.to_h do |key, value|
name = key.to_sym
mapped = self.class.reverse_map.fetch(name, name)
type = self.class.fields[mapped]&.fetch(:type)
stored =
case [type, value]
in [Class, Hash] if type <= OpenAI::BaseModel
type.new(value)
in [OpenAI::ArrayOf, Array] | [OpenAI::HashOf, Hash]
type.coerce(value)
else
value
end
[name, stored]
end
else
raise ArgumentError.new("Expected a #{Hash} or #{OpenAI::BaseModel}, got #{data.inspect}")
end
Expand Down
5 changes: 5 additions & 0 deletions rbi/lib/openai/base_model.rbi
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,11 @@ module OpenAI
def known_fields
end

# @api private
sig { returns(T::Hash[Symbol, Symbol]) }
def reverse_map
end

# @api private
sig do
returns(T::Hash[Symbol, T.all(OpenAI::BaseModel::KnownFieldShape, {type: OpenAI::Converter::Input})])
Expand Down
2 changes: 2 additions & 0 deletions sig/openai/base_model.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ module OpenAI
def self.known_fields: -> ::Hash[Symbol, (OpenAI::BaseModel::known_field
& { type_fn: (^-> OpenAI::Converter::input) })]

def self.reverse_map: -> ::Hash[Symbol, Symbol]

def self.fields: -> ::Hash[Symbol, (OpenAI::BaseModel::known_field
& { type: OpenAI::Converter::input })]

Expand Down
14 changes: 14 additions & 0 deletions test/openai/base_model_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,20 @@ def test_nested_model_dump
end
end

class M4 < M2
required :c, M1
required :d, OpenAI::ArrayOf[M4]
required :e, M2, api_name: :f
end

def test_model_to_h
model = M4.new(a: "wow", c: {}, d: [{}, 2, {c: {}}], f: {})
assert_pattern do
model.to_h => {a: "wow", c: M1, d: [M4, 2, M4 => child], f: M2}
assert_equal({c: M1.new}, child.to_h)
end
end

A3 = OpenAI::ArrayOf[A1]

class M3 < M1
Expand Down