Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make SecretEntry copyable so we don't lose the underlying BaseSecret obj for non-CatalogSet secrets #10518

Merged
merged 2 commits into from Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/function/table/system/duckdb_secrets.cpp
Expand Up @@ -97,17 +97,17 @@ void DuckDBSecretsFunction(ClientContext &context, TableFunctionInput &data_p, D
auto &secret_entry = secrets[data.offset];

vector<Value> scope_value;
for (const auto &scope_entry : secret_entry.get().secret->GetScope()) {
for (const auto &scope_entry : secret_entry.secret->GetScope()) {
scope_value.push_back(scope_entry);
}

const auto &secret = *secret_entry.get().secret;
const auto &secret = *secret_entry.secret;

output.SetValue(0, count, secret.GetName());
output.SetValue(1, count, Value(secret.GetType()));
output.SetValue(2, count, Value(secret.GetProvider()));
output.SetValue(3, count, Value(secret_entry.get().persist_type == SecretPersistType::PERSISTENT));
output.SetValue(4, count, Value(secret_entry.get().storage_mode));
output.SetValue(3, count, Value(secret_entry.persist_type == SecretPersistType::PERSISTENT));
output.SetValue(4, count, Value(secret_entry.storage_mode));
output.SetValue(5, count, Value::LIST(LogicalType::VARCHAR, scope_value));
output.SetValue(6, count, secret.ToString(bind_data.redact));

Expand Down
9 changes: 9 additions & 0 deletions src/include/duckdb/main/secret/secret.hpp
Expand Up @@ -100,6 +100,11 @@ class BaseSecret {
//! Serialize this secret
virtual void Serialize(Serializer &serializer) const;

virtual unique_ptr<const BaseSecret> Clone() const {
D_ASSERT(typeid(BaseSecret) == typeid(*this));
return make_uniq<BaseSecret>(*this);
stephaniewang526 marked this conversation as resolved.
Show resolved Hide resolved
}

//! Getters
const vector<string> &GetScope() const {
return prefix_paths;
Expand Down Expand Up @@ -188,6 +193,10 @@ class KeyValueSecret : public BaseSecret {
return duckdb::unique_ptr_cast<TYPE, BaseSecret>(std::move(result));
}

unique_ptr<const BaseSecret> Clone() const override {
return make_uniq<KeyValueSecret>(*this);
}

//! the map of key -> values that make up the secret
case_insensitive_tree_t<Value> secret_map;
//! keys that are sensitive and should be redacted
Expand Down
49 changes: 33 additions & 16 deletions src/include/duckdb/main/secret/secret_manager.hpp
Expand Up @@ -25,24 +25,42 @@ struct SecretMatch {
public:
SecretMatch() : secret_entry(nullptr), score(NumericLimits<int64_t>::Minimum()) {
}
SecretMatch(SecretEntry &secret_entry, int64_t score) : secret_entry(&secret_entry), score(score) {

SecretMatch(const SecretMatch &other)
: secret_entry((other.secret_entry != nullptr) ? make_uniq<SecretEntry>(*other.secret_entry) : nullptr),
score(other.score) {
}

SecretMatch(SecretEntry &secret_entry, int64_t score)
: secret_entry(make_uniq<SecretEntry>(secret_entry)), score(score) {
}

SecretMatch &operator=(const SecretMatch &other) {
this->secret_entry = (other.secret_entry != nullptr) ? make_uniq<SecretEntry>(*other.secret_entry) : nullptr;
this->score = other.score;
return *this;
};

//! Get the secret
const BaseSecret &GetSecret();
const BaseSecret &GetSecret() const;

bool HasMatch() {
return secret_entry;
return secret_entry != nullptr;
}

optional_ptr<SecretEntry> secret_entry;
unique_ptr<SecretEntry> secret_entry;
int64_t score;
};

//! A Secret Entry in the secret manager
struct SecretEntry {
public:
SecretEntry(unique_ptr<const BaseSecret> secret) : secret(std::move(secret)) {};
SecretEntry(unique_ptr<const BaseSecret> secret) : secret(secret != nullptr ? secret->Clone() : nullptr) {};

SecretEntry(const SecretEntry &other)
: persist_type(other.persist_type), storage_mode(other.storage_mode),
secret((other.secret != nullptr) ? other.secret->Clone() : nullptr) {
}

//! Whether the secret is persistent
SecretPersistType persist_type;
Expand Down Expand Up @@ -96,26 +114,25 @@ class SecretManager {
//! Register a Secret Function i.e. a secret provider for a secret type
DUCKDB_API void RegisterSecretFunction(CreateSecretFunction function, OnCreateConflict on_conflict);
//! Register a secret by providing a secret manually
DUCKDB_API optional_ptr<SecretEntry> RegisterSecret(CatalogTransaction transaction,
unique_ptr<const BaseSecret> secret,
OnCreateConflict on_conflict, SecretPersistType persist_type,
const string &storage = "");
DUCKDB_API unique_ptr<SecretEntry> RegisterSecret(CatalogTransaction transaction,
unique_ptr<const BaseSecret> secret, OnCreateConflict on_conflict,
SecretPersistType persist_type, const string &storage = "");
//! Create a secret from a CreateSecretInfo
DUCKDB_API optional_ptr<SecretEntry> CreateSecret(ClientContext &context, const CreateSecretInfo &info);
DUCKDB_API unique_ptr<SecretEntry> CreateSecret(ClientContext &context, const CreateSecretInfo &info);
//! The Bind for create secret is done by the secret manager
DUCKDB_API BoundStatement BindCreateSecret(CatalogTransaction transaction, CreateSecretInfo &info);
//! Lookup the best matching secret by matching the secret scopes to the path
DUCKDB_API SecretMatch LookupSecret(CatalogTransaction transaction, const string &path, const string &type);
//! Get a secret by name, optionally from a specific storage
DUCKDB_API optional_ptr<SecretEntry> GetSecretByName(CatalogTransaction transaction, const string &name,
const string &storage = "");
DUCKDB_API unique_ptr<SecretEntry> GetSecretByName(CatalogTransaction transaction, const string &name,
const string &storage = "");
//! Delete a secret by name, optionally by providing the storage to drop from
DUCKDB_API void DropSecretByName(CatalogTransaction transaction, const string &name,
OnEntryNotFound on_entry_not_found,
SecretPersistType persist_type = SecretPersistType::DEFAULT,
const string &storage = "");
//! List all secrets from all secret storages
DUCKDB_API vector<reference<SecretEntry>> AllSecrets(CatalogTransaction transaction);
DUCKDB_API vector<SecretEntry> AllSecrets(CatalogTransaction transaction);

//! Secret Manager settings
DUCKDB_API virtual void SetEnablePersistentSecrets(bool enabled);
Expand All @@ -141,9 +158,9 @@ class SecretManager {
//! Lookup a CreateSecretFunction
optional_ptr<CreateSecretFunction> LookupFunctionInternal(const string &type, const string &provider);
//! Register a new Secret
optional_ptr<SecretEntry> RegisterSecretInternal(CatalogTransaction transaction,
unique_ptr<const BaseSecret> secret, OnCreateConflict on_conflict,
SecretPersistType persist_type, const string &storage = "");
unique_ptr<SecretEntry> RegisterSecretInternal(CatalogTransaction transaction, unique_ptr<const BaseSecret> secret,
OnCreateConflict on_conflict, SecretPersistType persist_type,
const string &storage = "");
//! Initialize the secret catalog_set and persistent secrets (lazily)
void InitializeSecrets(CatalogTransaction transaction);
//! Load a secret storage
Expand Down
20 changes: 10 additions & 10 deletions src/include/duckdb/main/secret/secret_storage.hpp
Expand Up @@ -38,19 +38,19 @@ class SecretStorage {
};

//! Store a secret
virtual optional_ptr<SecretEntry> StoreSecret(unique_ptr<const BaseSecret> secret, OnCreateConflict on_conflict,
optional_ptr<CatalogTransaction> transaction = nullptr) = 0;
virtual unique_ptr<SecretEntry> StoreSecret(unique_ptr<const BaseSecret> secret, OnCreateConflict on_conflict,
optional_ptr<CatalogTransaction> transaction = nullptr) = 0;
//! Get all secrets
virtual vector<reference<SecretEntry>> AllSecrets(optional_ptr<CatalogTransaction> transaction = nullptr) = 0;
virtual vector<SecretEntry> AllSecrets(optional_ptr<CatalogTransaction> transaction = nullptr) = 0;
//! Drop secret by name
virtual void DropSecretByName(const string &name, OnEntryNotFound on_entry_not_found,
optional_ptr<CatalogTransaction> transaction = nullptr) = 0;
//! Get best match
virtual SecretMatch LookupSecret(const string &path, const string &type,
optional_ptr<CatalogTransaction> transaction = nullptr) = 0;
//! Get a secret by name
virtual optional_ptr<SecretEntry> GetSecretByName(const string &name,
optional_ptr<CatalogTransaction> transaction = nullptr) = 0;
virtual unique_ptr<SecretEntry> GetSecretByName(const string &name,
optional_ptr<CatalogTransaction> transaction = nullptr) = 0;

//! Return the offset associated to this storage for tie-breaking secrets between storages
virtual int64_t GetTieBreakOffset() = 0;
Expand Down Expand Up @@ -103,15 +103,15 @@ class CatalogSetSecretStorage : public SecretStorage {
return storage_name;
};

virtual optional_ptr<SecretEntry> StoreSecret(unique_ptr<const BaseSecret> secret, OnCreateConflict on_conflict,
optional_ptr<CatalogTransaction> transaction = nullptr) override;
vector<reference<SecretEntry>> AllSecrets(optional_ptr<CatalogTransaction> transaction = nullptr) override;
virtual unique_ptr<SecretEntry> StoreSecret(unique_ptr<const BaseSecret> secret, OnCreateConflict on_conflict,
optional_ptr<CatalogTransaction> transaction = nullptr) override;
vector<SecretEntry> AllSecrets(optional_ptr<CatalogTransaction> transaction = nullptr) override;
void DropSecretByName(const string &name, OnEntryNotFound on_entry_not_found,
optional_ptr<CatalogTransaction> transaction = nullptr) override;
SecretMatch LookupSecret(const string &path, const string &type,
optional_ptr<CatalogTransaction> transaction = nullptr) override;
optional_ptr<SecretEntry> GetSecretByName(const string &name,
optional_ptr<CatalogTransaction> transaction = nullptr) override;
unique_ptr<SecretEntry> GetSecretByName(const string &name,
optional_ptr<CatalogTransaction> transaction = nullptr) override;

protected:
//! Callback called on Store to allow child classes to implement persistence.
Expand Down
37 changes: 18 additions & 19 deletions src/main/secret/secret_manager.cpp
Expand Up @@ -31,8 +31,8 @@ SecretCatalogEntry::SecretCatalogEntry(unique_ptr<const BaseSecret> secret_p, Ca
secret = make_uniq<SecretEntry>(std::move(secret_p));
}

const BaseSecret &SecretMatch::GetSecret() {
return *secret_entry.get()->secret;
const BaseSecret &SecretMatch::GetSecret() const {
return *secret_entry->secret;
}

constexpr const char *SecretManager::TEMPORARY_STORAGE_NAME;
Expand Down Expand Up @@ -121,18 +121,17 @@ void SecretManager::RegisterSecretFunction(CreateSecretFunction function, OnCrea
secret_functions.insert({function.secret_type, new_set});
}

optional_ptr<SecretEntry> SecretManager::RegisterSecret(CatalogTransaction transaction,
unique_ptr<const BaseSecret> secret,
OnCreateConflict on_conflict, SecretPersistType persist_type,
const string &storage) {
unique_ptr<SecretEntry> SecretManager::RegisterSecret(CatalogTransaction transaction,
unique_ptr<const BaseSecret> secret, OnCreateConflict on_conflict,
SecretPersistType persist_type, const string &storage) {
InitializeSecrets(transaction);
return RegisterSecretInternal(transaction, std::move(secret), on_conflict, persist_type, storage);
}

optional_ptr<SecretEntry> SecretManager::RegisterSecretInternal(CatalogTransaction transaction,
unique_ptr<const BaseSecret> secret,
OnCreateConflict on_conflict,
SecretPersistType persist_type, const string &storage) {
unique_ptr<SecretEntry> SecretManager::RegisterSecretInternal(CatalogTransaction transaction,
unique_ptr<const BaseSecret> secret,
OnCreateConflict on_conflict,
SecretPersistType persist_type, const string &storage) {
//! Ensure we only create secrets for known types;
LookupTypeInternal(secret->GetType());

Expand Down Expand Up @@ -207,7 +206,7 @@ optional_ptr<CreateSecretFunction> SecretManager::LookupFunctionInternal(const s
return nullptr;
}

optional_ptr<SecretEntry> SecretManager::CreateSecret(ClientContext &context, const CreateSecretInfo &info) {
unique_ptr<SecretEntry> SecretManager::CreateSecret(ClientContext &context, const CreateSecretInfo &info) {
// Note that a context is required for CreateSecret, as the CreateSecretFunction expects one
auto transaction = CatalogTransaction::GetSystemCatalogTransaction(context);
InitializeSecrets(transaction);
Expand Down Expand Up @@ -294,15 +293,15 @@ SecretMatch SecretManager::LookupSecret(CatalogTransaction transaction, const st
InitializeSecrets(transaction);

int64_t best_match_score = NumericLimits<int64_t>::Minimum();
optional_ptr<SecretEntry> best_match = nullptr;
unique_ptr<SecretEntry> best_match = nullptr;

for (const auto &storage_ref : GetSecretStorages()) {
if (!storage_ref.get().IncludeInLookups()) {
continue;
}
auto match = storage_ref.get().LookupSecret(path, type, &transaction);
if (match.HasMatch() && match.score > best_match_score) {
best_match = match.secret_entry.get();
best_match = std::move(match.secret_entry);
best_match_score = match.score;
}
}
Expand All @@ -314,11 +313,11 @@ SecretMatch SecretManager::LookupSecret(CatalogTransaction transaction, const st
return SecretMatch();
}

optional_ptr<SecretEntry> SecretManager::GetSecretByName(CatalogTransaction transaction, const string &name,
const string &storage) {
unique_ptr<SecretEntry> SecretManager::GetSecretByName(CatalogTransaction transaction, const string &name,
const string &storage) {
InitializeSecrets(transaction);

optional_ptr<SecretEntry> result;
unique_ptr<SecretEntry> result = nullptr;
bool found = false;

if (!storage.empty()) {
Expand All @@ -339,7 +338,7 @@ optional_ptr<SecretEntry> SecretManager::GetSecretByName(CatalogTransaction tran
"Ambiguity detected for secret name '%s', secret occurs in multiple storage backends.", name);
}

result = lookup;
result = std::move(lookup);
found = true;
}
}
Expand Down Expand Up @@ -428,10 +427,10 @@ SecretType SecretManager::LookupTypeInternal(const string &type) {
throw InvalidInputException("Secret type '%s' not found", type);
}

vector<reference<SecretEntry>> SecretManager::AllSecrets(CatalogTransaction transaction) {
vector<SecretEntry> SecretManager::AllSecrets(CatalogTransaction transaction) {
InitializeSecrets(transaction);

vector<reference<SecretEntry>> result;
vector<SecretEntry> result;

// Add results from all backends to the result set
for (const auto &backend : secret_storages) {
Expand Down
18 changes: 9 additions & 9 deletions src/main/secret/secret_storage.cpp
Expand Up @@ -40,9 +40,9 @@ SecretMatch SecretStorage::SelectBestMatch(SecretEntry &secret_entry, const stri
}
}

optional_ptr<SecretEntry> CatalogSetSecretStorage::StoreSecret(unique_ptr<const BaseSecret> secret,
OnCreateConflict on_conflict,
optional_ptr<CatalogTransaction> transaction) {
unique_ptr<SecretEntry> CatalogSetSecretStorage::StoreSecret(unique_ptr<const BaseSecret> secret,
OnCreateConflict on_conflict,
optional_ptr<CatalogTransaction> transaction) {
if (secrets->GetEntry(GetTransactionOrDefault(transaction), secret->GetName())) {
if (on_conflict == OnCreateConflict::ERROR_ON_CONFLICT) {
string persist_string = persistent ? "Persistent" : "Temporary";
Expand Down Expand Up @@ -71,11 +71,11 @@ optional_ptr<SecretEntry> CatalogSetSecretStorage::StoreSecret(unique_ptr<const

auto secret_catalog_entry =
&secrets->GetEntry(GetTransactionOrDefault(transaction), secret_name)->Cast<SecretCatalogEntry>();
return secret_catalog_entry->secret;
return make_uniq<SecretEntry>(*secret_catalog_entry->secret);
}

vector<reference<SecretEntry>> CatalogSetSecretStorage::AllSecrets(optional_ptr<CatalogTransaction> transaction) {
vector<reference<SecretEntry>> ret_value;
vector<SecretEntry> CatalogSetSecretStorage::AllSecrets(optional_ptr<CatalogTransaction> transaction) {
vector<SecretEntry> ret_value;
const std::function<void(CatalogEntry &)> callback = [&](CatalogEntry &entry) {
auto &cast_entry = entry.Cast<SecretCatalogEntry>();
ret_value.push_back(*cast_entry.secret);
Expand Down Expand Up @@ -117,13 +117,13 @@ SecretMatch CatalogSetSecretStorage::LookupSecret(const string &path, const stri
return SecretMatch();
}

optional_ptr<SecretEntry> CatalogSetSecretStorage::GetSecretByName(const string &name,
optional_ptr<CatalogTransaction> transaction) {
unique_ptr<SecretEntry> CatalogSetSecretStorage::GetSecretByName(const string &name,
optional_ptr<CatalogTransaction> transaction) {
auto res = secrets->GetEntry(GetTransactionOrDefault(transaction), name);

if (res) {
auto &cast_entry = res->Cast<SecretCatalogEntry>();
return cast_entry.secret;
return make_uniq<SecretEntry>(*cast_entry.secret);
}

return nullptr;
Expand Down