Skip to content

Commit

Permalink
Merge pull request #33 from quentingodeau/feature/spn
Browse files Browse the repository at this point in the history
Feature/spn
  • Loading branch information
samansmink committed Feb 21, 2024
2 parents 923ff39 + 0c011b7 commit 51f680e
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 0 deletions.
55 changes: 55 additions & 0 deletions src/azure_secret.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,51 @@ static unique_ptr<BaseSecret> CreateAzureSecretFromCredentialChain(ClientContext
return std::move(result);
}

static unique_ptr<BaseSecret> CreateAzureSecretFromServicePrincipal(ClientContext &context, CreateSecretInput &input) {
auto tenant_id = input.options.find("tenant_id");
auto client_id = input.options.find("client_id");
auto client_secret = input.options.find("client_secret");
auto client_certificate_path = input.options.find("client_certificate_path");

auto account_name = input.options.find("account_name");
auto azure_endpoint = input.options.find("azure_endpoint");

auto scope = input.scope;
if (scope.empty()) {
scope.push_back("azure://");
scope.push_back("az://");
}

auto result = make_uniq<KeyValueSecret>(scope, input.type, input.provider, input.name);

FillWithAzureProxyInfo(context, input, *result);

// Add config to kv secret
if (tenant_id != input.options.end()) {
result->secret_map["tenant_id"] = tenant_id->second;
}
if (client_id != input.options.end()) {
result->secret_map["client_id"] = client_id->second;
}
if (client_secret != input.options.end()) {
result->secret_map["client_secret"] = client_secret->second;
result->redact_keys.insert("client_secret");
}
if (client_certificate_path != input.options.end()) {
result->secret_map["client_certificate_path"] = client_certificate_path->second;
result->redact_keys.insert("client_certificate_path");
}

if (account_name != input.options.end()) {
result->secret_map["account_name"] = account_name->second;
}
if (azure_endpoint != input.options.end()) {
result->secret_map["azure_endpoint"] = azure_endpoint->second;
}

return std::move(result);
}

static void RegisterCommonSecretParameters(CreateSecretFunction &function) {
// Register azure common parameters
function.named_parameters["account_name"] = LogicalType::VARCHAR;
Expand Down Expand Up @@ -119,6 +164,16 @@ void CreateAzureSecretFunctions::Register(DatabaseInstance &instance) {
cred_chain_function.named_parameters["azure_endpoint"] = LogicalType::VARCHAR;
RegisterCommonSecretParameters(cred_chain_function);
ExtensionUtil::RegisterFunction(instance, cred_chain_function);

CreateSecretFunction service_principal_function = {type, "service_principal",
CreateAzureSecretFromServicePrincipal};
service_principal_function.named_parameters["tenant_id"] = LogicalType::VARCHAR;
service_principal_function.named_parameters["client_id"] = LogicalType::VARCHAR;
service_principal_function.named_parameters["client_secret"] = LogicalType::VARCHAR;
service_principal_function.named_parameters["client_certificate_path"] = LogicalType::VARCHAR;
service_principal_function.named_parameters["azure_endpoint"] = LogicalType::VARCHAR;
RegisterCommonSecretParameters(service_principal_function);
ExtensionUtil::RegisterFunction(instance, service_principal_function);
}

} // namespace duckdb
51 changes: 51 additions & 0 deletions src/azure_storage_account_client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include <azure/core/credentials/token_credential_options.hpp>
#include <azure/identity/azure_cli_credential.hpp>
#include <azure/identity/chained_token_credential.hpp>
#include <azure/identity/client_certificate_credential.hpp>
#include <azure/identity/client_secret_credential.hpp>
#include <azure/identity/default_azure_credential.hpp>
#include <azure/identity/environment_credential.hpp>
#include <azure/identity/managed_identity_credential.hpp>
Expand Down Expand Up @@ -71,6 +73,23 @@ CreateChainedTokenCredential(const std::string &chain,
return std::make_shared<Azure::Identity::ChainedTokenCredential>(sources);
}

