Skip to content

Commit

Permalink
Limit query to single record in case of FirstOrDefault
Browse files Browse the repository at this point in the history
  • Loading branch information
henkmollema committed Jan 16, 2024
1 parent c216175 commit 4e1714a
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions src/Dommel/Select.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public static partial class DommelMapper
/// </returns>
public static IEnumerable<TEntity> Select<TEntity>(this IDbConnection connection, Expression<Func<TEntity, bool>> predicate, IDbTransaction? transaction = null, bool buffered = true)
{
var sql = BuildSelectSql(connection, predicate, out var parameters);
var sql = BuildSelectSql(connection, predicate, false, out var parameters);
LogQuery<TEntity>(sql);
return connection.Query<TEntity>(sql, parameters, transaction, buffered);
}
Expand All @@ -47,7 +47,7 @@ public static IEnumerable<TEntity> Select<TEntity>(this IDbConnection connection
/// </returns>
public static Task<IEnumerable<TEntity>> SelectAsync<TEntity>(this IDbConnection connection, Expression<Func<TEntity, bool>> predicate, IDbTransaction? transaction = null, CancellationToken cancellationToken = default)
{
var sql = BuildSelectSql(connection, predicate, out var parameters);
var sql = BuildSelectSql(connection, predicate, false, out var parameters);
LogQuery<TEntity>(sql);
return connection.QueryAsync<TEntity>(new CommandDefinition(sql, parameters, transaction: transaction, cancellationToken: cancellationToken));
}
Expand All @@ -66,7 +66,7 @@ public static Task<IEnumerable<TEntity>> SelectAsync<TEntity>(this IDbConnection
public static TEntity? FirstOrDefault<TEntity>(this IDbConnection connection, Expression<Func<TEntity, bool>> predicate, IDbTransaction? transaction = null)
where TEntity : class
{
var sql = BuildSelectSql(connection, predicate, out var parameters);
var sql = BuildSelectSql(connection, predicate, true, out var parameters);
LogQuery<TEntity>(sql);
return connection.QueryFirstOrDefault<TEntity>(sql, parameters, transaction);
}
Expand All @@ -86,22 +86,29 @@ public static Task<IEnumerable<TEntity>> SelectAsync<TEntity>(this IDbConnection
public static async Task<TEntity?> FirstOrDefaultAsync<TEntity>(this IDbConnection connection, Expression<Func<TEntity, bool>> predicate, IDbTransaction? transaction = null, CancellationToken cancellationToken = default)
where TEntity : class
{
var sql = BuildSelectSql(connection, predicate, out var parameters);
var sql = BuildSelectSql(connection, predicate, true, out var parameters);
LogQuery<TEntity>(sql);
return await connection.QueryFirstOrDefaultAsync<TEntity>(new CommandDefinition(sql, parameters, transaction, cancellationToken: cancellationToken));
}

private static string BuildSelectSql<TEntity>(IDbConnection connection, Expression<Func<TEntity, bool>> predicate, out DynamicParameters parameters)
private static string BuildSelectSql<TEntity>(IDbConnection connection, Expression<Func<TEntity, bool>> predicate, bool firstRecordOnly, out DynamicParameters parameters)
{
var type = typeof(TEntity);

// Build the select all part
var sql = BuildGetAllQuery(connection, type);

// Append the where statement
sql += CreateSqlExpression<TEntity>(GetSqlBuilder(connection))
.Where(predicate)
.ToSql(out parameters);
var sqlExpression = CreateSqlExpression<TEntity>(GetSqlBuilder(connection))
.Where(predicate);

if (firstRecordOnly)
{
// Only query the first result
sqlExpression.Page(1, 1);
}

sql += sqlExpression.ToSql(out parameters);
return sql;
}

Expand Down Expand Up @@ -153,7 +160,7 @@ public static Task<IEnumerable<TEntity>> SelectPagedAsync<TEntity>(this IDbConne
private static string BuildSelectPagedQuery<TEntity>(IDbConnection connection, Expression<Func<TEntity, bool>> predicate, int pageNumber, int pageSize, out DynamicParameters parameters)
{
// Start with the select query part
var sql = BuildSelectSql(connection, predicate, out parameters);
var sql = BuildSelectSql(connection, predicate, false, out parameters);

// Append the paging part including the order by
var keyColumns = Resolvers.KeyProperties(typeof(TEntity)).Select(p => Resolvers.Column(p.Property, connection));
Expand Down

0 comments on commit 4e1714a

Please sign in to comment.