Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix serialization tests & don't block thread #846

Merged
merged 4 commits into from
Aug 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/Microsoft.SqlTools.Hosting/Hosting/Contracts/Error.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,10 @@ public class Error
/// Error message
/// </summary>
public string Message { get; set; }

public override string ToString()
{
return $"Error(Code={Code},Message='{Message}')";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
using Microsoft.SqlTools.Hosting.Protocol;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.Contracts;
using Microsoft.SqlTools.ServiceLayer.QueryExecution.DataStorage;
using Microsoft.SqlTools.ServiceLayer.Utility;
using Microsoft.SqlTools.Utility;


Expand All @@ -40,93 +41,111 @@ public override void InitializeService(IProtocolEndpoint serviceHost)
/// <summary>
/// Begin to process request to save a resultSet to a file in CSV format
/// </summary>
internal async Task HandleSerializeStartRequest(SerializeDataStartRequestParams serializeParams,
internal Task HandleSerializeStartRequest(SerializeDataStartRequestParams serializeParams,
RequestContext<SerializeDataResult> requestContext)
{
// Run in separate thread so that message thread isn't held up by a potentially time consuming file write
Task.Run(async () => {
await RunSerializeStartRequest(serializeParams, requestContext);
}).ContinueWithOnFaulted(async t => await SendErrorAndCleanup(serializeParams?.FilePath, requestContext, t.Exception));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So a potential problem is that now since this is truly async we'll immediately return - at which point the caller may then call to continue serialization. But the start request may not actually have started/finished yet and thus we have a race condition if the continue serialization call starts trying to do stuff assuming that the start request has finished.

Same issue with the continue calls - those could then get out of order or start causing problems trying to write at the same time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this is dependent on the caller waiting on responses. That's what I've implemented in ADS - we wait on the result to come back, then send the next request. Since each one is awaited, we will get correct ordering. Right now I would treat this as "by design" because:

  • the cost to implement differently would be a lot higher since you then need an ordered queue model to track
  • most importantly since we're sending potentially large amounts of data, all of which are kept in memory until serialized, I explicitly do not want to encourage "send 5 messages in a row" and instead require "send message, wait until processed, send next message".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, yeah my mistake. Got confused about the return value here being the indication that the request was done. Nevermind!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest leaving a comment in the code about this stated 'by design' contract.

return Task.CompletedTask;
}

internal async Task RunSerializeStartRequest(SerializeDataStartRequestParams serializeParams, RequestContext<SerializeDataResult> requestContext)
{
try
{
// Verify we have sensible inputs and there isn't a task running for this file already
Validate.IsNotNull(nameof(serializeParams), serializeParams);
Validate.IsNotNullOrWhitespaceString("FilePath", serializeParams.FilePath);

DataSerializer serializer = null;
bool hasSerializer = inProgressSerializations.TryGetValue(serializeParams.FilePath, out serializer);
if (hasSerializer)
if (inProgressSerializations.TryGetValue(serializeParams.FilePath, out serializer))
{
// Cannot proceed as there is an in progress serialization happening
throw new Exception(SR.SerializationServiceRequestInProgress(serializeParams.FilePath));
}


// Create a new serializer, save for future calls if needed, and write the request out
serializer = new DataSerializer(serializeParams);
if (!serializeParams.IsLastBatch)
{
inProgressSerializations.AddOrUpdate(serializer.FilePath, serializer, (key, old) => serializer);
}
Func<Task<SerializeDataResult>> writeData = () =>
{
return Task.Factory.StartNew(() =>
{
var result = serializer.ProcessRequest(serializeParams);
return result;
});
};
await HandleRequest(writeData, requestContext, "HandleSerializeStartRequest");

Logger.Write(TraceEventType.Verbose, "HandleSerializeStartRequest");
SerializeDataResult result = serializer.ProcessRequest(serializeParams);
await requestContext.SendResult(result);
}
catch (Exception ex)
{
await requestContext.SendError(ex.Message);
await SendErrorAndCleanup(serializeParams.FilePath, requestContext, ex);
}
}

private async Task SendErrorAndCleanup(string filePath, RequestContext<SerializeDataResult> requestContext, Exception ex)
{
if (filePath != null)
{
try
{
DataSerializer removed;
inProgressSerializations.TryRemove(filePath, out removed);
if (removed != null)
{
// Flush any contents to disk and remove the writer
removed.CloseStreams();
}
}
catch
{
// Do not care if there was an error removing this, must always delete if something failed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

must always delete if something failed [](start = 72, length = 38)

You're not actually deleting the file though - is that intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question. I have a gap here for sure. On consideration, I will ensure the write stream is closed out in this case and intentionally leave file contents that have been written as-is. The thinking is that for most formats (CSV, JSON) users are going to be happier with half the written contents on an unexpected error than none.

}
}
await requestContext.SendError(ex.Message);
}

/// <summary>
/// Process request to save a resultSet to a file in CSV format
/// </summary>
internal async Task HandleSerializeContinueRequest(SerializeDataContinueRequestParams serializeParams,
internal Task HandleSerializeContinueRequest(SerializeDataContinueRequestParams serializeParams,
RequestContext<SerializeDataResult> requestContext)
{
// Run in separate thread so that message thread isn't held up by a potentially time consuming file write
Task.Run(async () =>
{
await RunSerializeContinueRequest(serializeParams, requestContext);
}).ContinueWithOnFaulted(async t => await SendErrorAndCleanup(serializeParams?.FilePath, requestContext, t.Exception));
return Task.CompletedTask;
}

internal async Task RunSerializeContinueRequest(SerializeDataContinueRequestParams serializeParams, RequestContext<SerializeDataResult> requestContext)
{
try
{
// Verify we have sensible inputs and some data has already been sent for the file
Validate.IsNotNull(nameof(serializeParams), serializeParams);
Validate.IsNotNullOrWhitespaceString("FilePath", serializeParams.FilePath);

DataSerializer serializer = null;
bool hasSerializer = inProgressSerializations.TryGetValue(serializeParams.FilePath, out serializer);
if (!hasSerializer)
if (!inProgressSerializations.TryGetValue(serializeParams.FilePath, out serializer))
{
throw new Exception(SR.SerializationServiceRequestNotFound(serializeParams.FilePath));
}

Func<Task<SerializeDataResult>> writeData = () =>
{
return Task.Factory.StartNew(() =>
{
var result = serializer.ProcessRequest(serializeParams);
if (serializeParams.IsLastBatch)
{
// Cleanup the serializer
this.inProgressSerializations.TryRemove(serializer.FilePath, out serializer);
}
return result;
});
};
await HandleRequest(writeData, requestContext, "HandleSerializeContinueRequest");
}
catch (Exception ex)
{
await requestContext.SendError(ex.Message);
}
}

private async Task HandleRequest<T>(Func<Task<T>> handler, RequestContext<T> requestContext, string requestType)
{
Logger.Write(TraceEventType.Verbose, requestType);

try
{
T result = await handler();
// Write to file and cleanup if needed
Logger.Write(TraceEventType.Verbose, "HandleSerializeContinueRequest");
SerializeDataResult result = serializer.ProcessRequest(serializeParams);
if (serializeParams.IsLastBatch)
{
// Cleanup the serializer
this.inProgressSerializations.TryRemove(serializer.FilePath, out serializer);
}
await requestContext.SendResult(result);
}
catch (Exception ex)
{
await requestContext.SendError(ex.Message);
await SendErrorAndCleanup(serializeParams.FilePath, requestContext, ex);
}
}
}
Expand Down Expand Up @@ -242,9 +261,13 @@ private void EnsureWriterCreated()
this.writer = factory.GetWriter(requestParams.FilePath);
}
}
private void CloseStreams()
public void CloseStreams()
{
this.writer.Dispose();
if (this.writer != null)
{
this.writer.Dispose();
this.writer = null;
}
}

