Skip to content

Commit

Permalink
refactor(pkg): add the new sqlite driver
Browse files Browse the repository at this point in the history
  • Loading branch information
danvergara committed May 3, 2023
1 parent 69c72ab commit c5813c8
Show file tree
Hide file tree
Showing 14 changed files with 96 additions and 81 deletions.
4 changes: 2 additions & 2 deletions cmd/seeder/main.go
Expand Up @@ -15,8 +15,8 @@ import (
// postgres driver.
_ "github.com/lib/pq"

// sqlite3 driver.
_ "github.com/mattn/go-sqlite3"
// sqlite driver.
_ "modernc.org/sqlite"
)

func main() {
Expand Down
7 changes: 4 additions & 3 deletions db/seeds/customer.go
Expand Up @@ -4,6 +4,7 @@ import (
"log"

"github.com/bxcodec/faker/v3"
"github.com/danvergara/dblab/pkg/drivers"
)

// CustomerSeed seeds the database with customers.
Expand All @@ -13,11 +14,11 @@ func (s Seed) CustomerSeed() {

// execute query.
switch s.driver {
case "postgres":
case drivers.POSTGRES:
_, err = s.db.Exec(`INSERT INTO customers(name, email) VALUES ($1, $2)`, faker.Name(), faker.Email())
case "mysql":
case drivers.MYSQL:
_, err = s.db.Exec(`INSERT INTO customers(name, email) VALUES (?, ?)`, faker.Name(), faker.Email())
case "sqlite3":
case drivers.SQLITE:
_, err = s.db.Exec(`INSERT INTO customers(name, email) VALUES (?, ?)`, faker.Name(), faker.Email())
default:
log.Println("unsupported driver")
Expand Down
7 changes: 4 additions & 3 deletions db/seeds/product.go
Expand Up @@ -5,6 +5,7 @@ import (
"math/rand"

"github.com/bxcodec/faker/v3"
"github.com/danvergara/dblab/pkg/drivers"
)

// ProductSeed seeds product data.
Expand All @@ -14,11 +15,11 @@ func (s Seed) ProductSeed() {

// execute query.
switch s.driver {
case "postgres":
case drivers.POSTGRES:
_, err = s.db.Exec(`INSERT INTO products(name, price) VALUES ($1, $2)`, faker.Word(), rand.Float32())
case "mysql":
case drivers.MYSQL:
_, err = s.db.Exec(`INSERT INTO products(name, price) VALUES (?, ?)`, faker.Word(), rand.Float32())
case "sqlite3":
case drivers.SQLITE:
_, err = s.db.Exec(`INSERT INTO products(name, price) VALUES (?, ?)`, faker.Word(), rand.Float32())
default:
log.Println("unsupported driver")
Expand Down
4 changes: 2 additions & 2 deletions db/seeds/seeder.go
Expand Up @@ -11,8 +11,8 @@ import (
// postgres driver.
_ "github.com/lib/pq"

// sqlite3 driver.
_ "github.com/mattn/go-sqlite3"
// sqlite driver.
_ "modernc.org/sqlite"
)

// Seed type.
Expand Down
7 changes: 4 additions & 3 deletions db/seeds/user.go
Expand Up @@ -4,6 +4,7 @@ import (
"log"

"github.com/bxcodec/faker/v3"
"github.com/danvergara/dblab/pkg/drivers"
)

// UserSeed seeds the database with users.
Expand All @@ -13,11 +14,11 @@ func (s Seed) UserSeed() {

// execute query.
switch s.driver {
case "postgres":
case drivers.POSTGRES:
_, err = s.db.Exec(`INSERT INTO users(username) VALUES ($1)`, faker.Name())
case "mysql":
case drivers.MYSQL:
_, err = s.db.Exec(`INSERT INTO users(username) VALUES (?)`, faker.Name())
case "sqlite3":
case drivers.SQLITE:
_, err = s.db.Exec(`INSERT INTO users(username) VALUES (?)`, faker.Name())
default:
log.Println("unsupported driver")
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yml
Expand Up @@ -71,7 +71,7 @@ services:
- ./:/src/app:z
environment:
- DB_NAME=db/dblab.db
- DB_DRIVER=sqlite3
- DB_DRIVER=sqlite
entrypoint: ["/bin/bash", "./scripts/entrypoint-sqlite3.dev.sh"]
networks:
- dblab
Expand Down
39 changes: 20 additions & 19 deletions pkg/client/client.go
Expand Up @@ -7,15 +7,16 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/danvergara/dblab/pkg/command"
"github.com/danvergara/dblab/pkg/connection"
"github.com/danvergara/dblab/pkg/drivers"
"github.com/danvergara/dblab/pkg/pagination"
"github.com/jmoiron/sqlx"

// mysql driver.
_ "github.com/go-sql-driver/mysql"
// postgres driver.
_ "github.com/lib/pq"
// sqlite3 driver.
_ "github.com/mattn/go-sqlite3"
// sqlite driver.
_ "modernc.org/sqlite"
)

// Client is used to store the pool of db connection.
Expand Down Expand Up @@ -51,9 +52,9 @@ func New(opts command.Options) (*Client, error) {
}

switch c.driver {
case "postgres":
case drivers.POSTGRES:
fallthrough
case "postgresql":
case drivers.POSTGRESQL:
if _, err = db.Exec(fmt.Sprintf("set search_path='%s'", c.schema)); err != nil {
return nil, err
}
Expand Down Expand Up @@ -209,9 +210,9 @@ func (c *Client) ShowTables() ([]string, error) {
tables := make([]string, 0)

switch c.driver {
case "postgres":
case drivers.POSTGRES:
fallthrough
case "postgresql":
case drivers.POSTGRESQL:
psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
query, args, err = psql.Select("table_name").
From("information_schema.tables").
Expand All @@ -222,9 +223,9 @@ func (c *Client) ShowTables() ([]string, error) {
return nil, err
}

case "mysql":
case drivers.MYSQL:
query = "SHOW TABLES;"
case "sqlite3":
case drivers.SQLITE:
query = `
SELECT
name
Expand Down Expand Up @@ -355,9 +356,9 @@ func (c *Client) tableStructure(tableName string) ([][]string, []string, error)
var query string

switch c.driver {
case "postgres":
case drivers.POSTGRES:
fallthrough
case "postgresql":
case drivers.POSTGRESQL:
psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)

query, args, err := psql.Select(
Expand Down Expand Up @@ -394,10 +395,10 @@ func (c *Client) tableStructure(tableName string) ([][]string, []string, error)
}

return c.Query(query, args...)
case "mysql":
case drivers.MYSQL:
query = fmt.Sprintf("DESCRIBE %s;", tableName)
return c.Query(query)
case "sqlite3":
case drivers.SQLITE:
query = fmt.Sprintf("PRAGMA table_info(%s);", tableName)
return c.Query(query)
default:
Expand All @@ -421,17 +422,17 @@ func (c *Client) constraints(tableName string) ([][]string, []string, error) {
Where("tc.table_name = ?")

switch c.driver {
case "sqlite3":
case drivers.SQLITE:
sql = `
SELECT *
FROM
sqlite_master
WHERE
type='table' AND name = ?;`
return c.Query(sql, tableName)
case "postgres":
case drivers.POSTGRES:
fallthrough
case "postgresql":
case drivers.POSTGRESQL:
query = query.Where(fmt.Sprintf("tc.table_schema = '%s'", c.schema))
query = query.PlaceholderFormat(sq.Dollar)
}
Expand All @@ -449,15 +450,15 @@ func (c *Client) indexes(tableName string) ([][]string, []string, error) {
var query string

switch c.driver {
case "postgres":
case drivers.POSTGRES:
fallthrough
case "postgresql":
case drivers.POSTGRESQL:
query = "SELECT * FROM pg_indexes WHERE tablename = $1;"
return c.Query(query, tableName)
case "mysql":
case drivers.MYSQL:
query = fmt.Sprintf("SHOW INDEX FROM %s", tableName)
return c.Query(query)
case "sqlite3":
case drivers.SQLITE:
query = `PRAGMA index_list(%s);`
query = fmt.Sprintf(query, tableName)
return c.Query(query)
Expand Down
12 changes: 7 additions & 5 deletions pkg/client/client_test.go
Expand Up @@ -9,11 +9,13 @@ import (
_ "github.com/go-sql-driver/mysql"
// postgres driver.
_ "github.com/lib/pq"
// sqlite3 driver.
_ "github.com/mattn/go-sqlite3"
// sqlite driver.
_ "modernc.org/sqlite"

"github.com/stretchr/testify/assert"

"github.com/danvergara/dblab/pkg/command"
"github.com/danvergara/dblab/pkg/drivers"
)

var (
Expand All @@ -40,11 +42,11 @@ func TestMain(m *testing.M) {

func generateURL(driver string) string {
switch driver {
case "postgres":
case drivers.POSTGRES:
return fmt.Sprintf("%s://%s:%s@%s:%s/%s?sslmode=disable", driver, user, password, host, port, name)
case "mysql":
case drivers.MYSQL:
return fmt.Sprintf("%s://%s:%s@tcp(%s:%s)/%s", driver, user, password, host, port, name)
case "sqlite3":
case drivers.SQLITE:
return name
default:
return ""
Expand Down
17 changes: 9 additions & 8 deletions pkg/config/config.go
Expand Up @@ -8,6 +8,7 @@ import (
"os"

"github.com/danvergara/dblab/pkg/command"
"github.com/danvergara/dblab/pkg/drivers"
"github.com/kkyr/fig"
"github.com/spf13/cobra"

Expand All @@ -18,7 +19,7 @@ import (
// drivers.
_ "github.com/golang-migrate/migrate/v4/database/mysql"
_ "github.com/golang-migrate/migrate/v4/database/postgres"
_ "github.com/mattn/go-sqlite3"
_ "modernc.org/sqlite"
)

// Config struct is used to store the db connection data.
Expand Down Expand Up @@ -122,7 +123,7 @@ func (c *Config) MigrateInstance() (*migrate.Migrate, error) {
}

switch c.Driver {
case "sqlite3":
case drivers.SQLITE:
dbDriver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
if err != nil {
fmt.Printf("instance error: %v \n", err)
Expand Down Expand Up @@ -181,11 +182,11 @@ func (c *Config) GetSQLXDBConnStr() string {
// getDBConnStr returns the connection string based on the provied host and db name.
func (c *Config) getDBConnStr(dbhost, dbname string) string {
switch c.Driver {
case "postgres":
case drivers.POSTGRES:
return fmt.Sprintf("%s://%s:%s@%s:%s/%s?sslmode=disable", c.Driver, c.User, c.Pswd, dbhost, c.Port, dbname)
case "mysql":
case drivers.MYSQL:
return fmt.Sprintf("%s://%s:%s@tcp(%s:%s)/%s", c.Driver, c.User, c.Pswd, dbhost, c.Port, dbname)
case "sqlite3":
case drivers.SQLITE:
return c.DBName
default:
return ""
Expand All @@ -195,11 +196,11 @@ func (c *Config) getDBConnStr(dbhost, dbname string) string {
// getSQLXConnStr returns the connection string based on the provied host and db name.
func (c *Config) getSQLXConnStr(dbhost, dbname string) string {
switch c.Driver {
case "postgres":
case drivers.POSTGRES:
return fmt.Sprintf("%s://%s:%s@%s:%s/%s?sslmode=disable", c.Driver, c.User, c.Pswd, dbhost, c.Port, dbname)
case "mysql":
case drivers.MYSQL:
return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", c.User, c.Pswd, dbhost, c.Port, dbname)
case "sqlite3":
case drivers.SQLITE:
return c.DBName
default:
return ""
Expand Down
19 changes: 10 additions & 9 deletions pkg/connection/connection.go
Expand Up @@ -10,6 +10,7 @@ import (
"strings"

"github.com/danvergara/dblab/pkg/command"
"github.com/danvergara/dblab/pkg/drivers"
)

var (
Expand All @@ -27,7 +28,7 @@ var (
// ErrInvalidDriver is used to notify that the provided driver is not supported.
ErrInvalidDriver = errors.New("invalid driver")
// ErrInvalidSqlite3Extension is used to notify that the selected file is not a sqlite3 file.
ErrInvalidSqlite3Extension = errors.New("invalid sqlite3 file extension")
ErrInvalidSqlite3Extension = errors.New("invalid sqlite file extension")
// ErrSocketFileDoNotExist indicates that the given path to the socket files leads to no file.
ErrSocketFileDoNotExist = errors.New("socket file does not exist")
// ErrInvalidSocketFile indicates that the socket file must end with .sock as suffix.
Expand All @@ -45,21 +46,21 @@ func init() {
// BuildConnectionFromOpts return the connection uri string given the options passed by the uses.
func BuildConnectionFromOpts(opts command.Options) (string, command.Options, error) {
if opts.URL != "" {
if strings.HasPrefix(opts.URL, "postgres") {
opts.Driver = "postgres"
if strings.HasPrefix(opts.URL, drivers.POSTGRES) {
opts.Driver = drivers.POSTGRES

conn, err := formatPostgresURL(opts)

return conn, opts, err
}

if strings.HasPrefix(opts.URL, "mysql") {
opts.Driver = "mysql"
if strings.HasPrefix(opts.URL, drivers.MYSQL) {
opts.Driver = drivers.MYSQL
conn, err := formatMySQLURL(opts)
return conn, opts, err
}

// this options is for sqlite3.
// this options is for sqlite.
// For more information see https://github.com/mattn/go-sqlite3#connection-string.
if strings.HasPrefix(opts.URL, "file:") {
return opts.URL, opts, nil
Expand All @@ -76,7 +77,7 @@ func BuildConnectionFromOpts(opts command.Options) (string, command.Options, err
}

switch opts.Driver {
case "postgres":
case drivers.POSTGRES:
query := url.Values{}
if opts.SSL != "" {
query.Add("sslmode", opts.SSL)
Expand All @@ -95,7 +96,7 @@ func BuildConnectionFromOpts(opts command.Options) (string, command.Options, err
}

return connDB.String(), opts, nil
case "mysql":
case drivers.MYSQL:
if opts.Socket != "" {
if !validSocketFile(opts.Socket) {
return "", opts, ErrInvalidSocketFile
Expand All @@ -109,7 +110,7 @@ func BuildConnectionFromOpts(opts command.Options) (string, command.Options, err
}

return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", opts.User, opts.Pass, opts.Host, opts.Port, opts.DBName), opts, nil
case "sqlite3":
case drivers.SQLITE:
if hasValidSqlite3FileExtension(opts.DBName) {
return opts.DBName, opts, nil
}
Expand Down

0 comments on commit c5813c8

Please sign in to comment.