Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support AuthnRequest in SAML #372

Merged
merged 1 commit into from
Dec 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions controllers/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,11 +364,12 @@ func (c *ApiController) Login() {

func (c *ApiController) GetSamlLogin() {
providerId := c.Input().Get("id")
authURL, err := object.GenerateSamlLoginUrl(providerId)
relayState := c.Input().Get("relayState")
authURL, method, err := object.GenerateSamlLoginUrl(providerId, relayState)
if err != nil {
c.ResponseError(err.Error())
}
c.ResponseOk(authURL)
c.ResponseOk(authURL, method)
}

func (c *ApiController) HandleSamlLogin() {
Expand Down
7 changes: 4 additions & 3 deletions object/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ type Provider struct {
Domain string `xorm:"varchar(100)" json:"domain"`
Bucket string `xorm:"varchar(100)" json:"bucket"`

Metadata string `xorm:"mediumtext" json:"metadata"`
IdP string `xorm:"mediumtext" json:"idP"`
IssuerUrl string `xorm:"varchar(100)" json:"issuerUrl"`
Metadata string `xorm:"mediumtext" json:"metadata"`
IdP string `xorm:"mediumtext" json:"idP"`
IssuerUrl string `xorm:"varchar(100)" json:"issuerUrl"`
EnableSignAuthnRequest bool `json:"enableSignAuthnRequest"`

ProviderUrl string `xorm:"varchar(200)" json:"providerUrl"`
}
Expand Down
51 changes: 39 additions & 12 deletions object/saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package object

import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
Expand All @@ -40,20 +41,32 @@ func ParseSamlResponse(samlResponse string, providerType string) (string, error)
return assertionInfo.NameID, nil
}

func GenerateSamlLoginUrl(id string) (string, error) {
func GenerateSamlLoginUrl(id, relayState string) (string, string, error) {
provider := GetProvider(id)
if provider.Category != "SAML" {
return "", fmt.Errorf("Provider %s's category is not SAML", provider.Name)
return "", "", fmt.Errorf("Provider %s's category is not SAML", provider.Name)
}
sp, err := buildSp(provider, "")
if err != nil {
return "", err
return "", "", err
}
authURL, err := sp.BuildAuthURL("")
if err != nil {
return "", err
auth := ""
method := ""
if provider.EnableSignAuthnRequest {
post, err := sp.BuildAuthBodyPost(relayState)
if err != nil {
return "", "", err
}
auth = string(post[:])
method = "POST"
} else {
auth, err = sp.BuildAuthURL(relayState)
if err != nil {
return "", "", err
}
method = "GET"
}
return authURL, nil
return auth, method, nil
}

func buildSp(provider *Provider, samlResponse string) (*saml2.SAMLServiceProvider, error) {
Expand All @@ -80,13 +93,16 @@ func buildSp(provider *Provider, samlResponse string) (*saml2.SAMLServiceProvide
ServiceProviderIssuer: fmt.Sprintf("%s/api/acs", origin),
AssertionConsumerServiceURL: fmt.Sprintf("%s/api/acs", origin),
IDPCertificateStore: &certStore,
SignAuthnRequests: false,
SPKeyStore: dsig.RandomKeyStoreForTest(),
}
if provider.Endpoint != "" {
randomKeyStore := dsig.RandomKeyStoreForTest()
sp.IdentityProviderSSOURL = provider.Endpoint
sp.IdentityProviderIssuer = provider.IssuerUrl
sp.SignAuthnRequests = false
sp.SPKeyStore = randomKeyStore
}
if provider.EnableSignAuthnRequest {
sp.SignAuthnRequests = true
sp.SPKeyStore = buildSpKeyStore()
}
return sp, nil
}
Expand All @@ -99,10 +115,21 @@ func parseSamlResponse(samlResponse string, providerType string) string {
deStr := strings.Replace(string(de), "\n", "", -1)
tagMap := map[string]string{
"Aliyun IDaaS": "ds",
"Keycloak": "dsig",
"Keycloak": "dsig",
}
tag := tagMap[providerType]
expression := fmt.Sprintf("<%s:X509Certificate>([\\s\\S]*?)</%s:X509Certificate>", tag, tag)
res := regexp.MustCompile(expression).FindStringSubmatch(deStr)
return res[1]
}
}

func buildSpKeyStore() dsig.X509KeyStore {
keyPair, err := tls.LoadX509KeyPair("object/token_jwt_key.pem", "object/token_jwt_key.key")
if err != nil {
panic(err)
}
return &dsig.TLSCertKeyStore {
PrivateKey: keyPair.PrivateKey,
Certificate: keyPair.Certificate,
}
}
12 changes: 11 additions & 1 deletion web/src/ProviderEditPage.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

import React from "react";
import {Button, Card, Col, Input, InputNumber, Row, Select} from 'antd';
import {Button, Card, Col, Input, InputNumber, Row, Select, Switch} from 'antd';
import {LinkOutlined} from "@ant-design/icons";
import * as ProviderBackend from "./backend/ProviderBackend";
import * as Setting from "./Setting";
Expand Down Expand Up @@ -418,6 +418,16 @@ class ProviderEditPage extends React.Component {
</React.Fragment>
) : this.state.provider.category === "SAML" ? (
<React.Fragment>
<Row style={{marginTop: '20px'}} >
<Col style={{marginTop: '5px'}} span={(Setting.isMobile()) ? 22 : 2}>
{Setting.getLabel(i18next.t("provider:Sign request"), i18next.t("provider:Sign request - Tooltip"))} :
</Col>
<Col span={22} >
<Switch checked={this.state.provider.enableSignAuthnRequest} onChange={checked => {
this.updateProviderField('enableSignAuthnRequest', checked);
}} />
</Col>
</Row>
<Row style={{marginTop: '20px'}} >
<Col style={{marginTop: '5px'}} span={(Setting.isMobile()) ? 22 : 2}>
{Setting.getLabel(i18next.t("provider:Metadata"), i18next.t("provider:Metadata - Tooltip"))} :
Expand Down
4 changes: 2 additions & 2 deletions web/src/auth/AuthBackend.js
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ export function unlink(values) {
}).then(res => res.json());
}

export function getSamlLogin(providerId) {
return fetch(`${authConfig.serverUrl}/api/get-saml-login?id=${providerId}`, {
export function getSamlLogin(providerId, relayState) {
return fetch(`${authConfig.serverUrl}/api/get-saml-login?id=${providerId}&relayState=${relayState}`, {
method: 'GET',
credentials: 'include',
}).then(res => res.json());
Expand Down
10 changes: 7 additions & 3 deletions web/src/auth/LoginPage.js
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,13 @@ class LoginPage extends React.Component {
let realRedirectUri = params.get("redirect_uri");
let redirectUri = `${window.location.origin}/callback/saml`;
let providerName = provider.name;
AuthBackend.getSamlLogin(`${provider.owner}/${providerName}`).then((res) => {
const replyState = `${clientId}&${application}&${providerName}&${realRedirectUri}&${redirectUri}`;
window.location.href = `${res.data}&RelayState=${btoa(replyState)}`;
let relayState = `${clientId}&${application}&${providerName}&${realRedirectUri}&${redirectUri}`;
AuthBackend.getSamlLogin(`${provider.owner}/${providerName}`, btoa(relayState)).then((res) => {
if (res.data2 === "POST") {
document.write(res.data)
} else {
window.location.href = res.data
}
});
}

Expand Down