Skip to content

Commit

Permalink
Automatically Create Database if Not Present (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsugarman committed Oct 28, 2021
1 parent 940a17e commit 215befc
Show file tree
Hide file tree
Showing 15 changed files with 525 additions and 109 deletions.
Expand Up @@ -24,6 +24,9 @@ class SqlDurabilityOptions
[JsonProperty("taskEventBatchSize")]
public int TaskEventBatchSize { get; set; } = 10;

[JsonProperty("createDatabaseIfNotExists")]
public bool CreateDatabaseIfNotExists { get; set; }

internal ILoggerFactory LoggerFactory { get; set; } = NullLoggerFactory.Instance;

internal SqlOrchestrationServiceSettings GetOrchestrationServiceSettings(
Expand Down Expand Up @@ -54,9 +57,10 @@ class SqlDurabilityOptions

var settings = new SqlOrchestrationServiceSettings(connectionString, this.TaskHubName)
{
CreateDatabaseIfNotExists = this.CreateDatabaseIfNotExists,
LoggerFactory = this.LoggerFactory,
WorkItemLockTimeout = this.TaskEventLockTimeout,
WorkItemBatchSize = this.TaskEventBatchSize,
WorkItemLockTimeout = this.TaskEventLockTimeout,
};

if (extensionOptions.MaxConcurrentActivityFunctions.HasValue)
Expand Down
19 changes: 18 additions & 1 deletion src/DurableTask.SqlServer/LogHelper.cs
Expand Up @@ -44,7 +44,7 @@ public void AcquiredAppLock(int statusCode, Stopwatch latencyStopwatch)
var logEvent = new LogEvents.AcquiredAppLockEvent(
statusCode,
latencyStopwatch.ElapsedMilliseconds);

this.WriteLog(logEvent);
}

Expand Down Expand Up @@ -112,6 +112,23 @@ public void PurgedInstances(string userId, int purgedInstanceCount, Stopwatch la
this.WriteLog(logEvent);
}

public void CommandCompleted(string commandText, Stopwatch latencyStopwatch, int retryCount, string? instanceId)
{
var logEvent = new LogEvents.CommandCompletedEvent(
commandText,
latencyStopwatch.ElapsedMilliseconds,
retryCount,
instanceId);

this.WriteLog(logEvent);
}

public void CreatedDatabase(string databaseName)
{
var logEvent = new LogEvents.CreatedDatabaseEvent(databaseName);
this.WriteLog(logEvent);
}

void WriteLog(ILogEvent logEvent)
{
// LogDurableEvent is an extension method defined in DurableTask.Core
Expand Down
34 changes: 34 additions & 0 deletions src/DurableTask.SqlServer/Logging/DefaultEventSource.cs
Expand Up @@ -232,5 +232,39 @@ unsafe void AcquiredAppLockCore(int eventId, int statusCode, long latencyMs, str
AppName,
ExtensionVersion);
}

[Event(EventIds.CommandCompleted, Level = EventLevel.Verbose)]
public void CommandCompleted(
string? InstanceId,
string CommandText,
long LatencyMs,
int RetryCount,
string AppName,
string ExtensionVersion)
{
// TODO: Switch to WriteEventCore for better performance
this.WriteEvent(
EventIds.CommandCompleted,
InstanceId ?? string.Empty,
CommandText,
LatencyMs,
RetryCount,
AppName,
ExtensionVersion);
}

[Event(EventIds.CreatedDatabase, Level = EventLevel.Informational)]
internal void CreatedDatabase(
string DatabaseName,
string AppName,
string ExtensionVersion)
{
// TODO: Use WriteEventCore for better performance
this.WriteEvent(
EventIds.CreatedDatabase,
DatabaseName,
AppName,
ExtensionVersion);
}
}
}
2 changes: 2 additions & 0 deletions src/DurableTask.SqlServer/Logging/EventIds.cs
Expand Up @@ -20,5 +20,7 @@ static class EventIds
public const int TransientDatabaseFailure = 308;
public const int ReplicaCountChangeRecommended = 309;
public const int PurgedInstances = 310;
public const int CommandCompleted = 311;
public const int CreatedDatabase = 312;
}
}
66 changes: 66 additions & 0 deletions src/DurableTask.SqlServer/Logging/LogEvents.cs
Expand Up @@ -415,5 +415,71 @@ public PurgedInstances(string userId, int purgedInstanceCount, long latencyMs)
DTUtils.AppName,
DTUtils.ExtensionVersionString);
}

