diff --git a/lib/saml/response.rb b/lib/saml/response.rb index f81d0da..6d0c4fc 100644 --- a/lib/saml/response.rb +++ b/lib/saml/response.rb @@ -28,10 +28,14 @@ def unknown_principal? !success? && status.status_code.unknown_principal? end - def encrypt_assertions(certificate, include_certificate: false) + def encrypt_assertions(key_descriptor_or_certificate, include_certificate: false, include_key_retrieval_method: false) @encrypted_assertions = [] assertions.each do |assertion| - @encrypted_assertions << Saml::Util.encrypt_assertion(assertion, certificate, include_certificate: include_certificate) + @encrypted_assertions << Saml::Util.encrypt_assertion( + assertion, key_descriptor_or_certificate, + include_certificate: include_certificate, + include_key_retrieval_method: include_key_retrieval_method + ) end assertions.clear end diff --git a/lib/saml/util.rb b/lib/saml/util.rb index 180f8ed..e2bad26 100644 --- a/lib/saml/util.rb +++ b/lib/saml/util.rb @@ -60,7 +60,7 @@ def sign_xml(message, format = :xml, include_nested_prefixlist = false, &block) end end - def encrypt_assertion(assertion, key_descriptor_or_certificate, include_certificate: false) + def encrypt_assertion(assertion, key_descriptor_or_certificate, include_certificate: false, include_key_retrieval_method: false) case key_descriptor_or_certificate when OpenSSL::X509::Certificate certificate = key_descriptor_or_certificate @@ -87,6 +87,11 @@ def encrypt_assertion(assertion, key_descriptor_or_certificate, include_certific end encrypted_key.encrypt(certificate.public_key) + if include_key_retrieval_method + encrypted_key.id = '_' + SecureRandom.uuid + encrypted_data.set_key_retrieval_method (Xmlenc::Builder::RetrievalMethod.new(uri: "##{encrypted_key.id}")) + end + Saml::Elements::EncryptedAssertion.new(encrypted_data: encrypted_data, encrypted_keys: encrypted_key) end diff --git a/spec/lib/saml/util_spec.rb b/spec/lib/saml/util_spec.rb index 092d6d8..b364088 100644 --- a/spec/lib/saml/util_spec.rb +++ b/spec/lib/saml/util_spec.rb @@ -371,6 +371,17 @@ def initialize expect(encrypted_assertion.encrypted_keys.key_info.x509Data.x509certificate.to_pem).to eq service_provider.certificate.to_pem end end + + context 'with include_key_retrieval_method option' do + let(:encrypted_assertion) do + Saml::Util.encrypt_assertion(Saml::Assertion.new, service_provider.certificate, include_key_retrieval_method: true) + end + + it 'add key_retrieval_method' do + expect(encrypted_assertion.encrypted_keys.id).not_to be_nil + expect(encrypted_assertion.encrypted_data.key_info.retrieval_method.uri).to eq "##{encrypted_assertion.encrypted_keys.id}" + end + end end context 'with a key descriptor as param' do @@ -393,6 +404,17 @@ def initialize expect(encrypted_assertion.encrypted_keys.key_info.x509Data.x509certificate.to_pem).to eq service_provider.certificate.to_pem end end + + context 'with include_key_retrieval_method option' do + let(:encrypted_assertion) do + Saml::Util.encrypt_assertion(Saml::Assertion.new, key_descriptor, include_key_retrieval_method: true) + end + + it 'add key_retrieval_method' do + expect(encrypted_assertion.encrypted_keys.id).not_to be_nil + expect(encrypted_assertion.encrypted_data.key_info.retrieval_method.uri).to eq "##{encrypted_assertion.encrypted_keys.id}" + end + end end context 'with a wrong param' do