Skip to content

Commit

Permalink
Support BulkCopy with DataContext (#3031)
Browse files Browse the repository at this point in the history
* support BulkCopy with DataContext

* use mapping schema from current context
  • Loading branch information
MaceWindu committed Jun 3, 2021
1 parent ce7232f commit 568c04d
Show file tree
Hide file tree
Showing 12 changed files with 210 additions and 111 deletions.
47 changes: 11 additions & 36 deletions Source/LinqToDB/Data/DataConnectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2245,10 +2245,7 @@ public static BulkCopyRowsCopied BulkCopy<T>(this ITable<T> table, BulkCopyOptio
{
if (table == null) throw new ArgumentNullException(nameof(table));

if (!(table.DataContext is DataConnection dataConnection))
throw new ArgumentException("DataContext must be of DataConnection type.");

return dataConnection.DataProvider.BulkCopy(table, options, source);
return table.GetDataProvider().BulkCopy(table, options, source);
}

/// <summary>
Expand All @@ -2264,10 +2261,7 @@ public static BulkCopyRowsCopied BulkCopy<T>(this ITable<T> table, int maxBatchS
{
if (table == null) throw new ArgumentNullException(nameof(table));

if (!(table.DataContext is DataConnection dataConnection))
throw new ArgumentException("DataContext must be of DataConnection type.");

return dataConnection.DataProvider.BulkCopy(table, new BulkCopyOptions { MaxBatchSize = maxBatchSize, }, source);
return table.GetDataProvider().BulkCopy(table, new BulkCopyOptions { MaxBatchSize = maxBatchSize, }, source);
}

/// <summary>
Expand All @@ -2282,12 +2276,11 @@ public static BulkCopyRowsCopied BulkCopy<T>(this ITable<T> table, IEnumerable<T
{
if (table == null) throw new ArgumentNullException(nameof(table));

if (!(table.DataContext is DataConnection dataConnection))
throw new ArgumentException("DataContext must be of DataConnection type.");

return dataConnection.DataProvider.BulkCopy(table, new BulkCopyOptions(), source);
return table.GetDataProvider().BulkCopy(table, new BulkCopyOptions(), source);
}



#endregion

#region BulkCopy IEnumerable async
Expand Down Expand Up @@ -2368,10 +2361,7 @@ public static Task<BulkCopyRowsCopied> BulkCopyAsync<T>(this ITable<T> table, Bu
if (table == null) throw new ArgumentNullException(nameof(table));
if (source == null) throw new ArgumentNullException(nameof(source));

if (!(table.DataContext is DataConnection dataConnection))
throw new ArgumentException("DataContext must be of DataConnection type.");

return dataConnection.DataProvider.BulkCopyAsync(table, options, source, cancellationToken);
return table.GetDataProvider().BulkCopyAsync(table, options, source, cancellationToken);
}

/// <summary>
Expand All @@ -2389,10 +2379,7 @@ public static Task<BulkCopyRowsCopied> BulkCopyAsync<T>(this ITable<T> table, in
if (table == null) throw new ArgumentNullException(nameof(table));
if (source == null) throw new ArgumentNullException(nameof(source));

if (!(table.DataContext is DataConnection dataConnection))
throw new ArgumentException("DataContext must be of DataConnection type.");

return dataConnection.DataProvider.BulkCopyAsync(table, new BulkCopyOptions { MaxBatchSize = maxBatchSize, }, source, cancellationToken);
return table.GetDataProvider().BulkCopyAsync(table, new BulkCopyOptions { MaxBatchSize = maxBatchSize, }, source, cancellationToken);
}

/// <summary>
Expand All @@ -2409,10 +2396,7 @@ public static Task<BulkCopyRowsCopied> BulkCopyAsync<T>(this ITable<T> table, IE
if (table == null) throw new ArgumentNullException(nameof(table));
if (source == null) throw new ArgumentNullException(nameof(source));

if (!(table.DataContext is DataConnection dataConnection))
throw new ArgumentException("DataContext must be of DataConnection type.");

return dataConnection.DataProvider.BulkCopyAsync(table, new BulkCopyOptions(), source, cancellationToken);
return table.GetDataProvider().BulkCopyAsync(table, new BulkCopyOptions(), source, cancellationToken);
}

#endregion
Expand Down Expand Up @@ -2496,10 +2480,7 @@ public static Task<BulkCopyRowsCopied> BulkCopyAsync<T>(this ITable<T> table, Bu
if (table == null) throw new ArgumentNullException(nameof(table));
if (source == null) throw new ArgumentNullException(nameof(source));

if (!(table.DataContext is DataConnection dataConnection))
throw new ArgumentException("DataContext must be of DataConnection type.");

return dataConnection.DataProvider.BulkCopyAsync(table, options, source, cancellationToken);
return table.GetDataProvider().BulkCopyAsync(table, options, source, cancellationToken);
}

/// <summary>
Expand All @@ -2517,10 +2498,7 @@ public static Task<BulkCopyRowsCopied> BulkCopyAsync<T>(this ITable<T> table, in
if (table == null) throw new ArgumentNullException(nameof(table));
if (source == null) throw new ArgumentNullException(nameof(source));

if (!(table.DataContext is DataConnection dataConnection))
throw new ArgumentException("DataContext must be of DataConnection type.");

return dataConnection.DataProvider.BulkCopyAsync(table, new BulkCopyOptions { MaxBatchSize = maxBatchSize, }, source, cancellationToken);
return table.GetDataProvider().BulkCopyAsync(table, new BulkCopyOptions { MaxBatchSize = maxBatchSize, }, source, cancellationToken);
}

/// <summary>
Expand All @@ -2537,10 +2515,7 @@ public static Task<BulkCopyRowsCopied> BulkCopyAsync<T>(this ITable<T> table, IA
if (table == null) throw new ArgumentNullException(nameof(table));
if (source == null) throw new ArgumentNullException(nameof(source));

if (!(table.DataContext is DataConnection dataConnection))
throw new ArgumentException("DataContext must be of DataConnection type.");

return dataConnection.DataProvider.BulkCopyAsync(table, new BulkCopyOptions(), source, cancellationToken);
return table.GetDataProvider().BulkCopyAsync(table, new BulkCopyOptions(), source, cancellationToken);
}

#endif
Expand Down
22 changes: 11 additions & 11 deletions Source/LinqToDB/DataProvider/DB2/DB2BulkCopy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ protected override BulkCopyRowsCopied ProviderSpecificCopy<T>(
BulkCopyOptions options,
IEnumerable<T> source)
{
if (table.DataContext is DataConnection dataConnection)
if (table.TryGetDataConnection(out var dataConnection))
{
var connection = _provider.TryGetProviderConnection(dataConnection.Connection, dataConnection.MappingSchema);
var connection = _provider.TryGetProviderConnection(dataConnection.Connection, table.DataContext.MappingSchema);
if (connection != null)
return ProviderSpecificCopyImpl(
table,
Expand All @@ -44,9 +44,9 @@ protected override BulkCopyRowsCopied ProviderSpecificCopy<T>(

protected override Task<BulkCopyRowsCopied> ProviderSpecificCopyAsync<T>(ITable<T> table, BulkCopyOptions options, IEnumerable<T> source, CancellationToken cancellationToken)
{
if (table.DataContext is DataConnection dataConnection)
if (table.TryGetDataConnection(out var dataConnection))
{
var connection = _provider.TryGetProviderConnection(dataConnection.Connection, dataConnection.MappingSchema);
var connection = _provider.TryGetProviderConnection(dataConnection.Connection, table.DataContext.MappingSchema);
if (connection != null)
// call the synchronous provider-specific implementation
return Task.FromResult(ProviderSpecificCopyImpl(
Expand All @@ -65,9 +65,9 @@ protected override Task<BulkCopyRowsCopied> ProviderSpecificCopyAsync<T>(ITable<
#if NATIVE_ASYNC
protected override async Task<BulkCopyRowsCopied> ProviderSpecificCopyAsync<T>(ITable<T> table, BulkCopyOptions options, IAsyncEnumerable<T> source, CancellationToken cancellationToken)
{
if (table.DataContext is DataConnection dataConnection)
if (table.TryGetDataConnection(out var dataConnection))
{
var connection = _provider.TryGetProviderConnection(dataConnection.Connection, dataConnection.MappingSchema);
var connection = _provider.TryGetProviderConnection(dataConnection.Connection, table.DataContext.MappingSchema);
if (connection != null)
{
var enumerator = source.GetAsyncEnumerator(cancellationToken);
Expand Down Expand Up @@ -100,11 +100,11 @@ internal static BulkCopyRowsCopied ProviderSpecificCopyImpl<T>(
Action<DataConnection, Func<string>, Func<int>> traceAction)
where T : notnull
{
var descriptor = dataConnection.MappingSchema.GetEntityDescriptor(typeof(T));
var descriptor = table.DataContext.MappingSchema.GetEntityDescriptor(typeof(T));
var columns = descriptor.Columns.Where(c => !c.SkipOnInsert || options.KeepIdentity == true && c.IsIdentity).ToList();
var rd = new BulkCopyReader<T>(dataConnection, columns, source);
var rc = new BulkCopyRowsCopied();
var sqlBuilder = dataConnection.DataProvider.CreateSqlBuilder(dataConnection.MappingSchema);
var sqlBuilder = dataConnection.DataProvider.CreateSqlBuilder(table.DataContext.MappingSchema);
var tableName = GetTableName(sqlBuilder, options, table);

var bcOptions = DB2BulkCopyOptions.Default;
Expand Down Expand Up @@ -159,7 +159,7 @@ internal static BulkCopyRowsCopied ProviderSpecificCopyImpl<T>(

protected override BulkCopyRowsCopied MultipleRowsCopy<T>(ITable<T> table, BulkCopyOptions options, IEnumerable<T> source)
{
var dataConnection = (DataConnection)table.DataContext;
var dataConnection = table.GetDataConnection();

if (((DB2DataProvider)dataConnection.DataProvider).Version == DB2Version.zOS)
return MultipleRowsCopy2(table, options, source, " FROM SYSIBM.SYSDUMMY1");
Expand All @@ -169,7 +169,7 @@ protected override BulkCopyRowsCopied MultipleRowsCopy<T>(ITable<T> table, BulkC

protected override Task<BulkCopyRowsCopied> MultipleRowsCopyAsync<T>(ITable<T> table, BulkCopyOptions options, IEnumerable<T> source, CancellationToken cancellationToken)
{
var dataConnection = (DataConnection)table.DataContext;
var dataConnection = table.GetDataConnection();

if (((DB2DataProvider)dataConnection.DataProvider).Version == DB2Version.zOS)
return MultipleRowsCopy2Async(table, options, source, " FROM SYSIBM.SYSDUMMY1", cancellationToken);
Expand All @@ -180,7 +180,7 @@ protected override Task<BulkCopyRowsCopied> MultipleRowsCopyAsync<T>(ITable<T> t
#if NATIVE_ASYNC
protected override Task<BulkCopyRowsCopied> MultipleRowsCopyAsync<T>(ITable<T> table, BulkCopyOptions options, IAsyncEnumerable<T> source, CancellationToken cancellationToken)
{
var dataConnection = (DataConnection)table.DataContext;
var dataConnection = table.GetDataConnection();

if (((DB2DataProvider)dataConnection.DataProvider).Version == DB2Version.zOS)
return MultipleRowsCopy2Async(table, options, source, " FROM SYSIBM.SYSDUMMY1", cancellationToken);
Expand Down
19 changes: 11 additions & 8 deletions Source/LinqToDB/DataProvider/Informix/InformixBulkCopy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ protected override BulkCopyRowsCopied ProviderSpecificCopy<T>(
BulkCopyOptions options,
IEnumerable<T> source)
{
if ((_provider.Adapter.InformixBulkCopy != null || _provider.Adapter.DB2BulkCopy != null) && table.DataContext is DataConnection dataConnection && dataConnection.Transaction == null)
if ((_provider.Adapter.InformixBulkCopy != null || _provider.Adapter.DB2BulkCopy != null)
&& table.TryGetDataConnection(out var dataConnection) && dataConnection.Transaction == null)
{
var connection = _provider.TryGetProviderConnection(dataConnection.Connection, dataConnection.MappingSchema);
var connection = _provider.TryGetProviderConnection(dataConnection.Connection, table.DataContext.MappingSchema);

if (connection != null)
{
Expand Down Expand Up @@ -60,9 +61,10 @@ protected override Task<BulkCopyRowsCopied> ProviderSpecificCopyAsync<T>(
IEnumerable<T> source,
CancellationToken cancellationToken)
{
if ((_provider.Adapter.InformixBulkCopy != null || _provider.Adapter.DB2BulkCopy != null) && table.DataContext is DataConnection dataConnection && dataConnection.Transaction == null)
if ((_provider.Adapter.InformixBulkCopy != null || _provider.Adapter.DB2BulkCopy != null)
&& table.TryGetDataConnection(out var dataConnection) && dataConnection.Transaction == null)
{
var connection = _provider.TryGetProviderConnection(dataConnection.Connection, dataConnection.MappingSchema);
var connection = _provider.TryGetProviderConnection(dataConnection.Connection, table.DataContext.MappingSchema);

if (connection != null)
{
Expand Down Expand Up @@ -97,9 +99,10 @@ protected override async Task<BulkCopyRowsCopied> ProviderSpecificCopyAsync<T>(
IAsyncEnumerable<T> source,
CancellationToken cancellationToken)
{
if ((_provider.Adapter.InformixBulkCopy != null || _provider.Adapter.DB2BulkCopy != null) && table.DataContext is DataConnection dataConnection && dataConnection.Transaction == null)
if ((_provider.Adapter.InformixBulkCopy != null || _provider.Adapter.DB2BulkCopy != null)
&& table.TryGetDataConnection(out var dataConnection) && dataConnection.Transaction == null)
{
var connection = _provider.TryGetProviderConnection(dataConnection.Connection, dataConnection.MappingSchema);
var connection = _provider.TryGetProviderConnection(dataConnection.Connection, table.DataContext.MappingSchema);

if (connection != null)
{
Expand Down Expand Up @@ -143,9 +146,9 @@ protected BulkCopyRowsCopied IDSProviderSpecificCopy<T>(
InformixProviderAdapter.BulkCopyAdapter bulkCopy)
where T: notnull
{
var ed = dataConnection.MappingSchema.GetEntityDescriptor(typeof(T));
var ed = table.DataContext.MappingSchema.GetEntityDescriptor(typeof(T));
var columns = ed.Columns.Where(c => !c.SkipOnInsert || options.KeepIdentity == true && c.IsIdentity).ToList();
var sb = _provider.CreateSqlBuilder(dataConnection.MappingSchema);
var sb = _provider.CreateSqlBuilder(table.DataContext.MappingSchema);
var rd = new BulkCopyReader<T>(dataConnection, columns, source);
var sqlopt = InformixProviderAdapter.IfxBulkCopyOptions.Default;
var rc = new BulkCopyRowsCopied();
Expand Down
19 changes: 13 additions & 6 deletions Source/LinqToDB/DataProvider/MultipleRowsHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,27 @@ public class MultipleRowsHelper<T> : MultipleRowsHelper
where T : notnull
{
public MultipleRowsHelper(ITable<T> table, BulkCopyOptions options)
: base((DataConnection)table.DataContext, options, typeof(T))
: base(table.DataContext, options, typeof(T))
{
TableName = BasicBulkCopy.GetTableName(SqlBuilder, options, table);
}
}

public abstract class MultipleRowsHelper
{
protected MultipleRowsHelper(DataConnection dataConnection, BulkCopyOptions options, Type entityType)
protected MultipleRowsHelper(IDataContext dataConnection, BulkCopyOptions options, Type entityType)
{
DataConnection = dataConnection;
DataConnection = dataConnection is DataConnection dc
? dc
: dataConnection is DataContext dx
? dx.GetDataConnection()
: throw new ArgumentException($"Must be of {nameof(DataConnection)} or {nameof(DataContext)} type but was {dataConnection.GetType()}", nameof(dataConnection));

MappingSchema = dataConnection.MappingSchema;
Options = options;
SqlBuilder = dataConnection.DataProvider.CreateSqlBuilder(dataConnection.MappingSchema);
ValueConverter = dataConnection.MappingSchema.ValueToSqlConverter;
Descriptor = dataConnection.MappingSchema.GetEntityDescriptor(entityType);
SqlBuilder = DataConnection.DataProvider.CreateSqlBuilder(MappingSchema);
ValueConverter = MappingSchema.ValueToSqlConverter;
Descriptor = MappingSchema.GetEntityDescriptor(entityType);
Columns = Descriptor.Columns
.Where(c => !c.SkipOnInsert || c.IsIdentity && options.KeepIdentity == true)
.ToArray();
Expand All @@ -41,6 +47,7 @@ protected MultipleRowsHelper(DataConnection dataConnection, BulkCopyOptions opti

public readonly ISqlBuilder SqlBuilder;
public readonly DataConnection DataConnection;
public readonly MappingSchema MappingSchema;
public readonly BulkCopyOptions Options;
public readonly ValueToSqlConverter ValueConverter;
public readonly EntityDescriptor Descriptor;
Expand Down
27 changes: 14 additions & 13 deletions Source/LinqToDB/DataProvider/MySql/MySqlBulkCopy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public MySqlBulkCopy(MySqlDataProvider provider)
{
_provider = provider;
}

protected override BulkCopyRowsCopied ProviderSpecificCopy<T>(
ITable<T> table,
BulkCopyOptions options,
Expand Down Expand Up @@ -80,13 +81,13 @@ protected override Task<BulkCopyRowsCopied> ProviderSpecificCopyAsync<T>(
private ProviderConnections? TryGetProviderConnections<T>(ITable<T> table)
where T : notnull
{
if (table.DataContext is DataConnection dataConnection && _provider.Adapter.BulkCopy != null)
if (table.TryGetDataConnection(out var dataConnection) && _provider.Adapter.BulkCopy != null)
{
var connection = _provider.TryGetProviderConnection(dataConnection.Connection, dataConnection.MappingSchema);
var connection = _provider.TryGetProviderConnection(dataConnection.Connection, table.DataContext.MappingSchema);

var transaction = dataConnection.Transaction;
if (connection != null && transaction != null)
transaction = _provider.TryGetProviderTransaction(transaction, dataConnection.MappingSchema);
transaction = _provider.TryGetProviderTransaction(transaction, table.DataContext.MappingSchema);

if (connection != null && (dataConnection.Transaction == null || transaction != null))
{
Expand All @@ -112,9 +113,9 @@ private async Task<BulkCopyRowsCopied> ProviderSpecificCopyInternalAsync<T>(
var dataConnection = providerConnections.DataConnection;
var connection = providerConnections.ProviderConnection;
var transaction = providerConnections.ProviderTransaction;
var ed = dataConnection.MappingSchema.GetEntityDescriptor(typeof(T));
var ed = table.DataContext.MappingSchema.GetEntityDescriptor(typeof(T));
var columns = ed.Columns.Where(c => !c.SkipOnInsert || options.KeepIdentity == true && c.IsIdentity).ToList();
var sb = _provider.CreateSqlBuilder(dataConnection.MappingSchema);
var sb = _provider.CreateSqlBuilder(table.DataContext.MappingSchema);
var rc = new BulkCopyRowsCopied();

var bc = _provider.Adapter.BulkCopy!.Create(connection, transaction);
Expand Down Expand Up @@ -191,9 +192,9 @@ private BulkCopyRowsCopied ProviderSpecificCopyInternal<T>(
var dataConnection = providerConnections.DataConnection;
var connection = providerConnections.ProviderConnection;
var transaction = providerConnections.ProviderTransaction;
var ed = dataConnection.MappingSchema.GetEntityDescriptor(typeof(T));
var ed = table.DataContext.MappingSchema.GetEntityDescriptor(typeof(T));
var columns = ed.Columns.Where(c => !c.SkipOnInsert || options.KeepIdentity == true && c.IsIdentity).ToList();
var sb = _provider.CreateSqlBuilder(dataConnection.MappingSchema);
var sb = _provider.CreateSqlBuilder(table.DataContext.MappingSchema);
var rc = new BulkCopyRowsCopied();

var bc = _provider.Adapter.BulkCopy!.Create(connection, transaction);
Expand Down Expand Up @@ -256,12 +257,12 @@ private async Task<BulkCopyRowsCopied> ProviderSpecificCopyInternalAsync<T>(
where T: notnull
{
var dataConnection = providerConnections.DataConnection;
var connection = providerConnections.ProviderConnection;
var transaction = providerConnections.ProviderTransaction;
var ed = dataConnection.MappingSchema.GetEntityDescriptor(typeof(T));
var columns = ed.Columns.Where(c => !c.SkipOnInsert || options.KeepIdentity == true && c.IsIdentity).ToList();
var sb = _provider.CreateSqlBuilder(dataConnection.MappingSchema);
var rc = new BulkCopyRowsCopied();
var connection = providerConnections.ProviderConnection;
var transaction = providerConnections.ProviderTransaction;
var ed = table.DataContext.MappingSchema.GetEntityDescriptor(typeof(T));
var columns = ed.Columns.Where(c => !c.SkipOnInsert || options.KeepIdentity == true && c.IsIdentity).ToList();
var sb = _provider.CreateSqlBuilder(table.DataContext.MappingSchema);
var rc = new BulkCopyRowsCopied();

var bc = _provider.Adapter.BulkCopy!.Create(connection, transaction);
if (options.NotifyAfter != 0 && options.RowsCopiedCallback != null)
Expand Down
Loading

0 comments on commit 568c04d

Please sign in to comment.