Skip to content

Commit

Permalink
feat: scim list provider users
Browse files Browse the repository at this point in the history
- add SCIM list users endpoint
  • Loading branch information
BruceMacD committed Oct 11, 2022
1 parent d924dbc commit af0b4d7
Show file tree
Hide file tree
Showing 13 changed files with 557 additions and 8 deletions.
67 changes: 67 additions & 0 deletions api/scim.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package api

import "github.com/infrahq/infra/internal/validate"

type SCIMUserName struct {
GivenName string `json:"givenName"`
FamilyName string `json:"familyName"`
}

type SCIMUserEmail struct {
Primary bool `json:"primary"`
Value string `json:"value"`
}

func (r SCIMUserEmail) ValidationRules() []validate.ValidationRule {
return []validate.ValidationRule{
validate.Required("value", r.Value),
validate.Email("value", r.Value),
}
}

const UserSchema = "urn:ietf:params:scim:schemas:core:2.0:User"

type SCIMMetadata struct {
ResourceType string `json:"resourceType"`
}

type SCIMUser struct {
Schemas []string `json:"schemas"`
ID string `json:"id"`
UserName string `json:"userName"`
Name SCIMUserName `json:"name"`
Emails []SCIMUserEmail `json:"emails"`
Active bool `json:"active"`
Meta SCIMMetadata `json:"meta"`
}

type SCIMParametersRequest struct {
// these pagination parameters must conform to the SCIM spec, rather than our standard pagination
StartIndex int `form:"startIndex"`
Count int `form:"count"`
}

func (r SCIMParametersRequest) ValidationRules() []validate.ValidationRule {
return []validate.ValidationRule{
validate.IntRule{
Name: "startIndex",
Value: r.StartIndex,
Min: validate.Int(0),
},
validate.IntRule{
Name: "count",
Value: r.Count,
Min: validate.Int(0),
},
}
}

const ListResponseSchema = "urn:ietf:params:scim:api:messages:2.0:ListResponse"

type ListProviderUsersResponse struct {
Schemas []string `json:"schemas"`
TotalResults int `json:"totalResults"`
Resources []SCIMUser `json:"Resources"` // intentionally capitalized
StartIndex int `json:"startIndex"`
ItemsPerPage int `json:"itemsPerPage"`
}
20 changes: 20 additions & 0 deletions internal/access/scim.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package access

import (
"fmt"

"github.com/gin-gonic/gin"

"github.com/infrahq/infra/internal/server/data"
"github.com/infrahq/infra/internal/server/models"
)

func ListProviderUsers(c *gin.Context, p *data.SCIMParameters) ([]models.ProviderUser, error) {
// this can only be run by an access key issued for an identity provider
ctx := GetRequestContext(c)
users, err := data.ListProviderUsers(ctx.DBTxn, ctx.Authenticated.AccessKey.IssuedFor, p)
if err != nil {
return nil, fmt.Errorf("list provider users: %w", err)
}
return users, nil
}
4 changes: 2 additions & 2 deletions internal/server/data/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func AssignIdentityToGroups(tx GormTxn, user *models.Identity, provider *models.
}
addIDs = append(addIDs, item)
}
if rows.Err() != nil {
if err := rows.Err(); err != nil {
return err
}

Expand Down Expand Up @@ -107,7 +107,7 @@ func AssignIdentityToGroups(tx GormTxn, user *models.Identity, provider *models.
}
ids = append(ids, item)
}
if rows.Err() != nil {
if err := rows.Err(); err != nil {
return err
}

Expand Down
19 changes: 19 additions & 0 deletions internal/server/data/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func migrations() []*migrator.Migration {
removeDotFromDestinationName(),
destinationNameUnique(),
removeDeletedIdentityProviderUsers(),
addProviderUserSCIMFields(),
// next one here
}
}
Expand Down Expand Up @@ -790,3 +791,21 @@ func removeDeletedIdentityProviderUsers() *migrator.Migration {
},
}
}

