Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: SCIM list provider users #3405

Merged
merged 6 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions api/scim.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package api

import "github.com/infrahq/infra/internal/validate"

type SCIMUserName struct {
GivenName string `json:"givenName"`
FamilyName string `json:"familyName"`
}

type SCIMUserEmail struct {
Primary bool `json:"primary"`
Value string `json:"value"`
}

func (r SCIMUserEmail) ValidationRules() []validate.ValidationRule {
return []validate.ValidationRule{
validate.Required("value", r.Value),
validate.Email("value", r.Value),
}
}

const UserSchema = "urn:ietf:params:scim:schemas:core:2.0:User"

type SCIMMetadata struct {
ResourceType string `json:"resourceType"`
}

type SCIMUser struct {
Schemas []string `json:"schemas"`
ID string `json:"id"`
UserName string `json:"userName"`
Name SCIMUserName `json:"name"`
Emails []SCIMUserEmail `json:"emails"`
Active bool `json:"active"`
Meta SCIMMetadata `json:"meta"`
}

type SCIMParametersRequest struct {
// these pagination parameters must conform to the SCIM spec, rather than our standard pagination
StartIndex int `form:"startIndex"`
Count int `form:"count"`
}

func (r SCIMParametersRequest) ValidationRules() []validate.ValidationRule {
return []validate.ValidationRule{
validate.IntRule{
Name: "startIndex",
Value: r.StartIndex,
Min: validate.Int(0),
},
validate.IntRule{
Name: "count",
Value: r.Count,
Min: validate.Int(0),
},
}
}

const ListResponseSchema = "urn:ietf:params:scim:api:messages:2.0:ListResponse"

type ListProviderUsersResponse struct {
Schemas []string `json:"schemas"`
TotalResults int `json:"totalResults"`
Resources []SCIMUser `json:"Resources"` // intentionally capitalized
StartIndex int `json:"startIndex"`
ItemsPerPage int `json:"itemsPerPage"`
}
20 changes: 20 additions & 0 deletions internal/access/scim.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package access

import (
"fmt"

"github.com/gin-gonic/gin"

"github.com/infrahq/infra/internal/server/data"
"github.com/infrahq/infra/internal/server/models"
)

func ListProviderUsers(c *gin.Context, p *data.SCIMParameters) ([]models.ProviderUser, error) {
// this can only be run by an access key issued for an identity provider
ctx := GetRequestContext(c)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this missing a call to RequireInfraRole (or IsAuthorized ) ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally that was part of this change, but after building out the rest of the SCIM functionality I realized that the SCIM role wasn't needed. This endpoint can only be called with an access key issued for a provider, and it can only modify the provider that the access key was issued for. In this way is acts similarly to our other isSelf endpoints.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, because we query with AccessKey.IssusedFor, which will have a providerID in the case of these SCIM access keys, and there should never be an overlap between userIDs and providerIDs, got it!

users, err := data.ListProviderUsers(ctx.DBTxn, ctx.Authenticated.AccessKey.IssuedFor, p)
BruceMacD marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, fmt.Errorf("list provider users: %w", err)
}
return users, nil
}
19 changes: 19 additions & 0 deletions internal/server/data/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func migrations() []*migrator.Migration {
removeDotFromDestinationName(),
destinationNameUnique(),
removeDeletedIdentityProviderUsers(),
addProviderUserSCIMFields(),
// next one here
}
}
Expand Down Expand Up @@ -777,3 +778,21 @@ func removeDeletedIdentityProviderUsers() *migrator.Migration {
},
}
}

func addProviderUserSCIMFields() *migrator.Migration {
return &migrator.Migration{
ID: "2022-09-28T13:00",
Migrate: func(tx migrator.DB) error {
stmt := `
ALTER TABLE provider_users
ADD COLUMN IF NOT EXISTS given_name text DEFAULT '',
ADD COLUMN IF NOT EXISTS family_name text DEFAULT '',
ADD COLUMN IF NOT EXISTS active boolean DEFAULT true;

CREATE UNIQUE INDEX IF NOT EXISTS idx_emails_providers ON provider_users (email, provider_id);
`
_, err := tx.Exec(stmt)
return err
},
}
}
6 changes: 6 additions & 0 deletions internal/server/data/migrations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,12 @@ DELETE FROM settings WHERE id=24567;
assert.Equal(t, count, 0)
},
},
{
label: testCaseLine("2022-09-28T13:00"),
expected: func(t *testing.T, db WriteTxn) {
// schema changes are tested with schema comparison
},
},
}

