Skip to content

Commit

Permalink
Make local files default for fs commands (#506)
Browse files Browse the repository at this point in the history
## Changes
<!-- Summary of your changes that are easy to understand -->

## Tests
<!-- How is this tested? -->
  • Loading branch information
shreyas-goenka committed Jun 23, 2023
1 parent d0e9953 commit 30efe91
Show file tree
Hide file tree
Showing 11 changed files with 260 additions and 71 deletions.
63 changes: 27 additions & 36 deletions cmd/fs/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,47 +10,38 @@ import (
"github.com/databricks/cli/libs/filer"
)

type Scheme string

const (
DbfsScheme = Scheme("dbfs")
LocalScheme = Scheme("file")
NoScheme = Scheme("")
)

func filerForPath(ctx context.Context, fullPath string) (filer.Filer, string, error) {
parts := strings.SplitN(fullPath, ":/", 2)
// Split path at : to detect any file schemes
parts := strings.SplitN(fullPath, ":", 2)

// If no scheme is specified, then local path
if len(parts) < 2 {
return nil, "", fmt.Errorf(`no scheme specified for path %s. Please specify scheme "dbfs" or "file". Example: file:/foo/bar or file:/c:/foo/bar`, fullPath)
f, err := filer.NewLocalClient("")
return f, fullPath, err
}

// On windows systems, paths start with a drive letter. If the scheme
// is a single letter and the OS is windows, then we conclude the path
// is meant to be a local path.
if runtime.GOOS == "windows" && len(parts[0]) == 1 {
f, err := filer.NewLocalClient("")
return f, fullPath, err
}
scheme := Scheme(parts[0])

if parts[0] != "dbfs" {
return nil, "", fmt.Errorf("invalid scheme: %s", parts[0])
}

path := parts[1]
switch scheme {
case DbfsScheme:
w := root.WorkspaceClient(ctx)
// If the specified path has the "Volumes" prefix, use the Files API.
if strings.HasPrefix(path, "Volumes/") {
f, err := filer.NewFilesClient(w, "/")
return f, path, err
}
f, err := filer.NewDbfsClient(w, "/")
return f, path, err
w := root.WorkspaceClient(ctx)

case LocalScheme:
if runtime.GOOS == "windows" {
parts := strings.SplitN(path, ":", 2)
if len(parts) < 2 {
return nil, "", fmt.Errorf("no volume specfied for path: %s", path)
}
volume := parts[0] + ":"
relPath := parts[1]
f, err := filer.NewLocalClient(volume)
return f, relPath, err
}
f, err := filer.NewLocalClient("/")
// If the specified path has the "Volumes" prefix, use the Files API.
if strings.HasPrefix(path, "/Volumes/") {
f, err := filer.NewFilesClient(w, "/")
return f, path, err

default:
return nil, "", fmt.Errorf(`unsupported scheme %s specified for path %s. Please specify scheme "dbfs" or "file". Example: file:/foo/bar or file:/c:/foo/bar`, scheme, fullPath)
}

// The file is a dbfs file, and uses the DBFS APIs
f, err := filer.NewDbfsClient(w, "/")
return f, path, err
}
56 changes: 46 additions & 10 deletions cmd/fs/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,58 @@ import (
"runtime"
"testing"

"github.com/databricks/cli/libs/filer"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNotSpecifyingVolumeForWindowsPathErrors(t *testing.T) {
if runtime.GOOS != "windows" {
t.Skip()
}
func TestFilerForPathForLocalPaths(t *testing.T) {
tmpDir := t.TempDir()
ctx := context.Background()

f, path, err := filerForPath(ctx, tmpDir)
assert.NoError(t, err)
assert.Equal(t, tmpDir, path)

info, err := f.Stat(ctx, path)
require.NoError(t, err)
assert.True(t, info.IsDir())
}

func TestFilerForPathForInvalidScheme(t *testing.T) {
ctx := context.Background()
pathWithVolume := `file:/c:/foo/bar`
pathWOVolume := `file:/uno/dos`

_, path, err := filerForPath(ctx, pathWithVolume)
_, _, err := filerForPath(ctx, "dbf:/a")
assert.ErrorContains(t, err, "invalid scheme")

_, _, err = filerForPath(ctx, "foo:a")
assert.ErrorContains(t, err, "invalid scheme")

_, _, err = filerForPath(ctx, "file:/a")
assert.ErrorContains(t, err, "invalid scheme")
}

func testWindowsFilerForPath(t *testing.T, ctx context.Context, fullPath string) {
f, path, err := filerForPath(ctx, fullPath)
assert.NoError(t, err)
assert.Equal(t, `/foo/bar`, path)

_, _, err = filerForPath(ctx, pathWOVolume)
assert.Equal(t, "no volume specfied for path: uno/dos", err.Error())
// Assert path remains unchanged
assert.Equal(t, path, fullPath)

// Assert local client is created
_, ok := f.(*filer.LocalClient)
assert.True(t, ok)
}

func TestFilerForWindowsLocalPaths(t *testing.T) {
if runtime.GOOS != "windows" {
t.SkipNow()
}

ctx := context.Background()
testWindowsFilerForPath(t, ctx, `c:\abc`)
testWindowsFilerForPath(t, ctx, `c:abc`)
testWindowsFilerForPath(t, ctx, `d:\abc`)
testWindowsFilerForPath(t, ctx, `d:\abc`)
testWindowsFilerForPath(t, ctx, `f:\abc\ef`)
}
13 changes: 3 additions & 10 deletions internal/fs_cp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func setupLocalFiler(t *testing.T) (filer.Filer, string) {
f, err := filer.NewLocalClient(tmp)
require.NoError(t, err)

return f, path.Join("file:/", filepath.ToSlash(tmp))
return f, path.Join(filepath.ToSlash(tmp))
}

func setupDbfsFiler(t *testing.T) (filer.Filer, string) {
Expand Down Expand Up @@ -259,21 +259,14 @@ func TestAccFsCpErrorsWhenSourceIsDirWithoutRecursiveFlag(t *testing.T) {
tmpDir := temporaryDbfsDir(t, w)

_, _, err = RequireErrorRun(t, "fs", "cp", "dbfs:"+tmpDir, "dbfs:/tmp")
assert.Equal(t, fmt.Sprintf("source path %s is a directory. Please specify the --recursive flag", strings.TrimPrefix(tmpDir, "/")), err.Error())
}

func TestAccFsCpErrorsOnNoScheme(t *testing.T) {
t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV"))

_, _, err := RequireErrorRun(t, "fs", "cp", "/a", "/b")
assert.Equal(t, "no scheme specified for path /a. Please specify scheme \"dbfs\" or \"file\". Example: file:/foo/bar or file:/c:/foo/bar", err.Error())
assert.Equal(t, fmt.Sprintf("source path %s is a directory. Please specify the --recursive flag", tmpDir), err.Error())
}

func TestAccFsCpErrorsOnInvalidScheme(t *testing.T) {
t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV"))

_, _, err := RequireErrorRun(t, "fs", "cp", "dbfs:/a", "https:/b")
assert.Equal(t, "unsupported scheme https specified for path https:/b. Please specify scheme \"dbfs\" or \"file\". Example: file:/foo/bar or file:/c:/foo/bar", err.Error())
assert.Equal(t, "invalid scheme: https", err.Error())
}

func TestAccFsCpSourceIsDirectoryButTargetIsFile(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions libs/filer/dbfs_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ type DbfsClient struct {
workspaceClient *databricks.WorkspaceClient

// File operations will be relative to this path.
root RootPath
root WorkspaceRootPath
}

func NewDbfsClient(w *databricks.WorkspaceClient, root string) (Filer, error) {
return &DbfsClient{
workspaceClient: w,

root: NewRootPath(root),
root: NewWorkspaceRootPath(root),
}, nil
}

Expand Down
4 changes: 2 additions & 2 deletions libs/filer/files_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ type FilesClient struct {
apiClient *client.DatabricksClient

// File operations will be relative to this path.
root RootPath
root WorkspaceRootPath
}

func filesNotImplementedError(fn string) error {
Expand All @@ -77,7 +77,7 @@ func NewFilesClient(w *databricks.WorkspaceClient, root string) (Filer, error) {
workspaceClient: w,
apiClient: apiClient,

root: NewRootPath(root),
root: NewWorkspaceRootPath(root),
}, nil
}

Expand Down
4 changes: 2 additions & 2 deletions libs/filer/local_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ import (
// LocalClient implements the [Filer] interface for the local filesystem.
type LocalClient struct {
// File operations will be relative to this path.
root RootPath
root localRootPath
}

func NewLocalClient(root string) (Filer, error) {
return &LocalClient{
root: NewRootPath(root),
root: NewLocalRootPath(root),
}, nil
}

Expand Down
27 changes: 27 additions & 0 deletions libs/filer/local_root_path.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package filer

import (
"fmt"
"path/filepath"
"strings"
)

type localRootPath struct {
rootPath string
}

func NewLocalRootPath(root string) localRootPath {
if root == "" {
return localRootPath{""}
}
return localRootPath{filepath.Clean(root)}
}

func (rp *localRootPath) Join(name string) (string, error) {
absPath := filepath.Join(rp.rootPath, name)

if !strings.HasPrefix(absPath, rp.rootPath) {
return "", fmt.Errorf("relative path escapes root: %s", name)
}
return absPath, nil
}
142 changes: 142 additions & 0 deletions libs/filer/local_root_path_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package filer

import (
"path/filepath"
"runtime"
"testing"

"github.com/stretchr/testify/assert"
)

func testUnixLocalRootPath(t *testing.T, uncleanRoot string) {
cleanRoot := filepath.Clean(uncleanRoot)
rp := NewLocalRootPath(uncleanRoot)

remotePath, err := rp.Join("a/b/c")
assert.NoError(t, err)
assert.Equal(t, cleanRoot+"/a/b/c", remotePath)

remotePath, err = rp.Join("a/b/../d")
assert.NoError(t, err)
assert.Equal(t, cleanRoot+"/a/d", remotePath)

remotePath, err = rp.Join("a/../c")
assert.NoError(t, err)
assert.Equal(t, cleanRoot+"/c", remotePath)

remotePath, err = rp.Join("a/b/c/.")
assert.NoError(t, err)
assert.Equal(t, cleanRoot+"/a/b/c", remotePath)

remotePath, err = rp.Join("a/b/c/d/./../../f/g")
assert.NoError(t, err)
assert.Equal(t, cleanRoot+"/a/b/f/g", remotePath)

remotePath, err = rp.Join(".//a/..//./b/..")
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)

remotePath, err = rp.Join("a/b/../..")
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)

remotePath, err = rp.Join("")
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)

remotePath, err = rp.Join(".")
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)

remotePath, err = rp.Join("/")
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)

_, err = rp.Join("..")
assert.ErrorContains(t, err, `relative path escapes root: ..`)

_, err = rp.Join("a/../..")
assert.ErrorContains(t, err, `relative path escapes root: a/../..`)

_, err = rp.Join("./../.")
assert.ErrorContains(t, err, `relative path escapes root: ./../.`)

_, err = rp.Join("/./.././..")
assert.ErrorContains(t, err, `relative path escapes root: /./.././..`)

_, err = rp.Join("./../.")
assert.ErrorContains(t, err, `relative path escapes root: ./../.`)

_, err = rp.Join("./..")
assert.ErrorContains(t, err, `relative path escapes root: ./..`)

_, err = rp.Join("./../../..")
assert.ErrorContains(t, err, `relative path escapes root: ./../../..`)

_, err = rp.Join("./../a/./b../../..")
assert.ErrorContains(t, err, `relative path escapes root: ./../a/./b../../..`)

_, err = rp.Join("../..")
assert.ErrorContains(t, err, `relative path escapes root: ../..`)
}

func TestUnixLocalRootPath(t *testing.T) {
if runtime.GOOS != "darwin" && runtime.GOOS != "linux" {
t.SkipNow()
}

testUnixLocalRootPath(t, "/some/root/path")
testUnixLocalRootPath(t, "/some/root/path/")
testUnixLocalRootPath(t, "/some/root/path/.")
testUnixLocalRootPath(t, "/some/root/../path/")
}

func testWindowsLocalRootPath(t *testing.T, uncleanRoot string) {
cleanRoot := filepath.Clean(uncleanRoot)
rp := NewLocalRootPath(uncleanRoot)

remotePath, err := rp.Join(`a\b\c`)
assert.NoError(t, err)
assert.Equal(t, cleanRoot+`\a\b\c`, remotePath)

remotePath, err = rp.Join(`a\b\..\d`)
assert.NoError(t, err)
assert.Equal(t, cleanRoot+`\a\d`, remotePath)

remotePath, err = rp.Join(`a\..\c`)
assert.NoError(t, err)
assert.Equal(t, cleanRoot+`\c`, remotePath)

remotePath, err = rp.Join(`a\b\c\.`)
assert.NoError(t, err)
assert.Equal(t, cleanRoot+`\a\b\c`, remotePath)

remotePath, err = rp.Join(`a\b\..\..`)
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)

remotePath, err = rp.Join("")
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)

remotePath, err = rp.Join(".")
assert.NoError(t, err)
assert.Equal(t, cleanRoot, remotePath)

_, err = rp.Join("..")
assert.ErrorContains(t, err, `relative path escapes root`)

_, err = rp.Join(`a\..\..`)
assert.ErrorContains(t, err, `relative path escapes root`)
}

func TestWindowsLocalRootPath(t *testing.T) {
if runtime.GOOS != "windows" {
t.SkipNow()
}

testWindowsLocalRootPath(t, `c:\some\root\path`)
testWindowsLocalRootPath(t, `c:\some\root\path\`)
testWindowsLocalRootPath(t, `c:\some\root\path\.`)
testWindowsLocalRootPath(t, `C:\some\root\..\path\`)
}
Loading

0 comments on commit 30efe91

Please sign in to comment.