Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not stat shared resources when downloading #1038

Merged
merged 3 commits into from Aug 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions changelog/unreleased/download-shares-fix.md
@@ -0,0 +1,9 @@
Bugfix: Do not stat shared resources when downloading

Previously, we statted the resources in all download requests resulting in
failures when downloading references. This PR fixes that by statting only in
case the resource is not present in the shares folder. It also fixes a bug where
we allowed uploading to the mount path, resulting in overwriting the user home
directory.

https://github.com/cs3org/reva/pull/1038
7 changes: 7 additions & 0 deletions examples/storage-references/storage-reva.toml
Expand Up @@ -5,3 +5,10 @@ address = "0.0.0.0:18000"
driver = "local"
mount_path = "/reva"
mount_id = "123e4567-e89b-12d3-a456-426655440000"
data_server_url = "http://localhost:18001/data"

[http]
address = "0.0.0.0:18001"

[http.services.dataprovider]
driver = "local"
40 changes: 20 additions & 20 deletions internal/grpc/services/gateway/storageprovider.go
Expand Up @@ -104,25 +104,6 @@ func (s *svc) getHome(ctx context.Context) string {
return "/home"
}
func (s *svc) InitiateFileDownload(ctx context.Context, req *provider.InitiateFileDownloadRequest) (*gateway.InitiateFileDownloadResponse, error) {
statReq := &provider.StatRequest{Ref: req.Ref}
statRes, err := s.stat(ctx, statReq)
if err != nil {
return &gateway.InitiateFileDownloadResponse{
Status: status.NewInternal(ctx, err, "gateway: error stating ref:"+req.Ref.String()),
}, nil
}
if statRes.Status.Code != rpc.Code_CODE_OK {
if statRes.Status.Code == rpc.Code_CODE_NOT_FOUND {
return &gateway.InitiateFileDownloadResponse{
Status: status.NewNotFound(ctx, "gateway: file not found"),
}, nil
}
err := status.NewErrorFromCode(statRes.Status.Code, "gateway")
return &gateway.InitiateFileDownloadResponse{
Status: status.NewInternal(ctx, err, "gateway: error stating ref"),
}, nil
}

p, err := s.getPath(ctx, req.Ref)
if err != nil {
return &gateway.InitiateFileDownloadResponse{
Expand All @@ -131,13 +112,31 @@ func (s *svc) InitiateFileDownload(ctx context.Context, req *provider.InitiateFi
}

if !s.inSharedFolder(ctx, p) {
statReq := &provider.StatRequest{Ref: req.Ref}
statRes, err := s.stat(ctx, statReq)
if err != nil {
return &gateway.InitiateFileDownloadResponse{
Status: status.NewInternal(ctx, err, "gateway: error stating ref:"+req.Ref.String()),
}, nil
}
if statRes.Status.Code != rpc.Code_CODE_OK {
if statRes.Status.Code == rpc.Code_CODE_NOT_FOUND {
return &gateway.InitiateFileDownloadResponse{
Status: status.NewNotFound(ctx, "gateway: file not found"),
}, nil
}
err := status.NewErrorFromCode(statRes.Status.Code, "gateway")
return &gateway.InitiateFileDownloadResponse{
Status: status.NewInternal(ctx, err, "gateway: error stating ref"),
}, nil
}
return s.initiateFileDownload(ctx, req)
}

log := appctx.GetLogger(ctx)
if s.isSharedFolder(ctx, p) || s.isShareName(ctx, p) {
log.Debug().Msgf("path:%s points to shared folder or share name", p)
err := errtypes.PermissionDenied("gateway: cannot upload to share folder or share name: path=" + p)
err := errtypes.PermissionDenied("gateway: cannot download share folder or share name: path=" + p)
log.Err(err).Msg("gateway: error downloading")
return &gateway.InitiateFileDownloadResponse{
Status: status.NewInvalidArg(ctx, "path points to share folder or share name"),
Expand Down Expand Up @@ -194,6 +193,7 @@ func (s *svc) InitiateFileDownload(ctx context.Context, req *provider.InitiateFi
},
}
req.Ref = ref
log.Debug().Msg("download path: " + target)
return s.initiateFileDownload(ctx, req)
}

Expand Down
5 changes: 5 additions & 0 deletions internal/grpc/services/storageprovider/storageprovider.go
Expand Up @@ -273,6 +273,11 @@ func (s *service) InitiateFileUpload(ctx context.Context, req *provider.Initiate
Status: status.NewInternal(ctx, err, "error unwrapping path"),
}, nil
}
if newRef.GetPath() == "/" {
return &provider.InitiateFileUploadResponse{
Status: status.NewInternal(ctx, errors.New("can't upload to mount path"), ""),
}, nil
}
url := *s.dataServerURL
if s.conf.DisableTus {
url.Path = path.Join("/", url.Path, newRef.GetPath())
Expand Down
27 changes: 15 additions & 12 deletions pkg/user/manager/rest/rest.go
Expand Up @@ -353,26 +353,24 @@ func (m *manager) GetUserByClaim(ctx context.Context, claim, value string) (*use

}

func (m *manager) findUsersByFilter(ctx context.Context, url string) ([]*userpb.User, error) {
func (m *manager) findUsersByFilter(ctx context.Context, url string, users map[string]*userpb.User) error {

userData, err := m.sendAPIRequest(ctx, url)
if err != nil {
return nil, err
return err
}

users := []*userpb.User{}

for _, usr := range userData {
usrInfo, ok := usr.(map[string]interface{})
if !ok {
return nil, errors.New("rest: error in type assertion")
return errors.New("rest: error in type assertion")
}

uid := &userpb.UserId{
OpaqueId: usrInfo["upn"].(string),
Idp: m.conf.IDProvider,
}
users = append(users, &userpb.User{
users[uid.OpaqueId] = &userpb.User{
Id: uid,
Username: usrInfo["upn"].(string),
Mail: usrInfo["primaryAccountEmail"].(string),
Expand All @@ -389,10 +387,10 @@ func (m *manager) findUsersByFilter(ctx context.Context, url string) ([]*userpb.
},
},
},
})
}
}

return users, nil
return nil
}

func (m *manager) FindUsers(ctx context.Context, query string) ([]*userpb.User, error) {
Expand All @@ -407,18 +405,23 @@ func (m *manager) FindUsers(ctx context.Context, query string) ([]*userpb.User,
return nil, errors.New("rest: illegal characters present in query")
}

users := []*userpb.User{}
users := make(map[string]*userpb.User)

for _, f := range filters {
url := fmt.Sprintf("%s/Identity/?filter=%s:contains:%s&field=id&field=upn&field=primaryAccountEmail&field=displayName&field=uid&field=gid",
m.conf.APIBaseURL, f, query)
filteredUsers, err := m.findUsersByFilter(ctx, url)
err := m.findUsersByFilter(ctx, url, users)
if err != nil {
return nil, err
}
users = append(users, filteredUsers...)
}
return users, nil

userSlice := make([]*userpb.User, len(users))
for _, v := range users {
userSlice = append(userSlice, v)
}

return userSlice, nil
}

func (m *manager) GetUserGroups(ctx context.Context, uid *userpb.UserId) ([]string, error) {
Expand Down