Skip to content

Commit

Permalink
fix(knex): support em.count() on virtual entities
Browse files Browse the repository at this point in the history
Works only for SQL drivers.
  • Loading branch information
B4nan committed Aug 28, 2022
1 parent 95c8dd5 commit 5bb4ebe
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 18 deletions.
5 changes: 5 additions & 0 deletions packages/core/src/drivers/DatabaseDriver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ export abstract class DatabaseDriver<C extends Connection> implements IDatabaseD
throw new Error(`Virtual entities are not supported by ${this.constructor.name} driver.`);
}

/* istanbul ignore next */
async countVirtual<T>(entityName: string, where: FilterQuery<T>, options: CountOptions<T>): Promise<number> {
throw new Error(`Counting virtual entities is not supported by ${this.constructor.name} driver.`);
}

async aggregate(entityName: string, pipeline: any[]): Promise<any[]> {
throw new Error(`Aggregations are not supported by ${this.constructor.name} driver`);
}
Expand Down
50 changes: 46 additions & 4 deletions packages/knex/src/AbstractSqlDriver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import type { AbstractSqlPlatform } from './AbstractSqlPlatform';
import { QueryBuilder } from './query/QueryBuilder';
import { SqlEntityManager } from './SqlEntityManager';
import type { Field } from './typings';
import { QueryType } from './query';

export abstract class AbstractSqlDriver<C extends AbstractSqlConnection = AbstractSqlConnection> extends DatabaseDriver<C> {

Expand Down Expand Up @@ -122,7 +123,33 @@ export abstract class AbstractSqlDriver<C extends AbstractSqlConnection = Abstra
return res as EntityData<T>[];
}

