Skip to content

Commit

Permalink
unify datagateway method handling
Browse files Browse the repository at this point in the history
Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>
  • Loading branch information
butonic committed Feb 19, 2024
1 parent e7d6bc3 commit 43c323c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 240 deletions.
5 changes: 5 additions & 0 deletions changelog/unreleased/unify-datagateway-method-handling.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Bugfix: unify datagateway method handling

The datagateway now unpacks and forwards all HTTP methods

https://github.com/cs3org/reva/pull/4527
259 changes: 19 additions & 240 deletions internal/http/services/datagateway/datagateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ func init() {
const (
// TokenTransportHeader holds the header key for the reva transfer token
TokenTransportHeader = "X-Reva-Transfer"
// UploadExpiresHeader holds the timestamp for the transport token expiry, defined in https://tus.io/protocols/resumable-upload.html#expiration
UploadExpiresHeader = "Upload-Expires"
)

func init() {
Expand Down Expand Up @@ -133,31 +131,13 @@ func (s *svc) setHandler() {
semconv.HTTPURLKey.String(r.URL.String()),
)
r = r.WithContext(ctx)
switch r.Method {
case "HEAD":
s.doHead(w, r)
return
case "GET":
s.doGet(w, r)
return
case "PUT":
s.doPut(w, r)
return
case "PATCH":
s.doPatch(w, r)
return
case "OPTIONS":
s.doOptions(w, r)
return
default:
w.WriteHeader(http.StatusNotImplemented)
return
}
s.doRequest(w, r)
})
}

