diff --git a/docs/configuration.md b/docs/configuration.md index 24b037d..a003883 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -157,7 +157,7 @@ for details of each. * `retention`: string, retention policy * `targets`: target configurations, each of which can be reference by other sections. Key is the name of the target that is referenced elsewhere. Each one has the following structure: * `type`: string, the type of target, one of: file, s3, smb - * `url`: string, the URL of the target + * `url`: string, the URL of the target; include `?latest` if the URL is a directory and you want to use the latest file in that directory. If the URL is a file, it should not include `?latest`. * `spec`: access details for the target, depends on target type: * Type s3: * `region`: string, the region diff --git a/pkg/core/restore.go b/pkg/core/restore.go index 2515c0f..076c634 100644 --- a/pkg/core/restore.go +++ b/pkg/core/restore.go @@ -42,7 +42,7 @@ func (e *Executor) Restore(ctx context.Context, opts RestoreOptions) error { attribute.String("targetfile", opts.TargetFile), attribute.String("tmpfile", tmpRestoreFile), ) - copied, err := opts.Target.Pull(ctx, opts.TargetFile, tmpRestoreFile, logger) + copied, err := opts.Target.Pull(ctx, opts.Target.URL(), tmpRestoreFile, logger) if err != nil { pullSpan.RecordError(err) pullSpan.End() diff --git a/pkg/storage/file/file.go b/pkg/storage/file/file.go index 84d0cdf..d227ae0 100644 --- a/pkg/storage/file/file.go +++ b/pkg/storage/file/file.go @@ -7,7 +7,6 @@ import ( "io/fs" "net/url" "os" - "path" "path/filepath" log "github.com/sirupsen/logrus" @@ -23,13 +22,60 @@ func New(u url.URL) *File { } func (f *File) Pull(ctx context.Context, source, target string, logger *log.Entry) (int64, error) { - return copyFile(path.Join(f.path, source), target) + // see if the target has `?latest` set, if so, we need to find the latest file + sourceFile := filepath.Join(f.path, source) + u, err := url.Parse(sourceFile) + if err != nil { + return 0, fmt.Errorf("failed to parse target URL %s: %v", source, err) + } + q := u.Query() + if q.Has("latest") { + latestFilename, err := f.Latest(ctx, u.Path, logger) + if err != nil { + return 0, fmt.Errorf("failed to find latest file for source %s: %v", u.Path, err) + } + logger.Debugf("latest file for target %s is %s", u.Path, latestFilename) + sourceFile = filepath.Join(u.Path, latestFilename) + } + + return copyFile(sourceFile, target) } func (f *File) Push(ctx context.Context, target, source string, logger *log.Entry) (int64, error) { return copyFile(source, filepath.Join(f.path, target)) } +func (f *File) Latest(ctx context.Context, target string, logger *log.Entry) (string, error) { + fullTarget := filepath.Join(f.path, target) + entries, err := os.ReadDir(fullTarget) + if err != nil { + return "", fmt.Errorf("failed to read directory %s: %w", f.path, err) + } + + var latest string + var latestModTime int64 + + for _, entry := range entries { + if entry.IsDir() || !entry.Type().IsRegular() { + continue + } + info, err := entry.Info() + if err != nil { + return "", fmt.Errorf("failed to get info for file %s: %w", entry.Name(), err) + } + if info.ModTime().Unix() > latestModTime { + latest = entry.Name() + latestModTime = info.ModTime().Unix() + } + } + + if latest == "" { + return "", fmt.Errorf("no files found for target %s", target) + } + + return latest, nil +} + func (f *File) Clean(filename string) string { return filename } diff --git a/pkg/storage/s3/s3.go b/pkg/storage/s3/s3.go index ab1c236..60b4840 100644 --- a/pkg/storage/s3/s3.go +++ b/pkg/storage/s3/s3.go @@ -66,6 +66,37 @@ func New(u url.URL, opts ...Option) *S3 { return s } +func (s *S3) Latest(ctx context.Context, target string, logger *log.Entry) (string, error) { + // get the s3 client + client, err := s.getClient(logger) + if err != nil { + return "", fmt.Errorf("failed to get AWS client: %v", err) + } + + // ensure that there is no leading / + p := strings.TrimPrefix(filepath.Join(s.url.Path, target), "/") + result, err := client.ListObjectsV2(context.TODO(), &s3.ListObjectsV2Input{Bucket: aws.String(s.url.Hostname()), Prefix: aws.String(p)}) + if err != nil { + return "", fmt.Errorf("failed to list objects, %v", err) + } + + var latest string + var latestModTime time.Time + + for _, item := range result.Contents { + if item.LastModified.After(latestModTime) { + latest = *item.Key + latestModTime = *item.LastModified + } + } + + if latest == "" { + return "", fmt.Errorf("no files found for target %s", target) + } + + return latest, nil +} + func (s *S3) Pull(ctx context.Context, source, target string, logger *log.Entry) (int64, error) { // get the s3 client client, err := s.getClient(logger) @@ -73,7 +104,22 @@ func (s *S3) Pull(ctx context.Context, source, target string, logger *log.Entry) return 0, fmt.Errorf("failed to get AWS client: %v", err) } - bucket, path := s.url.Hostname(), path.Join(s.url.Path, source) + sourceFile := filepath.Join(s.url.Path, source) + u, err := url.Parse(sourceFile) + if err != nil { + return 0, fmt.Errorf("failed to parse target URL %s: %v", source, err) + } + q := u.Query() + if q.Has("latest") { + latestFilename, err := s.Latest(ctx, u.Path, logger) + if err != nil { + return 0, fmt.Errorf("failed to find latest file for source %s: %v", u.Path, err) + } + logger.Debugf("latest file for target %s is %s", u.Path, latestFilename) + sourceFile = filepath.Join(u.Path, latestFilename) + } + + bucket, path := s.url.Hostname(), sourceFile // Create a downloader with the session and default options downloader := manager.NewDownloader(client) diff --git a/pkg/storage/smb/smb.go b/pkg/storage/smb/smb.go index c3f76e2..6065d6f 100644 --- a/pkg/storage/smb/smb.go +++ b/pkg/storage/smb/smb.go @@ -60,12 +60,27 @@ func (s *SMB) Pull(ctx context.Context, source, target string, logger *log.Entry smbFilename := fmt.Sprintf("%s%c%s", sharepath, smb2.PathSeparator, filepath.Base(strings.ReplaceAll(target, ":", "-"))) smbFilename = strings.TrimPrefix(smbFilename, fmt.Sprintf("%c", smb2.PathSeparator)) + sourceFile := smbFilename + u, err := url.Parse(smbFilename) + if err != nil { + return fmt.Errorf("failed to parse target URL %s: %v", source, err) + } + q := u.Query() + if q.Has("latest") { + latestFilename, err := s.Latest(ctx, u.Path, logger) + if err != nil { + return fmt.Errorf("failed to find latest file for target %s: %v", u.Path, err) + } + logger.Debugf("latest file for target %s is %s", u.Path, latestFilename) + sourceFile = filepath.Join(u.Path, latestFilename) + } + to, err := os.Create(target) if err != nil { return err } defer func() { _ = to.Close() }() - from, err := fs.Open(smbFilename) + from, err := fs.Open(sourceFile) if err != nil { return err } @@ -76,6 +91,43 @@ func (s *SMB) Pull(ctx context.Context, source, target string, logger *log.Entry return copied, err } +func (s *SMB) Latest(ctx context.Context, target string, logger *log.Entry) (string, error) { + var ( + latest string + err error + ) + err = s.exec(s.url, func(fs *smb2.Share, sharepath string) error { + smbDirname := fmt.Sprintf("%s%c%s", sharepath, smb2.PathSeparator, target) + smbDirname = strings.TrimPrefix(smbDirname, fmt.Sprintf("%c", smb2.PathSeparator)) + entries, err := fs.ReadDir(smbDirname) + if err != nil { + return fmt.Errorf("failed to read directory %s: %w", smbDirname, err) + } + + var latestModTime int64 + + for _, entry := range entries { + if entry.IsDir() || !entry.Mode().IsRegular() { + continue + } + + if entry.ModTime().Unix() > latestModTime { + latest = entry.Name() + latestModTime = entry.ModTime().Unix() + } + } + + if latest == "" { + return fmt.Errorf("no files found for target %s", target) + } + return nil + }) + if err != nil { + return "", err + } + return latest, nil +} + func (s *SMB) Push(ctx context.Context, target, source string, logger *log.Entry) (int64, error) { var ( copied int64 diff --git a/pkg/storage/storage.go b/pkg/storage/storage.go index 86d5147..79a7ef6 100644 --- a/pkg/storage/storage.go +++ b/pkg/storage/storage.go @@ -13,6 +13,8 @@ type Storage interface { Clean(filename string) string Push(ctx context.Context, target, source string, logger *log.Entry) (int64, error) Pull(ctx context.Context, source, target string, logger *log.Entry) (int64, error) + // Latest returns the latest, or most recent, file for a given target. Should return just the filename, relative to `target`, not the path. + Latest(ctx context.Context, target string, logger *log.Entry) (string, error) ReadDir(ctx context.Context, dirname string, logger *log.Entry) ([]fs.FileInfo, error) // Remove remove a particular file Remove(ctx context.Context, target string, logger *log.Entry) error