func addProviderUserSCIMFields() *migrator.Migration {
return &migrator.Migration{
ID: "2022-09-28T13:00",
Migrate: func(tx migrator.DB) error {
stmt := `
ALTER TABLE provider_users
ADD COLUMN IF NOT EXISTS given_name text,
ADD COLUMN IF NOT EXISTS family_name text,
ADD COLUMN IF NOT EXISTS active boolean DEFAULT true;
CREATE UNIQUE INDEX IF NOT EXISTS idx_emails_providers ON provider_users (email, provider_id);
`
_, err := tx.Exec(stmt)
return err
},
}
}
6 changes: 6 additions & 0 deletions internal/server/data/migrations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,12 @@ DELETE FROM settings WHERE id=24567;
assert.Equal(t, count, 0)
},
},
{
label: testCaseLine("2022-09-28T13:00"),
expected: func(t *testing.T, db WriteTxn) {
// schema changes are tested with schema comparison
},
},
}

ids := make(map[string]struct{}, len(testCases))
Expand Down
2 changes: 1 addition & 1 deletion internal/server/data/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func DeleteProviders(db GormTxn, selectors ...SelectorFunc) error {
for _, p := range toDelete {
ids = append(ids, p.ID)

providerUsers, err := listProviderUsers(db, p.ID)
providerUsers, err := ListProviderUsers(db, p.ID, nil)
if err != nil {
return fmt.Errorf("listing provider users: %w", err)
}
Expand Down
46 changes: 42 additions & 4 deletions internal/server/data/provideruser.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,29 +93,60 @@ func UpdateProviderUser(tx WriteTxn, providerUser *models.ProviderUser) error {
return handleError(err)
}

func listProviderUsers(tx ReadTxn, providerID uid.ID) ([]models.ProviderUser, error) {
func ListProviderUsers(tx ReadTxn, providerID uid.ID, p *SCIMParameters) ([]models.ProviderUser, error) {
table := &providerUserTable{}
query := querybuilder.New("SELECT")
query.B(columnsForSelect(table))
if p != nil {
query.B(", count(*) OVER()")
}
query.B("FROM")
query.B(table.Table())
query.B("INNER JOIN providers ON provider_users.provider_id = providers.id AND providers.organization_id = ?", tx.OrganizationID())
query.B("WHERE provider_id = ?", providerID)

query.B("ORDER BY email ASC")

if p != nil {
// apply scim parameters
if p.Count != 0 {
query.B("LIMIT ?", p.Count)
}
if p.StartIndex != 0 {
query.B("OFFSET ?", p.StartIndex)
}
}

rows, err := tx.Query(query.String(), query.Args...)
if err != nil {
return nil, err
}
defer rows.Close()

var result []models.ProviderUser
for rows.Next() {
var pu models.ProviderUser

if err := rows.Scan((*providerUserTable)(&pu).ScanFields()...); err != nil {
fields := (*providerUserTable)(&pu).ScanFields()
if p != nil {
fields = append(fields, &p.TotalCount)
}
if err := rows.Scan(fields...); err != nil {
return nil, err
}
result = append(result, pu)
}
return result, rows.Err()

if p != nil && p.Count == 0 {
p.Count = p.TotalCount
}

defer rows.Close()

if err := rows.Err(); err != nil {
return nil, err
}

return result, nil
}

type DeleteProviderUsersOptions struct {
Expand Down Expand Up @@ -188,3 +219,10 @@ func SyncProviderUser(ctx context.Context, tx GormTxn, user *models.Identity, pr

return nil
}

type SCIMParameters struct {
Count int // the number of items to return
StartIndex int // the offset to start counting from
TotalCount int
// TODO: filter query param
}
117 changes: 117 additions & 0 deletions internal/server/data/provideruser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/infrahq/infra/internal/server/models"
"github.com/infrahq/infra/internal/server/providers"
"github.com/infrahq/infra/uid"
)

// mockOIDC is a mock oidc identity provider
Expand Down Expand Up @@ -254,3 +255,119 @@ func TestDeleteProviderUser(t *testing.T) {
assert.NilError(t, err)
})
}

func TestListProviderUsers(t *testing.T) {
type testCase struct {
name string
setup func(t *testing.T, tx *Transaction) (providerID uid.ID, p *SCIMParameters, expected []models.ProviderUser, totalCount int)
}

testCases := []testCase{
{
name: "list all provider users",
setup: func(t *testing.T, tx *Transaction) (providerID uid.ID, p *SCIMParameters, expected []models.ProviderUser, totalCount int) {
provider := &models.Provider{
Name: "mockta",
Kind: models.ProviderKindOkta,
}

err := CreateProvider(tx, provider)
assert.NilError(t, err)

pu := createTestProviderUser(t, tx, provider, "david@example.com")
return provider.ID, nil, []models.ProviderUser{pu}, 0
},
},
{
name: "list all provider users invalid provider ID",
setup: func(t *testing.T, tx *Transaction) (providerID uid.ID, p *SCIMParameters, expected []models.ProviderUser, totalCount int) {
provider := &models.Provider{
Name: "mockta",
Kind: models.ProviderKindOkta,
}

err := CreateProvider(tx, provider)
assert.NilError(t, err)

_ = createTestProviderUser(t, tx, provider, "david@example.com")
return 1234, nil, nil, 0
},
},
{
name: "limit less than total",
setup: func(t *testing.T, tx *Transaction) (providerID uid.ID, p *SCIMParameters, expected []models.ProviderUser, totalCount int) {
provider := &models.Provider{
Name: "mockta",
Kind: models.ProviderKindOkta,
}

err := CreateProvider(tx, provider)
assert.NilError(t, err)

pu := createTestProviderUser(t, tx, provider, "david@example.com")
_ = createTestProviderUser(t, tx, provider, "lucy@example.com")
return provider.ID, &SCIMParameters{Count: 1}, []models.ProviderUser{pu}, 2
},
},
{
name: "offset from start",
setup: func(t *testing.T, tx *Transaction) (providerID uid.ID, p *SCIMParameters, expected []models.ProviderUser, totalCount int) {
provider := &models.Provider{
Name: "mockta",
Kind: models.ProviderKindOkta,
}

err := CreateProvider(tx, provider)
assert.NilError(t, err)

_ = createTestProviderUser(t, tx, provider, "david@example.com")
pu := createTestProviderUser(t, tx, provider, "lucy@example.com")
return provider.ID, &SCIMParameters{StartIndex: 1}, []models.ProviderUser{pu}, 2
},
},
}

runDBTests(t, func(t *testing.T, db *DB) {
org := &models.Organization{Name: "something", Domain: "example.com"}
assert.NilError(t, CreateOrganization(db, org))

// create some dummy data for another org to test multi-tenancy
stmt := `
INSERT INTO provider_users(identity_id, provider_id, email)
VALUES (?, ?, ?);
`
_, err := db.Exec(stmt, 123, 123, "otherorg@example.com")
assert.NilError(t, err)

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tx := txnForTestCase(t, db, org.ID)

providerID, p, expected, totalCount := tc.setup(t, tx)

result, err := ListProviderUsers(tx, providerID, p)

assert.NilError(t, err)
assert.DeepEqual(t, result, expected, cmpTimeWithDBPrecision)
if p != nil {
assert.Equal(t, p.TotalCount, totalCount)
}
})
}
})
}

