Skip to content

Commit

Permalink
feat: add overwrite flag
Browse files Browse the repository at this point in the history
  • Loading branch information
nehemming committed Aug 10, 2021
1 parent 1b65ce1 commit 4869275
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 25 deletions.
72 changes: 49 additions & 23 deletions pkg/builtin/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ type (
Copy struct {
Sources []string `mapstructure:"sources"`
Destination string `mapstructure:"destination"`
Overwrite bool `mapstructure:"overwrite"`
Log bool `mapstructure:"log"`
}

Expand Down Expand Up @@ -92,7 +93,7 @@ func (copyType) Prepare(ctx context.Context, capComm *rocket.CapComm, task rocke
}

// copy
return copyFiles(execCtx, files, destSpec, copyCfg.Log)
return copyFiles(execCtx, files, destSpec, copyCfg.Overwrite, copyCfg.Log)
}

return fn, nil
Expand Down Expand Up @@ -201,20 +202,20 @@ func toDistinctAbsRelSlice(files ...AbsRel) []AbsRel {
return res
}

func copyFiles(ctx context.Context, sources []AbsRel, dest DestSpec, log bool) error {
func copyFiles(ctx context.Context, sources []AbsRel, dest DestSpec, allowOverwrite, log bool) error {
for _, source := range sources {
if ctx.Err() != nil {
return ctx.Err()
}

if err := copyFile(source, dest, log); err != nil {
if err := copyFile(source, dest, allowOverwrite, log); err != nil {
return err
}
}
return nil
}

func copyFile(source AbsRel, dest DestSpec, log bool) error {
func copyFile(source AbsRel, dest DestSpec, allowOverwrite, log bool) error {
// Get the source files permission
stat, err := os.Stat(source.Abs)
if err != nil {
Expand All @@ -228,44 +229,69 @@ func copyFile(source AbsRel, dest DestSpec, log bool) error {
}
defer srcFile.Close()

var finalPath string
if dest.IsDir {
finalPath = filepath.Join(dest.Path, source.Rel)
} else {
finalPath = dest.Path
}
dir := filepath.Dir(finalPath)
err = os.MkdirAll(dir, 0777)

destAbsRel, err := prepDestination(source, dest, allowOverwrite)
if err != nil {
return errors.Wrapf(err, "dir %s:", dir)
return err
}

destFile, err := os.OpenFile(finalPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, stat.Mode())
if err != nil {
return errors.Wrapf(err, "dest %s:", finalPath)
if destAbsRel == nil || source.Abs == destAbsRel.Abs {
// skipping
if log {
loggee.Infof("skipping %s", source.Rel)
}
return nil
}
defer destFile.Close()

destRel, err := filepath.Rel(dest.Path, finalPath)
destFile, err := os.OpenFile(destAbsRel.Abs, os.O_RDWR|os.O_CREATE|os.O_TRUNC, stat.Mode())
if err != nil {
return errors.Wrapf(err, "dest %s, rel %s:", dest.Path, finalPath)
return errors.Wrapf(err, "dest %s:", destAbsRel.Rel)
}
defer destFile.Close()

// Do copy
_, err = io.Copy(destFile, srcFile)
if err != nil {
return errors.Wrapf(err, "copy %s => %s:", source.Rel, destRel)
return errors.Wrapf(err, "copy %s => %s:", source.Rel, destAbsRel.Rel)
}

// log
if log {
loggee.Infof("copy %s => %s", source.Rel, destRel)
loggee.Infof("copy %s => %s", source.Rel, destAbsRel.Rel)
}

return nil
}

func prepDestination(source AbsRel, dest DestSpec, allowOverwrite bool) (*AbsRel, error) {
var finalPath string
if dest.IsDir {
finalPath = filepath.Join(dest.Path, source.Rel)
} else {
finalPath = dest.Path
}

destRel, err := filepath.Rel(dest.Path, finalPath)
if err != nil {
return nil, errors.Wrapf(err, "dest %s, rel %s:", dest.Path, finalPath)
}

if !allowOverwrite {
_, err := os.Stat(finalPath)
if err == nil {
// skip
return nil, nil
}
}

// create dir if needed
dir := filepath.Dir(finalPath)
err = os.MkdirAll(dir, 0777)
if err != nil {
return nil, errors.Wrapf(err, "dir %s:", dir)
}

return &AbsRel{Abs: finalPath, Rel: destRel}, nil
}

func init() {
rocket.Default().RegisterTaskTypes(copyType{})
}
1 change: 0 additions & 1 deletion pkg/builtin/copy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ func TestCopyRun(t *testing.T) {
func validateCopyTest(t *testing.T, dir string) {
t.Helper()
// Check and clean

src, _ := globFileAbsRel("*/*.yml", "*.go")
dest, _ := globFileAbsRel(filepath.Join(dir, "**"))
if len(src) != len(dest) {
Expand Down
15 changes: 14 additions & 1 deletion pkg/builtin/testdata/copy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,22 @@ name: "test copy"
stages:
- tasks:
- type: copy
name: copy data
name: test deep copy
log: true
sources:
- "**/*.go"
- "**/*.yml"
destination: "testdata/cpt/"
- type: copy
name: copy skips
log: true
sources:
- "*.go"
destination: "testdata/cpt/"
- type: copy
name: copy overwrite
log: true
overwrite: true
sources:
- "*.go"
destination: "testdata/cpt/"

0 comments on commit 4869275

Please sign in to comment.