Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically Create Database if Not Present #49

Merged
merged 6 commits into from Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -24,6 +24,9 @@ class SqlDurabilityOptions
[JsonProperty("taskEventBatchSize")]
public int TaskEventBatchSize { get; set; } = 10;

[JsonProperty("createDatabaseIfNotExists")]
public bool CreateDatabaseIfNotExists { get; set; }

internal ILoggerFactory LoggerFactory { get; set; } = NullLoggerFactory.Instance;

internal SqlOrchestrationServiceSettings GetOrchestrationServiceSettings(
Expand Down Expand Up @@ -54,9 +57,10 @@ class SqlDurabilityOptions

var settings = new SqlOrchestrationServiceSettings(connectionString, this.TaskHubName)
{
CreateDatabaseIfNotExists = this.CreateDatabaseIfNotExists,
LoggerFactory = this.LoggerFactory,
WorkItemLockTimeout = this.TaskEventLockTimeout,
WorkItemBatchSize = this.TaskEventBatchSize,
WorkItemLockTimeout = this.TaskEventLockTimeout,
};

if (extensionOptions.MaxConcurrentActivityFunctions.HasValue)
Expand Down
19 changes: 18 additions & 1 deletion src/DurableTask.SqlServer/LogHelper.cs
Expand Up @@ -44,7 +44,7 @@ public void AcquiredAppLock(int statusCode, Stopwatch latencyStopwatch)
var logEvent = new LogEvents.AcquiredAppLockEvent(
statusCode,
latencyStopwatch.ElapsedMilliseconds);

this.WriteLog(logEvent);
}

Expand Down Expand Up @@ -112,6 +112,23 @@ public void PurgedInstances(string userId, int purgedInstanceCount, Stopwatch la
this.WriteLog(logEvent);
}

public void CommandCompleted(string commandText, Stopwatch latencyStopwatch, int retryCount, string? instanceId)
{
var logEvent = new LogEvents.CommandCompletedEvent(
commandText,
latencyStopwatch.ElapsedMilliseconds,
retryCount,
instanceId);

this.WriteLog(logEvent);
}

public void CreatedDatabase(string databaseName)
{
var logEvent = new LogEvents.CreatedDatabaseEvent(databaseName);
this.WriteLog(logEvent);
}

void WriteLog(ILogEvent logEvent)
{
// LogDurableEvent is an extension method defined in DurableTask.Core
Expand Down
34 changes: 34 additions & 0 deletions src/DurableTask.SqlServer/Logging/DefaultEventSource.cs
Expand Up @@ -232,5 +232,39 @@ unsafe void AcquiredAppLockCore(int eventId, int statusCode, long latencyMs, str
AppName,
ExtensionVersion);
}

[Event(EventIds.CommandCompleted, Level = EventLevel.Verbose)]
public void CommandCompleted(
string? InstanceId,
string CommandText,
long LatencyMs,
int RetryCount,
string AppName,
string ExtensionVersion)
{
// TODO: Switch to WriteEventCore for better performance
this.WriteEvent(
EventIds.CommandCompleted,
InstanceId ?? string.Empty,
CommandText,
LatencyMs,
RetryCount,
AppName,
ExtensionVersion);
}

[Event(EventIds.CreatedDatabase, Level = EventLevel.Informational)]
internal void CreatedDatabase(
string DatabaseName,
string AppName,
string ExtensionVersion)
{
// TODO: Use WriteEventCore for better performance
this.WriteEvent(
EventIds.CreatedDatabase,
DatabaseName,
AppName,
ExtensionVersion);
}
}
}
2 changes: 2 additions & 0 deletions src/DurableTask.SqlServer/Logging/EventIds.cs
Expand Up @@ -20,5 +20,7 @@ static class EventIds
public const int TransientDatabaseFailure = 308;
public const int ReplicaCountChangeRecommended = 309;
public const int PurgedInstances = 310;
public const int CommandCompleted = 311;
public const int CreatedDatabase = 312;
}
}
66 changes: 66 additions & 0 deletions src/DurableTask.SqlServer/Logging/LogEvents.cs
Expand Up @@ -415,5 +415,71 @@ public PurgedInstances(string userId, int purgedInstanceCount, long latencyMs)
DTUtils.AppName,
DTUtils.ExtensionVersionString);
}

