Skip to content

Commit

Permalink
Added Tree sitter provider (#2488)
Browse files Browse the repository at this point in the history
## Checklist

- [/] I have added
[tests](https://www.cursorless.org/docs/contributing/test-case-recorder/)
- [] I have updated the
[docs](https://github.com/cursorless-dev/cursorless/tree/main/docs) and
[cheatsheet](https://github.com/cursorless-dev/cursorless/tree/main/cursorless-talon/src/cheatsheet)
- [/] I have not broken the cheatsheet

---------

Co-authored-by: Pokey Rule <755842+pokey@users.noreply.github.com>
  • Loading branch information
AndreasArvidsson and pokey committed Jul 12, 2024
1 parent 4d68be9 commit bc50059
Show file tree
Hide file tree
Showing 12 changed files with 237 additions and 82 deletions.
20 changes: 20 additions & 0 deletions packages/common/src/ide/types/RawTreeSitterQueryProvider.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { Disposable } from "@cursorless/common";

/**
* Provides raw tree-sitter queries. These are usually read from `.scm` files
* on the filesystem, but this class abstracts away the details of how the
* queries are stored.
*/
export interface RawTreeSitterQueryProvider {
/**
* Listen for changes to queries. For now, this is only used during
* development, when we want to hot-reload queries.
*/
onChanges(listener: () => void): Disposable;

/**
* Return the raw text of the tree-sitter query of the given name. The query
* name is the name of one of the `.scm` files in our monorepo.
*/
readQuery(name: string): Promise<string | undefined>;
}
1 change: 1 addition & 0 deletions packages/common/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export * from "./ide/types/QuickPickOptions";
export * from "./ide/types/events.types";
export * from "./ide/types/Paths";
export * from "./ide/types/CommandHistoryStorage";
export * from "./ide/types/RawTreeSitterQueryProvider";
export * from "./ide/types/FileSystem.types";
export * from "./types/RangeExpansionBehavior";
export * from "./types/InputBoxOptions";
Expand Down
26 changes: 18 additions & 8 deletions packages/cursorless-engine/src/cursorlessEngine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
IDE,
ScopeProvider,
ensureCommandShape,
type FileSystem,
type RawTreeSitterQueryProvider,
} from "@cursorless/common";
import { KeyboardTargetUpdater } from "./KeyboardTargetUpdater";
import {
Expand All @@ -19,10 +19,15 @@ import { StoredTargetMap } from "./core/StoredTargets";
import { RangeUpdater } from "./core/updateSelections/RangeUpdater";
import { DisabledCommandServerApi } from "./disabledComponents/DisabledCommandServerApi";
import { DisabledHatTokenMap } from "./disabledComponents/DisabledHatTokenMap";
import { DisabledLanguageDefinitions } from "./disabledComponents/DisabledLanguageDefinitions";
import { DisabledSnippets } from "./disabledComponents/DisabledSnippets";
import { DisabledTalonSpokenForms } from "./disabledComponents/DisabledTalonSpokenForms";
import { DisabledTreeSitter } from "./disabledComponents/DisabledTreeSitter";
import { CustomSpokenFormGeneratorImpl } from "./generateSpokenForm/CustomSpokenFormGeneratorImpl";
import { LanguageDefinitions } from "./languages/LanguageDefinitions";
import {
LanguageDefinitionsImpl,
type LanguageDefinitions,
} from "./languages/LanguageDefinitions";
import { ModifierStageFactoryImpl } from "./processTargets/ModifierStageFactoryImpl";
import { ScopeHandlerFactoryImpl } from "./processTargets/modifiers/scopeHandlers";
import { runCommand } from "./runCommand";
Expand All @@ -39,8 +44,8 @@ import { TreeSitter } from "./typings/TreeSitter";
interface Props {
ide: IDE;
hats?: Hats;
treeSitter: TreeSitter;
fileSystem: FileSystem;
treeSitterQueryProvider?: RawTreeSitterQueryProvider;
treeSitter?: TreeSitter;
commandServerApi?: CommandServerApi;
talonSpokenForms?: TalonSpokenForms;
snippets?: Snippets;
Expand All @@ -49,8 +54,8 @@ interface Props {
export async function createCursorlessEngine({
ide,
hats,
treeSitter,
fileSystem,
treeSitterQueryProvider,
treeSitter = new DisabledTreeSitter(),
commandServerApi = new DisabledCommandServerApi(),
talonSpokenForms = new DisabledTalonSpokenForms(),
snippets = new DisabledSnippets(),
Expand All @@ -71,8 +76,13 @@ export async function createCursorlessEngine({
: new DisabledHatTokenMap();
void hatTokenMap.allocateHats();

const languageDefinitions = new LanguageDefinitions(fileSystem, treeSitter);
await languageDefinitions.init();
const languageDefinitions = treeSitterQueryProvider
? await LanguageDefinitionsImpl.create(
ide,
treeSitter,
treeSitterQueryProvider,
)
: new DisabledLanguageDefinitions();

ide.disposeOnExit(
rangeUpdater,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import type { TextDocument, Range, Listener } from "@cursorless/common";
import type { SyntaxNode } from "web-tree-sitter";
import type { LanguageDefinition } from "../languages/LanguageDefinition";
import type { LanguageDefinitions } from "../languages/LanguageDefinitions";

export class DisabledLanguageDefinitions implements LanguageDefinitions {
onDidChangeDefinition(_listener: Listener) {
return { dispose: () => {} };
}

loadLanguage(_languageId: string): Promise<void> {
return Promise.resolve();
}

get(_languageId: string): LanguageDefinition | undefined {
return undefined;
}

getNodeAtLocation(
_document: TextDocument,
_range: Range,
): SyntaxNode | undefined {
return undefined;
}

dispose(): void {
// Do nothing
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import type { TextDocument, Range } from "@cursorless/common";
import type { SyntaxNode, Tree, Language } from "web-tree-sitter";
import type { TreeSitter } from "../typings/TreeSitter";

export class DisabledTreeSitter implements TreeSitter {
getTree(_document: TextDocument): Tree {
throw new Error("Tree sitter not provided");
}

loadLanguage(_languageId: string): Promise<boolean> {
return Promise.resolve(false);
}

getLanguage(_languageId: string): Language | undefined {
throw new Error("Tree sitter not provided");
}

getNodeAtLocation(_document: TextDocument, _range: Range): SyntaxNode {
throw new Error("Tree sitter not provided");
}
}
53 changes: 26 additions & 27 deletions packages/cursorless-engine/src/languages/LanguageDefinition.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import {
FileSystem,
ScopeType,
SimpleScopeType,
showError,
type IDE,
type RawTreeSitterQueryProvider,
} from "@cursorless/common";
import { basename, dirname, join } from "pathe";
import { dirname, join } from "pathe";
import { TreeSitterScopeHandler } from "../processTargets/modifiers/scopeHandlers";
import { ide } from "../singletons/ide.singleton";
import { TreeSitter } from "../typings/TreeSitter";
import { matchAll } from "../util/regex";
import { TreeSitterQuery } from "./TreeSitterQuery";
Expand Down Expand Up @@ -36,16 +36,15 @@ export class LanguageDefinition {
* id doesn't have a new-style query definition
*/
static async create(
ide: IDE,
treeSitterQueryProvider: RawTreeSitterQueryProvider,
treeSitter: TreeSitter,
fileSystem: FileSystem,
queryDir: string,
languageId: string,
): Promise<LanguageDefinition | undefined> {
const languageQueryPath = join(queryDir, `${languageId}.scm`);

const rawLanguageQueryString = await readQueryFileAndImports(
fileSystem,
languageQueryPath,
ide,
treeSitterQueryProvider,
`${languageId}.scm`,
);

if (rawLanguageQueryString == null) {
Expand Down Expand Up @@ -91,43 +90,42 @@ export class LanguageDefinition {
* @returns The text of the query file, with all imports inlined
*/
async function readQueryFileAndImports(
fileSystem: FileSystem,
languageQueryPath: string,
ide: IDE,
provider: RawTreeSitterQueryProvider,
languageQueryName: string,
) {
// Seed the map with the query file itself
const rawQueryStrings: Record<string, string | null> = {
[languageQueryPath]: null,
[languageQueryName]: null,
};

const doValidation = ide().runMode !== "production";
const doValidation = ide.runMode !== "production";

// Keep reading imports until we've read all the imports. Every time we
// encounter an import in a query file, we add it to the map with a value
// of null, so that it will be read on the next iteration
while (Object.values(rawQueryStrings).some((v) => v == null)) {
for (const [queryPath, rawQueryString] of Object.entries(rawQueryStrings)) {
for (const [queryName, rawQueryString] of Object.entries(rawQueryStrings)) {
if (rawQueryString != null) {
continue;
}

const fileName = basename(queryPath);

let rawQuery = await fileSystem.readBundledFile(queryPath);
let rawQuery = await provider.readQuery(queryName);

if (rawQuery == null) {
if (queryPath === languageQueryPath) {
if (queryName === languageQueryName) {
// If this is the main query file, then we know that this language
// just isn't defined using new-style queries
return undefined;
}

showError(
ide().messages,
ide.messages,
"LanguageDefinition.readQueryFileAndImports.queryNotFound",
`Could not find imported query file ${queryPath}`,
`Could not find imported query file ${queryName}`,
);

if (ide().runMode === "test") {
if (ide.runMode === "test") {
throw new Error("Invalid import statement");
}

Expand All @@ -136,10 +134,10 @@ async function readQueryFileAndImports(
}

if (doValidation) {
validateQueryCaptures(fileName, rawQuery);
validateQueryCaptures(queryName, rawQuery);
}

rawQueryStrings[queryPath] = rawQuery;
rawQueryStrings[queryName] = rawQuery;
matchAll(
rawQuery,
// Matches lines like:
Expand All @@ -154,10 +152,10 @@ async function readQueryFileAndImports(
const relativeImportPath = match[1];

if (doValidation) {
validateImportSyntax(fileName, relativeImportPath, match[0]);
validateImportSyntax(ide, queryName, relativeImportPath, match[0]);
}

const importQueryPath = join(dirname(queryPath), relativeImportPath);
const importQueryPath = join(dirname(queryName), relativeImportPath);
rawQueryStrings[importQueryPath] =
rawQueryStrings[importQueryPath] ?? null;
},
Expand All @@ -169,6 +167,7 @@ async function readQueryFileAndImports(
}

function validateImportSyntax(
ide: IDE,
file: string,
relativeImportPath: string,
actual: string,
Expand All @@ -177,12 +176,12 @@ function validateImportSyntax(

if (actual !== canonicalSyntax) {
showError(
ide().messages,
ide.messages,
"LanguageDefinition.readQueryFileAndImports.malformedImport",
`Malformed import statement in ${file}: "${actual}". Import statements must be of the form "${canonicalSyntax}"`,
);

if (ide().runMode === "test") {
if (ide.runMode === "test") {
throw new Error("Invalid import statement");
}
}
Expand Down
Loading

0 comments on commit bc50059

Please sign in to comment.