Skip to content

Commit

Permalink
Merge pull request #301 from mozilla-services/stability-fixes
Browse files Browse the repository at this point in the history
database: error instead of dying on monitor failure
  • Loading branch information
g-k committed Jul 23, 2019
2 parents 1836aa8 + 76f64b9 commit ca1009d
Show file tree
Hide file tree
Showing 11 changed files with 334 additions and 67 deletions.
2 changes: 1 addition & 1 deletion autograph.yaml
Expand Up @@ -29,6 +29,7 @@ statsd:
# sslrootcert: /etc/ssl/certs/db-root.crt
# maxopenconns: 100
# maxidleconns: 10
# monitorpollinterval: 10s

# The keys below are testing keys that do not grant any power
signers:
Expand Down Expand Up @@ -513,7 +514,6 @@ signers:
- id: testmar
type: mar
# label of the private key in the hsm, it isn't stored locally
privatekey: |
-----BEGIN PRIVATE KEY-----
MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCAUrUDTS86CuqV
Expand Down
59 changes: 38 additions & 21 deletions database/connect.go
Expand Up @@ -34,14 +34,15 @@ type Transaction struct {

// Config holds the parameters to connect to a database
type Config struct {
Name string
User string
Password string
Host string
SSLMode string
SSLRootCert string
MaxOpenConns int
MaxIdleConns int
Name string
User string
Password string
Host string
SSLMode string
SSLRootCert string
MaxOpenConns int
MaxIdleConns int
MonitorPollInterval time.Duration
}

// Connect creates a database connection and returns a handler
Expand Down Expand Up @@ -70,20 +71,36 @@ func Connect(config Config) (*Handler, error) {
return &Handler{dbfd}, nil
}

// Monitor runs an infinite loop that queries the database every 10 seconds
// and panics if the query fails. It can be used in a goroutine to crash when
// the database becomes unavailable.
func (db *Handler) Monitor() {
// simple DB watchdog, crashes the process if connection dies
// CheckConnection runs a test query against the database and returns
// an error if it fails
func (db *Handler) CheckConnection() error {
var one uint
err := db.QueryRow("SELECT 1").Scan(&one)
if err != nil {
return errors.Wrap(err, "Database connection failed")
}
if one != 1 {
return errors.Errorf("Apparently the database doesn't know the meaning of one anymore")
}
return nil
}

// Monitor queries the database every pollInterval until it gets a
// quit signal logging an error when the test query fails. It can be
// used in a goroutine to check when the database becomes unavailable.
func (db *Handler) Monitor(pollInterval time.Duration, quit chan bool) {
log.Infof("starting DB monitor polling every %s", pollInterval)
for {
var one uint
err := db.QueryRow("SELECT 1").Scan(&one)
if err != nil {
log.Fatal("Database connection failed:", err)
}
if one != 1 {
log.Fatal("Apparently the database doesn't know the meaning of one anymore. Crashing.")
select {
case <-time.After(pollInterval):
err := db.CheckConnection()
if err != nil {
log.Error(err)
break
}
case <-quit:
log.Info("Shutting down DB monitor")
return
}
time.Sleep(10 * time.Second)
}
}
44 changes: 44 additions & 0 deletions database/connect_test.go
@@ -0,0 +1,44 @@
package database

import (
"testing"
"time"
)

func TestMonitor(t *testing.T) {
t.Parallel()

t.Run("runs and dies on connection close", func(t *testing.T) {
t.Parallel()

// connects
db, err := Connect(Config{
Name: "autograph",
User: "myautographdbuser",
Password: "myautographdbpassword",
Host: "127.0.0.1:5432",
})
if err != nil {
t.Fatal(err)
}

quit := make(chan bool)
go db.Monitor(5*time.Millisecond, quit)

// should not error for initial monitor run
err = db.CheckConnection()
if err != nil {
t.Fatalf("db.CheckConnection failed when it should not have with error: %s", err)
}
time.Sleep(10 * time.Millisecond)

// error for failing checks
db.Close()
err = db.CheckConnection()
if err == nil {
t.Fatalf("db.CheckConnection did not fail for a closed DB")
}

quit <- true
})
}
9 changes: 5 additions & 4 deletions database/queries_test.go
Expand Up @@ -10,10 +10,11 @@ import (

func TestConcurrentEndEntityOperations(t *testing.T) {
db, err := Connect(Config{
Name: "autograph",
User: "myautographdbuser",
Password: "myautographdbpassword",
Host: "127.0.0.1:5432",
Name: "autograph",
User: "myautographdbuser",
Password: "myautographdbpassword",
Host: "127.0.0.1:5432",
MonitorPollInterval: 10 * time.Second,
})
if err != nil {
t.Fatal(err)
Expand Down
1 change: 1 addition & 0 deletions docs/configuration.rst
Expand Up @@ -61,6 +61,7 @@ Make sure to set a user with limited grants in the configuration.
sslrootcert: /etc/ssl/certs/db-root.crt
maxopenconns: 100
maxidleconns: 10
monitorpollinterval: 10s
Hardware Security Module (HSM)
------------------------------
Expand Down
58 changes: 56 additions & 2 deletions handlers.go
Expand Up @@ -229,15 +229,69 @@ func (a *autographer) handleSignature(w http.ResponseWriter, r *http.Request) {
log.WithFields(log.Fields{"rid": rid}).Info("signing request completed successfully")
}

// handleHeartbeat returns a simple message indicating that the API is alive and well
func handleHeartbeat(w http.ResponseWriter, r *http.Request) {
// handleLBHeartbeat returns a simple message indicating that the API is alive and well
func handleLBHeartbeat(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
httpError(w, r, http.StatusMethodNotAllowed, "%s method not allowed; endpoint accepts GET only", r.Method)
return
}
w.Write([]byte("ohai"))
}

// handleHeartbeat checks whether backing services are enabled and
// accessible and returns 200 when they are and 502 when the
// aren't. Currently it only checks whether the HSM is accessible.
func (a *autographer) handleHeartbeat(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
httpError(w, r, http.StatusMethodNotAllowed, "%s method not allowed; endpoint accepts GET only", r.Method)
return
}
var (
// a map of backing service name to up or down/inaccessible status
result = map[string]bool{}
status = http.StatusOK
)

// try to fetch the private key from the HSM for the first
// signer conf with a non-PEM private key that we saved on
// server start
conf := a.hsmHeartbeatSignerConf
if conf != nil {
err := conf.CheckHSMConnection()
if err == nil {
result["hsmAccessible"] = true
status = http.StatusOK
} else {
log.Errorf("error checking HSM connection for signer %s: %s", conf.ID, err)
result["hsmAccessible"] = false
status = http.StatusInternalServerError
}
}

// check the database connection and return its status, but
// don't fail the heartbeat since we only care about DB
// connectivity on server start
if a.db != nil {
err := a.db.CheckConnection()
if err == nil {
result["dbAccessible"] = true
} else {
log.Errorf("error checking DB connection: %s", err)
result["dbAccessible"] = false
}
}

respdata, err := json.Marshal(result)
if err != nil {
log.Errorf("heartbeat failed to marshal JSON with error: %s", err)
httpError(w, r, http.StatusInternalServerError, "error marshaling response JSON")
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
w.Write(respdata)
}

func handleVersion(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
httpError(w, r, http.StatusMethodNotAllowed, "%s method not allowed; endpoint accepts GET only", r.Method)
Expand Down
94 changes: 91 additions & 3 deletions handlers_test.go
Expand Up @@ -24,6 +24,7 @@ import (
"strings"
"testing"

"go.mozilla.org/autograph/database"
"go.mozilla.org/autograph/signer/apk"
"go.mozilla.org/autograph/signer/contentsignature"
"go.mozilla.org/autograph/signer/mar"
Expand Down Expand Up @@ -360,7 +361,7 @@ func TestAuthFail(t *testing.T) {
}
}

func TestHeartbeat(t *testing.T) {
func TestLBHeartbeat(t *testing.T) {
t.Parallel()

var TESTCASES = []struct {
Expand All @@ -373,19 +374,106 @@ func TestHeartbeat(t *testing.T) {
{http.StatusMethodNotAllowed, `HEAD`},
}
for i, testcase := range TESTCASES {
req, err := http.NewRequest(testcase.method, "http://foo.bar/__heartbeat__", nil)
req, err := http.NewRequest(testcase.method, "http://foo.bar/__lbheartbeat__", nil)
if err != nil {
t.Fatal(err)
}
w := httptest.NewRecorder()
handleHeartbeat(w, req)
handleLBHeartbeat(w, req)
if w.Code != testcase.expect {
t.Fatalf("test case %d failed with code %d but %d was expected",
i, w.Code, testcase.expect)
}
}
}

func TestHeartbeat(t *testing.T) {
t.Parallel()

var TESTCASES = []struct {
expectedHTTPStatus int
method string
}{
{http.StatusOK, `GET`},
{http.StatusMethodNotAllowed, `POST`},
{http.StatusMethodNotAllowed, `PUT`},
{http.StatusMethodNotAllowed, `HEAD`},
}
for i, testcase := range TESTCASES {
req, err := http.NewRequest(testcase.method, "http://foo.bar/__heartbeat__", nil)
if err != nil {
t.Fatal(err)
}
w := httptest.NewRecorder()
ag.handleHeartbeat(w, req)
if w.Code != testcase.expectedHTTPStatus {
t.Fatalf("test case %d failed with code %d but %d was expected",
i, w.Code, testcase.expectedHTTPStatus)
}
if bytes.Equal(w.Body.Bytes(), []byte("{}\n")) {
t.Fatalf("test case %d returned unexpected heartbeat body %s expected {}", i, w.Body.Bytes())
}
}
}

func TestHeartbeatChecksHSMStatusFails(t *testing.T) {
// NB: do not run in parallel with TestHeartbeat*
ag.hsmHeartbeatSignerConf = &ag.signers[0].(*contentsignature.ContentSigner).Configuration

expectedStatus := http.StatusInternalServerError
expectedBody := []byte("{\"hsmAccessible\":false}")

req, err := http.NewRequest(`GET`, "http://foo.bar/__heartbeat__", nil)
if err != nil {
t.Fatal(err)
}
w := httptest.NewRecorder()
ag.handleHeartbeat(w, req)

if w.Code != expectedStatus {
t.Fatalf("failed with code %d but %d was expected", w.Code, expectedStatus)
}
if !bytes.Equal(w.Body.Bytes(), expectedBody) {
t.Fatalf("got unexpected heartbeat body %s expected %s", w.Body.Bytes(), expectedBody)
}

ag.hsmHeartbeatSignerConf = nil
}

func TestHeartbeatChecksDBStatusOK(t *testing.T) {
// NB: do not run in parallel with TestHeartbeat* or DB tests
db, err := database.Connect(database.Config{
Name: "autograph",
User: "myautographdbuser",
Password: "myautographdbpassword",
Host: "127.0.0.1:5432",
})
if err != nil {
t.Fatal(err)
}
ag.db = db

expectedStatus := http.StatusOK
expectedBody := []byte("{\"dbAccessible\":true}")

req, err := http.NewRequest(`GET`, "http://foo.bar/__heartbeat__", nil)
if err != nil {
t.Fatal(err)
}
w := httptest.NewRecorder()
ag.handleHeartbeat(w, req)

if w.Code != expectedStatus {
t.Fatalf("failed with code %d but %d was expected", w.Code, expectedStatus)
}
if !bytes.Equal(w.Body.Bytes(), expectedBody) {
t.Fatalf("got unexpected heartbeat body %s expected %s", w.Body.Bytes(), expectedBody)
}

db.Close()
ag.db = nil
}

func TestVersion(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit ca1009d

Please sign in to comment.