Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion lib/ruby_llm/provider.rb
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def headers
{}
end

def parameter_mappings
{}
end

def slug
self.class.slug
end
Expand All @@ -39,6 +43,7 @@ def configuration_requirements

def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, schema: nil, &) # rubocop:disable Metrics/ParameterLists
normalized_temperature = maybe_normalize_temperature(temperature, model)
transformed_params = apply_parameter_mappings(params)

payload = Utils.deep_merge(
render_payload(
Expand All @@ -49,7 +54,7 @@ def complete(messages, tools:, temperature:, model:, params: {}, headers: {}, sc
stream: block_given?,
schema: schema
),
params
transformed_params
)

if block_given?
Expand Down Expand Up @@ -192,6 +197,28 @@ def configured_remote_providers(config)

private

def apply_parameter_mappings(params)
return params if parameter_mappings.empty?

transformed = params.dup

parameter_mappings.each do |source_key, target_path|
next unless transformed.key?(source_key)

value = transformed.delete(source_key)
*keys, last_key = target_path

target = keys.inject(transformed) do |hash, key|
hash[key] = {} unless hash[key].is_a?(Hash)
hash[key]
end

target[last_key] = value
end

transformed
end

def try_parse_json(maybe_json)
return maybe_json unless maybe_json.is_a?(String)

Expand Down
8 changes: 8 additions & 0 deletions lib/ruby_llm/providers/gemini.rb
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ def headers
}
end

private

def parameter_mappings
{
max_tokens: %i[generationConfig maxOutputTokens]
}
end

class << self
def capabilities
Gemini::Capabilities
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions spec/ruby_llm/chat_request_options_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,24 @@
expect(json_response).to eq({ 'result' => 8 })
end
end

# Provider [:gemini] automatically maps max_tokens to generationConfig.maxOutputTokens
CHAT_MODELS.select { |model_info| model_info[:provider] == :gemini }.each do |model_info|
model = model_info[:model]
provider = model_info[:provider]
it "#{provider}/#{model} automatically maps max_tokens to maxOutputTokens" do
chat = RubyLLM
.chat(model: model, provider: provider)
.with_params(max_tokens: 100)

response = chat.ask('Say hello in 3 words.')

request_body = JSON.parse(response.raw.env.request_body)
expect(request_body.dig('generationConfig', 'maxOutputTokens')).to eq(100)
expect(request_body).not_to have_key('max_tokens')

expect(response.content).to be_present
end
end
end
end
45 changes: 45 additions & 0 deletions spec/ruby_llm/providers/gemini/chat_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,49 @@
# Verify our implementation correctly sums both token types
expect(response.output_tokens).to eq(candidates_tokens + thoughts_tokens)
end

describe 'parameter mapping' do
let(:provider) do
config = RubyLLM::Configuration.new
config.gemini_api_key = 'test_key'
RubyLLM::Providers::Gemini.new(config)
end

it 'maps max_tokens to generationConfig.maxOutputTokens' do
params = { max_tokens: 1000 }
result = provider.send(:apply_parameter_mappings, params)

expect(result).to eq({ generationConfig: { maxOutputTokens: 1000 } })
end

it 'removes max_tokens from the params after mapping' do
params = { max_tokens: 500 }
result = provider.send(:apply_parameter_mappings, params)

expect(result).not_to have_key(:max_tokens)
end

it 'preserves other params while mapping max_tokens' do
params = { max_tokens: 1000, other_param: 'value' }
result = provider.send(:apply_parameter_mappings, params)

expect(result[:other_param]).to eq('value')
expect(result.dig(:generationConfig, :maxOutputTokens)).to eq(1000)
end

it 'merges with existing generationConfig hash' do
params = { max_tokens: 500, generationConfig: { temperature: 0.7 } }
result = provider.send(:apply_parameter_mappings, params)

expect(result.dig(:generationConfig, :temperature)).to eq(0.7)
expect(result.dig(:generationConfig, :maxOutputTokens)).to eq(500)
end

it 'handles params without max_tokens' do
params = { other: 'value', custom: 123 }
result = provider.send(:apply_parameter_mappings, params)

expect(result).to eq({ other: 'value', custom: 123 })
end
end
end