func createTestProviderUser(t *testing.T, tx *Transaction, provider *models.Provider, userName string) models.ProviderUser {
user := &models.Identity{
Name: userName,
}
err := CreateIdentity(tx, user)
assert.NilError(t, err)

pu, err := CreateProviderUser(tx, provider, user)
assert.NilError(t, err)

pu.Groups = models.CommaSeparatedStrings{}

return *pu
}
7 changes: 6 additions & 1 deletion internal/server/data/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,10 @@ CREATE TABLE provider_users (
redirect_url text,
access_token text,
refresh_token text,
expires_at timestamp with time zone
expires_at timestamp with time zone,
given_name text,
family_name text,
active boolean DEFAULT true
);

CREATE TABLE providers (
Expand Down Expand Up @@ -284,6 +287,8 @@ CREATE UNIQUE INDEX idx_destinations_name ON destinations USING btree (organizat

CREATE UNIQUE INDEX idx_destinations_unique_id ON destinations USING btree (organization_id, unique_id) WHERE (deleted_at IS NULL);

CREATE UNIQUE INDEX idx_emails_providers ON provider_users USING btree (email, provider_id);

CREATE UNIQUE INDEX idx_encryption_keys_key_id ON encryption_keys USING btree (key_id);

CREATE UNIQUE INDEX idx_grant_srp ON grants USING btree (organization_id, subject, privilege, resource) WHERE (deleted_at IS NULL);
Expand Down
Loading

0 comments on commit af0b4d7

Please sign in to comment.