Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
IamTaoChen committed May 9, 2024
1 parent 77c6bca commit bd78f56
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 54 deletions.
24 changes: 3 additions & 21 deletions hscontrol/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,15 +350,15 @@ func extractIDTokenClaims(
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to decode id token claims"))
if werr != nil {
util.LogErr(err,"Failed to write response")
util.LogErr(err, "Failed to write response")
}
return nil, err
}

// Unmarshal the claims into a map
mappedClaims := make(map[string]interface{})
if err := json.Unmarshal(claims, &mappedClaims); err != nil {
util.LogErr(err,"Failed to unmarshal id token claims")
util.LogErr(err, "Failed to unmarshal id token claims")
return nil, err
}

Expand Down Expand Up @@ -388,24 +388,6 @@ func extractIDTokenClaims(
return &finalClaims, nil
}

// {
// var claims IDTokenClaims
// if err := idToken.Claims(&claims); err != nil {
// util.LogErr(err, "Failed to decode id token claims")

// writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
// writer.WriteHeader(http.StatusBadRequest)
// _, werr := writer.Write([]byte("Failed to decode id token claims"))
// if werr != nil {
// util.LogErr(err, "Failed to write response")
// }

// return nil, err
// }

// return &claims, nil
// }

// validateOIDCAllowedDomains checks that if AllowedDomains is provided,
// that the authenticated principal ends with @<alloweddomain>.
func validateOIDCAllowedDomains(
Expand Down Expand Up @@ -589,7 +571,7 @@ func getUserName(
writer http.ResponseWriter,
claims *IDTokenClaims,
stripEmaildomain bool,
) (string, error) {
) (string, error) {
userName, err := util.NormalizeToFQDNRules(
claims.Username,
stripEmaildomain,
Expand Down
66 changes: 33 additions & 33 deletions hscontrol/types/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,34 +126,34 @@ type OIDCConfig struct {
ClientSecret string
Scope []string
ExtraParams map[string]string
ClaimsMap OIDCClaimsMap
Allowed OIDCAllowedConfig
ClaimsMap OIDCClaimsMap
Allowed OIDCAllowedConfig
Expiry OIDCExpireConfig
Misc OIDCMiscConfig
Misc OIDCMiscConfig
}

type OIDCExpireConfig struct {
FromToken bool
FixedTime time.Duration
FromToken bool
FixedTime time.Duration
}

type OIDCAllowedConfig struct {
Domains []string
Users []string
Groups []string
Domains []string
Users []string
Groups []string
}

type OIDCClaimsMap struct {
Name string
Username string
Email string
Groups string
Name string
Username string
Email string
Groups string
}

type OIDCMiscConfig struct {
StripEmaildomain bool
FlattenGroups bool
FlattenSplter string
StripEmaildomain bool
FlattenGroups bool
FlattenSplter string
}

type DERPConfig struct {
Expand Down Expand Up @@ -254,9 +254,9 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("oidc.claims_map.email", "email")
viper.SetDefault("oidc.claims_map.groups", "groups")
// misc
viper.SetDefault("oidc.strip_email_domain", false)
viper.SetDefault("oidc.flatten_groups", false)
viper.SetDefault("oidc.flatten_splitter", "/")
viper.SetDefault("oidc.misc.strip_email_domain", false)
viper.SetDefault("oidc.misc.flatten_groups", false)
viper.SetDefault("oidc.misc.flatten_splitter", "/")

viper.SetDefault("logtail.enabled", false)
viper.SetDefault("randomize_client_port", false)
Expand Down Expand Up @@ -695,9 +695,9 @@ func GetOIDCConfig() (OIDCConfig, error) {
}
// get misc config
oidcMiscConfig := OIDCMiscConfig{
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
FlattenGroups: viper.GetBool("oidc.flatten_groups"),
FlattenSplter: viper.GetString("oidc.flatten_splitter"),
StripEmaildomain: viper.GetBool("oidc.misc.strip_email_domain"),
FlattenGroups: viper.GetBool("oidc.misc.flatten_groups"),
FlattenSplter: viper.GetString("oidc.misc.flatten_splitter"),
}
// get client secret
oidcClientSecret := viper.GetString("oidc.client_secret")
Expand All @@ -716,15 +716,15 @@ func GetOIDCConfig() (OIDCConfig, error) {
OnlyStartIfOIDCIsAvailable: viper.GetBool(
"oidc.only_start_if_oidc_is_available",
),
Issuer: viper.GetString("oidc.issuer"),
ClientID: viper.GetString("oidc.client_id"),
ClientSecret: oidcClientSecret,
Scope: viper.GetStringSlice("oidc.scope"),
ExtraParams: viper.GetStringMapString("oidc.extra_params"),
Allowed: oidcAllowed,
ClaimsMap: oidcClaimsMap,
Expiry: oidcExpireConfig,
Misc: oidcMiscConfig,
Issuer: viper.GetString("oidc.issuer"),
ClientID: viper.GetString("oidc.client_id"),
ClientSecret: oidcClientSecret,
Scope: viper.GetStringSlice("oidc.scope"),
ExtraParams: viper.GetStringMapString("oidc.extra_params"),
Allowed: oidcAllowed,
ClaimsMap: oidcClaimsMap,
Expiry: oidcExpireConfig,
Misc: oidcMiscConfig,
}
return OIDC, nil
}
Expand Down Expand Up @@ -810,9 +810,9 @@ func GetHeadscaleConfig() (*Config, error) {

UnixSocket: viper.GetString("unix_socket"),
UnixSocketPermission: util.GetFileMode("unix_socket_permission"),
OIDC: oidcConfig,
LogTail: logConfig,
RandomizeClientPort: randomizeClientPort,
OIDC: oidcConfig,
LogTail: logConfig,
RandomizeClientPort: randomizeClientPort,

ACL: GetACLConfig(),

Expand Down

0 comments on commit bd78f56

Please sign in to comment.