Skip to content

Commit

Permalink
Add user patch endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
imulab committed Apr 23, 2017
1 parent 86c3f91 commit b194087
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 58 deletions.
4 changes: 4 additions & 0 deletions example/server.go
Expand Up @@ -85,6 +85,7 @@ func main() {
mux.GetFunc("/Users", wrap(web.QueryUserHandler, scim.QueryUser))
mux.PostFunc("/Users/.search", wrap(web.QueryUserHandler, scim.QueryUser))
mux.PutFunc("/Users/:resourceId", wrap(web.ReplaceUserHandler, scim.ReplaceUser))
mux.PatchFunc("/Users/:resourceId", wrap(web.PatchUserHandler, scim.PatchUser))

http.ListenAndServe(":8080", mux)
}
Expand Down Expand Up @@ -148,6 +149,9 @@ func (ss *simpleServer) InternalSchema(id string) *scim.Schema {
func (ss *simpleServer) CorrectCase(subj *scim.Resource, sch *scim.Schema, ctx context.Context) error {
return scim.CorrectCase(subj, sch, ctx)
}
func (ss *simpleServer) ApplyPatch(patch scim.Patch, subj *scim.Resource, sch *scim.Schema, ctx context.Context) error {
return scim.ApplyPatch(patch, subj, sch, ctx)
}
func (ss *simpleServer) ValidateType(subj *scim.Resource, sch *scim.Schema, ctx context.Context) error {
return scim.ValidateType(subj, sch, ctx)
}
Expand Down
16 changes: 16 additions & 0 deletions handlers/shared.go
Expand Up @@ -29,6 +29,9 @@ type ScimServer interface {
// case
CorrectCase(subj *Resource, sch *Schema, ctx context.Context) error

// patch
ApplyPatch(patch Patch, subj *Resource, sch *Schema, ctx context.Context) error

// validation
ValidateType(subj *Resource, sch *Schema, ctx context.Context) error
ValidateRequired(subj *Resource, sch *Schema, ctx context.Context) error
Expand Down Expand Up @@ -233,6 +236,19 @@ func ParseBodyAsResource(req *http.Request) (*Resource, error) {
return &Resource{Complex: Complex(data)}, nil
}

func ParseModification(req *http.Request) (Modification, error) {
m := Modification{}
reqBody, err := ioutil.ReadAll(req.Body)
if err != nil {
return Modification{}, err
}
err = json.Unmarshal(reqBody, &m)
if err != nil {
return Modification{}, err
}
return m, nil
}

func ParseSearchRequest(req *http.Request, server ScimServer) (SearchRequest, error) {
switch req.Method {
case http.MethodGet:
Expand Down
63 changes: 63 additions & 0 deletions handlers/users.go
Expand Up @@ -51,6 +51,69 @@ func CreateUserHandler(r *http.Request, server ScimServer, ctx context.Context)
return
}

func PatchUserHandler(r *http.Request, server ScimServer, ctx context.Context) (ri *ResponseInfo) {
ri = newResponse()
sch := server.InternalSchema(shared.UserUrn)
repo := server.Repository(shared.UserResourceType)

id, version := ParseIdAndVersion(r, server.UrlParam)
ctx = context.WithValue(ctx, shared.ResourceId{}, id)

resource, err := repo.Get(id, version)
ErrorCheck(err)

mod, err := ParseModification(r)
ErrorCheck(err)
err = mod.Validate()
ErrorCheck(err)

for _, patch := range mod.Ops {
err = server.ApplyPatch(patch, resource.(*shared.Resource), sch, ctx)
ErrorCheck(err)
}

reference, err := repo.Get(id, version)
ErrorCheck(err)

err = server.ValidateType(resource.(*shared.Resource), sch, ctx)
ErrorCheck(err)

err = server.CorrectCase(resource.(*shared.Resource), sch, ctx)
ErrorCheck(err)

err = server.ValidateRequired(resource.(*shared.Resource), sch, ctx)
ErrorCheck(err)

err = server.ValidateMutability(resource.(*shared.Resource), reference.(*shared.Resource), sch, ctx)
ErrorCheck(err)

err = server.ValidateUniqueness(resource.(*shared.Resource), sch, repo, ctx)
ErrorCheck(err)

err = server.AssignReadOnlyValue(resource.(*shared.Resource), ctx)
ErrorCheck(err)

err = repo.Update(id, version, resource)
ErrorCheck(err)

json, err := server.MarshalJSON(resource, sch, []string{}, []string{})
ErrorCheck(err)

location := resource.GetData()["meta"].(map[string]interface{})["location"].(string)
newVersion := resource.GetData()["meta"].(map[string]interface{})["version"].(string)

ri.Status(http.StatusOK)
ri.ScimJsonHeader()
if len(newVersion) > 0 {
ri.ETagHeader(newVersion)
}
if len(location) > 0 {
ri.LocationHeader(location)
}
ri.Body(json)
return
}

func ReplaceUserHandler(r *http.Request, server ScimServer, ctx context.Context) (ri *ResponseInfo) {
ri = newResponse()
sch := server.InternalSchema(shared.UserUrn)
Expand Down
1 change: 1 addition & 0 deletions mongo/repository.go
Expand Up @@ -118,6 +118,7 @@ func (r *repository) Get(id, version string) (DataProvider, error) {
return nil, r.handleError(err, id)
}

delete(data, "_id")
return r.construct(Complex(data)), nil
}

Expand Down
81 changes: 48 additions & 33 deletions shared/patch.go
Expand Up @@ -17,7 +17,7 @@ func ApplyPatch(patch Patch, subj *Resource, sch *Schema, ctx context.Context) (
}
}()

ps := patchState{patch:patch}
ps := patchState{patch: patch, sch: sch, ctx: ctx}

var path Path
if len(patch.Path) == 0 {
Expand All @@ -43,20 +43,22 @@ func ApplyPatch(patch Patch, subj *Resource, sch *Schema, ctx context.Context) (

switch patch.Op {
case Add:
ps.applyPatchAdd(path, v, subj, sch, ctx)
ps.applyPatchAdd(path, v, subj)
case Replace:
ps.applyPatchReplace(path, v, subj, sch, ctx)
ps.applyPatchReplace(path, v, subj)
case Remove:
ps.applyPatchRemove(path, subj, sch, ctx)
ps.applyPatchRemove(path, subj)
default:
err = Error.InvalidParam("Op", "one of [add|remove|replace]", patch.Op)
}
return
}

type patchState struct {
patch Patch
destAttr *Attribute
patch Patch
destAttr *Attribute
sch *Schema
ctx context.Context
}

func (ps *patchState) throw(err error, ctx context.Context) {
Expand All @@ -65,21 +67,21 @@ func (ps *patchState) throw(err error, ctx context.Context) {
}
}

func (ps *patchState) applyPatchRemove(p Path, subj *Resource, sch *Schema, ctx context.Context) {
func (ps *patchState) applyPatchRemove(p Path, subj *Resource) {
basePath, lastPath := p.SeparateAtLast()
baseChannel := make(chan interface{}, 1)
if basePath == nil {
go func(){
go func() {
baseChannel <- subj.Complex
close(baseChannel)
}()
} else {
baseChannel = subj.Get(basePath, sch)
baseChannel = subj.Get(basePath, ps.sch)
}

var baseAttr AttributeSource = sch
var baseAttr AttributeSource = ps.sch
if basePath != nil {
baseAttr = sch.GetAttribute(basePath, true)
baseAttr = ps.sch.GetAttribute(basePath, true)
}

for base := range baseChannel {
Expand All @@ -101,9 +103,9 @@ func (ps *patchState) applyPatchRemove(p Path, subj *Resource, sch *Schema, ctx
origVal := baseVal.MapIndex(keyVal)
baseAttr = baseAttr.GetAttribute(lastPath, false)
reverseRoot := &filterNode{
data:Not,
typ:LogicalOperator,
left:lastPath.FilterRoot().(*filterNode),
data: Not,
typ: LogicalOperator,
left: lastPath.FilterRoot().(*filterNode),
}
newElemChannel := MultiValued(origVal.Interface().([]interface{})).Filter(reverseRoot, baseAttr)
newArr := make([]interface{}, 0)
Expand All @@ -130,25 +132,25 @@ func (ps *patchState) applyPatchRemove(p Path, subj *Resource, sch *Schema, ctx
case reflect.Map:
elemVal.SetMapIndex(keyVal, reflect.Value{})
default:
ps.throw(Error.InvalidPath(ps.patch.Path, "array base contains non-map"), ctx)
ps.throw(Error.InvalidPath(ps.patch.Path, "array base contains non-map"), ps.ctx)
}
}
default:
ps.throw(Error.InvalidPath(ps.patch.Path, "base evaluated to non-map and non-array."), ctx)
ps.throw(Error.InvalidPath(ps.patch.Path, "base evaluated to non-map and non-array."), ps.ctx)
}
}
}

func (ps *patchState) applyPatchReplace(p Path, v reflect.Value, subj *Resource, sch *Schema, ctx context.Context) {
func (ps *patchState) applyPatchReplace(p Path, v reflect.Value, subj *Resource) {
basePath, lastPath := p.SeparateAtLast()
baseChannel := make(chan interface{}, 1)
if basePath == nil {
go func(){
go func() {
baseChannel <- subj.Complex
close(baseChannel)
}()
} else {
baseChannel = subj.Get(basePath, sch)
baseChannel = subj.Get(basePath, ps.sch)
}

for base := range baseChannel {
Expand All @@ -163,32 +165,32 @@ func (ps *patchState) applyPatchReplace(p Path, v reflect.Value, subj *Resource,
}
}

func (ps *patchState) applyPatchAdd(p Path, v reflect.Value, subj *Resource, sch *Schema, ctx context.Context) {
func (ps *patchState) applyPatchAdd(p Path, v reflect.Value, subj *Resource) {
if p == nil {
if v.Kind() != reflect.Map {
ps.throw(Error.InvalidParam("value of add op", "to be complex (for implicit path)", "non-complex"), ctx)
ps.throw(Error.InvalidParam("value of add op", "to be complex (for implicit path)", "non-complex"), ps.ctx)
}
for _, k := range v.MapKeys() {
v0 := v.MapIndex(k)
if err := ApplyPatch(Patch{
Op:Add,
Path:k.String(),
Value:v0.Interface(),
}, subj, sch, ctx); err != nil {
ps.throw(err, ctx)
Op: Add,
Path: k.String(),
Value: v0.Interface(),
}, subj, ps.sch, ps.ctx); err != nil {
ps.throw(err, ps.ctx)
}
}
} else {
basePath, lastPath := p.SeparateAtLast()
baseChannel := make(chan interface{}, 1)

if basePath == nil {
go func(){
go func() {
baseChannel <- subj.Complex
close(baseChannel)
}()
} else {
baseChannel = subj.Get(basePath, sch)
baseChannel = subj.Get(basePath, ps.sch)
}

for base := range baseChannel {
Expand All @@ -206,12 +208,26 @@ func (ps *patchState) applyPatchAdd(p Path, v reflect.Value, subj *Resource, sch
if ps.destAttr.MultiValued {
origVal := baseVal.MapIndex(keyVal)
if !origVal.IsValid() {
baseVal.SetMapIndex(keyVal, reflect.ValueOf([]interface{}{v.Interface()}))
switch v.Kind() {
case reflect.Array, reflect.Slice:
baseVal.SetMapIndex(keyVal, v)
default:
baseVal.SetMapIndex(keyVal, reflect.ValueOf([]interface{}{v.Interface()}))
}

} else {
if origVal.Kind() == reflect.Interface {
origVal = origVal.Elem()
}
newArr := MultiValued(origVal.Interface().([]interface{})).Add(v.Interface())
var newArr MultiValued
switch v.Kind() {
case reflect.Array, reflect.Slice:
for i := 0; i < v.Len(); i++ {
newArr = MultiValued(origVal.Interface().([]interface{})).Add(v.Index(i).Interface())
}
default:
newArr = MultiValued(origVal.Interface().([]interface{})).Add(v.Interface())
}
baseVal.SetMapIndex(keyVal, reflect.ValueOf(newArr))
}
} else {
Expand All @@ -227,13 +243,12 @@ func (ps *patchState) applyPatchAdd(p Path, v reflect.Value, subj *Resource, sch
case reflect.Map:
elemVal.SetMapIndex(reflect.ValueOf(lastPath.Base()), v)
default:
ps.throw(Error.InvalidPath(ps.patch.Path, "array base contains non-map"), ctx)
ps.throw(Error.InvalidPath(ps.patch.Path, "array base contains non-map"), ps.ctx)
}
}
default:
ps.throw(Error.InvalidPath(ps.patch.Path, "base evaluated to non-map and non-array."), ctx)
ps.throw(Error.InvalidPath(ps.patch.Path, "base evaluated to non-map and non-array."), ps.ctx)
}
}
}
}

0 comments on commit b194087

Please sign in to comment.