Permalink
Browse files

Use referrer to determine URL sent in emails

Fixes #112
  • Loading branch information...
brycekahle committed Feb 9, 2018
1 parent b104245 commit c85d50f1540cd5a0aab1259e3ea2d6de8baadfd3
Showing with 79 additions and 56 deletions.
  1. +3 −11 api/external.go
  2. +16 −0 api/helpers.go
  3. +2 −1 api/invite.go
  4. +8 −8 api/mail.go
  5. +2 −1 api/recover.go
  6. +2 −1 api/signup.go
  7. +2 −1 api/user.go
  8. +19 −12 mailer/mailer.go
  9. +13 −9 mailer/mailer_test.go
  10. +4 −4 mailer/noop.go
  11. +8 −8 mailer/template.go
@@ -51,16 +51,7 @@ func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) e
}
}
referrer := ""
if reqref := r.Referer(); reqref != "" {
base, berr := url.Parse(config.SiteURL)
refurl, rerr := url.Parse(reqref)
// As long as the referrer came from the site, we will redirect back there
if berr == nil && rerr == nil && base.Hostname() == refurl.Hostname() {
referrer = reqref
}
}
referrer := a.getReferrer(r)
log := getLogEntry(r)
log.WithField("provider", providerType).Info("Redirecting to external provider")
@@ -170,7 +161,8 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
if !user.IsConfirmed() {
if !userData.Verified && !config.Mailer.Autoconfirm {
mailer := a.Mailer(ctx)
if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency); terr != nil {
referrer := a.getReferrer(r)
if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer); terr != nil {
return internalServerError("Error sending confirmation mail").WithInternalError(terr)
}
// email must be verified to issue a token
@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"github.com/netlify/gotrue/conf"
"github.com/netlify/gotrue/models"
@@ -91,3 +92,18 @@ func (a *API) requestAud(ctx context.Context, r *http.Request) string {
// Finally, return the default of none of the above methods are successful
return config.JWT.Aud
}
func (a *API) getReferrer(r *http.Request) string {
ctx := r.Context()
config := a.getConfig(ctx)
referrer := ""
if reqref := r.Referer(); reqref != "" {
base, berr := url.Parse(config.SiteURL)
refurl, rerr := url.Parse(reqref)
// As long as the referrer came from the site, we will redirect back there
if berr == nil && rerr == nil && base.Hostname() == refurl.Hostname() {
referrer = reqref
}
}
return referrer
}
@@ -60,7 +60,8 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error {
}
mailer := a.Mailer(ctx)
if err := sendInvite(tx, user, mailer); err != nil {
referrer := a.getReferrer(r)
if err := sendInvite(tx, user, mailer, referrer); err != nil {
return internalServerError("Error inviting user").WithInternalError(err)
}
return nil
@@ -11,57 +11,57 @@ import (
"github.com/pkg/errors"
)
func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration) error {
func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string) error {
if u.ConfirmationSentAt != nil && !u.ConfirmationSentAt.Add(maxFrequency).Before(time.Now()) {
return nil
}
oldToken := u.ConfirmationToken
u.ConfirmationToken = crypto.SecureToken()
now := time.Now()
if err := mailer.ConfirmationMail(u); err != nil {
if err := mailer.ConfirmationMail(u, referrerURL); err != nil {
u.ConfirmationToken = oldToken
return errors.Wrap(err, "Error sending confirmation email")
}
u.ConfirmationSentAt = &now
return errors.Wrap(tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at"), "Database error updating user for confirmation")
}
func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer) error {
func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, referrerURL string) error {
oldToken := u.ConfirmationToken
u.ConfirmationToken = crypto.SecureToken()
now := time.Now()
if err := mailer.InviteMail(u); err != nil {
if err := mailer.InviteMail(u, referrerURL); err != nil {
u.ConfirmationToken = oldToken
return errors.Wrap(err, "Error sending invite email")
}
u.InvitedAt = &now
return errors.Wrap(tx.UpdateOnly(u, "confirmation_token", "invited_at"), "Database error updating user for invite")
}
func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration) error {
func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string) error {
if u.RecoverySentAt != nil && !u.RecoverySentAt.Add(maxFrequency).Before(time.Now()) {
return nil
}
oldToken := u.RecoveryToken
u.RecoveryToken = crypto.SecureToken()
now := time.Now()
if err := mailer.RecoveryMail(u); err != nil {
if err := mailer.RecoveryMail(u, referrerURL); err != nil {
u.RecoveryToken = oldToken
return errors.Wrap(err, "Error sending recovery email")
}
u.RecoverySentAt = &now
return errors.Wrap(tx.UpdateOnly(u, "recovery_token", "recovery_sent_at"), "Database error updating user for recovery")
}
func (a *API) sendEmailChange(tx *storage.Connection, u *models.User, mailer mailer.Mailer, email string) error {
func (a *API) sendEmailChange(tx *storage.Connection, u *models.User, mailer mailer.Mailer, email string, referrerURL string) error {
oldToken := u.EmailChangeToken
oldEmail := u.EmailChange
u.EmailChangeToken = crypto.SecureToken()
u.EmailChange = email
now := time.Now()
if err := mailer.EmailChangeMail(u); err != nil {
if err := mailer.EmailChangeMail(u, referrerURL); err != nil {
u.EmailChangeToken = oldToken
u.EmailChange = oldEmail
return err
@@ -44,7 +44,8 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error {
}
mailer := a.Mailer(ctx)
return a.sendPasswordRecovery(tx, user, mailer, config.SMTP.MaxFrequency)
referrer := a.getReferrer(r)
return a.sendPasswordRecovery(tx, user, mailer, config.SMTP.MaxFrequency, referrer)
})
if err != nil {
return internalServerError("Error recovering user").WithInternalError(err)
@@ -80,7 +80,8 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
}
} else {
mailer := a.Mailer(ctx)
if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency); terr != nil {
referrer := a.getReferrer(r)
if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer); terr != nil {
return internalServerError("Error sending confirmation mail").WithInternalError(terr)
}
}
@@ -124,7 +124,8 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error {
}
mailer := a.Mailer(ctx)
if terr = a.sendEmailChange(tx, user, mailer, params.Email); terr != nil {
referrer := a.getReferrer(r)
if terr = a.sendEmailChange(tx, user, mailer, params.Email, referrer); terr != nil {
return internalServerError("Error sending change email").WithInternalError(terr)
}
}
@@ -11,10 +11,10 @@ import (
// Mailer defines the interface a mailer must implement.
type Mailer interface {
Send(user *models.User, subject, body string, data map[string]interface{}) error
InviteMail(user *models.User) error
ConfirmationMail(user *models.User) error
RecoveryMail(user *models.User) error
EmailChangeMail(user *models.User) error
InviteMail(user *models.User, referrerURL string) error
ConfirmationMail(user *models.User, referrerURL string) error
RecoveryMail(user *models.User, referrerURL string) error
EmailChangeMail(user *models.User, referrerURL string) error
ValidateEmail(email string) error
}
@@ -45,16 +45,23 @@ func withDefault(value, defaultValue string) string {
return value
}
func getSiteURL(siteURL, filepath, fragment string) (string, error) {
site, err := url.Parse(siteURL)
if err != nil {
return "", err
func getSiteURL(referrerURL, siteURL, filepath, fragment string) (string, error) {
baseURL := siteURL
if filepath == "" && referrerURL != "" {
baseURL = referrerURL
}
path, err := url.Parse(filepath)
site, err := url.Parse(baseURL)
if err != nil {
return "", err
}
u := site.ResolveReference(path)
u.Fragment = fragment
return u.String(), nil
if filepath != "" {
path, err := url.Parse(filepath)
if err != nil {
return "", err
}
site = site.ResolveReference(path)
}
site.Fragment = fragment
return site.String(), nil
}
@@ -8,19 +8,23 @@ import (
func TestGetSiteURL(t *testing.T) {
cases := []struct {
SiteURL string
Path string
Fragment string
Expected string
ReferrerURL string
SiteURL string
Path string
Fragment string
Expected string
}{
{"https://test.example.com", "/templates/confirm.html", "", "https://test.example.com/templates/confirm.html"},
{"https://test.example.com/removedpath", "/templates/confirm.html", "", "https://test.example.com/templates/confirm.html"},
{"https://test.example.com/", "/trailingslash/", "", "https://test.example.com/trailingslash/"},
{"https://test.example.com", "f", "fragment", "https://test.example.com/f#fragment"},
{"", "https://test.example.com", "/templates/confirm.html", "", "https://test.example.com/templates/confirm.html"},
{"", "https://test.example.com/removedpath", "/templates/confirm.html", "", "https://test.example.com/templates/confirm.html"},
{"", "https://test.example.com/", "/trailingslash/", "", "https://test.example.com/trailingslash/"},
{"", "https://test.example.com", "f", "fragment", "https://test.example.com/f#fragment"},
{"https://test.example.com/admin", "https://test.example.com", "", "fragment", "https://test.example.com/admin#fragment"},
{"https://test.example.com/admin", "https://test.example.com", "f", "fragment", "https://test.example.com/f#fragment"},
{"", "https://test.example.com", "", "fragment", "https://test.example.com#fragment"},
}
for _, c := range cases {
act, err := getSiteURL(c.SiteURL, c.Path, c.Fragment)
act, err := getSiteURL(c.ReferrerURL, c.SiteURL, c.Path, c.Fragment)
assert.NoError(t, err, c.Expected)
assert.Equal(t, c.Expected, act)
}
@@ -9,19 +9,19 @@ func (m noopMailer) ValidateEmail(email string) error {
return nil
}
func (m *noopMailer) InviteMail(user *models.User) error {
func (m *noopMailer) InviteMail(user *models.User, referrerURL string) error {
return nil
}
func (m *noopMailer) ConfirmationMail(user *models.User) error {
func (m *noopMailer) ConfirmationMail(user *models.User, referrerURL string) error {
return nil
}
func (m noopMailer) RecoveryMail(user *models.User) error {
func (m noopMailer) RecoveryMail(user *models.User, referrerURL string) error {
return nil
}
func (m *noopMailer) EmailChangeMail(user *models.User) error {
func (m *noopMailer) EmailChangeMail(user *models.User, referrerURL string) error {
return nil
}
@@ -41,8 +41,8 @@ func (m TemplateMailer) ValidateEmail(email string) error {
}
// InviteMail sends a invite mail to a new user
func (m *TemplateMailer) InviteMail(user *models.User) error {
url, err := getSiteURL(m.Config.SiteURL, m.Config.Mailer.URLPaths.Invite, "invite_token="+user.ConfirmationToken)
func (m *TemplateMailer) InviteMail(user *models.User, referrerURL string) error {
url, err := getSiteURL(referrerURL, m.Config.SiteURL, m.Config.Mailer.URLPaths.Invite, "invite_token="+user.ConfirmationToken)
if err != nil {
return err
}
@@ -64,8 +64,8 @@ func (m *TemplateMailer) InviteMail(user *models.User) error {
}
// ConfirmationMail sends a signup confirmation mail to a new user
func (m *TemplateMailer) ConfirmationMail(user *models.User) error {
url, err := getSiteURL(m.Config.SiteURL, m.Config.Mailer.URLPaths.Confirmation, "confirmation_token="+user.ConfirmationToken)
func (m *TemplateMailer) ConfirmationMail(user *models.User, referrerURL string) error {
url, err := getSiteURL(referrerURL, m.Config.SiteURL, m.Config.Mailer.URLPaths.Confirmation, "confirmation_token="+user.ConfirmationToken)
if err != nil {
return err
}
@@ -87,8 +87,8 @@ func (m *TemplateMailer) ConfirmationMail(user *models.User) error {
}
// EmailChangeMail sends an email change confirmation mail to a user
func (m *TemplateMailer) EmailChangeMail(user *models.User) error {
url, err := getSiteURL(m.Config.SiteURL, m.Config.Mailer.URLPaths.EmailChange, "email_change_token="+user.EmailChangeToken)
func (m *TemplateMailer) EmailChangeMail(user *models.User, referrerURL string) error {
url, err := getSiteURL(referrerURL, m.Config.SiteURL, m.Config.Mailer.URLPaths.EmailChange, "email_change_token="+user.EmailChangeToken)
if err != nil {
return err
}
@@ -111,8 +111,8 @@ func (m *TemplateMailer) EmailChangeMail(user *models.User) error {
}
// RecoveryMail sends a password recovery mail
func (m *TemplateMailer) RecoveryMail(user *models.User) error {
url, err := getSiteURL(m.Config.SiteURL, m.Config.Mailer.URLPaths.Recovery, "recovery_token="+user.RecoveryToken)
func (m *TemplateMailer) RecoveryMail(user *models.User, referrerURL string) error {
url, err := getSiteURL(referrerURL, m.Config.SiteURL, m.Config.Mailer.URLPaths.Recovery, "recovery_token="+user.RecoveryToken)
if err != nil {
return err
}

0 comments on commit c85d50f

Please sign in to comment.