Skip to content

Fix crash when running multiple filter smudge or clean operations concurrently #1260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 18, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions LibGit2Sharp.Tests/FilterFixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.IO;
using LibGit2Sharp.Tests.TestHelpers;
using Xunit;
using System.Threading.Tasks;

namespace LibGit2Sharp.Tests
{
Expand Down Expand Up @@ -172,6 +173,49 @@ public void CleanFilterWritesOutputToObjectTree()
}
}

[Fact]
public void CanHandleMultipleSmudgesConcurrently()
{
const string decodedInput = "This is a substitution cipher";
const string encodedInput = "Guvf vf n fhofgvghgvba pvcure";

const string branchName = "branch";

Action<Stream, Stream> smudgeCallback = SubstitutionCipherFilter.RotateByThirteenPlaces;

var filter = new FakeFilter(FilterName, attributes, null, smudgeCallback);
var registration = GlobalSettings.RegisterFilter(filter);

try
{
int count = 30;
var tasks = new Task<FileInfo>[count];

for (int i = 0; i < count; i++)
{
tasks[i] = Task.Factory.StartNew(() =>
{
string repoPath = InitNewRepository();
return CheckoutFileForSmudge(repoPath, branchName, encodedInput);
});
}

Task.WaitAll(tasks);

foreach(var task in tasks)
{
FileInfo expectedFile = task.Result;

string readAllText = File.ReadAllText(expectedFile.FullName);
Assert.Equal(decodedInput, readAllText);
}
}
finally
{
GlobalSettings.DeregisterFilter(registration);
}
}

