Skip to content

Commit

Permalink
activate mhsm in create/update with certificate
Browse files Browse the repository at this point in the history
  • Loading branch information
wuxu92 committed Mar 9, 2023
1 parent 8c7510d commit 601dd23
Show file tree
Hide file tree
Showing 6 changed files with 362 additions and 0 deletions.
7 changes: 7 additions & 0 deletions internal/clients/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ func Build(ctx context.Context, builder ClientBuilder) (*Client, error) {
return nil, fmt.Errorf("building account: %+v", err)
}

managedHSMAuth, err := auth.NewAuthorizerFromCredentials(ctx, *builder.AuthConfig, builder.AuthConfig.Environment.ManagedHSM)
if err != nil {
return nil, fmt.Errorf("unable to build authorizer for Managed HSM API: %+v", err)
}

client := Client{
Account: account,
}
Expand All @@ -121,6 +126,7 @@ func Build(ctx context.Context, builder ClientBuilder) (*Client, error) {
Authorizers: &common.Authorizers{
BatchManagement: batchManagementAuth,
KeyVault: keyVaultAuth,
ManagedHSM: managedHSMAuth,
ResourceManager: resourceManagerAuth,
Storage: storageAuth,
Synapse: synapseAuth,
Expand All @@ -137,6 +143,7 @@ func Build(ctx context.Context, builder ClientBuilder) (*Client, error) {

BatchManagementAuthorizer: authWrapper.AutorestAuthorizer(batchManagementAuth),
KeyVaultAuthorizer: authWrapper.AutorestAuthorizer(keyVaultAuth).BearerAuthorizerCallback(),
ManagedHSMAuthorizer: authWrapper.AutorestAuthorizer(managedHSMAuth).BearerAuthorizerCallback(),
ResourceManagerAuthorizer: authWrapper.AutorestAuthorizer(resourceManagerAuth),
StorageAuthorizer: authWrapper.AutorestAuthorizer(storageAuth),
SynapseAuthorizer: authWrapper.AutorestAuthorizer(synapseAuth),
Expand Down
2 changes: 2 additions & 0 deletions internal/common/client_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
type Authorizers struct {
BatchManagement auth.Authorizer
KeyVault auth.Authorizer
ManagedHSM auth.Authorizer
ResourceManager auth.Authorizer
Storage auth.Authorizer
Synapse auth.Authorizer
Expand Down Expand Up @@ -54,6 +55,7 @@ type ClientOptions struct {
// Legacy authorizers for go-autorest
BatchManagementAuthorizer autorest.Authorizer
KeyVaultAuthorizer autorest.Authorizer
ManagedHSMAuthorizer autorest.Authorizer
ResourceManagerAuthorizer autorest.Authorizer
StorageAuthorizer autorest.Authorizer
SynapseAuthorizer autorest.Authorizer
Expand Down
10 changes: 10 additions & 0 deletions internal/services/keyvault/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import (
type Client struct {
ManagedHsmClient *keyvault.ManagedHsmsClient
ManagementClient *keyvaultmgmt.BaseClient
MHSMSDClient *keyvaultmgmt.HSMSecurityDomainClient
VaultsClient *keyvault.VaultsClient
MHSMRoleClient *keyvaultmgmt.RoleDefinitionsClient
options *common.ClientOptions
}

Expand All @@ -20,13 +22,21 @@ func NewClient(o *common.ClientOptions) *Client {
managementClient := keyvaultmgmt.New()
o.ConfigureClient(&managementClient.Client, o.KeyVaultAuthorizer)

sdClient := keyvaultmgmt.NewHSMSecurityDomainClient()
o.ConfigureClient(&sdClient.Client, o.ManagedHSMAuthorizer)

mhsmRoleDefineClient := keyvaultmgmt.NewRoleDefinitionsClient()
o.ConfigureClient(&mhsmRoleDefineClient.Client, o.ManagedHSMAuthorizer)

vaultsClient := keyvault.NewVaultsClientWithBaseURI(o.ResourceManagerEndpoint, o.SubscriptionId)
o.ConfigureClient(&vaultsClient.Client, o.ResourceManagerAuthorizer)

return &Client{
ManagedHsmClient: &managedHsmClient,
ManagementClient: &managementClient,
MHSMSDClient: &sdClient,
VaultsClient: &vaultsClient,
MHSMRoleClient: &mhsmRoleDefineClient,
options: o,
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,32 +1,43 @@
package keyvault

import (
"context"
"crypto/sha1"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"time"

"github.com/Azure/azure-sdk-for-go/services/keyvault/mgmt/2021-10-01/keyvault" // nolint: staticcheck
"github.com/gofrs/uuid"
"github.com/hashicorp/go-azure-helpers/lang/pointer"
"github.com/hashicorp/go-azure-helpers/lang/response"
"github.com/hashicorp/go-azure-helpers/resourcemanager/commonschema"
"github.com/hashicorp/go-azure-helpers/resourcemanager/location"
"github.com/hashicorp/terraform-provider-azurerm/helpers/azure"
"github.com/hashicorp/terraform-provider-azurerm/helpers/tf"
"github.com/hashicorp/terraform-provider-azurerm/internal/clients"
"github.com/hashicorp/terraform-provider-azurerm/internal/services/keyvault/client"
"github.com/hashicorp/terraform-provider-azurerm/internal/services/keyvault/parse"
"github.com/hashicorp/terraform-provider-azurerm/internal/services/keyvault/validate"
"github.com/hashicorp/terraform-provider-azurerm/internal/tags"
"github.com/hashicorp/terraform-provider-azurerm/internal/tf/pluginsdk"
"github.com/hashicorp/terraform-provider-azurerm/internal/tf/validation"
"github.com/hashicorp/terraform-provider-azurerm/internal/timeouts"
"github.com/hashicorp/terraform-provider-azurerm/utils"
kv73 "github.com/tombuildsstuff/kermit/sdk/keyvault/7.4/keyvault"
)

func resourceKeyVaultManagedHardwareSecurityModule() *pluginsdk.Resource {
return &pluginsdk.Resource{
Create: resourceArmKeyVaultManagedHardwareSecurityModuleCreate,
Read: resourceArmKeyVaultManagedHardwareSecurityModuleRead,
Delete: resourceArmKeyVaultManagedHardwareSecurityModuleDelete,
Update: resourceArmKeyVaultManagedHardwareSecurityModuleUpdate,

Importer: pluginsdk.ImporterValidatingResourceId(func(id string) error {
_, err := parse.ManagedHSMID(id)
Expand All @@ -36,6 +47,7 @@ func resourceKeyVaultManagedHardwareSecurityModule() *pluginsdk.Resource {
Timeouts: &pluginsdk.ResourceTimeout{
Create: pluginsdk.DefaultTimeout(60 * time.Minute),
Read: pluginsdk.DefaultTimeout(5 * time.Minute),
Update: pluginsdk.DefaultTimeout(30 * time.Minute),
Delete: pluginsdk.DefaultTimeout(60 * time.Minute),
},

Expand Down Expand Up @@ -131,6 +143,30 @@ func resourceKeyVaultManagedHardwareSecurityModule() *pluginsdk.Resource {
},
},

"security_domain_certificate": {
Type: pluginsdk.TypeSet,
MinItems: 3,
MaxItems: 10,
Optional: true,
Elem: &pluginsdk.Schema{
Type: pluginsdk.TypeString,
ValidateFunc: validate.NestedItemId,
},
},

"security_domain_quorum": {
Type: pluginsdk.TypeInt,
Optional: true,
RequiredWith: []string{"security_domain_certificate"},
ValidateFunc: validation.IntBetween(2, 10),
},

"security_domain_enc_data": {
Type: pluginsdk.TypeString,
Computed: true,
Sensitive: true,
},

// https://github.com/Azure/azure-rest-api-specs/issues/13365
"tags": tags.ForceNewSchema(),
},
Expand Down Expand Up @@ -181,6 +217,7 @@ func resourceArmKeyVaultManagedHardwareSecurityModuleCreate(d *pluginsdk.Resourc
hsm.Properties.PublicNetworkAccess = keyvault.PublicNetworkAccessDisabled
}

client.Client.RetryAttempts = 1 // retry if failed
future, err := client.CreateOrUpdate(ctx, id.ResourceGroup, id.Name, hsm)
if err != nil {
return fmt.Errorf("creating %s: %+v", id, err)
Expand All @@ -190,10 +227,68 @@ func resourceArmKeyVaultManagedHardwareSecurityModuleCreate(d *pluginsdk.Resourc
return fmt.Errorf("waiting on creation for %s: %+v", id, err)
}

// security domain download to activate this module
if certs := utils.ExpandStringSlice(d.Get("security_domain_certificate").(*pluginsdk.Set).List()); certs != nil && len(*certs) > 0 {
// get hsm uri
hsmRes, err := future.Result(*client)
if err != nil {
return fmt.Errorf("get hsm result: %v", err)
}
if hsmRes.Properties == nil || hsmRes.Properties.HsmURI == nil {
return fmt.Errorf("get nil hsmURI for %s", id)
}

encData, err := securityDomainDownload(ctx,
meta.(*clients.Client).KeyVault,
*hsmRes.Properties.HsmURI,
*certs,
d.Get("security_domain_quorum").(int),
)
if err == nil {
d.Set("security_domain_enc_data", encData)
} else {
log.Printf("security domain download: %v", err)
}
}

d.SetId(id.ID())
return resourceArmKeyVaultManagedHardwareSecurityModuleRead(d, meta)
}

// update to re-activate the security module
func resourceArmKeyVaultManagedHardwareSecurityModuleUpdate(d *pluginsdk.ResourceData, meta interface{}) error {
cli := meta.(*clients.Client).KeyVault.ManagedHsmClient
ctx, cancel := timeouts.ForUpdate(meta.(*clients.Client).StopContext, d)
defer cancel()

id, err := parse.ManagedHSMID(d.Id())
if err != nil {
return err
}

resp, err := cli.Get(ctx, id.ResourceGroup, id.Name)
if err != nil || resp.Properties == nil || resp.Properties.HsmURI == nil {
return fmt.Errorf("retrieving %s: %+v", id, err)
}

if d.HasChange("security_domain_certificate") {
if certs := utils.ExpandStringSlice(d.Get("security_domain_certificate").(*pluginsdk.Set).List()); len(*certs) > 0 {
// get hsm uri
encData, err := securityDomainDownload(ctx,
meta.(*clients.Client).KeyVault,
*resp.Properties.HsmURI,
*certs,
d.Get("security_domain_quorum").(int),
)
if err != nil {
return fmt.Errorf("security domain download: %v", err)
}
d.Set("security_domain_enc_data", encData)
}
}
return nil
}

func resourceArmKeyVaultManagedHardwareSecurityModuleRead(d *pluginsdk.ResourceData, meta interface{}) error {
client := meta.(*clients.Client).KeyVault.ManagedHsmClient
ctx, cancel := timeouts.ForRead(meta.(*clients.Client).StopContext, d)
Expand Down Expand Up @@ -328,3 +423,84 @@ func flattenMHSMNetworkAcls(acl *keyvault.MHSMNetworkRuleSet) []interface{} {
}
return []interface{}{res}
}

func securityDomainDownload(ctx context.Context, cli *client.Client, vaultBaseURL string, certIDs []string, qourum int) (encDataStr string, err error) {
sdClient := cli.MHSMSDClient
keyClient := cli.ManagementClient

var param kv73.CertificateInfoObject
param.Required = utils.Int32(int32(qourum))
var certs []kv73.SecurityDomainJSONWebKey
for _, certID := range certIDs {
keyID, _ := parse.ParseNestedItemID(certID)
certRes, err := keyClient.GetCertificate(ctx, keyID.KeyVaultBaseUrl, keyID.Name, keyID.Version)
if err != nil {
return "", fmt.Errorf("retriving key %s: %v", certID, err)
}
if certRes.Cer == nil {
return "", fmt.Errorf("got nil key for %s", certID)
}
cert := kv73.SecurityDomainJSONWebKey{
Kty: pointer.FromString("RSA"),
KeyOps: &[]string{""},
Alg: pointer.FromString("RSA-OAEP-256"),
}
if *cert.Alg == "" {
}
if certRes.Policy != nil && certRes.Policy.KeyProperties != nil {
cert.Kty = pointer.FromString(string(certRes.Policy.KeyProperties.KeyType))
}
x5c := ""
if contents := certRes.Cer; contents != nil {
x5c = base64.StdEncoding.EncodeToString(*contents)
}
cert.X5c = &[]string{x5c}

sum1 := sha1.Sum([]byte(x5c))
x5tDst := make([]byte, base64.StdEncoding.EncodedLen(len(sum1)))
base64.URLEncoding.Encode(x5tDst, sum1[:])
cert.X5t = pointer.FromString(string(x5tDst))

sum256 := sha256.Sum256([]byte(x5c))
s256Dst := make([]byte, base64.StdEncoding.EncodedLen(len(sum256)))
base64.URLEncoding.Encode(s256Dst, sum256[:])
cert.X5tS256 = pointer.FromString(string(s256Dst))
certs = append(certs, cert)
}
param.Certificates = &certs

future, err := sdClient.Download(ctx, vaultBaseURL, param)
if err != nil {
return "", fmt.Errorf("downloading for %s: %v", vaultBaseURL, err)
}

originResponse := future.Response()
data, err := io.ReadAll(originResponse.Body)
if err != nil {
return "", err
}
var encData struct {
Value string `json:"value"`
}

err = json.Unmarshal(data, &encData)
if err != nil {
return "", fmt.Errorf("unmarshal EncData: %v", err)
}
// wait download code has bug will never return
// limit ctx to wait 5 second(value from azcli)
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
if err := future.WaitForCompletionRef(ctx, sdClient.Client); err != nil {
if !response.WasStatusCode(future.Response(), http.StatusOK) {
log.Printf("waiting for download of %s: %v. ignore", vaultBaseURL, err)
}
}
result, err := future.Result(*sdClient)
if result.Value != nil {
encData.Value = pointer.ToString(result.Value)
}
encDataStr = encData.Value

return encDataStr, nil
}
Loading

0 comments on commit 601dd23

Please sign in to comment.