private SaveResultsAsJsonRequestParams CreateJsonRequestParams()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ public void Validate()
ReceivedEvent received = ReceivedEvents[i];

// Step 1) Make sure the event type matches
Assert.Equal(expected.EventType, received.EventType);
Assert.True(expected.EventType.Equals(received.EventType),
string.Format("Expected EventType {0} but got {1}. Received object is {2}", expected.EventType, received.EventType, received.EventObject.ToString()));

// Step 2) Make sure the param type matches
Assert.True( expected.ParamType == received.EventObject.GetType()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ private async Task TestSaveAsCsvSuccess(bool includeHeaders)
.AddStandardResultValidator()
.Complete();

await SerializationService.HandleSerializeStartRequest(saveParams, efv.Object);
await SerializationService.RunSerializeStartRequest(saveParams, efv.Object);

// Then:
// ... There should not have been an error
Expand Down Expand Up @@ -189,8 +189,8 @@ private async Task SendAndVerifySerializeStartRequest(SerializeDataStartRequestP
.AddStandardResultValidator()
.Complete();

await SerializationService.HandleSerializeStartRequest(request1, efv.Object);

await SerializationService.RunSerializeStartRequest(request1, efv.Object);
// Then:
// ... There should not have been an error
efv.Validate();
Expand All @@ -202,7 +202,7 @@ private async Task SendAndVerifySerializeContinueRequest(SerializeDataContinueRe
.AddStandardResultValidator()
.Complete();

await SerializationService.HandleSerializeContinueRequest(request1, efv.Object);
await SerializationService.RunSerializeContinueRequest(request1, efv.Object);

// Then:
// ... There should not have been an error
Expand Down Expand Up @@ -260,10 +260,10 @@ private static string GetCsvPrintValue(DbCellValue d)
private static void AssertLineEquals(string line, string[] expected)
{
var actual = line.Split(',');
Assert.True(actual.Length == expected.Length, string.Format("Line '{0}' does not match values {1}", line, string.Join(",", expected)));
Assert.True(actual.Length == expected.Length, $"Line '{line}' does not match values {string.Join(",", expected)}");
for (int i = 0; i < actual.Length; i++)
{
Assert.True(expected[i] == actual[i], string.Format("Line '{0}' does not match values '{1}' as '{2}' does not equal '{3}'", line, string.Join(",", expected), expected[i], actual[i]));
Assert.True(expected[i] == actual[i], $"Line '{line}' does not match values '{string.Join(",", expected)}' as '{expected[i]}' does not equal '{actual[i]}'");
}
}

Expand Down