internal class CommandCompletedEvent : StructuredLogEvent, IEventSourceEvent
{
public CommandCompletedEvent(string commandText, long latencyMs, int retryCount, string? instanceId)
{
this.CommandText = commandText;
this.LatencyMs = latencyMs;
this.RetryCount = retryCount;
this.InstanceId = instanceId;
}

[StructuredLogField]
public string CommandText { get; }

[StructuredLogField]
public long LatencyMs { get; }

[StructuredLogField]
public int RetryCount { get; }

[StructuredLogField]
public string? InstanceId { get; }

public override LogLevel Level => LogLevel.Debug;

public override EventId EventId => new EventId(
EventIds.CommandCompleted,
nameof(EventIds.CommandCompleted));

protected override string CreateLogMessage() =>
string.IsNullOrEmpty(this.InstanceId) ?
$"Executed SQL statement(s) '{this.CommandText}' in {this.LatencyMs}ms" :
$"{this.InstanceId}: Executed SQL statement(s) '{this.CommandText}' in {this.LatencyMs}ms";

void IEventSourceEvent.WriteEventSource() =>
DefaultEventSource.Log.CommandCompleted(
this.InstanceId,
this.CommandText,
this.LatencyMs,
this.RetryCount,
DTUtils.AppName,
DTUtils.ExtensionVersionString);
}

