Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal DB changes #247

Merged
merged 5 commits into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/server/cluster/heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ func (g *Gateway) heartbeat(ctx context.Context, mode heartbeatMode) {
// Initialise slice to indicate to HeartbeatNodeHook that its being called from leader.
unavailableMembers := make([]string, 0)

err = query.Retry(func() error {
err = query.Retry(ctx, func(ctx context.Context) error {
// Durating cluster member fluctuations/upgrades the cluster can become unavailable so check here.
if g.Cluster == nil {
return fmt.Errorf("Cluster unavailable")
Expand Down
67 changes: 56 additions & 11 deletions internal/server/db/backups.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ func (c *Cluster) getInstanceBackupID(name string) (int, error) {
id := -1
arg1 := []any{name}
arg2 := []any{&id}
err := dbQueryRowScan(c, q, arg1, arg2)

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
return dbQueryRowScan(ctx, tx, q, arg1, arg2)
})
if err == sql.ErrNoRows {
return -1, api.StatusErrorf(http.StatusNotFound, "Instance backup not found")
}
Expand All @@ -71,7 +74,10 @@ SELECT instances_backups.id, instances_backups.instance_id,
arg1 := []any{projectName, name}
arg2 := []any{&args.ID, &args.InstanceID, &args.CreationDate,
&args.ExpiryDate, &instanceOnlyInt, &optimizedStorageInt}
err := dbQueryRowScan(c, q, arg1, arg2)

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
return dbQueryRowScan(ctx, tx, q, arg1, arg2)
})
if err != nil {
if err == sql.ErrNoRows {
return args, api.StatusErrorf(http.StatusNotFound, "Instance backup not found")
Expand Down Expand Up @@ -110,7 +116,10 @@ SELECT instances_backups.name, instances_backups.instance_id,
arg1 := []any{backupID}
arg2 := []any{&args.Name, &args.InstanceID, &args.CreationDate,
&args.ExpiryDate, &instanceOnlyInt, &optimizedStorageInt}
err := dbQueryRowScan(c, q, arg1, arg2)

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
return dbQueryRowScan(ctx, tx, q, arg1, arg2)
})
if err != nil {
if err == sql.ErrNoRows {
return args, api.StatusErrorf(http.StatusNotFound, "Instance backup not found")
Expand Down Expand Up @@ -141,7 +150,14 @@ JOIN projects ON projects.id=instances.project_id
WHERE projects.name=? AND instances.name=?`
inargs := []any{projectName, name}
outfmt := []any{name}
dbResults, err := queryScan(c, q, inargs, outfmt)

var dbResults [][]any

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
var err error
dbResults, err = queryScan(ctx, tx, q, inargs, outfmt)
return err
})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -203,7 +219,10 @@ func (c *Cluster) DeleteInstanceBackup(name string) error {
return err
}

err = exec(c, "DELETE FROM instances_backups WHERE id=?", id)
err = c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
_, err := tx.tx.ExecContext(ctx, "DELETE FROM instances_backups WHERE id=?", id)
return err
})
if err != nil {
return err
}
Expand Down Expand Up @@ -248,7 +267,14 @@ func (c *Cluster) GetExpiredInstanceBackups() ([]InstanceBackup, error) {

q := `SELECT instances_backups.name, instances_backups.expiry_date, instances_backups.instance_id FROM instances_backups`
outfmt := []any{name, expiryDate, instanceID}
dbResults, err := queryScan(c, q, nil, outfmt)

var dbResults [][]any

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
var err error
dbResults, err = queryScan(ctx, tx, q, nil, outfmt)
return err
})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -375,7 +401,14 @@ WHERE projects.name=? AND storage_volumes.name=?
ORDER BY storage_volumes_backups.id`
inargs := []any{projectName, volumeName}
outfmt := []any{volumeName}
dbResults, err := queryScan(c, q, inargs, outfmt)

var dbResults [][]any

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
var err error
dbResults, err = queryScan(ctx, tx, q, inargs, outfmt)
return err
})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -436,7 +469,10 @@ func (c *Cluster) getStoragePoolVolumeBackupID(name string) (int, error) {
id := -1
arg1 := []any{name}
arg2 := []any{&id}
err := dbQueryRowScan(c, q, arg1, arg2)

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
return dbQueryRowScan(ctx, tx, q, arg1, arg2)
})
if err == sql.ErrNoRows {
return -1, api.StatusErrorf(http.StatusNotFound, "Storage volume backup not found")
}
Expand All @@ -451,7 +487,10 @@ func (c *Cluster) DeleteStoragePoolVolumeBackup(name string) error {
return err
}

err = exec(c, "DELETE FROM storage_volumes_backups WHERE id=?", id)
err = c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
_, err := tx.tx.ExecContext(ctx, "DELETE FROM storage_volumes_backups WHERE id=?", id)
return err
})
if err != nil {
return err
}
Expand All @@ -478,7 +517,10 @@ WHERE projects.name=? AND backups.name=?
`
arg1 := []any{projectName, backupName}
outfmt := []any{&args.ID, &args.VolumeID, &args.Name, &args.CreationDate, &args.ExpiryDate, &args.VolumeOnly, &args.OptimizedStorage}
err := dbQueryRowScan(c, q, arg1, outfmt)

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
return dbQueryRowScan(ctx, tx, q, arg1, outfmt)
})
if err != nil {
if err == sql.ErrNoRows {
return args, api.StatusErrorf(http.StatusNotFound, "Storage volume backup not found")
Expand Down Expand Up @@ -509,7 +551,10 @@ WHERE backups.id=?
`
arg1 := []any{backupID}
outfmt := []any{&args.ID, &args.VolumeID, &args.Name, &args.CreationDate, &args.ExpiryDate, &args.VolumeOnly, &args.OptimizedStorage}
err := dbQueryRowScan(c, q, arg1, outfmt)

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
return dbQueryRowScan(ctx, tx, q, arg1, outfmt)
})
if err != nil {
if err == sql.ErrNoRows {
return args, api.StatusErrorf(http.StatusNotFound, "Storage volume backup not found")
Expand Down
2 changes: 1 addition & 1 deletion internal/server/db/cluster/open.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func EnsureSchema(db *sql.DB, address string, dir string) (bool, error) {
schema.Hook(hook)

var initial int
err := query.Retry(func() error {
err := query.Retry(context.TODO(), func(_ context.Context) error {
var err error
initial, err = schema.Ensure(db)
return err
Expand Down
157 changes: 61 additions & 96 deletions internal/server/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ func (c *Cluster) transaction(ctx context.Context, f func(context.Context, *Clus
nodeID: c.nodeID,
}

return c.retry(func() error {
return query.Retry(ctx, func(ctx context.Context) error {
txFunc := func(ctx context.Context, tx *sql.Tx) error {
clusterTx.tx = tx
return f(ctx, clusterTx)
Expand All @@ -357,14 +357,6 @@ func (c *Cluster) transaction(ctx context.Context, f func(context.Context, *Clus
})
}

func (c *Cluster) retry(f func() error) error {
if c.closingCtx.Err() != nil {
return f()
}

return query.Retry(f)
}

// NodeID sets the node NodeID associated with this cluster instance. It's used for
// backward-compatibility of all db-related APIs that were written before
// clustering and don't accept a node NodeID, so in those cases we automatically
Expand Down Expand Up @@ -473,82 +465,8 @@ func DqliteLatestSegment() (string, error) {
return "none", nil
}

func dbQueryRowScan(c *Cluster, q string, args []any, outargs []any) error {
return c.retry(func() error {
return query.Transaction(context.TODO(), c.db, func(ctx context.Context, tx *sql.Tx) error {
return tx.QueryRowContext(ctx, q, args...).Scan(outargs...)
})
})
}

func doDbScan(c *Cluster, q string, args []any, outargs []any) ([][]any, error) {
result := [][]any{}

err := c.retry(func() error {
return query.Transaction(context.TODO(), c.db, func(ctx context.Context, tx *sql.Tx) error {
rows, err := tx.QueryContext(ctx, q, args...)
if err != nil {
return err
}

defer func() { _ = rows.Close() }()

for rows.Next() {
ptrargs := make([]any, len(outargs))
for i := range outargs {
switch t := outargs[i].(type) {
case string:
str := ""
ptrargs[i] = &str
case int:
integer := 0
ptrargs[i] = &integer
case int64:
integer := int64(0)
ptrargs[i] = &integer
case bool:
boolean := bool(false)
ptrargs[i] = &boolean
default:
return fmt.Errorf("Bad interface type: %s", t)
}
}
err = rows.Scan(ptrargs...)
if err != nil {
return err
}

newargs := make([]any, len(outargs))
for i := range ptrargs {
switch t := outargs[i].(type) {
case string:
newargs[i] = *ptrargs[i].(*string)
case int:
newargs[i] = *ptrargs[i].(*int)
case int64:
newargs[i] = *ptrargs[i].(*int64)
case bool:
newargs[i] = *ptrargs[i].(*bool)
default:
return fmt.Errorf("Bad interface type: %s", t)
}
}
result = append(result, newargs)
}

err = rows.Err()
if err != nil {
return err
}

return nil
})
})
if err != nil {
return [][]any{}, err
}

return result, nil
func dbQueryRowScan(ctx context.Context, c *ClusterTx, q string, args []any, outargs []any) error {
return c.tx.QueryRowContext(ctx, q, args...).Scan(outargs...)
}

/*
Expand All @@ -564,16 +482,63 @@ func doDbScan(c *Cluster, q string, args []any, outargs []any) ([][]any, error)
* The result will be an array (one per output row) of arrays (one per output argument)
* of interfaces, containing pointers to the actual output arguments.
*/
func queryScan(c *Cluster, q string, inargs []any, outfmt []any) ([][]any, error) {
return doDbScan(c, q, inargs, outfmt)
}
func queryScan(ctx context.Context, c *ClusterTx, q string, inargs []any, outfmt []any) ([][]any, error) {
result := [][]any{}

func exec(c *Cluster, q string, args ...any) error {
err := c.retry(func() error {
return query.Transaction(context.TODO(), c.db, func(ctx context.Context, tx *sql.Tx) error {
_, err := tx.Exec(q, args...)
return err
})
})
return err
rows, err := c.tx.QueryContext(ctx, q, inargs...)
if err != nil {
return [][]any{}, err
}

defer func() { _ = rows.Close() }()

for rows.Next() {
ptrargs := make([]any, len(outfmt))
for i := range outfmt {
switch t := outfmt[i].(type) {
case string:
str := ""
ptrargs[i] = &str
case int:
integer := 0
ptrargs[i] = &integer
case int64:
integer := int64(0)
ptrargs[i] = &integer
case bool:
boolean := bool(false)
ptrargs[i] = &boolean
default:
return [][]any{}, fmt.Errorf("Bad interface type: %s", t)
}
}
err = rows.Scan(ptrargs...)
if err != nil {
return [][]any{}, err
}

newargs := make([]any, len(outfmt))
for i := range ptrargs {
switch t := outfmt[i].(type) {
case string:
newargs[i] = *ptrargs[i].(*string)
case int:
newargs[i] = *ptrargs[i].(*int)
case int64:
newargs[i] = *ptrargs[i].(*int64)
case bool:
newargs[i] = *ptrargs[i].(*bool)
default:
return [][]any{}, fmt.Errorf("Bad interface type: %s", t)
}
}
result = append(result, newargs)
}

err = rows.Err()
if err != nil {
return [][]any{}, err
}

return result, nil
}
25 changes: 21 additions & 4 deletions internal/server/db/instances.go
Original file line number Diff line number Diff line change
Expand Up @@ -978,8 +978,11 @@ func (c *Cluster) UpdateInstanceStatefulFlag(id int, stateful bool) error {
// UpdateInstanceSnapshotCreationDate updates the creation_date field of the instance snapshot with ID.
func (c *Cluster) UpdateInstanceSnapshotCreationDate(instanceID int, date time.Time) error {
stmt := `UPDATE instances_snapshots SET creation_date=? WHERE id=?`
err := exec(c, stmt, date, instanceID)
return err

return c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
_, err := tx.tx.ExecContext(ctx, stmt, date, instanceID)
return err
})
}

// GetInstanceSnapshotsNames returns the names of all snapshots of the instance
Expand All @@ -998,7 +1001,14 @@ ORDER BY instances_snapshots.creation_date, instances_snapshots.id
`
inargs := []any{project, name}
outfmt := []any{name}
dbResults, err := queryScan(c, q, inargs, outfmt)

var dbResults [][]any

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
var err error
dbResults, err = queryScan(ctx, tx, q, inargs, outfmt)
return err
})
if err != nil {
return result, err
}
Expand All @@ -1024,7 +1034,14 @@ ORDER BY instances_snapshots.creation_date, instances_snapshots.id
var numstr string
inargs := []any{project, name}
outfmt := []any{numstr}
results, err := queryScan(c, q, inargs, outfmt)

var results [][]any

err := c.Transaction(context.TODO(), func(ctx context.Context, tx *ClusterTx) error {
var err error
results, err = queryScan(ctx, tx, q, inargs, outfmt)
return err
})
if err != nil {
return 0
}
Expand Down
Loading
Loading