Skip to content

Commit

Permalink
ent: initial support for edge schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
a8m committed May 24, 2022
1 parent c3c8773 commit 363542a
Show file tree
Hide file tree
Showing 297 changed files with 25,772 additions and 1,228 deletions.
32 changes: 25 additions & 7 deletions dialect/sql/sqlgraph/graph.go
Expand Up @@ -302,6 +302,9 @@ type (
EdgeTarget struct {
Nodes []driver.Value
IDSpec *FieldSpec
// Additional fields can be set on the
// edge join table. Valid for M2M edges.
Fields []*FieldSpec
}

// EdgeSpec holds the information for updating a field
Expand Down Expand Up @@ -590,7 +593,7 @@ func (q *query) count(ctx context.Context, drv dialect.Driver) (int, error) {
// If no columns were selected in count,
// the default selection is by node ids.
columns := q.Node.Columns
if len(columns) == 0 {
if len(columns) == 0 && q.Node.ID != nil {
columns = append(columns, q.Node.ID.Column)
}
for i, c := range columns {
Expand Down Expand Up @@ -873,6 +876,12 @@ func (c *creator) node(ctx context.Context, drv dialect.Driver) error {
return err
}
if err := func() error {
// In case the spec does not contain an ID field, we assume
// we interact with an edge-schema with composite primary key.
if c.ID == nil {
query, args := insert.Query()
return c.tx.Exec(ctx, query, args, nil)
}
if err := c.insert(ctx, insert); err != nil {
return err
}
Expand Down Expand Up @@ -907,7 +916,7 @@ func (c *creator) setTableColumns(insert *sql.InsertBuilder, edges map[Rel][]*Ed
return err
}

// insert inserts the node to its table and sets its ID if it was not provided by the user.
// insert a node to its table and sets its ID if it was not provided by the user.
func (c *creator) insert(ctx context.Context, insert *sql.InsertBuilder) error {
if opts := c.CreateSpec.OnConflict; len(opts) > 0 {
insert.OnConflict(opts...)
Expand All @@ -916,7 +925,7 @@ func (c *creator) insert(ctx context.Context, insert *sql.InsertBuilder) error {
// If the id field was provided by the user.
if c.ID.Value != nil {
insert.Set(c.ID.Column, c.ID.Value)
// In case of "ON CONFLICT", the record may exists in the
// In case of "ON CONFLICT", the record may exist in the
// database, and we need to get back the database id field.
if len(c.CreateSpec.OnConflict) == 0 {
query, args := insert.Query()
Expand Down Expand Up @@ -1128,8 +1137,17 @@ func (g *graph) addM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeS
// The EdgeSpec is the same for all members in a group.
tables := edges.GroupTable()
for _, table := range edgeKeys(tables) {
edges := tables[table]
insert := g.builder.Insert(table).Columns(edges[0].Columns...)
var (
edges = tables[table]
columns = edges[0].Columns
values = make([]interface{}, 0, len(edges[0].Target.Fields))
)
// Specs are generated equally for all edges from the same type.
for _, f := range edges[0].Target.Fields {
values = append(values, f.Value)
columns = append(columns, f.Column)
}
insert := g.builder.Insert(table).Columns(columns...)
if edges[0].Schema != "" {
// If the Schema field was provided to the EdgeSpec (by the
// generated code), it should be the same for all EdgeSpecs.
Expand All @@ -1141,9 +1159,9 @@ func (g *graph) addM2MEdges(ctx context.Context, ids []driver.Value, edges EdgeS
pk1, pk2 = pk2, pk1
}
for _, pair := range product(pk1, pk2) {
insert.Values(pair[0], pair[1])
insert.Values(append([]interface{}{pair[0], pair[1]}, values...)...)
if edge.Bidi {
insert.Values(pair[1], pair[0])
insert.Values(append([]interface{}{pair[1], pair[0]}, values...)...)
}
}
}
Expand Down
46 changes: 46 additions & 0 deletions dialect/sql/sqlgraph/graph_test.go
Expand Up @@ -1135,6 +1135,29 @@ func TestCreateNode(t *testing.T) {
m.ExpectCommit()
},
},
{
name: "edges/m2m/fields",
spec: &CreateSpec{
Table: "groups",
ID: &FieldSpec{Column: "id", Type: field.TypeInt},
Fields: []*FieldSpec{
{Column: "name", Type: field.TypeString, Value: "GitHub"},
},
Edges: []*EdgeSpec{
{Rel: M2M, Table: "group_users", Columns: []string{"group_id", "user_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}, Fields: []*FieldSpec{{Column: "ts", Type: field.TypeInt, Value: 3}}}},
},
},
expect: func(m sqlmock.Sqlmock) {
m.ExpectBegin()
m.ExpectExec(escape("INSERT INTO `groups` (`name`) VALUES (?)")).
WithArgs("GitHub").
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectExec(escape("INSERT INTO `group_users` (`group_id`, `user_id`, `ts`) VALUES (?, ?, ?)")).
WithArgs(1, 2, 3).
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectCommit()
},
},
{
name: "edges/m2m/inverse",
spec: &CreateSpec{
Expand Down Expand Up @@ -1181,6 +1204,29 @@ func TestCreateNode(t *testing.T) {
m.ExpectCommit()
},
},
{
name: "edges/m2m/bidi/fields",
spec: &CreateSpec{
Table: "users",
ID: &FieldSpec{Column: "id", Type: field.TypeInt},
Fields: []*FieldSpec{
{Column: "name", Type: field.TypeString, Value: "mashraki"},
},
Edges: []*EdgeSpec{
{Rel: M2M, Bidi: true, Table: "user_friends", Columns: []string{"user_id", "friend_id"}, Target: &EdgeTarget{Nodes: []driver.Value{2}, IDSpec: &FieldSpec{Column: "id"}, Fields: []*FieldSpec{{Column: "ts", Type: field.TypeInt, Value: 3}}}},
},
},
expect: func(m sqlmock.Sqlmock) {
m.ExpectBegin()
m.ExpectExec(escape("INSERT INTO `users` (`name`) VALUES (?)")).
WithArgs("mashraki").
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectExec(escape("INSERT INTO `user_friends` (`user_id`, `friend_id`, `ts`) VALUES (?, ?, ?), (?, ?, ?)")).
WithArgs(1, 2, 3, 2, 1, 3).
WillReturnResult(sqlmock.NewResult(1, 1))
m.ExpectCommit()
},
},
{
name: "edges/m2m/bidi/batch",
spec: &CreateSpec{
Expand Down

0 comments on commit 363542a

Please sign in to comment.