Skip to content

Commit

Permalink
refactor:pass ctx to db layer, return err instead of recover panic (#450
Browse files Browse the repository at this point in the history
)
  • Loading branch information
hulb committed Sep 30, 2022
1 parent 39cb642 commit 09f2465
Show file tree
Hide file tree
Showing 20 changed files with 875 additions and 702 deletions.
3 changes: 2 additions & 1 deletion go.mod
Expand Up @@ -17,10 +17,12 @@ require (
github.com/lib/pq v1.10.6
github.com/muesli/go-app-paths v0.2.2
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.9.1
github.com/shurcooL/vfsgen v0.0.0-20200824052919-0d455de96546
github.com/sirupsen/logrus v1.9.0
github.com/spf13/cobra v1.5.0
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4
golang.org/x/term v0.0.0-20220722155259-a9ba230a4035
modernc.org/sqlite v1.18.1
)
Expand All @@ -45,7 +47,6 @@ require (
go.uber.org/atomic v1.9.0 // indirect
golang.org/x/image v0.0.0-20220413100746-70e8d0d3baa9 // indirect
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4 // indirect
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/tools v0.1.10 // indirect
Expand Down
4 changes: 2 additions & 2 deletions internal/cmd/add.go
Expand Up @@ -58,7 +58,7 @@ func addHandler(cmd *cobra.Command, args []string) {

// Create bookmark ID
var err error
book.ID, err = db.CreateNewID("bookmark")
book.ID, err = db.CreateNewID(cmd.Context(), "bookmark")
if err != nil {
cError.Printf("Failed to create ID: %v\n", err)
os.Exit(1)
Expand Down Expand Up @@ -111,7 +111,7 @@ func addHandler(cmd *cobra.Command, args []string) {
}

// Save bookmark to database
_, err = db.SaveBookmarks(book)
_, err = db.SaveBookmarks(cmd.Context(), book)
if err != nil {
cError.Printf("Failed to save bookmark: %v\n", err)
os.Exit(1)
Expand Down
4 changes: 2 additions & 2 deletions internal/cmd/check.go
Expand Up @@ -53,7 +53,7 @@ func checkHandler(cmd *cobra.Command, args []string) {

// Fetch bookmarks from database
filterOptions := database.GetBookmarksOptions{IDs: ids}
bookmarks, err := db.GetBookmarks(filterOptions)
bookmarks, err := db.GetBookmarks(cmd.Context(), filterOptions)
if err != nil {
cError.Printf("Failed to get bookmarks: %v\n", err)
os.Exit(1)
Expand Down Expand Up @@ -88,7 +88,7 @@ func checkHandler(cmd *cobra.Command, args []string) {
_, err := httpClient.Get(book.URL)
if err != nil {
chProblem <- book.ID
chMessage <- fmt.Errorf("Failed to reach %s: %v", book.URL, err)
chMessage <- fmt.Errorf("failed to reach %s: %v", book.URL, err)
return
}

Expand Down
2 changes: 1 addition & 1 deletion internal/cmd/delete.go
Expand Up @@ -52,7 +52,7 @@ func deleteHandler(cmd *cobra.Command, args []string) {
}

// Delete bookmarks from database
err = db.DeleteBookmarks(ids...)
err = db.DeleteBookmarks(cmd.Context(), ids...)
if err != nil {
cError.Printf("Failed to delete bookmarks: %v\n", err)
os.Exit(1)
Expand Down
2 changes: 1 addition & 1 deletion internal/cmd/export.go
Expand Up @@ -24,7 +24,7 @@ func exportCmd() *cobra.Command {

func exportHandler(cmd *cobra.Command, args []string) {
// Fetch bookmarks from database
bookmarks, err := db.GetBookmarks(database.GetBookmarksOptions{})
bookmarks, err := db.GetBookmarks(cmd.Context(), database.GetBookmarksOptions{})
if err != nil {
cError.Printf("Failed to get bookmarks: %v\n", err)
os.Exit(1)
Expand Down
12 changes: 9 additions & 3 deletions internal/cmd/import.go
Expand Up @@ -38,7 +38,7 @@ func importHandler(cmd *cobra.Command, args []string) {
}

// Prepare bookmark's ID
bookID, err := db.CreateNewID("bookmark")
bookID, err := db.CreateNewID(cmd.Context(), "bookmark")
if err != nil {
cError.Printf("Failed to create ID: %v\n", err)
os.Exit(1)
Expand Down Expand Up @@ -91,7 +91,13 @@ func importHandler(cmd *cobra.Command, args []string) {
return
}

if _, exist := db.GetBookmark(0, url); exist {
_, exist, err := db.GetBookmark(cmd.Context(), 0, url)
if err != nil {
cError.Printf("Skip %s: Get Bookmark fail, %v", url, err)
return
}

if exist {
cError.Printf("Skip %s: URL already exists\n", url)
mapURL[url] = struct{}{}
return
Expand Down Expand Up @@ -127,7 +133,7 @@ func importHandler(cmd *cobra.Command, args []string) {
})

// Save bookmark to database
bookmarks, err = db.SaveBookmarks(bookmarks...)
bookmarks, err = db.SaveBookmarks(cmd.Context(), bookmarks...)
if err != nil {
cError.Printf("Failed to save bookmarks: %v\n", err)
os.Exit(1)
Expand Down
2 changes: 1 addition & 1 deletion internal/cmd/open.go
Expand Up @@ -73,7 +73,7 @@ func openHandler(cmd *cobra.Command, args []string) {
WithContent: true,
}

bookmarks, err := db.GetBookmarks(getOptions)
bookmarks, err := db.GetBookmarks(cmd.Context(), getOptions)
if err != nil {
cError.Printf("Failed to get bookmarks: %v\n", err)
os.Exit(1)
Expand Down
12 changes: 9 additions & 3 deletions internal/cmd/pocket.go
Expand Up @@ -26,7 +26,7 @@ func pocketCmd() *cobra.Command {

func pocketHandler(cmd *cobra.Command, args []string) {
// Prepare bookmark's ID
bookID, err := db.CreateNewID("bookmark")
bookID, err := db.CreateNewID(cmd.Context(), "bookmark")
if err != nil {
cError.Printf("Failed to create ID: %v\n", err)
return
Expand Down Expand Up @@ -77,7 +77,13 @@ func pocketHandler(cmd *cobra.Command, args []string) {
return
}

if _, exist := db.GetBookmark(0, url); exist {
_, exist, err := db.GetBookmark(cmd.Context(), 0, url)
if err != nil {
cError.Printf("Skip %s: Get Bookmark fail, %v", url, err)
return
}

if exist {
cError.Printf("Skip %s: URL already exists\n", url)
mapURL[url] = struct{}{}
return
Expand Down Expand Up @@ -106,7 +112,7 @@ func pocketHandler(cmd *cobra.Command, args []string) {
})

// Save bookmark to database
bookmarks, err = db.SaveBookmarks(bookmarks...)
bookmarks, err = db.SaveBookmarks(cmd.Context(), bookmarks...)
if err != nil {
cError.Printf("Failed to save bookmarks: %v\n", err)
os.Exit(1)
Expand Down
2 changes: 1 addition & 1 deletion internal/cmd/print.go
Expand Up @@ -61,7 +61,7 @@ func printHandler(cmd *cobra.Command, args []string) {
OrderMethod: orderMethod,
}

bookmarks, err := db.GetBookmarks(searchOptions)
bookmarks, err := db.GetBookmarks(cmd.Context(), searchOptions)
if err != nil {
cError.Printf("Failed to get bookmarks: %v\n", err)
return
Expand Down
23 changes: 12 additions & 11 deletions internal/cmd/root.go
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/go-shiori/shiori/internal/database"
apppaths "github.com/muesli/go-app-paths"
"github.com/spf13/cobra"
"golang.org/x/net/context"
)

var (
Expand Down Expand Up @@ -61,7 +62,7 @@ func preRunRootHandler(cmd *cobra.Command, args []string) {
}

// Open database
db, err = openDatabase()
db, err = openDatabase(cmd.Context())
if err != nil {
cError.Printf("Failed to open database: %v\n", err)
os.Exit(1)
Expand Down Expand Up @@ -101,33 +102,33 @@ func getDataDir(portableMode bool) (string, error) {
return ".", nil
}

func openDatabase() (database.DB, error) {
func openDatabase(ctx context.Context) (database.DB, error) {
switch dbms, _ := os.LookupEnv("SHIORI_DBMS"); dbms {
case "mysql":
return openMySQLDatabase()
return openMySQLDatabase(ctx)
case "postgresql":
return openPostgreSQLDatabase()
return openPostgreSQLDatabase(ctx)
default:
return openSQLiteDatabase()
return openSQLiteDatabase(ctx)
}
}

func openSQLiteDatabase() (database.DB, error) {
func openSQLiteDatabase(ctx context.Context) (database.DB, error) {
dbPath := fp.Join(dataDir, "shiori.db")
return database.OpenSQLiteDatabase(dbPath)
return database.OpenSQLiteDatabase(ctx, dbPath)
}

func openMySQLDatabase() (database.DB, error) {
func openMySQLDatabase(ctx context.Context) (database.DB, error) {
user, _ := os.LookupEnv("SHIORI_MYSQL_USER")
password, _ := os.LookupEnv("SHIORI_MYSQL_PASS")
dbName, _ := os.LookupEnv("SHIORI_MYSQL_NAME")
dbAddress, _ := os.LookupEnv("SHIORI_MYSQL_ADDRESS")

connString := fmt.Sprintf("%s:%s@%s/%s?charset=utf8mb4", user, password, dbAddress, dbName)
return database.OpenMySQLDatabase(connString)
return database.OpenMySQLDatabase(ctx, connString)
}

func openPostgreSQLDatabase() (database.DB, error) {
func openPostgreSQLDatabase(ctx context.Context) (database.DB, error) {
host, _ := os.LookupEnv("SHIORI_PG_HOST")
port, _ := os.LookupEnv("SHIORI_PG_PORT")
user, _ := os.LookupEnv("SHIORI_PG_USER")
Expand All @@ -136,5 +137,5 @@ func openPostgreSQLDatabase() (database.DB, error) {

connString := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
host, port, user, password, dbName)
return database.OpenPGDatabase(connString)
return database.OpenPGDatabase(ctx, connString)
}
8 changes: 4 additions & 4 deletions internal/cmd/update.go
Expand Up @@ -94,7 +94,7 @@ func updateHandler(cmd *cobra.Command, args []string) {
IDs: ids,
}

bookmarks, err := db.GetBookmarks(filterOptions)
bookmarks, err := db.GetBookmarks(cmd.Context(), filterOptions)
if err != nil {
cError.Printf("Failed to get bookmarks: %v\n", err)
os.Exit(1)
Expand Down Expand Up @@ -159,7 +159,7 @@ func updateHandler(cmd *cobra.Command, args []string) {
content, contentType, err := core.DownloadBookmark(book.URL)
if err != nil {
chProblem <- book.ID
chMessage <- fmt.Errorf("Failed to download %s: %v", book.URL, err)
chMessage <- fmt.Errorf("failed to download %s: %v", book.URL, err)
return
}

Expand All @@ -178,7 +178,7 @@ func updateHandler(cmd *cobra.Command, args []string) {

if err != nil {
chProblem <- book.ID
chMessage <- fmt.Errorf("Failed to process %s: %v", book.URL, err)
chMessage <- fmt.Errorf("failed to process %s: %v", book.URL, err)
return
}

Expand Down Expand Up @@ -285,7 +285,7 @@ func updateHandler(cmd *cobra.Command, args []string) {
}

// Save bookmarks to database
bookmarks, err = db.SaveBookmarks(bookmarks...)
bookmarks, err = db.SaveBookmarks(cmd.Context(), bookmarks...)
if err != nil {
cError.Printf("Failed to save bookmark: %v\n", err)
os.Exit(1)
Expand Down
2 changes: 1 addition & 1 deletion internal/cmd/utils.go
Expand Up @@ -27,7 +27,7 @@ var (
cInfo = color.New(color.FgHiCyan)
cError = color.New(color.FgHiRed)

errInvalidIndex = errors.New("Index is not valid")
errInvalidIndex = errors.New("index is not valid")
)

func normalizeSpace(str string) string {
Expand Down
58 changes: 41 additions & 17 deletions internal/database/database.go
@@ -1,10 +1,13 @@
package database

import (
"database/sql"
"context"
"embed"
"log"

"github.com/go-shiori/shiori/internal/model"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
)

//go:embed migrations/*
Expand Down Expand Up @@ -46,44 +49,65 @@ type DB interface {
Migrate() error

// SaveBookmarks saves bookmarks data to database.
SaveBookmarks(bookmarks ...model.Bookmark) ([]model.Bookmark, error)
SaveBookmarks(ctx context.Context, bookmarks ...model.Bookmark) ([]model.Bookmark, error)

// GetBookmarks fetch list of bookmarks based on submitted options.
GetBookmarks(opts GetBookmarksOptions) ([]model.Bookmark, error)
GetBookmarks(ctx context.Context, opts GetBookmarksOptions) ([]model.Bookmark, error)

// GetBookmarksCount get count of bookmarks in database.
GetBookmarksCount(opts GetBookmarksOptions) (int, error)
GetBookmarksCount(ctx context.Context, opts GetBookmarksOptions) (int, error)

// DeleteBookmarks removes all record with matching ids from database.
DeleteBookmarks(ids ...int) error
DeleteBookmarks(ctx context.Context, ids ...int) error

// GetBookmark fetches bookmark based on its ID or URL.
GetBookmark(id int, url string) (model.Bookmark, bool)
// GetBookmark fetchs bookmark based on its ID or URL.
GetBookmark(ctx context.Context, id int, url string) (model.Bookmark, bool, error)

// SaveAccount saves new account in database
SaveAccount(model.Account) error
SaveAccount(ctx context.Context, a model.Account) error

// GetAccounts fetch list of account (without its password) with matching keyword.
GetAccounts(opts GetAccountsOptions) ([]model.Account, error)
GetAccounts(ctx context.Context, opts GetAccountsOptions) ([]model.Account, error)

// GetAccount fetch account with matching username.
GetAccount(username string) (model.Account, bool)
GetAccount(ctx context.Context, username string) (model.Account, bool, error)

// DeleteAccounts removes all record with matching usernames
DeleteAccounts(usernames ...string) error
DeleteAccounts(ctx context.Context, usernames ...string) error

// GetTags fetch list of tags and its frequency from database.
GetTags() ([]model.Tag, error)
GetTags(ctx context.Context) ([]model.Tag, error)

// RenameTag change the name of a tag.
RenameTag(id int, newName string) error
RenameTag(ctx context.Context, id int, newName string) error

// CreateNewID creates new id for specified table.
CreateNewID(table string) (int, error)
CreateNewID(ctx context.Context, table string) (int, error)
}

func checkError(err error) {
if err != nil && err != sql.ErrNoRows {
panic(err)
type dbbase struct {
sqlx.DB
}

func (db *dbbase) withTx(ctx context.Context, fn func(tx *sqlx.Tx) error) error {
tx, err := db.BeginTxx(ctx, nil)
if err != nil {
return errors.WithStack(err)
}

defer func() {
if err := tx.Commit(); err != nil {
log.Printf("error during commit: %s", err)
}
}()

err = fn(tx)
if err != nil {
if err := tx.Rollback(); err != nil {
log.Printf("error during rollback: %s", err)
}
return errors.WithStack(err)
}

return err
}

0 comments on commit 09f2465

Please sign in to comment.