Skip to content

Commit

Permalink
cacheroach: Add email-domain grouping.
Browse files Browse the repository at this point in the history
This change adds support for creating principals that represent an entire email
domain's worth of principals. This is a quick way to add support for
generalized principal grouping (#3), but is fairly coarse-grained.
  • Loading branch information
bobvawter committed Mar 19, 2021
1 parent e544447 commit b0b5f2f
Show file tree
Hide file tree
Showing 10 changed files with 279 additions and 106 deletions.
5 changes: 5 additions & 0 deletions api/principal.proto
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ message Principal {
}
}
];
// If present, indicates that the principal represents all users whose
// email address are in the given domain.
string email_domain = 5;
string refresh_token = 66 [
(capabilities.field_rule).never = true
];
Expand All @@ -81,6 +84,8 @@ message LoadRequest {
ID ID = 1;
// Load a Principal by email address.
string email = 2;
// Load a domain-level Principal.
string email_domain = 3;
}
}

Expand Down
173 changes: 102 additions & 71 deletions api/principal/principal.pb.go

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion doc/cacheroach_bootstrap.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ create a super-user principal using the server's HMAC key
This command should be used to create an initial user on a newly-created cacheroach installation. It requires access to the server's HMAC key that is used to sign tokens. The resulting session will have superuser access; the resulting configuration file should be treated with the same security as the key.

```
cacheroach bootstrap [flags] https://username[:password]@cacheroach.server/
cacheroach bootstrap [flags] https://cacheroach.server/
```

### Options
Expand Down
5 changes: 3 additions & 2 deletions doc/cacheroach_principal_create.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ cacheroach principal create <username> [flags]
### Options

```
-h, --help help for create
-o, --out string write a new configuration file, defaults to username.cfg
--emailDomain string create a unique principal that represents all principals with an email address in the given domain
-h, --help help for create
-o, --out string write a new configuration file, defaults to username.cfg
```

### Options inherited from parent commands
Expand Down
10 changes: 7 additions & 3 deletions pkg/cmd/cli/principal.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (c *CLI) principal() *cobra.Command {
Short: "principal management",
}

var createOut string
var createDomain, createOut string
create := &cobra.Command{
Use: "create <username>",
Short: "create a principal",
Expand All @@ -55,8 +55,9 @@ func (c *CLI) principal() *cobra.Command {
}
req := &principal.EnsureRequest{
Principal: &principal.Principal{
ID: principal.NewID(),
Claims: claimBytes,
ID: principal.NewID(),
Claims: claimBytes,
EmailDomain: createDomain,
},
}
prn, err := principal.NewPrincipalsClient(conn).Ensure(cmd.Context(), req)
Expand Down Expand Up @@ -96,6 +97,9 @@ func (c *CLI) principal() *cobra.Command {
return nil
},
}
create.Flags().StringVar(&createDomain, "emailDomain", "",
"create a unique principal that represents all principals with "+
"an email address in the given domain")
create.Flags().StringVarP(&createOut, "out", "o", "",
"write a new configuration file, defaults to username.cfg")

