From 99225c379432c0d57f454e054db39d635d546152 Mon Sep 17 00:00:00 2001 From: Anna Henningsen Date: Thu, 16 Dec 2021 13:02:08 +0100 Subject: [PATCH] fix(shell-api): align `Number*` type validation with legacy shell MONGOSH-1073 Extend our type validation helper to allow specifying individual BSON types, and use that to loosen the restrictions on the legacy `Number*()` constructors. --- packages/shell-api/src/helpers.ts | 12 ++++++++++-- packages/shell-api/src/shell-bson.spec.ts | 23 ++++++++++++++++++++--- packages/shell-api/src/shell-bson.ts | 21 +++++++++++---------- 3 files changed, 41 insertions(+), 15 deletions(-) diff --git a/packages/shell-api/src/helpers.ts b/packages/shell-api/src/helpers.ts index 2d35c82d62..9ae4808cc5 100644 --- a/packages/shell-api/src/helpers.ts +++ b/packages/shell-api/src/helpers.ts @@ -94,8 +94,16 @@ export function assertArgsDefinedType(args: any[], expectedTypes: Array e !== undefined).join(' or '); + const expectedTypesList: Array = + typeof expected === 'string' ? [expected] : expected; + const isExpectedTypeof = expectedTypesList.includes(typeof arg); + const isExpectedBson = expectedTypesList.includes(`bson:${arg?._bsontype}`); + + if (!isExpectedTypeof && !isExpectedBson) { + const expectedMsg = expectedTypesList + .filter(e => e !== undefined) + .map(e => e?.replace(/^bson:/, '')) + .join(' or '); throw new MongoshInvalidInputError( `Argument at position ${i} must be of type ${expectedMsg}, got ${typeof arg} instead${getAssertCaller(func)}`, CommonErrors.InvalidArgument diff --git a/packages/shell-api/src/shell-bson.spec.ts b/packages/shell-api/src/shell-bson.spec.ts index e7aa3ad5f4..2dd8d2c6d8 100644 --- a/packages/shell-api/src/shell-bson.spec.ts +++ b/packages/shell-api/src/shell-bson.spec.ts @@ -544,7 +544,7 @@ describe('Shell BSON', () => { try { (shellBson.NumberLong as any)({}); } catch (e) { - return expect(e.message).to.match(/string or number, got object.+\(NumberLong\)/); + return expect(e.message).to.match(/string or number or Long or Int32, got object.+\(NumberLong\)/); } expect.fail('Expecting error, nothing thrown'); }); @@ -576,7 +576,7 @@ describe('Shell BSON', () => { try { (shellBson.NumberDecimal as any)({}); } catch (e) { - return expect(e.message).to.match(/string or number, got object.+\(NumberDecimal\)/); + return expect(e.message).to.match(/string or number or Long or Int32 or Decimal128, got object.+\(NumberDecimal\)/); } expect.fail('Expecting error, nothing thrown'); }); @@ -609,12 +609,29 @@ describe('Shell BSON', () => { try { (shellBson.NumberInt as any)({}); } catch (e) { - return expect(e.message).to.match(/string or number, got object.+\(NumberInt\)/); + return expect(e.message).to.match(/string or number or Long or Int32, got object.+\(NumberInt\)/); } expect.fail('Expecting error, nothing thrown'); }); }); + describe('Number type cross-construction', () => { + it('matches the legacy shell', () => { + const { NumberInt, NumberLong, NumberDecimal } = shellBson as any; + expect(NumberInt(null).toString()).to.equal('0'); + expect(NumberLong(null).toString()).to.equal('0'); + + expect(NumberInt(NumberInt(1234)).toString()).to.equal('1234'); + expect(NumberInt(NumberLong(1234)).toString()).to.equal('1234'); + expect(NumberInt(NumberLong(1234)).toString()).to.equal('1234'); + expect(NumberLong(NumberInt(1234)).toString()).to.equal('1234'); + expect(NumberLong(NumberLong(1234)).toString()).to.equal('1234'); + expect(NumberDecimal(NumberInt(1234)).toString()).to.equal('1234'); + expect(NumberDecimal(NumberLong(1234)).toString()).to.equal('1234'); + expect(NumberDecimal(NumberDecimal(1234)).toString()).to.equal('1234'); + }); + }); + describe('EJSON', () => { it('serializes and de-serializes data', () => { const input = { a: new Date() }; diff --git a/packages/shell-api/src/shell-bson.ts b/packages/shell-api/src/shell-bson.ts index 83aa8d1aaa..affe70c892 100644 --- a/packages/shell-api/src/shell-bson.ts +++ b/packages/shell-api/src/shell-bson.ts @@ -121,24 +121,25 @@ export default function constructShellBson(bson: typeof BSON, printWarning: (msg return new bson.Code(c, s); }, { ...bson.Code, prototype: bson.Code.prototype }), NumberDecimal: Object.assign(function NumberDecimal(s = '0'): typeof bson.Decimal128.prototype { - assertArgsDefinedType([s], [['string', 'number']], 'NumberDecimal'); - if (typeof s === 'string') { - return bson.Decimal128.fromString(s); + assertArgsDefinedType([s], [['string', 'number', 'bson:Long', 'bson:Int32', 'bson:Decimal128']], 'NumberDecimal'); + if (typeof s === 'number') { + printWarning('NumberDecimal: specifying a number as argument is deprecated and may lead to loss of precision, pass a string instead'); } - printWarning('NumberDecimal: specifying a number as argument is deprecated and may lead to loss of precision, pass a string instead'); return bson.Decimal128.fromString(`${s}`); }, { prototype: bson.Decimal128.prototype }), NumberInt: Object.assign(function NumberInt(v = '0'): typeof bson.Int32.prototype { - assertArgsDefinedType([v], [['string', 'number']], 'NumberInt'); + v ??= '0'; + assertArgsDefinedType([v], [['string', 'number', 'bson:Long', 'bson:Int32']], 'NumberInt'); return new bson.Int32(parseInt(`${v}`, 10)); }, { prototype: bson.Int32.prototype }), NumberLong: Object.assign(function NumberLong(s: string | number = '0'): typeof bson.Long.prototype { - assertArgsDefinedType([s], [['string', 'number']], 'NumberLong'); - if (typeof s === 'string') { - return bson.Long.fromString(s); + s ??= '0'; + assertArgsDefinedType([s], [['string', 'number', 'bson:Long', 'bson:Int32']], 'NumberLong'); + if (typeof s === 'number') { + printWarning('NumberLong: specifying a number as argument is deprecated and may lead to loss of precision, pass a string instead'); + return bson.Long.fromNumber(s); } - printWarning('NumberLong: specifying a number as argument is deprecated and may lead to loss of precision, pass a string instead'); - return bson.Long.fromNumber(s); + return bson.Long.fromString(`${s}`); }, { prototype: bson.Long.prototype }), ISODate: function ISODate(input?: string): Date { if (!input) input = new Date().toISOString();