[Fact]
public void WhenCheckingOutAFileFileSmudgeWritesCorrectFileToWorkingDirectory()
{
Expand Down
122 changes: 87 additions & 35 deletions LibGit2Sharp/Filter.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
Expand Down Expand Up @@ -47,18 +50,34 @@ protected Filter(string name, IEnumerable<FilterAttributeEntry> attributes)
~Filter()
{
GlobalSettings.DeregisterFilter(this);

#if LEAKS_IDENTIFYING
int activeStreamCount = activeStreams.Count;
if (activeStreamCount > 0)
{
Trace.WriteLine(string.Format(CultureInfo.InvariantCulture, "{0} leaked {1} stream handles at finalization", GetType().Name, activeStreamCount));
}
#endif
}

private readonly string name;
private readonly IEnumerable<FilterAttributeEntry> attributes;
private readonly GitFilter gitFilter;
private readonly ConcurrentDictionary<IntPtr, StreamState> activeStreams = new ConcurrentDictionary<IntPtr, StreamState>();

private GitWriteStream thisStream;
private GitWriteStream nextStream;
private IntPtr thisPtr;
private IntPtr nextPtr;
private FilterSource filterSource;
private Stream output;
/// <summary>
/// State bag used to keep necessary reference from being
/// garbage collected during filter processing.
/// </summary>
private class StreamState
{
public GitWriteStream thisStream;
public GitWriteStream nextStream;
public IntPtr thisPtr;
public IntPtr nextPtr;
public FilterSource filterSource;
public Stream output;
}

/// <summary>
/// The name that this filter was registered with
Expand Down Expand Up @@ -226,33 +245,44 @@ int InitializeCallback(IntPtr filterPointer)
int StreamCreateCallback(out IntPtr git_writestream_out, GitFilter self, IntPtr payload, IntPtr filterSourcePtr, IntPtr git_writestream_next)
{
int result = 0;
var state = new StreamState();

try
{
Ensure.ArgumentNotZeroIntPtr(filterSourcePtr, "filterSourcePtr");
Ensure.ArgumentNotZeroIntPtr(git_writestream_next, "git_writestream_next");

thisStream = new GitWriteStream();
thisStream.close = StreamCloseCallback;
thisStream.write = StreamWriteCallback;
thisStream.free = StreamFreeCallback;
thisPtr = Marshal.AllocHGlobal(Marshal.SizeOf(thisStream));
Marshal.StructureToPtr(thisStream, thisPtr, false);
nextPtr = git_writestream_next;
nextStream = new GitWriteStream();
Marshal.PtrToStructure(nextPtr, nextStream);
filterSource = FilterSource.FromNativePtr(filterSourcePtr);
output = new WriteStream(nextStream, nextPtr);

Create(filterSource.Path, filterSource.Root, filterSource.SourceMode);
state.thisStream = new GitWriteStream();
state.thisStream.close = StreamCloseCallback;
state.thisStream.write = StreamWriteCallback;
state.thisStream.free = StreamFreeCallback;

state.thisPtr = Marshal.AllocHGlobal(Marshal.SizeOf(state.thisStream));
Marshal.StructureToPtr(state.thisStream, state.thisPtr, false);

state.nextPtr = git_writestream_next;
state.nextStream = new GitWriteStream();
Marshal.PtrToStructure(state.nextPtr, state.nextStream);

state.filterSource = FilterSource.FromNativePtr(filterSourcePtr);
state.output = new WriteStream(state.nextStream, state.nextPtr);

Create(state.filterSource.Path, state.filterSource.Root, state.filterSource.SourceMode);

if (!activeStreams.TryAdd(state.thisPtr, state))
{
// AFAICT this is a theoretical error that could only happen if we manage
// to free the stream pointer but fail to remove the dictionary entry.
throw new InvalidOperationException("Overlapping stream pointers");
}
}
catch (Exception exception)
{
// unexpected failures means memory clean up required
if (thisPtr != IntPtr.Zero)
if (state.thisPtr != IntPtr.Zero)
{
Marshal.FreeHGlobal(thisPtr);
thisPtr = IntPtr.Zero;
Marshal.FreeHGlobal(state.thisPtr);
state.thisPtr = IntPtr.Zero;
}

Log.Write(LogLevel.Error, "Filter.StreamCreateCallback exception");
Expand All @@ -261,24 +291,33 @@ int StreamCreateCallback(out IntPtr git_writestream_out, GitFilter self, IntPtr
result = (int)GitErrorCode.Error;
}

git_writestream_out = thisPtr;
git_writestream_out = state.thisPtr;

return result;
}

int StreamCloseCallback(IntPtr stream)
{
int result = 0;
StreamState state;

try
{
Ensure.ArgumentNotZeroIntPtr(stream, "stream");
Ensure.ArgumentIsExpectedIntPtr(stream, thisPtr, "stream");

using (BufferedStream outputBuffer = new BufferedStream(output, BufferSize))
if(!activeStreams.TryGetValue(stream, out state))
{
throw new ArgumentException("Unknown stream pointer", "stream");
}

Ensure.ArgumentIsExpectedIntPtr(stream, state.thisPtr, "stream");

using (BufferedStream outputBuffer = new BufferedStream(state.output, BufferSize))
{
Complete(filterSource.Path, filterSource.Root, outputBuffer);
Complete(state.filterSource.Path, state.filterSource.Root, outputBuffer);
}

result = state.nextStream.close(state.nextPtr);
}
catch (Exception exception)
{
Expand All @@ -288,19 +327,25 @@ int StreamCloseCallback(IntPtr stream)
result = (int)GitErrorCode.Error;
}

result = nextStream.close(nextPtr);

return result;
}

void StreamFreeCallback(IntPtr stream)
{
StreamState state;

try
{
Ensure.ArgumentNotZeroIntPtr(stream, "stream");
Ensure.ArgumentIsExpectedIntPtr(stream, thisPtr, "stream");

Marshal.FreeHGlobal(thisPtr);
if (!activeStreams.TryRemove(stream, out state))
{
throw new ArgumentException("Double free or invalid stream pointer", "stream");
}

Ensure.ArgumentIsExpectedIntPtr(stream, state.thisPtr, "stream");

Marshal.FreeHGlobal(state.thisPtr);
}
catch (Exception exception)
{
Expand All @@ -312,24 +357,31 @@ void StreamFreeCallback(IntPtr stream)
unsafe int StreamWriteCallback(IntPtr stream, IntPtr buffer, UIntPtr len)
{
int result = 0;
StreamState state;

try
{
Ensure.ArgumentNotZeroIntPtr(stream, "stream");
Ensure.ArgumentNotZeroIntPtr(buffer, "buffer");
Ensure.ArgumentIsExpectedIntPtr(stream, thisPtr, "stream");

if (!activeStreams.TryGetValue(stream, out state))
{
throw new ArgumentException("Invalid or already freed stream pointer", "stream");
}

Ensure.ArgumentIsExpectedIntPtr(stream, state.thisPtr, "stream");

using (UnmanagedMemoryStream input = new UnmanagedMemoryStream((byte*)buffer.ToPointer(), (long)len))
using (BufferedStream outputBuffer = new BufferedStream(output, BufferSize))
using (BufferedStream outputBuffer = new BufferedStream(state.output, BufferSize))
{
switch (filterSource.SourceMode)
switch (state.filterSource.SourceMode)
{
case FilterMode.Clean:
Clean(filterSource.Path, filterSource.Root, input, outputBuffer);
Clean(state.filterSource.Path, state.filterSource.Root, input, outputBuffer);
break;

case FilterMode.Smudge:
Smudge(filterSource.Path, filterSource.Root, input, outputBuffer);
Smudge(state.filterSource.Path, state.filterSource.Root, input, outputBuffer);
break;

default:
Expand Down