Expand Down
47 changes: 25 additions & 22 deletions pkg/store/principal/principal.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
"github.com/google/wire"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
Expand Down Expand Up @@ -84,8 +83,10 @@ func (s *Server) Ensure(
defer tx.Rollback(ctx)

row := tx.QueryRow(ctx,
"INSERT INTO principals (principal, refresh_after, refresh_status, refresh_token, claims, version) "+
"VALUES ($1, $2, $3, $4, $5, 1) "+
"INSERT INTO principals ( "+
"principal, email_domain, refresh_after, refresh_status, refresh_token, "+
"claims, version"+
") VALUES ($1, $2, $3, $4, $5, $6, 1) "+
"ON CONFLICT (principal) "+
"DO UPDATE SET (refresh_after, refresh_status, refresh_token, claims, version) = "+
"("+
Expand All @@ -95,7 +96,8 @@ func (s *Server) Ensure(
" IFNULL (excluded.claims, principals.claims),"+
" principals.version + 1"+
") RETURNING name, refresh_after, refresh_status, refresh_token, claims, version",
p.ID, p.RefreshAfter.AsTime(), p.RefreshStatus, p.RefreshToken, p.Claims)
p.ID, strings.ToLower(p.EmailDomain), p.RefreshAfter.AsTime(),
p.RefreshStatus, p.RefreshToken, p.Claims)

var pendingVersion int64
var refreshAfter time.Time
Expand Down Expand Up @@ -128,7 +130,7 @@ func (s *Server) List(_ *emptypb.Empty, out principal.Principals_ListServer) err
defer tx.Rollback(ctx)

rows, err := tx.Query(ctx,
"SELECT principal, name, claims, version "+
"SELECT principal, email_domain, name, claims, version "+
"FROM principals")
if err != nil {
return err
Expand All @@ -139,7 +141,7 @@ func (s *Server) List(_ *emptypb.Empty, out principal.Principals_ListServer) err
for rows.Next() {
p := &principal.Principal{ID: &principal.ID{}}

if err := rows.Scan(p.ID, &p.Label, &p.Claims, &p.Version); err != nil {
if err := rows.Scan(p.ID, &p.EmailDomain, &p.Label, &p.Claims, &p.Version); err != nil {
return err
}

Expand All @@ -153,35 +155,36 @@ func (s *Server) List(_ *emptypb.Empty, out principal.Principals_ListServer) err

// Load implements principal.PrincipalsServer.
func (s *Server) Load(ctx context.Context, req *principal.LoadRequest) (*principal.Principal, error) {
var id *principal.ID
var col string
var val interface{}

switch t := req.Kind.(type) {
case *principal.LoadRequest_Email:
id = &principal.ID{}
row := s.DB.QueryRow(ctx, "SELECT principal FROM principals WHERE email = $1", strings.ToLower(t.Email))
if err := row.Scan(id); errors.Is(err, pgx.ErrNoRows) {
return nil, status.Error(codes.NotFound, t.Email)
} else if err != nil {
return nil, err
}
col = "email"
val = strings.ToLower(t.Email)
case *principal.LoadRequest_ID:
id = t.ID
col = "principal"
val = t.ID
case *principal.LoadRequest_EmailDomain:
col = "email_domain"
val = strings.ToLower(t.EmailDomain)
default:
return nil, status.Error(codes.Unimplemented, "unknown kind")
}
ret := &principal.Principal{ID: id}
ret := &principal.Principal{ID: &principal.ID{}}
err := util.Retry(ctx, func(ctx context.Context) error {
var refreshAfter time.Time
row := s.DB.QueryRow(ctx,
"SELECT name, refresh_after, refresh_status, refresh_token, claims, version "+
"FROM principals "+
"WHERE principal = $1", id)
err := row.Scan(&ret.Label, &refreshAfter, &ret.RefreshStatus, &ret.RefreshToken,
&ret.Claims, &ret.Version)
"SELECT principal, name, email_domain, refresh_after, refresh_status, refresh_token, "+
"claims, version "+
"FROM principals WHERE "+col+" = $1", val)
err := row.Scan(ret.ID, &ret.Label, &ret.EmailDomain, &refreshAfter, &ret.RefreshStatus,
&ret.RefreshToken, &ret.Claims, &ret.Version)
ret.RefreshAfter = timestamppb.New(refreshAfter)
return err
})
if err == pgx.ErrNoRows {
return nil, status.Error(codes.NotFound, id.AsUUID().String())
return nil, status.Error(codes.NotFound, req.String())
}
return ret, err
}
Expand Down
22 changes: 22 additions & 0 deletions pkg/store/principal/principal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,26 @@ func TestPrincipal(t *testing.T) {
a.True(errors.Is(err, util.ErrVersionSkew))
})

t.Run("domain", func(t *testing.T) {
a := assert.New(t)
resp, err := rig.p.Ensure(ctx, &EnsureRequest{
Principal: &Principal{EmailDomain: "example.com"}})
if !a.NoError(err) {
return
}
a.Equal(int64(1), resp.Principal.Version)

found, err := rig.p.Load(ctx, &LoadRequest{
Kind: &LoadRequest_EmailDomain{EmailDomain: "example.com"}})
if !a.NoError(err) {
return
}
a.Equal(resp.Principal.ID.String(), found.ID.String())

_, err = rig.p.Ensure(ctx, &EnsureRequest{
Principal: &Principal{EmailDomain: "example.com"}})
if a.NotNil(err) {
a.Contains(err.Error(), "duplicate key value")
}
})
}
4 changes: 4 additions & 0 deletions pkg/store/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ CREATE TABLE IF NOT EXISTS principals (
region STRING NOT NULL DEFAULT IFNULL(crdb_internal.locality_value('region'), 'global') CHECK (length(region)>0),
principal UUID NOT NULL UNIQUE,
-- A principal may be created to delegate access to all users within a given email domain.
email_domain STRING NOT NULL DEFAULT '',
refresh_after TIMESTAMPTZ NOT NULL DEFAULT 0::TIMESTAMPTZ, -- The time at which the claims must be revalidated
refresh_status INT8 NOT NULL DEFAULT 0, -- Refresh state enum
refresh_token STRING NOT NULL DEFAULT '', -- OAuth2 refresh token to achieve revalidation
Expand All @@ -137,6 +140,7 @@ CREATE TABLE IF NOT EXISTS principals (
version INT8 NOT NULL CHECK (version > 0),
PRIMARY KEY (region, principal),
UNIQUE INDEX (email_domain) WHERE email_domain != '',
UNIQUE INDEX (email) WHERE email != ''
)
`, `
Expand Down
14 changes: 9 additions & 5 deletions pkg/store/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,15 @@ func (s *Server) Find(scope *session.Scope, server token.Tokens_FindServer) erro
s.cache.Remove(cacheKey)
}
return util.RetryLoop(ctx, func(ctx context.Context, sideEffect *util.Marker) error {
rows, err := s.db.Query(ctx,
"SELECT session, tenant, path, capabilities, expires_at, note, name, super "+
"FROM sessions "+
"WHERE principal = $1 AND expires_at > now()",
sn.PrincipalId)
rows, err := s.db.Query(ctx, `
WITH
dom AS (SELECT substring(email, '@(.*)$') as email_domain FROM principals WHERE principal = $1 AND email != ''),
prns AS (SELECT principal FROM principals JOIN dom USING (email_domain) UNION SELECT $1::UUID)
SELECT session, tenant, path, capabilities, expires_at, note, name, super
FROM sessions
JOIN prns USING (principal)
WHERE expires_at > now()
`, sn.PrincipalId)
if err != nil {
return err
}
Expand Down
103 changes: 101 additions & 2 deletions pkg/store/token/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,101 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"
)

// Check that a domain-level principal implicitly delegates to
// other principals with the name email domain.
func TestDomainInheritance(t *testing.T) {
a := assert.New(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()

rig, cleanup, err := testRig(ctx)
if !a.NoError(err) {
return
}
defer cleanup()

tID := tenant.NewID()
_, err = rig.tenants.Ensure(ctx, &tenant.EnsureRequest{Tenant: &tenant.Tenant{
Label: "Some Tenant",
ID: tID,
}})
if !a.NoError(err) {
return
}

pID := principal.NewID()
_, err = rig.principals.Ensure(ctx, &principal.EnsureRequest{
Principal: &principal.Principal{
ID: pID,
Claims: []byte(`{"email":"user@example.com"}`),
}})
if !a.NoError(err) {
return
}

principalSession, err := rig.tokens.Issue(ctx, &IssueRequest{Template: &session.Session{
PrincipalId: pID,
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
Capabilities: capabilities.All(),
Scope: &session.Scope{Kind: &session.Scope_OnPrincipal{OnPrincipal: pID}},
}})
if !a.NoError(err) {
return
}

domainID := principal.NewID()
_, err = rig.principals.Ensure(ctx, &principal.EnsureRequest{
Principal: &principal.Principal{
Label: "Domain Principal",
ID: domainID,
EmailDomain: "example.com",
}})
if !a.NoError(err) {
return
}

domainSession, err := rig.tokens.Issue(ctx, &IssueRequest{Template: &session.Session{
PrincipalId: domainID,
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
Capabilities: capabilities.All(),
Scope: &session.Scope{
Kind: &session.Scope_OnLocation{
OnLocation: &session.Location{
TenantId: tID,
Path: "/*",
}}}}})
if !a.NoError(err) {
return
}

{
sink := &sink{ctx: session.WithSession(ctx, principalSession.Issued)}
err = rig.tokens.Find(&session.Scope{}, sink)
if !a.NoError(err) {
return
}
a.Len(sink.ret, 2)
}

// Ensure that the domain-level session can will be invalidated.
_, err = rig.tokens.Invalidate(session.WithSession(ctx, domainSession.Issued),
&InvalidateRequest{Kind: &InvalidateRequest_ID{ID: domainSession.Issued.ID}})
if !a.NoError(err) {
return
}

rig.tokens.cache.Purge()

{
sink := &sink{ctx: session.WithSession(ctx, principalSession.Issued)}
err = rig.tokens.Find(&session.Scope{}, sink)
if !a.NoError(err) {
return
}
a.Len(sink.ret, 1)
}
}

func TestTokenFlow(t *testing.T) {
a := assert.New(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
Expand All @@ -47,15 +142,19 @@ func TestTokenFlow(t *testing.T) {
Label: "Some Tenant",
ID: tID,
}})
a.NoError(err)
if !a.NoError(err) {
return
}

pID := principal.NewID()
p := &principal.Principal{
Label: "Some User",
ID: pID,
}
_, err = rig.principals.Ensure(ctx, &principal.EnsureRequest{Principal: p})
a.NoError(err)
if !a.NoError(err) {
return
}

tcs := []*session.Session{
{
Expand Down

0 comments on commit b0b5f2f

Please sign in to comment.