internal class CommandCompletedEvent : StructuredLogEvent, IEventSourceEvent
{
public CommandCompletedEvent(string commandText, long latencyMs, int retryCount, string? instanceId)
{
this.CommandText = commandText;
this.LatencyMs = latencyMs;
this.RetryCount = retryCount;
this.InstanceId = instanceId;
}

[StructuredLogField]
public string CommandText { get; }

[StructuredLogField]
public long LatencyMs { get; }

[StructuredLogField]
public int RetryCount { get; }

[StructuredLogField]
public string? InstanceId { get; }

public override LogLevel Level => LogLevel.Debug;

public override EventId EventId => new EventId(
EventIds.CommandCompleted,
nameof(EventIds.CommandCompleted));

protected override string CreateLogMessage() =>
string.IsNullOrEmpty(this.InstanceId) ?
$"Executed SQL statement(s) '{this.CommandText}' in {this.LatencyMs}ms" :
$"{this.InstanceId}: Executed SQL statement(s) '{this.CommandText}' in {this.LatencyMs}ms";

void IEventSourceEvent.WriteEventSource() =>
DefaultEventSource.Log.CommandCompleted(
this.InstanceId,
this.CommandText,
this.LatencyMs,
this.RetryCount,
DTUtils.AppName,
DTUtils.ExtensionVersionString);
}

