diff --git a/packages/cli-repl/src/mongosh-repl.ts b/packages/cli-repl/src/mongosh-repl.ts index a5f50d7679..4fb6da5c3a 100644 --- a/packages/cli-repl/src/mongosh-repl.ts +++ b/packages/cli-repl/src/mongosh-repl.ts @@ -3,13 +3,12 @@ import { MongoshCommandFailed, MongoshInternalError, MongoshWarning } from '@mon import { changeHistory } from '@mongosh/history'; import i18n from '@mongosh/i18n'; import type { ServiceProvider } from '@mongosh/service-provider-core'; -import { EvaluationListener, ShellCliOptions, ShellInternalState } from '@mongosh/shell-api'; +import { EvaluationListener, ShellCliOptions, ShellInternalState, OnLoadResult } from '@mongosh/shell-api'; import { ShellEvaluator, ShellResult } from '@mongosh/shell-evaluator'; import type { MongoshBus, UserConfig } from '@mongosh/types'; import askpassword from 'askpassword'; import { Console } from 'console'; import { once } from 'events'; -import path from 'path'; import prettyRepl from 'pretty-repl'; import { ReplOptions, REPLServer } from 'repl'; import type { Readable, Writable } from 'stream'; @@ -318,27 +317,17 @@ class MongoshNodeRepl implements EvaluationListener { } } - async onLoad(filename: string): Promise { + async onLoad(filename: string): Promise { const repl = this.runtimeState().repl; const { contents, absolutePath } = await this.ioProvider.readFileUTF8(filename); - const previousFilename = repl.context.__filename; - repl.context.__filename = absolutePath; - repl.context.__dirname = path.dirname(absolutePath); - try { - await promisify(repl.eval.bind(repl))(contents, repl.context, filename); - } finally { - if (previousFilename) { - repl.context.__filename = previousFilename; - repl.context.__dirname = path.dirname(previousFilename); - } else { - delete repl.context.__filename; - delete repl.context.__dirname; - } - } + return { + resolvedFilename: absolutePath, + evaluate: () => promisify(repl.eval.bind(repl))(contents, repl.context, absolutePath) + }; } /** diff --git a/packages/shell-api/src/index.ts b/packages/shell-api/src/index.ts index c123793761..877673cc47 100644 --- a/packages/shell-api/src/index.ts +++ b/packages/shell-api/src/index.ts @@ -5,7 +5,7 @@ import Database from './database'; import Explainable from './explainable'; import ExplainableCursor from './explainable-cursor'; import Help, { HelpProperties } from './help'; -import ShellInternalState, { EvaluationListener, ShellCliOptions } from './shell-internal-state'; +import ShellInternalState, { EvaluationListener, ShellCliOptions, OnLoadResult } from './shell-internal-state'; import toIterator from './toIterator'; import Shard from './shard'; import ReplicaSet from './replica-set'; @@ -62,5 +62,6 @@ export { getShellApiType, ShellResult, ShellCliOptions, - TypeSignature + TypeSignature, + OnLoadResult }; diff --git a/packages/shell-api/src/shell-api.spec.ts b/packages/shell-api/src/shell-api.spec.ts index fc95720081..33a90c523d 100644 --- a/packages/shell-api/src/shell-api.spec.ts +++ b/packages/shell-api/src/shell-api.spec.ts @@ -570,9 +570,22 @@ describe('ShellApi', () => { }); describe('load', () => { it('asks the evaluation listener to load a file', async() => { - evaluationListener.onLoad.resolves(); + evaluationListener.onLoad.callsFake(async(filename: string) => { + expect(filename).to.equal('abc.js'); + expect(internalState.context.__filename).to.equal(undefined); + expect(internalState.context.__dirname).to.equal(undefined); + return { + resolvedFilename: '/resolved/abc.js', + evaluate: async() => { + expect(internalState.context.__filename).to.equal('/resolved/abc.js'); + expect(internalState.context.__dirname).to.equal('/resolved'); + } + }; + }); await internalState.context.load('abc.js'); - expect(evaluationListener.onLoad).to.have.been.calledWith('abc.js'); + expect(evaluationListener.onLoad).to.have.callCount(1); + expect(internalState.context.__filename).to.equal(undefined); + expect(internalState.context.__dirname).to.equal(undefined); }); }); for (const cmd of ['print', 'printjson']) { diff --git a/packages/shell-api/src/shell-api.ts b/packages/shell-api/src/shell-api.ts index 6df21d5593..f71ce548b2 100644 --- a/packages/shell-api/src/shell-api.ts +++ b/packages/shell-api/src/shell-api.ts @@ -20,6 +20,7 @@ import { CommonErrors, MongoshUnimplementedError, MongoshInternalError } from '@ import { DBQuery } from './deprecated'; import { promisify } from 'util'; import { ClientSideFieldLevelEncryptionOptions } from './field-level-encryption'; +import { dirname } from 'path'; @shellApiClassDefault @hasAsyncChild @@ -109,7 +110,25 @@ export default class ShellApi extends ShellApiClass { CommonErrors.NotImplemented ); } - await this.internalState.evaluationListener.onLoad(filename); + const { + resolvedFilename, evaluate + } = await this.internalState.evaluationListener.onLoad(filename); + + const context = this.internalState.context; + const previousFilename = context.__filename; + context.__filename = resolvedFilename; + context.__dirname = dirname(resolvedFilename); + try { + await evaluate(); + } finally { + if (previousFilename) { + context.__filename = previousFilename; + context.__dirname = dirname(previousFilename); + } else { + delete context.__filename; + delete context.__dirname; + } + } return true; } diff --git a/packages/shell-api/src/shell-internal-state.ts b/packages/shell-api/src/shell-internal-state.ts index bc5984e94e..604671be2d 100644 --- a/packages/shell-api/src/shell-internal-state.ts +++ b/packages/shell-api/src/shell-internal-state.ts @@ -35,6 +35,19 @@ export interface AutocompleteParameters { getCollectionCompletionsForCurrentDb: (collName: string) => Promise; } +export interface OnLoadResult { + /** + * The absolute path of the file that should be load()ed. + */ + resolvedFilename: string; + + /** + * The actual steps that are needed to evaluate the load()ed file. + * For the duration of this call, __filename and __dirname are set as expected. + */ + evaluate(): Promise; +} + export interface EvaluationListener { /** * Called when print() or printjson() is run from the shell. @@ -66,7 +79,7 @@ export interface EvaluationListener { /** * Called when load() is used in the shell. */ - onLoad?: (filename: string) => Promise; + onLoad?: (filename: string) => Promise | OnLoadResult; } /**