Skip to content

Commit

Permalink
Only request a reason from the client if the server requires it
Browse files Browse the repository at this point in the history
  • Loading branch information
nsheridan committed Aug 9, 2018
1 parent 347c11e commit d21fac6
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 31 deletions.
57 changes: 43 additions & 14 deletions client/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"bufio"
"bytes"
"crypto/tls"
"encoding/base64"
Expand All @@ -10,7 +11,9 @@ import (
"io/ioutil"
"net/http"
"net/url"
"os"
"path"
"strings"
"time"

"github.com/nsheridan/cashier/lib"
Expand All @@ -19,6 +22,10 @@ import (
"golang.org/x/crypto/ssh/agent"
)

var (
errNeedsReason = errors.New("reason required")
)

// SavePublicFiles installs the public part of the cert and key.
func SavePublicFiles(prefix string, cert *ssh.Certificate, pub ssh.PublicKey) error {
if prefix == "" {
Expand Down Expand Up @@ -77,7 +84,11 @@ func InstallCert(a agent.Agent, cert *ssh.Certificate, key Key) error {
}

// send the signing request to the CA.
func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignResponse, error) {
func send(sr *lib.SignRequest, token, ca string, ValidateTLSCertificate bool) (*lib.SignResponse, error) {
s, err := json.Marshal(sr)
if err != nil {
return nil, errors.Wrap(err, "unable to create sign request")
}
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: !ValidateTLSCertificate},
}
Expand All @@ -99,33 +110,51 @@ func send(s []byte, token, ca string, ValidateTLSCertificate bool) (*lib.SignRes
return nil, err
}
defer resp.Body.Close()
signResponse := &lib.SignResponse{}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Bad response from server: %s", resp.Status)
if resp.StatusCode == http.StatusForbidden && strings.HasPrefix(resp.Header.Get("X-Need-Reason"), "required") {
return signResponse, errNeedsReason
}
return signResponse, fmt.Errorf("bad response from server: %s", resp.Status)
}
c := &lib.SignResponse{}
if err := json.NewDecoder(resp.Body).Decode(c); err != nil {
if err := json.NewDecoder(resp.Body).Decode(signResponse); err != nil {
return nil, errors.Wrap(err, "unable to decode server response")
}
return c, nil
return signResponse, nil
}

func promptForReason() (message string) {
fmt.Print("Enter message: ")
scanner := bufio.NewScanner(os.Stdin)
if scanner.Scan() {
message = scanner.Text()
}
return message
}

// Sign sends the public key to the CA to be signed.
func Sign(pub ssh.PublicKey, token string, message string, conf *Config) (*ssh.Certificate, error) {
func Sign(pub ssh.PublicKey, token string, conf *Config) (*ssh.Certificate, error) {
var err error
validity, err := time.ParseDuration(conf.Validity)
if err != nil {
return nil, err
}
s, err := json.Marshal(&lib.SignRequest{
s := &lib.SignRequest{
Key: string(lib.GetPublicKey(pub)),
ValidUntil: time.Now().Add(validity),
Message: message,
})
if err != nil {
return nil, errors.Wrap(err, "unable to create sign request")
}
resp, err := send(s, token, conf.CA, conf.ValidateTLSCertificate)
if err != nil {
return nil, errors.Wrap(err, "error sending request to CA")
resp := &lib.SignResponse{}
for {
resp, err = send(s, token, conf.CA, conf.ValidateTLSCertificate)
if err == nil {
break
}
if err != nil && err == errNeedsReason {
s.Message = promptForReason()
continue
} else if err != nil {
return nil, errors.Wrap(err, "error sending request to CA")
}
}
if resp.Status != "ok" {
return nil, fmt.Errorf("bad response from CA: %s", resp.Response)
Expand Down
8 changes: 4 additions & 4 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestSignGood(t *testing.T) {
fmt.Fprintln(w, string(j))
}))
defer ts.Close()
_, err := send([]byte(`{}`), "token", ts.URL, true)
_, err := send(&lib.SignRequest{}, "token", ts.URL, true)
if err != nil {
t.Error(err)
}
Expand All @@ -79,7 +79,7 @@ func TestSignGood(t *testing.T) {
CA: ts.URL,
Validity: "24h",
}
cert, err := Sign(k, "token", "message", c)
cert, err := Sign(k, "token", c)
if cert == nil && err != nil {
t.Error(err)
}
Expand All @@ -95,7 +95,7 @@ func TestSignBad(t *testing.T) {
fmt.Fprintln(w, string(j))
}))
defer ts.Close()
_, err := send([]byte(`{}`), "token", ts.URL, true)
_, err := send(&lib.SignRequest{}, "token", ts.URL, true)
if err != nil {
t.Error(err)
}
Expand All @@ -107,7 +107,7 @@ func TestSignBad(t *testing.T) {
CA: ts.URL,
Validity: "24h",
}
cert, err := Sign(k, "token", "message", c)
cert, err := Sign(k, "token", c)
if cert != nil && err == nil {
t.Error(err)
}
Expand Down
10 changes: 1 addition & 9 deletions cmd/cashier/main.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package main

import (
"bufio"
"fmt"
"log"
"net"
Expand Down Expand Up @@ -50,14 +49,7 @@ func main() {
var token string
fmt.Scanln(&token)

var message string
fmt.Print("Enter message: ")
scanner := bufio.NewScanner(os.Stdin)
if scanner.Scan() {
message = scanner.Text()
}

cert, err := client.Sign(pub, token, message, c)
cert, err := client.Sign(pub, token, c)
if err != nil {
log.Fatalln(err)
}
Expand Down
1 change: 1 addition & 0 deletions server/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type Server struct {
CSRFSecret string `hcl:"csrf_secret"`
HTTPLogFile string `hcl:"http_logfile"`
Database Database `hcl:"database"`
RequireReason bool `hcl:"require_reason"`
}

// Auth holds the configuration specific to the OAuth provider.
Expand Down
16 changes: 12 additions & 4 deletions server/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ import (

// appContext contains local context - cookiestore, authsession etc.
type appContext struct {
cookiestore *sessions.CookieStore
authsession *auth.Session
cookiestore *sessions.CookieStore
authsession *auth.Session
requireReason bool
}

// getAuthTokenCookie retrieves a cookie from the request.
Expand Down Expand Up @@ -141,6 +142,12 @@ func signHandler(a *appContext, w http.ResponseWriter, r *http.Request) (int, er
if err != nil {
return http.StatusBadRequest, errors.Wrap(err, "unable to extract key from request")
}

if a.requireReason && req.Message == "" {
w.Header().Add("X-Need-Reason", "required")
return http.StatusForbidden, errors.New(http.StatusText(http.StatusForbidden))
}

username := authprovider.Username(token)
authprovider.Revoke(token) // We don't need this anymore.
cert, err := keysigner.SignUserKey(req, username)
Expand Down Expand Up @@ -266,7 +273,6 @@ type appHandler struct {
func (ah appHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
status, err := ah.h(ah.appContext, w, r)
if err != nil {
log.Printf("HTTP %d: %q", status, err)
http.Error(w, err.Error(), status)
}
}
Expand All @@ -283,7 +289,8 @@ func newState() string {
func runHTTPServer(conf *config.Server, l net.Listener) {
var err error
ctx := &appContext{
cookiestore: sessions.NewCookieStore([]byte(conf.CookieSecret)),
cookiestore: sessions.NewCookieStore([]byte(conf.CookieSecret)),
requireReason: conf.RequireReason,
}
ctx.cookiestore.Options = &sessions.Options{
MaxAge: 900,
Expand Down Expand Up @@ -313,6 +320,7 @@ func runHTTPServer(conf *config.Server, l net.Listener) {
r.Methods("GET").Path("/admin/certs.json").Handler(appHandler{ctx, listCertsJSONHandler})
r.Methods("GET").Path("/metrics").Handler(promhttp.Handler())
r.Methods("GET").Path("/healthcheck").HandlerFunc(healthcheck)

box := packr.NewBox("static")
r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(box)))
h := handlers.LoggingHandler(logfile, r)
Expand Down

0 comments on commit d21fac6

Please sign in to comment.