Skip to content

Commit

Permalink
feat(core): implement transaction lifecycle hooks (#1213)
Browse files Browse the repository at this point in the history
Current lifecycle hooks cover entity and unit of work events, but
transaction-related ones are missing. Add beforeTransactionStart,
afterTransactionStart, beforeTransactionCommit, afterTransactionCommit,
beforeTransactionRollback, and afterTransactionRollback

Closes #1175
  • Loading branch information
rhyek committed Dec 17, 2020
1 parent 26f62ca commit 0f81ff1
Show file tree
Hide file tree
Showing 12 changed files with 868 additions and 42 deletions.
10 changes: 5 additions & 5 deletions packages/core/src/EntityManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { AnyEntity, Dictionary, EntityData, EntityMetadata, EntityName, FilterDe
import { LoadStrategy, LockMode, QueryOrderMap, ReferenceType, SCALAR_TYPES } from './enums';
import { MetadataStorage } from './metadata';
import { Transaction } from './connections';
import { EventManager } from './events';
import { EventManager, TransactionEventBroadcaster } from './events';
import { EntityComparator } from './utils/EntityComparator';
import { OptimisticLockError, ValidationError } from './errors';

Expand Down Expand Up @@ -335,31 +335,31 @@ export class EntityManager<D extends IDatabaseDriver = IDatabaseDriver> {
await em.flush();

return ret;
}, ctx);
}, ctx, new TransactionEventBroadcaster(em));
});
}

/**
* Starts new transaction bound to this EntityManager. Use `ctx` parameter to provide the parent when nesting transactions.
*/
async begin(ctx?: Transaction): Promise<void> {
this.transactionContext = await this.getConnection('write').begin(ctx);
this.transactionContext = await this.getConnection('write').begin(ctx, new TransactionEventBroadcaster(this));
}

/**
* Commits the transaction bound to this EntityManager. Flushes before doing the actual commit query.
*/
async commit(): Promise<void> {
await this.flush();
await this.getConnection('write').commit(this.transactionContext);
await this.getConnection('write').commit(this.transactionContext, new TransactionEventBroadcaster(this));
delete this.transactionContext;
}

/**
* Rollbacks the transaction bound to this EntityManager.
*/
async rollback(): Promise<void> {
await this.getConnection('write').rollback(this.transactionContext);
await this.getConnection('write').rollback(this.transactionContext, new TransactionEventBroadcaster(this));
delete this.transactionContext;
}

Expand Down
9 changes: 5 additions & 4 deletions packages/core/src/connections/Connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { Configuration, ConnectionOptions, Utils } from '../utils';
import { MetadataStorage } from '../metadata';
import { Dictionary } from '../typings';
import { Platform } from '../platforms/Platform';
import { TransactionEventBroadcaster } from '../events/TransactionEventBroadcaster';

