Skip to content

Commit

Permalink
Stored Procedure calls in SQL Server working.
Browse files Browse the repository at this point in the history
  • Loading branch information
markrendle committed Jan 20, 2011
1 parent eb1bb8a commit 3170839
Show file tree
Hide file tree
Showing 13 changed files with 256 additions and 30 deletions.
6 changes: 3 additions & 3 deletions CommonAssemblyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

[assembly: AssemblyCompany("Simple.Data")]
[assembly: AssemblyProduct("Simple.Data")]
[assembly: AssemblyCopyright("Copyright © Mark Rendle 2010")]
[assembly: AssemblyCopyright("Copyright © Mark Rendle 2010-2011")]
[assembly: AssemblyTrademark("")]
[assembly: AssemblyCulture("")]

Expand All @@ -19,5 +19,5 @@
// COM, set the ComVisible attribute to true on that type.
[assembly: ComVisible(false)]

[assembly: AssemblyVersion("0.3.3.0")]
[assembly: AssemblyFileVersion("0.3.3.0")]
[assembly: AssemblyVersion("0.4.0.0")]
[assembly: AssemblyFileVersion("0.4.0.0")]
7 changes: 5 additions & 2 deletions Simple.Data.Ado/AdoAdapter.IAdapterWithFunctions.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Text;
Expand All @@ -8,14 +9,16 @@ namespace Simple.Data.Ado
{
internal partial class AdoAdapter : IAdapterWithFunctions
{
private readonly ConcurrentDictionary<string, ProcedureExecutor> _executors = new ConcurrentDictionary<string, ProcedureExecutor>();

public bool IsValidFunction(string functionName)
{
return _schema.FindProcedure(functionName) != null;
}

public IEnumerable<ResultSet> Execute(string functionName, IEnumerable<KeyValuePair<string, object>> parameters)
public IEnumerable<ResultSet> Execute(string functionName, IDictionary<string, object> parameters)
{
var executor = new ProcedureExecutor(this, ObjectName.Parse(functionName));
var executor = _executors.GetOrAdd(functionName, f => new ProcedureExecutor(this, ObjectName.Parse(f)));
return executor.Execute(parameters);
}
}
Expand Down
58 changes: 46 additions & 12 deletions Simple.Data.Ado/ProcedureExecutor.cs
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Data;
using System.Data.Common;
using System.Diagnostics;
using System.Linq;
using System.Text;
using Simple.Data.Extensions;
using Simple.Data.Ado.Schema;
using ResultSet = System.Collections.Generic.IEnumerable<System.Collections.Generic.IEnumerable<System.Collections.Generic.KeyValuePair<string, object>>>;
using ResultSet = System.Collections.Generic.IEnumerable<System.Collections.Generic.IDictionary<string, object>>;

namespace Simple.Data.Ado
{
internal class ProcedureExecutor
{
private const string SimpleReturnParameterName = "@__Simple_ReturnValue";

private readonly AdoAdapter _adapter;
private readonly ObjectName _procedureName;
private Func<DbCommand, IEnumerable<ResultSet>> _executeImpl;

public ProcedureExecutor(AdoAdapter adapter, ObjectName procedureName)
{
_adapter = adapter;
_procedureName = procedureName;
}

public IEnumerable<ResultSet> Execute(IEnumerable<KeyValuePair<string, object>> suppliedParameters)
{
return Execute(suppliedParameters.ToDictionary());
_executeImpl = ExecuteReader;
}

public IEnumerable<ResultSet> Execute(IDictionary<string, object> suppliedParameters)
Expand All @@ -40,10 +41,12 @@ public IEnumerable<ResultSet> Execute(IDictionary<string, object> suppliedParame
command.CommandText = procedure.QuotedName;
command.CommandType = CommandType.StoredProcedure;
SetParameters(procedure, command, suppliedParameters);

try
{
return Execute(command);
var result = _executeImpl(command);
suppliedParameters["__ReturnValue"] = command.Parameters[SimpleReturnParameterName].Value;
RetrieveOutputParameterValues(procedure, command, suppliedParameters);
return result;
}
catch (DbException ex)
{
Expand All @@ -52,20 +55,43 @@ public IEnumerable<ResultSet> Execute(IDictionary<string, object> suppliedParame
}
}

private static IEnumerable<ResultSet> Execute(DbCommand command)
private static void RetrieveOutputParameterValues(Procedure procedure, DbCommand command, IDictionary<string, object> suppliedParameters)
{
foreach (var outputParameter in procedure.Parameters.Where(p => p.Direction == ParameterDirection.InputOutput || p.Direction == ParameterDirection.Output))
{
suppliedParameters[outputParameter.Name.Replace("@", "")] =
command.Parameters[outputParameter.Name].Value;
}
}

private IEnumerable<ResultSet> ExecuteReader(DbCommand command)
{
command.Connection.Open();
using (var reader = command.ExecuteReader())
{
do
if (reader.FieldCount > 0)
{
yield return reader.ToDictionaries();
} while (reader.NextResult());
return reader.ToMultipleDictionaries();
}

// Don't call ExecuteReader for this function again.
_executeImpl = ExecuteNonQuery;
return Enumerable.Empty<ResultSet>();
}
}

private static IEnumerable<ResultSet> ExecuteNonQuery(DbCommand command)
{
Trace.TraceInformation("ExecuteNonQuery");
command.Connection.Open();
command.ExecuteNonQuery();
return Enumerable.Empty<ResultSet>();
}

private static void SetParameters(Procedure procedure, DbCommand cmd, IDictionary<string, object> suppliedParameters)
{
AddReturnParameter(cmd);

int i = 0;
foreach (var parameter in procedure.Parameters)
{
Expand All @@ -78,5 +104,13 @@ private static void SetParameters(Procedure procedure, DbCommand cmd, IDictionar
i++;
}
}

private static void AddReturnParameter(DbCommand cmd)
{
var returnParameter = cmd.CreateParameter();
returnParameter.ParameterName = SimpleReturnParameterName;
returnParameter.Direction = ParameterDirection.ReturnValue;
cmd.Parameters.Add(returnParameter);
}
}
}
2 changes: 1 addition & 1 deletion Simple.Data.SqlServer/SqlSchemaProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ private IEnumerable<DataRow> GetSchema(string collectionName, params string[] co

private static Procedure SchemaRowToStoredProcedure(DataRow row)
{
return new Procedure(row["ROUTINE_NAME"].ToString(), row["SPECIFIC_NAME"].ToString(), row["TABLE_SCHEMA"].ToString());
return new Procedure(row["ROUTINE_NAME"].ToString(), row["SPECIFIC_NAME"].ToString(), row["ROUTINE_SCHEMA"].ToString());
}

public IEnumerable<Parameter> GetParameters(Procedure storedProcedure)
Expand Down
2 changes: 1 addition & 1 deletion Simple.Data.SqlTest/DatabaseHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ internal static class DatabaseHelper
{
public static dynamic Open()
{
return Database.OpenConnection(Properties.Settings.Default.ConnectionString);
return Database.Opener.OpenConnection(Properties.Settings.Default.ConnectionString);
}
}
}
72 changes: 72 additions & 0 deletions Simple.Data.SqlTest/ProcedureTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using NUnit.Framework;

namespace Simple.Data.SqlTest
{
[TestFixture]
public class ProcedureTest
{
[Test]
public void GetCustomersTest()
{
var db = DatabaseHelper.Open();
var results = db.GetCustomers();
var actual = results.First();
Assert.AreEqual(1, actual.CustomerId);
}

[Test]
public void GetCustomerCountTest()
{
var db = DatabaseHelper.Open();
var results = db.GetCustomerCount();
Assert.AreEqual(1, results.ReturnValue);
}

[Test]
public void GetCustomerCountSecondCallExecutesNonQueryTest()
{
var listener = new TestTraceListener();
Trace.Listeners.Add(listener);
var db = DatabaseHelper.Open();
db.GetCustomerCount();
Assert.IsFalse(listener.Output.Contains("ExecuteNonQuery"));
db.GetCustomerCount();
Assert.IsTrue(listener.Output.Contains("ExecuteNonQuery"));
Trace.Listeners.Remove(listener);
}

[Test]
public void GetCustomerAndOrdersTest()
{
var db = DatabaseHelper.Open();
var results = db.GetCustomerAndOrders(1);
var customer = results.FirstOrDefault();
Assert.IsNotNull(customer);
Assert.AreEqual(1, customer.CustomerId);
Assert.IsTrue(results.NextResult());
var order = results.FirstOrDefault();
Assert.IsNotNull(order);
Assert.AreEqual(1, order.OrderId);
}

[Test]
public void GetCustomerAndOrdersStillWorksAfterZeroRecordCallTest()
{
var db = DatabaseHelper.Open();
db.GetCustomerAndOrders(1000);
var results = db.GetCustomerAndOrders(1);
var customer = results.FirstOrDefault();
Assert.IsNotNull(customer);
Assert.AreEqual(1, customer.CustomerId);
Assert.IsTrue(results.NextResult());
var order = results.FirstOrDefault();
Assert.IsNotNull(order);
Assert.AreEqual(1, order.OrderId);
}
}
}
2 changes: 2 additions & 0 deletions Simple.Data.SqlTest/Simple.Data.SqlTest.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
<Compile Include="DatabaseHelper.cs" />
<Compile Include="NaturalJoinTest.cs" />
<Compile Include="ObservableDataReaderTest.cs" />
<Compile Include="ProcedureTest.cs" />
<Compile Include="Properties\Resources.Designer.cs">
<AutoGen>True</AutoGen>
<DesignTime>True</DesignTime>
Expand All @@ -63,6 +64,7 @@
<Compile Include="FindTests.cs" />
<Compile Include="OrderDetailTests.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="TestTraceListener.cs" />
<Compile Include="TransactionTests.cs" />
<Compile Include="User.cs" />
</ItemGroup>
Expand Down
35 changes: 35 additions & 0 deletions Simple.Data.SqlTest/TestTraceListener.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;

namespace Simple.Data.SqlTest
{
public class TestTraceListener : TraceListener
{
private readonly StringBuilder _builder = new StringBuilder();
/// <summary>
/// When overridden in a derived class, writes the specified message to the listener you create in the derived class.
/// </summary>
/// <param name="message">A message to write. </param><filterpriority>2</filterpriority>
public override void Write(string message)
{
_builder.Append(message);
}

/// <summary>
/// When overridden in a derived class, writes a message to the listener you create in the derived class, followed by a line terminator.
/// </summary>
/// <param name="message">A message to write. </param><filterpriority>2</filterpriority>
public override void WriteLine(string message)
{
_builder.AppendLine(message);
}

public string Output
{
get { return _builder.ToString(); }
}
}
}
2 changes: 1 addition & 1 deletion Simple.Data/AdapterWithFunctionsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public static bool IsValidFunction(this Adapter adapter, string functionName)
return adapterWithFunctions.IsValidFunction(functionName);
}

public static bool Execute(this Adapter adapter, string functionName, IEnumerable<KeyValuePair<string,object>> parameters, out object result)
public static bool Execute(this Adapter adapter, string functionName, IDictionary<string,object> parameters, out object result)
{
var adapterWithFunctions = adapter as IAdapterWithFunctions;
if (adapterWithFunctions == null) throw new NotSupportedException("Adapter does not support Function calls.");
Expand Down
5 changes: 3 additions & 2 deletions Simple.Data/BinderHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ internal static IDictionary<string, object> NamedArgumentsToDictionary(this Invo
.ToDictionary();
}

public static IEnumerable<KeyValuePair<string, object>> ArgumentsToDictionary(this InvokeMemberBinder binder, IEnumerable<object> args)
public static IDictionary<string, object> ArgumentsToDictionary(this InvokeMemberBinder binder, IEnumerable<object> args)
{
return args.Reverse()
.Zip(binder.CallInfo.ArgumentNames.Reverse().ExtendInfinite(), (v, k) => new KeyValuePair<string, object>(k, v))
.Reverse()
.Select((kvp, i) => kvp.Key == null ? new KeyValuePair<string, object>("_" + i.ToString(), kvp.Value) : kvp);
.Select((kvp, i) => kvp.Key == null ? new KeyValuePair<string, object>("_" + i.ToString(), kvp.Value) : kvp)
.ToDictionary();
}
}
}
12 changes: 7 additions & 5 deletions Simple.Data/Commands/ExecuteFunctionCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ class ExecuteFunctionCommand
{
private readonly IAdapterWithFunctions _adapter;
private readonly string _functionName;
private readonly IEnumerable<KeyValuePair<string, object>> _arguments;
private readonly IDictionary<string, object> _arguments;

public ExecuteFunctionCommand(IAdapterWithFunctions adapter, string functionName, IEnumerable<KeyValuePair<string,object>> arguments)
public ExecuteFunctionCommand(IAdapterWithFunctions adapter, string functionName, IDictionary<string,object> arguments)
{
_adapter = adapter;
_functionName = functionName;
Expand All @@ -26,7 +26,7 @@ public bool Execute(out object result)
return true;
}

private static SimpleResultSet ToMultipleResultSets(object source)
private SimpleResultSet ToMultipleResultSets(object source)
{
if (source == null) return SimpleResultSet.Empty;
var resultSets = source as IEnumerable<IEnumerable<IEnumerable<KeyValuePair<string, object>>>>;
Expand All @@ -35,9 +35,11 @@ private static SimpleResultSet ToMultipleResultSets(object source)
return ToMultipleDynamicEnumerables(resultSets);
}

private static SimpleResultSet ToMultipleDynamicEnumerables(IEnumerable<IEnumerable<IEnumerable<KeyValuePair<string, object>>>> resultSets)
private SimpleResultSet ToMultipleDynamicEnumerables(IEnumerable<IEnumerable<IEnumerable<KeyValuePair<string, object>>>> resultSets)
{
return new SimpleResultSet(resultSets.Select(resultSet => resultSet.Select(dict => new SimpleRecord(dict))));
var result = new SimpleResultSet(resultSets.Select(resultSet => resultSet.Select(dict => new SimpleRecord(dict))));
result.SetOutputValues(_arguments);
return result;
}

private static SimpleResultSet ToResultSet(object source)
Expand Down
2 changes: 1 addition & 1 deletion Simple.Data/IAdapterWithFunctions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ namespace Simple.Data
public interface IAdapterWithFunctions
{
bool IsValidFunction(string functionName);
IEnumerable<ResultSet> Execute(string functionName, IEnumerable<KeyValuePair<string, object>> parameters);
IEnumerable<ResultSet> Execute(string functionName, IDictionary<string, object> parameters);
}
}
Loading

0 comments on commit 3170839

Please sign in to comment.