Skip to content

Commit

Permalink
KEYCLOAK-1881 Include key ID in <ds:KeyInfo> in SAML assertions and p…
Browse files Browse the repository at this point in the history
…rotocol message

Changes of SAML assertion creation/parsing that are required to allow
for validation of rotating realm key: signed SAML assertions and signed
SAML protocol message now contain signing key ID in XML <dsig:KeyName>
element.
  • Loading branch information
hmlnarik committed Nov 4, 2016
1 parent 904a5c3 commit 5d84050
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 51 deletions.
Expand Up @@ -79,7 +79,7 @@ public static BaseSAML2BindingBuilder createSaml2Binding(SamlDeployment deployme
binding.canonicalizationMethod(deployment.getSignatureCanonicalizationMethod());
}

binding.signWith(keypair);
binding.signWith(null, keypair);
binding.signDocument();
}
return binding;
Expand Down
Expand Up @@ -82,7 +82,7 @@ protected AuthOutcome logoutRequest(LogoutRequestType request, String relayState
if (deployment.getSignatureCanonicalizationMethod() != null)
binding.canonicalizationMethod(deployment.getSignatureCanonicalizationMethod());
binding.signatureAlgorithm(deployment.getSignatureAlgorithm())
.signWith(deployment.getSigningKeyPair())
.signWith(null, deployment.getSigningKeyPair())
.signDocument();
}

Expand Down Expand Up @@ -113,7 +113,7 @@ private AuthOutcome globalLogout() {
if (deployment.getSignatureCanonicalizationMethod() != null)
binding.canonicalizationMethod(deployment.getSignatureCanonicalizationMethod());
binding.signatureAlgorithm(deployment.getSignatureAlgorithm());
binding.signWith(deployment.getSigningKeyPair())
binding.signWith(null, deployment.getSigningKeyPair())
.signDocument();
}

