diff --git a/cmd/server/assets/home.html b/cmd/server/assets/home.html index 0633dde6c..e55955aa4 100644 --- a/cmd/server/assets/home.html +++ b/cmd/server/assets/home.html @@ -23,7 +23,7 @@ - + {{template "navbar" .}}
diff --git a/cmd/server/assets/login/login.html b/cmd/server/assets/login/login.html index ddcdd0157..b64760b2c 100644 --- a/cmd/server/assets/login/login.html +++ b/cmd/server/assets/login/login.html @@ -7,7 +7,7 @@ {{template "firebase" .}} - + {{if .currentUser}} {{template "navbar" .}} {{end}} diff --git a/cmd/server/assets/login/register-phone.html b/cmd/server/assets/login/register-phone.html index 91e83f8bb..9e8dcb1f9 100644 --- a/cmd/server/assets/login/register-phone.html +++ b/cmd/server/assets/login/register-phone.html @@ -10,7 +10,7 @@ {{template "firebase" .}} - + {{template "navbar" .}}
{{template "flash" .}} diff --git a/go.mod b/go.mod index 5aca192cf..10df5372a 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,8 @@ require ( github.com/Azure/azure-sdk-for-go v46.4.0+incompatible // indirect github.com/Azure/go-autorest/autorest v0.11.8 // indirect github.com/aws/aws-sdk-go v1.35.3 // indirect + github.com/chromedp/cdproto v0.0.0-20201009231348-1c6a710e77de + github.com/chromedp/chromedp v0.5.3 github.com/client9/misspell v0.3.4 github.com/containerd/continuity v0.0.0-20200928162600-f2cc35102c2a // indirect github.com/dgrijalva/jwt-go v3.2.0+incompatible @@ -22,12 +24,13 @@ require ( github.com/gonum/internal v0.0.0-20181124074243-f884aa714029 // indirect github.com/gonum/lapack v0.0.0-20181123203213-e4cdc5a0bff9 // indirect github.com/gonum/matrix v0.0.0-20181209220409-c518dec07be9 - github.com/google/exposure-notifications-server v0.14.0 + github.com/google/exposure-notifications-server v0.14.1-0.20201029142042-d22c576d1701 github.com/google/go-cmp v0.5.2 github.com/gorilla/csrf v1.7.0 github.com/gorilla/handlers v1.5.1 github.com/gorilla/mux v1.8.0 github.com/gorilla/schema v1.2.0 + github.com/gorilla/securecookie v1.1.1 github.com/gorilla/sessions v1.2.1 github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.0 diff --git a/go.sum b/go.sum index 37b4a4df1..69ddcd38c 100644 --- a/go.sum +++ b/go.sum @@ -242,6 +242,13 @@ github.com/chris-ramon/douceur v0.2.0 h1:IDMEdxlEUUBYBKE4z/mJnFyVXox+MjuEVDJNN27 github.com/chris-ramon/douceur v0.2.0/go.mod h1:wDW5xjJdeoMm1mRt4sD4c/LbF/mWdEpRXQKjTR8nIBE= github.com/chrismalek/oktasdk-go v0.0.0-20181212195951-3430665dfaa0 h1:CWU8piLyqoi9qXEUwzOh5KFKGgmSU5ZhktJyYcq6ryQ= github.com/chrismalek/oktasdk-go v0.0.0-20181212195951-3430665dfaa0/go.mod h1:5d8DqS60xkj9k3aXfL3+mXBH0DPYO0FQjcKosxl+b/Q= +github.com/chromedp/cdproto v0.0.0-20200116234248-4da64dd111ac/go.mod h1:PfAWWKJqjlGFYJEidUM6aVIWPr0EpobeyVWEEmplX7g= +github.com/chromedp/cdproto v0.0.0-20201009231348-1c6a710e77de h1:cuPPanKjAp5XBwrD1RkeN4ILGRSffUhS69LKkFqKtIA= +github.com/chromedp/cdproto v0.0.0-20201009231348-1c6a710e77de/go.mod h1:zx0YH7hi8sqkYXAa0LZZxpQLDsU8/a2jzbYbK79dQO8= +github.com/chromedp/chromedp v0.5.3 h1:F9LafxmYpsQhWQBdCs+6Sret1zzeeFyHS5LkRF//Ffg= +github.com/chromedp/chromedp v0.5.3/go.mod h1:YLdPtndaHQ4rCpSpBG+IPpy9JvX0VD+7aaLxYgYj28w= +github.com/chromedp/sysutil v0.0.0-20201009230539-dc95e7e83e8a h1:31c/rx2f48S4oFimjMnIJNEutSwrWoASeUiGzPV5joA= +github.com/chromedp/sysutil v0.0.0-20201009230539-dc95e7e83e8a/go.mod h1:kgWmDdq8fTzXYcKIBqIYvRRTnYb9aNS9moAV0xufSww= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -411,6 +418,12 @@ github.com/go-test/deep v1.0.6/go.mod h1:QV8Hv/iy04NyLBxAdO9njL0iVPN1S4d/A3NVv1V github.com/go-yaml/yaml v2.1.0+incompatible h1:RYi2hDdss1u4YE7GwixGzWwVo47T8UQwnTLB6vQiq+o= github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0= github.com/gobuffalo/here v0.6.0/go.mod h1:wAG085dHOYqUpf+Ap+WOdrPTp5IYcDAs/x7PLa8Y5fM= +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= +github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= +github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= +github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= github.com/gocql/gocql v0.0.0-20190301043612-f6df8288f9b4/go.mod h1:4Fw1eo5iaEhDUs8XyuhSVCVy52Jq3L+/3GJgYkwc+/0= github.com/gocql/gocql v0.0.0-20190402132108-0e1d5de854df h1:fwXmhM0OqixzJDOGgTSyNH9eEDij9uGTXwsyWXvyR0A= github.com/gocql/gocql v0.0.0-20190402132108-0e1d5de854df/go.mod h1:4Fw1eo5iaEhDUs8XyuhSVCVy52Jq3L+/3GJgYkwc+/0= @@ -475,8 +488,8 @@ github.com/gonum/matrix v0.0.0-20181209220409-c518dec07be9/go.mod h1:0EXg4mc1CNP github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/exposure-notifications-server v0.14.0 h1:p/wwaKswPvlz4wWLYwWJQ56j3Vm/PznRt06NIPnzC/I= -github.com/google/exposure-notifications-server v0.14.0/go.mod h1:oyS7traveoREo37z0irHi0zN304YjD9esDZ4eL3Jtqo= +github.com/google/exposure-notifications-server v0.14.1-0.20201029142042-d22c576d1701 h1:kuyJFaSRGgveKzgH4xwld3j2TZfu8wFHK4uhZmazY1c= +github.com/google/exposure-notifications-server v0.14.1-0.20201029142042-d22c576d1701/go.mod h1:oyS7traveoREo37z0irHi0zN304YjD9esDZ4eL3Jtqo= github.com/google/flatbuffers v1.11.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -847,6 +860,7 @@ github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQL github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/cpuid v1.2.0/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= +github.com/knq/sysutil v0.0.0-20191005231841-15668db23d08/go.mod h1:dFWs1zEqDjFtnBXsd1vPOZaLsESovai349994nHx3e0= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -874,6 +888,9 @@ github.com/lstoll/awskms v0.0.0-20200603175638-a388516467f1/go.mod h1:HysB/5CMc0 github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.7.0/go.mod h1:KAzv3t3aY1NaHWoQz1+4F1ccyAH66Jk7yos7ldAVICs= +github.com/mailru/easyjson v0.7.1 h1:mdxE1MF9o53iCb2Ghj1VfWvh7ZOwHpnVG/xwXrV90U8= +github.com/mailru/easyjson v0.7.1/go.mod h1:KAzv3t3aY1NaHWoQz1+4F1ccyAH66Jk7yos7ldAVICs= github.com/markbates/pkger v0.15.1/go.mod h1:0JoVlrol20BSywW79rN3kdFFsE5xYM+rSCQDXbLhiuI= github.com/martini-contrib/render v0.0.0-20150707142108-ec18f8345a11 h1:YFh+sjyJTMQSYjKwM4dFKhJPJC/wfo98tPUc17HdoYw= github.com/martini-contrib/render v0.0.0-20150707142108-ec18f8345a11/go.mod h1:Ah2dBMoxZEqk118as2T4u4fjfXarE0pPnMJaArZQZsI= diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 904513eae..22d70a8d3 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -24,7 +24,8 @@ import ( ) var ( - ErrSessionMissing = fmt.Errorf("session is missing") + ErrSessionMissing = fmt.Errorf("session is missing") + ErrSessionInfoMissing = fmt.Errorf("session info is missing") ) // InviteUserEmailFunc sends email with the given inviteLink. @@ -85,8 +86,9 @@ type Provider interface { // SessionInfo is a generic struct used to store session information. Not all // providers use all fields. type SessionInfo struct { - // IDToken is a unique string or ID. It is usually a JWT token. - IDToken string + // Data is provider-specific information. The schema is determined by the + // provider. + Data map[string]interface{} // TTL is the session duration. TTL time.Duration diff --git a/internal/auth/firebase.go b/internal/auth/firebase.go index 1e2c66641..53fdedb50 100644 --- a/internal/auth/firebase.go +++ b/internal/auth/firebase.go @@ -81,8 +81,19 @@ func (f *firebaseAuth) CheckRevoked(ctx context.Context, session *sessions.Sessi // StoreSession stores information about the session. func (f *firebaseAuth) StoreSession(ctx context.Context, session *sessions.Session, i *SessionInfo) error { + if i == nil || i.Data == nil { + f.ClearSession(ctx, session) + return ErrSessionInfoMissing + } + + idToken, ok := i.Data["id_token"].(string) + if !ok { + f.ClearSession(ctx, session) + return fmt.Errorf("missing id_token: %w", ErrSessionInfoMissing) + } + // Convert ID token to long-lived cookie - cookie, err := f.firebaseAuth.SessionCookie(ctx, i.IDToken, i.TTL) + cookie, err := f.firebaseAuth.SessionCookie(ctx, idToken, i.TTL) if err != nil { f.ClearSession(ctx, session) return err @@ -152,15 +163,6 @@ func (f *firebaseAuth) CreateUser(ctx context.Context, name, email, pass string, return true, nil } -// IDToken extracts the users IDtoken from the session. -func (f *firebaseAuth) IDToken(ctx context.Context, session *sessions.Session) (string, error) { - data, err := f.loadCookie(ctx, session) - if err != nil { - return "", err - } - return data.IDToken, nil -} - // EmailAddress extracts the users email from the session. func (f *firebaseAuth) EmailAddress(ctx context.Context, session *sessions.Session) (string, error) { data, err := f.loadCookie(ctx, session) @@ -286,8 +288,7 @@ func (f *firebaseAuth) emailVerificationLink(ctx context.Context, email string) return verify, nil } -type cookieData struct { - IDToken string +type firebaseCookieData struct { Email string EmailVerified bool MFAEnabled bool @@ -295,7 +296,7 @@ type cookieData struct { // dataFromCookie extracts the information from the provided firebase cookie, if // it exists. -func (f *firebaseAuth) dataFromCookie(ctx context.Context, cookie string) (*cookieData, error) { +func (f *firebaseAuth) dataFromCookie(ctx context.Context, cookie string) (*firebaseCookieData, error) { token, err := f.firebaseAuth.VerifySessionCookie(ctx, cookie) if err != nil { return nil, fmt.Errorf("failed to verify firebase cookie: %w", err) @@ -305,12 +306,6 @@ func (f *firebaseAuth) dataFromCookie(ctx context.Context, cookie string) (*cook return nil, fmt.Errorf("token claims are empty") } - // IDToken - idToken, ok := token.Claims["user_id"].(string) - if !ok { - return nil, fmt.Errorf("token claims for id are not a string") - } - // Email email, ok := token.Claims["email"].(string) if !ok { @@ -330,8 +325,7 @@ func (f *firebaseAuth) dataFromCookie(ctx context.Context, cookie string) (*cook } _, mfaEnabled := firebase["sign_in_second_factor"] - return &cookieData{ - IDToken: idToken, + return &firebaseCookieData{ Email: email, EmailVerified: emailVerified, MFAEnabled: mfaEnabled, @@ -339,7 +333,7 @@ func (f *firebaseAuth) dataFromCookie(ctx context.Context, cookie string) (*cook } // loadCookie loads and parses the firebase cookie from the session. -func (f *firebaseAuth) loadCookie(ctx context.Context, session *sessions.Session) (*cookieData, error) { +func (f *firebaseAuth) loadCookie(ctx context.Context, session *sessions.Session) (*firebaseCookieData, error) { raw, err := sessionGet(session, sessionKeyFirebaseCookie) if err != nil { f.ClearSession(ctx, session) diff --git a/internal/auth/local.go b/internal/auth/local.go new file mode 100644 index 000000000..b3d64e380 --- /dev/null +++ b/internal/auth/local.go @@ -0,0 +1,255 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/gorilla/sessions" +) + +const ( + sessionKeyLocalCookie = sessionKey("localCookie") +) + +type localAuth struct{} + +// NewLocal creates a new auth provider for local auth. +func NewLocal(ctx context.Context) (Provider, error) { + return &localAuth{}, nil +} + +// CheckRevoked checks if the users auth has been revoked. +func (a *localAuth) CheckRevoked(ctx context.Context, session *sessions.Session) error { + data, err := a.loadCookie(ctx, session) + if err != nil { + return err + } + + if data.Revoked { + return fmt.Errorf("session is revoked") + } + return nil +} + +// StoreSession stores information about the session. +func (a *localAuth) StoreSession(ctx context.Context, session *sessions.Session, i *SessionInfo) error { + if i == nil || i.Data == nil { + a.ClearSession(ctx, session) + return ErrSessionInfoMissing + } + + email, ok := i.Data["email"].(string) + if !ok { + a.ClearSession(ctx, session) + return fmt.Errorf("missing email: %w", ErrSessionInfoMissing) + } + + emailVerified, ok := i.Data["email_verified"].(bool) + if !ok { + a.ClearSession(ctx, session) + return fmt.Errorf("missing email_verified: %w", ErrSessionInfoMissing) + } + + mfaEnabled, ok := i.Data["mfa_enabled"].(bool) + if !ok { + a.ClearSession(ctx, session) + return fmt.Errorf("missing mfa_enabled: %w", ErrSessionInfoMissing) + } + + revoked, ok := i.Data["revoked"].(bool) + if !ok { + a.ClearSession(ctx, session) + return fmt.Errorf("missing revoked: %w", ErrSessionInfoMissing) + } + + // Convert ID token to long-lived cookie + cookie, err := json.Marshal(&localCookieData{ + Email: email, + EmailVerified: emailVerified, + MFAEnabled: mfaEnabled, + Revoked: revoked, + }) + if err != nil { + a.ClearSession(ctx, session) + return err + } + + // Set cookie + if err := sessionSet(session, sessionKeyLocalCookie, string(cookie)); err != nil { + a.ClearSession(ctx, session) + return err + } + + return nil +} + +// ClearSession removes any session information for this auth. +func (a *localAuth) ClearSession(ctx context.Context, session *sessions.Session) { + sessionClear(session, sessionKeyLocalCookie) +} + +// CreateUser creates a user in the upstream auth system with the given name and +// email. It returns true if the user was created or false if the user already +// exists. +func (a *localAuth) CreateUser(ctx context.Context, name, email, pass string, emailer InviteUserEmailFunc) (bool, error) { + if emailer == nil { + return false, fmt.Errorf("emailer is required for local auth") + } + + // For local auth, this is a noop since the controllers create the user in the + // database. + + // Send the welcome email. + inviteLink, err := a.passwordResetLink(ctx, email) + if err != nil { + return true, err + } + + if err := emailer(ctx, inviteLink); err != nil { + return true, fmt.Errorf("failed to send new user invitation email: %w", err) + } + + return true, nil +} + +// EmailAddress extracts the users email from the session. +func (a *localAuth) EmailAddress(ctx context.Context, session *sessions.Session) (string, error) { + data, err := a.loadCookie(ctx, session) + if err != nil { + return "", err + } + return data.Email, nil +} + +// EmailVerified returns true if the current user is verified, false otherwise. +func (a *localAuth) EmailVerified(ctx context.Context, session *sessions.Session) (bool, error) { + data, err := a.loadCookie(ctx, session) + if err != nil { + return false, err + } + return data.EmailVerified, nil +} + +// MFAEnabled returns whether MFA is enabled on the account. +func (a *localAuth) MFAEnabled(ctx context.Context, session *sessions.Session) (bool, error) { + data, err := a.loadCookie(ctx, session) + if err != nil { + return false, err + } + return data.MFAEnabled, nil +} + +// ChangePassword changes the users password. The data is not used. Since local +// auth does not use passwords, this is a noop. +func (a *localAuth) ChangePassword(ctx context.Context, newPassword string, data interface{}) error { + return nil +} + +// SendResetPasswordEmail resets the password for the given user. If the user does not +// exist, an error is returned. +func (a *localAuth) SendResetPasswordEmail(ctx context.Context, email string, emailer ResetPasswordEmailFunc) error { + if emailer == nil { + return fmt.Errorf("emailer is required for local auth") + } + + resetLink, err := a.passwordResetLink(ctx, email) + if err != nil { + return err + } + + if err := emailer(ctx, resetLink); err != nil { + return fmt.Errorf("failed to send password reset email: %w", err) + } + + return nil +} + +// VerifyPasswordResetCode does nothing. It returns the empty string. +func (a *localAuth) VerifyPasswordResetCode(ctx context.Context, code string) (string, error) { + return "", nil +} + +// SendEmailVerificationEmail sends an message to the currently authenticated +// user, asking them to verify ownership of the email address. +func (a *localAuth) SendEmailVerificationEmail(ctx context.Context, email string, data interface{}, emailer EmailVerificationEmailFunc) error { + if emailer == nil { + return fmt.Errorf("emailer is required for local auth") + } + + verifyLink, err := a.emailVerificationLink(ctx, email) + if err != nil { + return err + } + + if err := emailer(ctx, verifyLink); err != nil { + return fmt.Errorf("failed to send email verification email: %w", err) + } + + return nil +} + +// passwordResetLink generates and returns the password reset link for the given +// email (user). +func (a *localAuth) passwordResetLink(ctx context.Context, email string) (string, error) { + return "", fmt.Errorf("not yet implemented for local auth") +} + +// emailVerificationLink generates an email verification link for the given +// email. +func (a *localAuth) emailVerificationLink(ctx context.Context, email string) (string, error) { + return "", fmt.Errorf("not yet implemented for local auth") +} + +type localCookieData struct { + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + MFAEnabled bool `json:"mfa_enabled"` + Revoked bool `json:"revoked"` +} + +// dataFromCookie extracts the information from the provided local cookie, if it +// exists. The local cookie is actually just a JSON payload. +func (a *localAuth) dataFromCookie(ctx context.Context, cookie string) (*localCookieData, error) { + var data localCookieData + if err := json.Unmarshal([]byte(cookie), &data); err != nil { + return nil, err + } + return &data, nil +} + +// loadCookie loads and parses the local cookie from the session. +func (a *localAuth) loadCookie(ctx context.Context, session *sessions.Session) (*localCookieData, error) { + raw, err := sessionGet(session, sessionKeyLocalCookie) + if err != nil { + a.ClearSession(ctx, session) + return nil, err + } + + cookie, ok := raw.(string) + if !ok || cookie == "" { + a.ClearSession(ctx, session) + return nil, ErrSessionMissing + } + + data, err := a.dataFromCookie(ctx, cookie) + if err != nil { + a.ClearSession(ctx, session) + return nil, err + } + return data, nil +} diff --git a/internal/browser/browser.go b/internal/browser/browser.go new file mode 100644 index 000000000..48d661647 --- /dev/null +++ b/internal/browser/browser.go @@ -0,0 +1,20 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package browser provides helpers for writing integration tests that interact +// with the browser. It wraps chromedp to create a real browser, click buttons, +// and assert results. +// +// This package should only be used by tests. +package browser diff --git a/internal/browser/executor.go b/internal/browser/executor.go new file mode 100644 index 000000000..448b65b0d --- /dev/null +++ b/internal/browser/executor.go @@ -0,0 +1,188 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package browser + +import ( + "context" + "fmt" + "math" + "net/http" + "testing" + "time" + + "github.com/chromedp/cdproto/cdp" + "github.com/chromedp/cdproto/emulation" + "github.com/chromedp/cdproto/network" + "github.com/chromedp/cdproto/page" + "github.com/chromedp/chromedp" +) + +// defaultOptions are the default Chrome options. +var defaultOptions = [...]chromedp.ExecAllocatorOption{ + chromedp.NoFirstRun, + chromedp.NoDefaultBrowserCheck, + + // After Puppeteer's default behavior. + chromedp.Flag("disable-background-networking", true), + chromedp.Flag("enable-features", "NetworkService,NetworkServiceInProcess"), + chromedp.Flag("disable-background-timer-throttling", true), + chromedp.Flag("disable-backgrounding-occluded-windows", true), + chromedp.Flag("disable-breakpad", true), + chromedp.Flag("disable-client-side-phishing-detection", true), + chromedp.Flag("disable-default-apps", true), + chromedp.Flag("disable-dev-shm-usage", true), + chromedp.Flag("disable-extensions", true), + chromedp.Flag("disable-features", "site-per-process,TranslateUI,BlinkGenPropertyTrees"), + chromedp.Flag("disable-hang-monitor", true), + chromedp.Flag("disable-ipc-flooding-protection", true), + chromedp.Flag("disable-popup-blocking", true), + chromedp.Flag("disable-prompt-on-repost", true), + chromedp.Flag("disable-renderer-backgrounding", true), + chromedp.Flag("disable-sync", true), + chromedp.Flag("force-color-profile", "srgb"), + chromedp.Flag("metrics-recording-only", true), + chromedp.Flag("safebrowsing-disable-auto-update", true), + chromedp.Flag("enable-automation", true), + chromedp.Flag("password-store", "basic"), + chromedp.Flag("use-mock-keychain", true), +} + +// New creates a new headless browser context. Se NewFromOptions for usage. +func New(tb testing.TB) context.Context { + tb.Helper() + opts := defaultOptions[:] + opts = append(opts, chromedp.Headless) + return NewFromOptions(tb, opts) +} + +// NewHeadful creates a new browser context so you can actually watch the test. +// This is for local debugging and will fail on CI where a browser isn't +// actually available. +func NewHeadful(tb testing.TB) context.Context { + tb.Helper() + return NewFromOptions(tb, defaultOptions[:]) +} + +// NewFromOptions creates a new browser instance. All future calls to `Run` must +// use the context returned by this function! +// +// If this function returns successfully, a browser is running and ready to be +// used. It's recommended that you wrap the returned context in a timeout. +func NewFromOptions(tb testing.TB, opts []chromedp.ExecAllocatorOption) context.Context { + tb.Helper() + + allocCtx, cancel := chromedp.NewExecAllocator(context.Background(), opts...) + tb.Cleanup(cancel) + + taskCtx, cancel := chromedp.NewContext(allocCtx, chromedp.WithLogf(tb.Logf)) + tb.Cleanup(cancel) + + // Start browser + if err := chromedp.Run(taskCtx); err != nil { + tb.Fatal(err) + } + + return taskCtx +} + +// Screenshot captures a screenshot of the browser page in its current state. +// This is useful for debugging a test failure. The dst will contain the +// screenshot bytes in PNG format when the runner finishes. +func Screenshot(dst *[]byte) chromedp.Action { + return chromedp.ActionFunc(func(ctx context.Context) error { + _, _, contentSize, err := page.GetLayoutMetrics().Do(ctx) + if err != nil { + return err + } + + width, height := int64(math.Ceil(contentSize.Width)), int64(math.Ceil(contentSize.Height)) + + err = emulation. + SetDeviceMetricsOverride(width, height, 1, false). + WithScreenOrientation(&emulation.ScreenOrientation{ + Type: emulation.OrientationTypePortraitPrimary, + Angle: 0, + }). + Do(ctx) + if err != nil { + return err + } + + // capture screenshot + *dst, err = page.CaptureScreenshot(). + WithQuality(100). + WithClip(&page.Viewport{ + X: contentSize.X, + Y: contentSize.Y, + Width: contentSize.Width, + Height: contentSize.Height, + Scale: 2, + }).Do(ctx) + if err != nil { + return err + } + return nil + }) +} + +// SetCookie sets a cookie with the provided parameters. This can be used to +// bypass login and force a specific user be logged in during the test. +func SetCookie(c *http.Cookie) chromedp.Action { + return chromedp.ActionFunc(func(ctx context.Context) error { + exp := cdp.TimeSinceEpoch(time.Now().Add(24 * time.Hour)) + + ok, err := network. + SetCookie(c.Name, c.Value). + WithPath(c.Path). + WithDomain(c.Domain). + WithExpires(&exp). + WithSecure(c.Secure). + WithHTTPOnly(c.HttpOnly). + Do(ctx) + if err != nil { + return err + } + if !ok { + return fmt.Errorf("failed to set cookie %q", c.Name) + } + return nil + }) +} + +// Cookies sets the current list of cookies into the provided destination. +func Cookies(dst *[]*http.Cookie) chromedp.Action { + return chromedp.ActionFunc(func(ctx context.Context) error { + networkCookies, err := network.GetAllCookies().Do(ctx) + if err != nil { + return err + } + + httpCookies := make([]*http.Cookie, len(networkCookies)) + for i, c := range networkCookies { + httpCookies[i] = &http.Cookie{ + Name: c.Name, + Value: c.Value, + Path: c.Path, + Domain: c.Domain, + Expires: time.Unix(int64(c.Expires), 0), + Secure: c.Secure, + HttpOnly: c.HTTPOnly, + } + } + *dst = httpCookies + + return nil + }) +} diff --git a/internal/envstest/envstest.go b/internal/envstest/envstest.go new file mode 100644 index 000000000..1a152571c --- /dev/null +++ b/internal/envstest/envstest.go @@ -0,0 +1,16 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package envstest defines global test helpers for the entire project. +package envstest diff --git a/internal/envstest/random.go b/internal/envstest/random.go new file mode 100644 index 000000000..28b418a2c --- /dev/null +++ b/internal/envstest/random.go @@ -0,0 +1,40 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package envstest + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "testing" +) + +// RandomBytes returns a byte slice of random values of the given length. +func RandomBytes(tb testing.TB, length int) []byte { + buf := make([]byte, length) + n, err := rand.Read(buf) + if err != nil { + tb.Fatal(err) + } + if n < length { + tb.Fatal(fmt.Errorf("insufficient bytes read: %v, expected %v", n, length)) + } + return buf +} + +// RandomString returns a random hex-encoded string of the given length. +func RandomString(tb testing.TB, length int) string { + return hex.EncodeToString(RandomBytes(tb, length/2)) +} diff --git a/internal/envstest/server.go b/internal/envstest/server.go new file mode 100644 index 000000000..e64f59752 --- /dev/null +++ b/internal/envstest/server.go @@ -0,0 +1,304 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package envstest + +import ( + "context" + "fmt" + "net" + "net/http" + "path/filepath" + "testing" + "time" + + "github.com/google/exposure-notifications-server/pkg/keys" + "github.com/google/exposure-notifications-server/pkg/logging" + "github.com/google/exposure-notifications-server/pkg/server" + "github.com/google/exposure-notifications-verification-server/internal/auth" + "github.com/google/exposure-notifications-verification-server/internal/project" + "github.com/google/exposure-notifications-verification-server/internal/routes" + "github.com/google/exposure-notifications-verification-server/pkg/cache" + "github.com/google/exposure-notifications-verification-server/pkg/config" + "github.com/google/exposure-notifications-verification-server/pkg/controller" + "github.com/google/exposure-notifications-verification-server/pkg/database" + "github.com/google/exposure-notifications-verification-server/pkg/ratelimit" + + "github.com/gorilla/securecookie" + "github.com/gorilla/sessions" + "github.com/sethvargo/go-envconfig" + "github.com/sethvargo/go-limiter" + "github.com/sethvargo/go-limiter/memorystore" +) + +const ( + // sessionName is the name of the session. This must match the session name in + // the sessions middleware, but cannot be pulled from there due to a cyclical + // dependency. + sessionName = "verification-server-session" +) + +// TestServerResponse is used as the reply to creating a test UI server. +type TestServerResponse struct { + AuthProvider auth.Provider + Cacher cache.Cacher + Config *config.ServerConfig + Database *database.Database + KeyManager keys.KeyManager + RateLimiter limiter.Store + Server *server.Server +} + +// SessionCookie returns an encrypted cookie for the given session information, +// capable of being injected into the browser instance and read by the +// application. Since the cookie contains the session, it can be used to mutate +// any server state, including the currently-authenticated user. +func (r *TestServerResponse) SessionCookie(session *sessions.Session) (*http.Cookie, error) { + if session == nil { + return nil, fmt.Errorf("session cannot be nil") + } + + // Update options to be the server domain + if session.Options == nil { + session.Options = &sessions.Options{} + } + session.Options.Domain = r.Server.Addr() + session.Options.Path = "/" + + // Encode and encrypt the cookie using the same configuration as the server. + codecs := securecookie.CodecsFromPairs(r.Config.CookieKeys.AsBytes()...) + encoded, err := securecookie.EncodeMulti(sessionName, session.Values, codecs...) + if err != nil { + return nil, fmt.Errorf("failed to encode session cookie: %w", err) + } + + return sessions.NewCookie(sessionName, encoded, session.Options), nil +} + +// LoggedInCookie returns an encrypted cookie with the provided email address +// logged in. It also stores that email verification and MFA prompting have +// already occurred for a consistent post-login experience. +// +// The provided email is marked as verified, has MFA enabled, and is not +// revoked. To test other journeys, manually build the session. +func (r *TestServerResponse) LoggedInCookie(email string) (*http.Cookie, error) { + session := &sessions.Session{ + Values: map[interface{}]interface{}{}, + Options: &sessions.Options{}, + IsNew: true, + } + + controller.StoreSessionEmailVerificationPrompted(session, true) + controller.StoreSessionMFAPrompted(session, false) + + ctx := context.Background() + if err := r.AuthProvider.StoreSession(ctx, session, &auth.SessionInfo{ + Data: map[string]interface{}{ + "email": email, + "email_verified": true, + "mfa_enabled": true, + "revoked": false, + }, + TTL: 5 * time.Minute, + }); err != nil { + return nil, err + } + + return r.SessionCookie(session) +} + +// NewServer creates a new test UI server instance. When this function returns, +// a full UI server will be running locally on a random port. Cleanup is handled +// automatically. +func NewServer(tb testing.TB) *TestServerResponse { + tb.Helper() + + if testing.Short() { + tb.Skip() + } + + // Create the config and requirements. + response := newServerConfig(tb) + + // Configure logging + logger := logging.NewLogger(true) + ctx := logging.WithLogger(context.Background(), logger) + + // Build the routing. + mux, err := routes.Server(ctx, response.Config, response.Database, response.AuthProvider, response.Cacher, response.KeyManager, response.RateLimiter) + if err != nil { + tb.Fatal(err) + } + + // Create a stoppable context. + doneCtx, cancel := context.WithCancel(ctx) + tb.Cleanup(func() { + cancel() + }) + + // As of 2020-10-29, our CI infrastructure does not support IPv6. `server.New` + // binds to "tcp", which picks the "best" address, but it prefers IPv6. As a + // result, the server binds to the IPv6 loopback`[::]`, but then our browser + // instance cannot actually contact that loopback interface. To mitigate this, + // create a custom listener and force IPv4. The listener will still pick a + // randomly available port, but it will only choose an IPv4 address upon which + // to bind. + listener, err := net.Listen("tcp4", ":0") + if err != nil { + tb.Fatalf("failed to create listener: %v", err) + } + + // Start the server on a random port. Closing doneCtx will stop the server + // (which the cleanup step does). + srv, err := server.NewFromListener(listener) + if err != nil { + tb.Fatal(err) + } + go func() { + if err := srv.ServeHTTPHandler(doneCtx, mux); err != nil { + tb.Error(err) + } + }() + + return &TestServerResponse{ + AuthProvider: response.AuthProvider, + Config: response.Config, + Database: response.Database, + Cacher: response.Cacher, + KeyManager: response.KeyManager, + RateLimiter: response.RateLimiter, + Server: srv, + } +} + +// serverConfigResponse is the response from creating a server config. +type serverConfigResponse struct { + AuthProvider auth.Provider + Config *config.ServerConfig + Database *database.Database + Cacher cache.Cacher + KeyManager keys.KeyManager + RateLimiter limiter.Store +} + +// newServerConfig creates a new server configuration. It creates all the keys, +// databases, and cacher, but does not actually start the server. All cleanup is +// scheduled by t.Cleanup. +func newServerConfig(tb testing.TB) *serverConfigResponse { + tb.Helper() + + if testing.Short() { + tb.Skip() + } + + // Create the auth provider + authProvider, err := auth.NewLocal(context.Background()) + if err != nil { + tb.Fatal(err) + } + + // Create the cacher. + cacher, err := cache.NewInMemory(nil) + if err != nil { + tb.Fatal(err) + } + tb.Cleanup(func() { + if err := cacher.Close(); err != nil { + tb.Fatal(err) + } + }) + + // Create the database. + db, dbConfig := database.NewTestDatabaseWithCacher(tb, cacher) + + // Create the key manager. + keyManager := keys.TestKeyManager(tb) + + // Create the rate limiter. + limiterStore, err := memorystore.New(&memorystore.Config{ + Tokens: 30, + Interval: time.Second, + }) + if err != nil { + tb.Fatal(err) + } + tb.Cleanup(func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := limiterStore.Close(ctx); err != nil { + tb.Fatal(err) + } + }) + + // Create the config. + cfg := &config.ServerConfig{ + AssetsPath: ServerAssetsPath(tb), + Cache: cache.Config{ + Type: cache.TypeInMemory, + HMACKey: RandomBytes(tb, 64), + }, + Database: *dbConfig, + // Firebase is not used for browser tests. + Firebase: config.FirebaseConfig{ + APIKey: "test", + AuthDomain: "test.firebaseapp.com", + DatabaseURL: "https://test.firebaseio.com", + ProjectID: "test", + StorageBucket: "test.appspot.com", + MessageSenderID: "test", + AppID: "1:test:web:test", + MeasurementID: "G-TEST", + }, + CookieKeys: config.Base64ByteSlice{RandomBytes(tb, 64), RandomBytes(tb, 32)}, + CSRFAuthKey: RandomBytes(tb, 32), + CertificateSigning: config.CertificateSigningConfig{ + // TODO(sethvargo): configure this when the first test requires it + CertificateSigningKey: "UPDATE_ME", + Keys: keys.Config{ + KeyManagerType: keys.KeyManagerTypeFilesystem, + FilesystemRoot: filepath.Join(project.Root(), "local", "test", RandomString(tb, 8)), + }, + }, + RateLimit: ratelimit.Config{ + Type: ratelimit.RateLimiterTypeMemory, + HMACKey: RandomBytes(tb, 64), + }, + + // DevMode has to be enabled for tests. Otherwise the cookies fail. + DevMode: true, + } + + // Process the config - this simulates production setups and also ensures we + // get the defaults for any unset values. + emptyLookuper := envconfig.MapLookuper(nil) + if err := config.ProcessWith(context.Background(), cfg, emptyLookuper); err != nil { + tb.Fatal(err) + } + + return &serverConfigResponse{ + AuthProvider: authProvider, + Config: cfg, + Database: db, + Cacher: cacher, + KeyManager: keyManager, + RateLimiter: limiterStore, + } +} + +// ServerAssetsPath returns the path to the UI server assets. +func ServerAssetsPath(tb testing.TB) string { + tb.Helper() + return filepath.Join(project.Root(), "cmd", "server", "assets") +} diff --git a/internal/project/root.go b/internal/project/root.go new file mode 100644 index 000000000..91f333f37 --- /dev/null +++ b/internal/project/root.go @@ -0,0 +1,28 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package project defines global project helpers. +package project + +import ( + "path/filepath" + "runtime" +) + +var _, self, _, _ = runtime.Caller(0) + +// Root returns the filepath to the root of this project. +func Root() string { + return filepath.Join(filepath.Dir(self), "..", "..") +} diff --git a/pkg/controller/home/home_test.go b/pkg/controller/home/home_test.go new file mode 100644 index 000000000..ead498490 --- /dev/null +++ b/pkg/controller/home/home_test.go @@ -0,0 +1,101 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package home_test + +import ( + "context" + "testing" + "time" + + "github.com/google/exposure-notifications-verification-server/internal/browser" + "github.com/google/exposure-notifications-verification-server/internal/envstest" + "github.com/google/exposure-notifications-verification-server/pkg/database" + + "github.com/chromedp/chromedp" +) + +func TestHandleHome_IssueCode(t *testing.T) { + t.Parallel() + + harness := envstest.NewServer(t) + + // Get the default realm + realm, err := harness.Database.FindRealm(1) + if err != nil { + t.Fatal(err) + } + + // Create a user + admin := &database.User{ + Email: "admin@example.com", + Name: "Admin", + Realms: []*database.Realm{realm}, + AdminRealms: []*database.Realm{realm}, + } + if err := harness.Database.SaveUser(admin, database.System); err != nil { + t.Fatal(err) + } + + // Create a cookie that logs this user in. + cookie, err := harness.LoggedInCookie(admin.Email) + if err != nil { + panic(err) + } + + // Create a browser runner. + browserCtx := browser.New(t) + taskCtx, done := context.WithTimeout(browserCtx, 30*time.Second) + defer done() + + var code string + if err := chromedp.Run(taskCtx, + // Pre-authenticate the user. + browser.SetCookie(cookie), + + // Visit /home. + chromedp.Navigate(`http://`+harness.Server.Addr()+`/home`), + + // Wait for render. + chromedp.WaitVisible(`body#home`, chromedp.ByQuery), + + // Click the issue button. + chromedp.Click(`#submit`, chromedp.ByQuery), + chromedp.WaitVisible(`#code`, chromedp.ByQuery), + + // Get the code. + chromedp.TextContent(`#code`, &code, chromedp.ByQuery), + ); err != nil { + t.Fatal(err) + } + + // Verify code length. + if got, want := len(code), 8; got != want { + t.Errorf("expected %v to be %v", got, want) + } + + // Verify the code exists. + dbCode, err := harness.Database.FindVerificationCode(code) + if err != nil { + t.Fatal(err) + } + + if got, want := dbCode.TestType, "confirmed"; got != want { + t.Errorf("expected %v to be %v", got, want) + } + + if got, want := dbCode.Claimed, false; got != want { + t.Errorf("expected %v to be %v", got, want) + } +} diff --git a/pkg/controller/login/session.go b/pkg/controller/login/session.go index 864bae051..256c8f705 100644 --- a/pkg/controller/login/session.go +++ b/pkg/controller/login/session.go @@ -47,8 +47,10 @@ func (c *Controller) HandleCreateSession() http.Handler { // Create the session cookie. if err := c.authProvider.StoreSession(ctx, session, &auth.SessionInfo{ - IDToken: form.IDToken, - TTL: c.config.SessionDuration, + Data: map[string]interface{}{ + "id_token": form.IDToken, + }, + TTL: c.config.SessionDuration, }); err != nil { flash.Error("Failed to create session: %v", err) c.h.RenderJSON(w, http.StatusUnauthorized, api.Error(err)) diff --git a/pkg/controller/middleware/auth.go b/pkg/controller/middleware/auth.go index 45cebaece..747e47c2e 100644 --- a/pkg/controller/middleware/auth.go +++ b/pkg/controller/middleware/auth.go @@ -49,7 +49,6 @@ func RequireAuth(ctx context.Context, cacher cache.Cacher, authProvider auth.Pro controller.MissingSession(w, r, h) return } - flash := controller.Flash(session) // Check session idle timeout. diff --git a/pkg/controller/middleware/sessions.go b/pkg/controller/middleware/sessions.go index 789c74107..2085a093a 100644 --- a/pkg/controller/middleware/sessions.go +++ b/pkg/controller/middleware/sessions.go @@ -48,6 +48,8 @@ func RequireSession(ctx context.Context, store sessions.Store, h *render.Rendere // Get or create a session from the store. session, err := store.Get(r, sessionName) if err != nil { + logger.Errorw("failed to get session", "error", err) + // We couldn't get a session (invalid cookie, can't talk to redis, // whatever). According to the spec, this can return an error but can never // return an empty session. We intentionally discard the error to ensure we diff --git a/pkg/database/database_util.go b/pkg/database/database_util.go index 7fc577c00..376a86959 100644 --- a/pkg/database/database_util.go +++ b/pkg/database/database_util.go @@ -17,6 +17,8 @@ package database import ( "context" "crypto/rand" + "io/ioutil" + "log" "os" "strconv" "testing" @@ -24,8 +26,10 @@ import ( "github.com/google/exposure-notifications-server/pkg/keys" "github.com/google/exposure-notifications-server/pkg/secrets" + "github.com/google/exposure-notifications-verification-server/pkg/cache" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/jinzhu/gorm" "github.com/ory/dockertest" "github.com/sethvargo/go-envconfig" ) @@ -34,13 +38,12 @@ var ( approxTime = cmp.Options{cmpopts.EquateApproxTime(time.Second)} ) -// NewTestDatabaseWithConfig creates a new database suitable for use in testing. -// This should not be used outside of testing, but it is exposed in the main -// package so it can be shared with other packages. +// NewTestDatabaseWithCacher creates a database configured with a cacher for use +// in testing. // // All database tests can be skipped by running `go test -short` or by setting // the `SKIP_DATABASE_TESTS` environment variable. -func NewTestDatabaseWithConfig(tb testing.TB) (*Database, *Config) { +func NewTestDatabaseWithCacher(tb testing.TB, cacher cache.Cacher) (*Database, *Config) { tb.Helper() if testing.Short() { @@ -62,7 +65,6 @@ func NewTestDatabaseWithConfig(tb testing.TB) (*Database, *Config) { // Start the container. dbname, username, password := "en-verification-server", "my-username", "abcd1234" - tb.Log("Starting database") container, err := pool.RunWithOptions(&dockertest.RunOptions{ Repository: "postgres", Tag: "12-alpine", @@ -122,15 +124,22 @@ func NewTestDatabaseWithConfig(tb testing.TB) (*Database, *Config) { db.keyManager = keys.TestKeyManager(tb) db.config.EncryptionKey = keys.TestEncryptionKey(tb, db.keyManager) - if err := db.Open(ctx); err != nil { + if err := db.OpenWithCacher(ctx, cacher); err != nil { tb.Fatal(err) } - db.db.LogMode(false) + + // Disable logging temporarily for migrations. The callback registration is + // really quite chatty. + db.db.SetLogger(gorm.Logger{LogWriter: log.New(ioutil.Discard, "", 0)}) + db.db = db.db.LogMode(false) if err := db.RunMigrations(ctx); err != nil { tb.Fatalf("failed to migrate database: %v", err) } + // Re-enable logging. + db.db.SetLogger(gorm.Logger{LogWriter: log.New(os.Stdout, "", 0)}) + // Close db when done. tb.Cleanup(func() { db.db.Close() @@ -139,6 +148,20 @@ func NewTestDatabaseWithConfig(tb testing.TB) (*Database, *Config) { return db, config } +// NewTestDatabaseWithConfig creates a new database suitable for use in testing. +// This should not be used outside of testing, but it is exposed in the main +// package so it can be shared with other packages. +// +// All database tests can be skipped by running `go test -short` or by setting +// the `SKIP_DATABASE_TESTS` environment variable. +func NewTestDatabaseWithConfig(tb testing.TB) (*Database, *Config) { + return NewTestDatabaseWithCacher(tb, nil) +} + +// NewTestDatabase creates a new test database with the defautl configuration. +// +// All database tests can be skipped by running `go test -short` or by setting +// the `SKIP_DATABASE_TESTS` environment variable. func NewTestDatabase(tb testing.TB) *Database { tb.Helper()