ids := make(map[string]struct{}, len(testCases))
Expand Down
2 changes: 1 addition & 1 deletion internal/server/data/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func DeleteProviders(db GormTxn, selectors ...SelectorFunc) error {
for _, p := range toDelete {
ids = append(ids, p.ID)

providerUsers, err := listProviderUsers(db, p.ID)
providerUsers, err := ListProviderUsers(db, p.ID, nil)
if err != nil {
return fmt.Errorf("listing provider users: %w", err)
}
Expand Down
51 changes: 45 additions & 6 deletions internal/server/data/provideruser.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ func (p providerUserTable) Table() string {
}

func (p providerUserTable) Columns() []string {
return []string{"identity_id", "provider_id", "email", "groups", "last_update", "redirect_url", "access_token", "refresh_token", "expires_at"}
return []string{"identity_id", "provider_id", "email", "groups", "last_update", "redirect_url", "access_token", "refresh_token", "expires_at", "given_name", "family_name", "active"}
}

func (p providerUserTable) Values() []any {
return []any{p.IdentityID, p.ProviderID, p.Email, p.Groups, p.LastUpdate, p.RedirectURL, p.AccessToken, p.RefreshToken, p.ExpiresAt}
return []any{p.IdentityID, p.ProviderID, p.Email, p.Groups, p.LastUpdate, p.RedirectURL, p.AccessToken, p.RefreshToken, p.ExpiresAt, p.GivenName, p.FamilyName, p.Active}
}

func (p *providerUserTable) ScanFields() []any {
return []any{&p.IdentityID, &p.ProviderID, &p.Email, &p.Groups, &p.LastUpdate, &p.RedirectURL, &p.AccessToken, &p.RefreshToken, &p.ExpiresAt}
return []any{&p.IdentityID, &p.ProviderID, &p.Email, &p.Groups, &p.LastUpdate, &p.RedirectURL, &p.AccessToken, &p.RefreshToken, &p.ExpiresAt, &p.GivenName, &p.FamilyName, &p.Active}
}

func (p *providerUserTable) OnInsert() error {
Expand Down Expand Up @@ -66,6 +66,7 @@ func CreateProviderUser(db GormTxn, provider *models.Provider, ident *models.Ide
IdentityID: ident.ID,
Email: ident.Name,
LastUpdate: time.Now().UTC(),
Active: true,
}
if err := validateProviderUser(pu); err != nil {
return nil, err
Expand Down Expand Up @@ -93,20 +94,51 @@ func UpdateProviderUser(tx WriteTxn, providerUser *models.ProviderUser) error {
return handleError(err)
}

func listProviderUsers(tx ReadTxn, providerID uid.ID) ([]models.ProviderUser, error) {
func ListProviderUsers(tx ReadTxn, providerID uid.ID, p *SCIMParameters) ([]models.ProviderUser, error) {
table := &providerUserTable{}
query := querybuilder.New("SELECT")
query.B(columnsForSelect(table))
if p != nil {
query.B(", count(*) OVER()")
}
query.B("FROM")
query.B(table.Table())
query.B("INNER JOIN providers ON provider_users.provider_id = providers.id AND providers.organization_id = ?", tx.OrganizationID())
query.B("WHERE provider_id = ?", providerID)

query.B("ORDER BY email ASC")

if p != nil {
// apply scim parameters
if p.Count != 0 {
query.B("LIMIT ?", p.Count)
}
if p.StartIndex > 0 {
offset := p.StartIndex - 1 // start index begins at 1, not 0
query.B("OFFSET ?", offset)
}
}

rows, err := tx.Query(query.String(), query.Args...)
if err != nil {
return nil, err
}
return scanRows(rows, func(pu *models.ProviderUser) []any {
return (*providerUserTable)(pu).ScanFields()
result, err := scanRows(rows, func(pu *models.ProviderUser) []any {
fields := (*providerUserTable)(pu).ScanFields()
if p != nil {
fields = append(fields, &p.TotalCount)
}
return fields
})
if err != nil {
return nil, fmt.Errorf("scan provider users: %w", err)
}

if p != nil && p.Count == 0 {
p.Count = p.TotalCount
}

return result, nil
}

type DeleteProviderUsersOptions struct {
Expand Down Expand Up @@ -179,3 +211,10 @@ func SyncProviderUser(ctx context.Context, tx GormTxn, user *models.Identity, pr

return nil
}

type SCIMParameters struct {
Count int // the number of items to return
StartIndex int // the offset to start counting from
TotalCount int // the total number of items that match the query
// TODO: filter query param
}
119 changes: 119 additions & 0 deletions internal/server/data/provideruser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/infrahq/infra/internal/server/models"
"github.com/infrahq/infra/internal/server/providers"
"github.com/infrahq/infra/uid"
)

// mockOIDC is a mock oidc identity provider
Expand Down Expand Up @@ -106,6 +107,7 @@ func TestSyncProviderUser(t *testing.T) {
AccessToken: "any-access-token",
ExpiresAt: time.Now().Add(time.Hour).UTC(),
LastUpdate: time.Now().UTC(),
Active: true,
}

cmpProviderUser := cmp.Options{
Expand Down Expand Up @@ -166,6 +168,7 @@ func TestSyncProviderUser(t *testing.T) {
AccessToken: "any-access-token",
ExpiresAt: time.Now().Add(5 * time.Minute).UTC(),
LastUpdate: time.Now().UTC(),
Active: true,
}

cmpProviderUser := cmp.Options{
Expand Down Expand Up @@ -254,3 +257,119 @@ func TestDeleteProviderUser(t *testing.T) {
assert.NilError(t, err)
})
}

func TestListProviderUsers(t *testing.T) {
type testCase struct {
name string
setup func(t *testing.T, tx *Transaction) (providerID uid.ID, p *SCIMParameters, expected []models.ProviderUser, totalCount int)
}

testCases := []testCase{
{
name: "list all provider users",
setup: func(t *testing.T, tx *Transaction) (providerID uid.ID, p *SCIMParameters, expected []models.ProviderUser, totalCount int) {
provider := &models.Provider{
Name: "mockta",
Kind: models.ProviderKindOkta,
}

err := CreateProvider(tx, provider)
assert.NilError(t, err)

pu := createTestProviderUser(t, tx, provider, "david@example.com")
return provider.ID, nil, []models.ProviderUser{pu}, 0
},
},
{
name: "list all provider users invalid provider ID",
setup: func(t *testing.T, tx *Transaction) (providerID uid.ID, p *SCIMParameters, expected []models.ProviderUser, totalCount int) {
provider := &models.Provider{
Name: "mockta",
Kind: models.ProviderKindOkta,
}

err := CreateProvider(tx, provider)
assert.NilError(t, err)

_ = createTestProviderUser(t, tx, provider, "david@example.com")
return 1234, nil, nil, 0
},
},
{
name: "limit less than total",
setup: func(t *testing.T, tx *Transaction) (providerID uid.ID, p *SCIMParameters, expected []models.ProviderUser, totalCount int) {
provider := &models.Provider{
Name: "mockta",
Kind: models.ProviderKindOkta,
}

err := CreateProvider(tx, provider)
assert.NilError(t, err)

pu := createTestProviderUser(t, tx, provider, "david@example.com")
_ = createTestProviderUser(t, tx, provider, "lucy@example.com")
return provider.ID, &SCIMParameters{Count: 1}, []models.ProviderUser{pu}, 2
},
},
{
name: "offset from start",
setup: func(t *testing.T, tx *Transaction) (providerID uid.ID, p *SCIMParameters, expected []models.ProviderUser, totalCount int) {
provider := &models.Provider{
Name: "mockta",
Kind: models.ProviderKindOkta,
}

err := CreateProvider(tx, provider)
assert.NilError(t, err)

pu1 := createTestProviderUser(t, tx, provider, "david@example.com")
pu2 := createTestProviderUser(t, tx, provider, "lucy@example.com")
return provider.ID, &SCIMParameters{StartIndex: 1}, []models.ProviderUser{pu1, pu2}, 2
},
},
}

runDBTests(t, func(t *testing.T, db *DB) {
org := &models.Organization{Name: "something", Domain: "example.com"}
assert.NilError(t, CreateOrganization(db, org))

// create some dummy data for another org to test multi-tenancy
stmt := `
INSERT INTO provider_users(identity_id, provider_id, email)
VALUES (?, ?, ?);
`
_, err := db.Exec(stmt, 123, 123, "otherorg@example.com")
assert.NilError(t, err)

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tx := txnForTestCase(t, db, org.ID)

providerID, p, expected, totalCount := tc.setup(t, tx)

result, err := ListProviderUsers(tx, providerID, p)

assert.NilError(t, err)
assert.DeepEqual(t, result, expected, cmpTimeWithDBPrecision)
if p != nil {
assert.Equal(t, p.TotalCount, totalCount)
}
})
}
})
}

func createTestProviderUser(t *testing.T, tx *Transaction, provider *models.Provider, userName string) models.ProviderUser {
user := &models.Identity{
Name: userName,
}
err := CreateIdentity(tx, user)
assert.NilError(t, err)

pu, err := CreateProviderUser(tx, provider, user)
assert.NilError(t, err)

pu.Groups = models.CommaSeparatedStrings{}

return *pu
}
7 changes: 6 additions & 1 deletion internal/server/data/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,10 @@ CREATE TABLE provider_users (
redirect_url text,
access_token text,
refresh_token text,
expires_at timestamp with time zone
expires_at timestamp with time zone,
given_name text DEFAULT ''::text,
family_name text DEFAULT ''::text,
active boolean DEFAULT true
);

CREATE TABLE providers (
Expand Down Expand Up @@ -284,6 +287,8 @@ CREATE UNIQUE INDEX idx_destinations_name ON destinations USING btree (organizat

CREATE UNIQUE INDEX idx_destinations_unique_id ON destinations USING btree (organization_id, unique_id) WHERE (deleted_at IS NULL);

CREATE UNIQUE INDEX idx_emails_providers ON provider_users USING btree (email, provider_id);

CREATE UNIQUE INDEX idx_encryption_keys_key_id ON encryption_keys USING btree (key_id);

CREATE UNIQUE INDEX idx_grant_srp ON grants USING btree (organization_id, subject, privilege, resource) WHERE (deleted_at IS NULL);
Expand Down
Loading