diff --git a/extensions/ql-vscode/src/databases/github-database-download.ts b/extensions/ql-vscode/src/databases/github-database-download.ts index e19debf79cb..4f24ef38c68 100644 --- a/extensions/ql-vscode/src/databases/github-database-download.ts +++ b/extensions/ql-vscode/src/databases/github-database-download.ts @@ -2,10 +2,7 @@ import { window } from "vscode"; import { Octokit } from "@octokit/rest"; import { showNeverAskAgainDialog } from "../common/vscode/dialog"; import { getLanguageDisplayName } from "../common/query-language"; -import { - downloadGitHubDatabaseFromUrl, - promptForLanguage, -} from "./database-fetcher"; +import { downloadGitHubDatabaseFromUrl } from "./database-fetcher"; import { withProgress } from "../common/vscode/progress"; import { DatabaseManager } from "./local-databases"; import { CodeQLCliServer } from "../codeql-cli/cli"; @@ -66,40 +63,46 @@ export async function downloadDatabaseFromGitHub( cliServer: CodeQLCliServer, commandManager: AppCommandManager, ): Promise { - const languages = databases.map((database) => database.language); - - const language = await promptForLanguage(languages, undefined); - if (!language) { + const selectedDatabases = await promptForDatabases(databases); + if (selectedDatabases.length === 0) { return; } - const database = databases.find((database) => database.language === language); - if (!database) { - return; - } + await Promise.all( + selectedDatabases.map((database) => + withProgress( + async (progress) => { + await downloadGitHubDatabaseFromUrl( + database.url, + database.id, + database.created_at, + database.commit_oid ?? null, + owner, + repo, + octokit, + progress, + databaseManager, + storagePath, + cliServer, + true, + false, + ); - await withProgress(async (progress) => { - await downloadGitHubDatabaseFromUrl( - database.url, - database.id, - database.created_at, - database.commit_oid ?? null, - owner, - repo, - octokit, - progress, - databaseManager, - storagePath, - cliServer, - true, - false, - ); - - await commandManager.execute("codeQLDatabases.focus"); - void window.showInformationMessage( - `Downloaded ${getLanguageDisplayName(language)} database from GitHub.`, - ); - }); + await commandManager.execute("codeQLDatabases.focus"); + void window.showInformationMessage( + `Downloaded ${getLanguageDisplayName( + database.language, + )} database from GitHub.`, + ); + }, + { + title: `Adding ${getLanguageDisplayName( + database.language, + )} database from GitHub`, + }, + ), + ), + ); } /** @@ -126,3 +129,34 @@ function joinLanguages(languages: string[]): string { return result; } + +async function promptForDatabases( + databases: CodeqlDatabase[], +): Promise { + if (databases.length === 1) { + return databases; + } + + const items = databases + .map((database) => { + const bytesToDisplayMB = `${(database.size / (1024 * 1024)).toFixed( + 1, + )} MB`; + + return { + label: getLanguageDisplayName(database.language), + description: bytesToDisplayMB, + database, + }; + }) + .sort((a, b) => a.label.localeCompare(b.label)); + + const selectedItems = await window.showQuickPick(items, { + title: "Select databases to download", + placeHolder: "Databases found in this repository", + ignoreFocusOut: true, + canPickMany: true, + }); + + return selectedItems?.map((selectedItem) => selectedItem.database) ?? []; +} diff --git a/extensions/ql-vscode/test/vscode-tests/no-workspace/databases/github-database-download.test.ts b/extensions/ql-vscode/test/vscode-tests/no-workspace/databases/github-database-download.test.ts index 601a7dc318d..a9fc9309cda 100644 --- a/extensions/ql-vscode/test/vscode-tests/no-workspace/databases/github-database-download.test.ts +++ b/extensions/ql-vscode/test/vscode-tests/no-workspace/databases/github-database-download.test.ts @@ -1,6 +1,7 @@ import { faker } from "@faker-js/faker"; import { Octokit } from "@octokit/rest"; -import { mockedObject } from "../../utils/mocking.helpers"; +import { QuickPickItem, window } from "vscode"; +import { mockedObject, mockedQuickPickItem } from "../../utils/mocking.helpers"; import { askForGitHubDatabaseDownload, downloadDatabaseFromGitHub, @@ -103,15 +104,14 @@ describe("downloadDatabaseFromGitHub", () => { created_at: faker.date.past().toISOString(), commit_oid: faker.git.commitSha(), language: "swift", + size: 27389673, url: faker.internet.url({ protocol: "https", }), }), ]; - let promptForLanguageSpy: jest.SpiedFunction< - typeof databaseFetcher.promptForLanguage - >; + let showQuickPickSpy: jest.SpiedFunction; let downloadGitHubDatabaseFromUrlSpy: jest.SpiedFunction< typeof databaseFetcher.downloadGitHubDatabaseFromUrl >; @@ -121,9 +121,13 @@ describe("downloadDatabaseFromGitHub", () => { databaseManager = mockedObject({}); cliServer = mockedObject({}); - promptForLanguageSpy = jest - .spyOn(databaseFetcher, "promptForLanguage") - .mockResolvedValue(databases[0].language); + showQuickPickSpy = jest.spyOn(window, "showQuickPick").mockResolvedValue( + mockedQuickPickItem([ + mockedObject({ + database: databases[0], + }), + ]), + ); downloadGitHubDatabaseFromUrlSpy = jest .spyOn(databaseFetcher, "downloadGitHubDatabaseFromUrl") .mockResolvedValue(undefined); @@ -157,28 +161,6 @@ describe("downloadDatabaseFromGitHub", () => { true, false, ); - expect(promptForLanguageSpy).toHaveBeenCalledWith(["swift"], undefined); - }); - - describe("when not selecting language", () => { - beforeEach(() => { - promptForLanguageSpy.mockResolvedValue(undefined); - }); - - it("does not download the database", async () => { - await downloadDatabaseFromGitHub( - octokit, - owner, - repo, - databases, - databaseManager, - storagePath, - cliServer, - commandManager, - ); - - expect(downloadGitHubDatabaseFromUrlSpy).not.toHaveBeenCalled(); - }); }); describe("when there are multiple languages", () => { @@ -189,6 +171,7 @@ describe("downloadDatabaseFromGitHub", () => { created_at: faker.date.past().toISOString(), commit_oid: faker.git.commitSha(), language: "swift", + size: 27389673, url: faker.internet.url({ protocol: "https", }), @@ -198,16 +181,23 @@ describe("downloadDatabaseFromGitHub", () => { created_at: faker.date.past().toISOString(), commit_oid: null, language: "go", + size: 2930572385, url: faker.internet.url({ protocol: "https", }), }), ]; - - promptForLanguageSpy.mockResolvedValue(databases[1].language); }); - it("downloads the correct database", async () => { + it("downloads a single selected language", async () => { + showQuickPickSpy.mockResolvedValue( + mockedQuickPickItem([ + mockedObject({ + database: databases[1], + }), + ]), + ); + await downloadDatabaseFromGitHub( octokit, owner, @@ -235,10 +225,113 @@ describe("downloadDatabaseFromGitHub", () => { true, false, ); - expect(promptForLanguageSpy).toHaveBeenCalledWith( - ["swift", "go"], - undefined, + expect(showQuickPickSpy).toHaveBeenCalledWith( + [ + expect.objectContaining({ + label: "Go", + description: "2794.8 MB", + database: databases[1], + }), + expect.objectContaining({ + label: "Swift", + description: "26.1 MB", + database: databases[0], + }), + ], + expect.anything(), + ); + }); + + it("downloads multiple selected languages", async () => { + showQuickPickSpy.mockResolvedValue( + mockedQuickPickItem([ + mockedObject({ + database: databases[0], + }), + mockedObject({ + database: databases[1], + }), + ]), + ); + + await downloadDatabaseFromGitHub( + octokit, + owner, + repo, + databases, + databaseManager, + storagePath, + cliServer, + commandManager, + ); + + expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledTimes(2); + expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledWith( + databases[0].url, + databases[0].id, + databases[0].created_at, + databases[0].commit_oid, + owner, + repo, + octokit, + expect.anything(), + databaseManager, + storagePath, + cliServer, + true, + false, + ); + expect(downloadGitHubDatabaseFromUrlSpy).toHaveBeenCalledWith( + databases[1].url, + databases[1].id, + databases[1].created_at, + databases[1].commit_oid, + owner, + repo, + octokit, + expect.anything(), + databaseManager, + storagePath, + cliServer, + true, + false, + ); + expect(showQuickPickSpy).toHaveBeenCalledWith( + [ + expect.objectContaining({ + label: "Go", + description: "2794.8 MB", + database: databases[1], + }), + expect.objectContaining({ + label: "Swift", + description: "26.1 MB", + database: databases[0], + }), + ], + expect.anything(), ); }); + + describe("when not selecting language", () => { + beforeEach(() => { + showQuickPickSpy.mockResolvedValue(undefined); + }); + + it("does not download the database", async () => { + await downloadDatabaseFromGitHub( + octokit, + owner, + repo, + databases, + databaseManager, + storagePath, + cliServer, + commandManager, + ); + + expect(downloadGitHubDatabaseFromUrlSpy).not.toHaveBeenCalled(); + }); + }); }); });