Skip to content

Commit

Permalink
server: pre-read request body to fix HTTP/2 deadlock
Browse files Browse the repository at this point in the history
Fixes #538 (hopefully)
  • Loading branch information
jkowalski committed Aug 16, 2020
1 parent 923c91b commit 2b029a4
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 45 deletions.
12 changes: 3 additions & 9 deletions internal/server/api_content.go
Expand Up @@ -3,7 +3,6 @@ package server
import (
"context"
"errors"
"io/ioutil"
"net/http"

"github.com/gorilla/mux"
Expand All @@ -13,7 +12,7 @@ import (
"github.com/kopia/kopia/repo/content"
)

func (s *Server) handleContentGet(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleContentGet(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
dr, ok := s.rep.(*repo.DirectRepository)
if !ok {
return nil, notFoundError("content not found")
Expand All @@ -29,7 +28,7 @@ func (s *Server) handleContentGet(ctx context.Context, r *http.Request) (interfa
return data, nil
}

func (s *Server) handleContentInfo(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleContentInfo(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
dr, ok := s.rep.(*repo.DirectRepository)
if !ok {
return nil, notFoundError("content not found")
Expand All @@ -50,7 +49,7 @@ func (s *Server) handleContentInfo(ctx context.Context, r *http.Request) (interf
}
}

func (s *Server) handleContentPut(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleContentPut(ctx context.Context, r *http.Request, data []byte) (interface{}, *apiError) {
dr, ok := s.rep.(*repo.DirectRepository)
if !ok {
return nil, notFoundError("content not found")
Expand All @@ -59,11 +58,6 @@ func (s *Server) handleContentPut(ctx context.Context, r *http.Request) (interfa
cid := content.ID(mux.Vars(r)["contentID"])
prefix := cid.Prefix()

data, err := ioutil.ReadAll(r.Body)
if err != nil {
return nil, requestError(serverapi.ErrorMalformedRequest, "malformed request body")
}

actualCID, err := dr.Content.WriteContent(ctx, data, prefix)
if err != nil {
return nil, internalServerError(err)
Expand Down
10 changes: 5 additions & 5 deletions internal/server/api_manifest.go
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/kopia/kopia/repo/manifest"
)

func (s *Server) handleManifestGet(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleManifestGet(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
// password already validated by a wrapper, no need to check here.
userAtHost, _, _ := r.BasicAuth()

Expand All @@ -40,7 +40,7 @@ func (s *Server) handleManifestGet(ctx context.Context, r *http.Request) (interf
}, nil
}

func (s *Server) handleManifestDelete(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleManifestDelete(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
mid := manifest.ID(mux.Vars(r)["manifestID"])

err := s.rep.DeleteManifest(ctx, mid)
Expand All @@ -55,7 +55,7 @@ func (s *Server) handleManifestDelete(ctx context.Context, r *http.Request) (int
return &serverapi.Empty{}, nil
}

func (s *Server) handleManifestList(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleManifestList(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
// password already validated by a wrapper, no need to check here.
userAtHost, _, _ := r.BasicAuth()

Expand Down Expand Up @@ -95,10 +95,10 @@ func filterManifests(manifests []*manifest.EntryMetadata, userAtHost string) []*
return result
}

func (s *Server) handleManifestCreate(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleManifestCreate(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
var req remoterepoapi.ManifestWithMetadata

if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
if err := json.Unmarshal(body, &req); err != nil {
return nil, requestError(serverapi.ErrorMalformedRequest, "malformed request")
}

Expand Down
10 changes: 5 additions & 5 deletions internal/server/api_policies.go
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/kopia/kopia/snapshot/policy"
)

func (s *Server) handlePolicyList(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handlePolicyList(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
policies, err := policy.ListPolicies(ctx, s.rep)
if err != nil {
return nil, internalServerError(err)
Expand Down Expand Up @@ -50,7 +50,7 @@ func getPolicyTargetFromURL(u *url.URL) snapshot.SourceInfo {
}
}

func (s *Server) handlePolicyGet(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handlePolicyGet(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
pol, err := policy.GetDefinedPolicy(ctx, s.rep, getPolicyTargetFromURL(r.URL))
if errors.Is(err, policy.ErrPolicyNotFound) {
return nil, requestError(serverapi.ErrorNotFound, "policy not found")
Expand All @@ -59,7 +59,7 @@ func (s *Server) handlePolicyGet(ctx context.Context, r *http.Request) (interfac
return pol, nil
}

func (s *Server) handlePolicyDelete(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handlePolicyDelete(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
if err := policy.RemovePolicy(ctx, s.rep, getPolicyTargetFromURL(r.URL)); err != nil {
return nil, internalServerError(err)
}
Expand All @@ -71,9 +71,9 @@ func (s *Server) handlePolicyDelete(ctx context.Context, r *http.Request) (inter
return &serverapi.Empty{}, nil
}

func (s *Server) handlePolicyPut(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handlePolicyPut(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
newPolicy := &policy.Policy{}
if err := json.NewDecoder(r.Body).Decode(newPolicy); err != nil {
if err := json.Unmarshal(body, newPolicy); err != nil {
return nil, requestError(serverapi.ErrorMalformedRequest, "malformed request body")
}

Expand Down
22 changes: 11 additions & 11 deletions internal/server/api_repo.go
Expand Up @@ -20,7 +20,7 @@ import (
"github.com/kopia/kopia/snapshot/policy"
)

func (s *Server) handleRepoParameters(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleRepoParameters(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
dr, ok := s.rep.(*repo.DirectRepository)
if !ok {
return &serverapi.StatusResponse{
Expand All @@ -37,7 +37,7 @@ func (s *Server) handleRepoParameters(ctx context.Context, r *http.Request) (int
return rp, nil
}

func (s *Server) handleRepoStatus(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleRepoStatus(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
if s.rep == nil {
return &serverapi.StatusResponse{
Connected: false,
Expand Down Expand Up @@ -79,14 +79,14 @@ func maybeDecodeToken(req *serverapi.ConnectRepositoryRequest) *apiError {
return nil
}

func (s *Server) handleRepoCreate(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleRepoCreate(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
if s.rep != nil {
return nil, requestError(serverapi.ErrorAlreadyConnected, "already connected")
}

var req serverapi.CreateRepositoryRequest

if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
if err := json.Unmarshal(body, &req); err != nil {
return nil, requestError(serverapi.ErrorMalformedRequest, "unable to decode request: "+err.Error())
}

Expand Down Expand Up @@ -125,17 +125,17 @@ func (s *Server) handleRepoCreate(ctx context.Context, r *http.Request) (interfa
return nil, internalServerError(errors.Wrap(err, "flush"))
}

return s.handleRepoStatus(ctx, r)
return s.handleRepoStatus(ctx, r, nil)
}

func (s *Server) handleRepoConnect(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleRepoConnect(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
if s.rep != nil {
return nil, requestError(serverapi.ErrorAlreadyConnected, "already connected")
}

var req serverapi.ConnectRepositoryRequest

if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
if err := json.Unmarshal(body, &req); err != nil {
return nil, requestError(serverapi.ErrorMalformedRequest, "unable to decode request: "+err.Error())
}

Expand All @@ -147,10 +147,10 @@ func (s *Server) handleRepoConnect(ctx context.Context, r *http.Request) (interf
return nil, err
}

return s.handleRepoStatus(ctx, r)
return s.handleRepoStatus(ctx, r, nil)
}

func (s *Server) handleRepoSupportedAlgorithms(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleRepoSupportedAlgorithms(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
res := &serverapi.SupportedAlgorithmsResponse{
DefaultHashAlgorithm: hashing.DefaultAlgorithm,
HashAlgorithms: hashing.SupportedAlgorithms(),
Expand Down Expand Up @@ -200,7 +200,7 @@ func (s *Server) connectAndOpen(ctx context.Context, conn blob.ConnectionInfo, p
return nil
}

func (s *Server) handleRepoDisconnect(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleRepoDisconnect(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
// release shared lock so that SetRepository can acquire exclusive lock
s.mu.RUnlock()
err := s.SetRepository(ctx, nil)
Expand All @@ -217,7 +217,7 @@ func (s *Server) handleRepoDisconnect(ctx context.Context, r *http.Request) (int
return &serverapi.Empty{}, nil
}

func (s *Server) handleRepoSync(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleRepoSync(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
if err := s.rep.Refresh(ctx); err != nil {
return nil, internalServerError(errors.Wrap(err, "unable to refresh repository"))
}
Expand Down
2 changes: 1 addition & 1 deletion internal/server/api_snapshots.go
Expand Up @@ -10,7 +10,7 @@ import (
"github.com/kopia/kopia/snapshot/policy"
)

func (s *Server) handleSnapshotList(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleSnapshotList(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
manifestIDs, err := snapshot.ListSnapshotManifests(ctx, s.rep, nil)
if err != nil {
return nil, internalServerError(err)
Expand Down
6 changes: 3 additions & 3 deletions internal/server/api_sources.go
Expand Up @@ -15,7 +15,7 @@ import (
"github.com/kopia/kopia/snapshot/policy"
)

func (s *Server) handleSourcesList(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleSourcesList(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
resp := &serverapi.SourcesResponse{
Sources: []*serverapi.SourceStatus{},
LocalHost: s.rep.Hostname(),
Expand All @@ -37,10 +37,10 @@ func (s *Server) handleSourcesList(ctx context.Context, r *http.Request) (interf
return resp, nil
}

func (s *Server) handleSourcesCreate(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleSourcesCreate(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
var req serverapi.CreateSnapshotSourceRequest

if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
if err := json.Unmarshal(body, &req); err != nil {
return nil, requestError(serverapi.ErrorMalformedRequest, "malformed request body")
}

Expand Down
34 changes: 23 additions & 11 deletions internal/server/server.go
Expand Up @@ -4,6 +4,7 @@ package server
import (
"context"
"encoding/json"
"io/ioutil"
"net/http"
"net/url"
"sync"
Expand All @@ -25,6 +26,8 @@ var log = logging.GetContextLoggerFunc("kopia/server")

const maintenanceAttemptFrequency = 10 * time.Minute

type apiRequestFunc func(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError)

// Server exposes simple HTTP API for programmatically accessing Kopia features.
type Server struct {
OnShutdown func(ctx context.Context) error
Expand Down Expand Up @@ -85,30 +88,39 @@ func (s *Server) APIHandlers() http.Handler {
return m
}

func (s *Server) handleAPI(f func(ctx context.Context, r *http.Request) (interface{}, *apiError)) http.HandlerFunc {
return s.handleAPIPossiblyNotConnected(func(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleAPI(f apiRequestFunc) http.HandlerFunc {
return s.handleAPIPossiblyNotConnected(func(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
if s.rep == nil {
return nil, requestError(serverapi.ErrorNotConnected, "not connected")
}

return f(ctx, r)
return f(ctx, r, body)
})
}

func (s *Server) handleAPIPossiblyNotConnected(f func(ctx context.Context, r *http.Request) (interface{}, *apiError)) http.HandlerFunc {
func (s *Server) handleAPIPossiblyNotConnected(f apiRequestFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// we must pre-read request body before acquiring the lock as it sometimes leads to deadlock
// in HTTP/2 server.
// See https://github.com/golang/go/issues/40816
body, berr := ioutil.ReadAll(r.Body)
if berr != nil {
http.Error(w, "error reading request body", http.StatusInternalServerError)
return
}

s.mu.RLock()
defer s.mu.RUnlock()

ctx := r.Context()

log(ctx).Debugf("request %v", r.URL)
log(ctx).Debugf("request %v (%v bytes)", r.URL, len(body))

w.Header().Set("Content-Type", "application/json")
e := json.NewEncoder(w)
e.SetIndent("", " ")

v, err := f(ctx, r)
v, err := f(ctx, r, body)

if err == nil {
if b, ok := v.([]byte); ok {
Expand All @@ -134,23 +146,23 @@ func (s *Server) handleAPIPossiblyNotConnected(f func(ctx context.Context, r *ht
}
}

func (s *Server) handleRefresh(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleRefresh(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
if err := s.rep.Refresh(ctx); err != nil {
return nil, internalServerError(err)
}

return &serverapi.Empty{}, nil
}

func (s *Server) handleFlush(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleFlush(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
if err := s.rep.Flush(ctx); err != nil {
return nil, internalServerError(err)
}

return &serverapi.Empty{}, nil
}

func (s *Server) handleShutdown(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleShutdown(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
log(ctx).Infof("shutting down due to API request")

if f := s.OnShutdown; f != nil {
Expand Down Expand Up @@ -180,11 +192,11 @@ func (s *Server) forAllSourceManagersMatchingURLFilter(ctx context.Context, c fu
return resp, nil
}

func (s *Server) handleUpload(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleUpload(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
return s.forAllSourceManagersMatchingURLFilter(ctx, (*sourceManager).upload, r.URL.Query())
}

func (s *Server) handleCancel(ctx context.Context, r *http.Request) (interface{}, *apiError) {
func (s *Server) handleCancel(ctx context.Context, r *http.Request, body []byte) (interface{}, *apiError) {
return s.forAllSourceManagersMatchingURLFilter(ctx, (*sourceManager).cancel, r.URL.Query())
}

Expand Down

0 comments on commit 2b029a4

Please sign in to comment.