Skip to content

Commit

Permalink
db: use context for backup and restore (#7044)
Browse files Browse the repository at this point in the history
  • Loading branch information
unknwon committed Jun 11, 2022
1 parent f837ea6 commit 75fbb82
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 30 deletions.
3 changes: 2 additions & 1 deletion internal/cmd/backup.go
Expand Up @@ -5,6 +5,7 @@
package cmd

import (
"context"
"fmt"
"io/ioutil"
"os"
Expand Down Expand Up @@ -94,7 +95,7 @@ func runBackup(c *cli.Context) error {

// Database
dbDir := filepath.Join(rootDir, "db")
if err = db.DumpDatabase(conn, dbDir, c.Bool("verbose")); err != nil {
if err = db.DumpDatabase(context.Background(), conn, dbDir, c.Bool("verbose")); err != nil {
log.Fatal("Failed to dump database: %v", err)
}
if err = z.AddDir(archiveRootDir+"/db", dbDir); err != nil {
Expand Down
3 changes: 2 additions & 1 deletion internal/cmd/restore.go
Expand Up @@ -5,6 +5,7 @@
package cmd

import (
"context"
"os"
"path"
"path/filepath"
Expand Down Expand Up @@ -114,7 +115,7 @@ func runRestore(c *cli.Context) error {

// Database
dbDir := path.Join(archivePath, "db")
if err = db.ImportDatabase(conn, dbDir, c.Bool("verbose")); err != nil {
if err = db.ImportDatabase(context.Background(), conn, dbDir, c.Bool("verbose")); err != nil {
log.Fatal("Failed to import database: %v", err)
}

Expand Down
61 changes: 44 additions & 17 deletions internal/db/backup.go
Expand Up @@ -3,6 +3,7 @@ package db
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"os"
Expand Down Expand Up @@ -30,18 +31,24 @@ func getTableType(t interface{}) string {
}

// DumpDatabase dumps all data from database to file system in JSON Lines format.
func DumpDatabase(db *gorm.DB, dirPath string, verbose bool) error {
func DumpDatabase(ctx context.Context, db *gorm.DB, dirPath string, verbose bool) error {
err := os.MkdirAll(dirPath, os.ModePerm)
if err != nil {
return err
}

err = dumpLegacyTables(dirPath, verbose)
err = dumpLegacyTables(ctx, dirPath, verbose)
if err != nil {
return errors.Wrap(err, "dump legacy tables")
}

for _, table := range Tables {
select {
case <-ctx.Done():
return ctx.Err()
default:
}

tableName := getTableType(table)
if verbose {
log.Trace("Dumping table %q...", tableName)
Expand All @@ -55,7 +62,7 @@ func DumpDatabase(db *gorm.DB, dirPath string, verbose bool) error {
}
defer func() { _ = f.Close() }()

return dumpTable(db, table, f)
return dumpTable(ctx, db, table, f)
}()
if err != nil {
return errors.Wrapf(err, "dump table %q", tableName)
Expand All @@ -65,11 +72,13 @@ func DumpDatabase(db *gorm.DB, dirPath string, verbose bool) error {
return nil
}

func dumpTable(db *gorm.DB, table interface{}, w io.Writer) error {
query := db.Model(table).Order("id ASC")
func dumpTable(ctx context.Context, db *gorm.DB, table interface{}, w io.Writer) error {
query := db.WithContext(ctx).Model(table)
switch table.(type) {
case *LFSObject:
query = db.Model(table).Order("repo_id, oid ASC")
query = query.Order("repo_id, oid ASC")
default:
query = query.Order("id ASC")
}

rows, err := query.Rows()
Expand Down Expand Up @@ -98,10 +107,16 @@ func dumpTable(db *gorm.DB, table interface{}, w io.Writer) error {
return rows.Err()
}

func dumpLegacyTables(dirPath string, verbose bool) error {
func dumpLegacyTables(ctx context.Context, dirPath string, verbose bool) error {
// Purposely create a local variable to not modify global variable
legacyTables := append(legacyTables, new(Version))
for _, table := range legacyTables {
select {
case <-ctx.Done():
return ctx.Err()
default:
}

tableName := getTableType(table)
if verbose {
log.Trace("Dumping table %q...", tableName)
Expand All @@ -113,7 +128,7 @@ func dumpLegacyTables(dirPath string, verbose bool) error {
return fmt.Errorf("create JSON file: %v", err)
}

if err = x.Asc("id").Iterate(table, func(idx int, bean interface{}) (err error) {
if err = x.Context(ctx).Asc("id").Iterate(table, func(idx int, bean interface{}) (err error) {
return jsoniter.NewEncoder(f).Encode(bean)
}); err != nil {
_ = f.Close()
Expand All @@ -125,13 +140,19 @@ func dumpLegacyTables(dirPath string, verbose bool) error {
}

// ImportDatabase imports data from backup archive in JSON Lines format.
func ImportDatabase(db *gorm.DB, dirPath string, verbose bool) error {
err := importLegacyTables(dirPath, verbose)
func ImportDatabase(ctx context.Context, db *gorm.DB, dirPath string, verbose bool) error {
err := importLegacyTables(ctx, dirPath, verbose)
if err != nil {
return errors.Wrap(err, "import legacy tables")
}

for _, table := range Tables {
select {
case <-ctx.Done():
return ctx.Err()
default:
}

tableName := strings.TrimPrefix(fmt.Sprintf("%T", table), "*db.")
err := func() error {
tableFile := filepath.Join(dirPath, tableName+".json")
Expand All @@ -150,7 +171,7 @@ func ImportDatabase(db *gorm.DB, dirPath string, verbose bool) error {
}
defer func() { _ = f.Close() }()

return importTable(db, table, f)
return importTable(ctx, db, table, f)
}()
if err != nil {
return errors.Wrapf(err, "import table %q", tableName)
Expand All @@ -160,13 +181,13 @@ func ImportDatabase(db *gorm.DB, dirPath string, verbose bool) error {
return nil
}

func importTable(db *gorm.DB, table interface{}, r io.Reader) error {
err := db.Migrator().DropTable(table)
func importTable(ctx context.Context, db *gorm.DB, table interface{}, r io.Reader) error {
err := db.WithContext(ctx).Migrator().DropTable(table)
if err != nil {
return errors.Wrap(err, "drop table")
}

err = db.Migrator().AutoMigrate(table)
err = db.WithContext(ctx).Migrator().AutoMigrate(table)
if err != nil {
return errors.Wrap(err, "auto migrate")
}
Expand All @@ -191,7 +212,7 @@ func importTable(db *gorm.DB, table interface{}, r io.Reader) error {
return errors.Wrap(err, "unmarshal JSON to struct")
}

err = db.Create(elem).Error
err = db.WithContext(ctx).Create(elem).Error
if err != nil {
return errors.Wrap(err, "create row")
}
Expand All @@ -200,14 +221,14 @@ func importTable(db *gorm.DB, table interface{}, r io.Reader) error {
// PostgreSQL needs manually reset table sequence for auto increment keys
if conf.UsePostgreSQL && !skipResetIDSeq[rawTableName] {
seqName := rawTableName + "_id_seq"
if _, err = x.Exec(fmt.Sprintf(`SELECT setval('%s', COALESCE((SELECT MAX(id)+1 FROM "%s"), 1), false);`, seqName, rawTableName)); err != nil {
if _, err = x.Context(ctx).Exec(fmt.Sprintf(`SELECT setval('%s', COALESCE((SELECT MAX(id)+1 FROM "%s"), 1), false);`, seqName, rawTableName)); err != nil {
return errors.Wrapf(err, "reset table %q.%q", rawTableName, seqName)
}
}
return nil
}

func importLegacyTables(dirPath string, verbose bool) error {
func importLegacyTables(ctx context.Context, dirPath string, verbose bool) error {
snakeMapper := core.SnakeMapper{}

skipInsertProcessors := map[string]bool{
Expand All @@ -218,6 +239,12 @@ func importLegacyTables(dirPath string, verbose bool) error {
// Purposely create a local variable to not modify global variable
legacyTables := append(legacyTables, new(Version))
for _, table := range legacyTables {
select {
case <-ctx.Done():
return ctx.Err()
default:
}

tableName := strings.TrimPrefix(fmt.Sprintf("%T", table), "*db.")
tableFile := filepath.Join(dirPath, tableName+".json")
if !osutil.IsFile(tableFile) {
Expand Down
20 changes: 9 additions & 11 deletions internal/db/backup_test.go
Expand Up @@ -6,12 +6,14 @@ package db

import (
"bytes"
"context"
"os"
"path/filepath"
"testing"
"time"

"github.com/pkg/errors"
"github.com/stretchr/testify/require"
"gorm.io/gorm"

"gogs.io/gogs/internal/auth"
Expand All @@ -22,7 +24,7 @@ import (
"gogs.io/gogs/internal/testutil"
)

func Test_dumpAndImport(t *testing.T) {
func TestDumpAndImport(t *testing.T) {
if testing.Short() {
t.Skip()
}
Expand All @@ -43,8 +45,6 @@ func Test_dumpAndImport(t *testing.T) {
}

func setupDBToDump(t *testing.T, db *gorm.DB) {
t.Helper()

vals := []interface{}{
&Access{
ID: 1,
Expand Down Expand Up @@ -126,31 +126,29 @@ func setupDBToDump(t *testing.T, db *gorm.DB) {
}
for _, val := range vals {
err := db.Create(val).Error
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
}
}

func dumpTables(t *testing.T, db *gorm.DB) {
t.Helper()
ctx := context.Background()

for _, table := range Tables {
tableName := getTableType(table)

var buf bytes.Buffer
err := dumpTable(db, table, &buf)
err := dumpTable(ctx, db, table, &buf)
if err != nil {
t.Fatalf("%s: %v", tableName, err)
}

golden := filepath.Join("testdata", "backup", tableName+".golden.json")
testutil.AssertGolden(t, golden, testutil.Update("Test_dumpAndImport"), buf.String())
testutil.AssertGolden(t, golden, testutil.Update("TestDumpAndImport"), buf.String())
}
}

func importTables(t *testing.T, db *gorm.DB) {
t.Helper()
ctx := context.Background()

for _, table := range Tables {
tableName := getTableType(table)
Expand All @@ -163,7 +161,7 @@ func importTables(t *testing.T, db *gorm.DB) {
}
defer func() { _ = f.Close() }()

return importTable(db, table, f)
return importTable(ctx, db, table, f)
}()
if err != nil {
t.Fatalf("%s: %v", tableName, err)
Expand Down

0 comments on commit 75fbb82

Please sign in to comment.