Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OIDC claim names options #1594

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Use error group handling to ensure tests actually pass [#1535](https://github.co
Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480)
Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524)
Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563)
Add `oidc.groups_claim` and `oidc.email_claim` to allow setting those claim names [#1594](https://github.com/juanfont/headscale/pull/1594)

## 0.22.3 (2023-05-12)

Expand Down
5 changes: 5 additions & 0 deletions docs/oidc.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ oidc:
# It resolves environment variables, making integration to systemd's
# `LoadCredential` straightforward:
#client_secret_path: "${CREDENTIALS_DIRECTORY}/oidc_client_secret"
# If provided, the name of a custom OIDC claim for specifying user groups.
# The claim value is expected to be a string or array of strings.
groups_claim: groups
# The OIDC claim to use as the email.
email_claim: email

# Customize the scopes used in the OIDC flow, defaults to "openid", "profile" and "email" and add custom query
# parameters to the Authorize Endpoint request. Scopes default to "openid", "profile" and "email".
Expand Down
86 changes: 72 additions & 14 deletions hscontrol/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/rand"
_ "embed"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"html/template"
Expand Down Expand Up @@ -40,14 +41,45 @@ var (
errOIDCInvalidNodeState = errors.New(
"requested node state key expired before authorisation completed",
)
errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
errOIDCEmailClaimMissing = errors.New("email claim missing from ID Token")
)

type IDTokenClaims struct {
Name string `json:"name,omitempty"`
Groups []string `json:"groups,omitempty"`
Email string `json:"email"`
Username string `json:"preferred_username,omitempty"`
// in some cases the groups might be a single value and not a list
Groups stringOrArray
Email string
}

type stringOrArray []string

func (s *stringOrArray) UnmarshalJSON(b []byte) error {
var a []string
if err := json.Unmarshal(b, &a); err == nil {
*s = a
return nil
}
var str string
if err := json.Unmarshal(b, &str); err != nil {
return err
}
*s = []string{str}
return nil
}

type rawClaims map[string]json.RawMessage

func (c rawClaims) unmarshalClaim(name string, v interface{}) error {
val, ok := c[name]
if !ok {
return fmt.Errorf("claim not present")
}
return json.Unmarshal([]byte(val), v)
}

func (c rawClaims) hasClaim(name string) bool {
_, ok := c[name]
return ok
}

func (h *Headscale) initOIDC() error {
Expand Down Expand Up @@ -215,7 +247,7 @@ func (h *Headscale) OIDCCallback(
// return
// }

claims, err := extractIDTokenClaims(writer, idToken)
claims, err := extractIDTokenClaims(writer, h.cfg.OIDC, idToken)
if err != nil {
return
}
Expand Down Expand Up @@ -355,25 +387,51 @@ func (h *Headscale) verifyIDTokenForOIDCCallback(

func extractIDTokenClaims(
writer http.ResponseWriter,
cfg types.OIDCConfig,
idToken *oidc.IDToken,
) (*IDTokenClaims, error) {
var claims IDTokenClaims
if err := idToken.Claims(&claims); err != nil {
util.LogErr(err, "Failed to decode id token claims")
var rawClaims rawClaims
if err := idToken.Claims(&rawClaims); err != nil {
handleClaimError(writer, err)

writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to decode id token claims"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
return nil, err
}

if !rawClaims.hasClaim(cfg.EmailClaim) {
handleClaimError(writer, errOIDCEmailClaimMissing)

return nil, errOIDCEmailClaimMissing
}

if err := rawClaims.unmarshalClaim(cfg.EmailClaim, &claims.Email); err != nil {
handleClaimError(writer, err)

return nil, err
}

if rawClaims.hasClaim(cfg.GroupsClaim) {
if err := rawClaims.unmarshalClaim(cfg.GroupsClaim, &claims.Groups); err != nil {
handleClaimError(writer, err)

return nil, err
}
}

return &claims, nil
}

func handleClaimError(writer http.ResponseWriter, err error) {
util.LogErr(err, "Failed to decode id token rawClaims")

writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to decode id token rawClaims"))
if werr != nil {
util.LogErr(err, "Failed to write response")
}
}

// validateOIDCAllowedDomains checks that if AllowedDomains is provided,
// that the authenticated principal ends with @<alloweddomain>.
func validateOIDCAllowedDomains(
Expand Down
164 changes: 164 additions & 0 deletions hscontrol/oidc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package hscontrol

import (
"context"
"crypto"
"crypto/rand"
"crypto/rsa"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/go-jose/go-jose/v3"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"net/http/httptest"
"reflect"
"testing"
)

func Test_extractIDTokenClaims(t *testing.T) {
tests := []verificationTest{
{
name: "default claim names",
idToken: `{"iss":"https://foo", "email": "foo@bar.baz", "groups": ["group1", "group2"]}`,
cfg: types.OIDCConfig{
EmailClaim: "email",
GroupsClaim: "groups",
},
want: &IDTokenClaims{
Groups: []string{"group1", "group2"},
Email: "foo@bar.baz",
},
wantErr: false,
},
{
name: "custom claim names",
idToken: `{"iss":"https://foo", "my_custom_claim": "foo@bar.baz", "https://foo.baz/groups": ["group3", "group4"]}`,
cfg: types.OIDCConfig{
EmailClaim: "my_custom_claim",
GroupsClaim: "https://foo.baz/groups",
},
want: &IDTokenClaims{
Groups: []string{"group3", "group4"},
Email: "foo@bar.baz",
},
wantErr: false,
},
{
name: "group claim not present",
idToken: `{"iss":"https://foo", "my_custom_claim": "foo@bar.baz"}`,
cfg: types.OIDCConfig{
EmailClaim: "my_custom_claim",
GroupsClaim: "https://foo.baz/groups",
},
want: &IDTokenClaims{
Email: "foo@bar.baz",
},
wantErr: false,
},
{
name: "email claim not present",
idToken: `{"iss":"https://foo", "groups": ["group1", "group2"]}`,
cfg: types.OIDCConfig{
EmailClaim: "email",
GroupsClaim: "groups",
},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
token, err := tt.getToken(t)
if err != nil {
t.Errorf("could not parse the token: %v", err)

return
}

if !tt.wantErr {
assert.Equal(t, 200, recorder.Result().StatusCode)
assert.Empty(t, recorder.Result().Header)
}

got, err := extractIDTokenClaims(recorder, tt.cfg, token)
if (err != nil) != tt.wantErr {
t.Errorf("extractIDTokenClaims() error = %v, wantErr %v", err, tt.wantErr)

return
}

if !reflect.DeepEqual(got, tt.want) {
t.Errorf("extractIDTokenClaims() got = %v, want %v", got, tt.want)

return
}
})
}
}

type signingKey struct {
keyID string
key interface{}
pub interface{}
alg jose.SignatureAlgorithm
}

// sign creates a JWS using the private key from the provided payload.
func (s *signingKey) sign(t testing.TB, payload []byte) string {
privKey := &jose.JSONWebKey{Key: s.key, Algorithm: string(s.alg), KeyID: s.keyID}

signer, err := jose.NewSigner(jose.SigningKey{Algorithm: s.alg, Key: privKey}, nil)
if err != nil {
t.Fatal(err)
}
jws, err := signer.Sign(payload)
if err != nil {
t.Fatal(err)
}

data, err := jws.CompactSerialize()
if err != nil {
t.Fatal(err)
}

return data
}

type verificationTest struct {
name string
idToken string
cfg types.OIDCConfig
want *IDTokenClaims
wantErr bool
}

func newRSAKey(t testing.TB) *signingKey {
priv, err := rsa.GenerateKey(rand.Reader, 1028)
if err != nil {
t.Fatal(err)
}

return &signingKey{"", priv, priv.Public(), jose.RS256}
}

func (v verificationTest) getToken(t *testing.T) (*oidc.IDToken, error) {
key := newRSAKey(t)
token := key.sign(t, []byte(v.idToken))

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

verifier := oidc.NewVerifier(
"https://foo",
&oidc.StaticKeySet{PublicKeys: []crypto.PublicKey{key.pub}},
&oidc.Config{
SkipClientIDCheck: true,
SkipExpiryCheck: true,
SkipIssuerCheck: true,
InsecureSkipSignatureCheck: true,
},
)

return verifier.Verify(ctx, token)
}
6 changes: 6 additions & 0 deletions hscontrol/types/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ type OIDCConfig struct {
AllowedDomains []string
AllowedUsers []string
AllowedGroups []string
GroupsClaim string
EmailClaim string
StripEmaildomain bool
Expiry time.Duration
UseExpiryFromToken bool
Expand Down Expand Up @@ -185,6 +187,8 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("oidc.strip_email_domain", true)
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
viper.SetDefault("oidc.expiry", "180d")
viper.SetDefault("oidc.groups_claim", "groups")
viper.SetDefault("oidc.email_claim", "email")
viper.SetDefault("oidc.use_expiry_from_token", false)

viper.SetDefault("logtail.enabled", false)
Expand Down Expand Up @@ -631,6 +635,8 @@ func GetHeadscaleConfig() (*Config, error) {
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
GroupsClaim: viper.GetString("oidc.groups_claim"),
EmailClaim: viper.GetString("oidc.email_claim"),
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
Expiry: func() time.Duration {
// if set to 0, we assume no expiry
Expand Down