Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
hiranya911 committed Oct 13, 2017
2 parents e9de174 + 82c46f4 commit aed47cf
Show file tree
Hide file tree
Showing 14 changed files with 633 additions and 80 deletions.
7 changes: 4 additions & 3 deletions auth/auth.go
Expand Up @@ -25,6 +25,7 @@ import (
"strings"

"firebase.google.com/go/internal"
"golang.org/x/net/context"
)

const firebaseAudience = "https://identitytoolkit.googleapis.com/google.identity.identitytoolkit.v1.IdentityToolkit"
Expand Down Expand Up @@ -73,7 +74,7 @@ type signer interface {
//
// This function can only be invoked from within the SDK. Client applications should access the
// the Auth service through firebase.App.
func NewClient(c *internal.AuthConfig) (*Client, error) {
func NewClient(ctx context.Context, c *internal.AuthConfig) (*Client, error) {
var (
err error
email string
Expand All @@ -99,13 +100,13 @@ func NewClient(c *internal.AuthConfig) (*Client, error) {
if email != "" && pk != nil {
snr = serviceAcctSigner{email: email, pk: pk}
} else {
snr, err = newSigner(c.Ctx)
snr, err = newSigner(ctx)
if err != nil {
return nil, err
}
}

ks, err := newHTTPKeySource(c.Ctx, googleCertURL, c.Opts...)
ks, err := newHTTPKeySource(ctx, googleCertURL, c.Opts...)
if err != nil {
return nil, err
}
Expand Down
72 changes: 52 additions & 20 deletions auth/auth_test.go
Expand Up @@ -15,7 +15,9 @@
package auth

import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"log"
"os"
Expand Down Expand Up @@ -58,17 +60,16 @@ func TestMain(m *testing.M) {
log.Fatalln(err)
}
} else {
opt := option.WithCredentialsFile("../testdata/service_account.json")
creds, err = transport.Creds(context.Background(), opt)
ctx = context.Background()
creds, err = transport.Creds(ctx, option.WithCredentialsFile("../testdata/service_account.json"))
if err != nil {
log.Fatalln(err)
}

ks = &fileKeySource{FilePath: "../testdata/public_certs.json"}
}

client, err = NewClient(&internal.AuthConfig{
Ctx: ctx,
client, err = NewClient(ctx, &internal.AuthConfig{
Creds: creds,
ProjectID: "mock-project-id",
})
Expand All @@ -81,6 +82,32 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}

func TestNewClientInvalidCredentials(t *testing.T) {
creds := &google.DefaultCredentials{
JSON: []byte("foo"),
}
conf := &internal.AuthConfig{Creds: creds}
if c, err := NewClient(context.Background(), conf); c != nil || err == nil {
t.Errorf("NewCient() = (%v,%v); want = (nil, error)", c, err)
}
}

func TestNewClientInvalidPrivateKey(t *testing.T) {
sa := map[string]interface{}{
"private_key": "foo",
"client_email": "bar@test.com",
}
b, err := json.Marshal(sa)
if err != nil {
t.Fatal(err)
}
creds := &google.DefaultCredentials{JSON: b}
conf := &internal.AuthConfig{Creds: creds}
if c, err := NewClient(context.Background(), conf); c != nil || err == nil {
t.Errorf("NewCient() = (%v,%v); want = (nil, error)", c, err)
}
}

func TestCustomToken(t *testing.T) {
token, err := client.CustomToken("user1")
if err != nil {
Expand Down Expand Up @@ -118,31 +145,32 @@ func TestCustomTokenError(t *testing.T) {
}{
{"EmptyName", "", nil},
{"LongUid", strings.Repeat("a", 129), nil},
{"ReservedClaims", "uid", map[string]interface{}{"sub": "1234"}},
{"ReservedClaim", "uid", map[string]interface{}{"sub": "1234"}},
{"ReservedClaims", "uid", map[string]interface{}{"sub": "1234", "aud": "foo"}},
}

for _, tc := range cases {
token, err := client.CustomTokenWithClaims(tc.uid, tc.claims)
if token != "" || err == nil {
t.Errorf("CustomTokenWithClaims(%q) = (%q, %v); want: (\"\", error)", tc.name, token, err)
t.Errorf("CustomTokenWithClaims(%q) = (%q, %v); want = (\"\", error)", tc.name, token, err)
}
}
}

func TestCustomTokenInvalidCredential(t *testing.T) {
s, err := NewClient(&internal.AuthConfig{Ctx: context.Background()})
s, err := NewClient(context.Background(), &internal.AuthConfig{})
if err != nil {
t.Fatal(err)
}

token, err := s.CustomToken("user1")
if token != "" || err == nil {
t.Errorf("CustomTokenWithClaims() = (%q, %v); want: (\"\", error)", token, err)
t.Errorf("CustomTokenWithClaims() = (%q, %v); want = (\"\", error)", token, err)
}

token, err = s.CustomTokenWithClaims("user1", map[string]interface{}{"foo": "bar"})
if token != "" || err == nil {
t.Errorf("CustomTokenWithClaims() = (%q, %v); want: (\"\", error)", token, err)
t.Errorf("CustomTokenWithClaims() = (%q, %v); want = (\"\", error)", token, err)
}
}

Expand All @@ -152,15 +180,23 @@ func TestVerifyIDToken(t *testing.T) {
t.Fatal(err)
}
if ft.Claims["admin"] != true {
t.Errorf("Claims['admin'] = %v; want: true", ft.Claims["admin"])
t.Errorf("Claims['admin'] = %v; want = true", ft.Claims["admin"])
}
if ft.UID != ft.Subject {
t.Errorf("UID = %q; Sub = %q; want UID = Sub", ft.UID, ft.Subject)
}
}

func TestVerifyIDTokenInvalidSignature(t *testing.T) {
parts := strings.Split(testIDToken, ".")
token := fmt.Sprintf("%s:%s:invalidsignature", parts[0], parts[1])
if ft, err := client.VerifyIDToken(token); ft != nil || err == nil {
t.Errorf("VerifyiDToken('invalid-signature') = (%v, %v); want = (nil, error)", ft, err)
}
}

func TestVerifyIDTokenError(t *testing.T) {
var now int64 = 1000
now := time.Now().Unix()
cases := []struct {
name string
token string
Expand All @@ -172,28 +208,24 @@ func TestVerifyIDTokenError(t *testing.T) {
{"EmptySubject", getIDToken(mockIDTokenPayload{"sub": ""})},
{"IntSubject", getIDToken(mockIDTokenPayload{"sub": 10})},
{"LongSubject", getIDToken(mockIDTokenPayload{"sub": strings.Repeat("a", 129)})},
{"FutureToken", getIDToken(mockIDTokenPayload{"iat": time.Unix(now+1, 0)})},
{"FutureToken", getIDToken(mockIDTokenPayload{"iat": now + 1000})},
{"ExpiredToken", getIDToken(mockIDTokenPayload{
"iat": time.Unix(now-10, 0),
"exp": time.Unix(now-1, 0),
"iat": now - 1000,
"exp": now - 100,
})},
{"EmptyToken", ""},
{"BadFormatToken", "foobar"},
}

clk = &mockClock{now: time.Unix(now, 0)}
defer func() {
clk = &systemClock{}
}()
for _, tc := range cases {
if _, err := client.VerifyIDToken(tc.token); err == nil {
t.Errorf("VerifyyIDToken(%q) = nil; want error", tc.name)
t.Errorf("VerifyIDToken(%q) = nil; want error", tc.name)
}
}
}

func TestNoProjectID(t *testing.T) {
c, err := NewClient(&internal.AuthConfig{Ctx: context.Background()})
c, err := NewClient(context.Background(), &internal.AuthConfig{})
if err != nil {
t.Fatal(err)
}
Expand Down
7 changes: 2 additions & 5 deletions auth/crypto.go
Expand Up @@ -148,13 +148,10 @@ func (k *httpKeySource) refreshKeys() error {

func findMaxAge(resp *http.Response) (*time.Duration, error) {
cc := resp.Header.Get("cache-control")
for _, value := range strings.Split(cc, ", ") {
for _, value := range strings.Split(cc, ",") {
value = strings.TrimSpace(value)
if strings.HasPrefix(value, "max-age") {
if strings.HasPrefix(value, "max-age=") {
sep := strings.Index(value, "=")
if sep == -1 {
return nil, errors.New("Malformed cache-control header")
}
seconds, err := strconv.ParseInt(value[sep+1:], 10, 64)
if err != nil {
return nil, err
Expand Down
117 changes: 117 additions & 0 deletions auth/crypto_test.go
Expand Up @@ -15,6 +15,7 @@
package auth

import (
"errors"
"fmt"
"io"
"io/ioutil"
Expand Down Expand Up @@ -122,6 +123,122 @@ func TestHTTPKeySourceWithClient(t *testing.T) {
}
}

func TestHTTPKeySourceEmptyResponse(t *testing.T) {
hc, _ := newHTTPClient([]byte(""))
ks, err := newHTTPKeySource(context.Background(), "http://mock.url", option.WithHTTPClient(hc))
if err != nil {
t.Fatal(err)
}

if keys, err := ks.Keys(); keys != nil || err == nil {
t.Errorf("Keys() = (%v, %v); want = (nil, error)", keys, err)
}
}

func TestHTTPKeySourceIncorrectResponse(t *testing.T) {
hc, _ := newHTTPClient([]byte("{\"foo\": 1}"))
ks, err := newHTTPKeySource(context.Background(), "http://mock.url", option.WithHTTPClient(hc))
if err != nil {
t.Fatal(err)
}

if keys, err := ks.Keys(); keys != nil || err == nil {
t.Errorf("Keys() = (%v, %v); want = (nil, error)", keys, err)
}
}

func TestHTTPKeySourceTransportError(t *testing.T) {
hc := &http.Client{
Transport: &mockHTTPResponse{
Err: errors.New("transport error"),
},
}
ks, err := newHTTPKeySource(context.Background(), "http://mock.url", option.WithHTTPClient(hc))
if err != nil {
t.Fatal(err)
}

if keys, err := ks.Keys(); keys != nil || err == nil {
t.Errorf("Keys() = (%v, %v); want = (nil, error)", keys, err)
}
}

func TestFindMaxAge(t *testing.T) {
cases := []struct {
cc string
want int64
}{
{"max-age=100", 100},
{"public, max-age=100", 100},
{"public,max-age=100", 100},
}
for _, tc := range cases {
resp := &http.Response{
Header: http.Header{"Cache-Control": {tc.cc}},
}
age, err := findMaxAge(resp)
if err != nil {
t.Errorf("findMaxAge(%q) = %v", tc.cc, err)
} else if *age != (time.Duration(tc.want) * time.Second) {
t.Errorf("findMaxAge(%q) = %v; want %v", tc.cc, *age, tc.want)
}
}
}

func TestFindMaxAgeError(t *testing.T) {
cases := []string{
"",
"max-age 100",
"max-age: 100",
"max-age2=100",
"max-age=foo",
}
for _, tc := range cases {
resp := &http.Response{
Header: http.Header{"Cache-Control": []string{tc}},
}
if age, err := findMaxAge(resp); age != nil || err == nil {
t.Errorf("findMaxAge(%q) = (%v, %v); want = (nil, err)", tc, age, err)
}
}
}

func TestParsePublicKeys(t *testing.T) {
b, err := ioutil.ReadFile("../testdata/public_certs.json")
if err != nil {
t.Fatal(err)
}
keys, err := parsePublicKeys(b)
if err != nil {
t.Fatal(err)
}
if len(keys) != 3 {
t.Errorf("parsePublicKeys() = %d; want: %d", len(keys), 3)
}
}

func TestParsePublicKeysError(t *testing.T) {
cases := []string{
"",
"not-json",
}
for _, tc := range cases {
if keys, err := parsePublicKeys([]byte(tc)); keys != nil || err == nil {
t.Errorf("parsePublicKeys(%q) = (%v, %v); want: (nil, err)", tc, keys, err)
}
}
}

func TestDefaultServiceAcctSigner(t *testing.T) {
signer := &serviceAcctSigner{}
if email, err := signer.Email(); email != "" || err == nil {
t.Errorf("Email() = (%v, %v); want = ('', error)", email, err)
}
if sig, err := signer.Sign([]byte("")); sig != nil || err == nil {
t.Errorf("Sign() = (%v, %v); want = ('', error)", sig, err)
}
}

func verifyHTTPKeySource(ks *httpKeySource, rc *mockReadCloser) error {
mc := &mockClock{now: time.Unix(0, 0)}
ks.Clock = mc
Expand Down
5 changes: 1 addition & 4 deletions auth/jwt.go
Expand Up @@ -80,10 +80,7 @@ func decode(s string, i interface{}) error {
if err != nil {
return err
}
if err := json.NewDecoder(bytes.NewBuffer(decoded)).Decode(i); err != nil {
return err
}
return nil
return json.NewDecoder(bytes.NewBuffer(decoded)).Decode(i)
}

func encodeToken(s signer, h jwtHeader, p jwtPayload) (string, error) {
Expand Down

0 comments on commit aed47cf

Please sign in to comment.