Skip to content

Commit

Permalink
Create a CSRF object to manage nonces & cookies
Browse files Browse the repository at this point in the history
  • Loading branch information
Nick Meves committed Dec 29, 2020
1 parent ecb5c30 commit 8d0ac72
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 133 deletions.
126 changes: 32 additions & 94 deletions oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"context"
"crypto/hmac"
"encoding/json"
"errors"
"fmt"
Expand All @@ -21,7 +20,6 @@ import (
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/ip"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/middleware"
Expand Down Expand Up @@ -58,17 +56,8 @@ type allowedRoute struct {

// OAuthProxy is the main authentication proxy
type OAuthProxy struct {
CookieSeed string
CookieName string
CSRFCookieName string
CookieDomains []string
CookiePath string
CookieSecure bool
CookieHTTPOnly bool
CookieExpire time.Duration
CookieRefresh time.Duration
CookieSameSite string
Validator func(string) bool
CookieOptions *options.Cookie
Validator func(string) bool

RobotsPath string
SignInPath string
Expand Down Expand Up @@ -180,17 +169,8 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr
}

return &OAuthProxy{
CookieName: opts.Cookie.Name,
CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"),
CookieSeed: opts.Cookie.Secret,
CookieDomains: opts.Cookie.Domains,
CookiePath: opts.Cookie.Path,
CookieSecure: opts.Cookie.Secure,
CookieHTTPOnly: opts.Cookie.HTTPOnly,
CookieExpire: opts.Cookie.Expire,
CookieRefresh: opts.Cookie.Refresh,
CookieSameSite: opts.Cookie.SameSite,
Validator: validator,
CookieOptions: &opts.Cookie,
Validator: validator,

RobotsPath: "/robots.txt",
SignInPath: fmt.Sprintf("%s/sign_in", opts.ProxyPrefix),
Expand Down Expand Up @@ -376,7 +356,7 @@ func (p *OAuthProxy) GetRedirectURI(host string) string {
}
u := *p.redirectURL
if u.Scheme == "" {
if p.CookieSecure {
if p.CookieOptions.Secure {
u.Scheme = httpsScheme
} else {
u.Scheme = httpScheme
Expand Down Expand Up @@ -410,47 +390,6 @@ func (p *OAuthProxy) enrichSessionState(ctx context.Context, s *sessionsapi.Sess
return p.provider.EnrichSession(ctx, s)
}

// MakeCSRFCookie creates a cookie for CSRF
func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
return p.makeCookie(req, p.CSRFCookieName, value, expiration, now)
}

func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
cookieDomain := cookies.GetCookieDomain(req, p.CookieDomains)

if cookieDomain != "" {
domain := util.GetRequestHost(req)
if h, _, err := net.SplitHostPort(domain); err == nil {
domain = h
}
if !strings.HasSuffix(domain, cookieDomain) {
logger.Errorf("Warning: request host is %q but using configured cookie domain of %q", domain, cookieDomain)
}
}

return &http.Cookie{
Name: name,
Value: value,
Path: p.CookiePath,
Domain: cookieDomain,
HttpOnly: p.CookieHTTPOnly,
Secure: p.CookieSecure,
Expires: now.Add(expiration),
SameSite: cookies.ParseSameSite(p.CookieSameSite),
}
}

// ClearCSRFCookie creates a cookie to unset the CSRF cookie stored in the user's
// session
func (p *OAuthProxy) ClearCSRFCookie(rw http.ResponseWriter, req *http.Request) {
http.SetCookie(rw, p.MakeCSRFCookie(req, "", time.Hour*-1, time.Now()))
}

// SetCSRFCookie adds a CSRF cookie to the response
func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, p.MakeCSRFCookie(req, val, p.CookieExpire, time.Now()))
}

// ClearSessionCookie creates a cookie to unset the user's authentication cookie
// stored in the user's session
func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) error {
Expand Down Expand Up @@ -832,19 +771,13 @@ func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
prepareNoCache(rw)

// nonces holds the OAuth state CSRF nonce & the OIDC nonce
nonces := []string{"", ""}
for i := range nonces {
var err error
nonces[i], err = encryption.Nonce()
if err != nil {
logger.Errorf("Error obtaining nonce: %v", err)
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
return
}
csrf, err := cookies.NewCSRF(p.CookieOptions)
if err != nil {
logger.Errorf("Error creating CSRF nonce: %v", err)
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error")
return
}

p.SetCSRFCookie(rw, req, strings.Join(nonces, ":"))
redirect, err := p.GetRedirect(req)
if err != nil {
logger.Errorf("Error obtaining redirect: %v", err)
Expand All @@ -853,7 +786,15 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
}

redirectURI := p.GetRedirectURI(util.GetRequestHost(req))
loginURL := p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonces[0], redirect), nonces[1])
loginURL := p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", csrf.State, redirect), csrf.Nonce)

err = csrf.SetCookie(rw, req)
if err != nil {
logger.Errorf("Error setting CSRF cookie: %v", err)
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error")
return
}

http.Redirect(rw, req, loginURL, http.StatusFound)
}

