Skip to content

Commit

Permalink
feat(query-builder): respect discriminator column when joining STI re…
Browse files Browse the repository at this point in the history
…lation

Closes #4351
  • Loading branch information
B4nan committed Nov 5, 2023
1 parent a0e2c7f commit 57b7094
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 22 deletions.
8 changes: 4 additions & 4 deletions packages/core/src/EntityManager.ts
Expand Up @@ -106,10 +106,10 @@ export class EntityManager<D extends IDatabaseDriver = IDatabaseDriver> {
* @internal
*/
constructor(readonly config: Configuration,
private readonly driver: D,
private readonly metadata: MetadataStorage,
private readonly useContext = true,
private readonly eventManager = new EventManager(config.get('subscribers'))) { }
protected readonly driver: D,
protected readonly metadata: MetadataStorage,
protected readonly useContext = true,
protected readonly eventManager = new EventManager(config.get('subscribers'))) { }

/**
* Gets the Driver instance used by this EntityManager.
Expand Down
16 changes: 8 additions & 8 deletions packages/core/src/metadata/MetadataDiscovery.ts
Expand Up @@ -474,12 +474,12 @@ export class MetadataDiscovery {
}

if (!prop.joinColumns) {
prop.joinColumns = prop.referencedColumnNames.map(referencedColumnName => this.namingStrategy.joinKeyColumnName(meta.root.className, referencedColumnName, meta.compositePK));
prop.joinColumns = prop.referencedColumnNames.map(referencedColumnName => this.namingStrategy.joinKeyColumnName(meta.className, referencedColumnName, meta.compositePK));
}

if (!prop.inverseJoinColumns) {
const meta2 = this.metadata.get(prop.type);
prop.inverseJoinColumns = this.initManyToOneFieldName(prop, meta2.root.className);
prop.inverseJoinColumns = this.initManyToOneFieldName(prop, meta2.className);
}
}

Expand Down Expand Up @@ -608,7 +608,7 @@ export class MetadataDiscovery {
}

schemaName ??= meta.schema;
const targetType = prop.targetMeta!.root.className;
const targetType = prop.targetMeta!.className;
const data = new EntityMetadata({
name: prop.pivotTable,
className: prop.pivotTable,
Expand All @@ -626,9 +626,9 @@ export class MetadataDiscovery {
}

// handle self-referenced m:n with same default field names
if (meta.root.name === targetType && prop.joinColumns.every((joinColumn, idx) => joinColumn === prop.inverseJoinColumns[idx])) {
prop.joinColumns = prop.referencedColumnNames.map(name => this.namingStrategy.joinKeyColumnName(meta.root.className + '_1', name, meta.compositePK));
prop.inverseJoinColumns = prop.referencedColumnNames.map(name => this.namingStrategy.joinKeyColumnName(meta.root.className + '_2', name, meta.compositePK));
if (meta.name === targetType && prop.joinColumns.every((joinColumn, idx) => joinColumn === prop.inverseJoinColumns[idx])) {
prop.joinColumns = prop.referencedColumnNames.map(name => this.namingStrategy.joinKeyColumnName(meta.className + '_1', name, meta.compositePK));
prop.inverseJoinColumns = prop.referencedColumnNames.map(name => this.namingStrategy.joinKeyColumnName(meta.className + '_2', name, meta.compositePK));

if (prop.inversedBy) {
const prop2 = this.metadata.get(targetType).properties[prop.inversedBy];
Expand All @@ -637,8 +637,8 @@ export class MetadataDiscovery {
}
}

data.properties[meta.root.name + '_owner'] = this.definePivotProperty(prop, meta.root.name + '_owner', meta.root.name!, targetType + '_inverse', true);
data.properties[targetType + '_inverse'] = this.definePivotProperty(prop, targetType + '_inverse', targetType, meta.root.name + '_owner', false);
data.properties[meta.name + '_owner'] = this.definePivotProperty(prop, meta.name + '_owner', meta.name!, targetType + '_inverse', true);
data.properties[targetType + '_inverse'] = this.definePivotProperty(prop, targetType + '_inverse', targetType, meta.name + '_owner', false);

return this.metadata.set(data.className, data);
}
Expand Down
28 changes: 21 additions & 7 deletions packages/core/src/utils/Utils.ts
Expand Up @@ -415,13 +415,27 @@ export class Utils {
/**
* Renames object key, keeps order of properties.
*/
static renameKey<T>(payload: T, from: string | keyof T, to: string): void {
if (Utils.isObject(payload) && (from as string) in payload && !(to in payload)) {
Object.keys(payload).forEach(key => {
const value = payload[key];
delete payload[key];
payload[from === key ? to : key as keyof T] = value;
}, payload);
static renameKey<T>(payload: T, from: string | keyof T, to: string, recursive = false): void {
if (Utils.isObject(payload)) {
if ((from as string) in payload && !(to in payload)) {
Object.keys(payload).forEach(key => {
const value = payload[key];
delete payload[key];
payload[from === key ? to : key as keyof T] = value;
}, payload);
}

if (recursive) {
Object.keys(payload).forEach(key => {
Utils.renameKey(payload[key], from, to, recursive);
});
}

return;
}

if (recursive && Array.isArray(payload)) {
payload.forEach(item => Utils.renameKey(item, from, to, recursive));
}
}

Expand Down
17 changes: 16 additions & 1 deletion packages/knex/src/SqlEntityManager.ts
@@ -1,5 +1,15 @@
import type { Knex } from 'knex';
import { EntityManager, type AnyEntity, type ConnectionType, type Dictionary, type EntityData, type EntityName, type EntityRepository, type GetRepository, type QueryResult } from '@mikro-orm/core';
import {
EntityManager,
type AnyEntity,
type ConnectionType,
type EntityData,
type EntityName,
type EntityRepository,
type GetRepository,
type QueryResult,
type FilterQuery,
} from '@mikro-orm/core';
import type { AbstractSqlDriver } from './AbstractSqlDriver';
import { QueryBuilder } from './query';
import type { SqlEntityRepository } from './SqlEntityRepository';
Expand Down Expand Up @@ -40,4 +50,9 @@ export class SqlEntityManager<D extends AbstractSqlDriver = AbstractSqlDriver> e
return super.getRepository<T, U>(entityName);
}

protected override applyDiscriminatorCondition<Entity extends object>(entityName: string, where: FilterQuery<Entity>): FilterQuery<Entity> {
// this is handled in QueryBuilder now for SQL drivers
return where;
}

}
27 changes: 25 additions & 2 deletions packages/knex/src/query/QueryBuilder.ts
Expand Up @@ -1091,6 +1091,28 @@ export class QueryBuilder<T extends object = AnyEntity> {
return qb;
}

