Skip to content

Commit

Permalink
multi: custom sql store (#410)
Browse files Browse the repository at this point in the history
- new sql store to store gorilla sessions
- new db table 'Session'
- session timeout at 6 hours
- delete session data on logout
- delete session datas for user on password/email change
  • Loading branch information
JoeGruffins authored and dajohi committed Jul 25, 2019
1 parent 3058a44 commit cce2438
Show file tree
Hide file tree
Showing 8 changed files with 490 additions and 15 deletions.
20 changes: 16 additions & 4 deletions controllers/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -1157,9 +1157,12 @@ func (controller *MainController) EmailUpdate(c web.C, r *http.Request) (string,
"emailupdateError")
log.Errorf("EmailChangeComplete failed %v", err)
} else {
// Logout the user to force them to login with their new
// email address
session.Values["UserId"] = nil

// destroy session data and force re-login
userID, _ := session.Values["UserId"].(int64)
session.Options.MaxAge = -1
system.DestroySessionsForUserID(dbMap, userID)

session.AddFlash("Email successfully updated",
"emailupdateSuccess")
}
Expand Down Expand Up @@ -1452,6 +1455,9 @@ func (controller *MainController) PasswordUpdatePost(c web.C, r *http.Request) (
log.Errorf("error deleting token %v", err)
}

// destroy session data
system.DestroySessionsForUserID(dbMap, user.Id)

session.AddFlash("Password successfully updated", "passwordupdateSuccess")
return controller.PasswordUpdate(c, r)
}
Expand Down Expand Up @@ -1581,6 +1587,9 @@ func (controller *MainController) SettingsPost(c web.C, r *http.Request) (string
return controller.Settings(c, r)
}

// destroy session data
system.DestroySessionsForUserID(dbMap, user.Id)

