diff --git a/packages/dandi-contrib/data-pg/src/pg-db-queryable.ts b/packages/dandi-contrib/data-pg/src/pg-db-queryable.ts index c9d882e4..3bee847c 100644 --- a/packages/dandi-contrib/data-pg/src/pg-db-queryable.ts +++ b/packages/dandi-contrib/data-pg/src/pg-db-queryable.ts @@ -23,23 +23,12 @@ export class PgDbQueryableBase implements D ) {} public async query(cmd: string, ...args: any[]): Promise { - let result: QueryResult - if (args) { - args.forEach((arg, index) => { - args[index] = this.formatArg(arg) - }) - } - try { - result = await this.client.query(cmd, args) - } catch (err) { - throw new PgDbQueryError(err) - } - return result.rows + return this.queryInternal(cmd, args) } public async queryModel(model: Constructor, cmd: string, ...args: any[]): Promise { cmd = this.replaceSelectList(model, cmd) - const result = await this.query(cmd, ...args) + const result = await this.queryInternal(cmd, args) if (!result || !result.length) { return result } @@ -93,6 +82,21 @@ export class PgDbQueryableBase implements D return cmd.replace(/select\s+([\w\s,._]+)\s+from/i, `select\n${newSelect.join(',\n')}\nfrom`) } + protected async queryInternal(cmd: string, args: any[]): Promise { + let result: QueryResult + if (args) { + args.forEach((arg, index) => { + args[index] = this.formatArg(arg) + }) + } + try { + result = await this.client.query(cmd, args) + } catch (err) { + throw new PgDbQueryError(err) + } + return result.rows + } + private formatArg(arg: any): any { if (arg instanceof Uuid) { return `{${arg}}` diff --git a/packages/dandi-contrib/data-pg/src/pg-db-transaction-client.spec.ts b/packages/dandi-contrib/data-pg/src/pg-db-transaction-client.spec.ts index 11d8d57b..8f1f527c 100644 --- a/packages/dandi-contrib/data-pg/src/pg-db-transaction-client.spec.ts +++ b/packages/dandi-contrib/data-pg/src/pg-db-transaction-client.spec.ts @@ -12,7 +12,7 @@ import { ModelBuilderFixture } from '@dandi/model-builder/testing' import { expect } from 'chai' import { stub } from 'sinon' -describe.only('PgDbTransactionClient', function() { +describe('PgDbTransactionClient', function() { const harness = stubHarness(PgDbTransactionClient, PgDbPoolClientFixture.factory, @@ -93,6 +93,88 @@ describe.only('PgDbTransactionClient', function() { expect(this.client.query.thirdCall.args).to.deep.equal(['SELECT more FROM stuff', []]) }) + + it('throws an error if called when the transaction cannot accept a query', async function() { + + this.transactionClient.state = 'COMMITTING' + + await expect(this.transactionClient.query('SELECT foo FROM bar')) + .to.be.rejectedWith(InvalidTransactionStateError) + + }) + + }) + + describe('queryModel', function() { + + class TestModel {} + + beforeEach(function() { + stub(this.transactionClient, 'rollback') + }) + + it('begins the transaction if it has not already begun and sets the internal state to READY', async function() { + + await this.transactionClient.queryModel(TestModel, 'SELECT foo FROM bar') + + expect(this.transactionClient.state).to.equal('READY') + expect(this.client.query).to.have.been.calledTwice + expect(this.client.query.firstCall.args).to.deep.equal(['BEGIN', []]) + expect(this.client.query.secondCall.args).to.deep.equal(['SELECT foo FROM bar', []]) + }) + + it('does not send additional BEGIN queries if the transaction has already begun', async function() { + await this.transactionClient.queryModel(TestModel, 'INSERT INTO bar (foo) VALUES ($1)', 42) + await this.transactionClient.queryModel(TestModel, 'SELECT foo FROM bar') + + expect(this.client.query).to.have.been.calledThrice + expect(this.client.query.firstCall.args).to.deep.equal(['BEGIN', []]) + expect(this.client.query.secondCall.args).to.deep.equal(['INSERT INTO bar (foo) VALUES ($1)', [42]]) + expect(this.client.query.thirdCall.args).to.deep.equal(['SELECT foo FROM bar', []]) + }) + + it('rolls the transaction back if an exception is thrown and rethrows the error', async function() { + const catcher = stub() + this.client.query.onSecondCall().callsFake(() => { + throw new Error() + }) + this.client.query.onThirdCall().returns({ rows: [] }) + + try { + await this.transactionClient.queryModel(TestModel, 'SELECT foo FROM bar') + } catch (err) { + catcher(err) + } + expect(catcher).to.have.been.calledOnce + expect(this.client.query).to.have.been.calledTwice + expect(this.client.query.firstCall.args).to.deep.equal(['BEGIN', []]) + expect(this.client.query.secondCall.args).to.deep.equal(['SELECT foo FROM bar', []]) + expect(this.transactionClient.rollback).to.have.been.calledOnce + }) + + it('waits for an existing state transition before continuing', async function() { + + const firstQuery = this.transactionClient.queryModel(TestModel, 'SELECT foo FROM bar') + const secondQuery = this.transactionClient.queryModel(TestModel, 'SELECT more FROM stuff') + + await secondQuery + + expect(this.client.query).to.have.been.calledThrice + expect(this.client.query.firstCall.args).to.deep.equal(['BEGIN', []]) + expect(this.client.query.secondCall.args).to.deep.equal(['SELECT foo FROM bar', []]) + expect(this.client.query.thirdCall.args).to.deep.equal(['SELECT more FROM stuff', []]) + + }) + + it('throws an error if called when the transaction cannot accept a query', async function() { + + this.transactionClient.state = 'COMMITTING' + + await expect(this.transactionClient.queryModel(TestModel, 'SELECT foo FROM bar')) + .to.be.rejectedWith(InvalidTransactionStateError) + + }) + }) describe('commit', function() { @@ -150,6 +232,21 @@ describe.only('PgDbTransactionClient', function() { expect(receivedErr).to.be.instanceof(PgDbQueryError) expect(receivedErr.innerError).to.equal(err) }) + + it('rethrows the error if it cannot roll back', async function() { + + this.transactionClient.state = 'READY' + const err = new Error('Your llama is lloose!') + + this.client.query.callsFake(() => { + this.transactionClient.state = 'ROLLED_BACK' + return Promise.reject(err) + }) + + const commitErr = await expect(this.transactionClient.commit()).to.be.rejected + expect(commitErr.innerError).to.equal(err) + + }) }) describe('rollback', function() { diff --git a/packages/dandi-contrib/data-pg/src/pg-db-transaction-client.ts b/packages/dandi-contrib/data-pg/src/pg-db-transaction-client.ts index 082179e3..4aae2a23 100644 --- a/packages/dandi-contrib/data-pg/src/pg-db-transaction-client.ts +++ b/packages/dandi-contrib/data-pg/src/pg-db-transaction-client.ts @@ -71,7 +71,7 @@ export class PgDbTransactionClient extends PgDbQueryableBase imp } public async query(cmd: string, ...args: any[]): Promise { - return this.mutex.runLocked(async (lock) => { + return await this.mutex.runLocked(async (lock) => { await this.safeBeginTransaction() try { return await super.query(cmd, ...args) @@ -85,7 +85,7 @@ export class PgDbTransactionClient extends PgDbQueryableBase imp } public async queryModel(model: Constructor, cmd: string, ...args: any[]): Promise { - return this.mutex.runLocked(async (lock) => { + return await this.mutex.runLocked(async (lock) => { await this.safeBeginTransaction() try { return await super.queryModel(model, cmd, ...args) @@ -147,10 +147,6 @@ export class PgDbTransactionClient extends PgDbQueryableBase imp } private validateTransactionAction(action: TransactionAction): void { - // if (this.state === TRANSITIONS[action]) { - // return - // } - if (!ALLOWED_ACTIONS[this.state].includes(action)) { throw new InvalidTransactionStateError(`Cannot perform action ${action} while in transaction state ${this.state}`) }