diff --git a/lib/openai/base_model.rb b/lib/openai/base_model.rb index 4ef0106e..5c1a2499 100644 --- a/lib/openai/base_model.rb +++ b/lib/openai/base_model.rb @@ -902,6 +902,13 @@ def known_fields @known_fields ||= (self < OpenAI::BaseModel ? superclass.known_fields.dup : {}) end + # @api private + # + # @return [Hash{Symbol=>Symbol}] + def reverse_map + @reverse_map ||= (self < OpenAI::BaseModel ? superclass.reverse_map.dup : {}) + end + # @api private # # @return [Hash{Symbol=>Hash{Symbol=>Object}}] @@ -945,7 +952,7 @@ def defaults = (@defaults ||= {}) fallback = info[:const] defaults[name_sym] = fallback if required && !info[:nil?] && info.key?(:const) - key = info.fetch(:api_name, name_sym) + key = info[:api_name]&.tap { reverse_map[_1] = name_sym } || name_sym setter = "#{name_sym}=" if known_fields.key?(name_sym) @@ -1202,7 +1209,21 @@ def deconstruct_keys(keys) def initialize(data = {}) case OpenAI::Util.coerce_hash(data) in Hash => coerced - @data = coerced.transform_keys(&:to_sym) + @data = coerced.to_h do |key, value| + name = key.to_sym + mapped = self.class.reverse_map.fetch(name, name) + type = self.class.fields[mapped]&.fetch(:type) + stored = + case [type, value] + in [Class, Hash] if type <= OpenAI::BaseModel + type.new(value) + in [OpenAI::ArrayOf, Array] | [OpenAI::HashOf, Hash] + type.coerce(value) + else + value + end + [name, stored] + end else raise ArgumentError.new("Expected a #{Hash} or #{OpenAI::BaseModel}, got #{data.inspect}") end diff --git a/rbi/lib/openai/base_model.rbi b/rbi/lib/openai/base_model.rbi index 8840d3a8..11854971 100644 --- a/rbi/lib/openai/base_model.rbi +++ b/rbi/lib/openai/base_model.rbi @@ -457,6 +457,11 @@ module OpenAI def known_fields end + # @api private + sig { returns(T::Hash[Symbol, Symbol]) } + def reverse_map + end + # @api private sig do returns(T::Hash[Symbol, T.all(OpenAI::BaseModel::KnownFieldShape, {type: OpenAI::Converter::Input})]) diff --git a/sig/openai/base_model.rbs b/sig/openai/base_model.rbs index 574847b4..d9267814 100644 --- a/sig/openai/base_model.rbs +++ b/sig/openai/base_model.rbs @@ -176,6 +176,8 @@ module OpenAI def self.known_fields: -> ::Hash[Symbol, (OpenAI::BaseModel::known_field & { type_fn: (^-> OpenAI::Converter::input) })] + def self.reverse_map: -> ::Hash[Symbol, Symbol] + def self.fields: -> ::Hash[Symbol, (OpenAI::BaseModel::known_field & { type: OpenAI::Converter::input })] diff --git a/test/openai/base_model_test.rb b/test/openai/base_model_test.rb index f6e598d0..1a3c623e 100644 --- a/test/openai/base_model_test.rb +++ b/test/openai/base_model_test.rb @@ -222,6 +222,20 @@ def test_nested_model_dump end end + class M4 < M2 + required :c, M1 + required :d, OpenAI::ArrayOf[M4] + required :e, M2, api_name: :f + end + + def test_model_to_h + model = M4.new(a: "wow", c: {}, d: [{}, 2, {c: {}}], f: {}) + assert_pattern do + model.to_h => {a: "wow", c: M1, d: [M4, 2, M4 => child], f: M2} + assert_equal({c: M1.new}, child.to_h) + end + end + A3 = OpenAI::ArrayOf[A1] class M3 < M1