internal class CreatedDatabaseEvent : StructuredLogEvent, IEventSourceEvent
{
public CreatedDatabaseEvent(string databaseName) =>
this.DatabaseName = databaseName;

[StructuredLogField]
public string DatabaseName { get; }

public override EventId EventId => new EventId(
EventIds.CreatedDatabase,
nameof(EventIds.CreatedDatabase));

public override LogLevel Level => LogLevel.Information;

protected override string CreateLogMessage() => $"Created database '{this.DatabaseName}'.";

void IEventSourceEvent.WriteEventSource() =>
DefaultEventSource.Log.CreatedDatabase(
this.DatabaseName,
DTUtils.AppName,
DTUtils.ExtensionVersionString);
}
}
}
50 changes: 48 additions & 2 deletions src/DurableTask.SqlServer/SqlDbManager.cs
Expand Up @@ -31,7 +31,7 @@ public SqlDbManager(SqlOrchestrationServiceSettings settings, LogHelper traceHel
public async Task CreateOrUpgradeSchemaAsync(bool recreateIfExists)
{
// Prevent other create or delete operations from executing at the same time.
await using DatabaseLock dbLock = await this.AcquireDatabaseLockAsync();
await using DatabaseLock dbLock = await this.AcquireDatabaseLockAsync(this.settings.CreateDatabaseIfNotExists);

var currentSchemaVersion = new SemanticVersion(0, 0, 0);
if (recreateIfExists)
Expand Down Expand Up @@ -131,8 +131,13 @@ public async Task DeleteSchemaAsync()

Task DropSchemaAsync(DatabaseLock dbLock) => this.ExecuteSqlScriptAsync("drop-schema.sql", dbLock);

async Task<DatabaseLock> AcquireDatabaseLockAsync()
async Task<DatabaseLock> AcquireDatabaseLockAsync(bool createDatabaseIfNotExists = false)
{
if (createDatabaseIfNotExists)
{
await this.EnsureDatabaseExistsAsync();
}

SqlConnection connection = this.settings.CreateConnection();
await connection.OpenAsync();

Expand Down Expand Up @@ -171,6 +176,47 @@ async Task<DatabaseLock> AcquireDatabaseLockAsync()
return new DatabaseLock(connection, lockTransaction);
}

async Task EnsureDatabaseExistsAsync()
{
// Note that we may not be able to connect to the DB, let alone obtain the lock,
// if the database does not exist yet. So we obtain a connection to the 'master' database for now.
using SqlConnection connection = this.settings.CreateConnection("master");
await connection.OpenAsync();

if (!await this.DoesDatabaseExistAsync(this.settings.DatabaseName, connection))
{
await this.CreateDatabaseAsync(this.settings.DatabaseName, connection);
}
}

async Task<bool> DoesDatabaseExistAsync(string databaseName, SqlConnection connection)
{
using SqlCommand command = connection.CreateCommand();
command.CommandText = "SELECT 1 FROM sys.databases WHERE name = @databaseName";
command.Parameters.AddWithValue("@databaseName", databaseName);

bool exists = (int?)await SqlUtils.ExecuteScalarAsync(command, this.traceHelper) == 1;
return exists;
}

async Task<bool> CreateDatabaseAsync(string databaseName, SqlConnection connection)
{
using SqlCommand command = connection.CreateCommand();
command.CommandText = $"CREATE DATABASE {SqlIdentifier.Escape(databaseName)} COLLATE Latin1_General_100_BIN2_UTF8";

try
{
await SqlUtils.ExecuteNonQueryAsync(command, this.traceHelper);
this.traceHelper.CreatedDatabase(databaseName);
return true;
}
catch (SqlException e) when (e.Number == 1801 /* Database already exists */)
{
// Ignore
return false;
}
}

async Task ExecuteSqlScriptAsync(string scriptName, DatabaseLock dbLock)
{
// We don't actually use the lock here, but want to make sure the caller is holding it.
Expand Down
40 changes: 40 additions & 0 deletions src/DurableTask.SqlServer/SqlIdentifier.cs
@@ -0,0 +1,40 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License. See LICENSE in the project root for license information.

namespace DurableTask.SqlServer
{
using System;
using System.Text;

static class SqlIdentifier
{
public static string Escape(string value)
{
if (value == null)
{
throw new ArgumentNullException(nameof(value));
}

if (value == "")
{
throw new ArgumentException("Value cannot be empty.", nameof(value));
}

StringBuilder builder = new StringBuilder();

builder.Append('[');
foreach (char c in value)
{
if (c == ']')
{
builder.Append(']');
}

builder.Append(c);
}
builder.Append(']');

return builder.ToString();
}
}
}
40 changes: 40 additions & 0 deletions src/DurableTask.SqlServer/SqlOrchestrationServiceSettings.cs
Expand Up @@ -35,6 +35,12 @@ public SqlOrchestrationServiceSettings(string connectionString, string? taskHubN
ApplicationName = this.TaskHubName,
};

if (string.IsNullOrEmpty(builder.InitialCatalog))
{
throw new ArgumentException("Database or Initial Catalog must be specified in the connection string.", nameof(connectionString));
}

this.DatabaseName = builder.InitialCatalog;
this.TaskHubConnectionString = builder.ToString();
}

Expand Down Expand Up @@ -79,6 +85,15 @@ public SqlOrchestrationServiceSettings(string connectionString, string? taskHubN
[JsonProperty("maxActiveOrchestrations")]
public int MaxActiveOrchestrations { get; set; } = Environment.ProcessorCount;

/// <summary>
/// Gets or sets a flag indicating whether the database should be automatically created if it does not exist.
/// </summary>
/// <remarks>
/// If <see langword="true"/>, the user requires the permission <c>CREATE DATABASE</c>.
/// </remarks>
[JsonProperty("createDatabaseIfNotExists")]
public bool CreateDatabaseIfNotExists { get; set; }

/// <summary>
/// Gets a SQL connection string associated with the configured task hub.
/// </summary>
Expand All @@ -91,6 +106,31 @@ public SqlOrchestrationServiceSettings(string connectionString, string? taskHubN
[JsonIgnore]
public ILoggerFactory LoggerFactory { get; set; } = NullLoggerFactory.Instance;

/// <summary>
/// Gets or sets the name of the database that contains the instance store.
/// </summary>
/// <remarks>
/// This value is derived from the value of the <c>"Initial Catalog"</c> or <c>"Database"</c>
/// attribute in the <see cref="TaskHubConnectionString"/>.
/// </remarks>
[JsonIgnore]
public string DatabaseName { get; set; }

internal SqlConnection CreateConnection() => new SqlConnection(this.TaskHubConnectionString);

internal SqlConnection CreateConnection(string databaseName)
{
if (databaseName == this.DatabaseName)
{
return this.CreateConnection();
}

SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(this.TaskHubConnectionString)
{
InitialCatalog = databaseName
};

return new SqlConnection(builder.ToString());
}
}
}
25 changes: 23 additions & 2 deletions src/DurableTask.SqlServer/SqlUtils.cs
Expand Up @@ -52,7 +52,7 @@ public static HistoryEvent GetHistoryEvent(this DbDataReader reader, bool isOrch
int eventId = GetTaskId(reader);

HistoryEvent historyEvent;
switch(eventType)
switch (eventType)
{
case EventType.ContinueAsNew:
historyEvent = new ContinueAsNewEvent(eventId, GetPayloadText(reader));
Expand Down Expand Up @@ -396,6 +396,19 @@ static IEnumerable<SqlDataRecord> GetInstanceIdRecords(IEnumerable<string> insta
cmd => cmd.ExecuteNonQueryAsync(cancellationToken));
}

public static Task<object> ExecuteScalarAsync(
DbCommand command,
LogHelper traceHelper,
string? instanceId = null,
CancellationToken cancellationToken = default)
{
return ExecuteSprocAndTraceAsync(
command,
traceHelper,
instanceId,
cmd => cmd.ExecuteScalarAsync(cancellationToken));
}

static async Task<T> ExecuteSprocAndTraceAsync<T>(
DbCommand command,
LogHelper traceHelper,
Expand All @@ -410,7 +423,15 @@ static IEnumerable<SqlDataRecord> GetInstanceIdRecords(IEnumerable<string> insta
finally
{
context.LatencyStopwatch.Stop();
traceHelper.SprocCompleted(command.CommandText, context.LatencyStopwatch, context.RetryCount, instanceId);
switch (command.CommandType)
{
case CommandType.StoredProcedure:
traceHelper.SprocCompleted(command.CommandText, context.LatencyStopwatch, context.RetryCount, instanceId);
break;
default:
traceHelper.CommandCompleted(command.CommandText, context.LatencyStopwatch, context.RetryCount, instanceId);
break;
}
}
}

Expand Down

0 comments on commit 215befc

Please sign in to comment.