// verify extracts the transfer token from the request
// If it is not set as header we assume that it's the last path segment instead.
func (s *svc) verify(ctx context.Context, r *http.Request) (*transferClaims, error) {
// Extract transfer token from request header. If not existing, assume that it's the last path segment instead.
token := r.Header.Get(TokenTransportHeader)
if token == "" {
token = path.Base(r.URL.Path)
Expand All @@ -180,112 +160,7 @@ func (s *svc) verify(ctx context.Context, r *http.Request) (*transferClaims, err
return nil, err
}

func (s *svc) doHead(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
log := appctx.GetLogger(ctx)

claims, err := s.verify(ctx, r)
if err != nil {
err = errors.Wrap(err, "datagateway: error validating transfer token")
log.Error().Err(err).Str("token", r.Header.Get(TokenTransportHeader)).Msg("invalid transfer token")
w.WriteHeader(http.StatusForbidden)
return
}

log.Debug().Str("target", claims.Target).Msg("sending request to internal data server")

httpClient := s.client
httpReq, err := rhttp.NewRequest(ctx, "HEAD", claims.Target, nil)
if err != nil {
log.Error().Err(err).Msg("wrong request")
w.WriteHeader(http.StatusInternalServerError)
return
}
httpReq.Header = r.Header

httpRes, err := httpClient.Do(httpReq)
if err != nil {
log.Error().Err(err).Msg("error doing HEAD request to data service")
w.WriteHeader(http.StatusInternalServerError)
return
}
defer httpRes.Body.Close()

copyHeader(w.Header(), httpRes.Header)

// add upload expiry / transfer token expiry header for tus https://tus.io/protocols/resumable-upload.html#expiration
w.Header().Set(UploadExpiresHeader, time.Unix(claims.ExpiresAt, 0).Format(time.RFC1123))

if httpRes.StatusCode != http.StatusOK {
// swallow the body and set content-length to 0 to prevent reverse proxies from trying to read from it
w.Header().Set("Content-Length", "0")
w.WriteHeader(httpRes.StatusCode)
return
}

w.WriteHeader(http.StatusOK)
}

func (s *svc) doGet(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
log := appctx.GetLogger(ctx)

claims, err := s.verify(ctx, r)
if err != nil {
err = errors.Wrap(err, "datagateway: error validating transfer token")
log.Error().Err(err).Str("token", r.Header.Get(TokenTransportHeader)).Msg("invalid transfer token")
w.WriteHeader(http.StatusForbidden)
return
}

log.Debug().Str("target", claims.Target).Msg("sending request to internal data server")

httpClient := s.client
httpReq, err := rhttp.NewRequest(ctx, "GET", claims.Target, nil)
if err != nil {
log.Error().Err(err).Msg("wrong request")
w.WriteHeader(http.StatusInternalServerError)
return
}
httpReq.Header = r.Header

httpRes, err := httpClient.Do(httpReq)
if err != nil {
log.Error().Err(err).Msg("error doing GET request to data service")
w.WriteHeader(http.StatusInternalServerError)
return
}
defer httpRes.Body.Close()

copyHeader(w.Header(), httpRes.Header)
switch httpRes.StatusCode {
case http.StatusOK:
case http.StatusPartialContent:
default:
// swallow the body and set content-length to 0 to prevent reverse proxies from trying to read from it
w.Header().Set("Content-Length", "0")
w.WriteHeader(httpRes.StatusCode)
return
}
w.WriteHeader(httpRes.StatusCode)

var c int64
c, err = io.Copy(w, httpRes.Body)
if err != nil {
log.Error().Err(err).Msg("error writing body after headers were sent")
}
if httpRes.Header.Get("Content-Length") != "" {
i, err := strconv.ParseInt(httpRes.Header.Get("Content-Length"), 10, 64)
if err != nil {
log.Error().Err(err).Str("content-length", httpRes.Header.Get("Content-Length")).Msg("invalid content length in dataprovider response")
}
if i != c {
log.Error().Int64("content-length", i).Int64("transferred-bytes", c).Msg("content length vs transferred bytes mismatch")
}
}
}

func (s *svc) doPut(w http.ResponseWriter, r *http.Request) {
func (s *svc) doRequest(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
log := appctx.GetLogger(ctx)

Expand All @@ -309,10 +184,9 @@ func (s *svc) doPut(w http.ResponseWriter, r *http.Request) {
targetURL.RawQuery = r.URL.RawQuery
target = targetURL.String()

log.Debug().Str("target", claims.Target).Msg("sending request to internal data server")
log.Debug().Str("target", target).Msg("sending request to internal data server")

httpClient := s.client
httpReq, err := rhttp.NewRequest(ctx, "PUT", target, r.Body)
httpReq, err := rhttp.NewRequest(ctx, r.Method, target, r.Body)
if err != nil {
log.Err(err).Msg("wrong request")
w.WriteHeader(http.StatusInternalServerError)
Expand All @@ -321,68 +195,9 @@ func (s *svc) doPut(w http.ResponseWriter, r *http.Request) {
httpReq.Header = r.Header
httpReq.ContentLength = r.ContentLength

httpRes, err := httpClient.Do(httpReq)
httpRes, err := s.client.Do(httpReq)
if err != nil {
log.Err(err).Msg("error doing PUT request to data service")
w.WriteHeader(http.StatusInternalServerError)
return
}
defer httpRes.Body.Close()

copyHeader(w.Header(), httpRes.Header)
if httpRes.StatusCode != http.StatusOK {
// swallow the body and set content-length to 0 to prevent reverse proxies from trying to read from it
w.Header().Set("Content-Length", "0")
w.WriteHeader(httpRes.StatusCode)
return
}

w.WriteHeader(http.StatusOK)
_, err = io.Copy(w, httpRes.Body)
if err != nil {
log.Err(err).Msg("error writing body after header were set")
}
}

// TODO: put and post code is pretty much the same. Should be solved in a nicer way in the long run.
func (s *svc) doPatch(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
log := appctx.GetLogger(ctx)

claims, err := s.verify(ctx, r)
if err != nil {
err = errors.Wrap(err, "datagateway: error validating transfer token")
log.Err(err).Str("token", r.Header.Get(TokenTransportHeader)).Msg("invalid transfer token")
w.WriteHeader(http.StatusForbidden)
return
}

target := claims.Target
// add query params to target, clients can send checksums and other information.
targetURL, err := url.Parse(target)
if err != nil {
log.Err(err).Msg("datagateway: error parsing target url")
w.WriteHeader(http.StatusInternalServerError)
return
}

targetURL.RawQuery = r.URL.RawQuery
target = targetURL.String()

log.Debug().Str("target", claims.Target).Msg("sending request to internal data server")

httpClient := s.client
httpReq, err := rhttp.NewRequest(ctx, "PATCH", target, r.Body)
if err != nil {
log.Err(err).Msg("wrong request")
w.WriteHeader(http.StatusInternalServerError)
return
}
httpReq.Header = r.Header

httpRes, err := httpClient.Do(httpReq)
if err != nil {
log.Err(err).Msg("error doing PATCH request to data service")
log.Err(err).Msg("error doing " + r.Method + " request to data service")
w.WriteHeader(http.StatusInternalServerError)
return
}
Expand All @@ -395,58 +210,22 @@ func (s *svc) doPatch(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(httpRes.StatusCode)
return
}

w.WriteHeader(httpRes.StatusCode)
_, err = io.Copy(w, httpRes.Body)
if err != nil {
log.Err(err).Msg("error writing body after header were set")
}
}

func (s *svc) doOptions(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
log := appctx.GetLogger(ctx)

claims, err := s.verify(ctx, r)
if err != nil {
err = errors.Wrap(err, "datagateway: error validating transfer token")
log.Error().Err(err).Str("token", r.Header.Get(TokenTransportHeader)).Msg("invalid transfer token")
w.WriteHeader(http.StatusForbidden)
return
}

log.Debug().Str("target", claims.Target).Msg("sending request to internal data server")

httpClient := s.client
httpReq, err := rhttp.NewRequest(ctx, "OPTIONS", claims.Target, nil)
if err != nil {
log.Error().Err(err).Msg("wrong request")
w.WriteHeader(http.StatusInternalServerError)
return
}
httpReq.Header = r.Header

httpRes, err := httpClient.Do(httpReq)
var c int64
c, err = io.Copy(w, httpRes.Body)
if err != nil {
log.Error().Err(err).Msg("error doing OPTIONS request to data service")
w.WriteHeader(http.StatusInternalServerError)
return
log.Err(err).Msg("error writing body after header were set")
}
defer httpRes.Body.Close()

copyHeader(w.Header(), httpRes.Header)

// add upload expiry / transfer token expiry header for tus https://tus.io/protocols/resumable-upload.html#expiration
w.Header().Set(UploadExpiresHeader, time.Unix(claims.ExpiresAt, 0).Format(time.RFC1123))

if httpRes.StatusCode != http.StatusOK {
// swallow the body and set content-length to 0 to prevent reverse proxies from trying to read from it
w.Header().Set("Content-Length", "0")
w.WriteHeader(httpRes.StatusCode)
return
if httpRes.Header.Get("Content-Length") != "" {
i, err := strconv.ParseInt(httpRes.Header.Get("Content-Length"), 10, 64)
if err != nil {
log.Error().Err(err).Str("content-length", httpRes.Header.Get("Content-Length")).Msg("invalid content length in dataprovider response")
}
if i != c {
log.Error().Int64("content-length", i).Int64("transferred-bytes", c).Msg("content length vs transferred bytes mismatch")
}
}

w.WriteHeader(http.StatusOK)
}

func copyHeader(dst, src http.Header) {
Expand Down

0 comments on commit 43c323c

Please sign in to comment.