Expand Down
Expand Up @@ -55,6 +55,7 @@
public class BaseSAML2BindingBuilder<T extends BaseSAML2BindingBuilder> {
protected static final Logger logger = Logger.getLogger(BaseSAML2BindingBuilder.class);

protected String signingKeyId;
protected KeyPair signingKeyPair;
protected X509Certificate signingCertificate;
protected boolean sign;
Expand Down Expand Up @@ -82,23 +83,27 @@ public T signAssertions() {
return (T)this;
}

public T signWith(KeyPair keyPair) {
public T signWith(String signingKeyId, KeyPair keyPair) {
this.signingKeyId = signingKeyId;
this.signingKeyPair = keyPair;
return (T)this;
}

public T signWith(PrivateKey privateKey, PublicKey publicKey) {
public T signWith(String signingKeyId, PrivateKey privateKey, PublicKey publicKey) {
this.signingKeyId = signingKeyId;
this.signingKeyPair = new KeyPair(publicKey, privateKey);
return (T)this;
}

public T signWith(KeyPair keyPair, X509Certificate cert) {
public T signWith(String signingKeyId, KeyPair keyPair, X509Certificate cert) {
this.signingKeyId = signingKeyId;
this.signingKeyPair = keyPair;
this.signingCertificate = cert;
return (T)this;
}

public T signWith(PrivateKey privateKey, PublicKey publicKey, X509Certificate cert) {
public T signWith(String signingKeyId, PrivateKey privateKey, PublicKey publicKey, X509Certificate cert) {
this.signingKeyId = signingKeyId;
this.signingKeyPair = new KeyPair(publicKey, privateKey);
this.signingCertificate = cert;
return (T)this;
Expand Down Expand Up @@ -263,7 +268,7 @@ public void signDocument(Document samlDocument) throws ProcessingException {
samlSignature.setX509Certificate(signingCertificate);
}

samlSignature.signSAMLDocument(samlDocument, signingKeyPair, canonicalizationMethodType);
samlSignature.signSAMLDocument(samlDocument, signingKeyId, signingKeyPair, canonicalizationMethodType);
}

public void signAssertion(Document samlDocument) throws ProcessingException {
Expand Down
Expand Up @@ -121,7 +121,7 @@ public void setX509Certificate(X509Certificate x509Certificate) {
* @throws MarshalException
* @throws GeneralSecurityException
*/
public Document sign(Document doc, String referenceID, KeyPair keyPair, String canonicalizationMethodType) throws ParserConfigurationException,
public Document sign(Document doc, String referenceID, String keyId, KeyPair keyPair, String canonicalizationMethodType) throws ParserConfigurationException,
GeneralSecurityException, MarshalException, XMLSignatureException {
String referenceURI = "#" + referenceID;

Expand All @@ -130,6 +130,7 @@ public Document sign(Document doc, String referenceID, KeyPair keyPair, String c
if (sibling != null) {
SignatureUtilTransferObject dto = new SignatureUtilTransferObject();
dto.setDocumentToBeSigned(doc);
dto.setKeyId(keyId);
dto.setKeyPair(keyPair);
dto.setDigestMethod(digestMethod);
dto.setSignatureMethod(signatureMethod);
Expand All @@ -142,7 +143,7 @@ public Document sign(Document doc, String referenceID, KeyPair keyPair, String c

return XMLSignatureUtil.sign(dto, canonicalizationMethodType);
}
return XMLSignatureUtil.sign(doc, keyPair, digestMethod, signatureMethod, referenceURI, canonicalizationMethodType);
return XMLSignatureUtil.sign(doc, keyId, keyPair, digestMethod, signatureMethod, referenceURI, canonicalizationMethodType);
}

/**
Expand All @@ -153,11 +154,11 @@ public Document sign(Document doc, String referenceID, KeyPair keyPair, String c
*
* @throws org.keycloak.saml.common.exceptions.ProcessingException
*/
public void signSAMLDocument(Document samlDocument, KeyPair keypair, String canonicalizationMethodType) throws ProcessingException {
public void signSAMLDocument(Document samlDocument, String keyId, KeyPair keypair, String canonicalizationMethodType) throws ProcessingException {
// Get the ID from the root
String id = samlDocument.getDocumentElement().getAttribute(ID_ATTRIBUTE_NAME);
try {
sign(samlDocument, id, keypair, canonicalizationMethodType);
sign(samlDocument, id, keyId, keypair, canonicalizationMethodType);
} catch (Exception e) {
throw new ProcessingException(logger.signatureError(e));
}
Expand Down
Expand Up @@ -32,6 +32,9 @@ public class SignatureUtilTransferObject {
private X509Certificate x509Certificate;

private Document documentToBeSigned;

private String keyId;

private KeyPair keyPair;

private Node nextSibling;
Expand Down Expand Up @@ -111,4 +114,12 @@ public X509Certificate getX509Certificate() {
public void setX509Certificate(X509Certificate x509Certificate) {
this.x509Certificate = x509Certificate;
}

public String getKeyId() {
return keyId;
}

public void setKeyId(String keyId) {
this.keyId = keyId;
}
}
Expand Up @@ -79,7 +79,9 @@
import java.security.interfaces.RSAPublicKey;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import javax.xml.crypto.dsig.keyinfo.KeyName;

/**
* Utility for XML Signature <b>Note:</b> You can change the canonicalization method type by using the system property
Expand Down Expand Up @@ -157,7 +159,7 @@ public static void setIncludeKeyInfoInSignature(boolean includeKeyInfoInSignatur
* @throws MarshalException
* @throws GeneralSecurityException
*/
public static Document sign(Document doc, Node nodeToBeSigned, KeyPair keyPair, String digestMethod,
public static Document sign(Document doc, Node nodeToBeSigned, String keyId, KeyPair keyPair, String digestMethod,
String signatureMethod, String referenceURI, X509Certificate x509Certificate,
String canonicalizationMethodType) throws ParserConfigurationException, GeneralSecurityException,
MarshalException, XMLSignatureException {
Expand All @@ -179,7 +181,7 @@ public static Document sign(Document doc, Node nodeToBeSigned, KeyPair keyPair,
if (!referenceURI.isEmpty()) {
propagateIDAttributeSetup(nodeToBeSigned, newDoc.getDocumentElement());
}
newDoc = sign(newDoc, keyPair, digestMethod, signatureMethod, referenceURI, x509Certificate, canonicalizationMethodType);
newDoc = sign(newDoc, keyId, keyPair, digestMethod, signatureMethod, referenceURI, x509Certificate, canonicalizationMethodType);

// if the signed element is a SAMLv2.0 assertion we need to move the signature element to the position
// specified in the schema (before the assertion subject element).
Expand Down Expand Up @@ -220,10 +222,10 @@ public static Document sign(Document doc, Node nodeToBeSigned, KeyPair keyPair,
* @throws MarshalException
* @throws XMLSignatureException
*/
public static void sign(Element elementToSign, Node nextSibling, KeyPair keyPair, String digestMethod,
public static void sign(Element elementToSign, Node nextSibling, String keyId, KeyPair keyPair, String digestMethod,
String signatureMethod, String referenceURI, String canonicalizationMethodType)
throws GeneralSecurityException, MarshalException, XMLSignatureException {
sign(elementToSign, nextSibling, keyPair, digestMethod, signatureMethod, referenceURI, null, canonicalizationMethodType);
sign(elementToSign, nextSibling, keyId, keyPair, digestMethod, signatureMethod, referenceURI, null, canonicalizationMethodType);
}

/**
Expand All @@ -242,15 +244,15 @@ public static void sign(Element elementToSign, Node nextSibling, KeyPair keyPair
* @throws XMLSignatureException
* @since 2.5.0
*/
public static void sign(Element elementToSign, Node nextSibling, KeyPair keyPair, String digestMethod,
public static void sign(Element elementToSign, Node nextSibling, String keyId, KeyPair keyPair, String digestMethod,
String signatureMethod, String referenceURI, X509Certificate x509Certificate, String canonicalizationMethodType)
throws GeneralSecurityException, MarshalException, XMLSignatureException {
PrivateKey signingKey = keyPair.getPrivate();
PublicKey publicKey = keyPair.getPublic();

DOMSignContext dsc = new DOMSignContext(signingKey, elementToSign, nextSibling);

signImpl(dsc, digestMethod, signatureMethod, referenceURI, publicKey, x509Certificate, canonicalizationMethodType);
signImpl(dsc, digestMethod, signatureMethod, referenceURI, keyId, publicKey, x509Certificate, canonicalizationMethodType);
}

/**
Expand Down Expand Up @@ -284,9 +286,9 @@ public static void propagateIDAttributeSetup(Node sourceNode, Element destElemen
* @throws XMLSignatureException
* @throws MarshalException
*/
public static Document sign(Document doc, KeyPair keyPair, String digestMethod, String signatureMethod, String referenceURI, String canonicalizationMethodType)
public static Document sign(Document doc, String keyId, KeyPair keyPair, String digestMethod, String signatureMethod, String referenceURI, String canonicalizationMethodType)
throws GeneralSecurityException, MarshalException, XMLSignatureException {
return sign(doc, keyPair, digestMethod, signatureMethod, referenceURI, null, canonicalizationMethodType);
return sign(doc, keyId, keyPair, digestMethod, signatureMethod, referenceURI, null, canonicalizationMethodType);
}

/**
Expand All @@ -304,7 +306,7 @@ public static Document sign(Document doc, KeyPair keyPair, String digestMethod,
* @throws MarshalException
* @since 2.5.0
*/
public static Document sign(Document doc, KeyPair keyPair, String digestMethod, String signatureMethod, String referenceURI,
public static Document sign(Document doc, String keyId, KeyPair keyPair, String digestMethod, String signatureMethod, String referenceURI,
X509Certificate x509Certificate, String canonicalizationMethodType)
throws GeneralSecurityException, MarshalException, XMLSignatureException {
logger.trace("Document to be signed=" + DocumentUtil.asString(doc));
Expand All @@ -313,7 +315,7 @@ public static Document sign(Document doc, KeyPair keyPair, String digestMethod,

DOMSignContext dsc = new DOMSignContext(signingKey, doc.getDocumentElement());

signImpl(dsc, digestMethod, signatureMethod, referenceURI, publicKey, x509Certificate, canonicalizationMethodType);
signImpl(dsc, digestMethod, signatureMethod, referenceURI, keyId, publicKey, x509Certificate, canonicalizationMethodType);

return doc;
}
Expand Down Expand Up @@ -344,7 +346,7 @@ public static Document sign(SignatureUtilTransferObject dto, String canonicaliza

DOMSignContext dsc = new DOMSignContext(signingKey, doc.getDocumentElement(), nextSibling);

signImpl(dsc, digestMethod, signatureMethod, referenceURI, publicKey, dto.getX509Certificate(), canonicalizationMethodType);
signImpl(dsc, digestMethod, signatureMethod, referenceURI, dto.getKeyId(), publicKey, dto.getX509Certificate(), canonicalizationMethodType);

return doc;
}
Expand Down Expand Up @@ -594,7 +596,7 @@ public static KeyValueType createKeyValue(PublicKey key) {
throw logger.unsupportedType(key.toString());
}

private static void signImpl(DOMSignContext dsc, String digestMethod, String signatureMethod, String referenceURI, PublicKey publicKey,
private static void signImpl(DOMSignContext dsc, String digestMethod, String signatureMethod, String referenceURI, String keyId, PublicKey publicKey,
X509Certificate x509Certificate, String canonicalizationMethodType)
throws GeneralSecurityException, MarshalException, XMLSignatureException {
dsc.setDefaultNamespacePrefix("dsig");
Expand All @@ -618,35 +620,32 @@ private static void signImpl(DOMSignContext dsc, String digestMethod, String sig

KeyInfo ki = null;
if (includeKeyInfoInSignature) {
ki = createKeyInfo(publicKey, x509Certificate);
ki = createKeyInfo(keyId, publicKey, x509Certificate);
} else {
ki = createKeyInfo(keyId, null, null);
}
XMLSignature signature = fac.newXMLSignature(si, ki);

signature.sign(dsc);
}

private static KeyInfo createKeyInfo(PublicKey publicKey, X509Certificate x509Certificate) throws KeyException {
private static KeyInfo createKeyInfo(String keyId, PublicKey publicKey, X509Certificate x509Certificate) throws KeyException {
KeyInfoFactory keyInfoFactory = fac.getKeyInfoFactory();
KeyInfo keyInfo = null;
KeyValue keyValue = null;
//Just with public key
if (publicKey != null) {
keyValue = keyInfoFactory.newKeyValue(publicKey);
keyInfo = keyInfoFactory.newKeyInfo(Collections.singletonList(keyValue));

List<Object> items = new LinkedList<>();

if (keyId != null) {
items.add(keyInfoFactory.newKeyName(keyId));
}
if (x509Certificate != null) {
List x509list = new ArrayList();

x509list.add(x509Certificate);
X509Data x509Data = keyInfoFactory.newX509Data(x509list);
List items = new ArrayList();
if (x509Certificate != null) {
items.add(keyInfoFactory.newX509Data(Collections.singletonList(x509Certificate)));
}

items.add(x509Data);
if (keyValue != null) {
items.add(keyValue);
}
keyInfo = keyInfoFactory.newKeyInfo(items);
if (publicKey != null) {
items.add(keyInfoFactory.newKeyValue(publicKey));
}
return keyInfo;

return keyInfoFactory.newKeyInfo(items);
}
}
Expand Up @@ -267,7 +267,7 @@ protected Response logoutRequest(LogoutRequestType request, String relayState) {
.relayState(relayState);
if (config.isWantAuthnRequestsSigned()) {
KeyManager.ActiveKey keys = session.keys().getActiveKey(realm);
binding.signWith(keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate())
binding.signWith(keys.getKid(), keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate())
.signatureAlgorithm(provider.getSignatureAlgorithm())
.signDocument();
}
Expand Down
Expand Up @@ -103,7 +103,7 @@ public Response performLogin(AuthenticationRequest request) {

KeyPair keypair = new KeyPair(keys.getPublicKey(), keys.getPrivateKey());

binding.signWith(keypair);
binding.signWith(keys.getKid(), keypair);
binding.signatureAlgorithm(getSignatureAlgorithm());
binding.signDocument();
}
Expand Down Expand Up @@ -198,7 +198,7 @@ private JaxrsSAML2BindingBuilder buildLogoutBinding(KeycloakSession session, Use
.relayState(userSession.getId());
if (getConfig().isWantAuthnRequestsSigned()) {
KeyManager.ActiveKey keys = session.keys().getActiveKey(realm);
binding.signWith(keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate())
binding.signWith(keys.getKid(), keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate())
.signatureAlgorithm(getSignatureAlgorithm())
.signDocument();
}
Expand Down
Expand Up @@ -402,14 +402,14 @@ public Response authenticated(UserSessionModel userSession, ClientSessionCode ac
if (canonicalization != null) {
bindingBuilder.canonicalizationMethod(canonicalization);
}
bindingBuilder.signatureAlgorithm(samlClient.getSignatureAlgorithm()).signWith(keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate()).signDocument();
bindingBuilder.signatureAlgorithm(samlClient.getSignatureAlgorithm()).signWith(keys.getKid(), keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate()).signDocument();
}
if (samlClient.requiresAssertionSignature()) {
String canonicalization = samlClient.getCanonicalizationMethod();
if (canonicalization != null) {
bindingBuilder.canonicalizationMethod(canonicalization);
}
bindingBuilder.signatureAlgorithm(samlClient.getSignatureAlgorithm()).signWith(keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate()).signAssertions();
bindingBuilder.signatureAlgorithm(samlClient.getSignatureAlgorithm()).signWith(keys.getKid(), keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate()).signAssertions();
}
if (samlClient.requiresEncryption()) {
PublicKey publicKey = null;
Expand Down Expand Up @@ -541,7 +541,7 @@ public Response finishLogout(UserSessionModel userSession) {
binding.canonicalizationMethod(canonicalization);
}
KeyManager.ActiveKey keys = session.keys().getActiveKey(realm);
binding.signatureAlgorithm(algorithm).signWith(keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate()).signDocument();
binding.signatureAlgorithm(algorithm).signWith(keys.getKid(), keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate()).signDocument();
}

try {
Expand Down Expand Up @@ -639,7 +639,7 @@ private JaxrsSAML2BindingBuilder createBindingBuilder(SamlClient samlClient) {
JaxrsSAML2BindingBuilder binding = new JaxrsSAML2BindingBuilder();
if (samlClient.requiresRealmSignature()) {
KeyManager.ActiveKey keys = session.keys().getActiveKey(realm);
binding.signatureAlgorithm(samlClient.getSignatureAlgorithm()).signWith(keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate()).signDocument();
binding.signatureAlgorithm(samlClient.getSignatureAlgorithm()).signWith(keys.getKid(), keys.getPrivateKey(), keys.getPublicKey(), keys.getCertificate()).signDocument();
}
return binding;
}
Expand Down

0 comments on commit 5d84050

Please sign in to comment.