Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
361 lines (343 sloc) 10.6 KB
// Copyright 2019-present Facebook Inc. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.
package schema
import (
"context"
"fmt"
"math"
"sort"
"strconv"
"strings"
"github.com/facebookincubator/ent/dialect"
"github.com/facebookincubator/ent/dialect/sql"
"github.com/facebookincubator/ent/schema/field"
)
// MySQL is a mysql migration driver.
type MySQL struct {
dialect.Driver
version string
}
// init loads the MySQL version from the database for later use in the migration process.
func (d *MySQL) init(ctx context.Context, tx dialect.Tx) error {
rows := &sql.Rows{}
if err := tx.Query(ctx, "SHOW VARIABLES LIKE 'version'", []interface{}{}, rows); err != nil {
return fmt.Errorf("mysql: querying mysql version %v", err)
}
defer rows.Close()
if !rows.Next() {
return fmt.Errorf("mysql: version variable was not found")
}
version := make([]string, 2)
if err := rows.Scan(&version[0], &version[1]); err != nil {
return fmt.Errorf("mysql: scanning mysql version: %v", err)
}
d.version = version[1]
return nil
}
func (d *MySQL) tableExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
query, args := sql.Select(sql.Count("*")).From(sql.Table("INFORMATION_SCHEMA.TABLES").Unquote()).
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("TABLE_NAME", name)).Query()
return exist(ctx, tx, query, args...)
}
func (d *MySQL) fkExist(ctx context.Context, tx dialect.Tx, name string) (bool, error) {
query, args := sql.Select(sql.Count("*")).From(sql.Table("INFORMATION_SCHEMA.TABLE_CONSTRAINTS").Unquote()).
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("CONSTRAINT_TYPE", "FOREIGN KEY").And().EQ("CONSTRAINT_NAME", name)).Query()
return exist(ctx, tx, query, args...)
}
// table loads the current table description from the database.
func (d *MySQL) table(ctx context.Context, tx dialect.Tx, name string) (*Table, error) {
rows := &sql.Rows{}
query, args := sql.Select("column_name", "column_type", "is_nullable", "column_key", "column_default", "extra", "character_set_name", "collation_name").
From(sql.Table("INFORMATION_SCHEMA.COLUMNS").Unquote()).
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("TABLE_NAME", name)).Query()
if err := tx.Query(ctx, query, args, rows); err != nil {
return nil, fmt.Errorf("mysql: reading table description %v", err)
}
// call `Close` in cases of failures (`Close` is idempotent).
defer rows.Close()
t := NewTable(name)
for rows.Next() {
c := &Column{}
if err := d.scanColumn(c, rows); err != nil {
return nil, fmt.Errorf("mysql: %v", err)
}
if c.PrimaryKey() {
t.PrimaryKey = append(t.PrimaryKey, c)
}
t.AddColumn(c)
}
if err := rows.Close(); err != nil {
return nil, fmt.Errorf("mysql: closing rows %v", err)
}
indexes, err := d.indexes(ctx, tx, name)
if err != nil {
return nil, err
}
// add and link indexes to table columns.
for _, idx := range indexes {
t.AddIndex(idx.Name, idx.Unique, idx.columns)
}
return t, nil
}
// table loads the table indexes from the database.
func (d *MySQL) indexes(ctx context.Context, tx dialect.Tx, name string) ([]*Index, error) {
rows := &sql.Rows{}
query, args := sql.Select("index_name", "column_name", "non_unique", "seq_in_index").
From(sql.Table("INFORMATION_SCHEMA.STATISTICS").Unquote()).
Where(sql.EQ("TABLE_SCHEMA", sql.Raw("(SELECT DATABASE())")).And().EQ("TABLE_NAME", name)).Query()
if err := tx.Query(ctx, query, args, rows); err != nil {
return nil, fmt.Errorf("mysql: reading index description %v", err)
}
defer rows.Close()
idx, err := d.scanIndexes(rows)
if err != nil {
return nil, fmt.Errorf("mysql: %v", err)
}
return idx, nil
}
func (d *MySQL) setRange(ctx context.Context, tx dialect.Tx, name string, value int) error {
return tx.Exec(ctx, fmt.Sprintf("ALTER TABLE `%s` AUTO_INCREMENT = %d", name, value), []interface{}{}, new(sql.Result))
}
// tBuilder returns the MySQL DSL query for table creation.
func (d *MySQL) tBuilder(t *Table) *sql.TableBuilder {
b := sql.CreateTable(t.Name).IfNotExists()
for _, c := range t.Columns {
b.Column(d.addColumn(c))
}
for _, pk := range t.PrimaryKey {
b.PrimaryKey(pk.Name)
}
// default charset / collation on MySQL table.
// columns can be override using the Charset / Collate fields.
b.Charset("utf8mb4").Collate("utf8mb4_bin")
return b
}
// cType returns the MySQL string type for the given column.
func (d *MySQL) cType(c *Column) (t string) {
switch c.Type {
case field.TypeBool:
t = "boolean"
case field.TypeInt8:
t = "tinyint"
case field.TypeUint8:
t = "tinyint unsigned"
case field.TypeInt16:
t = "smallint"
case field.TypeUint16:
t = "smallint unsigned"
case field.TypeInt32:
t = "int"
case field.TypeUint32:
t = "int unsigned"
case field.TypeInt, field.TypeInt64:
t = "bigint"
case field.TypeUint, field.TypeUint64:
t = "bigint unsigned"
case field.TypeBytes:
size := int64(math.MaxUint16)
if c.Size > 0 {
size = c.Size
}
switch {
case size <= math.MaxUint8:
t = "tinyblob"
case size <= math.MaxUint16:
t = "blob"
case size < 1<<24:
t = "mediumblob"
case size <= math.MaxUint32:
t = "longblob"
}
case field.TypeJSON:
t = "json"
if compareVersions(d.version, "5.7.8") == -1 {
t = "longblob"
}
case field.TypeString:
size := c.Size
if size == 0 {
size = c.defaultSize(d.version)
}
if size <= math.MaxUint16 {
t = fmt.Sprintf("varchar(%d)", size)
} else {
t = "longtext"
}
case field.TypeFloat32, field.TypeFloat64:
t = "double"
case field.TypeTime:
t = "timestamp"
// in MySQL timestamp columns are `NOT NULL by default, and assigning NULL
// assigns the current_timestamp(). We avoid this if not set otherwise.
c.Nullable = true
case field.TypeEnum:
values := make([]string, len(c.Enums))
for i, e := range c.Enums {
values[i] = fmt.Sprintf("'%s'", e)
}
sort.Strings(values)
t = fmt.Sprintf("enum(%s)", strings.Join(values, ", "))
case field.TypeUUID:
t = "char(36) binary"
default:
panic(fmt.Sprintf("unsupported type %q for column %q", c.Type.String(), c.Name))
}
return t
}
// addColumn returns the DSL query for adding the given column to a table.
// The syntax/order is: datatype [Charset] [Unique|Increment] [Collation] [Nullable].
func (d *MySQL) addColumn(c *Column) *sql.ColumnBuilder {
b := sql.Column(c.Name).Type(d.cType(c)).Attr(c.Attr)
c.unique(b)
if c.Increment {
b.Attr("AUTO_INCREMENT")
}
c.nullable(b)
c.defaultValue(b)
return b
}
// alterColumn returns the DSL query for modifying the given column.
func (d *MySQL) alterColumn(c *Column) []*sql.ColumnBuilder {
return []*sql.ColumnBuilder{d.addColumn(c)}
}
// addIndex returns the querying for adding an index to MySQL.
func (d *MySQL) addIndex(i *Index, table string) *sql.IndexBuilder {
return i.Builder(table)
}
// dropIndex returns the querying for dropping an index in MySQL.
func (d *MySQL) dropIndex(i *Index, table string) *sql.DropIndexBuilder {
return i.DropBuilder(table)
}
// scanColumn scans the column information from MySQL column description.
func (d *MySQL) scanColumn(c *Column, rows *sql.Rows) error {
var (
nullable sql.NullString
defaults sql.NullString
)
if err := rows.Scan(&c.Name, &c.typ, &nullable, &c.Key, &defaults, &c.Attr, &sql.NullString{}, &sql.NullString{}); err != nil {
return fmt.Errorf("scanning column description: %v", err)
}
c.Unique = c.UniqueKey()
if nullable.Valid {
c.Nullable = nullable.String == "YES"
}
switch parts := strings.FieldsFunc(c.typ, func(r rune) bool {
return r == '(' || r == ')' || r == ' ' || r == ','
}); parts[0] {
case "int":
c.Type = field.TypeInt32
case "smallint":
c.Type = field.TypeInt16
if len(parts) == 3 { // smallint(5) unsigned.
c.Type = field.TypeUint16
}
case "bigint":
c.Type = field.TypeInt64
if len(parts) == 3 { // bigint(20) unsigned.
c.Type = field.TypeUint64
}
case "tinyint":
size, err := strconv.Atoi(parts[1])
if err != nil {
return fmt.Errorf("converting varchar size to int: %v", err)
}
switch {
case size == 1:
c.Type = field.TypeBool
case len(parts) == 3: // tinyint(3) unsigned.
c.Type = field.TypeUint8
default:
c.Type = field.TypeInt8
}
case "double":
c.Type = field.TypeFloat64
case "timestamp", "datetime":
c.Type = field.TypeTime
case "tinyblob":
c.Size = math.MaxUint8
c.Type = field.TypeBytes
case "blob":
c.Size = math.MaxUint16
c.Type = field.TypeBytes
case "mediumblob":
c.Size = 1<<24 - 1
c.Type = field.TypeBytes
case "longblob":
c.Size = math.MaxUint32
c.Type = field.TypeBytes
case "varbinary":
c.Type = field.TypeBytes
size, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return fmt.Errorf("converting varbinary size to int: %v", err)
}
c.Size = size
case "varchar":
c.Type = field.TypeString
size, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return fmt.Errorf("converting varchar size to int: %v", err)
}
c.Size = size
case "longtext":
c.Size = math.MaxInt32
c.Type = field.TypeString
case "json":
c.Type = field.TypeJSON
case "enum":
c.Type = field.TypeEnum
c.Enums = make([]string, len(parts)-1)
for i, e := range parts[1:] {
c.Enums[i] = strings.Trim(e, "'")
}
case "char":
size, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
return fmt.Errorf("converting char size to int: %v", err)
}
// UUID field has length of 36 characters (32 alphanumeric characters and 4 hyphens).
if size != 36 {
return fmt.Errorf("unknown char(%d) type (not a uuid)", size)
}
c.Type = field.TypeUUID
default:
return fmt.Errorf("unknown column type %q for version %q", parts[0], d.version)
}
if defaults.Valid {
return c.ScanDefault(defaults.String)
}
return nil
}
// scanIndexes scans sql.Rows into an Indexes list. The query for returning the rows,
// should return the following 4 columns: INDEX_NAME, COLUMN_NAME, NON_UNIQUE, SEQ_IN_INDEX.
// SEQ_IN_INDEX specifies the position of the column in the index columns.
func (d *MySQL) scanIndexes(rows *sql.Rows) (Indexes, error) {
var (
i Indexes
names = make(map[string]*Index)
)
for rows.Next() {
var (
name string
column string
nonuniq bool
seqindex int
)
if err := rows.Scan(&name, &column, &nonuniq, &seqindex); err != nil {
return nil, fmt.Errorf("scanning index description: %v", err)
}
// ignore primary keys.
if name == "PRIMARY" {
continue
}
idx, ok := names[name]
if !ok {
idx = &Index{Name: name, Unique: !nonuniq}
i = append(i, idx)
names[name] = idx
}
idx.columns = append(idx.columns, column)
}
return i, nil
}
You can’t perform that action at this time.