Skip to content

Commit

Permalink
feat (extras/kms): add WithTableNamePrefix(...) option (#182)
Browse files Browse the repository at this point in the history
  • Loading branch information
jimlambrt committed Aug 31, 2023
1 parent edb68c3 commit 61c41ef
Show file tree
Hide file tree
Showing 23 changed files with 409 additions and 215 deletions.
21 changes: 15 additions & 6 deletions extras/kms/data_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ type dataKey struct {
Purpose KeyPurpose `json:"purpose,omitempty" gorm:"default:null"`
// CreateTime from the RDBMS
CreateTime time.Time `json:"create_time,omitempty" gorm:"default:current_timestamp"`

// tableNamePrefix defines the prefix to use before the table name and
// allows us to support custom prefixes as well as multi KMSs within a
// single schema.
tableNamePrefix string `gorm:"-"`
}

// newDataKey creates a new in memory data key. This key is used for wrapper
Expand All @@ -47,10 +52,11 @@ func newDataKey(rootKeyId string, purpose KeyPurpose, _ ...Option) (*dataKey, er
// Clone creates a clone of the DataKey
func (k *dataKey) Clone() *dataKey {
return &dataKey{
PrivateId: k.PrivateId,
RootKeyId: k.RootKeyId,
Purpose: k.Purpose,
CreateTime: k.CreateTime,
PrivateId: k.PrivateId,
RootKeyId: k.RootKeyId,
Purpose: k.Purpose,
CreateTime: k.CreateTime,
tableNamePrefix: k.tableNamePrefix,
}
}

Expand All @@ -77,8 +83,11 @@ func (k *dataKey) vetForWrite(ctx context.Context, opType dbw.OpType) error {
return nil
}

// TableName returns the tablename
func (k *dataKey) TableName() string { return "kms_data_key" }
// TableName returns the table name
func (k *dataKey) TableName() string {
const tableName = "data_key"
return fmt.Sprintf("%s_%s", k.tableNamePrefix, tableName)
}

// GetPrivateId returns the key's private id
func (k *dataKey) GetPrivateId() string { return k.PrivateId }
Expand Down
13 changes: 11 additions & 2 deletions extras/kms/data_key_version.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ type dataKeyVersion struct {
Version uint32 `json:"version,omitempty" gorm:"default:null"`
// CreateTime from the RDBMS
CreateTime time.Time `json:"create_time,omitempty" gorm:"default:current_timestamp"`

// tableNamePrefix defines the prefix to use before the table name and
// allows us to support custom prefixes as well as multi KMSs within a
// single schema.
tableNamePrefix string `gorm:"-"`
}

// newDataKeyVersion creates a new in memory data key version. No options
Expand Down Expand Up @@ -63,6 +68,7 @@ func (k *dataKeyVersion) Clone() *dataKeyVersion {
RootKeyVersionId: k.RootKeyVersionId,
Version: k.Version,
CreateTime: k.CreateTime,
tableNamePrefix: k.tableNamePrefix,
}
clone.Key = make([]byte, len(k.Key))
copy(clone.Key, k.Key)
Expand Down Expand Up @@ -95,8 +101,11 @@ func (k *dataKeyVersion) vetForWrite(ctx context.Context, opType dbw.OpType) err
return nil
}

// TableName returns the tablename
func (k *dataKeyVersion) TableName() string { return "kms_data_key_version" }
// TableName returns the table name
func (k *dataKeyVersion) TableName() string {
const tableName = "data_key_version"
return fmt.Sprintf("%s_%s", k.tableNamePrefix, tableName)
}

// Encrypt will encrypt the data key version's key
func (k *dataKeyVersion) Encrypt(ctx context.Context, cipher wrapping.Wrapper) error {
Expand Down
55 changes: 37 additions & 18 deletions extras/kms/kms.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ const (

// scopeWrapperCache defines an scope wrapper cache
scopeWrapperCache

// DefaultTableNamePrefix defines the default prefix that will be add to
// every table name when an optional WithTableNamePrefix is not used when
// calling New(...). For example the root table name might be "data_key" and
// then the default is added to form "kms_data_key"
DefaultTableNamePrefix = "kms"
)

// Kms is a way to access wrappers for a given scope and purpose. Since keys can
Expand All @@ -48,25 +54,32 @@ type Kms struct {

purposes []KeyPurpose
repo *repository

// tableNamePrefix defines the prefix to use before the table name and
// allows us to support custom prefixes as well as multi KMSs within a
// single schema.
tableNamePrefix string
}

// New takes in a reader, writer and a list of key purposes it will support.
// Every kms will support a KeyPurposeRootKey by default and it doesn't need to
// be passed in as one of the supported purposes.
//
// No options are currently supported.
func New(r dbw.Reader, w dbw.Writer, purposes []KeyPurpose, _ ...Option) (*Kms, error) {
// Supported options: WithTableNamePrefix.
func New(r dbw.Reader, w dbw.Writer, purposes []KeyPurpose, opt ...Option) (*Kms, error) {
const op = "kms.New"
repo, err := newRepository(r, w)
repo, err := newRepository(r, w, opt...)
if err != nil {
return nil, fmt.Errorf("%s: unable to initialize repository: %w", op, err)
}
purposes = append(purposes, KeyPurposeRootKey)
removeDuplicatePurposes(purposes)

opts := getOpts(opt...)
return &Kms{
purposes: purposes,
repo: repo,
purposes: purposes,
repo: repo,
tableNamePrefix: opts.withTableNamePrefix,
}, nil
}

Expand Down Expand Up @@ -203,7 +216,7 @@ func (k *Kms) GetWrapper(ctx context.Context, scopeId string, purpose KeyPurpose
// Fast-path: we have a valid key at the scope/purpose. Verify the key with
// that ID is in the multiwrapper; if not, fall through to reload from the
// DB.
currVersion, err := currentCollectionVersion(ctx, k.repo.reader)
currVersion, err := currentCollectionVersion(ctx, k.repo.reader, k.tableNamePrefix)
if err != nil {
return nil, fmt.Errorf("%s: unable to determine current version of the kms collection: %w", op, err)
}
Expand Down Expand Up @@ -273,9 +286,12 @@ func (k *Kms) ListKeys(ctx context.Context, scopeId string) ([]Key, error) {
if err != nil {
return nil, fmt.Errorf("%s: unable to lookup root key: %w", op, err)
}
rkv := rootKeyVersion{
tableNamePrefix: k.tableNamePrefix,
}
var rkVersions []*rootKeyVersion
// we don't need to decrypt their keys, so we'll get them directly from the repo.list(...)
if err := k.repo.list(ctx, &rkVersions, "root_key_id = ?", []interface{}{dbKey.PrivateId}, withOrderByVersion(ascendingOrderBy)); err != nil {
if err := k.repo.list(ctx, &rkVersions, "root_key_id = ?", []interface{}{dbKey.PrivateId}, withOrderByVersion(ascendingOrderBy), withTableName(rkv.TableName())); err != nil {
return nil, fmt.Errorf("%s: unable to list root key versions: %w", op, err)
}
rk := newKeyFromRootKey(dbKey)
Expand All @@ -291,12 +307,15 @@ func (k *Kms) ListKeys(ctx context.Context, scopeId string) ([]Key, error) {
if err != nil {
return nil, fmt.Errorf("%s: unable to list data keys: %w", op, err)
}
dkv := dataKeyVersion{
tableNamePrefix: k.tableNamePrefix,
}
keys := []Key{rk}
for _, dk := range dataKeys {
dataKey := newKeyFromDataKey(dk, rk.Scope)
var versions []*dataKeyVersion
// we don't need to decrypt their keys, so we'll get them directly from the repo.list(...)
err := k.repo.list(ctx, &versions, "data_key_id = ?", []interface{}{dk.GetPrivateId()}, withOrderByVersion(ascendingOrderBy))
err := k.repo.list(ctx, &versions, "data_key_id = ?", []interface{}{dk.GetPrivateId()}, withOrderByVersion(ascendingOrderBy), withTableName(dkv.TableName()))
if err != nil {
return nil, fmt.Errorf("%s: %w", op, err)
}
Expand Down Expand Up @@ -349,12 +368,12 @@ func (k *Kms) CreateKeys(ctx context.Context, scopeId string, purposes []KeyPurp
return fmt.Errorf("%s: %w", op, err)
}

if err := updateKeyCollectionVersion(ctx, w); err != nil {
if err := updateKeyCollectionVersion(ctx, w, k.tableNamePrefix); err != nil {
return fmt.Errorf("%s: %w", op, err)
}

opts := getOpts(opt...)
if _, err := createKeysTx(ctx, r, w, rootWrapper, opts.withRandomReader, scopeId, purposes...); err != nil {
if _, err := createKeysTx(ctx, r, w, rootWrapper, opts.withRandomReader, k.tableNamePrefix, scopeId, purposes...); err != nil {
if localTx != nil {
if rollBackErr := localTx.Rollback(ctx); rollBackErr != nil {
err = multierror.Append(err, rollBackErr)
Expand Down Expand Up @@ -476,7 +495,7 @@ func (k *Kms) RotateKeys(ctx context.Context, scopeId string, opt ...Option) err
return fmt.Errorf("%s: %w", op, err)
}

if err := updateKeyCollectionVersion(ctx, writer); err != nil {
if err := updateKeyCollectionVersion(ctx, writer, k.tableNamePrefix); err != nil {
return fmt.Errorf("%s: %w", op, err)
}

Expand All @@ -492,13 +511,13 @@ func (k *Kms) RotateKeys(ctx context.Context, scopeId string, opt ...Option) err
if opts.withRewrap {
// rewrap the root key versions with the provided rootWrapper (assuming
// it has a new wrapper)
if err := rewrapRootKeyVersionsTx(ctx, reader, writer, rootWrapper, rk.PrivateId); err != nil {
if err := rewrapRootKeyVersionsTx(ctx, reader, writer, rootWrapper, rk.PrivateId, WithTableNamePrefix(k.tableNamePrefix)); err != nil {
return fmt.Errorf("%s: unable to rewrap root key versions: %w", op, err)
}
}

// rotate the root key version (adding a new version)
rkv, err := rotateRootKeyVersionTx(ctx, writer, rootWrapper, rk.PrivateId, WithRandomReader(opts.withRandomReader))
rkv, err := rotateRootKeyVersionTx(ctx, writer, rootWrapper, rk.PrivateId, WithRandomReader(opts.withRandomReader), WithTableNamePrefix(k.tableNamePrefix))
if err != nil {
return fmt.Errorf("%s: unable to rotate root key version: %w", op, err)
}
Expand All @@ -510,7 +529,7 @@ func (k *Kms) RotateKeys(ctx context.Context, scopeId string, opt ...Option) err

// we've got a new rootKeyVersion wrapper, so let's rewrap the existing DEKs.
if opts.withRewrap {
if err := rewrapDataKeyVersionsTx(ctx, reader, writer, rkvWrapper, rk.PrivateId); err != nil {
if err := rewrapDataKeyVersionsTx(ctx, reader, writer, k.tableNamePrefix, rkvWrapper, rk.PrivateId); err != nil {
return fmt.Errorf("%s: unable to rewrap data key versions: %w", op, err)
}
}
Expand All @@ -520,7 +539,7 @@ func (k *Kms) RotateKeys(ctx context.Context, scopeId string, opt ...Option) err
if purpose == KeyPurposeRootKey {
continue
}
if err := rotateDataKeyVersionTx(ctx, reader, writer, rkv.PrivateId, rkvWrapper, rk.PrivateId, purpose, WithRandomReader(opts.withRandomReader)); err != nil {
if err := rotateDataKeyVersionTx(ctx, reader, writer, k.tableNamePrefix, rkv.PrivateId, rkvWrapper, rk.PrivateId, purpose, WithRandomReader(opts.withRandomReader)); err != nil {
return fmt.Errorf("%s: unable to rotate data key version: %w", op, err)
}
}
Expand Down Expand Up @@ -571,7 +590,7 @@ func (k *Kms) RewrapKeys(ctx context.Context, scopeId string, opt ...Option) err
return fmt.Errorf("%s: %w", op, err)
}

if err := updateKeyCollectionVersion(ctx, writer); err != nil {
if err := updateKeyCollectionVersion(ctx, writer, k.tableNamePrefix); err != nil {
return fmt.Errorf("%s: %w", op, err)
}

Expand All @@ -585,7 +604,7 @@ func (k *Kms) RewrapKeys(ctx context.Context, scopeId string, opt ...Option) err

// rewrap the root key versions with the provided rootWrapper (assuming
// it has a new wrapper)
if err := rewrapRootKeyVersionsTx(ctx, reader, writer, rootWrapper, rk.PrivateId); err != nil {
if err := rewrapRootKeyVersionsTx(ctx, reader, writer, rootWrapper, rk.PrivateId, WithTableNamePrefix(k.tableNamePrefix)); err != nil {
return fmt.Errorf("%s: unable to rewrap root key versions: %w", op, err)
}

Expand All @@ -595,7 +614,7 @@ func (k *Kms) RewrapKeys(ctx context.Context, scopeId string, opt ...Option) err
}

// we've got a new rootKeyVersion wrapper, so let's rewrap the existing DEKs.
if err := rewrapDataKeyVersionsTx(ctx, reader, writer, rkvWrapper, rk.PrivateId); err != nil {
if err := rewrapDataKeyVersionsTx(ctx, reader, writer, k.tableNamePrefix, rkvWrapper, rk.PrivateId); err != nil {
return fmt.Errorf("%s: unable to rewrap data key versions: %w", op, err)
}

Expand Down
Loading

0 comments on commit 61c41ef

Please sign in to comment.