Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cmd/core/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"os"
"strings"
"text/tabwriter"
Expand Down Expand Up @@ -448,6 +449,13 @@ func IsURL(ref string) bool {
return strings.HasPrefix(ref, "http://") || strings.HasPrefix(ref, "https://")
}

// CloseOnCancel closes c when ctx is canceled; callers `defer CloseOnCancel(ctx, c)()` to stop the watcher on return.
func CloseOnCancel(ctx context.Context, c io.Closer) func() bool {
return context.AfterFunc(ctx, func() {
c.Close() //nolint:errcheck,gosec
})
}

// resolveOwner returns the unique backend where found==true; notFound on zero, ambiguous wrapped on multi-match (lists matched types).
func resolveOwner[T interface{ Type() string }](backends []T, ref string, found func(T) (bool, error), notFound, ambiguous error) (T, error) {
var matches []T
Expand Down
13 changes: 2 additions & 11 deletions cmd/snapshot/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package snapshot

import (
"cmp"
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -58,11 +57,7 @@ func (h Handler) Save(cmd *cobra.Command, args []string) error {
}
defer stream.Close() //nolint:errcheck

// Close stream on ctx cancel so Ctrl+C doesn't hang on the pipe.
stop := context.AfterFunc(ctx, func() {
stream.Close() //nolint:errcheck,gosec
})
defer stop()
defer cmdcore.CloseOnCancel(ctx, stream)()

cfg.Name = name
cfg.Description = description
Expand Down Expand Up @@ -199,11 +194,7 @@ func (h Handler) Export(cmd *cobra.Command, args []string) (err error) {
return fmt.Errorf("export: %w", err)
}
defer stream.Close() //nolint:errcheck

stop := context.AfterFunc(ctx, func() {
stream.Close() //nolint:errcheck,gosec
})
defer stop()
defer cmdcore.CloseOnCancel(ctx, stream)()

if output == "-" {
if _, err = io.Copy(os.Stdout, stream); err != nil {
Expand Down
12 changes: 2 additions & 10 deletions cmd/vm/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,7 @@ func (h Handler) Clone(cmd *cobra.Command, args []string) error {
return fmt.Errorf("open snapshot %s: %w", snapRef, err)
}
defer stream.Close() //nolint:errcheck

stop := context.AfterFunc(ctx, func() {
stream.Close() //nolint:errcheck,gosec
})
defer stop()
defer cmdcore.CloseOnCancel(ctx, stream)()

vmCfg, vmID, netProvider, netSetup, err := h.prepareClone(ctx, cmd, conf, cfg)
if err != nil {
Expand Down Expand Up @@ -191,11 +187,7 @@ func (h Handler) Restore(cmd *cobra.Command, args []string) error {
return fmt.Errorf("open snapshot: %w", err)
}
defer stream.Close() //nolint:errcheck

stop := context.AfterFunc(ctx, func() {
stream.Close() //nolint:errcheck,gosec
})
defer stop()
defer cmdcore.CloseOnCancel(ctx, stream)()

logger.Infof(ctx, "restoring VM %s from snapshot %s ...", vmRef, snapRef)

Expand Down
15 changes: 2 additions & 13 deletions cmd/vm/status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package vm
import (
"io"
"os"
"slices"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -176,21 +177,9 @@ func TestApplyFilters(t *testing.T) {
for _, vm := range got {
gotIDs = append(gotIDs, vm.ID)
}
if !equalStrings(gotIDs, tt.wantIDs) {
if !slices.Equal(gotIDs, tt.wantIDs) {
t.Errorf("applyFilters(%v) = %v, want %v", tt.filters, gotIDs, tt.wantIDs)
}
})
}
}

func equalStrings(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
7 changes: 1 addition & 6 deletions hypervisor/cloudhypervisor/console.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@ import (
// Console returns a bidirectional stream to the VM console: console.sock (UEFI) or the CH-allocated PTY (OCI).
// Caller closes the returned ReadWriteCloser.
func (ch *CloudHypervisor) Console(ctx context.Context, ref string) (io.ReadWriteCloser, error) {
id, err := ch.ResolveRef(ctx, ref)
if err != nil {
return nil, err
}

rec, err := ch.LoadRecord(ctx, id)
id, rec, err := ch.ResolveAndLoad(ctx, ref)
if err != nil {
return nil, err
}
Expand Down
6 changes: 1 addition & 5 deletions hypervisor/cloudhypervisor/extend.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,7 @@ func (ch *CloudHypervisor) runningVMClient(ctx context.Context, vmRef string) (*
}

func (ch *CloudHypervisor) runningVMClientWithRecord(ctx context.Context, vmRef string) (*http.Client, string, hypervisor.VMRecord, error) {
vmID, err := ch.ResolveRef(ctx, vmRef)
if err != nil {
return nil, "", hypervisor.VMRecord{}, err
}
rec, err := ch.LoadRecord(ctx, vmID)
vmID, rec, err := ch.ResolveAndLoad(ctx, vmRef)
if err != nil {
return nil, "", hypervisor.VMRecord{}, err
}
Expand Down
7 changes: 1 addition & 6 deletions hypervisor/firecracker/console.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,7 @@ import (
)

func (fc *Firecracker) Console(ctx context.Context, ref string) (io.ReadWriteCloser, error) {
id, err := fc.ResolveRef(ctx, ref)
if err != nil {
return nil, err
}

rec, err := fc.LoadRecord(ctx, id)
id, rec, err := fc.ResolveAndLoad(ctx, ref)
if err != nil {
return nil, err
}
Expand Down
17 changes: 17 additions & 0 deletions hypervisor/inspect.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,23 @@ func (b *Backend) LoadRecord(ctx context.Context, id string) (VMRecord, error) {
})
}

// ResolveAndLoad combines ResolveRef + LoadRecord under a single DB lock.
func (b *Backend) ResolveAndLoad(ctx context.Context, ref string) (string, VMRecord, error) {
var (
id string
rec VMRecord
)
return id, rec, b.DB.With(ctx, func(idx *VMIndex) error {
var err error
id, err = idx.Resolve(ref)
if err != nil {
return err
}
rec, err = utils.LookupCopy(idx.VMs, id)
return err
})
}

func vsockBound(path string) bool {
_, err := os.Stat(path)
return err == nil
Expand Down
24 changes: 5 additions & 19 deletions snapshot/localfile/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,25 +59,11 @@ func (lf *LocalFile) Import(ctx context.Context, r io.Reader, name, description
return "", fmt.Errorf("compute data dir size: %w", sizeErr)
}
now := time.Now()
if err = lf.store.Update(ctx, func(idx *snapshot.SnapshotIndex) error {
if cfg.Name != "" {
if existingID, ok := idx.Names[cfg.Name]; ok {
return fmt.Errorf("snapshot name %q already in use by %s", cfg.Name, existingID)
}
}
idx.Snapshots[id] = &snapshot.SnapshotRecord{
Snapshot: types.Snapshot{
SnapshotConfig: cfg,
CreatedAt: now,
},
DataDir: dataDir,
SizeBytes: size,
LastAccessedAt: now,
}
if cfg.Name != "" {
idx.Names[cfg.Name] = id
}
return nil
if err = lf.insertRecord(ctx, id, cfg.Name, &snapshot.SnapshotRecord{
Snapshot: types.Snapshot{SnapshotConfig: cfg, CreatedAt: now},
DataDir: dataDir,
SizeBytes: size,
LastAccessedAt: now,
}); err != nil {
return "", err
}
Expand Down
38 changes: 20 additions & 18 deletions snapshot/localfile/localfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,24 +90,10 @@ func (lf *LocalFile) Create(ctx context.Context, cfg *types.SnapshotConfig, stre
dataDir := lf.conf.SnapshotDataDir(id)
now := time.Now()

if err = lf.store.Update(ctx, func(idx *snapshot.SnapshotIndex) error {
if cfg.Name != "" {
if existingID, ok := idx.Names[cfg.Name]; ok {
return fmt.Errorf("snapshot name %q already in use by %s", cfg.Name, existingID)
}
}
idx.Snapshots[id] = &snapshot.SnapshotRecord{
Snapshot: types.Snapshot{
SnapshotConfig: *cfg,
CreatedAt: now,
},
Pending: true,
DataDir: dataDir,
}
if cfg.Name != "" {
idx.Names[cfg.Name] = id
}
return nil
if err = lf.insertRecord(ctx, id, cfg.Name, &snapshot.SnapshotRecord{
Snapshot: types.Snapshot{SnapshotConfig: *cfg, CreatedAt: now},
Pending: true,
DataDir: dataDir,
}); err != nil {
return "", err
}
Expand Down Expand Up @@ -235,6 +221,22 @@ func (lf *LocalFile) deleteOne(ctx context.Context, id string) error {
return nil
}

// insertRecord adds rec under id with name-collision check; both Create (Pending) and Import (finalized) go through here.
func (lf *LocalFile) insertRecord(ctx context.Context, id, name string, rec *snapshot.SnapshotRecord) error {
return lf.store.Update(ctx, func(idx *snapshot.SnapshotIndex) error {
if name != "" {
if existingID, ok := idx.Names[name]; ok {
return fmt.Errorf("snapshot name %q already in use by %s", name, existingID)
}
}
idx.Snapshots[id] = rec
if name != "" {
idx.Names[name] = id
}
return nil
})
}

// rollbackCreate removes a placeholder snapshot record from the DB.
func (lf *LocalFile) rollbackCreate(ctx context.Context, id, name string) {
if err := lf.store.Update(ctx, func(idx *snapshot.SnapshotIndex) error {
Expand Down
Loading