Skip to content

Commit

Permalink
Add TPM option for master key (#1)
Browse files Browse the repository at this point in the history
Add an option to store master keys on a TPM. When this option is used,
the data can only be decrypted with the same TPM.
  • Loading branch information
rthellend committed Mar 10, 2024
1 parent 87be173 commit 3e83705
Show file tree
Hide file tree
Showing 8 changed files with 592 additions and 254 deletions.
222 changes: 172 additions & 50 deletions crypto/aes.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
package crypto

import (
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"encoding/binary"
"errors"
Expand All @@ -38,6 +40,8 @@ import (
"path/filepath"
"runtime"

"github.com/c2FmZQ/tpm"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/pbkdf2"
)

Expand All @@ -58,6 +62,7 @@ type AESKey struct {

logger Logger
strictWipe bool
tpmKey *tpm.Key
}

func (k *AESKey) Logger() Logger {
Expand Down Expand Up @@ -98,24 +103,24 @@ type AESMasterKey struct {

// CreateAESMasterKey creates a new master key.
func CreateAESMasterKey(opts ...Option) (MasterKey, error) {
var logger Logger = defaultLogger{}
var strictWipe bool
for _, opt := range opts {
if opt.logger != nil {
logger = opt.logger
}
if opt.strictWipe != nil {
strictWipe = *opt.strictWipe
}
}
var opt option
opt.apply(opts)
b := make([]byte, 64)
if _, err := rand.Read(b); err != nil {
return nil, err
}
key := aesKeyFromBytes(b)
key.logger = logger
key.strictWipe = strictWipe
return &AESMasterKey{key}, nil
key.logger = opt.logger
key.strictWipe = opt.strictWipe
mk := &AESMasterKey{key}
if opt.tpm != nil {
tpmkey, err := opt.tpm.CreateKey()
if err != nil {
return nil, err
}
mk.tpmKey = tpmkey
}
return mk, nil
}

// CreateAESMasterKeyForTest creates a new master key to tests.
Expand All @@ -133,51 +138,90 @@ func CreateAESMasterKeyForTest() (MasterKey, error) {

// ReadAESMasterKey reads an encrypted master key from file and decrypts it.
func ReadAESMasterKey(passphrase []byte, file string, opts ...Option) (MasterKey, error) {
var logger Logger = defaultLogger{}
var strictWipe bool
for _, opt := range opts {
if opt.logger != nil {
logger = opt.logger
}
if opt.strictWipe != nil {
strictWipe = *opt.strictWipe
}
}
var opt option
opt.apply(opts)
b, err := os.ReadFile(file)
if err != nil {
return nil, err
}
if len(b) < 64 {
return nil, ErrDecryptFailed
}
version, b := b[0], b[1:]
if version != 1 {
logger.Debugf("ReadMasterKey: unexpected version: %d", version)
str := cryptobyte.String(b)
var version uint8
if !str.ReadUint8(&version) {
return nil, ErrDecryptFailed
}
salt, b := b[:16], b[16:]
numIter, b := int(binary.BigEndian.Uint32(b[:4])), b[4:]
dk := pbkdf2.Key(passphrase, salt, numIter, 32, sha256.New)
if version != 1 && version != 3 {
opt.logger.Debugf("ReadMasterKey: unexpected version: %d", version)
return nil, ErrDecryptFailed
}
if version == 3 && opt.tpm == nil {
opt.logger.Debug("ReadMasterKey: missing WithTPM option")
return nil, ErrDecryptFailed
}
salt := make([]byte, 16)
if !str.ReadBytes(&salt, 16) {
return nil, ErrDecryptFailed
}
var numIter uint32
if !str.ReadUint32(&numIter) {
return nil, ErrDecryptFailed
}
dk := pbkdf2.Key(passphrase, salt, int(numIter), 32, sha256.New)
block, err := aes.NewCipher(dk)
if err != nil {
logger.Debug(err)
opt.logger.Debug(err)
return nil, ErrDecryptFailed
}
gcm, err := cipher.NewGCM(block)
if err != nil {
logger.Debug(err)
opt.logger.Debug(err)
return nil, ErrDecryptFailed
}
nonce := b[:gcm.NonceSize()]
encMasterKey := b[gcm.NonceSize():]
mkBytes, err := gcm.Open(nil, nonce, encMasterKey, nil)
nonce := make([]byte, gcm.NonceSize())
if !str.ReadBytes(&nonce, len(nonce)) {
return nil, ErrDecryptFailed
}
mkBytes, err := gcm.Open(nil, nonce, []byte(str), nil)
if err != nil {
logger.Debug(err)
opt.logger.Debug(err)
return nil, ErrDecryptFailed
}
key := aesKeyFromBytes(mkBytes)
key.logger = logger
key.strictWipe = strictWipe
var key *AESKey
if version == 1 {
key = aesKeyFromBytes(mkBytes)
} else { // version == 3
str := cryptobyte.String(mkBytes)
var length uint16
if !str.ReadUint16(&length) {
return nil, ErrDecryptFailed
}
encKey := make([]byte, length)
if !str.ReadBytes(&encKey, len(encKey)) {
return nil, ErrDecryptFailed
}
if !str.ReadUint16(&length) {
return nil, ErrDecryptFailed
}
tpmCtx := make([]byte, length)
if !str.ReadBytes(&tpmCtx, len(tpmCtx)) {
return nil, ErrDecryptFailed
}
tpmKey, err := opt.tpm.UnmarshalKey(tpmCtx)
if err != nil {
return nil, err
}
decKey, err := tpmKey.Decrypt(nil, encKey, nil)
if err != nil {
opt.logger.Debug(err)
return nil, ErrDecryptFailed
}
key = aesKeyFromBytes(decKey)
key.tpmKey = tpmKey
}
key.logger = opt.logger
key.strictWipe = opt.strictWipe
return &AESMasterKey{key}, nil
}

Expand All @@ -191,8 +235,6 @@ func (mk AESMasterKey) Save(passphrase []byte, file string) error {
if len(passphrase) == 0 {
numIter = 10
}
numIterBin := make([]byte, 4)
binary.BigEndian.PutUint32(numIterBin, uint32(numIter))
dk := pbkdf2.Key(passphrase, salt, numIter, 32, sha256.New)
block, err := aes.NewCipher(dk)
if err != nil {
Expand All @@ -209,11 +251,44 @@ func (mk AESMasterKey) Save(passphrase []byte, file string) error {
mk.Logger().Debug(err)
return ErrEncryptFailed
}
encMasterKey := gcm.Seal(nonce, nonce, mk.key(), nil)
data := []byte{1} // version
data = append(data, salt...)
data = append(data, numIterBin...)
data = append(data, encMasterKey...)
var version uint8
var payload []byte
if mk.tpmKey == nil {
version = 1
payload = mk.key()
} else {
version = 3
buf := cryptobyte.NewBuilder(nil)
// encKey, err := mk.tpmKey.Encrypt(mk.key())
encKey, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, mk.tpmKey.Public().(*rsa.PublicKey), mk.key(), nil)
if err != nil {
mk.Logger().Debug(err)
return ErrEncryptFailed
}
buf.AddUint16(uint16(len(encKey)))
buf.AddBytes(encKey)
keyctx, err := mk.tpmKey.Marshal()
if err != nil {
mk.Logger().Debug(err)
return ErrEncryptFailed
}
buf.AddUint16(uint16(len(keyctx)))
buf.AddBytes(keyctx)
if payload, err = buf.Bytes(); err != nil {
mk.Logger().Debug(err)
return ErrEncryptFailed
}
}
encMasterKey := gcm.Seal(nonce, nonce, payload, nil)
buf := cryptobyte.NewBuilder([]byte{version})
buf.AddBytes(salt)
buf.AddUint32(uint32(numIter))
buf.AddBytes(encMasterKey)
data, err := buf.Bytes()
if err != nil {
mk.Logger().Debug(err)
return ErrEncryptFailed
}
dir, _ := filepath.Split(file)
if err := os.MkdirAll(dir, 0700); err != nil {
return err
Expand All @@ -237,6 +312,23 @@ func (k AESKey) Hash(b []byte) []byte {

// Decrypt decrypts data that was encrypted with Encrypt and the same key.
func (k AESKey) Decrypt(data []byte) ([]byte, error) {
if k.tpmKey != nil {
sigSize := k.tpmKey.Bits() / 8
if len(data) < 1+sigSize {
return nil, ErrDecryptFailed
}
version, data := data[0], data[1:]
if version != 3 {
return nil, ErrDecryptFailed
}
encData, data := data[:len(data)-sigSize], data[len(data)-sigSize:]
sig := data[:sigSize]
hashed := sha256.Sum256(encData)
if err := rsa.VerifyPKCS1v15(k.tpmKey.Public().(*rsa.PublicKey), crypto.SHA256, hashed[:], sig); err != nil {
return nil, ErrDecryptFailed
}
return k.tpmKey.Decrypt(nil, encData, nil)
}
if len(k.maskedKey) == 0 {
k.Logger().Fatal("key is not set")
}
Expand Down Expand Up @@ -274,6 +366,23 @@ func (k AESKey) Decrypt(data []byte) ([]byte, error) {

// Encrypt encrypts data using the key.
func (k AESKey) Encrypt(data []byte) ([]byte, error) {
if k.tpmKey != nil {
// encData, err := k.tpmKey.Encrypt(data)
encData, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, k.tpmKey.Public().(*rsa.PublicKey), data, nil)
if err != nil {
return nil, ErrEncryptFailed
}
hashed := sha256.Sum256(encData)
sig, err := k.tpmKey.Sign(nil, hashed[:], crypto.SHA256)
if err != nil {
return nil, ErrEncryptFailed
}
out := make([]byte, 1+len(encData)+len(sig))
out[0] = 3 // version
copy(out[1:], encData)
copy(out[1+len(encData):], sig)
return out, nil
}
if len(k.maskedKey) == 0 {
k.Logger().Fatal("key is not set")
}
Expand Down Expand Up @@ -347,10 +456,17 @@ func (k AESKey) NewKey() (EncryptionKey, error) {
return ek, nil
}

func (k AESKey) keysize() int {
if k.tpmKey != nil {
return 2*k.tpmKey.Bits()/8 + 1
}
return aesEncryptedKeySize
}

// DecryptKey decrypts an encrypted key.
func (k AESKey) DecryptKey(encryptedKey []byte) (EncryptionKey, error) {
if len(encryptedKey) != aesEncryptedKeySize {
k.Logger().Debugf("DecryptKey: unexpected encrypted key size %d != %d", len(encryptedKey), aesEncryptedKeySize)
if len(encryptedKey) != k.keysize() {
k.Logger().Debugf("DecryptKey: unexpected encrypted key size %d != %d", len(encryptedKey), k.keysize())
return nil, ErrDecryptFailed
}
b, err := k.Decrypt(encryptedKey)
Expand Down Expand Up @@ -509,6 +625,9 @@ func (r *AESStreamReader) Close() error {

// StartReader opens a reader to decrypt a stream of data.
func (k AESKey) StartReader(ctx []byte, r io.Reader) (StreamReader, error) {
if k.tpmKey != nil {
return nil, errors.New("operation not supported with TPM key")
}
var start int64
if seeker, ok := r.(io.Seeker); ok {
off, err := seeker.Seek(0, io.SeekCurrent)
Expand Down Expand Up @@ -577,6 +696,9 @@ func (w *AESStreamWriter) Close() (err error) {

// StartWriter opens a writer to encrypt a stream of data.
func (k AESKey) StartWriter(ctx []byte, w io.Writer) (StreamWriter, error) {
if k.tpmKey != nil {
return nil, errors.New("operation not supported with TPM key")
}
block, err := aes.NewCipher(k.key()[:32])
if err != nil {
k.Logger().Debug(err)
Expand All @@ -592,7 +714,7 @@ func (k AESKey) StartWriter(ctx []byte, w io.Writer) (StreamWriter, error) {

// ReadEncryptedKey reads an encrypted key and decrypts it.
func (k AESKey) ReadEncryptedKey(r io.Reader) (EncryptionKey, error) {
buf := make([]byte, aesEncryptedKeySize)
buf := make([]byte, k.keysize())
if _, err := io.ReadFull(r, buf); err != nil {
k.Logger().Debug(err)
return nil, ErrDecryptFailed
Expand All @@ -603,8 +725,8 @@ func (k AESKey) ReadEncryptedKey(r io.Reader) (EncryptionKey, error) {
// WriteEncryptedKey writes the encrypted key to the writer.
func (k AESKey) WriteEncryptedKey(w io.Writer) error {
n, err := w.Write(k.encryptedKey)
if n != aesEncryptedKeySize {
k.Logger().Debugf("WriteEncryptedKey: unexpected key size: %d != %d", n, aesEncryptedKeySize)
if n == 0 {
k.Logger().Debugf("WriteEncryptedKey: unexpected key size: %d", n)
return ErrEncryptFailed
}
return err
Expand Down
Loading

0 comments on commit 3e83705

Please sign in to comment.