Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 41 additions & 42 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ package auth

import (
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"strings"
Expand All @@ -31,10 +29,12 @@ import (
"google.golang.org/api/transport"
)

const firebaseAudience = "https://identitytoolkit.googleapis.com/google.identity.identitytoolkit.v1.IdentityToolkit"
const googleCertURL = "https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com"
const issuerPrefix = "https://securetoken.google.com/"
const tokenExpSeconds = 3600
const (
firebaseAudience = "https://identitytoolkit.googleapis.com/google.identity.identitytoolkit.v1.IdentityToolkit"
idTokenCertURL = "https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com"
issuerPrefix = "https://securetoken.google.com/"
tokenExpSeconds = 3600
)

var reservedClaims = []string{
"acr", "amr", "at_hash", "aud", "auth_time", "azp", "cnf", "c_hash",
Expand All @@ -58,6 +58,24 @@ type Token struct {
Claims map[string]interface{} `json:"-"`
}

func (t *Token) decodeFrom(s string) error {
// Decode into a regular map to access custom claims.
claims := make(map[string]interface{})
if err := decode(s, &claims); err != nil {
return err
}
// Now decode into Token to access the standard claims.
if err := decode(s, t); err != nil {
return err
}

for _, r := range []string{"iss", "aud", "exp", "iat", "sub", "uid"} {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe another comment here like "Delete the standard claims from the custom claims map."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

delete(claims, r)
}
t.Claims = claims
return nil
}

// Client is the interface for the Firebase auth service.
//
// Client facilitates generating custom JWT tokens for Firebase clients, and verifying ID tokens issued
Expand Down Expand Up @@ -94,7 +112,7 @@ func NewClient(ctx context.Context, c *internal.AuthConfig) (*Client, error) {
return nil, err
}
if svcAcct.PrivateKey != "" {
pk, err = parseKey(svcAcct.PrivateKey)
pk, err = parsePrivateKey(svcAcct.PrivateKey)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -124,7 +142,7 @@ func NewClient(ctx context.Context, c *internal.AuthConfig) (*Client, error) {

return &Client{
is: is,
ks: newHTTPKeySource(googleCertURL, hc),
ks: newHTTPKeySource(idTokenCertURL, hc),
projectID: c.ProjectID,
snr: snr,
version: "Go/Admin/" + c.Version,
Expand Down Expand Up @@ -164,6 +182,7 @@ func (c *Client) CustomTokenWithClaims(ctx context.Context, uid string, devClaim
}

now := clk.Now().Unix()
header := jwtHeader{Algorithm: "RS256", Type: "JWT"}
payload := &customToken{
Iss: iss,
Sub: iss,
Expand All @@ -173,7 +192,7 @@ func (c *Client) CustomTokenWithClaims(ctx context.Context, uid string, devClaim
Exp: now + tokenExpSeconds,
Claims: devClaims,
}
return encodeToken(ctx, c.snr, defaultHeader(), payload)
return encodeToken(ctx, c.snr, header, payload)
}

// RevokeRefreshTokens revokes all refresh tokens issued to a user.
Expand Down Expand Up @@ -201,7 +220,7 @@ func (c *Client) VerifyIDToken(ctx context.Context, idToken string) (*Token, err
return nil, errors.New("project id not available")
}
if idToken == "" {
return nil, fmt.Errorf("ID token must be a non-empty string")
return nil, fmt.Errorf("id token must be a non-empty string")
}

h := &jwtHeader{}
Expand All @@ -210,36 +229,36 @@ func (c *Client) VerifyIDToken(ctx context.Context, idToken string) (*Token, err
return nil, err
}

projectIDMsg := "Make sure the ID token comes from the same Firebase project as the credential used to" +
" authenticate this SDK."
verifyTokenMsg := "See https://firebase.google.com/docs/auth/admin/verify-id-tokens for details on how to " +
"retrieve a valid ID token."
projectIDMsg := "make sure the ID token comes from the same Firebase project as the credential used to" +
" authenticate this SDK"
verifyTokenMsg := "see https://firebase.google.com/docs/auth/admin/verify-id-tokens for details on how to " +
"retrieve a valid ID token"
issuer := issuerPrefix + c.projectID

var err error
if h.KeyID == "" {
if p.Audience == firebaseAudience {
err = fmt.Errorf("VerifyIDToken() expects an ID token, but was given a custom token")
err = fmt.Errorf("expected an ID token but got a custom token")
} else {
err = fmt.Errorf("ID token has no 'kid' header")
}
} else if h.Algorithm != "RS256" {
err = fmt.Errorf("ID token has invalid incorrect algorithm. Expected 'RS256' but got %q. %s",
err = fmt.Errorf("ID token has invalid algorithm; expected 'RS256' but got %q; %s",
h.Algorithm, verifyTokenMsg)
} else if p.Audience != c.projectID {
err = fmt.Errorf("ID token has invalid 'aud' (audience) claim. Expected %q but got %q. %s %s",
err = fmt.Errorf("ID token has invalid 'aud' (audience) claim; expected %q but got %q; %s; %s",
c.projectID, p.Audience, projectIDMsg, verifyTokenMsg)
} else if p.Issuer != issuer {
err = fmt.Errorf("ID token has invalid 'iss' (issuer) claim. Expected %q but got %q. %s %s",
err = fmt.Errorf("ID token has invalid 'iss' (issuer) claim; expected %q but got %q; %s; %s",
issuer, p.Issuer, projectIDMsg, verifyTokenMsg)
} else if p.IssuedAt > clk.Now().Unix() {
err = fmt.Errorf("ID token issued at future timestamp: %d", p.IssuedAt)
} else if p.Expires < clk.Now().Unix() {
err = fmt.Errorf("ID token has expired. Expired at: %d", p.Expires)
err = fmt.Errorf("ID token has expired at: %d", p.Expires)
} else if p.Subject == "" {
err = fmt.Errorf("ID token has empty 'sub' (subject) claim. %s", verifyTokenMsg)
err = fmt.Errorf("ID token has empty 'sub' (subject) claim; %s", verifyTokenMsg)
} else if len(p.Subject) > 128 {
err = fmt.Errorf("ID token has a 'sub' (subject) claim longer than 128 characters. %s", verifyTokenMsg)
err = fmt.Errorf("ID token has a 'sub' (subject) claim longer than 128 characters; %s", verifyTokenMsg)
}

if err != nil {
Expand All @@ -265,27 +284,7 @@ func (c *Client) VerifyIDTokenAndCheckRevoked(ctx context.Context, idToken strin
}

if p.IssuedAt*1000 < user.TokensValidAfterMillis {
return nil, internal.Error(idTokenRevoked, "ID token has been revoked")
return nil, internal.Error(idTokenRevoked, "id token has been revoked")
}
return p, nil
}

func parseKey(key string) (*rsa.PrivateKey, error) {
block, _ := pem.Decode([]byte(key))
if block == nil {
return nil, fmt.Errorf("no private key data found in: %v", key)
}
k := block.Bytes
parsedKey, err := x509.ParsePKCS8PrivateKey(k)
if err != nil {
parsedKey, err = x509.ParsePKCS1PrivateKey(k)
if err != nil {
return nil, fmt.Errorf("private key should be a PEM or plain PKSC1 or PKCS8; parse error: %v", err)
}
}
parsed, ok := parsedKey.(*rsa.PrivateKey)
if !ok {
return nil, errors.New("private key is not an RSA key")
}
return parsed, nil
}
19 changes: 11 additions & 8 deletions auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ import (
"firebase.google.com/go/internal"
)

var client *Client
var ctx context.Context
var testIDToken string
var testGetUserResponse []byte
var testListUsersResponse []byte
var (
client *Client
ctx context.Context
testIDToken string
testGetUserResponse []byte
testListUsersResponse []byte
)

var defaultTestOpts = []option.ClientOption{
option.WithCredentialsFile("../testdata/service_account.json"),
}
Expand Down Expand Up @@ -232,7 +235,7 @@ func TestVerifyIDTokenAndCheckRevokedInvalidated(t *testing.T) {
tok := getIDToken(mockIDTokenPayload{"uid": "uid", "iat": 1970}) // old token

p, err := s.Client.VerifyIDTokenAndCheckRevoked(ctx, tok)
we := "ID token has been revoked"
we := "id token has been revoked"
if p != nil || err == nil || err.Error() != we || !IsIDTokenRevoked(err) {
t.Errorf("VerifyIDTokenAndCheckRevoked(ctx, token) =(%v, %v); want = (%v, %v)",
p, err, nil, we)
Expand Down Expand Up @@ -371,7 +374,7 @@ func getIDTokenWithKid(kid string, p mockIDTokenPayload) string {
for k, v := range p {
pCopy[k] = v
}
h := defaultHeader()
h := jwtHeader{Algorithm: "RS256", Type: "JWT"}
h.KeyID = kid
token, err := encodeToken(ctx, client.snr, h, pCopy)
if err != nil {
Expand All @@ -382,7 +385,7 @@ func getIDTokenWithKid(kid string, p mockIDTokenPayload) string {

type mockIDTokenPayload map[string]interface{}

func (p mockIDTokenPayload) decode(s string) error {
func (p mockIDTokenPayload) decodeFrom(s string) error {
return decode(s, &p)
}

Expand Down
21 changes: 21 additions & 0 deletions auth/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io/ioutil"
"net/http"
"strconv"
Expand Down Expand Up @@ -185,6 +186,26 @@ func parsePublicKey(kid string, key []byte) (*publicKey, error) {
return &publicKey{kid, pk}, nil
}

func parsePrivateKey(key string) (*rsa.PrivateKey, error) {
block, _ := pem.Decode([]byte(key))
if block == nil {
return nil, fmt.Errorf("no private key data found in: %v", key)
}
k := block.Bytes
parsedKey, err := x509.ParsePKCS8PrivateKey(k)
if err != nil {
parsedKey, err = x509.ParsePKCS1PrivateKey(k)
if err != nil {
return nil, fmt.Errorf("private key should be a PEM or plain PKSC1 or PKCS8; parse error: %v", err)
}
}
parsed, ok := parsedKey.(*rsa.PrivateKey)
if !ok {
return nil, errors.New("private key is not an RSA key")
}
return parsed, nil
}

func verifySignature(parts []string, k *publicKey) error {
content := parts[0] + "." + parts[1]
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
Expand Down
42 changes: 11 additions & 31 deletions auth/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type jwtHeader struct {
}

type jwtPayload interface {
decode(s string) error
decodeFrom(s string) error
}

type customToken struct {
Expand All @@ -45,38 +45,11 @@ type customToken struct {
Claims map[string]interface{} `json:"claims,omitempty"`
}

func (p *customToken) decode(s string) error {
func (p *customToken) decodeFrom(s string) error {
return decode(s, p)
}

func (t *Token) decode(s string) error {
claims := make(map[string]interface{})
if err := decode(s, &claims); err != nil {
return err
}
if err := decode(s, t); err != nil {
return err
}

for _, r := range []string{"iss", "aud", "exp", "iat", "sub", "uid"} {
delete(claims, r)
}
t.Claims = claims
return nil
}

func defaultHeader() jwtHeader {
return jwtHeader{Algorithm: "RS256", Type: "JWT"}
}

func encode(i interface{}) (string, error) {
b, err := json.Marshal(i)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}

// decode accepts a JWT segment, and decodes it into the given interface.
func decode(s string, i interface{}) error {
decoded, err := base64.RawURLEncoding.DecodeString(s)
if err != nil {
Expand All @@ -86,6 +59,13 @@ func decode(s string, i interface{}) error {
}

func encodeToken(ctx context.Context, s signer, h jwtHeader, p jwtPayload) (string, error) {
encode := func(i interface{}) (string, error) {
b, err := json.Marshal(i)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
header, err := encode(h)
if err != nil {
return "", err
Expand All @@ -112,7 +92,7 @@ func decodeToken(ctx context.Context, token string, ks keySource, h *jwtHeader,
if err := decode(s[0], h); err != nil {
return err
}
if err := p.decode(s[1]); err != nil {
if err := p.decodeFrom(s[1]); err != nil {
return err
}

Expand Down
6 changes: 3 additions & 3 deletions auth/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
)

func TestEncodeToken(t *testing.T) {
h := defaultHeader()
h := jwtHeader{Algorithm: "RS256", Type: "JWT"}
p := mockIDTokenPayload{"key": "value"}
s, err := encodeToken(ctx, &mockSigner{}, h, p)
if err != nil {
Expand Down Expand Up @@ -57,7 +57,7 @@ func TestEncodeToken(t *testing.T) {
}

func TestEncodeSignError(t *testing.T) {
h := defaultHeader()
h := jwtHeader{Algorithm: "RS256", Type: "JWT"}
p := mockIDTokenPayload{"key": "value"}
signer := &mockSigner{
err: errors.New("sign error"),
Expand All @@ -68,7 +68,7 @@ func TestEncodeSignError(t *testing.T) {
}

func TestEncodeInvalidPayload(t *testing.T) {
h := defaultHeader()
h := jwtHeader{Algorithm: "RS256", Type: "JWT"}
p := mockIDTokenPayload{"key": func() {}}
if s, err := encodeToken(ctx, &mockSigner{}, h, p); s != "" || err == nil {
t.Errorf("encodeToken() = (%v, %v); want = ('', error)", s, err)
Expand Down