Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions database.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@ package ormshift

import (
"database/sql"
"errors"
"fmt"

"github.com/ordershift/ormshift/errs"
"github.com/ordershift/ormshift/schema"
)

Expand Down Expand Up @@ -35,16 +34,17 @@ type Database struct {

func OpenDatabase(driver DatabaseDriver, params ConnectionParams) (*Database, error) {
if driver == nil {
return nil, errors.New("DatabaseDriver cannot be nil")
err := errs.Nil("database driver")
return nil, failedToOpenDatabase(err)
}
connectionString := driver.ConnectionString(params)
db, err := sql.Open(driver.Name(), connectionString)
if err != nil {
return nil, fmt.Errorf("sql.Open failed: %w", err)
return nil, failedToOpenDatabase(err)
}
dbSchema, err := driver.DBSchema(db)
if err != nil {
return nil, fmt.Errorf("failed to get DB schema: %w", err)
return nil, failedToOpenDatabase(err)
}

return &Database{
Expand All @@ -56,6 +56,10 @@ func OpenDatabase(driver DatabaseDriver, params ConnectionParams) (*Database, er
}, nil
}

func failedToOpenDatabase(err error) error {
return errs.FailedTo("open database", err)
}

func (d *Database) Close() error {
return d.db.Close()
}
Expand Down
6 changes: 3 additions & 3 deletions database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestOpenDatabaseWithNilDriver(t *testing.T) {
if !testutils.AssertNilResultAndNotNilError(t, db, err, "ormshift.OpenDatabase") {
return
}
testutils.AssertErrorMessage(t, "DatabaseDriver cannot be nil", err, "ormshift.OpenDatabase")
testutils.AssertErrorMessage(t, "failed to open database: database driver cannot be nil", err, "ormshift.OpenDatabase")
}

func TestOpenDatabaseWithBadDriver(t *testing.T) {
Expand All @@ -42,7 +42,7 @@ func TestOpenDatabaseWithBadDriver(t *testing.T) {
if !testutils.AssertNilResultAndNotNilError(t, db, err, "ormshift.OpenDatabase") {
return
}
testutils.AssertErrorMessage(t, "sql.Open failed: sql: unknown driver \"bad-driver-name\" (forgotten import?)", err, "ormshift.OpenDatabase")
testutils.AssertErrorMessage(t, "failed to open database: sql: unknown driver \"bad-driver-name\" (forgotten import?)", err, "ormshift.OpenDatabase")
}

func TestOpenDatabaseWithBadSchema(t *testing.T) {
Expand All @@ -51,7 +51,7 @@ func TestOpenDatabaseWithBadSchema(t *testing.T) {
if !testutils.AssertNilResultAndNotNilError(t, db, err, "ormshift.OpenDatabase") {
return
}
testutils.AssertErrorMessage(t, "failed to get DB schema: intentionally bad schema", err, "ormshift.OpenDatabase")
testutils.AssertErrorMessage(t, "failed to open database: intentionally bad schema", err, "ormshift.OpenDatabase")
}

func TestClose(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion dialects/postgresql/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ func TestDBSchemaFailsWhenDBIsNil(t *testing.T) {
if !testutils.AssertNilResultAndNotNilError(t, schema, err, "driver.DBSchema") {
return
}
testutils.AssertErrorMessage(t, "sql.DB cannot be nil", err, "driver.DBSchema")
testutils.AssertErrorMessage(t, "failed to get db schema: db cannot be nil", err, "driver.DBSchema")
}
2 changes: 1 addition & 1 deletion dialects/sqlite/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,5 @@ func TestDBSchemaFailsWhenDBIsNil(t *testing.T) {
if !testutils.AssertNilResultAndNotNilError(t, schema, err, "driver.DBSchema") {
return
}
testutils.AssertErrorMessage(t, "sql.DB cannot be nil", err, "driver.DBSchema")
testutils.AssertErrorMessage(t, "failed to get db schema: db cannot be nil", err, "driver.DBSchema")
}
2 changes: 1 addition & 1 deletion dialects/sqlserver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,5 @@ func TestDBSchemaFailsWhenDBIsNil(t *testing.T) {
if !testutils.AssertNilResultAndNotNilError(t, schema, err, "driver.DBSchema") {
return
}
testutils.AssertErrorMessage(t, "sql.DB cannot be nil", err, "driver.DBSchema")
testutils.AssertErrorMessage(t, "failed to get db schema: db cannot be nil", err, "driver.DBSchema")
}
42 changes: 42 additions & 0 deletions errs/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package errs

import (
"errors"
"fmt"
)

var (
ErrInvalid = errors.New("invalid")
ErrNil = errors.New("cannot be nil")
ErrFailedTo = errors.New("failed to")
ErrAlreadyExists = errors.New("already exists")
)

// Invalid returns an error indicating that value is not valid for the given label.
// The error wraps ErrInvalid, allowing it to be checked with errors.Is.
func Invalid(label string) error {
return fmt.Errorf("%w %s", ErrInvalid, label)
}

// Nil returns an error indicating that the value identified by label is null.
// The error wraps ErrNil, allowing it to be checked with errors.Is.
func Nil(label string) error {
return fmt.Errorf("%s %w", label, ErrNil)
}

// FailedTo returns an error indicating a failure to perform the given action.
// It wraps ErrFailedTo and optionally wraps the provided cause.
func FailedTo(action string, err error) error {
failedToErr := fmt.Errorf("%w %s", ErrFailedTo, action)
if err == nil {
return failedToErr
}
return fmt.Errorf("%w: %w", failedToErr, err)
}

// AlreadyExists returns an error indicating that the resource identified by label
// already exists.
// The error wraps ErrAlreadyExists, allowing it to be checked with errors.Is.
func AlreadyExists(label string) error {
return fmt.Errorf("%s %w", label, ErrAlreadyExists)
}
48 changes: 48 additions & 0 deletions errs/errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package errs_test

import (
"testing"

"github.com/ordershift/ormshift/errs"
"github.com/ordershift/ormshift/internal/testutils"
)

type errorTester struct {
expectedMessageError string
expectedTypeError error
testedError error
}

func TestErrors(t *testing.T) {
testers := []errorTester{
{
expectedMessageError: "invalid driver",
expectedTypeError: errs.ErrInvalid,
testedError: errs.Invalid("driver"),
},
{
expectedMessageError: "database driver cannot be nil",
expectedTypeError: errs.ErrNil,
testedError: errs.Nil("database driver"),
},
{
expectedMessageError: "column already exists",
expectedTypeError: errs.ErrAlreadyExists,
testedError: errs.AlreadyExists("column"),
},
{
expectedMessageError: "failed to get db schema",
expectedTypeError: errs.ErrFailedTo,
testedError: errs.FailedTo("get db schema", nil),
},
{
expectedMessageError: "failed to get db schema: db cannot be nil",
expectedTypeError: errs.ErrFailedTo,
testedError: errs.FailedTo("get db schema", errs.Nil("db")),
},
}
for _, tester := range testers {
testutils.AssertErrorType(t, tester.expectedTypeError, tester.testedError)
testutils.AssertErrorMessage(t, tester.expectedMessageError, tester.testedError, "errs pkg")
}
}
13 changes: 13 additions & 0 deletions internal/testutils/assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package testutils

import (
"database/sql"
"errors"
"strings"
"testing"
)
Expand All @@ -21,6 +22,18 @@ func AssertErrorMessage(t *testing.T, expectedErrorMessage string, err error, fu
}
}

func AssertErrorType(t *testing.T, expectedErrorType, err error) bool {
if !errors.Is(err, expectedErrorType) {
t.Errorf(
"error with message [%s] has not expected type [%s]",
err.Error(),
expectedErrorType.Error(),
)
return false
}
return true
}

func AssertNotNilResultAndNilError[R any](t *testing.T, result *R, err error, functionName string) bool {
res := true
if result == nil {
Expand Down
4 changes: 3 additions & 1 deletion migrations/migrations.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package migrations

import "github.com/ordershift/ormshift"
import (
"github.com/ordershift/ormshift"
)

type Migration interface {
Up(migrator *Migrator) error
Expand Down
2 changes: 1 addition & 1 deletion migrations/migrations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func TestMigrateFailsWhenDatabaseIsClosed(t *testing.T) {
if !testutils.AssertNilResultAndNotNilError(t, migrator, err, "migrations.Migrate") {
return
}
testutils.AssertErrorMessage(t, "failed to get applied migration names: sql: database is closed", err, "migrations.Migrate")
testutils.AssertErrorMessage(t, "failed to migrate: failed to get applied migration names: sql: database is closed", err, "migrations.Migrate")
}

func TestMigrateFailsWhenMigrationUpFails(t *testing.T) {
Expand Down
14 changes: 11 additions & 3 deletions migrations/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/ordershift/ormshift"
"github.com/ordershift/ormshift/errs"
"github.com/ordershift/ormshift/schema"
)

Expand All @@ -18,15 +19,18 @@ type Migrator struct {

func NewMigrator(database *ormshift.Database, config *MigratorConfig) (*Migrator, error) {
if database == nil {
return nil, fmt.Errorf("database cannot be nil")
err := errs.Nil("database")
return nil, failedToMigrate(err)
}
if config == nil {
return nil, fmt.Errorf("migrator config cannot be nil")
err := errs.Nil("migrator config")
return nil, failedToMigrate(err)
}

appliedMigrationNames, err := getAppliedMigrationNames(database, config)
if err != nil {
return nil, fmt.Errorf("failed to get applied migration names: %w", err)
err := errs.FailedTo("get applied migration names", err)
return nil, failedToMigrate(err)
}
appliedMigrations := make(map[string]bool, len(appliedMigrationNames))
for _, name := range appliedMigrationNames {
Expand All @@ -41,6 +45,10 @@ func NewMigrator(database *ormshift.Database, config *MigratorConfig) (*Migrator
}, nil
}

func failedToMigrate(err error) error {
return errs.FailedTo("migrate", err)
}

func (m *Migrator) Add(migration Migration) {
m.migrations = append(m.migrations, migration)
}
Expand Down
6 changes: 3 additions & 3 deletions migrations/migrator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
func TestNewMigratorWhenDatabaseIsNil(t *testing.T) {
migrator, err := migrations.NewMigrator(nil, migrations.NewMigratorConfig())
testutils.AssertNilResultAndNotNilError(t, migrator, err, "migrations.NewMigrator[database=nil]")
testutils.AssertErrorMessage(t, "database cannot be nil", err, "migrations.NewMigrator[database=nil]")
testutils.AssertErrorMessage(t, "failed to migrate: database cannot be nil", err, "migrations.NewMigrator[database=nil]")
}

func TestNewMigratorWhenConfigIsNil(t *testing.T) {
Expand All @@ -25,7 +25,7 @@ func TestNewMigratorWhenConfigIsNil(t *testing.T) {

migrator, err := migrations.NewMigrator(db, nil)
testutils.AssertNilResultAndNotNilError(t, migrator, err, "migrations.NewMigrator[config=nil]")
testutils.AssertErrorMessage(t, "migrator config cannot be nil", err, "migrations.NewMigrator[config=nil]")
testutils.AssertErrorMessage(t, "failed to migrate: migrator config cannot be nil", err, "migrations.NewMigrator[config=nil]")
}

func TestNewMigratorWhenDatabaseIsInvalid(t *testing.T) {
Expand All @@ -38,7 +38,7 @@ func TestNewMigratorWhenDatabaseIsInvalid(t *testing.T) {

migrator, err := migrations.NewMigrator(db, migrations.NewMigratorConfig())
testutils.AssertNilResultAndNotNilError(t, migrator, err, "migrations.NewMigrator[database=invalid]")
testutils.AssertErrorMessage(t, "failed to get applied migration names: missing \"=\" after \"invalid-connection-string\" in connection info string\"", err, "migrations.NewMigrator[database=invalid]")
testutils.AssertErrorMessage(t, "failed to migrate: failed to get applied migration names: missing \"=\" after \"invalid-connection-string\" in connection info string\"", err, "migrations.NewMigrator[database=invalid]")
}

func TestApplyAllMigrationsFailsWhenRecordingFails(t *testing.T) {
Expand Down
10 changes: 8 additions & 2 deletions schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package schema

import (
"database/sql"
"errors"
"slices"
"strings"

"github.com/ordershift/ormshift/errs"
)

type DBSchema struct {
Expand All @@ -21,7 +22,8 @@ func NewDBSchema(
columnTypesQueryFunc ColumnTypesQueryFunc,
) (*DBSchema, error) {
if db == nil {
return nil, errors.New("sql.DB cannot be nil")
err := errs.Nil("db")
return nil, failedToGetDBSchema(err)
}
return &DBSchema{
db: db,
Expand All @@ -30,6 +32,10 @@ func NewDBSchema(
}, nil
}

func failedToGetDBSchema(err error) error {
return errs.FailedTo("get db schema", err)
}

func (s *DBSchema) HasTable(table string) bool {
tables, err := s.fetchTableNames()
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion schema/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestNewDBSchemaFailsWhenDBIsNil(t *testing.T) {
if !testutils.AssertNilResultAndNotNilError(t, dbSchema, err, "schema.NewDBSchema") {
return
}
testutils.AssertErrorMessage(t, "sql.DB cannot be nil", err, "schema.NewDBSchema")
testutils.AssertErrorMessage(t, "failed to get db schema: db cannot be nil", err, "schema.NewDBSchema")
}

func TestHasColumn(t *testing.T) {
Expand Down
9 changes: 8 additions & 1 deletion schema/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"fmt"
"slices"
"strings"

"github.com/ordershift/ormshift/errs"
)

type Table struct {
Expand Down Expand Up @@ -33,9 +35,14 @@ func (t *Table) AddColumns(params ...NewColumnParams) error {
return strings.EqualFold(column.Name(), c.Name())
})
if exists {
return fmt.Errorf("column %q already exists in table %q", column.Name(), t.Name())
return failedToAddColumnInTable(*t, column, errs.AlreadyExists("column"))
}
t.columns = append(t.columns, column)
}
return nil
}

func failedToAddColumnInTable(table Table, column Column, err error) error {
msg := fmt.Sprintf("add column %q in table %q", column.Name(), table.Name())
return errs.FailedTo(msg, err)
}
2 changes: 1 addition & 1 deletion schema/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ func TestAddColumnFailsWhenAlreadyExists(t *testing.T) {
if !testutils.AssertNotNilError(t, err, "Table.AddColumns") {
return
}
testutils.AssertErrorMessage(t, fmt.Sprintf("column %q already exists in table %q", "value", "product_attribute"), err, "Table.AddColumns")
testutils.AssertErrorMessage(t, fmt.Sprintf("failed to add column %q in table %q: column already exists", "value", "product_attribute"), err, "Table.AddColumns")
}