diff --git a/src/SqlClient.Tests/ProgrammabilityTests.fs b/src/SqlClient.Tests/ProgrammabilityTests.fs index 34986662..7af09a81 100644 --- a/src/SqlClient.Tests/ProgrammabilityTests.fs +++ b/src/SqlClient.Tests/ProgrammabilityTests.fs @@ -7,6 +7,7 @@ open FSharp.Data.SqlClient type AdventureWorks = SqlProgrammabilityProvider +type AdventureWorksDataTables = SqlProgrammabilityProvider type GetContactInformation = AdventureWorks.dbo.ufnGetContactInformation [] @@ -259,5 +260,9 @@ let PassingImageAsParamDoesntGetCut() = Assert.Equal(existing.LargePhoto.Value.Length, inserted.LargePhoto.Value.Length) Assert.Equal(existing.LargePhoto, inserted.LargePhoto) - - +[] +let ``honors result type parameter: datatable`` () = + let command = new AdventureWorksDataTables.Sales.GetUKSalesOrders(ConnectionStrings.AdventureWorksLiteral) + let gbp = 1.0M + let table : AdventureWorksDataTables.Sales.GetUKSalesOrders.Table = command.Execute(gbp) + Assert.Equal("Year", table.Columns.Year.ColumnName) \ No newline at end of file diff --git a/src/SqlClient/DesignTime.fs b/src/SqlClient/DesignTime.fs index 7ee45935..14ce04e9 100644 --- a/src/SqlClient/DesignTime.fs +++ b/src/SqlClient/DesignTime.fs @@ -30,6 +30,16 @@ type internal ReturnType = { | Some x -> Expr.Value( x.ErasedTo.AssemblyQualifiedName) | None -> <@@ null: string @@> +module internal SharedLogic = + /// Adds .Record or .Table inner type depending on resultType + let alterReturnTypeAccordingToResultType (returnType: ReturnType) (cmdProvidedType: ProvidedTypeDefinition) resultType = + if resultType = ResultType.Records then + // Add .Record + returnType.PerRow |> Option.iter (fun x -> cmdProvidedType.AddMember x.Provided) + elif resultType = ResultType.DataTable then + // add .Table + returnType.Single |> cmdProvidedType.AddMember + type DesignTime private() = static member internal AddGeneratedMethod (sqlParameters: Parameter list, hasOutputParameters, executeArgs: ProvidedParameter list, erasedType, providedOutputType, name) = diff --git a/src/SqlClient/SqlClientProvider.fs b/src/SqlClient/SqlClientProvider.fs index b81ff41a..4771bfec 100644 --- a/src/SqlClient/SqlClientProvider.fs +++ b/src/SqlClient/SqlClientProvider.fs @@ -44,9 +44,10 @@ type SqlProgrammabilityProvider(config : TypeProviderConfig) as this = ProvidedStaticParameter("ConfigFile", typeof, "") ProvidedStaticParameter("DataDirectory", typeof, "") ProvidedStaticParameter("UseReturnValue", typeof, false) - ], + ProvidedStaticParameter("ResultType", typeof, ResultType.Records) + ], instantiationFunction = (fun typeName args -> - let root = lazy this.CreateRootType(typeName, unbox args.[0], unbox args.[1], unbox args.[2], unbox args.[3]) + let root = lazy this.CreateRootType(typeName, unbox args.[0], unbox args.[1], unbox args.[2], unbox args.[3], unbox args.[4]) cache.GetOrAdd(typeName, root) ) ) @@ -54,9 +55,10 @@ type SqlProgrammabilityProvider(config : TypeProviderConfig) as this = providerType.AddXmlDoc """ Typed access to SQL Server programmable objects: stored procedures, functions and user defined table types. String used to open a SQL Server database or the name of the connection string in the configuration file in the form of “name=<connection string name>”. -A value that defines structure of result: Records, Tuples, DataTable, or SqlDataReader. The name of the configuration file that’s used for connection strings at DESIGN-TIME. The default value is app.config or web.config. The name of the data directory that replaces |DataDirectory| in connection strings. The default value is the project or script directory. +To be documented. +A value that defines structure of result: Records, Tuples, DataTable, or SqlDataReader, this affects only Stored Procedures. """ this.AddNamespace(nameSpace, [ providerType ]) @@ -68,7 +70,7 @@ type SqlProgrammabilityProvider(config : TypeProviderConfig) as this = |> defaultArg <| base.ResolveAssembly args - member internal this.CreateRootType( typeName, connectionStringOrName, configFile, dataDirectory, useReturnValue) = + member internal this.CreateRootType( typeName, connectionStringOrName, configFile, dataDirectory, useReturnValue, resultType) = if String.IsNullOrWhiteSpace connectionStringOrName then invalidArg "ConnectionStringOrName" "Value is empty!" let designTimeConnectionString = DesignTimeConnectionString.Parse(connectionStringOrName, config.ResolutionFolder, configFile) @@ -122,7 +124,7 @@ type SqlProgrammabilityProvider(config : TypeProviderConfig) as this = schemaType.AddMembersDelayed <| fun() -> [ - let routines = this.Routines(conn, schemaType.Name, udttsPerSchema, ResultType.Records, designTimeConnectionString, useReturnValue, uomPerSchema) + let routines = this.Routines(conn, schemaType.Name, udttsPerSchema, resultType, designTimeConnectionString, useReturnValue, uomPerSchema) routines |> List.iter tagProvidedType yield! routines @@ -182,8 +184,8 @@ type SqlProgrammabilityProvider(config : TypeProviderConfig) as this = let returnType = DesignTime.GetOutputTypes(outputColumns, resultType, rank, hasOutputParameters, unitsOfMeasurePerSchema) - do //Record - returnType.PerRow |> Option.iter (fun x -> cmdProvidedType.AddMember x.Provided) + do + SharedLogic.alterReturnTypeAccordingToResultType returnType cmdProvidedType resultType //ctors let sqlParameters = Expr.NewArray( typeof, parameters |> List.map QuotationsFactory.ToSqlParam) @@ -218,17 +220,17 @@ type SqlProgrammabilityProvider(config : TypeProviderConfig) as this = yield upcast DesignTime.AddGeneratedMethod(parameters, hasOutputParameters, executeArgs, cmdProvidedType.BaseType, returnType.Single, "Execute") if not hasOutputParameters - then + then let asyncReturnType = ProvidedTypeBuilder.MakeGenericType(typedefof<_ Async>, [ returnType.Single ]) yield upcast DesignTime.AddGeneratedMethod(parameters, hasOutputParameters, executeArgs, cmdProvidedType.BaseType, asyncReturnType, "AsyncExecute") if returnType.PerRow.IsSome - then + then let providedReturnType = ProvidedTypeBuilder.MakeGenericType(typedefof<_ option>, [ returnType.PerRow.Value.Provided ]) let providedAsyncReturnType = ProvidedTypeBuilder.MakeGenericType(typedefof<_ Async>, [ providedReturnType ]) if not hasOutputParameters - then + then yield upcast DesignTime.AddGeneratedMethod(parameters, hasOutputParameters, executeArgs, cmdProvidedType.BaseType, providedReturnType, "ExecuteSingle") yield upcast DesignTime.AddGeneratedMethod(parameters, hasOutputParameters, executeArgs, cmdProvidedType.BaseType, providedAsyncReturnType, "AsyncExecuteSingle") ] diff --git a/src/SqlClient/SqlCommandProvider.fs b/src/SqlClient/SqlCommandProvider.fs index c19c9364..41320d25 100644 --- a/src/SqlClient/SqlCommandProvider.fs +++ b/src/SqlClient/SqlCommandProvider.fs @@ -120,12 +120,7 @@ type SqlCommandProvider(config : TypeProviderConfig) as this = cmdProvidedType.AddMember(ProvidedProperty("ConnectionStringOrName", typeof, [], IsStatic = true, GetterCode = fun _ -> <@@ connectionStringOrName @@>)) do - if resultType = ResultType.Records then - // Add .Record - returnType.PerRow |> Option.iter (fun x -> cmdProvidedType.AddMember x.Provided) - elif resultType = ResultType.DataTable then - // add .Table - returnType.Single |> cmdProvidedType.AddMember + SharedLogic.alterReturnTypeAccordingToResultType returnType cmdProvidedType resultType do //ctors let designTimeConfig =