Skip to content

Commit

Permalink
Merge c36c42d into 395f038
Browse files Browse the repository at this point in the history
  • Loading branch information
wildan2711 committed Dec 10, 2018
2 parents 395f038 + c36c42d commit 95611f1
Show file tree
Hide file tree
Showing 5 changed files with 470 additions and 20 deletions.
97 changes: 91 additions & 6 deletions mutate.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,24 @@ import (
"github.com/dgraph-io/dgo/protos/api"
)

// UniqueError returns the field and value that failed the unique node check
type UniqueError struct {
Field string
Value interface{}
}

func (u UniqueError) Error() string {
return fmt.Sprintf("%s %v already exists\n", u.Field, u.Value)
}

// MutateOptions specifies options for mutating
type MutateOptions struct {
DisableInject bool
CommitNow bool
}

// Mutate is a shortcut to create mutations from data to be marshalled into JSON
// Mutate is a shortcut to create mutations from data to be marshalled into JSON,
// it will inject the node type from the Struct name converted to snake_case
func Mutate(ctx context.Context, tx *dgo.Txn, data interface{}, options ...MutateOptions) (string, error) {
opt := MutateOptions{}
if len(options) > 0 {
Expand All @@ -39,6 +50,7 @@ func Mutate(ctx context.Context, tx *dgo.Txn, data interface{}, options ...Mutat
return "", err
}

// TODO: handle bulk mutations
uid, ok := assigned.Uids["blank-0"]
if !ok {
// if update, no uid's assigned
Expand All @@ -47,6 +59,62 @@ func Mutate(ctx context.Context, tx *dgo.Txn, data interface{}, options ...Mutat
return uid, nil
}

// Create is similar to Mutate, but checks for fields that must be unique for a certain node type
func Create(ctx context.Context, tx *dgo.Txn, model interface{}, opt ...MutateOptions) (uid string, err error) {
uniqueFields := getAllUniqueFields(model)

for field, value := range uniqueFields {
if exists(ctx, tx, field, value, model) {
return "", UniqueError{field, value}
}
}

return Mutate(ctx, tx, model, opt...)
}

func exists(ctx context.Context, tx *dgo.Txn, field string, value interface{}, model interface{}) bool {
jsonValue, err := json.Marshal(value)
if err != nil {
log.Println("unmarshal", err)
return false
}

filter := fmt.Sprintf(`eq(%s, %s)`, field, jsonValue)
if err := GetByFilter(ctx, tx, filter, model); err != nil {
if err != ErrNodeNotFound {
log.Println("check exist", err)
}
return false
}
return true
}

// getAllUniqueFields gets all values of the fields that has to be unique
func getAllUniqueFields(model interface{}) map[string]interface{} {
v, err := reflectValue(model)
if err != nil {
return nil
}
numFields := v.NumField()

// map all fields that must be unique
uniqueValueMap := make(map[string]interface{})
for i := 0; i < numFields; i++ {
field := v.Field(i)
structField := v.Type().Field(i)

s, err := parseDgraphTag(&structField)
if err != nil {
return nil
}

if s.Unique {
uniqueValueMap[s.Predicate] = field.Interface()
}
}
return uniqueValueMap
}

func marshalAndInjectType(data interface{}, disableInject bool) ([]byte, error) {
jsonData, err := json.Marshal(data)
if err != nil {
Expand All @@ -55,23 +123,40 @@ func marshalAndInjectType(data interface{}, disableInject bool) ([]byte, error)
}

if !disableInject {
nodeType := getNodeType(data)
snakeCase := toSnakeCase(nodeType)
nodeType := GetNodeType(data)

switch jsonData[0] {
case 123: // if JSON object, starts with "{" (123 in ASCII)
result := fmt.Sprintf("{\"%s\":\"\",%s", snakeCase, string(jsonData[1:]))
result := fmt.Sprintf("{\"%s\":\"\",%s", nodeType, string(jsonData[1:]))
return []byte(result), nil
}
}

return jsonData, nil
}

func getNodeType(data interface{}) string {
// GetNodeType gets node type from NodeType() method of Node interface
// if it doesn't implement it, get it from the struct name and convert to snake case
func GetNodeType(data interface{}) string {
// check if data implements node interface
if node, ok := data.(Node); ok {
return node.NodeType()
}
return reflect.TypeOf(data).Elem().Name()
// get node type from struct name and convert to snake case
structName := ""
dataType := reflect.TypeOf(data)

switch dataType.Kind() {
case reflect.Struct:
structName = dataType.Name()
case reflect.Ptr, reflect.Slice:
dataType = dataType.Elem()
switch dataType.Kind() {
case reflect.Struct:
structName = dataType.Name()
case reflect.Ptr, reflect.Slice:
structName = dataType.Elem().Name()
}
}
return toSnakeCase(structName)
}
118 changes: 118 additions & 0 deletions mutate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ type TestCustomNode struct {
Field string `json:"field,omitempty"`
}

type TestUnique struct {
UID string `json:"uid,omitempty"`
Name string `json:"name,omitempty"`
Username string `json:"username,omitempty" dgraph:"index=term unique"`
Email string `json:"email,omitempty" dgraph:"index=term unique"`
No int `json:"no,omitempty" dgraph:"index=int unique"`
}

func (n TestCustomNode) NodeType() string {
return "custom_node_type"
}
Expand Down Expand Up @@ -120,3 +128,113 @@ func TestAddNodeType(t *testing.T) {
t.Errorf("expected %s got %s", expected, jsonData)
}
}

func TestGetNodeType(t *testing.T) {
nodeTypeStruct := GetNodeType(TestNode{})
nodeTypePtr := GetNodeType(&TestNode{})
nodeTypeSlice := GetNodeType([]TestNode{})
nodeTypeSlicePtr := GetNodeType([]*TestNode{})

assert.Equal(t, nodeTypeStruct, "test_node")
assert.Equal(t, nodeTypePtr, "test_node")
assert.Equal(t, nodeTypeSlice, "test_node")
assert.Equal(t, nodeTypeSlicePtr, "test_node")
}

func TestGetAllUniqueFields(t *testing.T) {
testUnique := &TestUnique{
Name: "H3h3",
Username: "wildan",
Email: "wildan2711@gmail.com",
No: 4,
}
uniqueFields := getAllUniqueFields(testUnique)
assert.Len(t, uniqueFields, 3)
}

func TestCreate(t *testing.T) {
testUnique := []TestUnique{
TestUnique{
Name: "H3h3",
Username: "wildan",
Email: "wildan2711@gmail.com",
No: 1,
},
TestUnique{
Name: "PooDiePie",
Username: "wildansyah",
Email: "wildansyah2711@gmail.com",
No: 2,
},
TestUnique{
Name: "Poopsie",
Username: "wildani",
Email: "wildani@gmail.com",
No: 3,
},
}

c := newDgraphClient()
if _, err := CreateSchema(c, &TestUnique{}); err != nil {
t.Error(err)
}
defer dropAll(c)

tx := c.NewTxn()

for _, data := range testUnique {
_, err := Create(context.Background(), tx, &data)
if err != nil {
t.Error(err)
}
}
if err := tx.Commit(context.Background()); err != nil {
t.Error(err)
}

testDuplicate := []TestUnique{
TestUnique{
Name: "H3h3",
Username: "wildanjing",
Email: "wildan2711@gmail.com",
No: 4,
},
TestUnique{
Name: "PooDiePie",
Username: "wildansyah",
Email: "wildanodol2711@gmail.com",
No: 5,
},
TestUnique{
Name: "lalap",
Username: "lalap",
Email: "lalap@gmail.com",
No: 3,
},
}

tx = c.NewTxn()

var duplicates []UniqueError
for _, data := range testDuplicate {
_, err := Create(context.Background(), tx, &data)
if err != nil {
if uniqueError, ok := err.(UniqueError); ok {
duplicates = append(duplicates, uniqueError)
continue
}
t.Error(err)
}
}
if err := tx.Commit(context.Background()); err != nil {
t.Error(err)
}

assert.Len(t, duplicates, 3)
assert.Equal(t, duplicates[0].Field, "email")
assert.Equal(t, duplicates[0].Value, "wildan2711@gmail.com")
assert.Equal(t, duplicates[1].Field, "username")
assert.Equal(t, duplicates[1].Value, "wildansyah")
assert.Equal(t, duplicates[2].Field, "no")
assert.Equal(t, duplicates[2].Value, 3)
}
102 changes: 102 additions & 0 deletions query.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package dgman

import (
"context"
"encoding/json"
"errors"
"fmt"

"github.com/dgraph-io/dgo"
)

var (
ErrNodeNotFound = errors.New("node not found")
)

// GetByUID gets a single node by their UID and returns the value to the passed model struct
func GetByUID(ctx context.Context, tx *dgo.Txn, uid string, model interface{}) error {
query := fmt.Sprintf(`{
data(func: uid(%s)) {
expand(_all_)
}
}`, uid)

resp, err := tx.Query(ctx, query)
if err != nil {
return err
}

return singleResult(resp.Json, model)
}

// GetByFilter gets a single node by using a Dgraph query filter
// and returns the single value to the passed model struct
func GetByFilter(ctx context.Context, tx *dgo.Txn, filter string, model interface{}) error {
nodeType := GetNodeType(model)
query := fmt.Sprintf(`{
data(func: has(%s)) @filter(%s) {
expand(_all_)
}
}`, nodeType, filter)

resp, err := tx.Query(ctx, query)
if err != nil {
return err
}

return singleResult(resp.Json, model)
}

// Find returns multiple nodes that matches the specified Dgraph query filter,
// the passed model must be a slice
func Find(ctx context.Context, tx *dgo.Txn, filter string, model interface{}) error {
nodeType := GetNodeType(model)
query := fmt.Sprintf(`{
data(func: has(%s)) @filter(%s) {
expand(_all_)
}
}`, nodeType, filter)
resp, err := tx.Query(ctx, query)
if err != nil {
return err
}

return multipleResult(resp.Json, model)
}

func singleResult(jsonData []byte, model interface{}) error {
var result struct {
Data []json.RawMessage
}

if err := json.Unmarshal(jsonData, &result); err != nil {
return err
}

if len(result.Data) == 0 {
return ErrNodeNotFound
}

val := result.Data[0]
if err := json.Unmarshal(val, model); err != nil {
return err
}

return nil
}

func multipleResult(jsonData []byte, model interface{}) error {
var result struct {
Data json.RawMessage
}

if err := json.Unmarshal(jsonData, &result); err != nil {
return err
}

if err := json.Unmarshal(result.Data, model); err != nil {
return err
}

return nil
}

0 comments on commit 95611f1

Please sign in to comment.