private applyDiscriminatorCondition(): void {
const meta = this.mainAlias.metadata;

if (!meta?.discriminatorValue) {
return;
}

const types = Object.values(meta.root.discriminatorMap!).map(cls => this.metadata.find(cls)!);
const children: EntityMetadata[] = [];
const lookUpChildren = (ret: EntityMetadata[], type: string) => {
const children = types.filter(meta2 => meta2.extends === type);
children.forEach(m => lookUpChildren(ret, m.className));
ret.push(...children.filter(c => c.discriminatorValue));

return children;
};
lookUpChildren(children, meta.className);
this.andWhere({
[meta.root.discriminatorColumn!]: children.length > 0 ? { $in: [meta.discriminatorValue, ...children.map(c => c.discriminatorValue)] } : meta.discriminatorValue,
});
}

private finalize(): void {
if (this.finalized) {
return;
Expand All @@ -1101,6 +1123,7 @@ export class QueryBuilder<T extends object = AnyEntity> {
}

const meta = this.mainAlias.metadata as EntityMetadata<T>;
this.applyDiscriminatorCondition();

if (meta && this.flags.has(QueryFlag.AUTO_JOIN_ONE_TO_ONE_OWNER)) {
const relationsToPopulate = this._populate.map(({ field }) => field);
Expand Down Expand Up @@ -1302,8 +1325,8 @@ export class QueryBuilder<T extends object = AnyEntity> {
const pivotAlias = this.getNextAlias(pivotMeta.name!);

this._joins[field] = this.helper.joinPivotTable(field, prop, this.mainAlias.aliasName, pivotAlias, 'leftJoin');
Utils.renameKey(this._cond, `${field}.${owner.name}`, Utils.getPrimaryKeyHash(owner.fieldNames.map(fieldName => `${pivotAlias}.${fieldName}`)));
Utils.renameKey(this._cond, `${field}.${inverse.name}`, Utils.getPrimaryKeyHash(inverse.fieldNames.map(fieldName => `${pivotAlias}.${fieldName}`)));
Utils.renameKey(this._cond, `${field}.${owner.name}`, Utils.getPrimaryKeyHash(owner.fieldNames.map(fieldName => `${pivotAlias}.${fieldName}`)), true);
Utils.renameKey(this._cond, `${field}.${inverse.name}`, Utils.getPrimaryKeyHash(inverse.fieldNames.map(fieldName => `${pivotAlias}.${fieldName}`)), true);
this._populateMap[field] = this._joins[field].alias;
}

Expand Down
6 changes: 6 additions & 0 deletions packages/knex/src/query/QueryBuilderHelper.ts
Expand Up @@ -260,6 +260,12 @@ export class QueryBuilderHelper {
conditions.push(`${this.knex.ref(left)} = ${this.knex.ref(right)}`);
});

if (join.prop.targetMeta!.discriminatorValue && !join.path?.endsWith('[pivot]')) {
const typeProperty = join.prop.targetMeta!.root.discriminatorColumn!;
const alias = !join.prop.owner ? join.inverseAlias ?? join.alias : join.ownerAlias;
join.cond[`${alias}.${typeProperty}`] = join.prop.targetMeta!.discriminatorValue;
}

Object.keys(join.cond).forEach(key => {
conditions.push(this.processJoinClause(key, join.cond[key], params));
});
Expand Down
97 changes: 97 additions & 0 deletions tests/features/single-table-inheritance/GH4351.test.ts
@@ -0,0 +1,97 @@
import { MikroORM } from '@mikro-orm/better-sqlite';
import { Collection, Entity, Enum, ManyToMany, ManyToOne, OneToMany, PrimaryKey, Property, Rel } from '@mikro-orm/core';

@Entity()
class User {

@PrimaryKey()
id!: number;

@Property({ nullable: true })
name?: string;

@OneToMany(() => Post, post => post.user)
posts = new Collection<Post>(this);

@ManyToMany(() => Post)
posts2 = new Collection<Post>(this);

}

@Entity({
discriminatorColumn: 'type',
abstract: true,
})
abstract class BaseEntity {

@PrimaryKey()
id!: number;

@Enum()
type!: 'post' | 'comment';

@ManyToOne(() => User, {
fieldName: 'userId',
deleteRule: 'set null',
nullable: true,
ref: true,
})
user?: Rel<User>;

}

@Entity({ discriminatorValue: 'post' })
class Post extends BaseEntity {

@Property({ nullable: true })
postField?: string;

}

@Entity({ discriminatorValue: 'comment' })
class Comment extends BaseEntity {

@Property({ nullable: true })
commentField?: string;

}

let orm: MikroORM;

beforeAll(async () => {
orm = await MikroORM.init({
dbName: ':memory:',
entities: [BaseEntity, Post, Comment],
});
await orm.schema.createSchema();
});

afterAll(async () => {
await orm.close(true);
});

test('it should add discriminator to the query', async () => {
const user = new User();
user.posts2.add(new Post());
const post = new Post();
post.user = user;
const comment = new Comment();
comment.user = user;

await orm.em.fork().persistAndFlush([user, post, comment]);

const qb1 = orm.em.qb(User, 'u').join('u.posts', 'p');
const res1 = await qb1;
expect(qb1.getFormattedQuery()).toMatch("select `u`.* from `user` as `u` inner join `base_entity` as `p` on `u`.`id` = `p`.`userId` and `p`.`type` = 'post'");
expect(res1).toHaveLength(1);

const qb2 = orm.em.qb(Post, 'p');
const res2 = await qb2;
expect(qb2.getFormattedQuery()).toMatch("select `p`.* from `base_entity` as `p` where `p`.`type` = 'post'");
expect(res2).toHaveLength(2);

const qb3 = orm.em.qb(User, 'u').join('u.posts2', 'p');
const res3 = await qb3;
expect(qb3.getFormattedQuery()).toMatch("select `u`.* from `user` as `u` inner join `user_posts2` as `u1` on `u`.`id` = `u1`.`user_id` inner join `base_entity` as `p` on `u1`.`post_id` = `p`.`id` and `p`.`type` = 'post'");
expect(res3).toHaveLength(1);
});

0 comments on commit 57b7094

Please sign in to comment.