diff --git a/Gemfile.lock b/Gemfile.lock index 1e860f5..1ddf06a 100644 --- a/Gemfile.lock +++ b/Gemfile.lock @@ -1,7 +1,7 @@ PATH remote: . specs: - ai (0.4.1) + ai (0.4.2) actionpack (>= 7.1.3) activesupport (>= 7.1.3) json_schemer (~> 2.4.0) diff --git a/lib/ai/clients/mastra.rb b/lib/ai/clients/mastra.rb index 3ccb9dd..a958f9a 100644 --- a/lib/ai/clients/mastra.rb +++ b/lib/ai/clients/mastra.rb @@ -74,6 +74,10 @@ def generate(agent_name, messages:, options: {}) parsed_response['response']['body'] = parsed_response['response']['messages'] end + if parsed_response['reasoning'] + parsed_response['reasoning_details'] = parsed_response['reasoning'] + end + parsed_response end diff --git a/lib/ai/schema_to_struct_string.rb b/lib/ai/schema_to_struct_string.rb index 31e74d6..d384855 100644 --- a/lib/ai/schema_to_struct_string.rb +++ b/lib/ai/schema_to_struct_string.rb @@ -1,4 +1,5 @@ # typed: strict +# rubocop:disable Sorbet/ForbidTUntyped require 'json' require 'active_support/inflector' @@ -9,6 +10,10 @@ module Ai # # The resulting definition is returned as a *string* so that it can be # injected into ERB templates when auto-generating files. + # + # Note: This class uses T.untyped for JSON schema structures as they are + # inherently dynamic and come from external sources. Type safety is maintained + # through runtime checks and the generated output is fully typed. class SchemaToStructString extend T::Sig @@ -23,24 +28,27 @@ def initialize(schema, class_name: 'Input') @root_class_name = class_name @generated_classes = T.let(Set.new, T::Set[String]) @nested_definitions = T.let([], T::Array[String]) - @schema_definitions = T.let({}, T::Hash[String, T::Hash[String, T.untyped]]) # rubocop:disable Sorbet/ForbidTUntyped - @resolved_refs = T.let({}, T::Hash[String, T::Hash[String, T.untyped]]) # rubocop:disable Sorbet/ForbidTUntyped + @schema_definitions = T.let({}, T::Hash[String, T::Hash[String, T.untyped]]) + @resolved_refs = T.let({}, T::Hash[String, T::Hash[String, T.untyped]]) + @dependencies = T.let({}, T::Hash[String, T::Set[String]]) + @current_class = T.let(nil, T.nilable(String)) end sig { returns(String) } def convert main_definition = generate_struct(parsed_schema, @root_class_name) - (@nested_definitions + [main_definition]).join("\n\n") + sorted_definitions = topological_sort(@nested_definitions) + (sorted_definitions + [main_definition]).join("\n\n") end - sig { returns(T::Hash[String, T.untyped]) } # rubocop:disable Sorbet/ForbidTUntyped + sig { returns(T::Hash[String, T.untyped]) } def parsed_schema return @parsed_schema if @parsed_schema - full_schema = T.let(JSON.parse(@schema), T::Hash[String, T.untyped]) # rubocop:disable Sorbet/ForbidTUntyped + full_schema = T.let(JSON.parse(@schema), T::Hash[String, T.untyped]) if full_schema.key?('json') - @parsed_schema = T.let(full_schema['json'], T.nilable(T::Hash[String, T.untyped])) # rubocop:disable Sorbet/ForbidTUntyped + @parsed_schema = T.let(full_schema['json'], T.nilable(T::Hash[String, T.untyped])) elsif full_schema.key?('$defs') || full_schema.key?('definitions') @schema_definitions = full_schema['$defs'] || full_schema['definitions'] || {} @parsed_schema = full_schema @@ -53,13 +61,11 @@ def parsed_schema raise ArgumentError, "Invalid JSON schema provided: #{e.message}" end - # rubocop:disable Sorbet/ForbidTUntyped sig do params(schema_hash: T::Hash[T.any(Symbol, String), T.untyped]).returns( T::Hash[T.any(Symbol, String), T.untyped] ) end - # rubocop:enable Sorbet/ForbidTUntyped def resolve_ref(schema_hash) ref = schema_hash['$ref'] return schema_hash unless ref @@ -84,18 +90,17 @@ def resolve_ref(schema_hash) return schema_hash unless resolved - @resolved_refs[ref] = T.cast(resolved, T::Hash[String, T.untyped]) # rubocop:disable Sorbet/ForbidTUntyped + @resolved_refs[ref] = T.cast(resolved, T::Hash[String, T.untyped]) resolved end sig do - params( - schema: T.untyped, # rubocop:disable Sorbet/ForbidTUntyped - parts: T::Array[String] - ).returns(T.nilable(T::Hash[T.any(Symbol, String), T.untyped])) # rubocop:disable Sorbet/ForbidTUntyped + params(schema: T.untyped, parts: T::Array[String]).returns( + T.nilable(T::Hash[T.any(Symbol, String), T.untyped]) + ) end def navigate_schema_path(schema, parts) - current = T.let(schema, T.untyped) # rubocop:disable Sorbet/ForbidTUntyped + current = T.let(schema, T.untyped) parts.each_with_index do |part, _index| return nil if current.nil? @@ -129,22 +134,31 @@ def navigate_schema_path(schema, parts) sig do params( - schema_hash: T::Hash[T.any(Symbol, String), T.untyped], # rubocop:disable Sorbet/ForbidTUntyped + schema_hash: T::Hash[T.any(Symbol, String), T.untyped], class_name: String, depth: Integer ).returns(String) end def generate_struct(schema_hash, class_name, depth = 0) - properties = T.let(schema_hash.fetch('properties', {}), T::Hash[String, T.untyped]) # rubocop:disable Sorbet/ForbidTUntyped + properties = T.let(schema_hash.fetch('properties', {}), T::Hash[String, T.untyped]) required = T.let(schema_hash.fetch('required', []), T::Array[String]) + previous_class = @current_class + @current_class = class_name + @dependencies[class_name] ||= Set.new + lines = [] lines << "class #{class_name} < T::Struct" properties.each do |prop_name, prop_schema| prop_type = sorbet_type(prop_name, prop_schema, depth) - prop_type = "T.nilable(#{prop_type})" unless required.include?(prop_name) || - prop_type == 'T.untyped' + + extract_class_dependencies(prop_type).each { |dep| add_dependency(dep) } + + unless required.include?(prop_name) || prop_type == 'T.untyped' || + prop_type.start_with?('T.nilable(') + prop_type = "T.nilable(#{prop_type})" + end comment = build_comment(prop_schema) lines << " #{comment}" if comment @@ -152,30 +166,71 @@ def generate_struct(schema_hash, class_name, depth = 0) end lines << 'end' + + @current_class = previous_class + lines.join("\n") end sig do params( prop_name: T.any(Symbol, String), - prop_schema: T::Hash[T.any(Symbol, String), T.untyped], # rubocop:disable Sorbet/ForbidTUntyped + prop_schema: T::Hash[T.any(Symbol, String), T.untyped], depth: Integer ).returns(String) end def sorbet_type(prop_name, prop_schema, depth) # rubocop:disable Metrics/CyclomaticComplexity resolved_schema = resolve_ref(prop_schema) - type = T.unsafe(resolved_schema['type'] || resolved_schema[:type]) # rubocop:disable Sorbet/ForbidTUnsafe - - if type.is_a?(Array) - non_null = type.reject { |t| t == 'null' } - ruby_types = - non_null - .map { |t| sorbet_type(prop_name, resolved_schema.merge('type' => t), depth) } - .uniq - return "T.any(#{ruby_types.join(', ')})" + + # Handle anyOf pattern for nullable types (e.g., from Zod's .nullable()) + any_of_value = resolved_schema['anyOf'] + if any_of_value.is_a?(Array) + non_null_schemas = any_of_value.select { |s| s.is_a?(Hash) && s['type'] != 'null' } + + if non_null_schemas.length == 1 && non_null_schemas.length < any_of_value.length + # It's a nullable type: anyOf with exactly one non-null type + first_schema = T.cast(non_null_schemas.first, T::Hash[T.any(Symbol, String), T.untyped]) + inner_type = sorbet_type(prop_name, first_schema, depth) + return "T.nilable(#{inner_type})" + elsif non_null_schemas.length > 1 + # Multiple non-null types in union + ruby_types = + non_null_schemas + .map do |schema| + sorbet_type( + prop_name, + T.cast(schema, T::Hash[T.any(Symbol, String), T.untyped]), + depth + ) + end + .uniq + base_type = "T.any(#{ruby_types.join(', ')})" + has_null = any_of_value.any? { |s| s.is_a?(Hash) && s['type'] == 'null' } + return has_null ? "T.nilable(#{base_type})" : base_type + end + end + + # Get the type field, which can be a string or array + type_value = resolved_schema['type'] || resolved_schema[:type] + + if type_value.is_a?(Array) + non_null = type_value.reject { |t| t == 'null' } + + if non_null.length == 1 && non_null.length < type_value.length + inner_type = + sorbet_type(prop_name, resolved_schema.merge('type' => non_null.first), depth) + return "T.nilable(#{inner_type})" + elsif non_null.length > 1 + ruby_types = + non_null + .map { |t| sorbet_type(prop_name, resolved_schema.merge('type' => t), depth) } + .uniq + base_type = "T.any(#{ruby_types.join(', ')})" + return non_null.length < type_value.length ? "T.nilable(#{base_type})" : base_type + end end - case type + case type_value when 'string' return 'Time' if resolved_schema['format'] == 'date-time' return 'String' unless resolved_schema.key?('enum') @@ -204,7 +259,7 @@ def sorbet_type(prop_name, prop_schema, depth) # rubocop:disable Metrics/Cycloma end "T::Array[T.any(#{tuple_types.join(', ')})]" else - items = T.cast(raw_items, T::Hash[T.any(Symbol, String), T.untyped]) # rubocop:disable Sorbet/ForbidTUntyped + items = T.cast(raw_items, T::Hash[T.any(Symbol, String), T.untyped]) "T::Array[#{sorbet_type(prop_name.to_s.singularize, items, depth + 1)}]" end when 'object' @@ -238,7 +293,7 @@ def generate_enum(class_name, values) lines.join("\n") end - sig { params(prop_schema: T::Hash[String, T.untyped]).returns(T.nilable(String)) } # rubocop:disable Sorbet/ForbidTUntyped + sig { params(prop_schema: T::Hash[String, T.untyped]).returns(T.nilable(String)) } def build_comment(prop_schema) keys_in_order = %w[ minLength @@ -269,5 +324,66 @@ def build_comment(prop_schema) "# #{entries.join(', ')}" end + + sig { params(type_string: String).returns(T::Set[String]) } + def extract_class_dependencies(type_string) + dependencies = Set.new + + type_string.scan(/(? { + 'type' => 'object', + 'properties' => { + 'name' => { + 'anyOf' => [{ 'type' => 'string' }, { 'type' => 'null' }] + } + }, + 'required' => ['name'] + } + }.to_json + + result = converter.convert(nullable_schema, class_name: 'NullableTest') + + expect(result).to include('const :name, T.nilable(String)') + expect(result).not_to include('T.untyped') + end + + it 'handles type arrays with null for nullable types' do + nullable_schema = { + 'json' => { + 'type' => 'object', + 'properties' => { + 'count' => { + 'type' => %w[integer null] + } + }, + 'required' => ['count'] + } + }.to_json + + result = converter.convert(nullable_schema, class_name: 'NullableTypeArray') + + expect(result).to include('const :count, T.nilable(Integer)') + end + + it 'handles nullable nested objects' do + nullable_nested_schema = { + 'json' => { + 'type' => 'object', + 'properties' => { + 'metadata' => { + 'anyOf' => [ + { + 'type' => 'object', + 'properties' => { + 'version' => { + 'type' => 'string' + } + }, + 'required' => ['version'] + }, + { 'type' => 'null' } + ] + } + }, + 'required' => ['metadata'] + } + }.to_json + + result = converter.convert(nullable_nested_schema, class_name: 'NullableNested') + + expect(result).to include('class Metadata < T::Struct') + expect(result).to include('const :version, String') + expect(result).to include('const :metadata, T.nilable(Metadata)') + expect(result).not_to include('T.untyped') + end + + it 'does not double-wrap nullable types' do + nullable_optional_schema = { + 'json' => { + 'type' => 'object', + 'properties' => { + 'optional_nullable' => { + 'anyOf' => [{ 'type' => 'string' }, { 'type' => 'null' }] + } + }, + 'required' => [] + } + }.to_json + + result = converter.convert(nullable_optional_schema, class_name: 'NoDoubleWrap') + + # Should have T.nilable once, not T.nilable(T.nilable(...)) + expect(result).to include('const :optional_nullable, T.nilable(String)') + expect(result).not_to match(/T\.nilable\(T\.nilable/) + end + end + + describe 'dependency sorting' do + it 'sorts nested structs by dependency order' do + dependency_schema = { + 'json' => { + 'type' => 'object', + 'properties' => { + 'report' => { + 'type' => 'object', + 'properties' => { + 'employee' => { + 'type' => 'object', + 'properties' => { + 'id' => { + 'type' => 'string' + }, + 'name' => { + 'type' => 'string' + } + }, + 'required' => %w[id name] + }, + 'meeting' => { + 'type' => 'object', + 'properties' => { + 'id' => { + 'type' => 'integer' + }, + 'employee_id' => { + 'type' => 'string' + } + }, + 'required' => %w[id employee_id] + } + }, + 'required' => %w[employee meeting] + } + }, + 'required' => ['report'] + } + }.to_json + + result = converter.convert(dependency_schema, class_name: 'DependencyTest') + + # Employee and Meeting should come before Report + employee_pos = result.index('class Employee < T::Struct') + meeting_pos = result.index('class Meeting < T::Struct') + report_pos = result.index('class Report < T::Struct') + input_pos = result.index('class DependencyTest < T::Struct') + + expect(employee_pos).to be < report_pos + expect(meeting_pos).to be < report_pos + expect(report_pos).to be < input_pos + end + + it 'sorts complex nested dependencies correctly' do + complex_dependency_schema = { + 'json' => { + 'type' => 'object', + 'properties' => { + 'direct_reports' => { + 'type' => 'array', + 'items' => { + 'type' => 'object', + 'properties' => { + 'employee_info' => { + 'type' => 'object', + 'properties' => { + 'id' => { + 'type' => 'string' + }, + 'name' => { + 'type' => 'string' + } + }, + 'required' => %w[id name] + }, + 'meetings' => { + 'type' => 'array', + 'items' => { + 'type' => 'object', + 'properties' => { + 'id' => { + 'type' => 'integer' + } + }, + 'required' => ['id'] + } + } + }, + 'required' => %w[employee_info meetings] + } + } + }, + 'required' => ['direct_reports'] + } + }.to_json + + result = converter.convert(complex_dependency_schema, class_name: 'ComplexDep') + + # EmployeeInfo and Meeting should come before DirectReport + employee_info_pos = result.index('class EmployeeInfo < T::Struct') + meeting_pos = result.index('class Meeting < T::Struct') + direct_report_pos = result.index('class DirectReport < T::Struct') + + expect(employee_info_pos).to be < direct_report_pos + expect(meeting_pos).to be < direct_report_pos + end + end end -end \ No newline at end of file +end