Expand Down Expand Up @@ -890,35 +831,32 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
return
}

state := strings.SplitN(req.Form.Get("state"), ":", 2)
if len(state) != 2 {
logger.Error("Error while parsing OAuth2 state: invalid length")
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Invalid State")
return
}
stateNonce := state[0]
redirect := state[1]
csrf, err := req.Cookie(p.CSRFCookieName)
csrf, err := cookies.LoadCSRFCookie(req, p.CookieOptions)
if err != nil {
logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: unable to obtain CSRF cookie")
p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", err.Error())
return
}
p.ClearCSRFCookie(rw, req)
csrfNonces := strings.Split(csrf.Value, ":")
if len(csrfNonces) != 2 {
logger.Error("Error while parsing CSRF cookie: invalid length")
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error")

csrf.ClearCookie(rw, req)

state := strings.SplitN(req.Form.Get("state"), ":", 2)
if len(state) != 2 {
logger.Error("Error while parsing OAuth2 state: invalid length")
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Invalid State")
return
}

if hmac.Equal([]byte(csrfNonces[0]), []byte(stateNonce)) {
nonce := state[0]
redirect := state[1]

if !csrf.CheckState(nonce) {
logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2: CSRF token mismatch, potential attack")
p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", "CSRF Failed")
return
}

session.Nonce = csrfNonces[1]
session.Nonce = csrf.Nonce
p.provider.ValidateSession(req.Context(), session)

if !p.IsValidRedirect(redirect) {
Expand Down
31 changes: 23 additions & 8 deletions oauthproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/mbland/hmacauth"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
sessionscookie "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/cookie"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/upstream"
Expand Down Expand Up @@ -696,23 +697,37 @@ func (patTest *PassAccessTokenTest) Close() {
patTest.providerServer.Close()
}

func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int,
cookie string) {
func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int, cookie string) {
rw := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
strings.NewReader(""))
if err != nil {
return 0, ""
}
req.AddCookie(patTest.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now()))

csrf, err := cookies.NewCSRF(patTest.proxy.CookieOptions)
if err != nil {
panic(err)
}
csrf.State = "nonce"
val, err := csrf.EncodeCookie()
if err != nil {
panic(err)
}
req.AddCookie(&http.Cookie{
Name: csrf.CookieName(),
Value: val,
})

patTest.proxy.ServeHTTP(rw, req)

return rw.Code, rw.Header().Values("Set-Cookie")[1]
}