static std::shared_ptr<Azure::Core::Credentials::TokenCredential>
CreateClientCredential(const std::string &tenant_id, const std::string &client_id, const std::string &client_secret,
const std::string &client_certificate_path,
const Azure::Core::Http::Policies::TransportOptions &transport_options) {
auto credential_options = ToTokenCredentialOptions(transport_options);
if (!client_secret.empty()) {
return std::make_shared<Azure::Identity::ClientSecretCredential>(tenant_id, client_id, client_secret,
credential_options);
} else if (!client_certificate_path.empty()) {
return std::make_shared<Azure::Identity::ClientCertificateCredential>(
tenant_id, client_id, client_certificate_path, credential_options);
}

throw InvalidInputException("Failed to fetch key 'client_secret' or 'client_certificate_path' from secret "
"'service_principal' of type 'azure'");
}

static Azure::Core::Http::Policies::TransportOptions GetTransportOptions(const KeyValueSecret &secret) {
Azure::Core::Http::Policies::TransportOptions transport_options;

Expand Down Expand Up @@ -150,13 +169,45 @@ GetStorageAccountClientFromCredentialChainProvider(const KeyValueSecret &secret)
return Azure::Storage::Blobs::BlobServiceClient(account_url, std::move(credential), blob_options);
}

static Azure::Storage::Blobs::BlobServiceClient
GetStorageAccountClientFromServicePrincipalProvider(const KeyValueSecret &secret) {
auto transport_options = GetTransportOptions(secret);

constexpr bool error_on_missing = true;
auto tenant_id = secret.TryGetValue("tenant_id", error_on_missing);
auto client_id = secret.TryGetValue("client_id", error_on_missing);
auto client_secret_val = secret.TryGetValue("client_secret");
auto client_certificate_path_val = secret.TryGetValue("client_certificate_path");

std::string client_secret = client_secret_val.IsNull() ? "" : client_secret_val.ToString();
std::string client_certificate_path =
client_certificate_path_val.IsNull() ? "" : client_certificate_path_val.ToString();

auto token_credential = CreateClientCredential(tenant_id.ToString(), client_id.ToString(), client_secret,
client_certificate_path, transport_options);

auto account_name = secret.TryGetValue("account_name", error_on_missing);

std::string endpoint = DEFAULT_ENDPOINT;
auto endpoint_value = secret.TryGetValue("endpoint");
if (!endpoint_value.IsNull()) {
endpoint = endpoint_value.ToString();
}

auto account_url = "https://" + account_name.ToString() + "." + endpoint;
auto blob_options = ToBlobClientOptions(transport_options);
return Azure::Storage::Blobs::BlobServiceClient {account_url, token_credential, blob_options};
}

static Azure::Storage::Blobs::BlobServiceClient GetStorageAccountClient(const KeyValueSecret &secret) {
auto &provider = secret.GetProvider();
// default provider
if (provider == "config") {
return GetStorageAccountClientFromConfigProvider(secret);
} else if (provider == "credential_chain") {
return GetStorageAccountClientFromCredentialChainProvider(secret);
} else if (provider == "service_principal") {
return GetStorageAccountClientFromServicePrincipalProvider(secret);
}

throw InvalidInputException("Unsupported provider type %s for azure", provider);
Expand Down
31 changes: 31 additions & 0 deletions test/sql/azure_spn.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# name: test/sql/azure_spn.test
# description: test azure extension with service principal authentication
# group: [azure]

require azure

require-env TENANT_ID

require-env CLIENT_ID

require-env CLIENT_SECRET

require-env ACCOUNT_NAME

statement ok
CREATE SECRET s1 (
TYPE AZURE,
PROVIDER SERVICE_PRINCIPAL,
TENANT_ID '${TENANT_ID}',
CLIENT_ID '${CLIENT_ID}',
CLIENT_SECRET '${CLIENT_SECRET}',
ACCOUNT_NAME '${ACCOUNT_NAME}'
)

query I
SELECT count(*) FROM 'azure://testing-private/l.csv';
----
60175

statement ok
DROP SECRET s1

0 comments on commit 51f680e

Please sign in to comment.