// send a confirmation email.
err = controller.emailSender.PasswordChangeConfirm(user.Email, controller.baseURL, remoteIP)
if err != nil {
Expand Down Expand Up @@ -2083,8 +2092,11 @@ func (controller *MainController) VotingPost(c web.C, r *http.Request) (string,
func (controller *MainController) Logout(c web.C, r *http.Request) (string, int) {
session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)
if session.Values["UserId"] == nil {
return "/", http.StatusSeeOther
}

session.Values["UserId"] = nil
session.Options.MaxAge = -1

return "/", http.StatusSeeOther
}
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module github.com/decred/dcrstakepool

require (
github.com/DATA-DOG/go-sqlmock v1.3.3
github.com/apoydence/onpar v0.0.0-20190519213022-ee068f8ea4d1 // indirect
github.com/dajohi/goemail v1.0.0
github.com/dchest/captcha v0.0.0-20170622155422-6a29415a8364
Expand All @@ -22,6 +23,7 @@ require (
github.com/golang/protobuf v1.3.2
github.com/gorilla/context v1.1.1
github.com/gorilla/csrf v1.5.1
github.com/gorilla/securecookie v1.1.1
github.com/gorilla/sessions v1.1.3
github.com/jessevdk/go-flags v1.4.0
github.com/jrick/logrotate v1.0.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
github.com/DATA-DOG/go-sqlmock v1.3.3 h1:CWUqKXe0s8A2z6qCgkP4Kru7wC11YoAnoupUKFDnH08=
github.com/DATA-DOG/go-sqlmock v1.3.3/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/aead/siphash v0.0.0-20170329201724-e404fcfc8885/go.mod h1:Nywa3cDsYNNK3gaciGTWPwHt0wlpNV15vwmswBAUSII=
github.com/agl/ed25519 v0.0.0-20170116200512-5312a6153412 h1:w1UutsfOrms1J05zt7ISrnJIXKzwaspym5BTKGx93EI=
github.com/agl/ed25519 v0.0.0-20170116200512-5312a6153412/go.mod h1:WPjqKcmVOxf0XSf3YxCJs6N6AOSrOx3obionmG7T0y0=
Expand Down
10 changes: 10 additions & 0 deletions models/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ type PasswordReset struct {
Expires int64
}

type Session struct {
Id int64 `db:"SessionID"`
Token string
Data []byte
UserId int64
Created int64
Expires int64
}

type User struct {
Id int64 `db:"UserId"`
Email string
Expand Down Expand Up @@ -313,6 +322,7 @@ func GetDbMap(APISecret, baseURL, user, password, hostname, port, database strin
dbMap.AddTableWithName(EmailChange{}, "EmailChange").SetKeys(true, "Id")
dbMap.AddTableWithName(LowFeeTicket{}, "LowFeeTicket").SetKeys(true, "Id")
dbMap.AddTableWithName(PasswordReset{}, "PasswordReset").SetKeys(true, "Id")
dbMap.AddTableWithName(Session{}, "Session").SetKeys(true, "Id")
usersTableName := "Users"
dbMap.AddTableWithName(User{}, usersTableName).SetKeys(true, "Id")

Expand Down
22 changes: 12 additions & 10 deletions system/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type Application struct {
APISecret string
Template *template.Template
TemplatesPath string
Store *sessions.CookieStore
Store *SQLStore
DbMap *gorp.DbMap
}

Expand All @@ -41,15 +41,6 @@ func (application *Application) Init(APISecret string, baseURL string,
DBPassword string,
DBPort string, DBUser string) {

hash := sha256.New()
io.WriteString(hash, cookieSecret)
application.Store = sessions.NewCookieStore(hash.Sum(nil))
application.Store.Options = &sessions.Options{
Path: "/",
HttpOnly: true,
Secure: cookieSecure,
}

application.DbMap = models.GetDbMap(
APISecret,
baseURL,
Expand All @@ -59,6 +50,17 @@ func (application *Application) Init(APISecret string, baseURL string,
DBPort,
DBName)

hash := sha256.New()
io.WriteString(hash, cookieSecret)
application.Store = NewSQLStore(application.DbMap, hash.Sum(nil))
application.Store.Options = &sessions.Options{
Path: "/",
HttpOnly: true,
Secure: cookieSecure,
//six hours
MaxAge: 60 * 60 * 6,
}

application.APISecret = APISecret
}

Expand Down
5 changes: 4 additions & 1 deletion system/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ func (application *Application) ApplyTemplates(c *web.C, h http.Handler) http.Ha
// Makes sure controllers can have access to session
func (application *Application) ApplySessions(c *web.C, h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
session, _ := application.Store.Get(r, "session")
session, err := application.Store.New(r, "session")
if err != nil {
log.Warnf("session load err: %v ", err)
}
c.Env["Session"] = session
h.ServeHTTP(w, r)
}
Expand Down
200 changes: 200 additions & 0 deletions system/sqlstore.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
package system

import (
"bytes"
"database/sql"
"encoding/base32"
"encoding/gob"
"fmt"
"net/http"
"time"

"github.com/decred/dcrstakepool/models"
"github.com/go-gorp/gorp"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
)

// SQLStore stores gorilla sessions in a database.
type SQLStore struct {
Options *sessions.Options
codecs []securecookie.Codec
dbMap *gorp.DbMap
}

// NewSQLStore returns a new SQLStore. The keyPairs are used in the same way as
// the gorilla sessions CookieStore.
func NewSQLStore(dbMap *gorp.DbMap, keyPairs ...[]byte) *SQLStore {
s := &SQLStore{
codecs: securecookie.CodecsFromPairs(keyPairs...),
dbMap: dbMap,
}
// clean db of expired sessions once a day
go func() {
for {
time.Sleep(time.Hour * 24)
if err := s.destroyExpiredSessions(); err != nil {
log.Warn(err)
}
}
}()
return s
}

// Get returns a cached session.
func (s *SQLStore) Get(r *http.Request, name string) (*sessions.Session, error) {
return sessions.GetRegistry(r).Get(s, name)
}

// New creates a new session for the given request r. If the request
// contains a valid session ID for an existing, non-expired session,
// then that session will be loaded from the database.
func (s *SQLStore) New(r *http.Request, name string) (*sessions.Session, error) {
session := sessions.NewSession(s, name)
opts := *s.Options
session.Options = &opts
c, err := r.Cookie(name)
if err != nil {
if err == http.ErrNoCookie {
return session, nil
}
return session, err
}
err = securecookie.DecodeMulti(name, c.Value, &session.ID, s.codecs...)
if err != nil {
// these are not the sessions you are looking for
log.Infof("sqlstore: New: unable to decode cookie: %v", err)
return session, nil
}
err = s.load(session)
if err != nil {
return session, err
}
return session, nil
}

// Save stores the session in the database. If session.Options.MaxAge
// is < 0, the session is deleted from the database.
func (s *SQLStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
if session.Options.MaxAge < 0 {
return s.destroy(session)
}
if len(session.ID) == 0 {
session.ID = base32.StdEncoding.EncodeToString(securecookie.GenerateRandomKey(32))
}
if err := s.save(session); err != nil {
return err
}
// data is not stored in the cookie, only the session id
encoded, err := securecookie.EncodeMulti(session.Name(), &session.ID, s.codecs...)
if err != nil {
return err
}
http.SetCookie(w, sessions.NewCookie(session.Name(), encoded, session.Options))
return nil
}

// load loads the session identified by its ID from the database if it
// exists. If the session has expired, it is destroyed.
func (s *SQLStore) load(session *sessions.Session) error {
var dbSession models.Session
if err := s.dbMap.SelectOne(&dbSession, "SELECT * FROM Session WHERE Token = ?", session.ID); err != nil {
// if no rows are found nothing is done
if err == sql.ErrNoRows {
return nil
}
return fmt.Errorf("Could not select session to destroy: %v", err)
}
if dbSession.Expires < time.Now().Unix() {
return s.destroy(session)
}
// write db Data to session.Values
return gob.NewDecoder(bytes.NewBuffer(dbSession.Data)).Decode(&session.Values)
}

// save checks whether the session is new and inserts if new. Updates if
// not.
func (s *SQLStore) save(session *sessions.Session) error {
var dbSession models.Session
var buf bytes.Buffer
var isNew bool
if err := s.dbMap.SelectOne(&dbSession, "SELECT * FROM Session WHERE Token = ?", session.ID); err != nil {
if err != sql.ErrNoRows {
return fmt.Errorf("Could not select session: %v", err)
}
// no rows found so new
isNew = true
}
if userID, ok := session.Values["UserId"].(int64); ok {
dbSession.UserId = userID
} else {
// all sessions with no user specified are UserId -1
dbSession.UserId = -1
}
if err := gob.NewEncoder(&buf).Encode(session.Values); err != nil {
return err
}
dbSession.Data = buf.Bytes()
if isNew {
now := time.Now().Unix()
dbSession.Token = session.ID
dbSession.Created = now
dbSession.Expires = now + int64(session.Options.MaxAge)
if err := s.dbMap.Insert(&dbSession); err != nil {
return fmt.Errorf("Could not insert session: %v", err)
}
} else {
if _, err := s.dbMap.Update(&dbSession); err != nil {
return fmt.Errorf("Could not update session: %v", err)
}
}
return nil
}

// delete one session from the db
func (s *SQLStore) destroy(session *sessions.Session) error {
var dbSession models.Session
if err := s.dbMap.SelectOne(&dbSession, "SELECT * FROM Session WHERE Token = ?", session.ID); err != nil {
// if no rows are found nothing is done
if err == sql.ErrNoRows {
return nil
}
return fmt.Errorf("Could not select session to destroy: %v", err)
}
if _, err := s.dbMap.Delete(&dbSession); err != nil {
return fmt.Errorf("Could not destroy session: %v", err)
}
return nil
}

// delete expired sessions from the db
func (s *SQLStore) destroyExpiredSessions() error {
var dbSession models.Session
dbSessions, err := s.dbMap.Select(&dbSession, "SELECT * FROM Session WHERE Expires < ?", time.Now().Unix())
if err != nil {
return fmt.Errorf("Could not select expired sessions: %v", err)
}
_, err = s.dbMap.Delete(dbSessions...)
if err != nil {
return fmt.Errorf("Could not destroy expired sessions: %v", err)
}
return nil
}

// DestroySessionsForUserID deletes all sessions from the db for userId
//
// It should be noted that this does not prevent the user's current
// session from being saved again, which can be achieved by setting
// MaxAge to -1
func DestroySessionsForUserID(dbMap *gorp.DbMap, userID int64) error {
var dbSession models.Session
dbSessions, err := dbMap.Select(&dbSession, "SELECT * FROM Session WHERE UserId = ?", userID)
if err != nil {
return fmt.Errorf("Could not select user sessions to destroy: %v", err)
}
_, err = dbMap.Delete(dbSessions...)
if err != nil {
return fmt.Errorf("Could not destroy user sessions: %v", err)
}
return nil
}
Loading

0 comments on commit cce2438

Please sign in to comment.