export abstract class Connection {

Expand Down Expand Up @@ -41,19 +42,19 @@ export abstract class Connection {
*/
abstract getDefaultClientUrl(): string;

async transactional<T>(cb: (trx: Transaction) => Promise<T>, ctx?: Transaction): Promise<T> {
async transactional<T>(cb: (trx: Transaction) => Promise<T>, ctx?: Transaction, eventBroadcaster?: TransactionEventBroadcaster): Promise<T> {
throw new Error(`Transactions are not supported by current driver`);
}

async begin(ctx?: Transaction): Promise<unknown> {
async begin(ctx?: Transaction, eventBroadcaster?: TransactionEventBroadcaster): Promise<unknown> {
throw new Error(`Transactions are not supported by current driver`);
}

async commit(ctx: Transaction): Promise<void> {
async commit(ctx: Transaction, eventBroadcaster?: TransactionEventBroadcaster): Promise<void> {
throw new Error(`Transactions are not supported by current driver`);
}

async rollback(ctx: Transaction): Promise<void> {
async rollback(ctx: Transaction, eventBroadcaster?: TransactionEventBroadcaster): Promise<void> {
throw new Error(`Transactions are not supported by current driver`);
}

Expand Down
8 changes: 8 additions & 0 deletions packages/core/src/enums.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,12 @@ export enum EventType {
beforeFlush = 'beforeFlush',
onFlush = 'onFlush',
afterFlush = 'afterFlush',
beforeTransactionStart = 'beforeTransactionStart',
afterTransactionStart = 'afterTransactionStart',
beforeTransactionCommit = 'beforeTransactionCommit',
afterTransactionCommit = 'afterTransactionCommit',
beforeTransactionRollback = 'beforeTransactionRollback',
afterTransactionRollback = 'afterTransactionRollback',
}

export type TransactionEventType = EventType.beforeTransactionStart | EventType.afterTransactionStart | EventType.beforeTransactionCommit | EventType.afterTransactionCommit | EventType.beforeTransactionRollback | EventType.afterTransactionRollback;
11 changes: 6 additions & 5 deletions packages/core/src/events/EventManager.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { AnyEntity, EntityMetadata } from '../typings';
import { EventArgs, EventSubscriber, FlushEventArgs } from './EventSubscriber';
import { EventArgs, EventSubscriber, FlushEventArgs, TransactionEventArgs } from './EventSubscriber';
import { Utils } from '../utils';
import { EventType } from '../enums';
import { EventType, TransactionEventType } from '../enums';

export class EventManager {

Expand All @@ -22,9 +22,10 @@ export class EventManager {
});
}

dispatchEvent<T extends AnyEntity<T>>(event: TransactionEventType, args: TransactionEventArgs): unknown;
dispatchEvent<T extends AnyEntity<T>>(event: EventType.onInit, args: Partial<EventArgs<T>>): unknown;
dispatchEvent<T extends AnyEntity<T>>(event: EventType, args: Partial<EventArgs<T> | FlushEventArgs>): Promise<unknown>;
dispatchEvent<T extends AnyEntity<T>>(event: EventType, args: Partial<EventArgs<T> | FlushEventArgs>): Promise<unknown> | unknown {
dispatchEvent<T extends AnyEntity<T>>(event: EventType, args: Partial<EventArgs<T> | FlushEventArgs | TransactionEventArgs>): Promise<unknown> | unknown {
const listeners: [EventType, EventSubscriber<T>][] = [];
const entity: T = (args as EventArgs<T>).entity;

Expand All @@ -41,10 +42,10 @@ export class EventManager {
}

if (event === EventType.onInit) {
return listeners.forEach(listener => listener[1][listener[0]]!(args as (EventArgs<T> & FlushEventArgs)));
return listeners.forEach(listener => listener[1][listener[0]]!(args as (EventArgs<T> & FlushEventArgs & TransactionEventArgs)));
}

return Utils.runSerial(listeners, listener => listener[1][listener[0]]!(args as (EventArgs<T> & FlushEventArgs)) as Promise<void>);
return Utils.runSerial(listeners, listener => listener[1][listener[0]]!(args as (EventArgs<T> & FlushEventArgs & TransactionEventArgs)) as Promise<void>);
}

hasListeners<T extends AnyEntity<T>>(event: EventType, meta: EntityMetadata<T>): boolean {
Expand Down
13 changes: 13 additions & 0 deletions packages/core/src/events/EventSubscriber.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { EntityName } from '../typings';
import { EntityManager } from '../EntityManager';
import { ChangeSet, UnitOfWork } from '../unit-of-work';
import { Transaction } from '../connections';

export interface EventArgs<T> {
entity: T;
Expand All @@ -12,6 +13,11 @@ export interface FlushEventArgs extends Omit<EventArgs<unknown>, 'entity'> {
uow: UnitOfWork;
}

export interface TransactionEventArgs extends Omit<EventArgs<unknown>, 'entity' | 'changeSet'> {
transaction?: Transaction;
uow?: UnitOfWork;
}

export interface EventSubscriber<T = any> {
getSubscribedEntities?(): EntityName<T>[];
onInit?(args: EventArgs<T>): void;
Expand All @@ -24,4 +30,11 @@ export interface EventSubscriber<T = any> {
beforeFlush?(args: FlushEventArgs): Promise<void>;
onFlush?(args: FlushEventArgs): Promise<void>;
afterFlush?(args: FlushEventArgs): Promise<void>;

beforeTransactionStart?(args: TransactionEventArgs): Promise<void>;
afterTransactionStart?(args: TransactionEventArgs): Promise<void>;
beforeTransactionCommit?(args: TransactionEventArgs): Promise<void>;
afterTransactionCommit?(args: TransactionEventArgs): Promise<void>;
beforeTransactionRollback?(args: TransactionEventArgs): Promise<void>;
afterTransactionRollback?(args: TransactionEventArgs): Promise<void>;
}
17 changes: 17 additions & 0 deletions packages/core/src/events/TransactionEventBroadcaster.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { Transaction } from '../connections';
import { EntityManager } from '../EntityManager';
import { TransactionEventType } from '../enums';
import { UnitOfWork } from '../unit-of-work';

export class TransactionEventBroadcaster {

constructor(
private entityManager: EntityManager,
private uow?: UnitOfWork
) {}

async dispatchEvent(event: TransactionEventType, transaction?: Transaction) {
await this.entityManager.getEventManager().dispatchEvent(event, { em: this.entityManager, transaction, uow: this.uow });
}

}
1 change: 1 addition & 0 deletions packages/core/src/events/index.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
export * from './EventSubscriber';
export * from './EventManager';
export * from './TransactionEventBroadcaster';
3 changes: 2 additions & 1 deletion packages/core/src/unit-of-work/UnitOfWork.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { EntityManager } from '../EntityManager';
import { Cascade, EventType, LockMode, ReferenceType } from '../enums';
import { OptimisticLockError, ValidationError } from '../errors';
import { Transaction } from '../connections';
import { TransactionEventBroadcaster } from '../events';
import { IdentityMap } from './IdentityMap';

export class UnitOfWork {
Expand Down Expand Up @@ -223,7 +224,7 @@ export class UnitOfWork {
const runInTransaction = !this.em.isInTransaction() && platform.supportsTransactions() && this.em.config.get('implicitTransactions');

if (runInTransaction) {
await this.em.getConnection('write').transactional(trx => this.persistToDatabase(groups, trx));
await this.em.getConnection('write').transactional(trx => this.persistToDatabase(groups, trx), undefined, new TransactionEventBroadcaster(this.em, this));
} else {
await this.persistToDatabase(groups, this.em.getTransactionContext());
}
Expand Down
58 changes: 49 additions & 9 deletions packages/knex/src/AbstractSqlConnection.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import Knex, { Config, QueryBuilder, Raw, Client, Transaction as KnexTransaction } from 'knex';
import { readFile } from 'fs-extra';
import { AnyEntity, Configuration, Connection, ConnectionOptions, EntityData, QueryResult, Transaction, Utils } from '@mikro-orm/core';
import {
AnyEntity, Configuration, Connection, ConnectionOptions, EntityData, EventType, QueryResult,
Transaction, TransactionEventBroadcaster, Utils,
} from '@mikro-orm/core';
import { AbstractSqlPlatform } from './AbstractSqlPlatform';

const parentTransactionSymbol = Symbol('parentTransaction');

function isRootTransaction<T>(trx: Transaction<T>) {
return !Object.getOwnPropertySymbols(trx).includes(parentTransactionSymbol);
}

export abstract class AbstractSqlConnection extends Connection {

protected platform!: AbstractSqlPlatform;
Expand Down Expand Up @@ -30,21 +39,52 @@ export abstract class AbstractSqlConnection extends Connection {
}
}

async transactional<T>(cb: (trx: Transaction<KnexTransaction>) => Promise<T>, ctx?: Transaction<KnexTransaction>): Promise<T> {
return (ctx || this.client).transaction(cb);
async transactional<T>(cb: (trx: Transaction<KnexTransaction>) => Promise<T>, ctx?: Transaction<KnexTransaction>, eventBroadcaster?: TransactionEventBroadcaster): Promise<T> {
const trx = await this.begin(ctx, eventBroadcaster);
try {
const ret = await cb(trx);
await this.commit(trx, eventBroadcaster);
return ret;
} catch (error) {
await this.rollback(trx, eventBroadcaster);
throw error;
}
}

async begin(ctx?: KnexTransaction): Promise<KnexTransaction> {
return (ctx || this.client).transaction();
async begin(ctx?: KnexTransaction, eventBroadcaster?: TransactionEventBroadcaster): Promise<KnexTransaction> {
if (!ctx) {
await eventBroadcaster?.dispatchEvent(EventType.beforeTransactionStart);
}
const trx = await (ctx || this.client).transaction();
if (!ctx) {
await eventBroadcaster?.dispatchEvent(EventType.afterTransactionStart, trx);
} else {
trx[parentTransactionSymbol] = ctx;
}
return trx;
}

async commit(ctx: KnexTransaction): Promise<void> {
async commit(ctx: KnexTransaction, eventBroadcaster?: TransactionEventBroadcaster): Promise<void> {
const runTrxHooks = isRootTransaction(ctx);
if (runTrxHooks) {
await eventBroadcaster?.dispatchEvent(EventType.beforeTransactionCommit, ctx);
}
ctx.commit();
return ctx.executionPromise; // https://github.com/knex/knex/issues/3847#issuecomment-626330453
await ctx.executionPromise; // https://github.com/knex/knex/issues/3847#issuecomment-626330453
if (runTrxHooks) {
await eventBroadcaster?.dispatchEvent(EventType.afterTransactionCommit, ctx);
}
}

async rollback(ctx: KnexTransaction): Promise<void> {
return ctx.rollback();
async rollback(ctx: KnexTransaction, eventBroadcaster?: TransactionEventBroadcaster): Promise<void> {
const runTrxHooks = isRootTransaction(ctx);
if (runTrxHooks) {
await eventBroadcaster?.dispatchEvent(EventType.beforeTransactionRollback, ctx);
}
await ctx.rollback();
if (runTrxHooks) {
await eventBroadcaster?.dispatchEvent(EventType.afterTransactionRollback, ctx);
}
}

async execute<T extends QueryResult | EntityData<AnyEntity> | EntityData<AnyEntity>[] = EntityData<AnyEntity>[]>(queryOrKnex: string | QueryBuilder | Raw, params: any[] = [], method: 'all' | 'get' | 'run' = 'all', ctx?: Transaction): Promise<T> {
Expand Down
43 changes: 27 additions & 16 deletions packages/mongodb/src/MongoConnection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
import { inspect } from 'util';
import {
Connection, ConnectionConfig, QueryResult, Transaction, Utils, QueryOrder, QueryOrderMap,
FilterQuery, AnyEntity, EntityName, Dictionary, EntityData,
FilterQuery, AnyEntity, EntityName, Dictionary, EntityData, TransactionEventBroadcaster, EventType,
} from '@mikro-orm/core';

export class MongoConnection extends Connection {
Expand Down Expand Up @@ -148,39 +148,50 @@ export class MongoConnection extends Connection {
return this.runQuery<T, number>('countDocuments', collection, undefined, where, ctx);
}

async transactional<T>(cb: (trx: Transaction<ClientSession>) => Promise<T>, ctx?: Transaction<ClientSession>): Promise<T> {
const session = ctx || this.client.startSession();
let ret: T = null as unknown as T;

async transactional<T>(cb: (trx: Transaction<ClientSession>) => Promise<T>, ctx?: Transaction<ClientSession>, eventBroadcaster?: TransactionEventBroadcaster): Promise<T> {
const session = await this.begin(ctx, eventBroadcaster);
try {
this.logQuery('db.begin();');
await session.withTransaction(async () => ret = await cb(session));
const ret = await cb(session);
await this.commit(session, eventBroadcaster);
return ret;
} catch (error) {
await this.rollback(session, eventBroadcaster);
throw error;
} finally {
session.endSession();
this.logQuery('db.commit();');
} catch (e) {
this.logQuery('db.rollback();');
throw e;
}

return ret;
}

async begin(ctx?: ClientSession): Promise<ClientSession> {
async begin(ctx?: ClientSession, eventBroadcaster?: TransactionEventBroadcaster): Promise<ClientSession> {
if (!ctx) {
/* istanbul ignore next */
await eventBroadcaster?.dispatchEvent(EventType.beforeTransactionStart);
}
const session = ctx || this.client.startSession();
session.startTransaction();
this.logQuery('db.begin();');
/* istanbul ignore next */
await eventBroadcaster?.dispatchEvent(EventType.afterTransactionStart, session);

return session;
}

async commit(ctx: ClientSession): Promise<void> {
async commit(ctx: ClientSession, eventBroadcaster?: TransactionEventBroadcaster): Promise<void> {
/* istanbul ignore next */
await eventBroadcaster?.dispatchEvent(EventType.beforeTransactionCommit, ctx);
await ctx.commitTransaction();
this.logQuery('db.commit();');
/* istanbul ignore next */
await eventBroadcaster?.dispatchEvent(EventType.afterTransactionCommit, ctx);
}

async rollback(ctx: ClientSession): Promise<void> {
async rollback(ctx: ClientSession, eventBroadcaster?: TransactionEventBroadcaster): Promise<void> {
/* istanbul ignore next */
await eventBroadcaster?.dispatchEvent(EventType.beforeTransactionRollback, ctx);
await ctx.abortTransaction();
this.logQuery('db.rollback();');
/* istanbul ignore next */
await eventBroadcaster?.dispatchEvent(EventType.afterTransactionRollback, ctx);
}

protected logQuery(query: string, took?: number): void {
Expand Down
Loading

0 comments on commit 0f81ff1

Please sign in to comment.