diff --git a/cmd/core/helpers.go b/cmd/core/helpers.go index ee8108d3..d6a2ce4e 100644 --- a/cmd/core/helpers.go +++ b/cmd/core/helpers.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "os" "strings" "text/tabwriter" @@ -448,6 +449,13 @@ func IsURL(ref string) bool { return strings.HasPrefix(ref, "http://") || strings.HasPrefix(ref, "https://") } +// CloseOnCancel closes c when ctx is canceled; callers `defer CloseOnCancel(ctx, c)()` to stop the watcher on return. +func CloseOnCancel(ctx context.Context, c io.Closer) func() bool { + return context.AfterFunc(ctx, func() { + c.Close() //nolint:errcheck,gosec + }) +} + // resolveOwner returns the unique backend where found==true; notFound on zero, ambiguous wrapped on multi-match (lists matched types). func resolveOwner[T interface{ Type() string }](backends []T, ref string, found func(T) (bool, error), notFound, ambiguous error) (T, error) { var matches []T diff --git a/cmd/snapshot/handler.go b/cmd/snapshot/handler.go index 8844b95d..703044bf 100644 --- a/cmd/snapshot/handler.go +++ b/cmd/snapshot/handler.go @@ -2,7 +2,6 @@ package snapshot import ( "cmp" - "context" "errors" "fmt" "io" @@ -58,11 +57,7 @@ func (h Handler) Save(cmd *cobra.Command, args []string) error { } defer stream.Close() //nolint:errcheck - // Close stream on ctx cancel so Ctrl+C doesn't hang on the pipe. - stop := context.AfterFunc(ctx, func() { - stream.Close() //nolint:errcheck,gosec - }) - defer stop() + defer cmdcore.CloseOnCancel(ctx, stream)() cfg.Name = name cfg.Description = description @@ -199,11 +194,7 @@ func (h Handler) Export(cmd *cobra.Command, args []string) (err error) { return fmt.Errorf("export: %w", err) } defer stream.Close() //nolint:errcheck - - stop := context.AfterFunc(ctx, func() { - stream.Close() //nolint:errcheck,gosec - }) - defer stop() + defer cmdcore.CloseOnCancel(ctx, stream)() if output == "-" { if _, err = io.Copy(os.Stdout, stream); err != nil { diff --git a/cmd/vm/run.go b/cmd/vm/run.go index c35dad82..1e75e663 100644 --- a/cmd/vm/run.go +++ b/cmd/vm/run.go @@ -112,11 +112,7 @@ func (h Handler) Clone(cmd *cobra.Command, args []string) error { return fmt.Errorf("open snapshot %s: %w", snapRef, err) } defer stream.Close() //nolint:errcheck - - stop := context.AfterFunc(ctx, func() { - stream.Close() //nolint:errcheck,gosec - }) - defer stop() + defer cmdcore.CloseOnCancel(ctx, stream)() vmCfg, vmID, netProvider, netSetup, err := h.prepareClone(ctx, cmd, conf, cfg) if err != nil { @@ -191,11 +187,7 @@ func (h Handler) Restore(cmd *cobra.Command, args []string) error { return fmt.Errorf("open snapshot: %w", err) } defer stream.Close() //nolint:errcheck - - stop := context.AfterFunc(ctx, func() { - stream.Close() //nolint:errcheck,gosec - }) - defer stop() + defer cmdcore.CloseOnCancel(ctx, stream)() logger.Infof(ctx, "restoring VM %s from snapshot %s ...", vmRef, snapRef) diff --git a/cmd/vm/status_test.go b/cmd/vm/status_test.go index 0f1aaf47..7cbbcf73 100644 --- a/cmd/vm/status_test.go +++ b/cmd/vm/status_test.go @@ -3,6 +3,7 @@ package vm import ( "io" "os" + "slices" "strings" "testing" "time" @@ -176,21 +177,9 @@ func TestApplyFilters(t *testing.T) { for _, vm := range got { gotIDs = append(gotIDs, vm.ID) } - if !equalStrings(gotIDs, tt.wantIDs) { + if !slices.Equal(gotIDs, tt.wantIDs) { t.Errorf("applyFilters(%v) = %v, want %v", tt.filters, gotIDs, tt.wantIDs) } }) } } - -func equalStrings(a, b []string) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if a[i] != b[i] { - return false - } - } - return true -} diff --git a/hypervisor/cloudhypervisor/console.go b/hypervisor/cloudhypervisor/console.go index cffc15ab..00ec2817 100644 --- a/hypervisor/cloudhypervisor/console.go +++ b/hypervisor/cloudhypervisor/console.go @@ -15,12 +15,7 @@ import ( // Console returns a bidirectional stream to the VM console: console.sock (UEFI) or the CH-allocated PTY (OCI). // Caller closes the returned ReadWriteCloser. func (ch *CloudHypervisor) Console(ctx context.Context, ref string) (io.ReadWriteCloser, error) { - id, err := ch.ResolveRef(ctx, ref) - if err != nil { - return nil, err - } - - rec, err := ch.LoadRecord(ctx, id) + id, rec, err := ch.ResolveAndLoad(ctx, ref) if err != nil { return nil, err } diff --git a/hypervisor/cloudhypervisor/extend.go b/hypervisor/cloudhypervisor/extend.go index 1530c63d..013f8d11 100644 --- a/hypervisor/cloudhypervisor/extend.go +++ b/hypervisor/cloudhypervisor/extend.go @@ -202,11 +202,7 @@ func (ch *CloudHypervisor) runningVMClient(ctx context.Context, vmRef string) (* } func (ch *CloudHypervisor) runningVMClientWithRecord(ctx context.Context, vmRef string) (*http.Client, string, hypervisor.VMRecord, error) { - vmID, err := ch.ResolveRef(ctx, vmRef) - if err != nil { - return nil, "", hypervisor.VMRecord{}, err - } - rec, err := ch.LoadRecord(ctx, vmID) + vmID, rec, err := ch.ResolveAndLoad(ctx, vmRef) if err != nil { return nil, "", hypervisor.VMRecord{}, err } diff --git a/hypervisor/firecracker/console.go b/hypervisor/firecracker/console.go index c0df0230..afab67fc 100644 --- a/hypervisor/firecracker/console.go +++ b/hypervisor/firecracker/console.go @@ -10,12 +10,7 @@ import ( ) func (fc *Firecracker) Console(ctx context.Context, ref string) (io.ReadWriteCloser, error) { - id, err := fc.ResolveRef(ctx, ref) - if err != nil { - return nil, err - } - - rec, err := fc.LoadRecord(ctx, id) + id, rec, err := fc.ResolveAndLoad(ctx, ref) if err != nil { return nil, err } diff --git a/hypervisor/inspect.go b/hypervisor/inspect.go index b2863630..fb4b9aea 100644 --- a/hypervisor/inspect.go +++ b/hypervisor/inspect.go @@ -73,6 +73,23 @@ func (b *Backend) LoadRecord(ctx context.Context, id string) (VMRecord, error) { }) } +// ResolveAndLoad combines ResolveRef + LoadRecord under a single DB lock. +func (b *Backend) ResolveAndLoad(ctx context.Context, ref string) (string, VMRecord, error) { + var ( + id string + rec VMRecord + ) + return id, rec, b.DB.With(ctx, func(idx *VMIndex) error { + var err error + id, err = idx.Resolve(ref) + if err != nil { + return err + } + rec, err = utils.LookupCopy(idx.VMs, id) + return err + }) +} + func vsockBound(path string) bool { _, err := os.Stat(path) return err == nil diff --git a/snapshot/localfile/import.go b/snapshot/localfile/import.go index 8513bbf3..01bdbbde 100644 --- a/snapshot/localfile/import.go +++ b/snapshot/localfile/import.go @@ -59,25 +59,11 @@ func (lf *LocalFile) Import(ctx context.Context, r io.Reader, name, description return "", fmt.Errorf("compute data dir size: %w", sizeErr) } now := time.Now() - if err = lf.store.Update(ctx, func(idx *snapshot.SnapshotIndex) error { - if cfg.Name != "" { - if existingID, ok := idx.Names[cfg.Name]; ok { - return fmt.Errorf("snapshot name %q already in use by %s", cfg.Name, existingID) - } - } - idx.Snapshots[id] = &snapshot.SnapshotRecord{ - Snapshot: types.Snapshot{ - SnapshotConfig: cfg, - CreatedAt: now, - }, - DataDir: dataDir, - SizeBytes: size, - LastAccessedAt: now, - } - if cfg.Name != "" { - idx.Names[cfg.Name] = id - } - return nil + if err = lf.insertRecord(ctx, id, cfg.Name, &snapshot.SnapshotRecord{ + Snapshot: types.Snapshot{SnapshotConfig: cfg, CreatedAt: now}, + DataDir: dataDir, + SizeBytes: size, + LastAccessedAt: now, }); err != nil { return "", err } diff --git a/snapshot/localfile/localfile.go b/snapshot/localfile/localfile.go index 8c93ffa5..e6165a26 100644 --- a/snapshot/localfile/localfile.go +++ b/snapshot/localfile/localfile.go @@ -90,24 +90,10 @@ func (lf *LocalFile) Create(ctx context.Context, cfg *types.SnapshotConfig, stre dataDir := lf.conf.SnapshotDataDir(id) now := time.Now() - if err = lf.store.Update(ctx, func(idx *snapshot.SnapshotIndex) error { - if cfg.Name != "" { - if existingID, ok := idx.Names[cfg.Name]; ok { - return fmt.Errorf("snapshot name %q already in use by %s", cfg.Name, existingID) - } - } - idx.Snapshots[id] = &snapshot.SnapshotRecord{ - Snapshot: types.Snapshot{ - SnapshotConfig: *cfg, - CreatedAt: now, - }, - Pending: true, - DataDir: dataDir, - } - if cfg.Name != "" { - idx.Names[cfg.Name] = id - } - return nil + if err = lf.insertRecord(ctx, id, cfg.Name, &snapshot.SnapshotRecord{ + Snapshot: types.Snapshot{SnapshotConfig: *cfg, CreatedAt: now}, + Pending: true, + DataDir: dataDir, }); err != nil { return "", err } @@ -235,6 +221,22 @@ func (lf *LocalFile) deleteOne(ctx context.Context, id string) error { return nil } +// insertRecord adds rec under id with name-collision check; both Create (Pending) and Import (finalized) go through here. +func (lf *LocalFile) insertRecord(ctx context.Context, id, name string, rec *snapshot.SnapshotRecord) error { + return lf.store.Update(ctx, func(idx *snapshot.SnapshotIndex) error { + if name != "" { + if existingID, ok := idx.Names[name]; ok { + return fmt.Errorf("snapshot name %q already in use by %s", name, existingID) + } + } + idx.Snapshots[id] = rec + if name != "" { + idx.Names[name] = id + } + return nil + }) +} + // rollbackCreate removes a placeholder snapshot record from the DB. func (lf *LocalFile) rollbackCreate(ctx context.Context, id, name string) { if err := lf.store.Update(ctx, func(idx *snapshot.SnapshotIndex) error {