Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client refactor #183

Merged
merged 6 commits into from Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
182 changes: 50 additions & 132 deletions pkg/client/client.go
@@ -1,27 +1,35 @@
package client

import (
"errors"
"fmt"

sq "github.com/Masterminds/squirrel"
_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
_ "modernc.org/sqlite"

"github.com/danvergara/dblab/pkg/command"
"github.com/danvergara/dblab/pkg/connection"
"github.com/danvergara/dblab/pkg/drivers"
"github.com/danvergara/dblab/pkg/pagination"
"github.com/jmoiron/sqlx"

// mysql driver.
_ "github.com/go-sql-driver/mysql"
// postgres driver.
_ "github.com/lib/pq"
// sqlite driver.
_ "modernc.org/sqlite"
)

// databaseQuerier is an interface that indicates the methods
// a given type has to implement to interact with a database,
// to get specific data.
// This allows us to decouple the client from the database implementation and
// make adding new databases easier.
type databaseQuerier interface {
ShowTables() (string, []interface{}, error)
TableStructure(tableName string) (string, []interface{}, error)
Constraints(tableName string) (string, []interface{}, error)
Indexes(tableName string) (string, []interface{}, error)
}

// Client is used to store the pool of db connection.
type Client struct {
db *sqlx.DB
databaseQuerier databaseQuerier
driver, schema string
paginationManager *pagination.Manager
limit uint
Expand Down Expand Up @@ -51,6 +59,18 @@ func New(opts command.Options) (*Client, error) {
c.schema = opts.Schema
}

// This is where an implementation of databaseQuerier is getting picked up.
switch c.driver {
case drivers.Postgres, drivers.PostgreSQL:
c.databaseQuerier = newPostgres(c.schema)
case drivers.MySQL:
c.databaseQuerier = newMySQL()
case drivers.SQLite:
c.databaseQuerier = newSQLite()
default:
return nil, fmt.Errorf("%s driver not supported", c.driver)
}

switch c.driver {
case drivers.Postgres:
fallthrough
Expand Down Expand Up @@ -209,31 +229,9 @@ func (c *Client) ShowTables() ([]string, error) {

tables := make([]string, 0)

switch c.driver {
case drivers.Postgres:
fallthrough
case drivers.PostgreSQL:
psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
query, args, err = psql.Select("table_name").
From("information_schema.tables").
Where(sq.Eq{"table_schema": c.schema}).
OrderBy("table_name").
ToSql()
if err != nil {
return nil, err
}

case drivers.MySQL:
query = "SHOW TABLES;"
case drivers.SQLite:
query = `
SELECT
name
FROM
sqlite_schema
WHERE
type ='table' AND
name NOT LIKE 'sqlite_%';`
query, args, err = c.databaseQuerier.ShowTables()
if err != nil {
return nil, err
}

rows, err := c.db.Queryx(query, args...)
Expand Down Expand Up @@ -353,116 +351,36 @@ func (c *Client) tableCount(tableName string) (int, error) {

// tableStructure returns the structure of the table columns.
func (c *Client) tableStructure(tableName string) ([][]string, []string, error) {
var query string

switch c.driver {
case drivers.Postgres:
fallthrough
case drivers.PostgreSQL:
psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)

query, args, err := psql.Select(
"c.column_name",
"c.is_nullable",
"c.data_type",
"c.character_maximum_length",
"c.numeric_precision",
"c.numeric_scale",
"c.ordinal_position",
"tc.constraint_type AS pkey",
).
From("information_schema.columns AS c").
LeftJoin(
`information_schema.constraint_column_usage AS ccu
ON c.table_schema = ccu.table_schema
AND c.table_name = ccu.table_name
AND c.column_name = ccu.column_name`,
).
LeftJoin(
`information_schema.table_constraints AS tc
ON ccu.constraint_schema = tc.constraint_schema
AND ccu.constraint_name = tc.constraint_name`,
).
Where(
sq.And{
sq.Eq{"c.table_schema": c.schema},
sq.Eq{"c.table_name": tableName},
},
).
ToSql()
if err != nil {
return nil, nil, err
}
var (
query string
err error
args []interface{}
)

return c.Query(query, args...)
case drivers.MySQL:
query = fmt.Sprintf("DESCRIBE %s;", tableName)
return c.Query(query)
case drivers.SQLite:
query = fmt.Sprintf("PRAGMA table_info(%s);", tableName)
return c.Query(query)
default:
return nil, nil, errors.New("not supported driver")
query, args, err = c.databaseQuerier.TableStructure(tableName)
if err != nil {
return nil, nil, err
}

return c.Query(query, args...)
}

// constraints returns the resultet of from information_schema.table_constraints.
func (c *Client) constraints(tableName string) ([][]string, []string, error) {
var (
query sq.SelectBuilder
sql string
)

query = sq.Select(
`tc.constraint_name`,
`tc.table_name`,
`tc.constraint_type`,
).
From("information_schema.table_constraints AS tc").
Where("tc.table_name = ?")

switch c.driver {
case drivers.SQLite:
sql = `
SELECT *
FROM
sqlite_master
WHERE
type='table' AND name = ?;`
return c.Query(sql, tableName)
case drivers.Postgres:
fallthrough
case drivers.PostgreSQL:
query = query.Where(fmt.Sprintf("tc.table_schema = '%s'", c.schema))
query = query.PlaceholderFormat(sq.Dollar)
}

sql, _, err := query.ToSql()
sql, args, err := c.databaseQuerier.Constraints(tableName)
if err != nil {
return nil, nil, err
}

return c.Query(sql, tableName)
return c.Query(sql, args...)
}

// indexes returns a resulset with the information of the indexes given a table name.
func (c *Client) indexes(tableName string) ([][]string, []string, error) {
var query string

switch c.driver {
case drivers.Postgres:
fallthrough
case drivers.PostgreSQL:
query = "SELECT * FROM pg_indexes WHERE tablename = $1;"
return c.Query(query, tableName)
case drivers.MySQL:
query = fmt.Sprintf("SHOW INDEX FROM %s", tableName)
return c.Query(query)
case drivers.SQLite:
query = `PRAGMA index_list(%s);`
query = fmt.Sprintf(query, tableName)
return c.Query(query)
default:
return nil, nil, errors.New("not supported driver")
query, args, err := c.databaseQuerier.Indexes(tableName)
if err != nil {
return nil, nil, err
}

return c.Query(query, args...)
}
7 changes: 2 additions & 5 deletions pkg/client/client_test.go
Expand Up @@ -5,14 +5,10 @@ import (
"os"
"testing"

// mysql driver.
_ "github.com/go-sql-driver/mysql"
// postgres driver.
_ "github.com/lib/pq"
// sqlite driver.
_ "modernc.org/sqlite"

"github.com/stretchr/testify/assert"
_ "modernc.org/sqlite"

"github.com/danvergara/dblab/pkg/command"
"github.com/danvergara/dblab/pkg/drivers"
Expand Down Expand Up @@ -334,6 +330,7 @@ func TestMetadata(t *testing.T) {
m, err := c.Metadata("products")

assert.NoError(t, err)
assert.NotNil(t, m)

// Total count.
assert.Equal(t, m.TotalPages, 1)
Expand Down
56 changes: 56 additions & 0 deletions pkg/client/mysql.go
@@ -0,0 +1,56 @@
package client

import (
"fmt"

sq "github.com/Masterminds/squirrel"
)

// mysql struct is in charge of perform all the mysql related queries,
// without the client knowing.
type mysql struct{}

// a validation to see if mysql is implementing databaseQuerier.
var _ databaseQuerier = (*mysql)(nil)

// returns a pointer to a mysql.
func newMySQL() *mysql {
m := mysql{}
return &m
}

// ShowTables returns a query to retrieve all the tables.
func (m *mysql) ShowTables() (string, []interface{}, error) {
query := "SHOW TABLES;"
return query, nil, nil
}

// TableStructure returns a query string to retrieve all the relevant information of a given table.
func (m *mysql) TableStructure(tableName string) (string, []interface{}, error) {
query := fmt.Sprintf("DESCRIBE %s;", tableName)
return query, nil, nil
}

// Constraints returns all the constraints of a given table.
func (m *mysql) Constraints(tableName string) (string, []interface{}, error) {
query := sq.Select(
`tc.constraint_name`,
`tc.table_name`,
`tc.constraint_type`,
).
From("information_schema.table_constraints AS tc").
Where("tc.table_name = ?", tableName)

sql, args, err := query.ToSql()
if err != nil {
return "", nil, err
}

return sql, args, err
}

// Indexes returns a query to get all the indexes of a table.
func (m *mysql) Indexes(tableName string) (string, []interface{}, error) {
query := fmt.Sprintf("SHOW INDEX FROM %s", tableName)
return query, nil, nil
}