Skip to content

Commit

Permalink
server: use gorilla for csrf (#338)
Browse files Browse the repository at this point in the history
  • Loading branch information
jholdstock authored and dajohi committed Apr 24, 2019
1 parent 2cce476 commit 196af94
Show file tree
Hide file tree
Showing 17 changed files with 115 additions and 205 deletions.
22 changes: 22 additions & 0 deletions controllers/main.go
Expand Up @@ -31,6 +31,7 @@ import (
wallettypes "github.com/decred/dcrwallet/rpc/jsonrpc/types"
"github.com/decred/dcrwallet/wallet/v2/udb"
"github.com/go-gorp/gorp"
"github.com/gorilla/csrf"
"github.com/zenazn/goji/web"

"google.golang.org/grpc"
Expand Down Expand Up @@ -451,6 +452,7 @@ func (controller *MainController) APIVoting(c web.C, r *http.Request) ([]string,
func (controller *MainController) isAdmin(c web.C, r *http.Request) (bool, error) {
remoteIP := getClientIP(r, controller.realIPHeader)
session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)

if session.Values["UserId"] == nil {
return false, fmt.Errorf("%s request with no session from %s",
Expand Down Expand Up @@ -755,6 +757,7 @@ func (controller *MainController) handlePotentialFatalError(fn string, err error
func (controller *MainController) Address(c web.C, r *http.Request) (string, int) {
t := controller.GetTemplate(c)
session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)

if session.Values["UserId"] == nil {
return "/", http.StatusSeeOther
Expand All @@ -776,6 +779,7 @@ func (controller *MainController) Address(c web.C, r *http.Request) (string, int
// AddressPost is address form submit route.
func (controller *MainController) AddressPost(c web.C, r *http.Request) (string, int) {
session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)
remoteIP := getClientIP(r, controller.realIPHeader)

if session.Values["UserId"] == nil {
Expand Down Expand Up @@ -1020,6 +1024,7 @@ func (controller *MainController) AdminStatus(c web.C, r *http.Request) (string,
func (controller *MainController) AdminTickets(c web.C, r *http.Request) (string, int) {
t := controller.GetTemplate(c)
session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)
dbMap := controller.GetDbMap(c)

isAdmin, err := controller.isAdmin(c, r)
Expand Down Expand Up @@ -1057,6 +1062,7 @@ func (controller *MainController) AdminTickets(c web.C, r *http.Request) (string
// AdminTicketsPost validates and processes the form posted from AdminTickets.
func (controller *MainController) AdminTicketsPost(c web.C, r *http.Request) (string, int) {
session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)
dbMap := controller.GetDbMap(c)
remoteIP := getClientIP(r, controller.realIPHeader)

Expand Down Expand Up @@ -1181,6 +1187,7 @@ func (controller *MainController) AdminTicketsPost(c web.C, r *http.Request) (st
func (controller *MainController) EmailUpdate(c web.C, r *http.Request) (string, int) {
t := controller.GetTemplate(c)
session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)
dbMap := controller.GetDbMap(c)

render := func() string {
Expand Down Expand Up @@ -1252,6 +1259,7 @@ func (controller *MainController) EmailUpdate(c web.C, r *http.Request) (string,
func (controller *MainController) EmailVerify(c web.C, r *http.Request) (string, int) {
t := controller.GetTemplate(c)
session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)
dbMap := controller.GetDbMap(c)

render := func() string {
Expand Down Expand Up @@ -1354,6 +1362,7 @@ func (controller *MainController) Index(c web.C, r *http.Request) (string, int)
func (controller *MainController) PasswordReset(c web.C, r *http.Request) (string, int) {
c.Env["Title"] = "Decred Voting Service - Password Reset"
session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)
c.Env["FlashError"] = append(session.Flashes("passwordresetError"), session.Flashes("captchaFailed")...)
c.Env["FlashSuccess"] = session.Flashes("passwordresetSuccess")
c.Env["IsPasswordReset"] = true
Expand All @@ -1374,6 +1383,7 @@ func (controller *MainController) PasswordReset(c web.C, r *http.Request) (strin
func (controller *MainController) PasswordResetPost(c web.C, r *http.Request) (string, int) {
email := r.FormValue("email")
session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)
dbMap := controller.GetDbMap(c)

if !controller.IsCaptchaDone(c) {
Expand Down Expand Up @@ -1430,6 +1440,7 @@ func (controller *MainController) PasswordResetPost(c web.C, r *http.Request) (s
func (controller *MainController) PasswordUpdate(c web.C, r *http.Request) (string, int) {
t := controller.GetTemplate(c)
session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)

render := func() string {
c.Env["Title"] = "Decred Voting Service - Password Update"
Expand All @@ -1456,6 +1467,7 @@ func (controller *MainController) PasswordUpdate(c web.C, r *http.Request) (stri
// the password reset email. The token is validated and the password is changed.
func (controller *MainController) PasswordUpdatePost(c web.C, r *http.Request) (string, int) {
session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)
dbMap := controller.GetDbMap(c)
remoteIP := getClientIP(r, controller.realIPHeader)

Expand Down Expand Up @@ -1512,6 +1524,7 @@ func (controller *MainController) PasswordUpdatePost(c web.C, r *http.Request) (
// Settings renders the settings page.
func (controller *MainController) Settings(c web.C, r *http.Request) (string, int) {
session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)
dbMap := controller.GetDbMap(c)

if session.Values["UserId"] == nil {
Expand Down Expand Up @@ -1554,6 +1567,7 @@ func (controller *MainController) Settings(c web.C, r *http.Request) (string, in
// SettingsPost handles changing the user's email address or password.
func (controller *MainController) SettingsPost(c web.C, r *http.Request) (string, int) {
session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)
dbMap := controller.GetDbMap(c)
remoteIP := getClientIP(r, controller.realIPHeader)

Expand Down Expand Up @@ -1661,6 +1675,7 @@ func (controller *MainController) SettingsPost(c web.C, r *http.Request) (string
func (controller *MainController) SignIn(c web.C, r *http.Request) (string, int) {
t := controller.GetTemplate(c)
session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)

// Tell main.html what route is being rendered
c.Env["IsSignIn"] = true
Expand All @@ -1680,6 +1695,7 @@ func (controller *MainController) SignInPost(c web.C, r *http.Request) (string,
email, password := r.FormValue("email"), r.FormValue("password")

session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)
dbMap := controller.GetDbMap(c)
remoteIP := getClientIP(r, controller.realIPHeader)

Expand Down Expand Up @@ -1721,6 +1737,7 @@ func (controller *MainController) SignUp(c web.C, r *http.Request) (string, int)
}

session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)
c.Env["FlashError"] = append(session.Flashes("signupError"), session.Flashes("captchaFailed")...)
c.Env["FlashSuccess"] = session.Flashes("signupSuccess")
c.Env["CaptchaID"] = captcha.New()
Expand All @@ -1742,6 +1759,7 @@ func (controller *MainController) SignUpPost(c web.C, r *http.Request) (string,
}

session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)
if !controller.IsCaptchaDone(c) {
session.AddFlash("You must complete the captcha.", "signupError")
return controller.SignUp(c, r)
Expand Down Expand Up @@ -1904,6 +1922,7 @@ func (controller *MainController) Tickets(c web.C, r *http.Request) (string, int

t := controller.GetTemplate(c)
session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)
remoteIP := getClientIP(r, controller.realIPHeader)

if session.Values["UserId"] == nil {
Expand Down Expand Up @@ -2032,6 +2051,7 @@ func (controller *MainController) Tickets(c web.C, r *http.Request) (string, int
// Voting renders the voting page.
func (controller *MainController) Voting(c web.C, r *http.Request) (string, int) {
session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)
dbMap := controller.GetDbMap(c)

if session.Values["UserId"] == nil {
Expand Down Expand Up @@ -2065,6 +2085,7 @@ func (controller *MainController) Voting(c web.C, r *http.Request) (string, int)
// VotingPost form submit route.
func (controller *MainController) VotingPost(c web.C, r *http.Request) (string, int) {
session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)
dbMap := controller.GetDbMap(c)

if session.Values["UserId"] == nil {
Expand Down Expand Up @@ -2118,6 +2139,7 @@ func (controller *MainController) VotingPost(c web.C, r *http.Request) (string,
// Logout the user.
func (controller *MainController) Logout(c web.C, r *http.Request) (string, int) {
session := controller.GetSession(c)
c.Env[csrf.TemplateTag] = csrf.TemplateField(r)

session.Values["UserId"] = nil

Expand Down
1 change: 1 addition & 0 deletions go.mod
Expand Up @@ -21,6 +21,7 @@ require (
github.com/go-sql-driver/mysql v1.4.0
github.com/golang/protobuf v1.2.0
github.com/gorilla/context v1.1.1
github.com/gorilla/csrf v1.5.1
github.com/gorilla/sessions v1.1.2
github.com/jessevdk/go-flags v1.4.0
github.com/jrick/logrotate v1.0.0
Expand Down
5 changes: 5 additions & 0 deletions go.sum
Expand Up @@ -118,6 +118,8 @@ github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/csrf v1.5.1 h1:UASc2+EB0T51tvl6/2ls2ciA8/qC7KdTO7DsOEKbttQ=
github.com/gorilla/csrf v1.5.1/go.mod h1:HTDW7xFOO1aHddQUmghe9/2zTvg7AYCnRCs7MxTGu/0=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.1.2 h1:4esMHhwKLQ9Odtku/p+onvH+eRJFWjV4y3iTDVWrZNU=
Expand All @@ -143,10 +145,13 @@ github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK86
github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/onsi/ginkgo v1.6.0 h1:Ix8l273rp3QzYgXSR+c8d1fTG7UPgYkOSELPhiY/YGw=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/gomega v1.4.1 h1:PZSj/UFNaVp3KxrzHOcS7oyuWA7LoOY/77yCTEFu21U=
github.com/onsi/gomega v1.4.1/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA=
github.com/onsi/gomega v1.4.3 h1:RE1xgDvH7imwFD45h+u2SgIfERHlS2yNG4DObb5BSKU=
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/poy/onpar v0.0.0-20181125144932-f2f06780798d h1:dG+BIeP2sDXC2AsX+P1Xvftxy+b4iYTfrDDkAIQB7VM=
Expand Down
4 changes: 2 additions & 2 deletions sample-dcrstakepool.conf
Expand Up @@ -11,8 +11,8 @@ adminips=127.0.0.1
; Multiple values can be used and are separated by a comma.
;adminuserids=1,2,3

; Secret string used to encrypt API tokens. Can use openssl rand -hex 32
; to generate one.
; Secret string used to encrypt API and to generate CSRF tokens.
; Can use openssl rand -hex 32 to generate one.
;apisecret=

; baseurl to use when emailing verification links.
Expand Down
127 changes: 68 additions & 59 deletions server.go
Expand Up @@ -14,6 +14,7 @@ import (
"google.golang.org/grpc"

"github.com/gorilla/context"
"github.com/gorilla/csrf"

"github.com/decred/dcrd/rpcclient/v2"
"github.com/decred/dcrstakepool/controllers"
Expand Down Expand Up @@ -82,31 +83,6 @@ func runMain() int {

rpcclient.UseLogger(log)

// Setup static files
assetHandler := http.StripPrefix("/assets/",
http.FileServer(http.Dir(cfg.PublicPath)))

// Apply middleware
app := web.New()
app.Handle("/assets/*", assetHandler)

app.Use(middleware.RequestID)
app.Use(middleware.Logger) // TODO: reimplement to use our logger
app.Use(middleware.Recoverer)

// Execute various middleware functions. The order is very important
// as each function establishes part of the application environment/context
// that the next function will assume has been setup successfully.
app.Use(application.ApplyTemplates)
app.Use(application.ApplySessions)
app.Use(application.ApplyCaptcha) // must be after ApplySessions
app.Use(application.ApplyDbMap)
app.Use(application.ApplyAPI)
app.Use(application.ApplyAuth) // must be after ApplySessions
app.Use(application.ApplyIsXhr)
app.Use(application.ApplyCsrfProtection) // must be after ApplySessions
app.Use(context.ClearHandler)

// Supported API versions are advertised in the API stats result
APIVersionsSupported := []int{1, 2}

Expand Down Expand Up @@ -194,73 +170,106 @@ func runMain() int {

controller.RPCStart()

// Couple of files - in the real world you would use nginx to serve them.
app.Get("/robots.txt", http.FileServer(http.Dir(cfg.PublicPath)))
app.Get("/favicon.ico", http.FileServer(http.Dir(cfg.PublicPath+"/images")))
// Set up web server routes
app := web.New()

// Middlewares used by app are applied to all routes (HTML and API)
app.Use(middleware.RequestID)
app.Use(middleware.Logger) // TODO: reimplement to use our logger
app.Use(middleware.Recoverer)

app.Use(application.ApplyDbMap)

app.Use(context.ClearHandler)

// API routes
api := web.New()

api.Use(application.ApplyAPI)

api.Handle("/api/v1/:command", application.APIHandler(controller.API))
api.Handle("/api/v2/:command", application.APIHandler(controller.API))
api.Handle("/api/*", gojify(system.APIInvalidHandler))

// HTML routes
html := web.New()

// Execute various middleware functions. The order is very important
// as each function establishes part of the application environment/context
// that the next function will assume has been setup successfully.
html.Use(application.ApplyTemplates)
html.Use(application.ApplySessions)
html.Use(application.ApplyCaptcha) // must be after ApplySessions
html.Use(application.ApplyAuth) // must be after ApplySessions
html.Use(csrf.Protect([]byte(cfg.APISecret), csrf.Secure(cfg.CookieSecure)))

// Setup static files
html.Get("/assets/*", http.StripPrefix("/assets/",
http.FileServer(http.Dir(cfg.PublicPath))))
html.Get("/robots.txt", http.FileServer(http.Dir(cfg.PublicPath)))
html.Get("/favicon.ico", http.FileServer(http.Dir(cfg.PublicPath+"/images")))

// Home page
app.Get("/", application.Route(controller, "Index"))
html.Get("/", application.Route(controller, "Index"))

// Admin tickets page
app.Get("/admintickets", application.Route(controller, "AdminTickets"))
app.Post("/admintickets", application.Route(controller, "AdminTicketsPost"))
html.Get("/admintickets", application.Route(controller, "AdminTickets"))
html.Post("/admintickets", application.Route(controller, "AdminTicketsPost"))
// Admin status page
app.Get("/status", application.Route(controller, "AdminStatus"))
html.Get("/status", application.Route(controller, "AdminStatus"))

// Address form
app.Get("/address", application.Route(controller, "Address"))
app.Post("/address", application.Route(controller, "AddressPost"))

// API
app.Handle("/api/v1/:command", application.APIHandler(controller.API))
app.Handle("/api/v2/:command", application.APIHandler(controller.API))
app.Handle("/api/*", gojify(system.APIInvalidHandler))
html.Get("/address", application.Route(controller, "Address"))
html.Post("/address", application.Route(controller, "AddressPost"))

// Email change/update confirmation
app.Get("/emailupdate", application.Route(controller, "EmailUpdate"))
html.Get("/emailupdate", application.Route(controller, "EmailUpdate"))

// Email verification
app.Get("/emailverify", application.Route(controller, "EmailVerify"))
html.Get("/emailverify", application.Route(controller, "EmailVerify"))

// Error page
app.Get("/error", application.Route(controller, "Error"))
html.Get("/error", application.Route(controller, "Error"))

// Password Reset routes
app.Get("/passwordreset", application.Route(controller, "PasswordReset"))
app.Post("/passwordreset", application.Route(controller, "PasswordResetPost"))
html.Get("/passwordreset", application.Route(controller, "PasswordReset"))
html.Post("/passwordreset", application.Route(controller, "PasswordResetPost"))

// Password Update routes
app.Get("/passwordupdate", application.Route(controller, "PasswordUpdate"))
app.Post("/passwordupdate", application.Route(controller, "PasswordUpdatePost"))
html.Get("/passwordupdate", application.Route(controller, "PasswordUpdate"))
html.Post("/passwordupdate", application.Route(controller, "PasswordUpdatePost"))

// Settings routes
app.Get("/settings", application.Route(controller, "Settings"))
app.Post("/settings", application.Route(controller, "SettingsPost"))
html.Get("/settings", application.Route(controller, "Settings"))
html.Post("/settings", application.Route(controller, "SettingsPost"))

// Sign In routes
app.Get("/signin", application.Route(controller, "SignIn"))
app.Post("/signin", application.Route(controller, "SignInPost"))
html.Get("/signin", application.Route(controller, "SignIn"))
html.Post("/signin", application.Route(controller, "SignInPost"))

// Sign Up routes
app.Get("/signup", application.Route(controller, "SignUp"))
app.Post("/signup", application.Route(controller, "SignUpPost"))
html.Get("/signup", application.Route(controller, "SignUp"))
html.Post("/signup", application.Route(controller, "SignUpPost"))

// Captcha
app.Get("/captchas/*", controller.CaptchaServe)
app.Post("/verifyhuman", controller.CaptchaVerify)
html.Get("/captchas/*", controller.CaptchaServe)
html.Post("/verifyhuman", controller.CaptchaVerify)

// Stats
app.Get("/stats", application.Route(controller, "Stats"))
html.Get("/stats", application.Route(controller, "Stats"))

// Tickets
app.Get("/tickets", application.Route(controller, "Tickets"))
html.Get("/tickets", application.Route(controller, "Tickets"))

// Voting routes
app.Get("/voting", application.Route(controller, "Voting"))
app.Post("/voting", application.Route(controller, "VotingPost"))
html.Get("/voting", application.Route(controller, "Voting"))
html.Post("/voting", application.Route(controller, "VotingPost"))

// KTHXBYE
app.Get("/logout", application.Route(controller, "Logout"))
html.Get("/logout", application.Route(controller, "Logout"))

app.Handle("/api/*", api)
app.Handle("/*", html)

graceful.PostHook(func() {
controller.RPCStop()
Expand Down

0 comments on commit 196af94

Please sign in to comment.