// getEndpointWithCookie makes a requests againt the oauthproxy with passed requestPath
// and cookie and returns body and status code.
func (patTest *PassAccessTokenTest) getEndpointWithCookie(cookie string, endpoint string) (httpCode int, accessToken string) {
cookieName := patTest.proxy.CookieName
cookieName := patTest.proxy.CookieOptions.Name
var value string
keyPrefix := cookieName + "="

Expand Down Expand Up @@ -983,7 +998,7 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi

// Now, zero-out proxy.CookieRefresh for the cases that don't involve
// access_token validation.
pcTest.proxy.CookieRefresh = time.Duration(0)
pcTest.proxy.CookieOptions.Refresh = time.Duration(0)
pcTest.rw = httptest.NewRecorder()
pcTest.req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
pcTest.validateUser = true
Expand Down Expand Up @@ -1105,7 +1120,7 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
err = pcTest.SaveSession(startSession)
assert.NoError(t, err)

pcTest.proxy.CookieRefresh = time.Hour
pcTest.proxy.CookieOptions.Refresh = time.Hour
session, err := pcTest.LoadCookiedSession()
assert.NotEqual(t, nil, err)
if session != nil {
Expand Down Expand Up @@ -1877,7 +1892,7 @@ func TestClearSplitCookie(t *testing.T) {
t.Fatal(err)
}

p := OAuthProxy{CookieName: opts.Cookie.Name, CookieDomains: opts.Cookie.Domains, sessionStore: store}
p := OAuthProxy{CookieOptions: &opts.Cookie, sessionStore: store}
var rw = httptest.NewRecorder()
req := httptest.NewRequest("get", "/", nil)

Expand Down Expand Up @@ -1910,7 +1925,7 @@ func TestClearSingleCookie(t *testing.T) {
t.Fatal(err)
}

p := OAuthProxy{CookieName: opts.Cookie.Name, CookieDomains: opts.Cookie.Domains, sessionStore: store}
p := OAuthProxy{CookieOptions: &opts.Cookie, sessionStore: store}
var rw = httptest.NewRecorder()
req := httptest.NewRequest("get", "/", nil)

Expand Down
65 changes: 34 additions & 31 deletions pkg/cookies/cookies.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,46 +12,33 @@ import (
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util"
)

// MakeCookie constructs a cookie from the given parameters,
// discovering the domain from the request if not specified.
func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time, sameSite http.SameSite) *http.Cookie {
if domain != "" {
host := util.GetRequestHost(req)
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
if !strings.HasSuffix(host, domain) {
logger.Errorf("Warning: request host is %q but using configured cookie domain of %q", host, domain)
}
// MakeCookieFromOptions constructs a cookie based on the given *options.CookieOptions,
// value and creation time
func MakeCookieFromOptions(req *http.Request, name string, value string, opts *options.Cookie, expiration time.Duration, now time.Time) *http.Cookie {
domain := GetCookieDomain(req, opts.Domains)
// If nothing matches, create the cookie with the shortest domain
if domain == "" && len(opts.Domains) > 0 {
logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q",
util.GetRequestHost(req),
strings.Join(opts.Domains, ","),
)
domain = opts.Domains[len(opts.Domains)-1]
}

return &http.Cookie{
c := &http.Cookie{
Name: name,
Value: value,
Path: path,
Path: opts.Path,
Domain: domain,
HttpOnly: httpOnly,
Secure: secure,
Expires: now.Add(expiration),
SameSite: sameSite,
HttpOnly: opts.HTTPOnly,
Secure: opts.Secure,
SameSite: ParseSameSite(opts.SameSite),
}
}

// MakeCookieFromOptions constructs a cookie based on the given *options.CookieOptions,
// value and creation time
func MakeCookieFromOptions(req *http.Request, name string, value string, cookieOpts *options.Cookie, expiration time.Duration, now time.Time) *http.Cookie {
domain := GetCookieDomain(req, cookieOpts.Domains)
WarnInvalidDomain(c, req)

if domain != "" {
return MakeCookie(req, name, value, cookieOpts.Path, domain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite))
}
// If nothing matches, create the cookie with the shortest domain
defaultDomain := ""
if len(cookieOpts.Domains) > 0 {
logger.Errorf("Warning: request host %q did not match any of the specific cookie domains of %q", util.GetRequestHost(req), strings.Join(cookieOpts.Domains, ","))
defaultDomain = cookieOpts.Domains[len(cookieOpts.Domains)-1]
}
return MakeCookie(req, name, value, cookieOpts.Path, defaultDomain, cookieOpts.HTTPOnly, cookieOpts.Secure, expiration, now, ParseSameSite(cookieOpts.SameSite))
return c
}

// GetCookieDomain returns the correct cookie domain given a list of domains
Expand Down Expand Up @@ -81,3 +68,19 @@ func ParseSameSite(v string) http.SameSite {
panic(fmt.Sprintf("Invalid value for SameSite: %s", v))
}
}

// WarnInvalidDomain logs a warning if the request host and cookie domain are
// mismatched.
func WarnInvalidDomain(c *http.Cookie, req *http.Request) {
if c.Domain == "" {
return
}

host := util.GetRequestHost(req)
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
if !strings.HasSuffix(host, c.Domain) {
logger.Errorf("Warning: request host is %q but using configured cookie domain of %q", host, c.Domain)
}
}

0 comments on commit 8d0ac72

Please sign in to comment.