protected async wrapVirtualExpressionInSubquery<T>(meta: EntityMetadata<T>, expression: string, where: FilterQuery<T>, options: FindOptions<T, any>) {
async countVirtual<T>(entityName: string, where: FilterQuery<T>, options: CountOptions<T>): Promise<number> {
const meta = this.metadata.get<T>(entityName);

/* istanbul ignore next */
if (!meta.expression) {
return 0;
}

if (typeof meta.expression === 'string') {
return this.wrapVirtualExpressionInSubquery(meta, meta.expression, where, options as Dictionary, QueryType.COUNT);
}

const em = this.createEntityManager(false);
em.setTransactionContext(options.ctx);
const res = meta.expression(em, where, options as Dictionary);

if (res instanceof QueryBuilder<T[]>) {
return this.wrapVirtualExpressionInSubquery(meta, res.getFormattedQuery(), where, options as Dictionary, QueryType.COUNT);
}

return res as any;
}

protected async wrapVirtualExpressionInSubquery<T>(meta: EntityMetadata<T>, expression: string, where: FilterQuery<T>, options: FindOptions<T, any>, type: QueryType.COUNT): Promise<number>;
protected async wrapVirtualExpressionInSubquery<T>(meta: EntityMetadata<T>, expression: string, where: FilterQuery<T>, options: FindOptions<T, any>, type: QueryType.SELECT): Promise<T[]>;
protected async wrapVirtualExpressionInSubquery<T>(meta: EntityMetadata<T>, expression: string, where: FilterQuery<T>, options: FindOptions<T, any>): Promise<T[]>;
protected async wrapVirtualExpressionInSubquery<T>(meta: EntityMetadata<T>, expression: string, where: FilterQuery<T>, options: FindOptions<T, any>, type = QueryType.SELECT): Promise<unknown> {
const qb = this.createQueryBuilder(meta.className, options?.ctx, options.connectionType, options.convertCustomTypes)
.limit(options?.limit, options?.offset);

Expand All @@ -132,11 +159,21 @@ export abstract class AbstractSqlDriver<C extends AbstractSqlConnection = Abstra

qb.where(where);

const kqb = qb.getKnexQuery();
kqb.clear('select').select('*');
kqb.fromRaw(`(${expression}) as ${this.platform.quoteIdentifier(qb.alias)}`);
const kqb = qb.getKnexQuery().clear('select');

if (type === QueryType.COUNT) {
kqb.select(qb.raw('count(*) as count'));
} else { // select
kqb.select('*');
}

kqb.fromRaw(`(${expression}) as ${this.platform.quoteIdentifier(qb.alias)}`);
const res = await this.execute<T[]>(kqb);

if (type === QueryType.COUNT) {
return (res[0] as Dictionary).count;
}

return res.map(row => this.mapResult(row, meta)!);
}

Expand Down Expand Up @@ -235,6 +272,11 @@ export abstract class AbstractSqlDriver<C extends AbstractSqlConnection = Abstra

async count<T extends AnyEntity<T>>(entityName: string, where: any, options: CountOptions<T> = {}): Promise<number> {
const meta = this.metadata.find(entityName);

if (meta?.virtual) {
return this.countVirtual<T>(entityName, where, options);
}

const qb = this.createQueryBuilder(entityName, options.ctx, options.connectionType, false)
.groupBy(options.groupBy!)
.having(options.having!)
Expand Down
4 changes: 4 additions & 0 deletions packages/mongodb/src/MongoDriver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ export class MongoDriver extends DatabaseDriver<MongoConnection> {
}

async count<T extends AnyEntity<T>>(entityName: string, where: FilterQuery<T>, options: CountOptions<T> = {}, ctx?: Transaction<ClientSession>): Promise<number> {
if (this.metadata.find(entityName)?.virtual) {
return this.countVirtual(entityName, where, options);
}

where = this.renameFields(entityName, where, true);
return this.rethrow(this.getConnection('read').countDocuments(entityName, where, ctx));
}
Expand Down
32 changes: 18 additions & 14 deletions tests/features/virtual-entities/virtual-entities.sqlite.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ describe('virtual entities (sqlite)', () => {
await createEntities(3);

const mock = mockLogger(orm);
const profiles = await orm.em.find(AuthorProfile, {});
const [profiles, total] = await orm.em.findAndCount(AuthorProfile, {});
expect(total).toBe(3);
expect(profiles).toEqual([
{
name: 'Jon Snow 1',
Expand Down Expand Up @@ -140,12 +141,13 @@ describe('virtual entities (sqlite)', () => {
expect(someProfiles4).toHaveLength(2);
expect(someProfiles4.map(p => p.name)).toEqual(['Jon Snow 2', 'Jon Snow 3']);

expect(mock.mock.calls).toHaveLength(5);
expect(mock.mock.calls[0][0]).toMatch(`select * from (${authorProfilesSQL}) as \`a0\``);
expect(mock.mock.calls[1][0]).toMatch(`select * from (${authorProfilesSQL}) as \`a0\` order by \`a0\`.\`name\` asc limit 2 offset 1`);
expect(mock.mock.calls[2][0]).toMatch(`select * from (${authorProfilesSQL}) as \`a0\` order by \`a0\`.\`name\` asc limit 2`);
expect(mock.mock.calls[3][0]).toMatch(`select * from (${authorProfilesSQL}) as \`a0\` where \`a0\`.\`name\` like 'Jon%' and \`a0\`.\`age\` >= 0 order by \`a0\`.\`name\` asc limit 2`);
expect(mock.mock.calls[4][0]).toMatch(`select * from (${authorProfilesSQL}) as \`a0\` where \`a0\`.\`name\` in ('Jon Snow 2', 'Jon Snow 3')`);
expect(mock.mock.calls).toHaveLength(6);
expect(mock.mock.calls[0][0]).toMatch(`select count(*) as count from (${authorProfilesSQL}) as \`a0\``);
expect(mock.mock.calls[1][0]).toMatch(`select * from (${authorProfilesSQL}) as \`a0\``);
expect(mock.mock.calls[2][0]).toMatch(`select * from (${authorProfilesSQL}) as \`a0\` order by \`a0\`.\`name\` asc limit 2 offset 1`);
expect(mock.mock.calls[3][0]).toMatch(`select * from (${authorProfilesSQL}) as \`a0\` order by \`a0\`.\`name\` asc limit 2`);
expect(mock.mock.calls[4][0]).toMatch(`select * from (${authorProfilesSQL}) as \`a0\` where \`a0\`.\`name\` like 'Jon%' and \`a0\`.\`age\` >= 0 order by \`a0\`.\`name\` asc limit 2`);
expect(mock.mock.calls[5][0]).toMatch(`select * from (${authorProfilesSQL}) as \`a0\` where \`a0\`.\`name\` in ('Jon Snow 2', 'Jon Snow 3')`);
expect(orm.em.getUnitOfWork().getIdentityMap().keys()).toHaveLength(0);
});

Expand All @@ -155,7 +157,8 @@ describe('virtual entities (sqlite)', () => {
await createEntities(3);

const mock = mockLogger(orm);
const books = await orm.em.find(BookWithAuthor, {});
const [books, total] = await orm.em.findAndCount(BookWithAuthor, {});
expect(total).toBe(9);
expect(books).toEqual([
{
title: 'My Life on the Wall, part 1/1',
Expand Down Expand Up @@ -230,12 +233,13 @@ describe('virtual entities (sqlite)', () => {
'inner join `tags_ordered` as `t1` on `b`.`id` = `t1`.`book4_id` ' +
'inner join `book_tag4` as `t` on `t1`.`book_tag4_id` = `t`.`id` ' +
'group by `b`.`id`';
expect(mock.mock.calls).toHaveLength(5);
expect(mock.mock.calls[0][0]).toMatch(`select * from (${sql}) as \`b0\``);
expect(mock.mock.calls[1][0]).toMatch(`select * from (${sql}) as \`b0\` order by \`b0\`.\`title\` asc limit 2 offset 1`);
expect(mock.mock.calls[2][0]).toMatch(`select * from (${sql}) as \`b0\` order by \`b0\`.\`title\` asc limit 2`);
expect(mock.mock.calls[3][0]).toMatch(`select * from (${sql}) as \`b0\` where \`b0\`.\`title\` like 'My Life%' and \`b0\`.\`author_name\` is not null order by \`b0\`.\`title\` asc limit 2`);
expect(mock.mock.calls[4][0]).toMatch(`select * from (${sql}) as \`b0\` where \`b0\`.\`title\` in ('My Life on the Wall, part 1/2', 'My Life on the Wall, part 1/3')`);
expect(mock.mock.calls).toHaveLength(6);
expect(mock.mock.calls[0][0]).toMatch(`select count(*) as count from (${sql}) as \`b0\``);
expect(mock.mock.calls[1][0]).toMatch(`select * from (${sql}) as \`b0\``);
expect(mock.mock.calls[2][0]).toMatch(`select * from (${sql}) as \`b0\` order by \`b0\`.\`title\` asc limit 2 offset 1`);
expect(mock.mock.calls[3][0]).toMatch(`select * from (${sql}) as \`b0\` order by \`b0\`.\`title\` asc limit 2`);
expect(mock.mock.calls[4][0]).toMatch(`select * from (${sql}) as \`b0\` where \`b0\`.\`title\` like 'My Life%' and \`b0\`.\`author_name\` is not null order by \`b0\`.\`title\` asc limit 2`);
expect(mock.mock.calls[5][0]).toMatch(`select * from (${sql}) as \`b0\` where \`b0\`.\`title\` in ('My Life on the Wall, part 1/2', 'My Life on the Wall, part 1/3')`);

expect(orm.em.getUnitOfWork().getIdentityMap().keys()).toHaveLength(0);
expect(mock.mock.calls[0][0]).toMatch(sql);
Expand Down

0 comments on commit 5bb4ebe

Please sign in to comment.