internal class CreatedDatabaseEvent : StructuredLogEvent, IEventSourceEvent
{
public CreatedDatabaseEvent(string databaseName) =>
this.DatabaseName = databaseName;

[StructuredLogField]
public string DatabaseName { get; }

public override EventId EventId => new EventId(
EventIds.CreatedDatabase,
nameof(EventIds.CreatedDatabase));

public override LogLevel Level => LogLevel.Information;

protected override string CreateLogMessage() => $"Created database '{this.DatabaseName}'.";

void IEventSourceEvent.WriteEventSource() =>
DefaultEventSource.Log.CreatedDatabase(
this.DatabaseName,
DTUtils.AppName,
DTUtils.ExtensionVersionString);
}
}
}
50 changes: 48 additions & 2 deletions src/DurableTask.SqlServer/SqlDbManager.cs
Expand Up @@ -31,7 +31,7 @@ public SqlDbManager(SqlOrchestrationServiceSettings settings, LogHelper traceHel
public async Task CreateOrUpgradeSchemaAsync(bool recreateIfExists)
{
// Prevent other create or delete operations from executing at the same time.
await using DatabaseLock dbLock = await this.AcquireDatabaseLockAsync();
await using DatabaseLock dbLock = await this.AcquireDatabaseLockAsync(this.settings.CreateDatabaseIfNotExists);

var currentSchemaVersion = new SemanticVersion(0, 0, 0);
if (recreateIfExists)
Expand Down Expand Up @@ -131,8 +131,13 @@ public async Task DeleteSchemaAsync()

Task DropSchemaAsync(DatabaseLock dbLock) => this.ExecuteSqlScriptAsync("drop-schema.sql", dbLock);

async Task<DatabaseLock> AcquireDatabaseLockAsync()
async Task<DatabaseLock> AcquireDatabaseLockAsync(bool createDatabaseIfNotExists = false)
{
if (createDatabaseIfNotExists)
{
await this.EnsureDatabaseExistsAsync();
}

SqlConnection connection = this.settings.CreateConnection();
await connection.OpenAsync();

Expand Down Expand Up @@ -171,6 +176,47 @@ async Task<DatabaseLock> AcquireDatabaseLockAsync()
return new DatabaseLock(connection, lockTransaction);
}

async Task EnsureDatabaseExistsAsync()
{
// Note that we may not be able to connect to the DB, let alone obtain the lock,
// if the database does not exist yet. So we obtain a connection to the 'master' database for now.
using SqlConnection connection = this.settings.CreateConnection("master");
await connection.OpenAsync();

if (!await this.DoesDatabaseExistAsync(this.settings.DatabaseName, connection))
{
await this.CreateDatabaseAsync(this.settings.DatabaseName, connection);
}
}

async Task<bool> DoesDatabaseExistAsync(string databaseName, SqlConnection connection)
{
using SqlCommand command = connection.CreateCommand();
command.CommandText = "SELECT 1 FROM sys.databases WHERE name = @databaseName";
command.Parameters.AddWithValue("@databaseName", databaseName);

bool exists = (int?)await SqlUtils.ExecuteScalarAsync(command, this.traceHelper) == 1;
return exists;
}

async Task<bool> CreateDatabaseAsync(string databaseName, SqlConnection connection)
{
using SqlCommand command = connection.CreateCommand();
command.CommandText = $"CREATE DATABASE {SqlIdentifier.Escape(databaseName)} COLLATE Latin1_General_100_BIN2_UTF8";

try
{
await SqlUtils.ExecuteNonQueryAsync(command, this.traceHelper);
this.traceHelper.CreatedDatabase(databaseName);
cgillum marked this conversation as resolved.
Show resolved Hide resolved
return true;
}
catch (SqlException e) when (e.Number == 1801 /* Database already exists */)
{
// Ignore
return false;
}
}

async Task ExecuteSqlScriptAsync(string scriptName, DatabaseLock dbLock)
{
// We don't actually use the lock here, but want to make sure the caller is holding it.
Expand Down
40 changes: 40 additions & 0 deletions src/DurableTask.SqlServer/SqlIdentifier.cs
@@ -0,0 +1,40 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License. See LICENSE in the project root for license information.

namespace DurableTask.SqlServer
{
using System;
using System.Text;

static class SqlIdentifier
{
public static string Escape(string value)
{
if (value == null)
{
throw new ArgumentNullException(nameof(value));
}

if (value == "")
{
throw new ArgumentException("Value cannot be empty.", nameof(value));
}

StringBuilder builder = new StringBuilder();

builder.Append('[');
foreach (char c in value)
{
if (c == ']')
{
builder.Append(']');
}

builder.Append(c);
}
builder.Append(']');

return builder.ToString();
}
}
}
40 changes: 40 additions & 0 deletions src/DurableTask.SqlServer/SqlOrchestrationServiceSettings.cs
Expand Up @@ -35,6 +35,12 @@ public SqlOrchestrationServiceSettings(string connectionString, string? taskHubN
ApplicationName = this.TaskHubName,
};

if (string.IsNullOrEmpty(builder.InitialCatalog))
{
throw new ArgumentException("Database or Initial Catalog must be specified in the connection string.", nameof(connectionString));
}

this.DatabaseName = builder.InitialCatalog;
this.TaskHubConnectionString = builder.ToString();
}

Expand Down Expand Up @@ -79,6 +85,15 @@ public SqlOrchestrationServiceSettings(string connectionString, string? taskHubN
[JsonProperty("maxActiveOrchestrations")]
public int MaxActiveOrchestrations { get; set; } = Environment.ProcessorCount;

/// <summary>
/// Gets or sets a flag indicating whether the database should be automatically created if it does not exist.
/// </summary>
/// <remarks>
/// If <see langword="true"/>, the user requires the permission <c>CREATE DATABASE</c>.
/// </remarks>
[JsonProperty("createDatabaseIfNotExists")]
public bool CreateDatabaseIfNotExists { get; set; }
cgillum marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// Gets a SQL connection string associated with the configured task hub.
/// </summary>
Expand All @@ -91,6 +106,31 @@ public SqlOrchestrationServiceSettings(string connectionString, string? taskHubN
[JsonIgnore]
public ILoggerFactory LoggerFactory { get; set; } = NullLoggerFactory.Instance;

/// <summary>
/// Gets or sets the name of the database that contains the instance store.
/// </summary>
/// <remarks>
/// This value is derived from the value of the <c>"Initial Catalog"</c> or <c>"Database"</c>
/// attribute in the <see cref="TaskHubConnectionString"/>.
/// </remarks>
[JsonIgnore]
public string DatabaseName { get; set; }

internal SqlConnection CreateConnection() => new SqlConnection(this.TaskHubConnectionString);

internal SqlConnection CreateConnection(string databaseName)
{
if (databaseName == this.DatabaseName)
cgillum marked this conversation as resolved.
Show resolved Hide resolved
{
return this.CreateConnection();
}

SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(this.TaskHubConnectionString)
{
InitialCatalog = databaseName
};

return new SqlConnection(builder.ToString());
}
}
}
25 changes: 23 additions & 2 deletions src/DurableTask.SqlServer/SqlUtils.cs
Expand Up @@ -52,7 +52,7 @@ public static HistoryEvent GetHistoryEvent(this DbDataReader reader, bool isOrch
int eventId = GetTaskId(reader);

HistoryEvent historyEvent;
switch(eventType)
switch (eventType)
{
case EventType.ContinueAsNew:
historyEvent = new ContinueAsNewEvent(eventId, GetPayloadText(reader));
Expand Down Expand Up @@ -396,6 +396,19 @@ static IEnumerable<SqlDataRecord> GetInstanceIdRecords(IEnumerable<string> insta
cmd => cmd.ExecuteNonQueryAsync(cancellationToken));
}

public static Task<object> ExecuteScalarAsync(
DbCommand command,
LogHelper traceHelper,
string? instanceId = null,
CancellationToken cancellationToken = default)
{
return ExecuteSprocAndTraceAsync(
command,
traceHelper,
instanceId,
cmd => cmd.ExecuteScalarAsync(cancellationToken));
}

static async Task<T> ExecuteSprocAndTraceAsync<T>(
DbCommand command,
LogHelper traceHelper,
Expand All @@ -410,7 +423,15 @@ static IEnumerable<SqlDataRecord> GetInstanceIdRecords(IEnumerable<string> insta
finally
{
context.LatencyStopwatch.Stop();
traceHelper.SprocCompleted(command.CommandText, context.LatencyStopwatch, context.RetryCount, instanceId);
switch (command.CommandType)
{
case CommandType.StoredProcedure:
traceHelper.SprocCompleted(command.CommandText, context.LatencyStopwatch, context.RetryCount, instanceId);
break;
default:
traceHelper.CommandCompleted(command.CommandText, context.LatencyStopwatch, context.RetryCount, instanceId);
break;
}
}
}

Expand Down